From 796bc7ed3453839b1dc8d846b71fe9fac9a2d646 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Sat, 12 Apr 2025 15:38:19 -0700 Subject: [PATCH] fix: improve LDAP error handling (#425) Co-authored-by: Kyle Mendell --- backend/internal/common/errors.go | 7 + .../controller/user_group_controller.go | 2 +- backend/internal/service/ldap_service.go | 206 +++++++++++------- backend/internal/service/ldap_service_test.go | 73 +++++++ .../internal/service/user_group_service.go | 15 +- backend/internal/service/user_service.go | 39 +++- 6 files changed, 239 insertions(+), 103 deletions(-) create mode 100644 backend/internal/service/ldap_service_test.go diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 5307431c..8dd28b1a 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "net/http" ) @@ -21,6 +22,12 @@ func (e *AlreadyInUseError) Error() string { } func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } +func (e *AlreadyInUseError) Is(target error) bool { + // Ignore the field property when checking if an error is of the type AlreadyInUseError + x := &AlreadyInUseError{} + return errors.As(target, &x) +} + type SetupAlreadyCompletedError struct{} func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" } diff --git a/backend/internal/controller/user_group_controller.go b/backend/internal/controller/user_group_controller.go index 8f47b40b..6176c3d1 100644 --- a/backend/internal/controller/user_group_controller.go +++ b/backend/internal/controller/user_group_controller.go @@ -160,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) { return } - group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false) + group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index 6d6f412c..ce878cf0 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -15,6 +15,7 @@ import ( "time" "github.com/go-ldap/ldap/v3" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" "gorm.io/gorm" @@ -28,7 +29,12 @@ type LdapService struct { } func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService { - return &LdapService{db: db, appConfigService: appConfigService, userService: userService, groupService: groupService} + return &LdapService{ + db: db, + appConfigService: appConfigService, + userService: userService, + groupService: groupService, + } } func (s *LdapService) createClient() (*ldap.Conn, error) { @@ -39,19 +45,15 @@ func (s *LdapService) createClient() (*ldap.Conn, error) { } // Setup LDAP connection - ldapURL := dbConfig.LdapUrl.Value - skipTLSVerify := dbConfig.LdapSkipCertVerify.IsTrue() - client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: skipTLSVerify, //nolint:gosec + client, err := ldap.DialURL(dbConfig.LdapUrl.Value, ldap.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: dbConfig.LdapSkipCertVerify.IsTrue(), //nolint:gosec })) if err != nil { return nil, fmt.Errorf("failed to connect to LDAP: %w", err) } // Bind as service account - bindDn := dbConfig.LdapBindDn.Value - bindPassword := dbConfig.LdapBindPassword.Value - err = client.Bind(bindDn, bindPassword) + err = client.Bind(dbConfig.LdapBindDn.Value, dbConfig.LdapBindPassword.Value) if err != nil { return nil, fmt.Errorf("failed to bind to LDAP: %w", err) } @@ -65,12 +67,19 @@ func (s *LdapService) SyncAll(ctx context.Context) error { tx.Rollback() }() - err := s.SyncUsers(ctx, tx) + // Setup LDAP connection + client, err := s.createClient() + if err != nil { + return fmt.Errorf("failed to create LDAP client: %w", err) + } + defer client.Close() + + err = s.SyncUsers(ctx, tx, client) if err != nil { return fmt.Errorf("failed to sync users: %w", err) } - err = s.SyncGroups(ctx, tx) + err = s.SyncGroups(ctx, tx, client) if err != nil { return fmt.Errorf("failed to sync groups: %w", err) } @@ -85,16 +94,9 @@ func (s *LdapService) SyncAll(ctx context.Context) error { } //nolint:gocognit -func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { +func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error { dbConfig := s.appConfigService.GetDbConfig() - // Setup LDAP connection - client, err := s.createClient() - if err != nil { - return fmt.Errorf("failed to create LDAP client: %w", err) - } - defer client.Close() - searchAttrs := []string{ dbConfig.LdapAttributeGroupName.Value, dbConfig.LdapAttributeGroupUniqueIdentifier.Value, @@ -115,11 +117,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { } // Create a mapping for groups that exist - ldapGroupIDs := make(map[string]bool) + ldapGroupIDs := make(map[string]struct{}, len(result.Entries)) for _, value := range result.Entries { - var membersUserId []string - ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value) // Skip groups without a valid LDAP ID @@ -128,29 +128,40 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { continue } - ldapGroupIDs[ldapId] = true + ldapGroupIDs[ldapId] = struct{}{} // Try to find the group in the database var databaseGroup model.UserGroup - tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup) + err = tx. + WithContext(ctx). + Where("ldap_id = ?", ldapId). + First(&databaseGroup). + Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + // This could error with ErrRecordNotFound and we want to ignore that here + return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err) + } // Get group members and add to the correct Group groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value) + membersUserId := make([]string, 0, len(groupMembers)) for _, member := range groupMembers { - // Normal output of this would be CN=username,ou=people,dc=example,dc=com - // Splitting at the "=" and "," then just grabbing the username for that string - singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0] + ldapId := getDNProperty("uid", member) + if ldapId == "" { + continue + } var databaseUser model.User - err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - // The user collides with a non-LDAP user, so we skip it - continue - } else { - return err - } - + err = tx. + WithContext(ctx). + Where("username = ? AND ldap_id IS NOT NULL", ldapId). + First(&databaseUser). + Error + if errors.Is(err, gorm.ErrRecordNotFound) { + // The user collides with a non-LDAP user, so we skip it + continue + } else if err != nil { + return fmt.Errorf("failed to query for existing user '%s': %w", ldapId, err) } membersUserId = append(membersUserId, databaseUser.ID) @@ -165,26 +176,22 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { if databaseGroup.ID == "" { newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx) if err != nil { - log.Printf("Error syncing group %s: %v", syncGroup.Name, err) - continue + return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err) } _, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx) if err != nil { - log.Printf("Error syncing group %s: %v", syncGroup.Name, err) - continue + return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err) } } else { _, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx) if err != nil { - log.Printf("Error syncing group %s: %v", syncGroup.Name, err) - continue + return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err) } _, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx) if err != nil { - log.Printf("Error syncing group %s: %v", syncGroup.Name, err) - continue + return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err) } } } @@ -197,38 +204,33 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { Select("ldap_id"). Error if err != nil { - log.Printf("Failed to fetch groups from database: %v", err) + return fmt.Errorf("failed to fetch groups from database: %w", err) } // Delete groups that no longer exist in LDAP for _, group := range ldapGroupsInDb { - if _, exists := ldapGroupIDs[*group.LdapID]; !exists { - err = tx. - WithContext(ctx). - Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID). - Error - if err != nil { - log.Printf("Failed to delete group %s with: %v", group.Name, err) - } else { - log.Printf("Deleted group %s", group.Name) - } + if _, exists := ldapGroupIDs[*group.LdapID]; exists { + continue } + + err = tx. + WithContext(ctx). + Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID). + Error + if err != nil { + return fmt.Errorf("failed to delete group '%s': %w", group.Name, err) + } + + log.Printf("Deleted group '%s'", group.Name) } return nil } //nolint:gocognit -func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { +func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error { dbConfig := s.appConfigService.GetDbConfig() - // Setup LDAP connection - client, err := s.createClient() - if err != nil { - return fmt.Errorf("failed to create LDAP client: %w", err) - } - defer client.Close() - searchAttrs := []string{ "memberOf", "sn", @@ -253,11 +255,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { result, err := client.Search(searchReq) if err != nil { - fmt.Println(fmt.Errorf("failed to query LDAP: %w", err)) + return fmt.Errorf("failed to query LDAP: %w", err) } // Create a mapping for users that exist - ldapUserIDs := make(map[string]bool) + ldapUserIDs := make(map[string]struct{}, len(result.Entries)) for _, value := range result.Entries { ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value) @@ -268,17 +270,26 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { continue } - ldapUserIDs[ldapId] = true + ldapUserIDs[ldapId] = struct{}{} // Get the user from the database var databaseUser model.User - tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser) + err = tx. + WithContext(ctx). + Where("ldap_id = ?", ldapId). + First(&databaseUser). + Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + // This could error with ErrRecordNotFound and we want to ignore that here + return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err) + } // Check if user is admin by checking if they are in the admin group isAdmin := false for _, group := range value.GetAttributeValues("memberOf") { - if strings.Contains(group, dbConfig.LdapAttributeAdminGroup.Value) { + if getDNProperty("cn", group) == dbConfig.LdapAttributeAdminGroup.Value { isAdmin = true + break } } @@ -292,20 +303,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { } if databaseUser.ID == "" { - _, err = s.userService.createUserInternal(ctx, newUser, tx) - if err != nil { - log.Printf("Error syncing user %s: %v", newUser.Username, err) + _, err = s.userService.createUserInternal(ctx, newUser, true, tx) + if errors.Is(err, &common.AlreadyInUseError{}) { + log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err) + continue + } else if err != nil { + return fmt.Errorf("error creating user '%s': %w", newUser.Username, err) } } else { _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx) - if err != nil { - log.Printf("Error syncing user %s: %v", newUser.Username, err) + if errors.Is(err, &common.AlreadyInUseError{}) { + log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err) + continue + } else if err != nil { + return fmt.Errorf("error updating user '%s': %w", newUser.Username, err) } } // Save profile picture - if pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value); pictureString != "" { - if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil { + pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value) + if pictureString != "" { + err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString) + if err != nil { + // This is not a fatal error log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err) } } @@ -319,18 +339,21 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { Select("ldap_id"). Error if err != nil { - log.Printf("Failed to fetch users from database: %v", err) + return fmt.Errorf("failed to fetch users from database: %w", err) } // Delete users that no longer exist in LDAP for _, user := range ldapUsersInDb { - if _, exists := ldapUserIDs[*user.LdapID]; !exists { - if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil { - log.Printf("Failed to delete user %s with: %v", user.Username, err) - } else { - log.Printf("Deleted user %s", user.Username) - } + if _, exists := ldapUserIDs[*user.LdapID]; exists { + continue } + + err = s.userService.deleteUserInternal(ctx, user.ID, true, tx) + if err != nil { + return fmt.Errorf("failed to delete user '%s': %w", user.Username, err) + } + + log.Printf("Deleted user '%s'", user.Username) } return nil @@ -367,9 +390,28 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin } // Update the profile picture - if err := s.userService.UpdateProfilePicture(userId, reader); err != nil { + err = s.userService.UpdateProfilePicture(userId, reader) + if err != nil { return fmt.Errorf("failed to update profile picture: %w", err) } return nil } + +// getDNProperty returns the value of a property from a LDAP identifier +// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names +func getDNProperty(property string, str string) string { + // Example format is "CN=username,ou=people,dc=example,dc=com" + // First we split at the comma + property = strings.ToLower(property) + l := len(property) + 1 + for _, v := range strings.Split(str, ",") { + v = strings.TrimSpace(v) + if len(v) > l && strings.ToLower(v)[0:l] == property+"=" { + return v[l:] + } + } + + // CN not found, return an empty string + return "" +} diff --git a/backend/internal/service/ldap_service_test.go b/backend/internal/service/ldap_service_test.go new file mode 100644 index 00000000..3a3e8c0b --- /dev/null +++ b/backend/internal/service/ldap_service_test.go @@ -0,0 +1,73 @@ +package service + +import ( + "testing" +) + +func TestGetDNProperty(t *testing.T) { + tests := []struct { + name string + property string + dn string + expectedResult string + }{ + { + name: "simple case", + property: "cn", + dn: "cn=username,ou=people,dc=example,dc=com", + expectedResult: "username", + }, + { + name: "property not found", + property: "uid", + dn: "cn=username,ou=people,dc=example,dc=com", + expectedResult: "", + }, + { + name: "mixed case property", + property: "CN", + dn: "cn=username,ou=people,dc=example,dc=com", + expectedResult: "username", + }, + { + name: "mixed case DN", + property: "cn", + dn: "CN=username,OU=people,DC=example,DC=com", + expectedResult: "username", + }, + { + name: "spaces in DN", + property: "cn", + dn: "cn=username, ou=people, dc=example, dc=com", + expectedResult: "username", + }, + { + name: "value with special characters", + property: "cn", + dn: "cn=user.name+123,ou=people,dc=example,dc=com", + expectedResult: "user.name+123", + }, + { + name: "empty DN", + property: "cn", + dn: "", + expectedResult: "", + }, + { + name: "empty property", + property: "", + dn: "cn=username,ou=people,dc=example,dc=com", + expectedResult: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getDNProperty(tt.property, tt.dn) + if result != tt.expectedResult { + t.Errorf("getDNProperty(%q, %q) = %q, want %q", + tt.property, tt.dn, result, tt.expectedResult) + } + }) + } +} diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index efa3ac9f..a6aedbc5 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -122,13 +122,13 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro return group, nil } -func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) { +func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) { tx := s.db.Begin() defer func() { tx.Rollback() }() - group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx) + group, err = s.updateInternal(ctx, id, input, false, tx) if err != nil { return model.UserGroup{}, err } @@ -141,14 +141,14 @@ func (s *UserGroupService) Update(ctx context.Context, id string, input dto.User return group, nil } -func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) { +func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, isLdapSync bool, tx *gorm.DB) (group model.UserGroup, err error) { group, err = s.getInternal(ctx, id, tx) if err != nil { return model.UserGroup{}, err } // Disallow updating the group if it is an LDAP group and LDAP is enabled - if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() { + if !isLdapSync && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() { return model.UserGroup{}, &common.LdapUserGroupUpdateError{} } @@ -160,10 +160,9 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input Preload("Users"). Save(&group). Error - if err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} - } + if errors.Is(err, gorm.ErrDuplicatedKey) { + return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} + } else if err != nil { return model.UserGroup{}, err } return group, nil diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1724a260..d2764d1d 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -184,7 +184,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all First(&user). Error if err != nil { - return err + return fmt.Errorf("failed to load user to delete: %w", err) } // Disallow deleting the user if it is an LDAP user and LDAP is enabled @@ -199,7 +199,12 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all return err } - return tx.WithContext(ctx).Delete(&user).Error + err = tx.WithContext(ctx).Delete(&user).Error + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + + return nil } func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) { @@ -208,7 +213,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) ( tx.Rollback() }() - user, err := s.createUserInternal(ctx, input, tx) + user, err := s.createUserInternal(ctx, input, false, tx) if err != nil { return model.User{}, err } @@ -221,7 +226,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) ( return user, nil } -func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) { +func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) { user := model.User{ FirstName: input.FirstName, LastName: input.LastName, @@ -236,10 +241,15 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea err := tx.WithContext(ctx).Create(&user).Error if errors.Is(err, gorm.ErrDuplicatedKey) { - tx.Rollback() + // Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here + if !isLdapSync { + tx.Rollback() + // If we are here, the transaction is already aborted due to an error, so we pass s.db + err = s.checkDuplicatedFields(ctx, user, s.db) + } else { + err = s.checkDuplicatedFields(ctx, user, tx) + } - // If we are here, the transaction is already aborted due to an error, so we pass s.db - err = s.checkDuplicatedFields(ctx, user, s.db) return model.User{}, err } else if err != nil { return model.User{}, err @@ -266,7 +276,7 @@ func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser return user, nil } -func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) { +func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) { var user model.User err := tx. WithContext(ctx). @@ -278,7 +288,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd } // Disallow updating the user if it is an LDAP group and LDAP is enabled - if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() { + if !isLdapSync && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() { return model.User{}, &common.LdapUserUpdateError{} } @@ -296,10 +306,15 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd Save(&user). Error if errors.Is(err, gorm.ErrDuplicatedKey) { - tx.Rollback() + // Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here + if !isLdapSync { + tx.Rollback() + // If we are here, the transaction is already aborted due to an error, so we pass s.db + err = s.checkDuplicatedFields(ctx, user, s.db) + } else { + err = s.checkDuplicatedFields(ctx, user, tx) + } - // If we are here, the transaction is already aborted due to an error, so we pass s.db - err = s.checkDuplicatedFields(ctx, user, s.db) return user, err } else if err != nil { return user, err