mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-09 16:39:15 +00:00
feat: add support for SCIM provisioning (#1182)
This commit is contained in:
@@ -83,6 +83,7 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService)
|
||||
controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService)
|
||||
controller.NewVersionController(apiGroup, svc.versionService)
|
||||
controller.NewScimController(apiGroup, authMiddleware, svc.scimService)
|
||||
|
||||
// Add test controller in non-production environments
|
||||
if !common.EnvConfig.AppEnv.IsProduction() {
|
||||
|
||||
@@ -12,22 +12,24 @@ import (
|
||||
)
|
||||
|
||||
type services struct {
|
||||
appConfigService *service.AppConfigService
|
||||
appImagesService *service.AppImagesService
|
||||
emailService *service.EmailService
|
||||
geoLiteService *service.GeoLiteService
|
||||
auditLogService *service.AuditLogService
|
||||
jwtService *service.JwtService
|
||||
webauthnService *service.WebAuthnService
|
||||
userService *service.UserService
|
||||
customClaimService *service.CustomClaimService
|
||||
oidcService *service.OidcService
|
||||
userGroupService *service.UserGroupService
|
||||
ldapService *service.LdapService
|
||||
apiKeyService *service.ApiKeyService
|
||||
versionService *service.VersionService
|
||||
fileStorage storage.FileStorage
|
||||
appLockService *service.AppLockService
|
||||
appConfigService *service.AppConfigService
|
||||
appImagesService *service.AppImagesService
|
||||
emailService *service.EmailService
|
||||
geoLiteService *service.GeoLiteService
|
||||
auditLogService *service.AuditLogService
|
||||
jwtService *service.JwtService
|
||||
webauthnService *service.WebAuthnService
|
||||
scimService *service.ScimService
|
||||
scimSchedulerService *service.ScimSchedulerService
|
||||
userService *service.UserService
|
||||
customClaimService *service.CustomClaimService
|
||||
oidcService *service.OidcService
|
||||
userGroupService *service.UserGroupService
|
||||
ldapService *service.LdapService
|
||||
apiKeyService *service.ApiKeyService
|
||||
versionService *service.VersionService
|
||||
fileStorage storage.FileStorage
|
||||
appLockService *service.AppLockService
|
||||
}
|
||||
|
||||
// Initializes all services
|
||||
@@ -70,6 +72,11 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
|
||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||
svc.scimService = service.NewScimService(db, httpClient)
|
||||
svc.scimSchedulerService, err = service.NewScimSchedulerService(ctx, svc.scimService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SCIM scheduler service: %w", err)
|
||||
}
|
||||
|
||||
svc.versionService = service.NewVersionService(httpClient)
|
||||
|
||||
|
||||
@@ -63,6 +63,8 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
||||
|
||||
group.GET("/oidc/users/me/clients", authMiddleware.WithAdminNotRequired().Add(), oc.listOwnAccessibleClientsHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/scim-service-provider", authMiddleware.Add(), oc.getClientScimServiceProviderHandler)
|
||||
|
||||
}
|
||||
|
||||
type OidcController struct {
|
||||
@@ -845,3 +847,29 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, preview)
|
||||
}
|
||||
|
||||
// getClientScimServiceProviderHandler godoc
|
||||
// @Summary Get SCIM service provider
|
||||
// @Description Get the SCIM service provider configuration for an OIDC client
|
||||
// @Tags OIDC
|
||||
// @Produce json
|
||||
// @Param id path string true "Client ID"
|
||||
// @Success 200 {object} dto.ScimServiceProviderDTO "SCIM service provider configuration"
|
||||
// @Router /api/oidc/clients/{id}/scim-service-provider [get]
|
||||
func (oc *OidcController) getClientScimServiceProviderHandler(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
|
||||
provider, err := oc.oidcService.GetClientScimServiceProvider(c.Request.Context(), clientID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var providerDto dto.ScimServiceProviderDTO
|
||||
if err := dto.MapStruct(provider, &providerDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, providerDto)
|
||||
}
|
||||
|
||||
122
backend/internal/controller/scim_controller.go
Normal file
122
backend/internal/controller/scim_controller.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
func NewScimController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, scimService *service.ScimService) {
|
||||
ugc := ScimController{
|
||||
scimService: scimService,
|
||||
}
|
||||
|
||||
group.POST("/scim/service-provider", authMiddleware.Add(), ugc.createServiceProviderHandler)
|
||||
group.POST("/scim/service-provider/:id/sync", authMiddleware.Add(), ugc.syncServiceProviderHandler)
|
||||
group.PUT("/scim/service-provider/:id", authMiddleware.Add(), ugc.updateServiceProviderHandler)
|
||||
group.DELETE("/scim/service-provider/:id", authMiddleware.Add(), ugc.deleteServiceProviderHandler)
|
||||
}
|
||||
|
||||
type ScimController struct {
|
||||
scimService *service.ScimService
|
||||
}
|
||||
|
||||
// syncServiceProviderHandler godoc
|
||||
// @Summary Sync SCIM service provider
|
||||
// @Description Trigger synchronization for a SCIM service provider
|
||||
// @Tags SCIM
|
||||
// @Param id path string true "Service Provider ID"
|
||||
// @Success 200 "OK"
|
||||
// @Router /api/scim/service-provider/{id}/sync [post]
|
||||
func (c *ScimController) syncServiceProviderHandler(ctx *gin.Context) {
|
||||
err := c.scimService.SyncServiceProvider(ctx.Request.Context(), ctx.Param("id"))
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// createServiceProviderHandler godoc
|
||||
// @Summary Create SCIM service provider
|
||||
// @Description Create a new SCIM service provider
|
||||
// @Tags SCIM
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param serviceProvider body dto.ScimServiceProviderCreateDTO true "SCIM service provider information"
|
||||
// @Success 201 {object} dto.ScimServiceProviderDTO "Created SCIM service provider"
|
||||
// @Router /api/scim/service-provider [post]
|
||||
func (c *ScimController) createServiceProviderHandler(ctx *gin.Context) {
|
||||
var input dto.ScimServiceProviderCreateDTO
|
||||
if err := ctx.ShouldBindJSON(&input); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := c.scimService.CreateServiceProvider(ctx.Request.Context(), &input)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var providerDTO dto.ScimServiceProviderDTO
|
||||
if err := dto.MapStruct(provider, &providerDTO); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusCreated, providerDTO)
|
||||
}
|
||||
|
||||
// updateServiceProviderHandler godoc
|
||||
// @Summary Update SCIM service provider
|
||||
// @Description Update an existing SCIM service provider
|
||||
// @Tags SCIM
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Service Provider ID"
|
||||
// @Param serviceProvider body dto.ScimServiceProviderCreateDTO true "SCIM service provider information"
|
||||
// @Success 200 {object} dto.ScimServiceProviderDTO "Updated SCIM service provider"
|
||||
// @Router /api/scim/service-provider/{id} [put]
|
||||
func (c *ScimController) updateServiceProviderHandler(ctx *gin.Context) {
|
||||
var input dto.ScimServiceProviderCreateDTO
|
||||
if err := ctx.ShouldBindJSON(&input); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := c.scimService.UpdateServiceProvider(ctx.Request.Context(), ctx.Param("id"), &input)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var providerDTO dto.ScimServiceProviderDTO
|
||||
if err := dto.MapStruct(provider, &providerDTO); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, providerDTO)
|
||||
}
|
||||
|
||||
// deleteServiceProviderHandler godoc
|
||||
// @Summary Delete SCIM service provider
|
||||
// @Description Delete a SCIM service provider by ID
|
||||
// @Tags SCIM
|
||||
// @Param id path string true "Service Provider ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/scim/service-provider/{id} [delete]
|
||||
func (c *ScimController) deleteServiceProviderHandler(ctx *gin.Context) {
|
||||
err := c.scimService.DeleteServiceProvider(ctx.Request.Context(), ctx.Param("id"))
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
96
backend/internal/dto/scim_dto.go
Normal file
96
backend/internal/dto/scim_dto.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type ScimServiceProviderDTO struct {
|
||||
ID string `json:"id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Token string `json:"token"`
|
||||
LastSyncedAt *datatype.DateTime `json:"lastSyncedAt"`
|
||||
OidcClient OidcClientMetaDataDto `json:"oidcClient"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
|
||||
type ScimServiceProviderCreateDTO struct {
|
||||
Endpoint string `json:"endpoint" binding:"required,url"`
|
||||
Token string `json:"token"`
|
||||
OidcClientID string `json:"oidcClientId" binding:"required"`
|
||||
}
|
||||
|
||||
type ScimUser struct {
|
||||
ScimResourceData
|
||||
UserName string `json:"userName"`
|
||||
Name *ScimName `json:"name,omitempty"`
|
||||
Display string `json:"displayName,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
Emails []ScimEmail `json:"emails,omitempty"`
|
||||
}
|
||||
|
||||
type ScimName struct {
|
||||
GivenName string `json:"givenName,omitempty"`
|
||||
FamilyName string `json:"familyName,omitempty"`
|
||||
}
|
||||
|
||||
type ScimEmail struct {
|
||||
Value string `json:"value"`
|
||||
Primary bool `json:"primary,omitempty"`
|
||||
}
|
||||
|
||||
type ScimGroup struct {
|
||||
ScimResourceData
|
||||
Display string `json:"displayName"`
|
||||
Members []ScimGroupMember `json:"members,omitempty"`
|
||||
}
|
||||
|
||||
type ScimGroupMember struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type ScimListResponse[T any] struct {
|
||||
Resources []T `json:"Resources"`
|
||||
TotalResults int `json:"totalResults"`
|
||||
StartIndex int `json:"startIndex"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
}
|
||||
|
||||
type ScimResourceData struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
ExternalID string `json:"externalId,omitempty"`
|
||||
Schemas []string `json:"schemas"`
|
||||
Meta ScimResourceMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
type ScimResourceMeta struct {
|
||||
Location string `json:"location,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
Created time.Time `json:"created,omitempty"`
|
||||
LastModified time.Time `json:"lastModified,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
func (r ScimResourceData) GetID() string {
|
||||
return r.ID
|
||||
}
|
||||
|
||||
func (r ScimResourceData) GetExternalID() string {
|
||||
return r.ExternalID
|
||||
}
|
||||
|
||||
func (r ScimResourceData) GetSchemas() []string {
|
||||
return r.Schemas
|
||||
}
|
||||
|
||||
func (r ScimResourceData) GetMeta() ScimResourceMeta {
|
||||
return r.Meta
|
||||
}
|
||||
|
||||
type ScimResource interface {
|
||||
GetID() string
|
||||
GetExternalID() string
|
||||
GetSchemas() []string
|
||||
GetMeta() ScimResourceMeta
|
||||
}
|
||||
14
backend/internal/model/scim.go
Normal file
14
backend/internal/model/scim.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package model
|
||||
|
||||
import datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
|
||||
type ScimServiceProvider struct {
|
||||
Base
|
||||
|
||||
Endpoint string `sortable:"true"`
|
||||
Token datatype.EncryptedString
|
||||
LastSyncedAt *datatype.DateTime `sortable:"true"`
|
||||
|
||||
OidcClientID string
|
||||
OidcClient OidcClient `gorm:"foreignKey:OidcClientID;references:ID;"`
|
||||
}
|
||||
91
backend/internal/model/types/encrypted_string.go
Normal file
91
backend/internal/model/types/encrypted_string.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package datatype
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const encryptedStringAAD = "encrypted_string"
|
||||
|
||||
var encStringKey []byte
|
||||
|
||||
// EncryptedString stores plaintext in memory and persists encrypted data in the database.
|
||||
type EncryptedString string //nolint:recvcheck
|
||||
|
||||
func (e *EncryptedString) Scan(value any) error {
|
||||
if value == nil {
|
||||
*e = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
raw = v
|
||||
case []byte:
|
||||
raw = string(v)
|
||||
default:
|
||||
return fmt.Errorf("unexpected type for EncryptedString: %T", value)
|
||||
}
|
||||
|
||||
if raw == "" {
|
||||
*e = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
encBytes, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode encrypted string: %w", err)
|
||||
}
|
||||
|
||||
decBytes, err := cryptoutils.Decrypt(encStringKey, encBytes, []byte(encryptedStringAAD))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt encrypted string: %w", err)
|
||||
}
|
||||
|
||||
*e = EncryptedString(decBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e EncryptedString) Value() (driver.Value, error) {
|
||||
if e == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
encBytes, err := cryptoutils.Encrypt(encStringKey, []byte(e), []byte(encryptedStringAAD))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt string: %w", err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encBytes), nil
|
||||
}
|
||||
|
||||
func (e EncryptedString) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
func deriveEncryptedStringKey(master []byte) ([]byte, error) {
|
||||
const info = "pocketid/encrypted_string"
|
||||
r := hkdf.New(sha256.New, master, nil, []byte(info))
|
||||
|
||||
key := make([]byte, 32)
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
key, err := deriveEncryptedStringKey(common.EnvConfig.EncryptionKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to derive encrypted string key: %v", err))
|
||||
}
|
||||
encStringKey = key
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
@@ -22,6 +23,7 @@ type User struct {
|
||||
Locale *string
|
||||
LdapID *string
|
||||
Disabled bool `sortable:"true" filterable:"true"`
|
||||
UpdatedAt *datatype.DateTime
|
||||
|
||||
CustomClaims []CustomClaim
|
||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||
@@ -85,6 +87,13 @@ func (u User) Initials() string {
|
||||
return strings.ToUpper(first + last)
|
||||
}
|
||||
|
||||
func (u User) LastModified() time.Time {
|
||||
if u.UpdatedAt != nil {
|
||||
return u.UpdatedAt.ToTime()
|
||||
}
|
||||
return u.CreatedAt.ToTime()
|
||||
}
|
||||
|
||||
type OneTimeAccessToken struct {
|
||||
Base
|
||||
Token string
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type UserGroup struct {
|
||||
Base
|
||||
FriendlyName string `sortable:"true"`
|
||||
Name string `sortable:"true"`
|
||||
LdapID *string
|
||||
UpdatedAt *datatype.DateTime
|
||||
Users []User `gorm:"many2many:user_groups_users;"`
|
||||
CustomClaims []CustomClaim
|
||||
AllowedOidcClients []OidcClient `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
}
|
||||
|
||||
func (ug UserGroup) LastModified() time.Time {
|
||||
if ug.UpdatedAt != nil {
|
||||
return ug.UpdatedAt.ToTime()
|
||||
}
|
||||
return ug.CreatedAt.ToTime()
|
||||
}
|
||||
|
||||
@@ -98,6 +98,17 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
DisplayName: "Craig Federighi",
|
||||
IsAdmin: false,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "d9256384-98ad-49a7-bc58-99ad0b4dc23c",
|
||||
},
|
||||
Username: "eddy",
|
||||
Email: utils.Ptr("eddy.cue@test.com"),
|
||||
FirstName: "Eddy",
|
||||
LastName: "Cue",
|
||||
DisplayName: "Eddy Cue",
|
||||
IsAdmin: false,
|
||||
},
|
||||
}
|
||||
for _, user := range users {
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
@@ -209,6 +220,20 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "c46d2090-37a0-4f2b-8748-6aa53b0c1afa",
|
||||
},
|
||||
Name: "SCIM Client",
|
||||
Secret: "$2a$10$h4wfa8gI7zavDAxwzSq1sOwYU4e8DwK1XZ8ZweNnY5KzlJ3Iz.qdK", // nQbiuMRG7FpdK2EnDd5MBivWQeKFXohn
|
||||
CallbackURLs: model.UrlList{"http://scimclient/auth/callback"},
|
||||
CreatedByID: utils.Ptr(users[0].ID),
|
||||
IsGroupRestricted: true,
|
||||
AllowedUserGroups: []model.UserGroup{
|
||||
userGroups[0],
|
||||
userGroups[1],
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, client := range oidcClients {
|
||||
if err := tx.Create(&client).Error; err != nil {
|
||||
|
||||
@@ -168,7 +168,7 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, client) {
|
||||
if !IsUserGroupAllowedToAuthorize(user, client) {
|
||||
return "", "", &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID,
|
||||
}
|
||||
|
||||
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
|
||||
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
|
||||
func IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
|
||||
if !client.IsGroupRestricted {
|
||||
return true
|
||||
}
|
||||
@@ -1325,7 +1325,7 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
||||
return fmt.Errorf("error finding user groups: %w", err)
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
|
||||
if !IsUserGroupAllowedToAuthorize(user, deviceAuth.Client) {
|
||||
return &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
@@ -1829,7 +1829,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, client) {
|
||||
if !IsUserGroupAllowedToAuthorize(user, client) {
|
||||
return nil, &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
@@ -1956,7 +1956,7 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
||||
return false, err
|
||||
}
|
||||
|
||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||
return IsUserGroupAllowedToAuthorize(user, client), nil
|
||||
}
|
||||
|
||||
var errLogoTooLarge = errors.New("logo is too large")
|
||||
@@ -2116,3 +2116,16 @@ func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientScimServiceProvider(ctx context.Context, clientID string) (model.ScimServiceProvider, error) {
|
||||
var provider model.ScimServiceProvider
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
First(&provider, "oidc_client_id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.ScimServiceProvider{}, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
136
backend/internal/service/scim_scheduler_service.go
Normal file
136
backend/internal/service/scim_scheduler_service.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ScimSchedulerService schedules and triggers periodic synchronization
|
||||
// of SCIM service providers. Each provider is tracked independently,
|
||||
// and sync operations are run at or after their scheduled time.
|
||||
type ScimSchedulerService struct {
|
||||
scimService *ScimService
|
||||
providerSyncTime map[string]time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewScimSchedulerService(ctx context.Context, scimService *ScimService) (*ScimSchedulerService, error) {
|
||||
s := &ScimSchedulerService{
|
||||
scimService: scimService,
|
||||
providerSyncTime: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
err := s.start(ctx)
|
||||
return s, err
|
||||
}
|
||||
|
||||
// ScheduleSync forces the given provider to be synced soon by
|
||||
// moving its next scheduled time to 5 minutes from now.
|
||||
func (s *ScimSchedulerService) ScheduleSync(providerID string) {
|
||||
s.setSyncTime(providerID, 5*time.Minute)
|
||||
}
|
||||
|
||||
// start initializes the scheduler and begins the synchronization loop.
|
||||
// Syncs happen every hour by default, but ScheduleSync can be called to schedule a sync sooner.
|
||||
func (s *ScimSchedulerService) start(ctx context.Context) error {
|
||||
if err := s.refreshProviders(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
const (
|
||||
syncCheckInterval = 5 * time.Second
|
||||
providerRefreshDelay = time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(syncCheckInterval)
|
||||
defer ticker.Stop()
|
||||
lastProviderRefresh := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// Runs every 5 seconds to check if any provider is due for sync
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
if now.Sub(lastProviderRefresh) >= providerRefreshDelay {
|
||||
err := s.refreshProviders(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Error refreshing SCIM service providers",
|
||||
slog.Any("error", err),
|
||||
)
|
||||
} else {
|
||||
lastProviderRefresh = now
|
||||
}
|
||||
}
|
||||
|
||||
var due []string
|
||||
s.mu.RLock()
|
||||
for providerID, syncTime := range s.providerSyncTime {
|
||||
if !syncTime.After(now) {
|
||||
due = append(due, providerID)
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
s.syncProviders(ctx, due)
|
||||
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScimSchedulerService) refreshProviders(ctx context.Context) error {
|
||||
providers, err := s.scimService.ListServiceProviders(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inAHour := time.Now().Add(time.Hour)
|
||||
|
||||
s.mu.Lock()
|
||||
for _, provider := range providers {
|
||||
if _, exists := s.providerSyncTime[provider.ID]; !exists {
|
||||
s.providerSyncTime[provider.ID] = inAHour
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScimSchedulerService) syncProviders(ctx context.Context, providerIDs []string) {
|
||||
for _, providerID := range providerIDs {
|
||||
err := s.scimService.SyncServiceProvider(ctx, providerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Remove the provider from the schedule if it no longer exists
|
||||
s.mu.Lock()
|
||||
delete(s.providerSyncTime, providerID)
|
||||
s.mu.Unlock()
|
||||
} else {
|
||||
slog.Error("Error syncing SCIM client",
|
||||
slog.String("provider_id", providerID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// A successful sync schedules the next sync in an hour
|
||||
s.setSyncTime(providerID, time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScimSchedulerService) setSyncTime(providerID string, t time.Duration) {
|
||||
s.mu.Lock()
|
||||
s.providerSyncTime[providerID] = time.Now().Add(t)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
774
backend/internal/service/scim_service.go
Normal file
774
backend/internal/service/scim_service.go
Normal file
@@ -0,0 +1,774 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
scimUserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
scimGroupSchema = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
scimContentType = "application/scim+json"
|
||||
)
|
||||
|
||||
const scimErrorBodyLimit = 4096
|
||||
|
||||
type scimSyncAction int
|
||||
|
||||
const (
|
||||
scimActionNone scimSyncAction = iota
|
||||
scimActionCreated
|
||||
scimActionUpdated
|
||||
scimActionDeleted
|
||||
)
|
||||
|
||||
type scimSyncStats struct {
|
||||
Created int
|
||||
Updated int
|
||||
Deleted int
|
||||
}
|
||||
|
||||
// ScimService handles SCIM provisioning to external service providers.
|
||||
type ScimService struct {
|
||||
db *gorm.DB
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewScimService(db *gorm.DB, httpClient *http.Client) *ScimService {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{Timeout: 20 * time.Second}
|
||||
}
|
||||
|
||||
return &ScimService{db: db, httpClient: httpClient}
|
||||
}
|
||||
|
||||
func (s *ScimService) GetServiceProvider(
|
||||
ctx context.Context,
|
||||
serviceProviderID string,
|
||||
) (model.ScimServiceProvider, error) {
|
||||
var provider model.ScimServiceProvider
|
||||
err := s.db.WithContext(ctx).
|
||||
Preload("OidcClient").
|
||||
Preload("OidcClient.AllowedUserGroups").
|
||||
First(&provider, "id = ?", serviceProviderID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.ScimServiceProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) ListServiceProviders(ctx context.Context) ([]model.ScimServiceProvider, error) {
|
||||
var providers []model.ScimServiceProvider
|
||||
err := s.db.WithContext(ctx).
|
||||
Preload("OidcClient").
|
||||
Find(&providers).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) CreateServiceProvider(
|
||||
ctx context.Context,
|
||||
input *dto.ScimServiceProviderCreateDTO) (model.ScimServiceProvider, error) {
|
||||
provider := model.ScimServiceProvider{
|
||||
Endpoint: input.Endpoint,
|
||||
Token: datatype.EncryptedString(input.Token),
|
||||
OidcClientID: input.OidcClientID,
|
||||
}
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(&provider).Error; err != nil {
|
||||
return model.ScimServiceProvider{}, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) UpdateServiceProvider(ctx context.Context,
|
||||
serviceProviderID string,
|
||||
input *dto.ScimServiceProviderCreateDTO,
|
||||
) (model.ScimServiceProvider, error) {
|
||||
var provider model.ScimServiceProvider
|
||||
err := s.db.WithContext(ctx).
|
||||
First(&provider, "id = ?", serviceProviderID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.ScimServiceProvider{}, err
|
||||
}
|
||||
|
||||
provider.Endpoint = input.Endpoint
|
||||
provider.Token = datatype.EncryptedString(input.Token)
|
||||
provider.OidcClientID = input.OidcClientID
|
||||
|
||||
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
||||
return model.ScimServiceProvider{}, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) DeleteServiceProvider(ctx context.Context, serviceProviderID string) error {
|
||||
return s.db.WithContext(ctx).
|
||||
Delete(&model.ScimServiceProvider{}, "id = ?", serviceProviderID).
|
||||
Error
|
||||
}
|
||||
|
||||
func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID string) error {
|
||||
start := time.Now()
|
||||
provider, err := s.GetServiceProvider(ctx, serviceProviderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "Syncing SCIM service provider",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.String("oidc_client_id", provider.OidcClientID),
|
||||
)
|
||||
|
||||
allowedGroupIDs := groupIDs(provider.OidcClient.AllowedUserGroups)
|
||||
|
||||
// Load users and groups that should be synced to the SCIM provider
|
||||
groups, err := s.groupsForClient(ctx, provider.OidcClient, allowedGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
users, err := s.usersForClient(ctx, provider.OidcClient, allowedGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load users and groups that already exist in the SCIM provider
|
||||
userResources, err := listScimResources[dto.ScimUser](s, ctx, provider, "/Users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groupResources, err := listScimResources[dto.ScimGroup](s, ctx, provider, "/Groups")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var userStats scimSyncStats
|
||||
var groupStats scimSyncStats
|
||||
|
||||
// Sync users first, so that groups can reference them
|
||||
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil {
|
||||
errs = append(errs, err)
|
||||
userStats = stats
|
||||
} else {
|
||||
userStats = stats
|
||||
}
|
||||
|
||||
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
groupStats = stats
|
||||
} else {
|
||||
groupStats = stats
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
slog.WarnContext(ctx, "SCIM sync completed with errors",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.Int("error_count", len(errs)),
|
||||
slog.Int("users_created", userStats.Created),
|
||||
slog.Int("users_updated", userStats.Updated),
|
||||
slog.Int("users_deleted", userStats.Deleted),
|
||||
slog.Int("groups_created", groupStats.Created),
|
||||
slog.Int("groups_updated", groupStats.Updated),
|
||||
slog.Int("groups_deleted", groupStats.Deleted),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
)
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
provider.LastSyncedAt = utils.Ptr(datatype.DateTime(time.Now()))
|
||||
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "SCIM sync completed",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.Int("users_created", userStats.Created),
|
||||
slog.Int("users_updated", userStats.Updated),
|
||||
slog.Int("users_deleted", userStats.Deleted),
|
||||
slog.Int("groups_created", groupStats.Created),
|
||||
slog.Int("groups_updated", groupStats.Updated),
|
||||
slog.Int("groups_deleted", groupStats.Deleted),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScimService) syncUsers(
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
users []model.User,
|
||||
resourceList *dto.ScimListResponse[dto.ScimUser],
|
||||
) (stats scimSyncStats, err error) {
|
||||
var errs []error
|
||||
|
||||
// Update or create users
|
||||
for _, u := range users {
|
||||
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources)
|
||||
|
||||
action, created, err := s.syncUser(ctx, provider, u, existing)
|
||||
if created != nil && existing == nil {
|
||||
resourceList.Resources = append(resourceList.Resources, *created)
|
||||
}
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update stats based on action taken by syncUser
|
||||
switch action {
|
||||
case scimActionCreated:
|
||||
stats.Created++
|
||||
case scimActionUpdated:
|
||||
stats.Updated++
|
||||
case scimActionDeleted:
|
||||
stats.Deleted++
|
||||
case scimActionNone:
|
||||
}
|
||||
}
|
||||
|
||||
// Delete users that are present in SCIM provider but not locally.
|
||||
userSet := make(map[string]struct{})
|
||||
for _, u := range users {
|
||||
userSet[u.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, r := range resourceList.Resources {
|
||||
if _, ok := userSet[r.ExternalID]; !ok {
|
||||
if err := s.deleteScimResource(ctx, provider, "/Users/"+url.PathEscape(r.ID)); err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
stats.Deleted++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *ScimService) syncGroups(
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
groups []model.UserGroup,
|
||||
remoteGroups []dto.ScimGroup,
|
||||
userResources []dto.ScimUser,
|
||||
) (stats scimSyncStats, err error) {
|
||||
var errs []error
|
||||
|
||||
// Update or create groups
|
||||
for _, g := range groups {
|
||||
existing := getResourceByExternalID[dto.ScimGroup](g.ID, remoteGroups)
|
||||
|
||||
action, err := s.syncGroup(ctx, provider, g, existing, userResources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update stats based on action taken by syncGroup
|
||||
switch action {
|
||||
case scimActionCreated:
|
||||
stats.Created++
|
||||
case scimActionUpdated:
|
||||
stats.Updated++
|
||||
case scimActionDeleted:
|
||||
stats.Deleted++
|
||||
case scimActionNone:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Delete groups that are present in SCIM provider but not locally
|
||||
groupSet := make(map[string]struct{})
|
||||
for _, g := range groups {
|
||||
groupSet[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, r := range remoteGroups {
|
||||
if _, ok := groupSet[r.ExternalID]; !ok {
|
||||
if err := s.deleteScimResource(ctx, provider, "/Groups/"+url.PathEscape(r.GetID())); err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
stats.Deleted++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *ScimService) syncUser(ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
user model.User,
|
||||
userResource *dto.ScimUser,
|
||||
) (scimSyncAction, *dto.ScimUser, error) {
|
||||
// If user is not allowed for the client, delete it from SCIM provider
|
||||
if userResource != nil && !IsUserGroupAllowedToAuthorize(user, provider.OidcClient) {
|
||||
return scimActionDeleted, nil, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Users/%s", url.PathEscape(userResource.ID)))
|
||||
}
|
||||
|
||||
payload := dto.ScimUser{
|
||||
ScimResourceData: dto.ScimResourceData{
|
||||
Schemas: []string{scimUserSchema},
|
||||
ExternalID: user.ID,
|
||||
},
|
||||
UserName: user.Username,
|
||||
Name: &dto.ScimName{
|
||||
GivenName: user.FirstName,
|
||||
FamilyName: user.LastName,
|
||||
},
|
||||
Display: user.DisplayName,
|
||||
Active: !user.Disabled,
|
||||
}
|
||||
|
||||
if user.Email != nil {
|
||||
payload.Emails = []dto.ScimEmail{{
|
||||
Value: *user.Email,
|
||||
Primary: true,
|
||||
}}
|
||||
}
|
||||
|
||||
// If the user exists on the SCIM provider, and it has been modified, update it
|
||||
if userResource != nil {
|
||||
if user.LastModified().Before(userResource.GetMeta().LastModified) {
|
||||
return scimActionNone, nil, nil
|
||||
}
|
||||
path := fmt.Sprintf("/Users/%s", url.PathEscape(userResource.GetID()))
|
||||
userResource, err := updateScimResource(s, ctx, provider, path, payload)
|
||||
if err != nil {
|
||||
return scimActionNone, nil, err
|
||||
}
|
||||
return scimActionUpdated, userResource, nil
|
||||
}
|
||||
|
||||
// Otherwise, create a new SCIM user
|
||||
userResource, err := createScimResource(s, ctx, provider, "/Users", payload)
|
||||
if err != nil {
|
||||
return scimActionNone, nil, err
|
||||
}
|
||||
|
||||
return scimActionCreated, userResource, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) syncGroup(
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
group model.UserGroup,
|
||||
groupResource *dto.ScimGroup,
|
||||
userResources []dto.ScimUser,
|
||||
) (scimSyncAction, error) {
|
||||
// If group is not allowed for the client, delete it from SCIM provider
|
||||
if groupResource != nil && !groupAllowedForClient(group.ID, provider.OidcClient) {
|
||||
return scimActionDeleted, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID())))
|
||||
}
|
||||
|
||||
// Prepare group members
|
||||
members := make([]dto.ScimGroupMember, len(group.Users))
|
||||
for i, user := range group.Users {
|
||||
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources)
|
||||
if userResource == nil {
|
||||
// Groups depend on user IDs already being provisioned
|
||||
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)
|
||||
}
|
||||
|
||||
members[i] = dto.ScimGroupMember{
|
||||
Value: userResource.GetID(),
|
||||
}
|
||||
}
|
||||
|
||||
groupPayload := dto.ScimGroup{
|
||||
ScimResourceData: dto.ScimResourceData{
|
||||
Schemas: []string{scimGroupSchema},
|
||||
ExternalID: group.ID,
|
||||
},
|
||||
Display: group.FriendlyName,
|
||||
Members: members,
|
||||
}
|
||||
|
||||
// If the group exists on the SCIM provider, and it has been modified, update it
|
||||
if groupResource != nil {
|
||||
if group.LastModified().Before(groupResource.GetMeta().LastModified) {
|
||||
return scimActionNone, nil
|
||||
}
|
||||
path := fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID()))
|
||||
_, err := updateScimResource(s, ctx, provider, path, groupPayload)
|
||||
if err != nil {
|
||||
return scimActionNone, err
|
||||
}
|
||||
return scimActionUpdated, nil
|
||||
}
|
||||
|
||||
// Otherwise, create a new SCIM group
|
||||
_, err := createScimResource(s, ctx, provider, "/Groups", groupPayload)
|
||||
if err != nil {
|
||||
return scimActionNone, err
|
||||
}
|
||||
|
||||
return scimActionCreated, nil
|
||||
}
|
||||
|
||||
func groupAllowedForClient(groupID string, client model.OidcClient) bool {
|
||||
if !client.IsGroupRestricted {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, allowedGroup := range client.AllowedUserGroups {
|
||||
if allowedGroup.ID == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func groupIDs(groups []model.UserGroup) []string {
|
||||
ids := make([]string, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = g.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (s *ScimService) groupsForClient(
|
||||
ctx context.Context,
|
||||
client model.OidcClient,
|
||||
allowedGroupIDs []string,
|
||||
) ([]model.UserGroup, error) {
|
||||
var groups []model.UserGroup
|
||||
|
||||
query := s.db.WithContext(ctx).Preload("Users").Model(&model.UserGroup{})
|
||||
if client.IsGroupRestricted {
|
||||
if len(allowedGroupIDs) == 0 {
|
||||
return groups, nil
|
||||
}
|
||||
query = query.Where("id IN ?", allowedGroupIDs)
|
||||
}
|
||||
|
||||
if err := query.Find(&groups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) usersForClient(
|
||||
ctx context.Context,
|
||||
client model.OidcClient,
|
||||
allowedGroupIDs []string,
|
||||
) ([]model.User, error) {
|
||||
var users []model.User
|
||||
|
||||
query := s.db.WithContext(ctx).Model(&model.User{})
|
||||
if client.IsGroupRestricted {
|
||||
if len(allowedGroupIDs) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
query = query.
|
||||
Joins("JOIN user_groups_users ON users.id = user_groups_users.user_id").
|
||||
Where("user_groups_users.user_group_id IN ?", allowedGroupIDs).
|
||||
Select("users.*").
|
||||
Distinct()
|
||||
}
|
||||
|
||||
query = query.Preload("UserGroups")
|
||||
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func getResourceByExternalID[T dto.ScimResource](externalID string, resource []T) *T {
|
||||
for i := range resource {
|
||||
if resource[i].GetExternalID() == externalID {
|
||||
return &resource[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listScimResources[T any](
|
||||
s *ScimService,
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
path string,
|
||||
) (result dto.ScimListResponse[T], err error) {
|
||||
startIndex := 1
|
||||
count := 1000
|
||||
|
||||
for {
|
||||
// Use SCIM pagination to avoid missing resources on large providers
|
||||
queryParams := map[string]string{
|
||||
"startIndex": strconv.Itoa(startIndex),
|
||||
"count": strconv.Itoa(count),
|
||||
}
|
||||
|
||||
resp, err := s.scimRequest(ctx, provider, http.MethodGet, path, nil, queryParams)
|
||||
if err != nil {
|
||||
return dto.ScimListResponse[T]{}, err
|
||||
}
|
||||
|
||||
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK); err != nil {
|
||||
return dto.ScimListResponse[T]{}, err
|
||||
}
|
||||
|
||||
var page dto.ScimListResponse[T]
|
||||
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
|
||||
return dto.ScimListResponse[T]{}, fmt.Errorf("failed to decode SCIM list response: %w", err)
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
|
||||
// Initialize metadata only once
|
||||
if result.TotalResults == 0 {
|
||||
result.TotalResults = page.TotalResults
|
||||
}
|
||||
|
||||
result.Resources = append(result.Resources, page.Resources...)
|
||||
|
||||
// If we've fetched everything, stop
|
||||
if len(result.Resources) >= page.TotalResults || len(page.Resources) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
startIndex += page.ItemsPerPage
|
||||
}
|
||||
|
||||
result.ItemsPerPage = len(result.Resources)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func createScimResource[T dto.ScimResource](
|
||||
s *ScimService,
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
path string, payload T) (*T, error) {
|
||||
resp, err := s.scimRequest(ctx, provider, http.MethodPost, path, payload, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resource T
|
||||
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode SCIM create response: %w", err)
|
||||
}
|
||||
|
||||
return &resource, nil
|
||||
}
|
||||
|
||||
func updateScimResource[T dto.ScimResource](
|
||||
s *ScimService,
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
path string,
|
||||
payload T,
|
||||
) (*T, error) {
|
||||
resp, err := s.scimRequest(ctx, provider, http.MethodPut, path, payload, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resource T
|
||||
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode SCIM update response: %w", err)
|
||||
}
|
||||
|
||||
return &resource, nil
|
||||
}
|
||||
|
||||
func (s *ScimService) deleteScimResource(ctx context.Context, provider model.ScimServiceProvider, path string) error {
|
||||
resp, err := s.scimRequest(ctx, provider, http.MethodDelete, path, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *ScimService) scimRequest(
|
||||
ctx context.Context,
|
||||
provider model.ScimServiceProvider,
|
||||
method,
|
||||
path string,
|
||||
payload any,
|
||||
queryParams map[string]string,
|
||||
) (*http.Response, error) {
|
||||
urlString, err := scimURL(provider.Endpoint, path, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bodyBytes []byte
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode SCIM payload: %w", err)
|
||||
}
|
||||
bodyBytes = encoded
|
||||
}
|
||||
|
||||
retryAttempts := 3
|
||||
for attempt := 1; attempt <= retryAttempts; attempt++ {
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlString, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", scimContentType)
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", scimContentType)
|
||||
}
|
||||
token := string(provider.Token)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
slog.Debug("Sending SCIM request",
|
||||
slog.String("method", method),
|
||||
slog.String("url", urlString),
|
||||
slog.String("provider_id", provider.ID),
|
||||
)
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only retry on 429 to avoid masking other errors
|
||||
if resp.StatusCode != http.StatusTooManyRequests || attempt == retryAttempts {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
retryDelay := scimRetryDelay(resp.Header.Get("Retry-After"), attempt)
|
||||
slog.WarnContext(ctx, "SCIM provider rate-limited, retrying",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.String("method", method),
|
||||
slog.String("url", urlString),
|
||||
slog.Int("attempt", attempt),
|
||||
slog.Duration("retry_after", retryDelay),
|
||||
)
|
||||
|
||||
resp.Body.Close()
|
||||
if err := utils.SleepWithContext(ctx, retryDelay); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("scim request retry attempts exceeded")
|
||||
}
|
||||
|
||||
func scimRetryDelay(retryAfter string, attempt int) time.Duration {
|
||||
// Respect Retry-After when provided
|
||||
if retryAfter != "" {
|
||||
if seconds, err := strconv.Atoi(retryAfter); err == nil {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
if t, err := http.ParseTime(retryAfter); err == nil {
|
||||
if delay := time.Until(t); delay > 0 {
|
||||
return delay
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff otherwise
|
||||
maxDelay := 10 * time.Second
|
||||
delay := 500 * time.Millisecond * (time.Duration(1) << (attempt - 1)) //nolint:gosec // attempt is bounded 1-3
|
||||
if delay > maxDelay {
|
||||
return maxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func scimURL(endpoint, p string, queryParams map[string]string) (string, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid scim endpoint: %w", err)
|
||||
}
|
||||
|
||||
u.Path = path.Join(strings.TrimRight(u.Path, "/"), p)
|
||||
|
||||
q := u.Query()
|
||||
for key, value := range queryParams {
|
||||
q.Set(key, value)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func ensureScimStatus(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
provider model.ScimServiceProvider,
|
||||
allowedStatuses ...int) error {
|
||||
for _, status := range allowedStatuses {
|
||||
if resp.StatusCode == status {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
body := readScimErrorBody(resp.Body)
|
||||
|
||||
slog.ErrorContext(ctx, "SCIM request failed",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.String("method", resp.Request.Method),
|
||||
slog.String("url", resp.Request.URL.String()),
|
||||
slog.Int("status", resp.StatusCode),
|
||||
slog.String("response_body", body),
|
||||
)
|
||||
|
||||
return fmt.Errorf("scim request failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
func readScimErrorBody(body io.Reader) string {
|
||||
payload, err := io.ReadAll(io.LimitReader(body, scimErrorBodyLimit))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(payload))
|
||||
}
|
||||
@@ -3,7 +3,9 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
@@ -151,6 +153,7 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
||||
|
||||
group.Name = input.Name
|
||||
group.FriendlyName = input.FriendlyName
|
||||
group.UpdatedAt = utils.Ptr(datatype.DateTime(time.Now()))
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -214,6 +217,8 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
|
||||
}
|
||||
|
||||
// Save the updated group
|
||||
group.UpdatedAt = utils.Ptr(datatype.DateTime(time.Now()))
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&group).
|
||||
|
||||
@@ -426,6 +426,8 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
}
|
||||
}
|
||||
|
||||
user.UpdatedAt = utils.Ptr(datatype.DateTime(time.Now()))
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&user).
|
||||
@@ -646,6 +648,16 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
// Update the UpdatedAt field for all affected groups
|
||||
now := time.Now()
|
||||
for _, group := range groups {
|
||||
group.UpdatedAt = utils.Ptr(datatype.DateTime(now))
|
||||
err = tx.WithContext(ctx).Save(&group).Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
|
||||
21
backend/internal/utils/sleep_util.go
Normal file
21
backend/internal/utils/sleep_util.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SleepWithContext(ctx context.Context, delay time.Duration) error {
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user