mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 11:36:46 +00:00
feat: restrict oidc clients by user groups per default (#1164)
This commit is contained in:
@@ -72,7 +72,7 @@ type UserController struct {
|
||||
// @Description Retrieve all groups a specific user belongs to
|
||||
// @Tags Users,User Groups
|
||||
// @Param id path string true "User ID"
|
||||
// @Success 200 {array} dto.UserGroupDtoWithUsers
|
||||
// @Success 200 {array} dto.UserGroupDto
|
||||
// @Router /api/users/{id}/groups [get]
|
||||
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
@@ -82,7 +82,7 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var groupsDto []dto.UserGroupDtoWithUsers
|
||||
var groupsDto []dto.UserGroupDto
|
||||
if err := dto.MapStructList(groups, &groupsDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -28,6 +28,7 @@ func NewUserGroupController(group *gin.RouterGroup, authMiddleware *middleware.A
|
||||
userGroupsGroup.PUT("/:id", ugc.update)
|
||||
userGroupsGroup.DELETE("/:id", ugc.delete)
|
||||
userGroupsGroup.PUT("/:id/users", ugc.updateUsers)
|
||||
userGroupsGroup.PUT("/:id/allowed-oidc-clients", ugc.updateAllowedOidcClients)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +45,7 @@ type UserGroupController struct {
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
|
||||
// @Success 200 {object} dto.Paginated[dto.UserGroupMinimalDto]
|
||||
// @Router /api/user-groups [get]
|
||||
func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
searchTerm := c.Query("search")
|
||||
@@ -57,9 +58,9 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Map the user groups to DTOs
|
||||
var groupsDto = make([]dto.UserGroupDtoWithUserCount, len(groups))
|
||||
var groupsDto = make([]dto.UserGroupMinimalDto, len(groups))
|
||||
for i, group := range groups {
|
||||
var groupDto dto.UserGroupDtoWithUserCount
|
||||
var groupDto dto.UserGroupMinimalDto
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -72,7 +73,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
groupsDto[i] = groupDto
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.Paginated[dto.UserGroupDtoWithUserCount]{
|
||||
c.JSON(http.StatusOK, dto.Paginated[dto.UserGroupMinimalDto]{
|
||||
Data: groupsDto,
|
||||
Pagination: pagination,
|
||||
})
|
||||
@@ -85,7 +86,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers
|
||||
// @Success 200 {object} dto.UserGroupDto
|
||||
// @Router /api/user-groups/{id} [get]
|
||||
func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
|
||||
@@ -94,7 +95,7 @@ func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
var groupDto dto.UserGroupDto
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -110,7 +111,7 @@ func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
|
||||
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
|
||||
// @Success 201 {object} dto.UserGroupDto "Created user group"
|
||||
// @Router /api/user-groups [post]
|
||||
func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
@@ -125,7 +126,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
var groupDto dto.UserGroupDto
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -142,7 +143,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
|
||||
// @Success 200 {object} dto.UserGroupDto "Updated user group"
|
||||
// @Router /api/user-groups/{id} [put]
|
||||
func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
@@ -157,7 +158,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
var groupDto dto.UserGroupDto
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -192,7 +193,7 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers
|
||||
// @Success 200 {object} dto.UserGroupDto
|
||||
// @Router /api/user-groups/{id}/users [put]
|
||||
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
var input dto.UserGroupUpdateUsersDto
|
||||
@@ -207,7 +208,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var groupDto dto.UserGroupDtoWithUsers
|
||||
var groupDto dto.UserGroupDto
|
||||
if err := dto.MapStruct(group, &groupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -215,3 +216,35 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, groupDto)
|
||||
}
|
||||
|
||||
// updateAllowedOidcClients godoc
|
||||
// @Summary Update allowed OIDC clients
|
||||
// @Description Update the OIDC clients allowed for a specific user group
|
||||
// @Tags OIDC
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Param groups body dto.UserGroupUpdateAllowedOidcClientsDto true "OIDC client IDs to allow"
|
||||
// @Success 200 {object} dto.UserGroupDto "Updated user group"
|
||||
// @Router /api/user-groups/{id}/allowed-oidc-clients [put]
|
||||
func (ugc *UserGroupController) updateAllowedOidcClients(c *gin.Context) {
|
||||
var input dto.UserGroupUpdateAllowedOidcClientsDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
userGroup, err := ugc.UserGroupService.UpdateAllowedOidcClient(c.Request.Context(), c.Param("id"), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var userGroupDto dto.UserGroupDto
|
||||
if err := dto.MapStruct(userGroup, &userGroupDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userGroupDto)
|
||||
}
|
||||
|
||||
@@ -18,11 +18,12 @@ type OidcClientDto struct {
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
IsGroupRestricted bool `json:"isGroupRestricted"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedUserGroupsDto struct {
|
||||
OidcClientDto
|
||||
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
|
||||
AllowedUserGroups []UserGroupMinimalDto `json:"allowedUserGroups"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedGroupsCountDto struct {
|
||||
@@ -43,6 +44,7 @@ type OidcClientUpdateDto struct {
|
||||
HasDarkLogo bool `json:"hasDarkLogo"`
|
||||
LogoURL *string `json:"logoUrl"`
|
||||
DarkLogoURL *string `json:"darkLogoUrl"`
|
||||
IsGroupRestricted bool `json:"isGroupRestricted"`
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
|
||||
@@ -12,11 +12,11 @@ type SignupTokenCreateDto struct {
|
||||
}
|
||||
|
||||
type SignupTokenDto struct {
|
||||
ID string `json:"id"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt"`
|
||||
UsageLimit int `json:"usageLimit"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
UserGroups []UserGroupDto `json:"userGroups"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
ID string `json:"id"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt"`
|
||||
UsageLimit int `json:"usageLimit"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
UserGroups []UserGroupMinimalDto `json:"userGroups"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
|
||||
@@ -8,18 +8,18 @@ import (
|
||||
)
|
||||
|
||||
type UserDto struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email *string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName *string `json:"lastName"`
|
||||
DisplayName string `json:"displayName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
UserGroups []UserGroupDto `json:"userGroups"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
Disabled bool `json:"disabled"`
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email *string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName *string `json:"lastName"`
|
||||
DisplayName string `json:"displayName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
UserGroups []UserGroupMinimalDto `json:"userGroups"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
|
||||
@@ -8,25 +8,17 @@ import (
|
||||
)
|
||||
|
||||
type UserGroupDto struct {
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
Users []UserDto `json:"users"`
|
||||
AllowedOidcClients []OidcClientMetaDataDto `json:"allowedOidcClients"`
|
||||
}
|
||||
|
||||
type UserGroupDtoWithUsers struct {
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
Users []UserDto `json:"users"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
|
||||
type UserGroupDtoWithUserCount struct {
|
||||
type UserGroupMinimalDto struct {
|
||||
ID string `json:"id"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
Name string `json:"name"`
|
||||
@@ -36,6 +28,10 @@ type UserGroupDtoWithUserCount struct {
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
|
||||
type UserGroupUpdateAllowedOidcClientsDto struct {
|
||||
OidcClientIDs []string `json:"oidcClientIds" binding:"required"`
|
||||
}
|
||||
|
||||
type UserGroupCreateDto struct {
|
||||
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50" unorm:"nfc"`
|
||||
Name string `json:"name" binding:"required,min=2,max=255" unorm:"nfc"`
|
||||
|
||||
@@ -58,6 +58,7 @@ type OidcClient struct {
|
||||
RequiresReauthentication bool `sortable:"true" filterable:"true"`
|
||||
Credentials OidcClientCredentials
|
||||
LaunchURL *string
|
||||
IsGroupRestricted bool
|
||||
|
||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
CreatedByID *string
|
||||
|
||||
@@ -2,9 +2,10 @@ package model
|
||||
|
||||
type UserGroup struct {
|
||||
Base
|
||||
FriendlyName string `sortable:"true"`
|
||||
Name string `sortable:"true"`
|
||||
LdapID *string
|
||||
Users []User `gorm:"many2many:user_groups_users;"`
|
||||
CustomClaims []CustomClaim
|
||||
FriendlyName string `sortable:"true"`
|
||||
Name string `sortable:"true"`
|
||||
LdapID *string
|
||||
Users []User `gorm:"many2many:user_groups_users;"`
|
||||
CustomClaims []CustomClaim
|
||||
AllowedOidcClients []OidcClient `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
}
|
||||
|
||||
@@ -169,10 +169,11 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
Base: model.Base{
|
||||
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
|
||||
},
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
|
||||
CreatedByID: utils.Ptr(users[1].ID),
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
|
||||
CreatedByID: utils.Ptr(users[1].ID),
|
||||
IsGroupRestricted: true,
|
||||
AllowedUserGroups: []model.UserGroup{
|
||||
userGroups[1],
|
||||
},
|
||||
@@ -185,6 +186,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
Secret: "$2a$10$xcRReBsvkI1XI6FG8xu/pOgzeF00bH5Wy4d/NThwcdi3ZBpVq/B9a", // n4VfQeXlTzA6yKpWbR9uJcMdSx2qH0Lo
|
||||
CallbackURLs: model.UrlList{"http://tailscale/auth/callback"},
|
||||
LogoutCallbackURLs: model.UrlList{"http://tailscale/auth/logout/callback"},
|
||||
IsGroupRestricted: true,
|
||||
CreatedByID: utils.Ptr(users[0].ID),
|
||||
},
|
||||
{
|
||||
|
||||
@@ -226,7 +226,7 @@ func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID,
|
||||
|
||||
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
|
||||
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
|
||||
if len(client.AllowedUserGroups) == 0 {
|
||||
if !client.IsGroupRestricted {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -778,6 +778,14 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
if !input.IsGroupRestricted {
|
||||
// Clear allowed user groups if the restriction is removed
|
||||
err = tx.Model(&client).Association("AllowedUserGroups").Clear()
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.WithContext(ctx).Save(&client).Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
@@ -816,6 +824,7 @@ func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClien
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
client.RequiresReauthentication = input.RequiresReauthentication
|
||||
client.LaunchURL = input.LaunchURL
|
||||
client.IsGroupRestricted = input.IsGroupRestricted
|
||||
|
||||
// Credentials
|
||||
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
|
||||
|
||||
@@ -53,6 +53,7 @@ func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.
|
||||
Where("id = ?", id).
|
||||
Preload("CustomClaims").
|
||||
Preload("Users").
|
||||
Preload("AllowedOidcClients").
|
||||
First(&group).
|
||||
Error
|
||||
return group, err
|
||||
@@ -248,3 +249,54 @@ func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (
|
||||
Count()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id string, input dto.UserGroupUpdateAllowedOidcClientsDto) (group model.UserGroup, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
group, err = s.getInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
// Fetch the clients based on the client IDs
|
||||
var clients []model.OidcClient
|
||||
if len(input.OidcClientIDs) > 0 {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("id IN (?)", input.OidcClientIDs).
|
||||
Find(&clients).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current clients with the new set of clients
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&group).
|
||||
Association("AllowedOidcClients").
|
||||
Replace(clients)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
// Save the updated group
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients DROP COLUMN is_group_restricted;
|
||||
@@ -0,0 +1,10 @@
|
||||
ALTER TABLE oidc_clients
|
||||
ADD COLUMN is_group_restricted boolean NOT NULL DEFAULT false;
|
||||
|
||||
UPDATE oidc_clients oc
|
||||
SET is_group_restricted =
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM oidc_clients_allowed_user_groups a
|
||||
WHERE a.oidc_client_id = oc.id
|
||||
);
|
||||
@@ -0,0 +1,7 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE oidc_clients DROP COLUMN is_group_restricted;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,13 @@
|
||||
PRAGMA foreign_keys= OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE oidc_clients
|
||||
ADD COLUMN is_group_restricted BOOLEAN NOT NULL DEFAULT 0;
|
||||
|
||||
UPDATE oidc_clients
|
||||
SET is_group_restricted = (SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END
|
||||
FROM oidc_clients_allowed_user_groups
|
||||
WHERE oidc_clients_allowed_user_groups.oidc_client_id = oidc_clients.id);
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys= ON;
|
||||
Reference in New Issue
Block a user