mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 15:04:43 +00:00
817 lines
21 KiB
Go
817 lines
21 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-co-op/gocron/v2"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
scimUserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"
|
|
scimGroupSchema = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
|
scimContentType = "application/scim+json"
|
|
)
|
|
|
|
const scimErrorBodyLimit = 4096
|
|
|
|
type scimSyncAction int
|
|
|
|
type Scheduler interface {
|
|
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
|
|
RemoveJob(name string) error
|
|
}
|
|
|
|
const (
|
|
scimActionNone scimSyncAction = iota
|
|
scimActionCreated
|
|
scimActionUpdated
|
|
scimActionDeleted
|
|
)
|
|
|
|
type scimSyncStats struct {
|
|
Created int
|
|
Updated int
|
|
Deleted int
|
|
}
|
|
|
|
// ScimService handles SCIM provisioning to external service providers.
|
|
type ScimService struct {
|
|
db *gorm.DB
|
|
scheduler Scheduler
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewScimService(db *gorm.DB, scheduler Scheduler, httpClient *http.Client) *ScimService {
|
|
if httpClient == nil {
|
|
httpClient = &http.Client{Timeout: 20 * time.Second}
|
|
}
|
|
|
|
return &ScimService{db: db, scheduler: scheduler, httpClient: httpClient}
|
|
}
|
|
|
|
func (s *ScimService) GetServiceProvider(
|
|
ctx context.Context,
|
|
serviceProviderID string,
|
|
) (model.ScimServiceProvider, error) {
|
|
var provider model.ScimServiceProvider
|
|
err := s.db.WithContext(ctx).
|
|
Preload("OidcClient").
|
|
Preload("OidcClient.AllowedUserGroups").
|
|
First(&provider, "id = ?", serviceProviderID).
|
|
Error
|
|
if err != nil {
|
|
return model.ScimServiceProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (s *ScimService) ListServiceProviders(ctx context.Context) ([]model.ScimServiceProvider, error) {
|
|
var providers []model.ScimServiceProvider
|
|
err := s.db.WithContext(ctx).
|
|
Preload("OidcClient").
|
|
Find(&providers).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
func (s *ScimService) CreateServiceProvider(
|
|
ctx context.Context,
|
|
input *dto.ScimServiceProviderCreateDTO) (model.ScimServiceProvider, error) {
|
|
provider := model.ScimServiceProvider{
|
|
Endpoint: input.Endpoint,
|
|
Token: datatype.EncryptedString(input.Token),
|
|
OidcClientID: input.OidcClientID,
|
|
}
|
|
|
|
if err := s.db.WithContext(ctx).Create(&provider).Error; err != nil {
|
|
return model.ScimServiceProvider{}, err
|
|
}
|
|
|
|
return provider, nil
|
|
}
|
|
|
|
func (s *ScimService) UpdateServiceProvider(ctx context.Context,
|
|
serviceProviderID string,
|
|
input *dto.ScimServiceProviderCreateDTO,
|
|
) (model.ScimServiceProvider, error) {
|
|
var provider model.ScimServiceProvider
|
|
err := s.db.WithContext(ctx).
|
|
First(&provider, "id = ?", serviceProviderID).
|
|
Error
|
|
if err != nil {
|
|
return model.ScimServiceProvider{}, err
|
|
}
|
|
|
|
provider.Endpoint = input.Endpoint
|
|
provider.Token = datatype.EncryptedString(input.Token)
|
|
provider.OidcClientID = input.OidcClientID
|
|
|
|
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
|
return model.ScimServiceProvider{}, err
|
|
}
|
|
|
|
return provider, nil
|
|
}
|
|
|
|
func (s *ScimService) DeleteServiceProvider(ctx context.Context, serviceProviderID string) error {
|
|
return s.db.WithContext(ctx).
|
|
Delete(&model.ScimServiceProvider{}, "id = ?", serviceProviderID).
|
|
Error
|
|
}
|
|
|
|
//nolint:contextcheck
|
|
func (s *ScimService) ScheduleSync() {
|
|
jobName := "ScheduledScimSync"
|
|
start := time.Now().Add(5 * time.Minute)
|
|
|
|
_ = s.scheduler.RemoveJob(jobName)
|
|
|
|
err := s.scheduler.RegisterJob(
|
|
context.Background(), jobName,
|
|
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false)
|
|
|
|
if err != nil {
|
|
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
|
|
}
|
|
}
|
|
|
|
func (s *ScimService) SyncAll(ctx context.Context) error {
|
|
providers, err := s.ListServiceProviders(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var errs []error
|
|
for _, provider := range providers {
|
|
if ctx.Err() != nil {
|
|
errs = append(errs, ctx.Err())
|
|
break
|
|
}
|
|
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil {
|
|
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
|
|
}
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID string) error {
|
|
start := time.Now()
|
|
provider, err := s.GetServiceProvider(ctx, serviceProviderID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.InfoContext(ctx, "Syncing SCIM service provider",
|
|
slog.String("provider_id", provider.ID),
|
|
slog.String("oidc_client_id", provider.OidcClientID),
|
|
)
|
|
|
|
allowedGroupIDs := groupIDs(provider.OidcClient.AllowedUserGroups)
|
|
|
|
// Load users and groups that should be synced to the SCIM provider
|
|
groups, err := s.groupsForClient(ctx, provider.OidcClient, allowedGroupIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
users, err := s.usersForClient(ctx, provider.OidcClient, allowedGroupIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Load users and groups that already exist in the SCIM provider
|
|
userResources, err := listScimResources[dto.ScimUser](s, ctx, provider, "/Users")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
groupResources, err := listScimResources[dto.ScimGroup](s, ctx, provider, "/Groups")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var errs []error
|
|
var userStats scimSyncStats
|
|
var groupStats scimSyncStats
|
|
|
|
// Sync users first, so that groups can reference them
|
|
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil {
|
|
errs = append(errs, err)
|
|
userStats = stats
|
|
} else {
|
|
userStats = stats
|
|
}
|
|
|
|
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
groupStats = stats
|
|
} else {
|
|
groupStats = stats
|
|
}
|
|
|
|
if len(errs) > 0 {
|
|
slog.WarnContext(ctx, "SCIM sync completed with errors",
|
|
slog.String("provider_id", provider.ID),
|
|
slog.Int("error_count", len(errs)),
|
|
slog.Int("users_created", userStats.Created),
|
|
slog.Int("users_updated", userStats.Updated),
|
|
slog.Int("users_deleted", userStats.Deleted),
|
|
slog.Int("groups_created", groupStats.Created),
|
|
slog.Int("groups_updated", groupStats.Updated),
|
|
slog.Int("groups_deleted", groupStats.Deleted),
|
|
slog.Duration("duration", time.Since(start)),
|
|
)
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
provider.LastSyncedAt = utils.Ptr(datatype.DateTime(time.Now()))
|
|
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.InfoContext(ctx, "SCIM sync completed",
|
|
slog.String("provider_id", provider.ID),
|
|
slog.Int("users_created", userStats.Created),
|
|
slog.Int("users_updated", userStats.Updated),
|
|
slog.Int("users_deleted", userStats.Deleted),
|
|
slog.Int("groups_created", groupStats.Created),
|
|
slog.Int("groups_updated", groupStats.Updated),
|
|
slog.Int("groups_deleted", groupStats.Deleted),
|
|
slog.Duration("duration", time.Since(start)),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ScimService) syncUsers(
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
users []model.User,
|
|
resourceList *dto.ScimListResponse[dto.ScimUser],
|
|
) (stats scimSyncStats, err error) {
|
|
var errs []error
|
|
|
|
// Update or create users
|
|
for _, u := range users {
|
|
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources)
|
|
|
|
action, created, err := s.syncUser(ctx, provider, u, existing)
|
|
if created != nil && existing == nil {
|
|
resourceList.Resources = append(resourceList.Resources, *created)
|
|
}
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
|
|
// Update stats based on action taken by syncUser
|
|
switch action {
|
|
case scimActionCreated:
|
|
stats.Created++
|
|
case scimActionUpdated:
|
|
stats.Updated++
|
|
case scimActionDeleted:
|
|
stats.Deleted++
|
|
case scimActionNone:
|
|
}
|
|
}
|
|
|
|
// Delete users that are present in SCIM provider but not locally.
|
|
userSet := make(map[string]struct{})
|
|
for _, u := range users {
|
|
userSet[u.ID] = struct{}{}
|
|
}
|
|
|
|
for _, r := range resourceList.Resources {
|
|
if _, ok := userSet[r.ExternalID]; !ok {
|
|
if err := s.deleteScimResource(ctx, provider, "/Users/"+url.PathEscape(r.ID)); err != nil {
|
|
errs = append(errs, err)
|
|
} else {
|
|
stats.Deleted++
|
|
}
|
|
}
|
|
}
|
|
|
|
return stats, errors.Join(errs...)
|
|
}
|
|
|
|
func (s *ScimService) syncGroups(
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
groups []model.UserGroup,
|
|
remoteGroups []dto.ScimGroup,
|
|
userResources []dto.ScimUser,
|
|
) (stats scimSyncStats, err error) {
|
|
var errs []error
|
|
|
|
// Update or create groups
|
|
for _, g := range groups {
|
|
existing := getResourceByExternalID[dto.ScimGroup](g.ID, remoteGroups)
|
|
|
|
action, err := s.syncGroup(ctx, provider, g, existing, userResources)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
|
|
// Update stats based on action taken by syncGroup
|
|
switch action {
|
|
case scimActionCreated:
|
|
stats.Created++
|
|
case scimActionUpdated:
|
|
stats.Updated++
|
|
case scimActionDeleted:
|
|
stats.Deleted++
|
|
case scimActionNone:
|
|
}
|
|
|
|
}
|
|
|
|
// Delete groups that are present in SCIM provider but not locally
|
|
groupSet := make(map[string]struct{})
|
|
for _, g := range groups {
|
|
groupSet[g.ID] = struct{}{}
|
|
}
|
|
|
|
for _, r := range remoteGroups {
|
|
if _, ok := groupSet[r.ExternalID]; !ok {
|
|
if err := s.deleteScimResource(ctx, provider, "/Groups/"+url.PathEscape(r.GetID())); err != nil {
|
|
errs = append(errs, err)
|
|
} else {
|
|
stats.Deleted++
|
|
}
|
|
}
|
|
}
|
|
|
|
return stats, errors.Join(errs...)
|
|
}
|
|
|
|
func (s *ScimService) syncUser(ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
user model.User,
|
|
userResource *dto.ScimUser,
|
|
) (scimSyncAction, *dto.ScimUser, error) {
|
|
// If user is not allowed for the client, delete it from SCIM provider
|
|
if userResource != nil && !IsUserGroupAllowedToAuthorize(user, provider.OidcClient) {
|
|
return scimActionDeleted, nil, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Users/%s", url.PathEscape(userResource.ID)))
|
|
}
|
|
|
|
payload := dto.ScimUser{
|
|
ScimResourceData: dto.ScimResourceData{
|
|
Schemas: []string{scimUserSchema},
|
|
ExternalID: user.ID,
|
|
},
|
|
UserName: user.Username,
|
|
Name: &dto.ScimName{
|
|
GivenName: user.FirstName,
|
|
FamilyName: user.LastName,
|
|
},
|
|
Display: user.DisplayName,
|
|
Active: !user.Disabled,
|
|
}
|
|
|
|
if user.Email != nil {
|
|
payload.Emails = []dto.ScimEmail{{
|
|
Value: *user.Email,
|
|
Primary: true,
|
|
}}
|
|
}
|
|
|
|
// If the user exists on the SCIM provider, and it has been modified, update it
|
|
if userResource != nil {
|
|
if user.LastModified().Before(userResource.GetMeta().LastModified) {
|
|
return scimActionNone, nil, nil
|
|
}
|
|
path := fmt.Sprintf("/Users/%s", url.PathEscape(userResource.GetID()))
|
|
userResource, err := updateScimResource(s, ctx, provider, path, payload)
|
|
if err != nil {
|
|
return scimActionNone, nil, err
|
|
}
|
|
return scimActionUpdated, userResource, nil
|
|
}
|
|
|
|
// Otherwise, create a new SCIM user
|
|
userResource, err := createScimResource(s, ctx, provider, "/Users", payload)
|
|
if err != nil {
|
|
return scimActionNone, nil, err
|
|
}
|
|
|
|
return scimActionCreated, userResource, nil
|
|
}
|
|
|
|
func (s *ScimService) syncGroup(
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
group model.UserGroup,
|
|
groupResource *dto.ScimGroup,
|
|
userResources []dto.ScimUser,
|
|
) (scimSyncAction, error) {
|
|
// If group is not allowed for the client, delete it from SCIM provider
|
|
if groupResource != nil && !groupAllowedForClient(group.ID, provider.OidcClient) {
|
|
return scimActionDeleted, s.deleteScimResource(ctx, provider, fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID())))
|
|
}
|
|
|
|
// Prepare group members
|
|
members := make([]dto.ScimGroupMember, len(group.Users))
|
|
for i, user := range group.Users {
|
|
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources)
|
|
if userResource == nil {
|
|
// Groups depend on user IDs already being provisioned
|
|
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)
|
|
}
|
|
|
|
members[i] = dto.ScimGroupMember{
|
|
Value: userResource.GetID(),
|
|
}
|
|
}
|
|
|
|
groupPayload := dto.ScimGroup{
|
|
ScimResourceData: dto.ScimResourceData{
|
|
Schemas: []string{scimGroupSchema},
|
|
ExternalID: group.ID,
|
|
},
|
|
Display: group.FriendlyName,
|
|
Members: members,
|
|
}
|
|
|
|
// If the group exists on the SCIM provider, and it has been modified, update it
|
|
if groupResource != nil {
|
|
if group.LastModified().Before(groupResource.GetMeta().LastModified) {
|
|
return scimActionNone, nil
|
|
}
|
|
path := fmt.Sprintf("/Groups/%s", url.PathEscape(groupResource.GetID()))
|
|
_, err := updateScimResource(s, ctx, provider, path, groupPayload)
|
|
if err != nil {
|
|
return scimActionNone, err
|
|
}
|
|
return scimActionUpdated, nil
|
|
}
|
|
|
|
// Otherwise, create a new SCIM group
|
|
_, err := createScimResource(s, ctx, provider, "/Groups", groupPayload)
|
|
if err != nil {
|
|
return scimActionNone, err
|
|
}
|
|
|
|
return scimActionCreated, nil
|
|
}
|
|
|
|
func groupAllowedForClient(groupID string, client model.OidcClient) bool {
|
|
if !client.IsGroupRestricted {
|
|
return true
|
|
}
|
|
|
|
for _, allowedGroup := range client.AllowedUserGroups {
|
|
if allowedGroup.ID == groupID {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func groupIDs(groups []model.UserGroup) []string {
|
|
ids := make([]string, len(groups))
|
|
for i, g := range groups {
|
|
ids[i] = g.ID
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func (s *ScimService) groupsForClient(
|
|
ctx context.Context,
|
|
client model.OidcClient,
|
|
allowedGroupIDs []string,
|
|
) ([]model.UserGroup, error) {
|
|
var groups []model.UserGroup
|
|
|
|
query := s.db.WithContext(ctx).Preload("Users").Model(&model.UserGroup{})
|
|
if client.IsGroupRestricted {
|
|
if len(allowedGroupIDs) == 0 {
|
|
return groups, nil
|
|
}
|
|
query = query.Where("id IN ?", allowedGroupIDs)
|
|
}
|
|
|
|
if err := query.Find(&groups).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return groups, nil
|
|
}
|
|
|
|
func (s *ScimService) usersForClient(
|
|
ctx context.Context,
|
|
client model.OidcClient,
|
|
allowedGroupIDs []string,
|
|
) ([]model.User, error) {
|
|
var users []model.User
|
|
|
|
query := s.db.WithContext(ctx).Model(&model.User{})
|
|
if client.IsGroupRestricted {
|
|
if len(allowedGroupIDs) == 0 {
|
|
return users, nil
|
|
}
|
|
query = query.
|
|
Joins("JOIN user_groups_users ON users.id = user_groups_users.user_id").
|
|
Where("user_groups_users.user_group_id IN ?", allowedGroupIDs).
|
|
Select("users.*").
|
|
Distinct()
|
|
}
|
|
|
|
query = query.Preload("UserGroups")
|
|
|
|
if err := query.Find(&users).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func getResourceByExternalID[T dto.ScimResource](externalID string, resource []T) *T {
|
|
for i := range resource {
|
|
if resource[i].GetExternalID() == externalID {
|
|
return &resource[i]
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func listScimResources[T any](
|
|
s *ScimService,
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
path string,
|
|
) (result dto.ScimListResponse[T], err error) {
|
|
startIndex := 1
|
|
count := 1000
|
|
|
|
for {
|
|
// Use SCIM pagination to avoid missing resources on large providers
|
|
queryParams := map[string]string{
|
|
"startIndex": strconv.Itoa(startIndex),
|
|
"count": strconv.Itoa(count),
|
|
}
|
|
|
|
resp, err := s.scimRequest(ctx, provider, http.MethodGet, path, nil, queryParams)
|
|
if err != nil {
|
|
return dto.ScimListResponse[T]{}, err
|
|
}
|
|
|
|
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK); err != nil {
|
|
return dto.ScimListResponse[T]{}, err
|
|
}
|
|
|
|
var page dto.ScimListResponse[T]
|
|
if err := json.NewDecoder(resp.Body).Decode(&page); err != nil {
|
|
return dto.ScimListResponse[T]{}, fmt.Errorf("failed to decode SCIM list response: %w", err)
|
|
}
|
|
|
|
resp.Body.Close()
|
|
|
|
// Initialize metadata only once
|
|
if result.TotalResults == 0 {
|
|
result.TotalResults = page.TotalResults
|
|
}
|
|
|
|
result.Resources = append(result.Resources, page.Resources...)
|
|
|
|
// If we've fetched everything, stop
|
|
if len(result.Resources) >= page.TotalResults || len(page.Resources) == 0 {
|
|
break
|
|
}
|
|
|
|
startIndex += page.ItemsPerPage
|
|
}
|
|
|
|
result.ItemsPerPage = len(result.Resources)
|
|
return result, nil
|
|
}
|
|
|
|
func createScimResource[T dto.ScimResource](
|
|
s *ScimService,
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
path string, payload T) (*T, error) {
|
|
resp, err := s.scimRequest(ctx, provider, http.MethodPost, path, payload, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resource T
|
|
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
|
|
return nil, fmt.Errorf("failed to decode SCIM create response: %w", err)
|
|
}
|
|
|
|
return &resource, nil
|
|
}
|
|
|
|
func updateScimResource[T dto.ScimResource](
|
|
s *ScimService,
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
path string,
|
|
payload T,
|
|
) (*T, error) {
|
|
resp, err := s.scimRequest(ctx, provider, http.MethodPut, path, payload, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if err := ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusCreated); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resource T
|
|
if err := json.NewDecoder(resp.Body).Decode(&resource); err != nil {
|
|
return nil, fmt.Errorf("failed to decode SCIM update response: %w", err)
|
|
}
|
|
|
|
return &resource, nil
|
|
}
|
|
|
|
func (s *ScimService) deleteScimResource(ctx context.Context, provider model.ScimServiceProvider, path string) error {
|
|
resp, err := s.scimRequest(ctx, provider, http.MethodDelete, path, nil, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusNotFound {
|
|
return nil
|
|
}
|
|
|
|
return ensureScimStatus(ctx, resp, provider, http.StatusOK, http.StatusNoContent)
|
|
}
|
|
|
|
func (s *ScimService) scimRequest(
|
|
ctx context.Context,
|
|
provider model.ScimServiceProvider,
|
|
method,
|
|
path string,
|
|
payload any,
|
|
queryParams map[string]string,
|
|
) (*http.Response, error) {
|
|
urlString, err := scimURL(provider.Endpoint, path, queryParams)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var bodyBytes []byte
|
|
if payload != nil {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode SCIM payload: %w", err)
|
|
}
|
|
bodyBytes = encoded
|
|
}
|
|
|
|
retryAttempts := 3
|
|
for attempt := 1; attempt <= retryAttempts; attempt++ {
|
|
var body io.Reader
|
|
if bodyBytes != nil {
|
|
body = bytes.NewReader(bodyBytes)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, urlString, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Accept", scimContentType)
|
|
if payload != nil {
|
|
req.Header.Set("Content-Type", scimContentType)
|
|
}
|
|
token := string(provider.Token)
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
slog.Debug("Sending SCIM request",
|
|
slog.String("method", method),
|
|
slog.String("url", urlString),
|
|
slog.String("provider_id", provider.ID),
|
|
)
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Only retry on 429 to avoid masking other errors
|
|
if resp.StatusCode != http.StatusTooManyRequests || attempt == retryAttempts {
|
|
return resp, nil
|
|
}
|
|
|
|
retryDelay := scimRetryDelay(resp.Header.Get("Retry-After"), attempt)
|
|
slog.WarnContext(ctx, "SCIM provider rate-limited, retrying",
|
|
slog.String("provider_id", provider.ID),
|
|
slog.String("method", method),
|
|
slog.String("url", urlString),
|
|
slog.Int("attempt", attempt),
|
|
slog.Duration("retry_after", retryDelay),
|
|
)
|
|
|
|
resp.Body.Close()
|
|
if err := utils.SleepWithContext(ctx, retryDelay); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("scim request retry attempts exceeded")
|
|
}
|
|
|
|
func scimRetryDelay(retryAfter string, attempt int) time.Duration {
|
|
// Respect Retry-After when provided
|
|
if retryAfter != "" {
|
|
if seconds, err := strconv.Atoi(retryAfter); err == nil {
|
|
return time.Duration(seconds) * time.Second
|
|
}
|
|
if t, err := http.ParseTime(retryAfter); err == nil {
|
|
if delay := time.Until(t); delay > 0 {
|
|
return delay
|
|
}
|
|
}
|
|
}
|
|
|
|
// Exponential backoff otherwise
|
|
maxDelay := 10 * time.Second
|
|
delay := 500 * time.Millisecond * (time.Duration(1) << (attempt - 1)) //nolint:gosec // attempt is bounded 1-3
|
|
if delay > maxDelay {
|
|
return maxDelay
|
|
}
|
|
return delay
|
|
}
|
|
|
|
func scimURL(endpoint, p string, queryParams map[string]string) (string, error) {
|
|
u, err := url.Parse(endpoint)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid scim endpoint: %w", err)
|
|
}
|
|
|
|
u.Path = path.Join(strings.TrimRight(u.Path, "/"), p)
|
|
|
|
q := u.Query()
|
|
for key, value := range queryParams {
|
|
q.Set(key, value)
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
|
|
return u.String(), nil
|
|
}
|
|
|
|
func ensureScimStatus(
|
|
ctx context.Context,
|
|
resp *http.Response,
|
|
provider model.ScimServiceProvider,
|
|
allowedStatuses ...int) error {
|
|
for _, status := range allowedStatuses {
|
|
if resp.StatusCode == status {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
body := readScimErrorBody(resp.Body)
|
|
|
|
slog.ErrorContext(ctx, "SCIM request failed",
|
|
slog.String("provider_id", provider.ID),
|
|
slog.String("method", resp.Request.Method),
|
|
slog.String("url", resp.Request.URL.String()),
|
|
slog.Int("status", resp.StatusCode),
|
|
slog.String("response_body", body),
|
|
)
|
|
|
|
return fmt.Errorf("scim request failed with status %d: %s", resp.StatusCode, body)
|
|
}
|
|
|
|
func readScimErrorBody(body io.Reader) string {
|
|
payload, err := io.ReadAll(io.LimitReader(body, scimErrorBodyLimit))
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(string(payload))
|
|
}
|