1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 13:21:45 +00:00
Files
pocket-id/backend/internal/service/scim_service.go

817 lines
21 KiB
Go

package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/go-co-op/gocron/v2"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
const (
scimUserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"
scimGroupSchema = "urn:ietf:params:scim:schemas:core:2.0:Group"
scimContentType = "application/scim+json"
)
const scimErrorBodyLimit = 4096
type scimSyncAction int
type Scheduler interface {
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
RemoveJob(name string) error
}
const (
scimActionNone scimSyncAction = iota
scimActionCreated
scimActionUpdated
scimActionDeleted
)
type scimSyncStats struct {
Created int
Updated int
Deleted int
}
// ScimService handles SCIM provisioning to external service providers.
type ScimService struct {
db *gorm.DB
scheduler Scheduler
httpClient *http.Client
}
func NewScimService(db *gorm.DB, scheduler Scheduler, httpClient *http.Client) *ScimService {
if httpClient == nil {
httpClient = &http.Client{Timeout: 20 * time.Second}
}
return &ScimService{db: db, scheduler: scheduler, httpClient: httpClient}
}
func (s *ScimService) GetServiceProvider(
ctx context.Context,
serviceProviderID string,
) (model.ScimServiceProvider, error) {
var provider model.ScimServiceProvider
err := s.db.WithContext(ctx).
Preload("OidcClient").
Preload("OidcClient.AllowedUserGroups").
First(&provider, "id = ?", serviceProviderID).
Error
if err != nil {
return model.ScimServiceProvider{}, err
}
return provider, nil
}
func (s *ScimService) ListServiceProviders(ctx context.Context) ([]model.ScimServiceProvider, error) {
var providers []model.ScimServiceProvider
err := s.db.WithContext(ctx).
Preload("OidcClient").
Find(&providers).
Error
if err != nil {
return nil, err
}
return providers, nil
}
func (s *ScimService) CreateServiceProvider(
ctx context.Context,
input *dto.ScimServiceProviderCreateDTO) (model.ScimServiceProvider, error) {
provider := model.ScimServiceProvider{
Endpoint: input.Endpoint,
Token: datatype.EncryptedString(input.Token),
OidcClientID: input.OidcClientID,
}
if err := s.db.WithContext(ctx).Create(&provider).Error; err != nil {
return model.ScimServiceProvider{}, err
}
return provider, nil
}
func (s *ScimService) UpdateServiceProvider(ctx context.Context,
serviceProviderID string,
input *dto.ScimServiceProviderCreateDTO,
) (model.ScimServiceProvider, error) {
var provider model.ScimServiceProvider
err := s.db.WithContext(ctx).
First(&provider, "id = ?", serviceProviderID).
Error
if err != nil {
return model.ScimServiceProvider{}, err
}
provider.Endpoint = input.Endpoint
provider.Token = datatype.EncryptedString(input.Token)
provider.OidcClientID = input.OidcClientID
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
return model.ScimServiceProvider{}, err
}
return provider, nil
}
func (s *ScimService) DeleteServiceProvider(ctx context.Context, serviceProviderID string) error {
return s.db.WithContext(ctx).
Delete(&model.ScimServiceProvider{}, "id = ?", serviceProviderID).
Error
}
//nolint:contextcheck
func (s *ScimService) ScheduleSync() {
jobName := "ScheduledScimSync"
start := time.Now().Add(5 * time.Minute)
_ = s.scheduler.RemoveJob(jobName)
err := s.scheduler.RegisterJob(
context.Background(), jobName,
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false)
if err != nil {
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
}
}
func (s *ScimService) SyncAll(ctx context.Context) error {
providers, err := s.ListServiceProviders(ctx)
if err != nil {
return err
}
var errs []error
for _, provider := range providers {
if ctx.Err() != nil {
errs = append(errs, ctx.Err())
break
}
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil {
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
}
}
return errors.Join(errs...)
}
func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID string) error {
start := time.Now()
provider, err := s.GetServiceProvider(ctx, serviceProviderID)
if err != nil {
return err
}
slog.InfoContext(ctx, "Syncing SCIM service provider",
slog.String("provider_id", provider.ID),
slog.String("oidc_client_id", provider.OidcClientID),
)
allowedGroupIDs := groupIDs(provider.OidcClient.AllowedUserGroups)
// Load users and groups that should be synced to the SCIM provider
groups, err := s.groupsForClient(ctx, provider.OidcClient, allowedGroupIDs)
if err != nil {
return err
}
users, err := s.usersForClient(ctx, provider.OidcClient, allowedGroupIDs)
if err != nil {
return err
}
// Load users and groups that already exist in the SCIM provider
userResources, err := listScimResources[dto.ScimUser](s, ctx, provider, "/Users")
if err != nil {
return err
}
groupResources, err := listScimResources[dto.ScimGroup](s, ctx, provider, "/Groups")
if err != nil {
return err
}
var errs []error
var userStats scimSyncStats
var groupStats scimSyncStats
// Sync users first, so that groups can reference them
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil {
errs = append(errs, err)
userStats = stats
} else {
userStats = stats
}
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
if err != nil {
errs = append(errs, err)
groupStats = stats
} else {
groupStats = stats
}
if len(errs) > 0 {
slog.WarnContext(ctx, "SCIM sync completed with errors",
slog.String("provider_id", provider.ID),
slog.Int("error_count", len(errs)),
slog.Int("users_created", userStats.Created),
slog.Int("users_updated", userStats.Updated),
slog.Int("users_deleted", userStats.Deleted),
slog.Int("groups_created", groupStats.Created),
slog.Int("groups_updated", groupStats.Updated),
slog.Int("groups_deleted", groupStats.Deleted),
slog.Duration("duration", time.Since(start)),
)
return errors.Join(errs...)
}
provider.LastSyncedAt = utils.Ptr(datatype.DateTime(time.Now()))
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
return err
}
slog.InfoContext(ctx, "SCIM sync completed",
slog.String("provider_id", provider.ID),
slog.Int("users_created", userStats.Created),
slog.Int("users_updated", userStats.Updated),
slog.Int("users_deleted", userStats.Deleted),
slog.Int("groups_created", groupStats.Created),
slog.Int("groups_updated", groupStats.Updated),
slog.Int("groups_deleted", groupStats.Deleted),
slog.Duration("duration", time.Since(start)),
)
return nil
}
func (s *ScimService) syncUsers(
ctx context.Context,
provider model.ScimServiceProvider,
users []model.User,
resourceList *dto.ScimListResponse[dto.ScimUser],
) (stats scimSyncStats, err error) {
var errs []error
// Update or create users
for _, u := range users {
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources)
action, created, err := s.syncUser(ctx, provider, u, existing)
if created != nil && existing == nil {
resourceList.Resources = append(resourceList.Resources, *created)
}
if err != nil {
errs = append(errs, err)
continue
}
// Update stats based on action taken by syncUser
switch action {
case scimActionCreated:
stats.Created++
case scimActionUpdated:
stats.Updated++
case scimActionDeleted:
stats.Deleted++
case scimActionNone:
}
}
// Delete users that are present in SCIM provider but not locally.
userSet := make(map[string]struct{})
for _, u := range users {
userSet[u.ID] = struct{}{}
}
for _, r := range resourceList.Resources {
if _, ok := userSet[r.ExternalID]; !ok {
if err := s.deleteScimResource(ctx, provider, "/Users/"+url.PathEscape(r.ID)); err != nil {
errs = append(errs, err)
} else {
stats.Deleted++
}
}
}
return stats, errors.Join(errs...)
}
func (s *ScimService) syncGroups(
ctx context.Context,
provider model.ScimServiceProvider,
groups []model.UserGroup,
remoteGroups []dto.ScimGroup,
userResources []dto.ScimUser,
) (stats scimSyncStats, err error) {
var errs []error
// Update or create groups
for _, g := range groups {
existing := getResourceByExternalID[dto.ScimGroup](g.ID, remoteGroups)
action, err := s.syncGroup(ctx, provider, g, existing, userResources)
if err != nil {
errs = append(errs, err)
continue
}
// Update stats based on action taken by syncGroup
switch action {
case scimActionCreated:
stats.Created++
case scimActionUpdated:
stats.Updated++
case scimActionDeleted:
stats.Deleted++
case scimActionNone:
}
}
// Delete groups that are present in SCIM provider but not locally
groupSet := make(map[string]struct{})
for _, g := range groups {
groupSet[g.ID] = struct{}{}
}
for _, r := range remoteGroups {
if _, ok := groupSet[r.ExternalID]; !ok {
if err := s.deleteScimResource(ctx, provider, "/Groups/"+url.PathEscape(r.GetID())); err != nil {
errs = append(errs, err)
} else {
stats.Deleted++
}
}
}
return stats, errors.Join(errs...)
}
func (s *ScimService) syncUser(ctx context.Context,
provider model.ScimServiceProvider,
user model.User,
userResource *dto.ScimUser,
) (scimSyncAction, *dto.ScimUser, error) {
// If user is not allowed for the client, delete it from SCIM provider
if userResource != nil && !IsUserGroupAllowedToAuthorize(user, provider.OidcClient) {
return scimActionDeleted, nil, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Users/%s", url.PathEscape(userResource.ID)))
}
payload := dto.ScimUser{
ScimResourceData: dto.ScimResourceData{
Schemas: []string{scimUserSchema},
ExternalID: user.ID,
},
UserName: user.Username,
Name: &dto.ScimName{
GivenName: user.FirstName,
FamilyName: user.LastName,
},
Display: user.DisplayName,
Active: !user.Disabled,
}
if user.Email != nil {
payload.Emails = []dto.ScimEmail{{
Value: *user.Email,
Primary: true,
}}
}
// If the user exists on the SCIM provider, and it has been modified, update it
if userResource != nil {
if user.LastModified().Before(userResource.GetMeta().LastModified) {
return scimActionNone, nil, nil
}
path := fmt.Sprintf("/Users/%s", url.PathEscape(userResource.GetID()))
userResource, err := updateScimResource(s, ctx, provider, path, payload)
if err != nil {
return scimActionNone, nil, err
}
return scimActionUpdated, userResource, nil
}
// Otherwise, create a new SCIM user
userResource, err := createScimResource(s, ctx, provider, "/Users", payload)
if err != nil {
return scimActionNone, nil, err
}
return scimActionCreated, userResource, nil
}
func (s *ScimService) syncGroup(
ctx context.Context,
provider model.ScimServiceProvider,
group model.UserGroup,
groupResource *dto.ScimGroup,
userResources []dto.ScimUser,
) (scimSyncAction, error) {
// If group is not allowed for the client, delete it from SCIM provider
if groupResource != nil && !groupAllowedForClient(group.ID, provider.OidcClient) {
return scimActionDeleted, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID())))
}
// Prepare group members
members := make([]dto.ScimGroupMember, len(group.Users))
for i, user := range group.Users {
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources)
if userResource == nil {
// Groups depend on user IDs already being provisioned
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)
}
members[i] = dto.ScimGroupMember{
Value: userResource.GetID(),
}
}
groupPayload := dto.ScimGroup{
ScimResourceData: dto.ScimResourceData{
Schemas: []string{scimGroupSchema},
ExternalID: group.ID,
},
Display: group.FriendlyName,
Members: members,
}
// If the group exists on the SCIM provider, and it has been modified, update it
if groupResource != nil {
if group.LastModified().Before(groupResource.GetMeta().LastModified) {
return scimActionNone, nil
}
path := fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID()))
_, err := updateScimResource(s, ctx, provider, path, groupPayload)
if err != nil {
return scimActionNone, err
}
return scimActionUpdated, nil
}
// Otherwise, create a new SCIM group
_, err := createScimResource(s, ctx, provider, "/Groups", groupPayload)
if err != nil {
return scimActionNone, err
}
return scimActionCreated, nil
}
func groupAllowedForClient(groupID string, client model.OidcClient) bool {
if !client.IsGroupRestricted {
return true
}
for _, allowedGroup := range client.AllowedUserGroups {
if allowedGroup.ID == groupID {
return true
}
}
return false
}
func groupIDs(groups []model.UserGroup) []string {
ids := make([]string, len(groups))
for i, g := range groups {
ids[i] = g.ID
}
return ids
}
func (s *ScimService) groupsForClient(
ctx context.Context,
client model.OidcClient,
allowedGroupIDs []string,
) ([]model.UserGroup, error) {
var groups []model.UserGroup
query := s.db.WithContext(ctx).Preload("Users").Model(&model.UserGroup{})
if client.IsGroupRestricted {
if len(allowedGroupIDs) == 0 {
return groups, nil
}
query = query.Where("id IN ?", allowedGroupIDs)
}
if err := query.Find(&groups).Error; err != nil {
return nil, err
}
return groups, nil
}
func (s *ScimService) usersForClient(
ctx context.Context,
client model.OidcClient,
allowedGroupIDs []string,
) ([]model.User, error) {
var users []model.User
query := s.db.WithContext(ctx).Model(&model.User{})
if client.IsGroupRestricted {
if len(allowedGroupIDs) == 0 {
return users, nil
}
query = query.
Joins("JOIN user_groups_users ON users.id = user_groups_users.user_id").
Where("user_groups_users.user_group_id IN ?", allowedGroupIDs).
Select("users.*").
Distinct()
}
query = query.Preload("UserGroups")
if err := query.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func getResourceByExternalID[T dto.ScimResource](externalID string, resource []T) *T {
for i := range resource {
if resource[i].GetExternalID() == externalID {
return &resource[i]
}
}
return nil
}
func listScimResources[T any](
s *ScimService,
ctx context.Context,
provider model.ScimServiceProvider,
path string,
) (result dto.ScimListResponse[T], err error) {
startIndex := 1
count := 1000
for {
// Use SCIM pagination to avoid missing resources on large providers
queryParams := map[string]string{
"startIndex": strconv.Itoa(startIndex),
"count": strconv.Itoa(count),
}
resp, err := s.scimRequest(ctx, provider, http.MethodGet, path, nil, queryParams)
if err != nil {
return dto.ScimListResponse[T]{}, err
}
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK); err != nil {
return dto.ScimListResponse[T]{}, err
}
var page dto.ScimListResponse[T]
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
return dto.ScimListResponse[T]{}, fmt.Errorf("failed to decode SCIM list response: %w", err)
}
resp.Body.Close()
// Initialize metadata only once
if result.TotalResults == 0 {
result.TotalResults = page.TotalResults
}
result.Resources = append(result.Resources, page.Resources...)
// If we've fetched everything, stop
if len(result.Resources) >= page.TotalResults || len(page.Resources) == 0 {
break
}
startIndex += page.ItemsPerPage
}
result.ItemsPerPage = len(result.Resources)
return result, nil
}
func createScimResource[T dto.ScimResource](
s *ScimService,
ctx context.Context,
provider model.ScimServiceProvider,
path string, payload T) (*T, error) {
resp, err := s.scimRequest(ctx, provider, http.MethodPost, path, payload, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
return nil, err
}
var resource T
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
return nil, fmt.Errorf("failed to decode SCIM create response: %w", err)
}
return &resource, nil
}
func updateScimResource[T dto.ScimResource](
s *ScimService,
ctx context.Context,
provider model.ScimServiceProvider,
path string,
payload T,
) (*T, error) {
resp, err := s.scimRequest(ctx, provider, http.MethodPut, path, payload, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
return nil, err
}
var resource T
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
return nil, fmt.Errorf("failed to decode SCIM update response: %w", err)
}
return &resource, nil
}
func (s *ScimService) deleteScimResource(ctx context.Context, provider model.ScimServiceProvider, path string) error {
resp, err := s.scimRequest(ctx, provider, http.MethodDelete, path, nil, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil
}
return ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusNoContent)
}
func (s *ScimService) scimRequest(
ctx context.Context,
provider model.ScimServiceProvider,
method,
path string,
payload any,
queryParams map[string]string,
) (*http.Response, error) {
urlString, err := scimURL(provider.Endpoint, path, queryParams)
if err != nil {
return nil, err
}
var bodyBytes []byte
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to encode SCIM payload: %w", err)
}
bodyBytes = encoded
}
retryAttempts := 3
for attempt := 1; attempt <= retryAttempts; attempt++ {
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequestWithContext(ctx, method, urlString, body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", scimContentType)
if payload != nil {
req.Header.Set("Content-Type", scimContentType)
}
token := string(provider.Token)
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
slog.Debug("Sending SCIM request",
slog.String("method", method),
slog.String("url", urlString),
slog.String("provider_id", provider.ID),
)
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, err
}
// Only retry on 429 to avoid masking other errors
if resp.StatusCode != http.StatusTooManyRequests || attempt == retryAttempts {
return resp, nil
}
retryDelay := scimRetryDelay(resp.Header.Get("Retry-After"), attempt)
slog.WarnContext(ctx, "SCIM provider rate-limited, retrying",
slog.String("provider_id", provider.ID),
slog.String("method", method),
slog.String("url", urlString),
slog.Int("attempt", attempt),
slog.Duration("retry_after", retryDelay),
)
resp.Body.Close()
if err := utils.SleepWithContext(ctx, retryDelay); err != nil {
return nil, err
}
}
return nil, fmt.Errorf("scim request retry attempts exceeded")
}
func scimRetryDelay(retryAfter string, attempt int) time.Duration {
// Respect Retry-After when provided
if retryAfter != "" {
if seconds, err := strconv.Atoi(retryAfter); err == nil {
return time.Duration(seconds) * time.Second
}
if t, err := http.ParseTime(retryAfter); err == nil {
if delay := time.Until(t); delay > 0 {
return delay
}
}
}
// Exponential backoff otherwise
maxDelay := 10 * time.Second
delay := 500 * time.Millisecond * (time.Duration(1) << (attempt - 1)) //nolint:gosec // attempt is bounded 1-3
if delay > maxDelay {
return maxDelay
}
return delay
}
func scimURL(endpoint, p string, queryParams map[string]string) (string, error) {
u, err := url.Parse(endpoint)
if err != nil {
return "", fmt.Errorf("invalid scim endpoint: %w", err)
}
u.Path = path.Join(strings.TrimRight(u.Path, "/"), p)
q := u.Query()
for key, value := range queryParams {
q.Set(key, value)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func ensureScimStatus(
ctx context.Context,
resp *http.Response,
provider model.ScimServiceProvider,
allowedStatuses ...int) error {
for _, status := range allowedStatuses {
if resp.StatusCode == status {
return nil
}
}
body := readScimErrorBody(resp.Body)
slog.ErrorContext(ctx, "SCIM request failed",
slog.String("provider_id", provider.ID),
slog.String("method", resp.Request.Method),
slog.String("url", resp.Request.URL.String()),
slog.Int("status", resp.StatusCode),
slog.String("response_body", body),
)
return fmt.Errorf("scim request failed with status %d: %s", resp.StatusCode, body)
}
func readScimErrorBody(body io.Reader) string {
payload, err := io.ReadAll(io.LimitReader(body, scimErrorBodyLimit))
if err != nil {
return ""
}
return strings.TrimSpace(string(payload))
}