mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-15 15:45:05 +00:00
feat: display groups on the account page (#296)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -132,22 +132,18 @@ func (s *LdapService) SyncGroups() error {
|
||||
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute),
|
||||
}
|
||||
|
||||
usersToAddDto := dto.UserGroupUpdateUsersDto{
|
||||
UserIDs: membersUserId,
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.Create(syncGroup)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
} else {
|
||||
if _, err = s.groupService.UpdateUsers(newGroup.ID, usersToAddDto); err != nil {
|
||||
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, usersToAddDto)
|
||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
return err
|
||||
|
||||
@@ -103,16 +103,16 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) UpdateUsers(id string, input dto.UserGroupUpdateUsersDto) (group model.UserGroup, err error) {
|
||||
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
|
||||
group, err = s.Get(id)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
// Fetch the users based on UserIDs in input
|
||||
// Fetch the users based on the userIds
|
||||
var users []model.User
|
||||
if len(input.UserIDs) > 0 {
|
||||
if err := s.db.Where("id IN (?)", input.UserIDs).Find(&users).Error; err != nil {
|
||||
if len(userIds) > 0 {
|
||||
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
||||
"io"
|
||||
"log"
|
||||
"net/url"
|
||||
@@ -12,6 +10,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
@@ -48,7 +49,7 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
|
||||
|
||||
func (s *UserService) GetUser(userID string) (model.User, error) {
|
||||
var user model.User
|
||||
err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
|
||||
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
|
||||
return user, err
|
||||
}
|
||||
|
||||
@@ -83,6 +84,14 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
|
||||
return defaultPicture, int64(defaultPicture.Len()), nil
|
||||
}
|
||||
|
||||
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
|
||||
var user model.User
|
||||
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.UserGroups, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
|
||||
// Validate the user ID to prevent directory traversal
|
||||
if err := uuid.Validate(userID); err != nil {
|
||||
@@ -269,6 +278,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
|
||||
return oneTimeAccessToken.User, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
|
||||
user, err = s.GetUser(id)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
// Fetch the groups based on userGroupIds
|
||||
var groups []model.UserGroup
|
||||
if len(userGroupIds) > 0 {
|
||||
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current groups with the new set of groups
|
||||
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
// Save the updated user
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
var userCount int64
|
||||
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
|
||||
Reference in New Issue
Block a user