1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 18:30:09 +00:00

refactor: separate querying LDAP and updating DB during sync (#1371)

This commit is contained in:
Alessandro (Ale) Segala
2026-03-08 12:03:58 -07:00
committed by GitHub
parent cad80e7d74
commit d71966f996
4 changed files with 743 additions and 227 deletions

View File

@@ -35,6 +35,7 @@ type LdapService struct {
userService *UserService
groupService *UserGroupService
fileStorage storage.FileStorage
clientFactory func() (ldapClient, error)
}
type savePicture struct {
@@ -43,8 +44,33 @@ type savePicture struct {
picture string
}
type ldapDesiredUser struct {
ldapID string
input dto.UserCreateDto
picture string
}
type ldapDesiredGroup struct {
ldapID string
input dto.UserGroupCreateDto
memberUsernames []string
}
type ldapDesiredState struct {
users []ldapDesiredUser
userIDs map[string]struct{}
groups []ldapDesiredGroup
groupIDs map[string]struct{}
}
type ldapClient interface {
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
Bind(username, password string) error
Close() error
}
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
return &LdapService{
service := &LdapService{
db: db,
httpClient: httpClient,
appConfigService: appConfigService,
@@ -52,9 +78,12 @@ func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppC
groupService: groupService,
fileStorage: fileStorage,
}
service.clientFactory = service.createClient
return service
}
func (s *LdapService) createClient() (*ldap.Conn, error) {
func (s *LdapService) createClient() (ldapClient, error) {
dbConfig := s.appConfigService.GetDbConfig()
if !dbConfig.LdapEnabled.IsTrue() {
@@ -79,24 +108,33 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
func (s *LdapService) SyncAll(ctx context.Context) error {
// Setup LDAP connection
client, err := s.createClient()
client, err := s.clientFactory()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
// Start a transaction
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
// First, we fetch all users and group from LDAP, which is our "desired state"
desiredState, err := s.fetchDesiredState(ctx, client)
if err != nil {
return fmt.Errorf("failed to fetch LDAP state: %w", err)
}
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
// Start a transaction
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return fmt.Errorf("failed to begin database transaction: %w", tx.Error)
}
defer tx.Rollback()
// Reconcile users
savePictures, deleteFiles, err := s.reconcileUsers(ctx, tx, desiredState.users, desiredState.userIDs)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
err = s.SyncGroups(ctx, tx, client)
// Reconcile groups
err = s.reconcileGroups(ctx, tx, desiredState.groups, desiredState.groupIDs)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
@@ -129,10 +167,31 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
func (s *LdapService) fetchDesiredState(ctx context.Context, client ldapClient) (ldapDesiredState, error) {
// Fetch users first so we can use their DNs when resolving group members
users, userIDs, usernamesByDN, err := s.fetchUsersFromLDAP(ctx, client)
if err != nil {
return ldapDesiredState{}, err
}
// Then fetch groups to complete the desired LDAP state snapshot
groups, groupIDs, err := s.fetchGroupsFromLDAP(ctx, client, usernamesByDN)
if err != nil {
return ldapDesiredState{}, err
}
return ldapDesiredState{
users: users,
userIDs: userIDs,
groups: groups,
groupIDs: groupIDs,
}, nil
}
func (s *LdapService) fetchGroupsFromLDAP(ctx context.Context, client ldapClient, usernamesByDN map[string]string) (desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}, err error) {
dbConfig := s.appConfigService.GetDbConfig()
// Query LDAP for all groups we want to manage
searchAttrs := []string{
dbConfig.LdapAttributeGroupName.Value,
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
@@ -149,90 +208,42 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
)
result, err := client.Search(searchReq)
if err != nil {
return fmt.Errorf("failed to query LDAP: %w", err)
return nil, nil, fmt.Errorf("failed to query LDAP groups: %w", err)
}
// Create a mapping for groups that exist
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
// Build the in-memory desired state for groups
ldapGroupIDs = make(map[string]struct{}, len(result.Entries))
desiredGroups = make([]ldapDesiredGroup, 0, len(result.Entries))
for _, value := range result.Entries {
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
// Skip groups without a valid LDAP ID
if ldapId == "" {
if ldapID == "" {
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
continue
}
ldapGroupIDs[ldapId] = struct{}{}
// Try to find the group in the database
var databaseGroup model.UserGroup
err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseGroup).
Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
}
ldapGroupIDs[ldapID] = struct{}{}
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
membersUserId := make([]string, 0, len(groupMembers))
memberUsernames := make([]string, 0, len(groupMembers))
for _, member := range groupMembers {
username := getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
// If username extraction fails, try to query LDAP directly for the user
username := s.resolveGroupMemberUsername(ctx, client, member, usernamesByDN)
if username == "" {
// Query LDAP to get the user by their DN
userSearchReq := ldap.NewSearchRequest(
member,
ldap.ScopeBaseObject,
0, 0, 0, false,
"(objectClass=*)",
[]string{dbConfig.LdapAttributeUserUsername.Value, dbConfig.LdapAttributeUserUniqueIdentifier.Value},
[]ldap.Control{},
)
userResult, err := client.Search(userSearchReq)
if err != nil || len(userResult.Entries) == 0 {
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
continue
}
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
if username == "" {
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
continue
}
}
username = norm.NFC.String(username)
var databaseUser model.User
err = tx.
WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", username).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
continue
} else if err != nil {
return fmt.Errorf("failed to query for existing user '%s': %w", username, err)
}
membersUserId = append(membersUserId, databaseUser.ID)
memberUsernames = append(memberUsernames, username)
}
syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: ldapId,
LdapID: ldapID,
}
dto.Normalize(syncGroup)
dto.Normalize(&syncGroup)
err = syncGroup.Validate()
if err != nil {
@@ -240,64 +251,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
continue
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
desiredGroups = append(desiredGroups, ldapDesiredGroup{
ldapID: ldapID,
input: syncGroup,
memberUsernames: memberUsernames,
})
}
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
} else {
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
}
}
// Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup
err = tx.
WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
return fmt.Errorf("failed to fetch groups from database: %w", err)
}
// Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
}
slog.Info("Deleted group", slog.String("group", group.Name))
}
return nil
return desiredGroups, ldapGroupIDs, nil
}
//nolint:gocognit
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
func (s *LdapService) fetchUsersFromLDAP(ctx context.Context, client ldapClient) (desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}, usernamesByDN map[string]string, err error) {
dbConfig := s.appConfigService.GetDbConfig()
// Query LDAP for all users we want to manage
searchAttrs := []string{
"memberOf",
"sn",
@@ -323,50 +290,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
result, err := client.Search(searchReq)
if err != nil {
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
return nil, nil, nil, fmt.Errorf("failed to query LDAP users: %w", err)
}
// Create a mapping for users that exist
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
savePictures = make([]savePicture, 0, len(result.Entries))
// Build the in-memory desired state for users and a DN lookup for group membership resolution
ldapUserIDs = make(map[string]struct{}, len(result.Entries))
usernamesByDN = make(map[string]string, len(result.Entries))
desiredUsers = make([]ldapDesiredUser, 0, len(result.Entries))
for _, value := range result.Entries {
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
username := norm.NFC.String(value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value))
if normalizedDN := normalizeLDAPDN(value.DN); normalizedDN != "" && username != "" {
usernamesByDN[normalizedDN] = username
}
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
// Skip users without a valid LDAP ID
if ldapId == "" {
if ldapID == "" {
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
continue
}
ldapUserIDs[ldapId] = struct{}{}
// Get the user from the database
var databaseUser model.User
err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseUser).
Error
// If a user is found (even if disabled), enable them since they're now back in LDAP
if databaseUser.ID != "" && databaseUser.Disabled {
err = tx.
WithContext(ctx).
Model(&model.User{}).
Where("id = ?", databaseUser.ID).
Update("disabled", false).
Error
if err != nil {
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
}
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
}
ldapUserIDs[ldapID] = struct{}{}
// Check if user is admin by checking if they are in the admin group
isAdmin := false
@@ -385,14 +331,14 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
IsAdmin: isAdmin,
LdapID: ldapId,
LdapID: ldapID,
}
if newUser.DisplayName == "" {
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
}
dto.Normalize(newUser)
dto.Normalize(&newUser)
err = newUser.Validate()
if err != nil {
@@ -400,53 +346,201 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
continue
}
userID := databaseUser.ID
if databaseUser.ID == "" {
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
}
userID = createdUser.ID
} else {
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
}
}
// Save profile picture
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
if pictureString != "" {
// Storage operations must be executed outside of a transaction
savePictures = append(savePictures, savePicture{
userID: databaseUser.ID,
username: userID,
picture: pictureString,
desiredUsers = append(desiredUsers, ldapDesiredUser{
ldapID: ldapID,
input: newUser,
picture: value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value),
})
}
return desiredUsers, ldapUserIDs, usernamesByDN, nil
}
func (s *LdapService) resolveGroupMemberUsername(ctx context.Context, client ldapClient, member string, usernamesByDN map[string]string) string {
dbConfig := s.appConfigService.GetDbConfig()
// First try the DN cache we built while loading users
username, exists := usernamesByDN[normalizeLDAPDN(member)]
if exists && username != "" {
return username
}
// Then try to extract the username directly from the DN
username = getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
if username != "" {
return norm.NFC.String(username)
}
// As a fallback, query LDAP for the referenced entry
userSearchReq := ldap.NewSearchRequest(
member,
ldap.ScopeBaseObject,
0, 0, 0, false,
"(objectClass=*)",
[]string{dbConfig.LdapAttributeUserUsername.Value},
[]ldap.Control{},
)
userResult, err := client.Search(userSearchReq)
if err != nil || len(userResult.Entries) == 0 {
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
return ""
}
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
if username == "" {
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
return ""
}
return norm.NFC.String(username)
}
func (s *LdapService) reconcileGroups(ctx context.Context, tx *gorm.DB, desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}) error {
// Load the current LDAP-managed state from the database
ldapGroupsInDB, ldapGroupsByID, err := s.loadLDAPGroupsInDB(ctx, tx)
if err != nil {
return fmt.Errorf("failed to fetch groups from database: %w", err)
}
_, _, ldapUsersByUsername, err := s.loadLDAPUsersInDB(ctx, tx)
if err != nil {
return fmt.Errorf("failed to fetch users from database: %w", err)
}
// Apply creates and updates to match the desired LDAP group state
for _, desiredGroup := range desiredGroups {
memberUserIDs := make([]string, 0, len(desiredGroup.memberUsernames))
for _, username := range desiredGroup.memberUsernames {
databaseUser, exists := ldapUsersByUsername[username]
if !exists {
// The user collides with a non-LDAP user or was skipped during user sync, so we ignore it
continue
}
memberUserIDs = append(memberUserIDs, databaseUser.ID)
}
databaseGroup := ldapGroupsByID[desiredGroup.ldapID]
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, desiredGroup.input, tx)
if err != nil {
return fmt.Errorf("failed to create group '%s': %w", desiredGroup.input.Name, err)
}
ldapGroupsByID[desiredGroup.ldapID] = newGroup
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, memberUserIDs, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
}
continue
}
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, desiredGroup.input, true, tx)
if err != nil {
return fmt.Errorf("failed to update group '%s': %w", desiredGroup.input.Name, err)
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, memberUserIDs, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
}
}
// Delete groups that are no longer present in LDAP
for _, group := range ldapGroupsInDB {
if group.LdapID == nil {
continue
}
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
continue
}
// Get all LDAP users from the database
var ldapUsersInDb []model.User
err = tx.
WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("id, username, ldap_id, disabled").
Delete(&model.UserGroup{}, "ldap_id = ?", *group.LdapID).
Error
if err != nil {
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
}
slog.Info("Deleted group", slog.String("group", group.Name))
}
return nil
}
//nolint:gocognit
func (s *LdapService) reconcileUsers(ctx context.Context, tx *gorm.DB, desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}) (savePictures []savePicture, deleteFiles []string, err error) {
dbConfig := s.appConfigService.GetDbConfig()
// Load the current LDAP-managed state from the database
ldapUsersInDB, ldapUsersByID, _, err := s.loadLDAPUsersInDB(ctx, tx)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
}
// Mark users as disabled or delete users that no longer exist in LDAP
deleteFiles = make([]string, 0, len(ldapUserIDs))
for _, user := range ldapUsersInDb {
// Skip if the user ID exists in the fetched LDAP results
// Apply creates and updates to match the desired LDAP user state
savePictures = make([]savePicture, 0, len(desiredUsers))
for _, desiredUser := range desiredUsers {
databaseUser := ldapUsersByID[desiredUser.ldapID]
// If a user is found (even if disabled), enable them since they're now back in LDAP.
if databaseUser.ID != "" && databaseUser.Disabled {
err = tx.
WithContext(ctx).
Model(&model.User{}).
Where("id = ?", databaseUser.ID).
Update("disabled", false).
Error
if err != nil {
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
}
databaseUser.Disabled = false
ldapUsersByID[desiredUser.ldapID] = databaseUser
}
userID := databaseUser.ID
if databaseUser.ID == "" {
createdUser, err := s.userService.createUserInternal(ctx, desiredUser.input, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping creating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
continue
} else if err != nil {
return nil, nil, fmt.Errorf("error creating user '%s': %w", desiredUser.input.Username, err)
}
userID = createdUser.ID
ldapUsersByID[desiredUser.ldapID] = createdUser
} else {
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, desiredUser.input, false, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping updating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
continue
} else if err != nil {
return nil, nil, fmt.Errorf("error updating user '%s': %w", desiredUser.input.Username, err)
}
}
if desiredUser.picture != "" {
savePictures = append(savePictures, savePicture{
userID: userID,
username: desiredUser.input.Username,
picture: desiredUser.picture,
})
}
}
// Disable or delete users that are no longer present in LDAP
deleteFiles = make([]string, 0, len(ldapUsersInDB))
for _, user := range ldapUsersInDB {
if user.LdapID == nil {
continue
}
if _, exists := ldapUserIDs[*user.LdapID]; exists {
continue
}
@@ -458,7 +552,9 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
}
slog.Info("Disabled user", slog.String("username", user.Username))
} else {
continue
}
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
if err != nil {
target := &common.LdapUserUpdateError{}
@@ -469,18 +565,60 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
}
slog.Info("Deleted user", slog.String("username", user.Username))
// Storage operations must be executed outside of a transaction
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
}
}
return savePictures, deleteFiles, nil
}
func (s *LdapService) loadLDAPUsersInDB(ctx context.Context, tx *gorm.DB) (users []model.User, byLdapID map[string]model.User, byUsername map[string]model.User, err error) {
// Load all LDAP-managed users and index them by LDAP ID and by username
err = tx.
WithContext(ctx).
Select("id, username, ldap_id, disabled").
Where("ldap_id IS NOT NULL").
Find(&users).
Error
if err != nil {
return nil, nil, nil, err
}
byLdapID = make(map[string]model.User, len(users))
byUsername = make(map[string]model.User, len(users))
for _, user := range users {
byLdapID[*user.LdapID] = user
byUsername[user.Username] = user
}
return users, byLdapID, byUsername, nil
}
func (s *LdapService) loadLDAPGroupsInDB(ctx context.Context, tx *gorm.DB) ([]model.UserGroup, map[string]model.UserGroup, error) {
var groups []model.UserGroup
// Load all LDAP-managed groups and index them by LDAP ID
err := tx.
WithContext(ctx).
Select("id, name, ldap_id").
Where("ldap_id IS NOT NULL").
Find(&groups).
Error
if err != nil {
return nil, nil, err
}
groupsByID := make(map[string]model.UserGroup, len(groups))
for _, group := range groups {
groupsByID[*group.LdapID] = group
}
return groups, groupsByID, nil
}
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.ReadSeeker
// Accept either a URL, a base64-encoded payload, or raw binary data
_, err := url.ParseRequestURI(pictureString)
if err == nil {
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
@@ -522,6 +660,31 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
return nil
}
// normalizeLDAPDN returns a canonical lowercase form of a DN for use as a map key.
// Different LDAP servers may format the same DN with varying attribute type casing (e.g. "CN=" vs "cn=") or extra whitespace (e.g. "dc=example, dc=com").
// Without normalization, cache lookups in usernamesByDN would miss when a member attribute value uses a different format than the DN returned in the search entry
//
// ldap.ParseDN is used instead of simple lowercasing because it correctly handles multi-valued RDNs (joined with "+") and strips inter-component whitespace.
// If parsing fails for any reason, we fall back to a simple lowercase+trim.
func normalizeLDAPDN(dn string) string {
parsed, err := ldap.ParseDN(dn)
if err != nil {
return strings.ToLower(strings.TrimSpace(dn))
}
// Reconstruct the DN in a canonical form: lowercase type=lowercase value, with RDN components separated by "," and multi-value attributes by "+"
parts := make([]string, 0, len(parsed.RDNs))
for _, rdn := range parsed.RDNs {
attrs := make([]string, 0, len(rdn.Attributes))
for _, attr := range rdn.Attributes {
attrs = append(attrs, strings.ToLower(attr.Type)+"="+strings.ToLower(attr.Value))
}
parts = append(parts, strings.Join(attrs, "+"))
}
return strings.Join(parts, ",")
}
// getDNProperty returns the value of a property from a LDAP identifier
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
func getDNProperty(property string, str string) string {

View File

@@ -1,9 +1,286 @@
package service
import (
"net/http"
"testing"
"github.com/go-ldap/ldap/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/storage"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
type fakeLDAPClient struct {
searchFn func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
}
func (c *fakeLDAPClient) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
if c.searchFn == nil {
return nil, nil
}
return c.searchFn(searchRequest)
}
func (c *fakeLDAPClient) Bind(_, _ string) error {
return nil
}
func (c *fakeLDAPClient) Close() error {
return nil
}
func TestLdapServiceSyncAllReconcilesUsersAndGroups(t *testing.T) {
service, db := newTestLdapService(t, newFakeLDAPClient(
ldapSearchResult(
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
"entryUUID": {"u-alice"},
"uid": {"alice"},
"mail": {"alice@example.com"},
"givenName": {"Alice"},
"sn": {"Jones"},
"displayName": {""},
"memberOf": {"cn=admins,ou=groups,dc=example,dc=com"},
}),
ldapEntry("uid=bob,ou=people,dc=example,dc=com", map[string][]string{
"entryUUID": {"u-bob"},
"uid": {"bob"},
"mail": {"bob@example.com"},
"givenName": {"Bob"},
"sn": {"Brown"},
"displayName": {""},
}),
),
ldapSearchResult(
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
"entryUUID": {"g-team"},
"cn": {"team"},
"member": {
"UID=Alice, OU=People, DC=example, DC=com",
"uid=bob, ou=people, dc=example, dc=com",
},
}),
),
))
aliceLdapID := "u-alice"
missingLdapID := "u-missing"
teamLdapID := "g-team"
oldGroupLdapID := "g-old"
require.NoError(t, db.Create(&model.User{
Username: "alice-old",
Email: new("alice-old@example.com"),
EmailVerified: true,
FirstName: "Old",
LastName: "Name",
DisplayName: "Old Name",
LdapID: &aliceLdapID,
Disabled: true,
}).Error)
require.NoError(t, db.Create(&model.User{
Username: "missing",
Email: new("missing@example.com"),
EmailVerified: true,
FirstName: "Missing",
LastName: "User",
DisplayName: "Missing User",
LdapID: &missingLdapID,
}).Error)
require.NoError(t, db.Create(&model.UserGroup{
Name: "team-old",
FriendlyName: "team-old",
LdapID: &teamLdapID,
}).Error)
require.NoError(t, db.Create(&model.UserGroup{
Name: "old-group",
FriendlyName: "old-group",
LdapID: &oldGroupLdapID,
}).Error)
require.NoError(t, service.SyncAll(t.Context()))
var alice model.User
require.NoError(t, db.First(&alice, "ldap_id = ?", aliceLdapID).Error)
assert.Equal(t, "alice", alice.Username)
assert.Equal(t, new("alice@example.com"), alice.Email)
assert.Equal(t, "Alice", alice.FirstName)
assert.Equal(t, "Jones", alice.LastName)
assert.Equal(t, "Alice Jones", alice.DisplayName)
assert.True(t, alice.IsAdmin)
assert.False(t, alice.Disabled)
var bob model.User
require.NoError(t, db.First(&bob, "ldap_id = ?", "u-bob").Error)
assert.Equal(t, "bob", bob.Username)
assert.Equal(t, "Bob Brown", bob.DisplayName)
var missing model.User
require.NoError(t, db.First(&missing, "ldap_id = ?", missingLdapID).Error)
assert.True(t, missing.Disabled)
var oldGroupCount int64
require.NoError(t, db.Model(&model.UserGroup{}).Where("ldap_id = ?", oldGroupLdapID).Count(&oldGroupCount).Error)
assert.Zero(t, oldGroupCount)
var team model.UserGroup
require.NoError(t, db.Preload("Users").First(&team, "ldap_id = ?", teamLdapID).Error)
assert.Equal(t, "team", team.Name)
assert.Equal(t, "team", team.FriendlyName)
assert.ElementsMatch(t, []string{"alice", "bob"}, usernames(team.Users))
}
func TestLdapServiceSyncAllHandlesDuplicateLDAPIDsInSingleRun(t *testing.T) {
service, db := newTestLdapService(t, newFakeLDAPClient(
ldapSearchResult(
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
"entryUUID": {"u-dup"},
"uid": {"alice"},
"mail": {"alice@example.com"},
"givenName": {"Alice"},
"sn": {"Doe"},
"displayName": {"Alice Doe"},
}),
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
"entryUUID": {"u-dup"},
"uid": {"alice"},
"mail": {"alice@example.com"},
"givenName": {"Alicia"},
"sn": {"Doe"},
"displayName": {"Alicia Doe"},
}),
),
ldapSearchResult(
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
"entryUUID": {"g-dup"},
"cn": {"team"},
"member": {"uid=alice,ou=people,dc=example,dc=com"},
}),
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
"entryUUID": {"g-dup"},
"cn": {"team-renamed"},
"member": {"uid=alice,ou=people,dc=example,dc=com"},
}),
),
))
require.NoError(t, service.SyncAll(t.Context()))
var users []model.User
require.NoError(t, db.Find(&users, "ldap_id = ?", "u-dup").Error)
require.Len(t, users, 1)
assert.Equal(t, "alice", users[0].Username)
assert.Equal(t, "Alicia", users[0].FirstName)
assert.Equal(t, "Alicia Doe", users[0].DisplayName)
var groups []model.UserGroup
require.NoError(t, db.Preload("Users").Find(&groups, "ldap_id = ?", "g-dup").Error)
require.Len(t, groups, 1)
assert.Equal(t, "team-renamed", groups[0].Name)
assert.Equal(t, "team-renamed", groups[0].FriendlyName)
assert.ElementsMatch(t, []string{"alice"}, usernames(groups[0].Users))
}
func newTestLdapService(t *testing.T, client ldapClient) (*LdapService, *gorm.DB) {
t.Helper()
db := testutils.NewDatabaseForTest(t)
fileStorage, err := storage.NewDatabaseStorage(db)
require.NoError(t, err)
appConfig := NewTestAppConfigService(&model.AppConfig{
RequireUserEmail: model.AppConfigVariable{Value: "false"},
LdapEnabled: model.AppConfigVariable{Value: "true"},
LdapBase: model.AppConfigVariable{Value: "dc=example,dc=com"},
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
LdapAttributeUserUsername: model.AppConfigVariable{Value: "uid"},
LdapAttributeUserEmail: model.AppConfigVariable{Value: "mail"},
LdapAttributeUserFirstName: model.AppConfigVariable{Value: "givenName"},
LdapAttributeUserLastName: model.AppConfigVariable{Value: "sn"},
LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "displayName"},
LdapAttributeUserProfilePicture: model.AppConfigVariable{Value: "jpegPhoto"},
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
LdapAttributeGroupName: model.AppConfigVariable{Value: "cn"},
LdapAdminGroupName: model.AppConfigVariable{Value: "admins"},
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
})
groupService := NewUserGroupService(db, appConfig, nil)
userService := NewUserService(
db,
nil,
nil,
nil,
appConfig,
NewCustomClaimService(db),
NewAppImagesService(map[string]string{}, fileStorage),
nil,
fileStorage,
)
service := NewLdapService(db, &http.Client{}, appConfig, userService, groupService, fileStorage)
service.clientFactory = func() (ldapClient, error) {
return client, nil
}
return service, db
}
func newFakeLDAPClient(userResult, groupResult *ldap.SearchResult) ldapClient {
return &fakeLDAPClient{
searchFn: func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
switch searchRequest.Filter {
case "(objectClass=person)":
return userResult, nil
case "(objectClass=groupOfNames)":
return groupResult, nil
default:
return &ldap.SearchResult{}, nil
}
},
}
}
func ldapSearchResult(entries ...*ldap.Entry) *ldap.SearchResult {
return &ldap.SearchResult{Entries: entries}
}
func ldapEntry(dn string, attrs map[string][]string) *ldap.Entry {
entry := &ldap.Entry{
DN: dn,
Attributes: make([]*ldap.EntryAttribute, 0, len(attrs)),
}
for name, values := range attrs {
entry.Attributes = append(entry.Attributes, &ldap.EntryAttribute{
Name: name,
Values: values,
})
}
return entry
}
func usernames(users []model.User) []string {
result := make([]string, 0, len(users))
for _, user := range users {
result = append(result, user.Username)
}
return result
}
func TestGetDNProperty(t *testing.T) {
tests := []struct {
name string
@@ -64,10 +341,58 @@ func TestGetDNProperty(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getDNProperty(tt.property, tt.dn)
if result != tt.expectedResult {
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
tt.property, tt.dn, result, tt.expectedResult)
assert.Equalf(t, tt.expectedResult, result, "getDNProperty(%q, %q)", tt.property, tt.dn)
})
}
}
func TestNormalizeLDAPDN(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "already normalized",
input: "cn=alice,dc=example,dc=com",
expected: "cn=alice,dc=example,dc=com",
},
{
name: "uppercase attribute types",
input: "CN=Alice,DC=example,DC=com",
expected: "cn=alice,dc=example,dc=com",
},
{
name: "spaces after commas",
input: "cn=alice, dc=example, dc=com",
expected: "cn=alice,dc=example,dc=com",
},
{
name: "uppercase types and spaces",
input: "CN=Alice, DC=example, DC=com",
expected: "cn=alice,dc=example,dc=com",
},
{
name: "multi-valued RDN",
input: "cn=alice+uid=a123,dc=example,dc=com",
expected: "cn=alice+uid=a123,dc=example,dc=com",
},
{
name: "invalid DN falls back to lowercase+trim",
input: " NOT A VALID DN ",
expected: "not a valid dn",
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeLDAPDN(tt.input)
assert.Equalf(t, tt.expected, result, "normalizeLDAPDN(%q)", tt.input)
})
}
}
@@ -98,9 +423,7 @@ func TestConvertLdapIdToString(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := convertLdapIdToString(tt.input)
if got != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, got)
}
assert.Equal(t, tt.expected, got)
})
}
}

View File

@@ -96,7 +96,10 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
return err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return nil
}
@@ -126,7 +129,10 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
return model.UserGroup{}, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return group, nil
}
@@ -175,7 +181,10 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
return model.UserGroup{}, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return group, nil
}
@@ -238,7 +247,10 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
return model.UserGroup{}, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return group, nil
}
@@ -315,6 +327,9 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin
return model.UserGroup{}, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return group, nil
}

View File

@@ -225,7 +225,10 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI
return fmt.Errorf("failed to delete user: %w", err)
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return nil
}
@@ -310,7 +313,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
}
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return user, nil
}
@@ -456,7 +462,10 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
return user, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return user, nil
}
@@ -515,7 +524,10 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
return model.User{}, err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return user, nil
}
@@ -576,7 +588,10 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
return err
}
if s.scimService != nil {
s.scimService.ScheduleSync()
}
return nil
}