1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-14 15:17:26 +00:00

feat: display groups on the account page (#296)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-03-06 15:25:03 -06:00
committed by GitHub
parent 37b24bed91
commit 0f14a93e1d
14 changed files with 223 additions and 22 deletions

View File

@@ -27,9 +27,12 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler)
group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler)
group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler)
group.GET("/users/:id/groups", jwtAuthMiddleware.Add(true), uc.getUserGroupsHandler)
group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler)
group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler)
group.PUT("/users/:id/user-groups", jwtAuthMiddleware.Add(true), uc.updateUserGroups)
group.GET("/users/:id/profile-picture.png", uc.getUserProfilePictureHandler)
group.GET("/users/me/profile-picture.png", jwtAuthMiddleware.Add(false), uc.getCurrentUserProfilePictureHandler)
group.PUT("/users/:id/profile-picture", jwtAuthMiddleware.Add(true), uc.updateUserProfilePictureHandler)
@@ -46,6 +49,23 @@ type UserController struct {
appConfigService *service.AppConfigService
}
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id")
groups, err := uc.userService.GetUserGroups(userID)
if err != nil {
c.Error(err)
return
}
var groupsDto []dto.UserGroupDtoWithUsers
if err := dto.MapStructList(groups, &groupsDto); err != nil {
c.Error(err)
return
}
c.JSON(http.StatusOK, groupsDto)
}
func (uc *UserController) listUsersHandler(c *gin.Context) {
searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest
@@ -315,3 +335,25 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
c.JSON(http.StatusOK, userDto)
}
func (uc *UserController) updateUserGroups(c *gin.Context) {
var input dto.UserUpdateUserGroupDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds)
if err != nil {
c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
return
}
c.JSON(http.StatusOK, userDto)
}

View File

@@ -139,7 +139,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
return
}
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input)
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs)
if err != nil {
c.Error(err)
return

View File

@@ -10,6 +10,7 @@ type UserDto struct {
LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"`
CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupDto `json:"userGroups"`
LdapID *string `json:"ldapId"`
}
@@ -31,3 +32,7 @@ type OneTimeAccessEmailDto struct {
Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"`
}
type UserUpdateUserGroupDto struct {
UserGroupIds []string `json:"userGroupIds" binding:"required"`
}

View File

@@ -4,6 +4,15 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type UserGroupDto struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
LdapID *string `json:"ldapId"`
CreatedAt datatype.DateTime `json:"createdAt"`
}
type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`

View File

@@ -132,22 +132,18 @@ func (s *LdapService) SyncGroups() error {
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute),
}
usersToAddDto := dto.UserGroupUpdateUsersDto{
UserIDs: membersUserId,
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, usersToAddDto); err != nil {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
_, err = s.groupService.UpdateUsers(databaseGroup.ID, usersToAddDto)
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err

View File

@@ -103,16 +103,16 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, input dto.UserGroupUpdateUsersDto) (group model.UserGroup, err error) {
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id)
if err != nil {
return model.UserGroup{}, err
}
// Fetch the users based on UserIDs in input
// Fetch the users based on the userIds
var users []model.User
if len(input.UserIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserIDs).Find(&users).Error; err != nil {
if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
return model.UserGroup{}, err
}
}

View File

@@ -3,8 +3,6 @@ package service
import (
"errors"
"fmt"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/utils/image"
"io"
"log"
"net/url"
@@ -12,6 +10,9 @@ import (
"strings"
"time"
"github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
@@ -48,7 +49,7 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
func (s *UserService) GetUser(userID string) (model.User, error) {
var user model.User
err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
return user, err
}
@@ -83,6 +84,14 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
return defaultPicture, int64(defaultPicture.Len()), nil
}
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
return nil, err
}
return user.UserGroups, nil
}
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
@@ -269,6 +278,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return oneTimeAccessToken.User, accessToken, nil
}
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id)
if err != nil {
return model.User{}, err
}
// Fetch the groups based on userGroupIds
var groups []model.UserGroup
if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
return model.User{}, err
}
}
// Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
return model.User{}, err
}
// Save the updated user
if err := s.db.Save(&user).Error; err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {