diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index ac6042b2..05bd1e4c 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -35,6 +35,7 @@ type LdapService struct { userService *UserService groupService *UserGroupService fileStorage storage.FileStorage + clientFactory func() (ldapClient, error) } type savePicture struct { @@ -43,8 +44,33 @@ type savePicture struct { picture string } +type ldapDesiredUser struct { + ldapID string + input dto.UserCreateDto + picture string +} + +type ldapDesiredGroup struct { + ldapID string + input dto.UserGroupCreateDto + memberUsernames []string +} + +type ldapDesiredState struct { + users []ldapDesiredUser + userIDs map[string]struct{} + groups []ldapDesiredGroup + groupIDs map[string]struct{} +} + +type ldapClient interface { + Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) + Bind(username, password string) error + Close() error +} + func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService { - return &LdapService{ + service := &LdapService{ db: db, httpClient: httpClient, appConfigService: appConfigService, @@ -52,9 +78,12 @@ func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppC groupService: groupService, fileStorage: fileStorage, } + + service.clientFactory = service.createClient + return service } -func (s *LdapService) createClient() (*ldap.Conn, error) { +func (s *LdapService) createClient() (ldapClient, error) { dbConfig := s.appConfigService.GetDbConfig() if !dbConfig.LdapEnabled.IsTrue() { @@ -79,24 +108,33 @@ func (s *LdapService) createClient() (*ldap.Conn, error) { func (s *LdapService) SyncAll(ctx context.Context) error { // Setup LDAP connection - client, err := s.createClient() + client, err := s.clientFactory() if err != nil { return fmt.Errorf("failed to create LDAP client: %w", err) } defer client.Close() - // Start a transaction - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() + // First, we fetch all users and group from LDAP, which is our "desired state" + desiredState, err := s.fetchDesiredState(ctx, client) + if err != nil { + return fmt.Errorf("failed to fetch LDAP state: %w", err) + } - savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client) + // Start a transaction + tx := s.db.WithContext(ctx).Begin() + if tx.Error != nil { + return fmt.Errorf("failed to begin database transaction: %w", tx.Error) + } + defer tx.Rollback() + + // Reconcile users + savePictures, deleteFiles, err := s.reconcileUsers(ctx, tx, desiredState.users, desiredState.userIDs) if err != nil { return fmt.Errorf("failed to sync users: %w", err) } - err = s.SyncGroups(ctx, tx, client) + // Reconcile groups + err = s.reconcileGroups(ctx, tx, desiredState.groups, desiredState.groupIDs) if err != nil { return fmt.Errorf("failed to sync groups: %w", err) } @@ -129,10 +167,31 @@ func (s *LdapService) SyncAll(ctx context.Context) error { return nil } -//nolint:gocognit -func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error { +func (s *LdapService) fetchDesiredState(ctx context.Context, client ldapClient) (ldapDesiredState, error) { + // Fetch users first so we can use their DNs when resolving group members + users, userIDs, usernamesByDN, err := s.fetchUsersFromLDAP(ctx, client) + if err != nil { + return ldapDesiredState{}, err + } + + // Then fetch groups to complete the desired LDAP state snapshot + groups, groupIDs, err := s.fetchGroupsFromLDAP(ctx, client, usernamesByDN) + if err != nil { + return ldapDesiredState{}, err + } + + return ldapDesiredState{ + users: users, + userIDs: userIDs, + groups: groups, + groupIDs: groupIDs, + }, nil +} + +func (s *LdapService) fetchGroupsFromLDAP(ctx context.Context, client ldapClient, usernamesByDN map[string]string) (desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}, err error) { dbConfig := s.appConfigService.GetDbConfig() + // Query LDAP for all groups we want to manage searchAttrs := []string{ dbConfig.LdapAttributeGroupName.Value, dbConfig.LdapAttributeGroupUniqueIdentifier.Value, @@ -149,90 +208,42 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. ) result, err := client.Search(searchReq) if err != nil { - return fmt.Errorf("failed to query LDAP: %w", err) + return nil, nil, fmt.Errorf("failed to query LDAP groups: %w", err) } - // Create a mapping for groups that exist - ldapGroupIDs := make(map[string]struct{}, len(result.Entries)) + // Build the in-memory desired state for groups + ldapGroupIDs = make(map[string]struct{}, len(result.Entries)) + desiredGroups = make([]ldapDesiredGroup, 0, len(result.Entries)) for _, value := range result.Entries { - ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)) + ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)) // Skip groups without a valid LDAP ID - if ldapId == "" { + if ldapID == "" { slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)) continue } - ldapGroupIDs[ldapId] = struct{}{} - - // Try to find the group in the database - var databaseGroup model.UserGroup - err = tx. - WithContext(ctx). - Where("ldap_id = ?", ldapId). - First(&databaseGroup). - Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - // This could error with ErrRecordNotFound and we want to ignore that here - return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err) - } + ldapGroupIDs[ldapID] = struct{}{} // Get group members and add to the correct Group groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value) - membersUserId := make([]string, 0, len(groupMembers)) + memberUsernames := make([]string, 0, len(groupMembers)) for _, member := range groupMembers { - username := getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member) - - // If username extraction fails, try to query LDAP directly for the user + username := s.resolveGroupMemberUsername(ctx, client, member, usernamesByDN) if username == "" { - // Query LDAP to get the user by their DN - userSearchReq := ldap.NewSearchRequest( - member, - ldap.ScopeBaseObject, - 0, 0, 0, false, - "(objectClass=*)", - []string{dbConfig.LdapAttributeUserUsername.Value, dbConfig.LdapAttributeUserUniqueIdentifier.Value}, - []ldap.Control{}, - ) - - userResult, err := client.Search(userSearchReq) - if err != nil || len(userResult.Entries) == 0 { - slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err)) - continue - } - - username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value) - if username == "" { - slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member)) - continue - } - } - - username = norm.NFC.String(username) - - var databaseUser model.User - err = tx. - WithContext(ctx). - Where("username = ? AND ldap_id IS NOT NULL", username). - First(&databaseUser). - Error - if errors.Is(err, gorm.ErrRecordNotFound) { - // The user collides with a non-LDAP user, so we skip it continue - } else if err != nil { - return fmt.Errorf("failed to query for existing user '%s': %w", username, err) } - membersUserId = append(membersUserId, databaseUser.ID) + memberUsernames = append(memberUsernames, username) } syncGroup := dto.UserGroupCreateDto{ Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value), FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value), - LdapID: ldapId, + LdapID: ldapID, } - dto.Normalize(syncGroup) + dto.Normalize(&syncGroup) err = syncGroup.Validate() if err != nil { @@ -240,64 +251,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. continue } - if databaseGroup.ID == "" { - newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx) - if err != nil { - return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err) - } - - _, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx) - if err != nil { - return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err) - } - } else { - _, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx) - if err != nil { - return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err) - } - - _, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx) - if err != nil { - return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err) - } - } + desiredGroups = append(desiredGroups, ldapDesiredGroup{ + ldapID: ldapID, + input: syncGroup, + memberUsernames: memberUsernames, + }) } - // Get all LDAP groups from the database - var ldapGroupsInDb []model.UserGroup - err = tx. - WithContext(ctx). - Find(&ldapGroupsInDb, "ldap_id IS NOT NULL"). - Select("ldap_id"). - Error - if err != nil { - return fmt.Errorf("failed to fetch groups from database: %w", err) - } - - // Delete groups that no longer exist in LDAP - for _, group := range ldapGroupsInDb { - if _, exists := ldapGroupIDs[*group.LdapID]; exists { - continue - } - - err = tx. - WithContext(ctx). - Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID). - Error - if err != nil { - return fmt.Errorf("failed to delete group '%s': %w", group.Name, err) - } - - slog.Info("Deleted group", slog.String("group", group.Name)) - } - - return nil + return desiredGroups, ldapGroupIDs, nil } -//nolint:gocognit -func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) { +func (s *LdapService) fetchUsersFromLDAP(ctx context.Context, client ldapClient) (desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}, usernamesByDN map[string]string, err error) { dbConfig := s.appConfigService.GetDbConfig() + // Query LDAP for all users we want to manage searchAttrs := []string{ "memberOf", "sn", @@ -323,50 +290,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C result, err := client.Search(searchReq) if err != nil { - return nil, nil, fmt.Errorf("failed to query LDAP: %w", err) + return nil, nil, nil, fmt.Errorf("failed to query LDAP users: %w", err) } - // Create a mapping for users that exist - ldapUserIDs := make(map[string]struct{}, len(result.Entries)) - savePictures = make([]savePicture, 0, len(result.Entries)) + // Build the in-memory desired state for users and a DN lookup for group membership resolution + ldapUserIDs = make(map[string]struct{}, len(result.Entries)) + usernamesByDN = make(map[string]string, len(result.Entries)) + desiredUsers = make([]ldapDesiredUser, 0, len(result.Entries)) for _, value := range result.Entries { - ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)) + username := norm.NFC.String(value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)) + if normalizedDN := normalizeLDAPDN(value.DN); normalizedDN != "" && username != "" { + usernamesByDN[normalizedDN] = username + } + + ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)) // Skip users without a valid LDAP ID - if ldapId == "" { + if ldapID == "" { slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value)) continue } - ldapUserIDs[ldapId] = struct{}{} - - // Get the user from the database - var databaseUser model.User - err = tx. - WithContext(ctx). - Where("ldap_id = ?", ldapId). - First(&databaseUser). - Error - - // If a user is found (even if disabled), enable them since they're now back in LDAP - if databaseUser.ID != "" && databaseUser.Disabled { - err = tx. - WithContext(ctx). - Model(&model.User{}). - Where("id = ?", databaseUser.ID). - Update("disabled", false). - Error - - if err != nil { - return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err) - } - } - - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - // This could error with ErrRecordNotFound and we want to ignore that here - return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err) - } + ldapUserIDs[ldapID] = struct{}{} // Check if user is admin by checking if they are in the admin group isAdmin := false @@ -385,14 +331,14 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value), DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value), IsAdmin: isAdmin, - LdapID: ldapId, + LdapID: ldapID, } if newUser.DisplayName == "" { newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName) } - dto.Normalize(newUser) + dto.Normalize(&newUser) err = newUser.Validate() if err != nil { @@ -400,53 +346,201 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C continue } - userID := databaseUser.ID - if databaseUser.ID == "" { - createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx) - if errors.Is(err, &common.AlreadyInUseError{}) { - slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err)) + desiredUsers = append(desiredUsers, ldapDesiredUser{ + ldapID: ldapID, + input: newUser, + picture: value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value), + }) + } + + return desiredUsers, ldapUserIDs, usernamesByDN, nil +} + +func (s *LdapService) resolveGroupMemberUsername(ctx context.Context, client ldapClient, member string, usernamesByDN map[string]string) string { + dbConfig := s.appConfigService.GetDbConfig() + + // First try the DN cache we built while loading users + username, exists := usernamesByDN[normalizeLDAPDN(member)] + if exists && username != "" { + return username + } + + // Then try to extract the username directly from the DN + username = getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member) + if username != "" { + return norm.NFC.String(username) + } + + // As a fallback, query LDAP for the referenced entry + userSearchReq := ldap.NewSearchRequest( + member, + ldap.ScopeBaseObject, + 0, 0, 0, false, + "(objectClass=*)", + []string{dbConfig.LdapAttributeUserUsername.Value}, + []ldap.Control{}, + ) + + userResult, err := client.Search(userSearchReq) + if err != nil || len(userResult.Entries) == 0 { + slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err)) + return "" + } + + username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value) + if username == "" { + slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member)) + return "" + } + + return norm.NFC.String(username) +} + +func (s *LdapService) reconcileGroups(ctx context.Context, tx *gorm.DB, desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}) error { + // Load the current LDAP-managed state from the database + ldapGroupsInDB, ldapGroupsByID, err := s.loadLDAPGroupsInDB(ctx, tx) + if err != nil { + return fmt.Errorf("failed to fetch groups from database: %w", err) + } + + _, _, ldapUsersByUsername, err := s.loadLDAPUsersInDB(ctx, tx) + if err != nil { + return fmt.Errorf("failed to fetch users from database: %w", err) + } + + // Apply creates and updates to match the desired LDAP group state + for _, desiredGroup := range desiredGroups { + memberUserIDs := make([]string, 0, len(desiredGroup.memberUsernames)) + for _, username := range desiredGroup.memberUsernames { + databaseUser, exists := ldapUsersByUsername[username] + if !exists { + // The user collides with a non-LDAP user or was skipped during user sync, so we ignore it continue - } else if err != nil { - return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err) - } - userID = createdUser.ID - } else { - _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx) - if errors.Is(err, &common.AlreadyInUseError{}) { - slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err)) - continue - } else if err != nil { - return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err) } + + memberUserIDs = append(memberUserIDs, databaseUser.ID) } - // Save profile picture - pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value) - if pictureString != "" { - // Storage operations must be executed outside of a transaction - savePictures = append(savePictures, savePicture{ - userID: databaseUser.ID, - username: userID, - picture: pictureString, - }) + databaseGroup := ldapGroupsByID[desiredGroup.ldapID] + if databaseGroup.ID == "" { + newGroup, err := s.groupService.createInternal(ctx, desiredGroup.input, tx) + if err != nil { + return fmt.Errorf("failed to create group '%s': %w", desiredGroup.input.Name, err) + } + ldapGroupsByID[desiredGroup.ldapID] = newGroup + + _, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, memberUserIDs, tx) + if err != nil { + return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err) + } + continue + } + + _, err = s.groupService.updateInternal(ctx, databaseGroup.ID, desiredGroup.input, true, tx) + if err != nil { + return fmt.Errorf("failed to update group '%s': %w", desiredGroup.input.Name, err) + } + + _, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, memberUserIDs, tx) + if err != nil { + return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err) } } - // Get all LDAP users from the database - var ldapUsersInDb []model.User - err = tx. - WithContext(ctx). - Find(&ldapUsersInDb, "ldap_id IS NOT NULL"). - Select("id, username, ldap_id, disabled"). - Error + // Delete groups that are no longer present in LDAP + for _, group := range ldapGroupsInDB { + if group.LdapID == nil { + continue + } + + if _, exists := ldapGroupIDs[*group.LdapID]; exists { + continue + } + + err = tx. + WithContext(ctx). + Delete(&model.UserGroup{}, "ldap_id = ?", *group.LdapID). + Error + if err != nil { + return fmt.Errorf("failed to delete group '%s': %w", group.Name, err) + } + + slog.Info("Deleted group", slog.String("group", group.Name)) + } + + return nil +} + +//nolint:gocognit +func (s *LdapService) reconcileUsers(ctx context.Context, tx *gorm.DB, desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}) (savePictures []savePicture, deleteFiles []string, err error) { + dbConfig := s.appConfigService.GetDbConfig() + + // Load the current LDAP-managed state from the database + ldapUsersInDB, ldapUsersByID, _, err := s.loadLDAPUsersInDB(ctx, tx) if err != nil { return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err) } - // Mark users as disabled or delete users that no longer exist in LDAP - deleteFiles = make([]string, 0, len(ldapUserIDs)) - for _, user := range ldapUsersInDb { - // Skip if the user ID exists in the fetched LDAP results + // Apply creates and updates to match the desired LDAP user state + savePictures = make([]savePicture, 0, len(desiredUsers)) + + for _, desiredUser := range desiredUsers { + databaseUser := ldapUsersByID[desiredUser.ldapID] + + // If a user is found (even if disabled), enable them since they're now back in LDAP. + if databaseUser.ID != "" && databaseUser.Disabled { + err = tx. + WithContext(ctx). + Model(&model.User{}). + Where("id = ?", databaseUser.ID). + Update("disabled", false). + Error + if err != nil { + return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err) + } + + databaseUser.Disabled = false + ldapUsersByID[desiredUser.ldapID] = databaseUser + } + + userID := databaseUser.ID + if databaseUser.ID == "" { + createdUser, err := s.userService.createUserInternal(ctx, desiredUser.input, true, tx) + if errors.Is(err, &common.AlreadyInUseError{}) { + slog.Warn("Skipping creating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err)) + continue + } else if err != nil { + return nil, nil, fmt.Errorf("error creating user '%s': %w", desiredUser.input.Username, err) + } + + userID = createdUser.ID + ldapUsersByID[desiredUser.ldapID] = createdUser + } else { + _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, desiredUser.input, false, true, tx) + if errors.Is(err, &common.AlreadyInUseError{}) { + slog.Warn("Skipping updating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err)) + continue + } else if err != nil { + return nil, nil, fmt.Errorf("error updating user '%s': %w", desiredUser.input.Username, err) + } + } + + if desiredUser.picture != "" { + savePictures = append(savePictures, savePicture{ + userID: userID, + username: desiredUser.input.Username, + picture: desiredUser.picture, + }) + } + } + + // Disable or delete users that are no longer present in LDAP + deleteFiles = make([]string, 0, len(ldapUsersInDB)) + for _, user := range ldapUsersInDB { + if user.LdapID == nil { + continue + } + if _, exists := ldapUserIDs[*user.LdapID]; exists { continue } @@ -458,29 +552,73 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C } slog.Info("Disabled user", slog.String("username", user.Username)) - } else { - err = s.userService.deleteUserInternal(ctx, tx, user.ID, true) - if err != nil { - target := &common.LdapUserUpdateError{} - if errors.As(err, &target) { - return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username) - } - return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err) - } - - slog.Info("Deleted user", slog.String("username", user.Username)) - - // Storage operations must be executed outside of a transaction - deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png")) + continue } + + err = s.userService.deleteUserInternal(ctx, tx, user.ID, true) + if err != nil { + target := &common.LdapUserUpdateError{} + if errors.As(err, &target) { + return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username) + } + return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err) + } + + slog.Info("Deleted user", slog.String("username", user.Username)) + deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png")) } return savePictures, deleteFiles, nil } +func (s *LdapService) loadLDAPUsersInDB(ctx context.Context, tx *gorm.DB) (users []model.User, byLdapID map[string]model.User, byUsername map[string]model.User, err error) { + // Load all LDAP-managed users and index them by LDAP ID and by username + err = tx. + WithContext(ctx). + Select("id, username, ldap_id, disabled"). + Where("ldap_id IS NOT NULL"). + Find(&users). + Error + if err != nil { + return nil, nil, nil, err + } + + byLdapID = make(map[string]model.User, len(users)) + byUsername = make(map[string]model.User, len(users)) + for _, user := range users { + byLdapID[*user.LdapID] = user + byUsername[user.Username] = user + } + + return users, byLdapID, byUsername, nil +} + +func (s *LdapService) loadLDAPGroupsInDB(ctx context.Context, tx *gorm.DB) ([]model.UserGroup, map[string]model.UserGroup, error) { + var groups []model.UserGroup + + // Load all LDAP-managed groups and index them by LDAP ID + err := tx. + WithContext(ctx). + Select("id, name, ldap_id"). + Where("ldap_id IS NOT NULL"). + Find(&groups). + Error + if err != nil { + return nil, nil, err + } + + groupsByID := make(map[string]model.UserGroup, len(groups)) + for _, group := range groups { + groupsByID[*group.LdapID] = group + } + + return groups, groupsByID, nil +} + func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error { var reader io.ReadSeeker + // Accept either a URL, a base64-encoded payload, or raw binary data _, err := url.ParseRequestURI(pictureString) if err == nil { ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second) @@ -522,6 +660,31 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin return nil } +// normalizeLDAPDN returns a canonical lowercase form of a DN for use as a map key. +// Different LDAP servers may format the same DN with varying attribute type casing (e.g. "CN=" vs "cn=") or extra whitespace (e.g. "dc=example, dc=com"). +// Without normalization, cache lookups in usernamesByDN would miss when a member attribute value uses a different format than the DN returned in the search entry +// +// ldap.ParseDN is used instead of simple lowercasing because it correctly handles multi-valued RDNs (joined with "+") and strips inter-component whitespace. +// If parsing fails for any reason, we fall back to a simple lowercase+trim. +func normalizeLDAPDN(dn string) string { + parsed, err := ldap.ParseDN(dn) + if err != nil { + return strings.ToLower(strings.TrimSpace(dn)) + } + + // Reconstruct the DN in a canonical form: lowercase type=lowercase value, with RDN components separated by "," and multi-value attributes by "+" + parts := make([]string, 0, len(parsed.RDNs)) + for _, rdn := range parsed.RDNs { + attrs := make([]string, 0, len(rdn.Attributes)) + for _, attr := range rdn.Attributes { + attrs = append(attrs, strings.ToLower(attr.Type)+"="+strings.ToLower(attr.Value)) + } + parts = append(parts, strings.Join(attrs, "+")) + } + + return strings.Join(parts, ",") +} + // getDNProperty returns the value of a property from a LDAP identifier // See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names func getDNProperty(property string, str string) string { diff --git a/backend/internal/service/ldap_service_test.go b/backend/internal/service/ldap_service_test.go index 1a049bfe..22553f8b 100644 --- a/backend/internal/service/ldap_service_test.go +++ b/backend/internal/service/ldap_service_test.go @@ -1,9 +1,286 @@ package service import ( + "net/http" "testing" + + "github.com/go-ldap/ldap/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/pocket-id/pocket-id/backend/internal/storage" + testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" ) +type fakeLDAPClient struct { + searchFn func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) +} + +func (c *fakeLDAPClient) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + if c.searchFn == nil { + return nil, nil + } + + return c.searchFn(searchRequest) +} + +func (c *fakeLDAPClient) Bind(_, _ string) error { + return nil +} + +func (c *fakeLDAPClient) Close() error { + return nil +} + +func TestLdapServiceSyncAllReconcilesUsersAndGroups(t *testing.T) { + service, db := newTestLdapService(t, newFakeLDAPClient( + ldapSearchResult( + ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{ + "entryUUID": {"u-alice"}, + "uid": {"alice"}, + "mail": {"alice@example.com"}, + "givenName": {"Alice"}, + "sn": {"Jones"}, + "displayName": {""}, + "memberOf": {"cn=admins,ou=groups,dc=example,dc=com"}, + }), + ldapEntry("uid=bob,ou=people,dc=example,dc=com", map[string][]string{ + "entryUUID": {"u-bob"}, + "uid": {"bob"}, + "mail": {"bob@example.com"}, + "givenName": {"Bob"}, + "sn": {"Brown"}, + "displayName": {""}, + }), + ), + ldapSearchResult( + ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{ + "entryUUID": {"g-team"}, + "cn": {"team"}, + "member": { + "UID=Alice, OU=People, DC=example, DC=com", + "uid=bob, ou=people, dc=example, dc=com", + }, + }), + ), + )) + + aliceLdapID := "u-alice" + missingLdapID := "u-missing" + teamLdapID := "g-team" + oldGroupLdapID := "g-old" + + require.NoError(t, db.Create(&model.User{ + Username: "alice-old", + Email: new("alice-old@example.com"), + EmailVerified: true, + FirstName: "Old", + LastName: "Name", + DisplayName: "Old Name", + LdapID: &aliceLdapID, + Disabled: true, + }).Error) + + require.NoError(t, db.Create(&model.User{ + Username: "missing", + Email: new("missing@example.com"), + EmailVerified: true, + FirstName: "Missing", + LastName: "User", + DisplayName: "Missing User", + LdapID: &missingLdapID, + }).Error) + + require.NoError(t, db.Create(&model.UserGroup{ + Name: "team-old", + FriendlyName: "team-old", + LdapID: &teamLdapID, + }).Error) + + require.NoError(t, db.Create(&model.UserGroup{ + Name: "old-group", + FriendlyName: "old-group", + LdapID: &oldGroupLdapID, + }).Error) + + require.NoError(t, service.SyncAll(t.Context())) + + var alice model.User + require.NoError(t, db.First(&alice, "ldap_id = ?", aliceLdapID).Error) + assert.Equal(t, "alice", alice.Username) + assert.Equal(t, new("alice@example.com"), alice.Email) + assert.Equal(t, "Alice", alice.FirstName) + assert.Equal(t, "Jones", alice.LastName) + assert.Equal(t, "Alice Jones", alice.DisplayName) + assert.True(t, alice.IsAdmin) + assert.False(t, alice.Disabled) + + var bob model.User + require.NoError(t, db.First(&bob, "ldap_id = ?", "u-bob").Error) + assert.Equal(t, "bob", bob.Username) + assert.Equal(t, "Bob Brown", bob.DisplayName) + + var missing model.User + require.NoError(t, db.First(&missing, "ldap_id = ?", missingLdapID).Error) + assert.True(t, missing.Disabled) + + var oldGroupCount int64 + require.NoError(t, db.Model(&model.UserGroup{}).Where("ldap_id = ?", oldGroupLdapID).Count(&oldGroupCount).Error) + assert.Zero(t, oldGroupCount) + + var team model.UserGroup + require.NoError(t, db.Preload("Users").First(&team, "ldap_id = ?", teamLdapID).Error) + assert.Equal(t, "team", team.Name) + assert.Equal(t, "team", team.FriendlyName) + assert.ElementsMatch(t, []string{"alice", "bob"}, usernames(team.Users)) +} + +func TestLdapServiceSyncAllHandlesDuplicateLDAPIDsInSingleRun(t *testing.T) { + service, db := newTestLdapService(t, newFakeLDAPClient( + ldapSearchResult( + ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{ + "entryUUID": {"u-dup"}, + "uid": {"alice"}, + "mail": {"alice@example.com"}, + "givenName": {"Alice"}, + "sn": {"Doe"}, + "displayName": {"Alice Doe"}, + }), + ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{ + "entryUUID": {"u-dup"}, + "uid": {"alice"}, + "mail": {"alice@example.com"}, + "givenName": {"Alicia"}, + "sn": {"Doe"}, + "displayName": {"Alicia Doe"}, + }), + ), + ldapSearchResult( + ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{ + "entryUUID": {"g-dup"}, + "cn": {"team"}, + "member": {"uid=alice,ou=people,dc=example,dc=com"}, + }), + ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{ + "entryUUID": {"g-dup"}, + "cn": {"team-renamed"}, + "member": {"uid=alice,ou=people,dc=example,dc=com"}, + }), + ), + )) + + require.NoError(t, service.SyncAll(t.Context())) + + var users []model.User + require.NoError(t, db.Find(&users, "ldap_id = ?", "u-dup").Error) + require.Len(t, users, 1) + assert.Equal(t, "alice", users[0].Username) + assert.Equal(t, "Alicia", users[0].FirstName) + assert.Equal(t, "Alicia Doe", users[0].DisplayName) + + var groups []model.UserGroup + require.NoError(t, db.Preload("Users").Find(&groups, "ldap_id = ?", "g-dup").Error) + require.Len(t, groups, 1) + assert.Equal(t, "team-renamed", groups[0].Name) + assert.Equal(t, "team-renamed", groups[0].FriendlyName) + assert.ElementsMatch(t, []string{"alice"}, usernames(groups[0].Users)) +} + +func newTestLdapService(t *testing.T, client ldapClient) (*LdapService, *gorm.DB) { + t.Helper() + + db := testutils.NewDatabaseForTest(t) + + fileStorage, err := storage.NewDatabaseStorage(db) + require.NoError(t, err) + + appConfig := NewTestAppConfigService(&model.AppConfig{ + RequireUserEmail: model.AppConfigVariable{Value: "false"}, + LdapEnabled: model.AppConfigVariable{Value: "true"}, + LdapBase: model.AppConfigVariable{Value: "dc=example,dc=com"}, + LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"}, + LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"}, + LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"}, + LdapAttributeUserUsername: model.AppConfigVariable{Value: "uid"}, + LdapAttributeUserEmail: model.AppConfigVariable{Value: "mail"}, + LdapAttributeUserFirstName: model.AppConfigVariable{Value: "givenName"}, + LdapAttributeUserLastName: model.AppConfigVariable{Value: "sn"}, + LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "displayName"}, + LdapAttributeUserProfilePicture: model.AppConfigVariable{Value: "jpegPhoto"}, + LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"}, + LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"}, + LdapAttributeGroupName: model.AppConfigVariable{Value: "cn"}, + LdapAdminGroupName: model.AppConfigVariable{Value: "admins"}, + LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"}, + }) + + groupService := NewUserGroupService(db, appConfig, nil) + userService := NewUserService( + db, + nil, + nil, + nil, + appConfig, + NewCustomClaimService(db), + NewAppImagesService(map[string]string{}, fileStorage), + nil, + fileStorage, + ) + + service := NewLdapService(db, &http.Client{}, appConfig, userService, groupService, fileStorage) + service.clientFactory = func() (ldapClient, error) { + return client, nil + } + + return service, db +} + +func newFakeLDAPClient(userResult, groupResult *ldap.SearchResult) ldapClient { + return &fakeLDAPClient{ + searchFn: func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + switch searchRequest.Filter { + case "(objectClass=person)": + return userResult, nil + case "(objectClass=groupOfNames)": + return groupResult, nil + default: + return &ldap.SearchResult{}, nil + } + }, + } +} + +func ldapSearchResult(entries ...*ldap.Entry) *ldap.SearchResult { + return &ldap.SearchResult{Entries: entries} +} + +func ldapEntry(dn string, attrs map[string][]string) *ldap.Entry { + entry := &ldap.Entry{ + DN: dn, + Attributes: make([]*ldap.EntryAttribute, 0, len(attrs)), + } + + for name, values := range attrs { + entry.Attributes = append(entry.Attributes, &ldap.EntryAttribute{ + Name: name, + Values: values, + }) + } + + return entry +} + +func usernames(users []model.User) []string { + result := make([]string, 0, len(users)) + for _, user := range users { + result = append(result, user.Username) + } + + return result +} + func TestGetDNProperty(t *testing.T) { tests := []struct { name string @@ -64,10 +341,58 @@ func TestGetDNProperty(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := getDNProperty(tt.property, tt.dn) - if result != tt.expectedResult { - t.Errorf("getDNProperty(%q, %q) = %q, want %q", - tt.property, tt.dn, result, tt.expectedResult) - } + assert.Equalf(t, tt.expectedResult, result, "getDNProperty(%q, %q)", tt.property, tt.dn) + }) + } +} + +func TestNormalizeLDAPDN(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "already normalized", + input: "cn=alice,dc=example,dc=com", + expected: "cn=alice,dc=example,dc=com", + }, + { + name: "uppercase attribute types", + input: "CN=Alice,DC=example,DC=com", + expected: "cn=alice,dc=example,dc=com", + }, + { + name: "spaces after commas", + input: "cn=alice, dc=example, dc=com", + expected: "cn=alice,dc=example,dc=com", + }, + { + name: "uppercase types and spaces", + input: "CN=Alice, DC=example, DC=com", + expected: "cn=alice,dc=example,dc=com", + }, + { + name: "multi-valued RDN", + input: "cn=alice+uid=a123,dc=example,dc=com", + expected: "cn=alice+uid=a123,dc=example,dc=com", + }, + { + name: "invalid DN falls back to lowercase+trim", + input: " NOT A VALID DN ", + expected: "not a valid dn", + }, + { + name: "empty string", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeLDAPDN(tt.input) + assert.Equalf(t, tt.expected, result, "normalizeLDAPDN(%q)", tt.input) }) } } @@ -98,9 +423,7 @@ func TestConvertLdapIdToString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := convertLdapIdToString(tt.input) - if got != tt.expected { - t.Errorf("Expected %q, got %q", tt.expected, got) - } + assert.Equal(t, tt.expected, got) }) } } diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index e55c9085..0e37a5ac 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -96,7 +96,10 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error { return err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return nil } @@ -126,7 +129,10 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro return model.UserGroup{}, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return group, nil } @@ -175,7 +181,10 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input return model.UserGroup{}, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return group, nil } @@ -238,7 +247,10 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u return model.UserGroup{}, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return group, nil } @@ -315,6 +327,9 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin return model.UserGroup{}, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return group, nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index f0ad2369..f3ec4fbc 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -225,7 +225,10 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI return fmt.Errorf("failed to delete user: %w", err) } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return nil } @@ -310,7 +313,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea } } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return user, nil } @@ -456,7 +462,10 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd return user, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return user, nil } @@ -515,7 +524,10 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup return model.User{}, err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return user, nil } @@ -576,7 +588,10 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user return err } - s.scimService.ScheduleSync() + if s.scimService != nil { + s.scimService.ScheduleSync() + } + return nil }