mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 20:15:07 +00:00
refactor: separate querying LDAP and updating DB during sync (#1371)
This commit is contained in:
committed by
GitHub
parent
cad80e7d74
commit
d71966f996
@@ -35,6 +35,7 @@ type LdapService struct {
|
|||||||
userService *UserService
|
userService *UserService
|
||||||
groupService *UserGroupService
|
groupService *UserGroupService
|
||||||
fileStorage storage.FileStorage
|
fileStorage storage.FileStorage
|
||||||
|
clientFactory func() (ldapClient, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type savePicture struct {
|
type savePicture struct {
|
||||||
@@ -43,8 +44,33 @@ type savePicture struct {
|
|||||||
picture string
|
picture string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ldapDesiredUser struct {
|
||||||
|
ldapID string
|
||||||
|
input dto.UserCreateDto
|
||||||
|
picture string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ldapDesiredGroup struct {
|
||||||
|
ldapID string
|
||||||
|
input dto.UserGroupCreateDto
|
||||||
|
memberUsernames []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ldapDesiredState struct {
|
||||||
|
users []ldapDesiredUser
|
||||||
|
userIDs map[string]struct{}
|
||||||
|
groups []ldapDesiredGroup
|
||||||
|
groupIDs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ldapClient interface {
|
||||||
|
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||||
|
Bind(username, password string) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
||||||
return &LdapService{
|
service := &LdapService{
|
||||||
db: db,
|
db: db,
|
||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
appConfigService: appConfigService,
|
appConfigService: appConfigService,
|
||||||
@@ -52,9 +78,12 @@ func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppC
|
|||||||
groupService: groupService,
|
groupService: groupService,
|
||||||
fileStorage: fileStorage,
|
fileStorage: fileStorage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service.clientFactory = service.createClient
|
||||||
|
return service
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
func (s *LdapService) createClient() (ldapClient, error) {
|
||||||
dbConfig := s.appConfigService.GetDbConfig()
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
if !dbConfig.LdapEnabled.IsTrue() {
|
if !dbConfig.LdapEnabled.IsTrue() {
|
||||||
@@ -79,24 +108,33 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
|||||||
|
|
||||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
client, err := s.createClient()
|
client, err := s.clientFactory()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
// Start a transaction
|
// First, we fetch all users and group from LDAP, which is our "desired state"
|
||||||
tx := s.db.Begin()
|
desiredState, err := s.fetchDesiredState(ctx, client)
|
||||||
defer func() {
|
if err != nil {
|
||||||
tx.Rollback()
|
return fmt.Errorf("failed to fetch LDAP state: %w", err)
|
||||||
}()
|
}
|
||||||
|
|
||||||
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
|
// Start a transaction
|
||||||
|
tx := s.db.WithContext(ctx).Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
return fmt.Errorf("failed to begin database transaction: %w", tx.Error)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
// Reconcile users
|
||||||
|
savePictures, deleteFiles, err := s.reconcileUsers(ctx, tx, desiredState.users, desiredState.userIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sync users: %w", err)
|
return fmt.Errorf("failed to sync users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.SyncGroups(ctx, tx, client)
|
// Reconcile groups
|
||||||
|
err = s.reconcileGroups(ctx, tx, desiredState.groups, desiredState.groupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sync groups: %w", err)
|
return fmt.Errorf("failed to sync groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -129,10 +167,31 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
func (s *LdapService) fetchDesiredState(ctx context.Context, client ldapClient) (ldapDesiredState, error) {
|
||||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
// Fetch users first so we can use their DNs when resolving group members
|
||||||
|
users, userIDs, usernamesByDN, err := s.fetchUsersFromLDAP(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return ldapDesiredState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then fetch groups to complete the desired LDAP state snapshot
|
||||||
|
groups, groupIDs, err := s.fetchGroupsFromLDAP(ctx, client, usernamesByDN)
|
||||||
|
if err != nil {
|
||||||
|
return ldapDesiredState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ldapDesiredState{
|
||||||
|
users: users,
|
||||||
|
userIDs: userIDs,
|
||||||
|
groups: groups,
|
||||||
|
groupIDs: groupIDs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LdapService) fetchGroupsFromLDAP(ctx context.Context, client ldapClient, usernamesByDN map[string]string) (desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}, err error) {
|
||||||
dbConfig := s.appConfigService.GetDbConfig()
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
|
// Query LDAP for all groups we want to manage
|
||||||
searchAttrs := []string{
|
searchAttrs := []string{
|
||||||
dbConfig.LdapAttributeGroupName.Value,
|
dbConfig.LdapAttributeGroupName.Value,
|
||||||
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
|
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
|
||||||
@@ -149,90 +208,42 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
)
|
)
|
||||||
result, err := client.Search(searchReq)
|
result, err := client.Search(searchReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
return nil, nil, fmt.Errorf("failed to query LDAP groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a mapping for groups that exist
|
// Build the in-memory desired state for groups
|
||||||
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
|
ldapGroupIDs = make(map[string]struct{}, len(result.Entries))
|
||||||
|
desiredGroups = make([]ldapDesiredGroup, 0, len(result.Entries))
|
||||||
|
|
||||||
for _, value := range result.Entries {
|
for _, value := range result.Entries {
|
||||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||||
|
|
||||||
// Skip groups without a valid LDAP ID
|
// Skip groups without a valid LDAP ID
|
||||||
if ldapId == "" {
|
if ldapID == "" {
|
||||||
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ldapGroupIDs[ldapId] = struct{}{}
|
ldapGroupIDs[ldapID] = struct{}{}
|
||||||
|
|
||||||
// Try to find the group in the database
|
|
||||||
var databaseGroup model.UserGroup
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Where("ldap_id = ?", ldapId).
|
|
||||||
First(&databaseGroup).
|
|
||||||
Error
|
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
|
||||||
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get group members and add to the correct Group
|
// Get group members and add to the correct Group
|
||||||
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
|
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
|
||||||
membersUserId := make([]string, 0, len(groupMembers))
|
memberUsernames := make([]string, 0, len(groupMembers))
|
||||||
for _, member := range groupMembers {
|
for _, member := range groupMembers {
|
||||||
username := getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
username := s.resolveGroupMemberUsername(ctx, client, member, usernamesByDN)
|
||||||
|
|
||||||
// If username extraction fails, try to query LDAP directly for the user
|
|
||||||
if username == "" {
|
if username == "" {
|
||||||
// Query LDAP to get the user by their DN
|
|
||||||
userSearchReq := ldap.NewSearchRequest(
|
|
||||||
member,
|
|
||||||
ldap.ScopeBaseObject,
|
|
||||||
0, 0, 0, false,
|
|
||||||
"(objectClass=*)",
|
|
||||||
[]string{dbConfig.LdapAttributeUserUsername.Value, dbConfig.LdapAttributeUserUniqueIdentifier.Value},
|
|
||||||
[]ldap.Control{},
|
|
||||||
)
|
|
||||||
|
|
||||||
userResult, err := client.Search(userSearchReq)
|
|
||||||
if err != nil || len(userResult.Entries) == 0 {
|
|
||||||
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
memberUsernames = append(memberUsernames, username)
|
||||||
if username == "" {
|
|
||||||
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
username = norm.NFC.String(username)
|
|
||||||
|
|
||||||
var databaseUser model.User
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Where("username = ? AND ldap_id IS NOT NULL", username).
|
|
||||||
First(&databaseUser).
|
|
||||||
Error
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
// The user collides with a non-LDAP user, so we skip it
|
|
||||||
continue
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to query for existing user '%s': %w", username, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
membersUserId = append(membersUserId, databaseUser.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
syncGroup := dto.UserGroupCreateDto{
|
syncGroup := dto.UserGroupCreateDto{
|
||||||
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||||
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||||
LdapID: ldapId,
|
LdapID: ldapID,
|
||||||
}
|
}
|
||||||
dto.Normalize(syncGroup)
|
dto.Normalize(&syncGroup)
|
||||||
|
|
||||||
err = syncGroup.Validate()
|
err = syncGroup.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -240,64 +251,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if databaseGroup.ID == "" {
|
desiredGroups = append(desiredGroups, ldapDesiredGroup{
|
||||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
ldapID: ldapID,
|
||||||
if err != nil {
|
input: syncGroup,
|
||||||
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
|
memberUsernames: memberUsernames,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
return desiredGroups, ldapGroupIDs, nil
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all LDAP groups from the database
|
|
||||||
var ldapGroupsInDb []model.UserGroup
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
|
|
||||||
Select("ldap_id").
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete groups that no longer exist in LDAP
|
|
||||||
for _, group := range ldapGroupsInDb {
|
|
||||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("Deleted group", slog.String("group", group.Name))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
func (s *LdapService) fetchUsersFromLDAP(ctx context.Context, client ldapClient) (desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}, usernamesByDN map[string]string, err error) {
|
||||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
|
|
||||||
dbConfig := s.appConfigService.GetDbConfig()
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
|
// Query LDAP for all users we want to manage
|
||||||
searchAttrs := []string{
|
searchAttrs := []string{
|
||||||
"memberOf",
|
"memberOf",
|
||||||
"sn",
|
"sn",
|
||||||
@@ -323,50 +290,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
|
|
||||||
result, err := client.Search(searchReq)
|
result, err := client.Search(searchReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to query LDAP users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a mapping for users that exist
|
// Build the in-memory desired state for users and a DN lookup for group membership resolution
|
||||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
ldapUserIDs = make(map[string]struct{}, len(result.Entries))
|
||||||
savePictures = make([]savePicture, 0, len(result.Entries))
|
usernamesByDN = make(map[string]string, len(result.Entries))
|
||||||
|
desiredUsers = make([]ldapDesiredUser, 0, len(result.Entries))
|
||||||
|
|
||||||
for _, value := range result.Entries {
|
for _, value := range result.Entries {
|
||||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
username := norm.NFC.String(value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value))
|
||||||
|
if normalizedDN := normalizeLDAPDN(value.DN); normalizedDN != "" && username != "" {
|
||||||
|
usernamesByDN[normalizedDN] = username
|
||||||
|
}
|
||||||
|
|
||||||
|
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||||
|
|
||||||
// Skip users without a valid LDAP ID
|
// Skip users without a valid LDAP ID
|
||||||
if ldapId == "" {
|
if ldapID == "" {
|
||||||
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ldapUserIDs[ldapId] = struct{}{}
|
ldapUserIDs[ldapID] = struct{}{}
|
||||||
|
|
||||||
// Get the user from the database
|
|
||||||
var databaseUser model.User
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Where("ldap_id = ?", ldapId).
|
|
||||||
First(&databaseUser).
|
|
||||||
Error
|
|
||||||
|
|
||||||
// If a user is found (even if disabled), enable them since they're now back in LDAP
|
|
||||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Model(&model.User{}).
|
|
||||||
Where("id = ?", databaseUser.ID).
|
|
||||||
Update("disabled", false).
|
|
||||||
Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
|
||||||
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if user is admin by checking if they are in the admin group
|
// Check if user is admin by checking if they are in the admin group
|
||||||
isAdmin := false
|
isAdmin := false
|
||||||
@@ -385,14 +331,14 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
||||||
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
|
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
|
||||||
IsAdmin: isAdmin,
|
IsAdmin: isAdmin,
|
||||||
LdapID: ldapId,
|
LdapID: ldapID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if newUser.DisplayName == "" {
|
if newUser.DisplayName == "" {
|
||||||
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
|
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
|
||||||
}
|
}
|
||||||
|
|
||||||
dto.Normalize(newUser)
|
dto.Normalize(&newUser)
|
||||||
|
|
||||||
err = newUser.Validate()
|
err = newUser.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -400,53 +346,201 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := databaseUser.ID
|
desiredUsers = append(desiredUsers, ldapDesiredUser{
|
||||||
if databaseUser.ID == "" {
|
ldapID: ldapID,
|
||||||
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
|
input: newUser,
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
picture: value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value),
|
||||||
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
|
||||||
continue
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
|
||||||
}
|
|
||||||
userID = createdUser.ID
|
|
||||||
} else {
|
|
||||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
|
||||||
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
|
||||||
continue
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save profile picture
|
|
||||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
|
||||||
if pictureString != "" {
|
|
||||||
// Storage operations must be executed outside of a transaction
|
|
||||||
savePictures = append(savePictures, savePicture{
|
|
||||||
userID: databaseUser.ID,
|
|
||||||
username: userID,
|
|
||||||
picture: pictureString,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return desiredUsers, ldapUserIDs, usernamesByDN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LdapService) resolveGroupMemberUsername(ctx context.Context, client ldapClient, member string, usernamesByDN map[string]string) string {
|
||||||
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
|
// First try the DN cache we built while loading users
|
||||||
|
username, exists := usernamesByDN[normalizeLDAPDN(member)]
|
||||||
|
if exists && username != "" {
|
||||||
|
return username
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then try to extract the username directly from the DN
|
||||||
|
username = getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
||||||
|
if username != "" {
|
||||||
|
return norm.NFC.String(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// As a fallback, query LDAP for the referenced entry
|
||||||
|
userSearchReq := ldap.NewSearchRequest(
|
||||||
|
member,
|
||||||
|
ldap.ScopeBaseObject,
|
||||||
|
0, 0, 0, false,
|
||||||
|
"(objectClass=*)",
|
||||||
|
[]string{dbConfig.LdapAttributeUserUsername.Value},
|
||||||
|
[]ldap.Control{},
|
||||||
|
)
|
||||||
|
|
||||||
|
userResult, err := client.Search(userSearchReq)
|
||||||
|
if err != nil || len(userResult.Entries) == 0 {
|
||||||
|
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||||
|
if username == "" {
|
||||||
|
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return norm.NFC.String(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LdapService) reconcileGroups(ctx context.Context, tx *gorm.DB, desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}) error {
|
||||||
|
// Load the current LDAP-managed state from the database
|
||||||
|
ldapGroupsInDB, ldapGroupsByID, err := s.loadLDAPGroupsInDB(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, ldapUsersByUsername, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply creates and updates to match the desired LDAP group state
|
||||||
|
for _, desiredGroup := range desiredGroups {
|
||||||
|
memberUserIDs := make([]string, 0, len(desiredGroup.memberUsernames))
|
||||||
|
for _, username := range desiredGroup.memberUsernames {
|
||||||
|
databaseUser, exists := ldapUsersByUsername[username]
|
||||||
|
if !exists {
|
||||||
|
// The user collides with a non-LDAP user or was skipped during user sync, so we ignore it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
memberUserIDs = append(memberUserIDs, databaseUser.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseGroup := ldapGroupsByID[desiredGroup.ldapID]
|
||||||
|
if databaseGroup.ID == "" {
|
||||||
|
newGroup, err := s.groupService.createInternal(ctx, desiredGroup.input, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create group '%s': %w", desiredGroup.input.Name, err)
|
||||||
|
}
|
||||||
|
ldapGroupsByID[desiredGroup.ldapID] = newGroup
|
||||||
|
|
||||||
|
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, memberUserIDs, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, desiredGroup.input, true, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update group '%s': %w", desiredGroup.input.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, memberUserIDs, tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete groups that are no longer present in LDAP
|
||||||
|
for _, group := range ldapGroupsInDB {
|
||||||
|
if group.LdapID == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all LDAP users from the database
|
|
||||||
var ldapUsersInDb []model.User
|
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
Delete(&model.UserGroup{}, "ldap_id = ?", *group.LdapID).
|
||||||
Select("id, username, ldap_id, disabled").
|
|
||||||
Error
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("Deleted group", slog.String("group", group.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gocognit
|
||||||
|
func (s *LdapService) reconcileUsers(ctx context.Context, tx *gorm.DB, desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||||
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
|
// Load the current LDAP-managed state from the database
|
||||||
|
ldapUsersInDB, ldapUsersByID, _, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
// Apply creates and updates to match the desired LDAP user state
|
||||||
deleteFiles = make([]string, 0, len(ldapUserIDs))
|
savePictures = make([]savePicture, 0, len(desiredUsers))
|
||||||
for _, user := range ldapUsersInDb {
|
|
||||||
// Skip if the user ID exists in the fetched LDAP results
|
for _, desiredUser := range desiredUsers {
|
||||||
|
databaseUser := ldapUsersByID[desiredUser.ldapID]
|
||||||
|
|
||||||
|
// If a user is found (even if disabled), enable them since they're now back in LDAP.
|
||||||
|
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||||
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Model(&model.User{}).
|
||||||
|
Where("id = ?", databaseUser.ID).
|
||||||
|
Update("disabled", false).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseUser.Disabled = false
|
||||||
|
ldapUsersByID[desiredUser.ldapID] = databaseUser
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := databaseUser.ID
|
||||||
|
if databaseUser.ID == "" {
|
||||||
|
createdUser, err := s.userService.createUserInternal(ctx, desiredUser.input, true, tx)
|
||||||
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
|
slog.Warn("Skipping creating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error creating user '%s': %w", desiredUser.input.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID = createdUser.ID
|
||||||
|
ldapUsersByID[desiredUser.ldapID] = createdUser
|
||||||
|
} else {
|
||||||
|
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, desiredUser.input, false, true, tx)
|
||||||
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
|
slog.Warn("Skipping updating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error updating user '%s': %w", desiredUser.input.Username, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if desiredUser.picture != "" {
|
||||||
|
savePictures = append(savePictures, savePicture{
|
||||||
|
userID: userID,
|
||||||
|
username: desiredUser.input.Username,
|
||||||
|
picture: desiredUser.picture,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable or delete users that are no longer present in LDAP
|
||||||
|
deleteFiles = make([]string, 0, len(ldapUsersInDB))
|
||||||
|
for _, user := range ldapUsersInDB {
|
||||||
|
if user.LdapID == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -458,7 +552,9 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Disabled user", slog.String("username", user.Username))
|
slog.Info("Disabled user", slog.String("username", user.Username))
|
||||||
} else {
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
target := &common.LdapUserUpdateError{}
|
target := &common.LdapUserUpdateError{}
|
||||||
@@ -469,18 +565,60 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||||
|
|
||||||
// Storage operations must be executed outside of a transaction
|
|
||||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return savePictures, deleteFiles, nil
|
return savePictures, deleteFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *LdapService) loadLDAPUsersInDB(ctx context.Context, tx *gorm.DB) (users []model.User, byLdapID map[string]model.User, byUsername map[string]model.User, err error) {
|
||||||
|
// Load all LDAP-managed users and index them by LDAP ID and by username
|
||||||
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Select("id, username, ldap_id, disabled").
|
||||||
|
Where("ldap_id IS NOT NULL").
|
||||||
|
Find(&users).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
byLdapID = make(map[string]model.User, len(users))
|
||||||
|
byUsername = make(map[string]model.User, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
byLdapID[*user.LdapID] = user
|
||||||
|
byUsername[user.Username] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, byLdapID, byUsername, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LdapService) loadLDAPGroupsInDB(ctx context.Context, tx *gorm.DB) ([]model.UserGroup, map[string]model.UserGroup, error) {
|
||||||
|
var groups []model.UserGroup
|
||||||
|
|
||||||
|
// Load all LDAP-managed groups and index them by LDAP ID
|
||||||
|
err := tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Select("id, name, ldap_id").
|
||||||
|
Where("ldap_id IS NOT NULL").
|
||||||
|
Find(&groups).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsByID := make(map[string]model.UserGroup, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsByID[*group.LdapID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups, groupsByID, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||||
var reader io.ReadSeeker
|
var reader io.ReadSeeker
|
||||||
|
|
||||||
|
// Accept either a URL, a base64-encoded payload, or raw binary data
|
||||||
_, err := url.ParseRequestURI(pictureString)
|
_, err := url.ParseRequestURI(pictureString)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||||
@@ -522,6 +660,31 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeLDAPDN returns a canonical lowercase form of a DN for use as a map key.
|
||||||
|
// Different LDAP servers may format the same DN with varying attribute type casing (e.g. "CN=" vs "cn=") or extra whitespace (e.g. "dc=example, dc=com").
|
||||||
|
// Without normalization, cache lookups in usernamesByDN would miss when a member attribute value uses a different format than the DN returned in the search entry
|
||||||
|
//
|
||||||
|
// ldap.ParseDN is used instead of simple lowercasing because it correctly handles multi-valued RDNs (joined with "+") and strips inter-component whitespace.
|
||||||
|
// If parsing fails for any reason, we fall back to a simple lowercase+trim.
|
||||||
|
func normalizeLDAPDN(dn string) string {
|
||||||
|
parsed, err := ldap.ParseDN(dn)
|
||||||
|
if err != nil {
|
||||||
|
return strings.ToLower(strings.TrimSpace(dn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the DN in a canonical form: lowercase type=lowercase value, with RDN components separated by "," and multi-value attributes by "+"
|
||||||
|
parts := make([]string, 0, len(parsed.RDNs))
|
||||||
|
for _, rdn := range parsed.RDNs {
|
||||||
|
attrs := make([]string, 0, len(rdn.Attributes))
|
||||||
|
for _, attr := range rdn.Attributes {
|
||||||
|
attrs = append(attrs, strings.ToLower(attr.Type)+"="+strings.ToLower(attr.Value))
|
||||||
|
}
|
||||||
|
parts = append(parts, strings.Join(attrs, "+"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(parts, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// getDNProperty returns the value of a property from a LDAP identifier
|
// getDNProperty returns the value of a property from a LDAP identifier
|
||||||
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
|
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
|
||||||
func getDNProperty(property string, str string) string {
|
func getDNProperty(property string, str string) string {
|
||||||
|
|||||||
@@ -1,9 +1,286 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-ldap/ldap/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||||
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type fakeLDAPClient struct {
|
||||||
|
searchFn func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeLDAPClient) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||||
|
if c.searchFn == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.searchFn(searchRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeLDAPClient) Bind(_, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeLDAPClient) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLdapServiceSyncAllReconcilesUsersAndGroups(t *testing.T) {
|
||||||
|
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||||
|
ldapSearchResult(
|
||||||
|
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"u-alice"},
|
||||||
|
"uid": {"alice"},
|
||||||
|
"mail": {"alice@example.com"},
|
||||||
|
"givenName": {"Alice"},
|
||||||
|
"sn": {"Jones"},
|
||||||
|
"displayName": {""},
|
||||||
|
"memberOf": {"cn=admins,ou=groups,dc=example,dc=com"},
|
||||||
|
}),
|
||||||
|
ldapEntry("uid=bob,ou=people,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"u-bob"},
|
||||||
|
"uid": {"bob"},
|
||||||
|
"mail": {"bob@example.com"},
|
||||||
|
"givenName": {"Bob"},
|
||||||
|
"sn": {"Brown"},
|
||||||
|
"displayName": {""},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ldapSearchResult(
|
||||||
|
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"g-team"},
|
||||||
|
"cn": {"team"},
|
||||||
|
"member": {
|
||||||
|
"UID=Alice, OU=People, DC=example, DC=com",
|
||||||
|
"uid=bob, ou=people, dc=example, dc=com",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
aliceLdapID := "u-alice"
|
||||||
|
missingLdapID := "u-missing"
|
||||||
|
teamLdapID := "g-team"
|
||||||
|
oldGroupLdapID := "g-old"
|
||||||
|
|
||||||
|
require.NoError(t, db.Create(&model.User{
|
||||||
|
Username: "alice-old",
|
||||||
|
Email: new("alice-old@example.com"),
|
||||||
|
EmailVerified: true,
|
||||||
|
FirstName: "Old",
|
||||||
|
LastName: "Name",
|
||||||
|
DisplayName: "Old Name",
|
||||||
|
LdapID: &aliceLdapID,
|
||||||
|
Disabled: true,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
require.NoError(t, db.Create(&model.User{
|
||||||
|
Username: "missing",
|
||||||
|
Email: new("missing@example.com"),
|
||||||
|
EmailVerified: true,
|
||||||
|
FirstName: "Missing",
|
||||||
|
LastName: "User",
|
||||||
|
DisplayName: "Missing User",
|
||||||
|
LdapID: &missingLdapID,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
require.NoError(t, db.Create(&model.UserGroup{
|
||||||
|
Name: "team-old",
|
||||||
|
FriendlyName: "team-old",
|
||||||
|
LdapID: &teamLdapID,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
require.NoError(t, db.Create(&model.UserGroup{
|
||||||
|
Name: "old-group",
|
||||||
|
FriendlyName: "old-group",
|
||||||
|
LdapID: &oldGroupLdapID,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
require.NoError(t, service.SyncAll(t.Context()))
|
||||||
|
|
||||||
|
var alice model.User
|
||||||
|
require.NoError(t, db.First(&alice, "ldap_id = ?", aliceLdapID).Error)
|
||||||
|
assert.Equal(t, "alice", alice.Username)
|
||||||
|
assert.Equal(t, new("alice@example.com"), alice.Email)
|
||||||
|
assert.Equal(t, "Alice", alice.FirstName)
|
||||||
|
assert.Equal(t, "Jones", alice.LastName)
|
||||||
|
assert.Equal(t, "Alice Jones", alice.DisplayName)
|
||||||
|
assert.True(t, alice.IsAdmin)
|
||||||
|
assert.False(t, alice.Disabled)
|
||||||
|
|
||||||
|
var bob model.User
|
||||||
|
require.NoError(t, db.First(&bob, "ldap_id = ?", "u-bob").Error)
|
||||||
|
assert.Equal(t, "bob", bob.Username)
|
||||||
|
assert.Equal(t, "Bob Brown", bob.DisplayName)
|
||||||
|
|
||||||
|
var missing model.User
|
||||||
|
require.NoError(t, db.First(&missing, "ldap_id = ?", missingLdapID).Error)
|
||||||
|
assert.True(t, missing.Disabled)
|
||||||
|
|
||||||
|
var oldGroupCount int64
|
||||||
|
require.NoError(t, db.Model(&model.UserGroup{}).Where("ldap_id = ?", oldGroupLdapID).Count(&oldGroupCount).Error)
|
||||||
|
assert.Zero(t, oldGroupCount)
|
||||||
|
|
||||||
|
var team model.UserGroup
|
||||||
|
require.NoError(t, db.Preload("Users").First(&team, "ldap_id = ?", teamLdapID).Error)
|
||||||
|
assert.Equal(t, "team", team.Name)
|
||||||
|
assert.Equal(t, "team", team.FriendlyName)
|
||||||
|
assert.ElementsMatch(t, []string{"alice", "bob"}, usernames(team.Users))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLdapServiceSyncAllHandlesDuplicateLDAPIDsInSingleRun(t *testing.T) {
|
||||||
|
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||||
|
ldapSearchResult(
|
||||||
|
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"u-dup"},
|
||||||
|
"uid": {"alice"},
|
||||||
|
"mail": {"alice@example.com"},
|
||||||
|
"givenName": {"Alice"},
|
||||||
|
"sn": {"Doe"},
|
||||||
|
"displayName": {"Alice Doe"},
|
||||||
|
}),
|
||||||
|
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"u-dup"},
|
||||||
|
"uid": {"alice"},
|
||||||
|
"mail": {"alice@example.com"},
|
||||||
|
"givenName": {"Alicia"},
|
||||||
|
"sn": {"Doe"},
|
||||||
|
"displayName": {"Alicia Doe"},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ldapSearchResult(
|
||||||
|
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"g-dup"},
|
||||||
|
"cn": {"team"},
|
||||||
|
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||||
|
}),
|
||||||
|
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||||
|
"entryUUID": {"g-dup"},
|
||||||
|
"cn": {"team-renamed"},
|
||||||
|
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
require.NoError(t, service.SyncAll(t.Context()))
|
||||||
|
|
||||||
|
var users []model.User
|
||||||
|
require.NoError(t, db.Find(&users, "ldap_id = ?", "u-dup").Error)
|
||||||
|
require.Len(t, users, 1)
|
||||||
|
assert.Equal(t, "alice", users[0].Username)
|
||||||
|
assert.Equal(t, "Alicia", users[0].FirstName)
|
||||||
|
assert.Equal(t, "Alicia Doe", users[0].DisplayName)
|
||||||
|
|
||||||
|
var groups []model.UserGroup
|
||||||
|
require.NoError(t, db.Preload("Users").Find(&groups, "ldap_id = ?", "g-dup").Error)
|
||||||
|
require.Len(t, groups, 1)
|
||||||
|
assert.Equal(t, "team-renamed", groups[0].Name)
|
||||||
|
assert.Equal(t, "team-renamed", groups[0].FriendlyName)
|
||||||
|
assert.ElementsMatch(t, []string{"alice"}, usernames(groups[0].Users))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestLdapService(t *testing.T, client ldapClient) (*LdapService, *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
|
fileStorage, err := storage.NewDatabaseStorage(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
appConfig := NewTestAppConfigService(&model.AppConfig{
|
||||||
|
RequireUserEmail: model.AppConfigVariable{Value: "false"},
|
||||||
|
LdapEnabled: model.AppConfigVariable{Value: "true"},
|
||||||
|
LdapBase: model.AppConfigVariable{Value: "dc=example,dc=com"},
|
||||||
|
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
|
||||||
|
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
|
||||||
|
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||||
|
LdapAttributeUserUsername: model.AppConfigVariable{Value: "uid"},
|
||||||
|
LdapAttributeUserEmail: model.AppConfigVariable{Value: "mail"},
|
||||||
|
LdapAttributeUserFirstName: model.AppConfigVariable{Value: "givenName"},
|
||||||
|
LdapAttributeUserLastName: model.AppConfigVariable{Value: "sn"},
|
||||||
|
LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "displayName"},
|
||||||
|
LdapAttributeUserProfilePicture: model.AppConfigVariable{Value: "jpegPhoto"},
|
||||||
|
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
||||||
|
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||||
|
LdapAttributeGroupName: model.AppConfigVariable{Value: "cn"},
|
||||||
|
LdapAdminGroupName: model.AppConfigVariable{Value: "admins"},
|
||||||
|
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
|
||||||
|
})
|
||||||
|
|
||||||
|
groupService := NewUserGroupService(db, appConfig, nil)
|
||||||
|
userService := NewUserService(
|
||||||
|
db,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
appConfig,
|
||||||
|
NewCustomClaimService(db),
|
||||||
|
NewAppImagesService(map[string]string{}, fileStorage),
|
||||||
|
nil,
|
||||||
|
fileStorage,
|
||||||
|
)
|
||||||
|
|
||||||
|
service := NewLdapService(db, &http.Client{}, appConfig, userService, groupService, fileStorage)
|
||||||
|
service.clientFactory = func() (ldapClient, error) {
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeLDAPClient(userResult, groupResult *ldap.SearchResult) ldapClient {
|
||||||
|
return &fakeLDAPClient{
|
||||||
|
searchFn: func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||||
|
switch searchRequest.Filter {
|
||||||
|
case "(objectClass=person)":
|
||||||
|
return userResult, nil
|
||||||
|
case "(objectClass=groupOfNames)":
|
||||||
|
return groupResult, nil
|
||||||
|
default:
|
||||||
|
return &ldap.SearchResult{}, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ldapSearchResult(entries ...*ldap.Entry) *ldap.SearchResult {
|
||||||
|
return &ldap.SearchResult{Entries: entries}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ldapEntry(dn string, attrs map[string][]string) *ldap.Entry {
|
||||||
|
entry := &ldap.Entry{
|
||||||
|
DN: dn,
|
||||||
|
Attributes: make([]*ldap.EntryAttribute, 0, len(attrs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, values := range attrs {
|
||||||
|
entry.Attributes = append(entry.Attributes, &ldap.EntryAttribute{
|
||||||
|
Name: name,
|
||||||
|
Values: values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func usernames(users []model.User) []string {
|
||||||
|
result := make([]string, 0, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
result = append(result, user.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetDNProperty(t *testing.T) {
|
func TestGetDNProperty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -64,10 +341,58 @@ func TestGetDNProperty(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := getDNProperty(tt.property, tt.dn)
|
result := getDNProperty(tt.property, tt.dn)
|
||||||
if result != tt.expectedResult {
|
assert.Equalf(t, tt.expectedResult, result, "getDNProperty(%q, %q)", tt.property, tt.dn)
|
||||||
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
|
})
|
||||||
tt.property, tt.dn, result, tt.expectedResult)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeLDAPDN(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "already normalized",
|
||||||
|
input: "cn=alice,dc=example,dc=com",
|
||||||
|
expected: "cn=alice,dc=example,dc=com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase attribute types",
|
||||||
|
input: "CN=Alice,DC=example,DC=com",
|
||||||
|
expected: "cn=alice,dc=example,dc=com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "spaces after commas",
|
||||||
|
input: "cn=alice, dc=example, dc=com",
|
||||||
|
expected: "cn=alice,dc=example,dc=com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase types and spaces",
|
||||||
|
input: "CN=Alice, DC=example, DC=com",
|
||||||
|
expected: "cn=alice,dc=example,dc=com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-valued RDN",
|
||||||
|
input: "cn=alice+uid=a123,dc=example,dc=com",
|
||||||
|
expected: "cn=alice+uid=a123,dc=example,dc=com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid DN falls back to lowercase+trim",
|
||||||
|
input: " NOT A VALID DN ",
|
||||||
|
expected: "not a valid dn",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := normalizeLDAPDN(tt.input)
|
||||||
|
assert.Equalf(t, tt.expected, result, "normalizeLDAPDN(%q)", tt.input)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,9 +423,7 @@ func TestConvertLdapIdToString(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := convertLdapIdToString(tt.input)
|
got := convertLdapIdToString(tt.input)
|
||||||
if got != tt.expected {
|
assert.Equal(t, tt.expected, got)
|
||||||
t.Errorf("Expected %q, got %q", tt.expected, got)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,7 +96,10 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +129,10 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +181,10 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +247,10 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,6 +327,9 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -225,7 +225,10 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI
|
|||||||
return fmt.Errorf("failed to delete user: %w", err)
|
return fmt.Errorf("failed to delete user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,7 +313,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,7 +462,10 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,7 +524,10 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
|||||||
return model.User{}, err
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -576,7 +588,10 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.scimService != nil {
|
||||||
s.scimService.ScheduleSync()
|
s.scimService.ScheduleSync()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user