1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 11:36:46 +00:00

feat: restrict oidc clients by user groups per default (#1164)

This commit is contained in:
Elias Schneider
2025-12-24 09:09:25 +01:00
committed by GitHub
parent e358c433f0
commit f75cef83d5
30 changed files with 469 additions and 102 deletions

View File

@@ -72,7 +72,7 @@ type UserController struct {
// @Description Retrieve all groups a specific user belongs to
// @Tags Users,User Groups
// @Param id path string true "User ID"
// @Success 200 {array} dto.UserGroupDtoWithUsers
// @Success 200 {array} dto.UserGroupDto
// @Router /api/users/{id}/groups [get]
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id")
@@ -82,7 +82,7 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
return
}
var groupsDto []dto.UserGroupDtoWithUsers
var groupsDto []dto.UserGroupDto
if err := dto.MapStructList(groups, &groupsDto); err != nil {
_ = c.Error(err)
return

View File

@@ -28,6 +28,7 @@ func NewUserGroupController(group *gin.RouterGroup, authMiddleware *middleware.A
userGroupsGroup.PUT("/:id", ugc.update)
userGroupsGroup.DELETE("/:id", ugc.delete)
userGroupsGroup.PUT("/:id/users", ugc.updateUsers)
userGroupsGroup.PUT("/:id/allowed-oidc-clients", ugc.updateAllowedOidcClients)
}
}
@@ -44,7 +45,7 @@ type UserGroupController struct {
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Success 200 {object} dto.Paginated[dto.UserGroupMinimalDto]
// @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) {
searchTerm := c.Query("search")
@@ -57,9 +58,9 @@ func (ugc *UserGroupController) list(c *gin.Context) {
}
// Map the user groups to DTOs
var groupsDto = make([]dto.UserGroupDtoWithUserCount, len(groups))
var groupsDto = make([]dto.UserGroupMinimalDto, len(groups))
for i, group := range groups {
var groupDto dto.UserGroupDtoWithUserCount
var groupDto dto.UserGroupMinimalDto
if err := dto.MapStruct(group, &groupDto); err != nil {
_ = c.Error(err)
return
@@ -72,7 +73,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
groupsDto[i] = groupDto
}
c.JSON(http.StatusOK, dto.Paginated[dto.UserGroupDtoWithUserCount]{
c.JSON(http.StatusOK, dto.Paginated[dto.UserGroupMinimalDto]{
Data: groupsDto,
Pagination: pagination,
})
@@ -85,7 +86,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Success 200 {object} dto.UserGroupDto
// @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
@@ -94,7 +95,7 @@ func (ugc *UserGroupController) get(c *gin.Context) {
return
}
var groupDto dto.UserGroupDtoWithUsers
var groupDto dto.UserGroupDto
if err := dto.MapStruct(group, &groupDto); err != nil {
_ = c.Error(err)
return
@@ -110,7 +111,7 @@ func (ugc *UserGroupController) get(c *gin.Context) {
// @Accept json
// @Produce json
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
// @Success 201 {object} dto.UserGroupDto "Created user group"
// @Router /api/user-groups [post]
func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
@@ -125,7 +126,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
return
}
var groupDto dto.UserGroupDtoWithUsers
var groupDto dto.UserGroupDto
if err := dto.MapStruct(group, &groupDto); err != nil {
_ = c.Error(err)
return
@@ -142,7 +143,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
// @Produce json
// @Param id path string true "User Group ID"
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
// @Success 200 {object} dto.UserGroupDto "Updated user group"
// @Router /api/user-groups/{id} [put]
func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
@@ -157,7 +158,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
return
}
var groupDto dto.UserGroupDtoWithUsers
var groupDto dto.UserGroupDto
if err := dto.MapStruct(group, &groupDto); err != nil {
_ = c.Error(err)
return
@@ -192,7 +193,7 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
// @Produce json
// @Param id path string true "User Group ID"
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Success 200 {object} dto.UserGroupDto
// @Router /api/user-groups/{id}/users [put]
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto
@@ -207,7 +208,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
return
}
var groupDto dto.UserGroupDtoWithUsers
var groupDto dto.UserGroupDto
if err := dto.MapStruct(group, &groupDto); err != nil {
_ = c.Error(err)
return
@@ -215,3 +216,35 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
c.JSON(http.StatusOK, groupDto)
}
// updateAllowedOidcClients godoc
// @Summary Update allowed OIDC clients
// @Description Update the OIDC clients allowed for a specific user group
// @Tags OIDC
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Param groups body dto.UserGroupUpdateAllowedOidcClientsDto true "OIDC client IDs to allow"
// @Success 200 {object} dto.UserGroupDto "Updated user group"
// @Router /api/user-groups/{id}/allowed-oidc-clients [put]
func (ugc *UserGroupController) updateAllowedOidcClients(c *gin.Context) {
var input dto.UserGroupUpdateAllowedOidcClientsDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
userGroup, err := ugc.UserGroupService.UpdateAllowedOidcClient(c.Request.Context(), c.Param("id"), input)
if err != nil {
_ = c.Error(err)
return
}
var userGroupDto dto.UserGroupDto
if err := dto.MapStruct(userGroup, &userGroupDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, userGroupDto)
}

View File

@@ -18,11 +18,12 @@ type OidcClientDto struct {
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Credentials OidcClientCredentialsDto `json:"credentials"`
IsGroupRestricted bool `json:"isGroupRestricted"`
}
type OidcClientWithAllowedUserGroupsDto struct {
OidcClientDto
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
AllowedUserGroups []UserGroupMinimalDto `json:"allowedUserGroups"`
}
type OidcClientWithAllowedGroupsCountDto struct {
@@ -43,6 +44,7 @@ type OidcClientUpdateDto struct {
HasDarkLogo bool `json:"hasDarkLogo"`
LogoURL *string `json:"logoUrl"`
DarkLogoURL *string `json:"darkLogoUrl"`
IsGroupRestricted bool `json:"isGroupRestricted"`
}
type OidcClientCreateDto struct {

View File

@@ -12,11 +12,11 @@ type SignupTokenCreateDto struct {
}
type SignupTokenDto struct {
ID string `json:"id"`
Token string `json:"token"`
ExpiresAt datatype.DateTime `json:"expiresAt"`
UsageLimit int `json:"usageLimit"`
UsageCount int `json:"usageCount"`
UserGroups []UserGroupDto `json:"userGroups"`
CreatedAt datatype.DateTime `json:"createdAt"`
ID string `json:"id"`
Token string `json:"token"`
ExpiresAt datatype.DateTime `json:"expiresAt"`
UsageLimit int `json:"usageLimit"`
UsageCount int `json:"usageCount"`
UserGroups []UserGroupMinimalDto `json:"userGroups"`
CreatedAt datatype.DateTime `json:"createdAt"`
}

View File

@@ -8,18 +8,18 @@ import (
)
type UserDto struct {
ID string `json:"id"`
Username string `json:"username"`
Email *string `json:"email" `
FirstName string `json:"firstName"`
LastName *string `json:"lastName"`
DisplayName string `json:"displayName"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupDto `json:"userGroups"`
LdapID *string `json:"ldapId"`
Disabled bool `json:"disabled"`
ID string `json:"id"`
Username string `json:"username"`
Email *string `json:"email" `
FirstName string `json:"firstName"`
LastName *string `json:"lastName"`
DisplayName string `json:"displayName"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupMinimalDto `json:"userGroups"`
LdapID *string `json:"ldapId"`
Disabled bool `json:"disabled"`
}
type UserCreateDto struct {

View File

@@ -8,25 +8,17 @@ import (
)
type UserGroupDto struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
LdapID *string `json:"ldapId"`
CreatedAt datatype.DateTime `json:"createdAt"`
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
LdapID *string `json:"ldapId"`
CreatedAt datatype.DateTime `json:"createdAt"`
Users []UserDto `json:"users"`
AllowedOidcClients []OidcClientMetaDataDto `json:"allowedOidcClients"`
}
type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
Users []UserDto `json:"users"`
LdapID *string `json:"ldapId"`
CreatedAt datatype.DateTime `json:"createdAt"`
}
type UserGroupDtoWithUserCount struct {
type UserGroupMinimalDto struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
@@ -36,6 +28,10 @@ type UserGroupDtoWithUserCount struct {
CreatedAt datatype.DateTime `json:"createdAt"`
}
type UserGroupUpdateAllowedOidcClientsDto struct {
OidcClientIDs []string `json:"oidcClientIds" binding:"required"`
}
type UserGroupCreateDto struct {
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50" unorm:"nfc"`
Name string `json:"name" binding:"required,min=2,max=255" unorm:"nfc"`

View File

@@ -58,6 +58,7 @@ type OidcClient struct {
RequiresReauthentication bool `sortable:"true" filterable:"true"`
Credentials OidcClientCredentials
LaunchURL *string
IsGroupRestricted bool
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
CreatedByID *string

View File

@@ -2,9 +2,10 @@ package model
type UserGroup struct {
Base
FriendlyName string `sortable:"true"`
Name string `sortable:"true"`
LdapID *string
Users []User `gorm:"many2many:user_groups_users;"`
CustomClaims []CustomClaim
FriendlyName string `sortable:"true"`
Name string `sortable:"true"`
LdapID *string
Users []User `gorm:"many2many:user_groups_users;"`
CustomClaims []CustomClaim
AllowedOidcClients []OidcClient `gorm:"many2many:oidc_clients_allowed_user_groups;"`
}

View File

@@ -169,10 +169,11 @@ func (s *TestService) SeedDatabase(baseURL string) error {
Base: model.Base{
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
},
Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
CreatedByID: utils.Ptr(users[1].ID),
Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
CreatedByID: utils.Ptr(users[1].ID),
IsGroupRestricted: true,
AllowedUserGroups: []model.UserGroup{
userGroups[1],
},
@@ -185,6 +186,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
Secret: "$2a$10$xcRReBsvkI1XI6FG8xu/pOgzeF00bH5Wy4d/NThwcdi3ZBpVq/B9a", // n4VfQeXlTzA6yKpWbR9uJcMdSx2qH0Lo
CallbackURLs: model.UrlList{"http://tailscale/auth/callback"},
LogoutCallbackURLs: model.UrlList{"http://tailscale/auth/logout/callback"},
IsGroupRestricted: true,
CreatedByID: utils.Ptr(users[0].ID),
},
{

View File

@@ -226,7 +226,7 @@ func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID,
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
if len(client.AllowedUserGroups) == 0 {
if !client.IsGroupRestricted {
return true
}
@@ -778,6 +778,14 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
updateOIDCClientModelFromDto(&client, &input)
if !input.IsGroupRestricted {
// Clear allowed user groups if the restriction is removed
err = tx.Model(&client).Association("AllowedUserGroups").Clear()
if err != nil {
return model.OidcClient{}, err
}
}
err = tx.WithContext(ctx).Save(&client).Error
if err != nil {
return model.OidcClient{}, err
@@ -816,6 +824,7 @@ func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClien
client.PkceEnabled = input.IsPublic || input.PkceEnabled
client.RequiresReauthentication = input.RequiresReauthentication
client.LaunchURL = input.LaunchURL
client.IsGroupRestricted = input.IsGroupRestricted
// Credentials
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))

View File

@@ -53,6 +53,7 @@ func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
Preload("AllowedOidcClients").
First(&group).
Error
return group, err
@@ -248,3 +249,54 @@ func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (
Count()
return count, nil
}
func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id string, input dto.UserGroupUpdateAllowedOidcClientsDto) (group model.UserGroup, err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
// Fetch the clients based on the client IDs
var clients []model.OidcClient
if len(input.OidcClientIDs) > 0 {
err = tx.
WithContext(ctx).
Where("id IN (?)", input.OidcClientIDs).
Find(&clients).
Error
if err != nil {
return model.UserGroup{}, err
}
}
// Replace the current clients with the new set of clients
err = tx.
WithContext(ctx).
Model(&group).
Association("AllowedOidcClients").
Replace(clients)
if err != nil {
return model.UserGroup{}, err
}
// Save the updated group
err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}

View File

@@ -0,0 +1 @@
ALTER TABLE oidc_clients DROP COLUMN is_group_restricted;

View File

@@ -0,0 +1,10 @@
ALTER TABLE oidc_clients
ADD COLUMN is_group_restricted boolean NOT NULL DEFAULT false;
UPDATE oidc_clients oc
SET is_group_restricted =
EXISTS (
SELECT 1
FROM oidc_clients_allowed_user_groups a
WHERE a.oidc_client_id = oc.id
);

View File

@@ -0,0 +1,7 @@
PRAGMA foreign_keys=OFF;
BEGIN;
ALTER TABLE oidc_clients DROP COLUMN is_group_restricted;
COMMIT;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,13 @@
PRAGMA foreign_keys= OFF;
BEGIN;
ALTER TABLE oidc_clients
ADD COLUMN is_group_restricted BOOLEAN NOT NULL DEFAULT 0;
UPDATE oidc_clients
SET is_group_restricted = (SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END
FROM oidc_clients_allowed_user_groups
WHERE oidc_clients_allowed_user_groups.oidc_client_id = oidc_clients.id);
COMMIT;
PRAGMA foreign_keys= ON;