mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 11:36:46 +00:00
refactor: run SCIM jobs in context of gocron instead of custom implementation
This commit is contained in:
@@ -48,8 +48,13 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scheduler, err := job.NewScheduler()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create job scheduler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create all services
|
// Create all services
|
||||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage, scheduler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize services: %w", err)
|
return fmt.Errorf("failed to initialize services: %w", err)
|
||||||
}
|
}
|
||||||
@@ -74,11 +79,7 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
shutdownFns = append(shutdownFns, shutdownFn)
|
shutdownFns = append(shutdownFns, shutdownFn)
|
||||||
|
|
||||||
// Init the job scheduler
|
// Register scheduled jobs
|
||||||
scheduler, err := job.NewScheduler()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create job scheduler: %w", err)
|
|
||||||
}
|
|
||||||
err = registerScheduledJobs(ctx, db, svc, httpClient, scheduler)
|
err = registerScheduledJobs(ctx, db, svc, httpClient, scheduler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to register scheduled jobs: %w", err)
|
return fmt.Errorf("failed to register scheduled jobs: %w", err)
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ func registerScheduledJobs(ctx context.Context, db *gorm.DB, svc *services, http
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to register analytics job in scheduler: %w", err)
|
return fmt.Errorf("failed to register analytics job in scheduler: %w", err)
|
||||||
}
|
}
|
||||||
|
err = scheduler.RegisterScimJobs(ctx, svc.scimService)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to register SCIM scheduler job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/job"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
@@ -12,28 +13,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type services struct {
|
type services struct {
|
||||||
appConfigService *service.AppConfigService
|
appConfigService *service.AppConfigService
|
||||||
appImagesService *service.AppImagesService
|
appImagesService *service.AppImagesService
|
||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
geoLiteService *service.GeoLiteService
|
geoLiteService *service.GeoLiteService
|
||||||
auditLogService *service.AuditLogService
|
auditLogService *service.AuditLogService
|
||||||
jwtService *service.JwtService
|
jwtService *service.JwtService
|
||||||
webauthnService *service.WebAuthnService
|
webauthnService *service.WebAuthnService
|
||||||
scimService *service.ScimService
|
scimService *service.ScimService
|
||||||
scimSchedulerService *service.ScimSchedulerService
|
userService *service.UserService
|
||||||
userService *service.UserService
|
customClaimService *service.CustomClaimService
|
||||||
customClaimService *service.CustomClaimService
|
oidcService *service.OidcService
|
||||||
oidcService *service.OidcService
|
userGroupService *service.UserGroupService
|
||||||
userGroupService *service.UserGroupService
|
ldapService *service.LdapService
|
||||||
ldapService *service.LdapService
|
apiKeyService *service.ApiKeyService
|
||||||
apiKeyService *service.ApiKeyService
|
versionService *service.VersionService
|
||||||
versionService *service.VersionService
|
fileStorage storage.FileStorage
|
||||||
fileStorage storage.FileStorage
|
appLockService *service.AppLockService
|
||||||
appLockService *service.AppLockService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializes all services
|
// Initializes all services
|
||||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string, fileStorage storage.FileStorage) (svc *services, err error) {
|
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string, fileStorage storage.FileStorage, scheduler *job.Scheduler) (svc *services, err error) {
|
||||||
svc = &services{}
|
svc = &services{}
|
||||||
|
|
||||||
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
||||||
@@ -63,20 +63,17 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
|||||||
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient, fileStorage)
|
svc.scimService = service.NewScimService(db, scheduler, httpClient)
|
||||||
|
|
||||||
|
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, svc.scimService, httpClient, fileStorage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService, svc.scimService)
|
||||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, svc.scimService, fileStorage)
|
||||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
|
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
|
||||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||||
svc.scimService = service.NewScimService(db, httpClient)
|
|
||||||
svc.scimSchedulerService, err = service.NewScimSchedulerService(ctx, svc.scimService)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create SCIM scheduler service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
svc.versionService = service.NewVersionService(httpClient)
|
svc.versionService = service.NewVersionService(httpClient)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func (s *Scheduler) RegisterAnalyticsJob(ctx context.Context, appConfig *service
|
|||||||
appConfig: appConfig,
|
appConfig: appConfig,
|
||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
}
|
}
|
||||||
return s.registerJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, true)
|
return s.RegisterJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnalyticsJob struct {
|
type AnalyticsJob struct {
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func (s *Scheduler) RegisterApiKeyExpiryJob(ctx context.Context, apiKeyService *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send every day at midnight
|
// Send every day at midnight
|
||||||
return s.registerJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, false)
|
return s.RegisterJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
|
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) erro
|
|||||||
// Run every 24 hours (but with some jitter so they don't run at the exact same time), and now
|
// Run every 24 hours (but with some jitter so they don't run at the exact same time), and now
|
||||||
def := gocron.DurationRandomJob(24*time.Hour-2*time.Minute, 24*time.Hour+2*time.Minute)
|
def := gocron.DurationRandomJob(24*time.Hour-2*time.Minute, 24*time.Hour+2*time.Minute)
|
||||||
return errors.Join(
|
return errors.Join(
|
||||||
s.registerJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true),
|
s.RegisterJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true),
|
||||||
s.registerJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true),
|
s.RegisterJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true),
|
||||||
s.registerJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true),
|
s.RegisterJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true),
|
||||||
s.registerJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true),
|
s.RegisterJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true),
|
||||||
s.registerJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true),
|
s.RegisterJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true),
|
||||||
s.registerJob(ctx, "ClearReauthenticationTokens", def, jobs.clearReauthenticationTokens, true),
|
s.RegisterJob(ctx, "ClearReauthenticationTokens", def, jobs.clearReauthenticationTokens, true),
|
||||||
s.registerJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true),
|
s.RegisterJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ import (
|
|||||||
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
|
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
|
||||||
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
|
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
|
||||||
|
|
||||||
err := s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
|
err := s.RegisterJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
|
||||||
|
|
||||||
// Only necessary for file system storage
|
// Only necessary for file system storage
|
||||||
if fileStorage.Type() == storage.TypeFileSystem {
|
if fileStorage.Type() == storage.TypeFileSystem {
|
||||||
err = errors.Join(err, s.registerJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true))
|
err = errors.Join(err, s.RegisterJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (s *Scheduler) RegisterGeoLiteUpdateJobs(ctx context.Context, geoLiteServic
|
|||||||
jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService}
|
jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService}
|
||||||
|
|
||||||
// Run every 24 hours (and right away)
|
// Run every 24 hours (and right away)
|
||||||
return s.registerJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, true)
|
return s.RegisterJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error {
|
func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (s *Scheduler) RegisterLdapJobs(ctx context.Context, ldapService *service.L
|
|||||||
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
|
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
|
||||||
|
|
||||||
// Register the job to run every hour
|
// Register the job to run every hour
|
||||||
return s.registerJob(ctx, "SyncLdap", gocron.DurationJob(time.Hour), jobs.syncLdap, true)
|
return s.RegisterJob(ctx, "SyncLdap", gocron.DurationJob(time.Hour), jobs.syncLdap, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *LdapJobs) syncLdap(ctx context.Context) error {
|
func (j *LdapJobs) syncLdap(ctx context.Context) error {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package job
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
@@ -24,6 +25,26 @@ func NewScheduler() (*Scheduler, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) RemoveJob(name string) error {
|
||||||
|
jobs := s.scheduler.Jobs()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, job := range jobs {
|
||||||
|
if job.Name() == name {
|
||||||
|
err := s.scheduler.RemoveJob(job.ID())
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to unqueue job %q with ID %q: %w", name, job.ID().String(), err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run the scheduler.
|
// Run the scheduler.
|
||||||
// This function blocks until the context is canceled.
|
// This function blocks until the context is canceled.
|
||||||
func (s *Scheduler) Run(ctx context.Context) error {
|
func (s *Scheduler) Run(ctx context.Context) error {
|
||||||
@@ -43,9 +64,10 @@ func (s *Scheduler) Run(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) registerJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool) error {
|
func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error {
|
||||||
jobOptions := []gocron.JobOption{
|
jobOptions := []gocron.JobOption{
|
||||||
gocron.WithContext(ctx),
|
gocron.WithContext(ctx),
|
||||||
|
gocron.WithName(name),
|
||||||
gocron.WithEventListeners(
|
gocron.WithEventListeners(
|
||||||
gocron.BeforeJobRuns(func(jobID uuid.UUID, jobName string) {
|
gocron.BeforeJobRuns(func(jobID uuid.UUID, jobName string) {
|
||||||
slog.Info("Starting job",
|
slog.Info("Starting job",
|
||||||
@@ -73,6 +95,8 @@ func (s *Scheduler) registerJob(ctx context.Context, name string, def gocron.Job
|
|||||||
jobOptions = append(jobOptions, gocron.JobOption(gocron.WithStartImmediately()))
|
jobOptions = append(jobOptions, gocron.JobOption(gocron.WithStartImmediately()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jobOptions = append(jobOptions, extraOptions...)
|
||||||
|
|
||||||
_, err := s.scheduler.NewJob(def, gocron.NewTask(job), jobOptions...)
|
_, err := s.scheduler.NewJob(def, gocron.NewTask(job), jobOptions...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
25
backend/internal/job/scim_job.go
Normal file
25
backend/internal/job/scim_job.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package job
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-co-op/gocron/v2"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ScimJobs struct {
|
||||||
|
scimService *service.ScimService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) RegisterScimJobs(ctx context.Context, scimService *service.ScimService) error {
|
||||||
|
jobs := &ScimJobs{scimService: scimService}
|
||||||
|
|
||||||
|
// Register the job to run every hour
|
||||||
|
return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *ScimJobs) SyncScim(ctx context.Context) error {
|
||||||
|
return j.scimService.SyncAll(ctx)
|
||||||
|
}
|
||||||
@@ -56,6 +56,7 @@ type OidcService struct {
|
|||||||
auditLogService *AuditLogService
|
auditLogService *AuditLogService
|
||||||
customClaimService *CustomClaimService
|
customClaimService *CustomClaimService
|
||||||
webAuthnService *WebAuthnService
|
webAuthnService *WebAuthnService
|
||||||
|
scimService *ScimService
|
||||||
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
jwkCache *jwk.Cache
|
jwkCache *jwk.Cache
|
||||||
@@ -70,6 +71,7 @@ func NewOidcService(
|
|||||||
auditLogService *AuditLogService,
|
auditLogService *AuditLogService,
|
||||||
customClaimService *CustomClaimService,
|
customClaimService *CustomClaimService,
|
||||||
webAuthnService *WebAuthnService,
|
webAuthnService *WebAuthnService,
|
||||||
|
scimService *ScimService,
|
||||||
httpClient *http.Client,
|
httpClient *http.Client,
|
||||||
fileStorage storage.FileStorage,
|
fileStorage storage.FileStorage,
|
||||||
) (s *OidcService, err error) {
|
) (s *OidcService, err error) {
|
||||||
@@ -80,6 +82,7 @@ func NewOidcService(
|
|||||||
auditLogService: auditLogService,
|
auditLogService: auditLogService,
|
||||||
customClaimService: customClaimService,
|
customClaimService: customClaimService,
|
||||||
webAuthnService: webAuthnService,
|
webAuthnService: webAuthnService,
|
||||||
|
scimService: scimService,
|
||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
fileStorage: fileStorage,
|
fileStorage: fileStorage,
|
||||||
}
|
}
|
||||||
@@ -1088,6 +1091,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
|
|||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ScimSchedulerService schedules and triggers periodic synchronization
|
|
||||||
// of SCIM service providers. Each provider is tracked independently,
|
|
||||||
// and sync operations are run at or after their scheduled time.
|
|
||||||
type ScimSchedulerService struct {
|
|
||||||
scimService *ScimService
|
|
||||||
providerSyncTime map[string]time.Time
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewScimSchedulerService(ctx context.Context, scimService *ScimService) (*ScimSchedulerService, error) {
|
|
||||||
s := &ScimSchedulerService{
|
|
||||||
scimService: scimService,
|
|
||||||
providerSyncTime: make(map[string]time.Time),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.start(ctx)
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScheduleSync forces the given provider to be synced soon by
|
|
||||||
// moving its next scheduled time to 5 minutes from now.
|
|
||||||
func (s *ScimSchedulerService) ScheduleSync(providerID string) {
|
|
||||||
s.setSyncTime(providerID, 5*time.Minute)
|
|
||||||
}
|
|
||||||
|
|
||||||
// start initializes the scheduler and begins the synchronization loop.
|
|
||||||
// Syncs happen every hour by default, but ScheduleSync can be called to schedule a sync sooner.
|
|
||||||
func (s *ScimSchedulerService) start(ctx context.Context) error {
|
|
||||||
if err := s.refreshProviders(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
const (
|
|
||||||
syncCheckInterval = 5 * time.Second
|
|
||||||
providerRefreshDelay = time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
ticker := time.NewTicker(syncCheckInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
lastProviderRefresh := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
// Runs every 5 seconds to check if any provider is due for sync
|
|
||||||
case <-ticker.C:
|
|
||||||
now := time.Now()
|
|
||||||
if now.Sub(lastProviderRefresh) >= providerRefreshDelay {
|
|
||||||
err := s.refreshProviders(ctx)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("Error refreshing SCIM service providers",
|
|
||||||
slog.Any("error", err),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
lastProviderRefresh = now
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var due []string
|
|
||||||
s.mu.RLock()
|
|
||||||
for providerID, syncTime := range s.providerSyncTime {
|
|
||||||
if !syncTime.After(now) {
|
|
||||||
due = append(due, providerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
|
|
||||||
s.syncProviders(ctx, due)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ScimSchedulerService) refreshProviders(ctx context.Context) error {
|
|
||||||
providers, err := s.scimService.ListServiceProviders(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
inAHour := time.Now().Add(time.Hour)
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
for _, provider := range providers {
|
|
||||||
if _, exists := s.providerSyncTime[provider.ID]; !exists {
|
|
||||||
s.providerSyncTime[provider.ID] = inAHour
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ScimSchedulerService) syncProviders(ctx context.Context, providerIDs []string) {
|
|
||||||
for _, providerID := range providerIDs {
|
|
||||||
err := s.scimService.SyncServiceProvider(ctx, providerID)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
// Remove the provider from the schedule if it no longer exists
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.providerSyncTime, providerID)
|
|
||||||
s.mu.Unlock()
|
|
||||||
} else {
|
|
||||||
slog.Error("Error syncing SCIM client",
|
|
||||||
slog.String("provider_id", providerID),
|
|
||||||
slog.Any("error", err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// A successful sync schedules the next sync in an hour
|
|
||||||
s.setSyncTime(providerID, time.Hour)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ScimSchedulerService) setSyncTime(providerID string, t time.Duration) {
|
|
||||||
s.mu.Lock()
|
|
||||||
s.providerSyncTime[providerID] = time.Now().Add(t)
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-co-op/gocron/v2"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
@@ -32,6 +33,11 @@ const scimErrorBodyLimit = 4096
|
|||||||
|
|
||||||
type scimSyncAction int
|
type scimSyncAction int
|
||||||
|
|
||||||
|
type Scheduler interface {
|
||||||
|
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
|
||||||
|
RemoveJob(name string) error
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
scimActionNone scimSyncAction = iota
|
scimActionNone scimSyncAction = iota
|
||||||
scimActionCreated
|
scimActionCreated
|
||||||
@@ -48,15 +54,16 @@ type scimSyncStats struct {
|
|||||||
// ScimService handles SCIM provisioning to external service providers.
|
// ScimService handles SCIM provisioning to external service providers.
|
||||||
type ScimService struct {
|
type ScimService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
scheduler Scheduler
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewScimService(db *gorm.DB, httpClient *http.Client) *ScimService {
|
func NewScimService(db *gorm.DB, scheduler Scheduler, httpClient *http.Client) *ScimService {
|
||||||
if httpClient == nil {
|
if httpClient == nil {
|
||||||
httpClient = &http.Client{Timeout: 20 * time.Second}
|
httpClient = &http.Client{Timeout: 20 * time.Second}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ScimService{db: db, httpClient: httpClient}
|
return &ScimService{db: db, scheduler: scheduler, httpClient: httpClient}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ScimService) GetServiceProvider(
|
func (s *ScimService) GetServiceProvider(
|
||||||
@@ -132,6 +139,41 @@ func (s *ScimService) DeleteServiceProvider(ctx context.Context, serviceProvider
|
|||||||
Error
|
Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:contextcheck
|
||||||
|
func (s *ScimService) ScheduleSync() {
|
||||||
|
jobName := "ScheduledScimSync"
|
||||||
|
start := time.Now().Add(5 * time.Minute)
|
||||||
|
|
||||||
|
_ = s.scheduler.RemoveJob(jobName)
|
||||||
|
|
||||||
|
err := s.scheduler.RegisterJob(
|
||||||
|
context.Background(), jobName,
|
||||||
|
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScimService) SyncAll(ctx context.Context) error {
|
||||||
|
providers, err := s.ListServiceProviders(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, provider := range providers {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
errs = append(errs, ctx.Err())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID string) error {
|
func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID string) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
provider, err := s.GetServiceProvider(ctx, serviceProviderID)
|
provider, err := s.GetServiceProvider(ctx, serviceProviderID)
|
||||||
|
|||||||
@@ -16,11 +16,12 @@ import (
|
|||||||
|
|
||||||
type UserGroupService struct {
|
type UserGroupService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
scimService *ScimService
|
||||||
appConfigService *AppConfigService
|
appConfigService *AppConfigService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserGroupService {
|
func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService, scimService *ScimService) *UserGroupService {
|
||||||
return &UserGroupService{db: db, appConfigService: appConfigService}
|
return &UserGroupService{db: db, appConfigService: appConfigService, scimService: scimService}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserGroupService) List(ctx context.Context, name string, listRequestOptions utils.ListRequestOptions) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
func (s *UserGroupService) List(ctx context.Context, name string, listRequestOptions utils.ListRequestOptions) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
|
||||||
@@ -90,7 +91,13 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
|
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
|
||||||
@@ -118,6 +125,8 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
|
|||||||
}
|
}
|
||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +174,8 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
|||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,6 +238,7 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,5 +315,6 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin
|
|||||||
return model.UserGroup{}, err
|
return model.UserGroup{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,10 +37,11 @@ type UserService struct {
|
|||||||
appConfigService *AppConfigService
|
appConfigService *AppConfigService
|
||||||
customClaimService *CustomClaimService
|
customClaimService *CustomClaimService
|
||||||
appImagesService *AppImagesService
|
appImagesService *AppImagesService
|
||||||
|
scimService *ScimService
|
||||||
fileStorage storage.FileStorage
|
fileStorage storage.FileStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService, fileStorage storage.FileStorage) *UserService {
|
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService, scimService *ScimService, fileStorage storage.FileStorage) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
db: db,
|
db: db,
|
||||||
jwtService: jwtService,
|
jwtService: jwtService,
|
||||||
@@ -49,6 +50,7 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
|
|||||||
appConfigService: appConfigService,
|
appConfigService: appConfigService,
|
||||||
customClaimService: customClaimService,
|
customClaimService: customClaimService,
|
||||||
appImagesService: appImagesService,
|
appImagesService: appImagesService,
|
||||||
|
scimService: scimService,
|
||||||
fileStorage: fileStorage,
|
fileStorage: fileStorage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -226,6 +228,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI
|
|||||||
return fmt.Errorf("failed to delete user: %w", err)
|
return fmt.Errorf("failed to delete user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,6 +312,7 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,6 +451,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -663,6 +668,7 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
|||||||
return model.User{}, err
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -753,12 +759,19 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
|
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
|
||||||
return tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Model(&model.User{}).
|
Model(&model.User{}).
|
||||||
Where("id = ?", userID).
|
Where("id = ?", userID).
|
||||||
Update("disabled", true).
|
Update("disabled", true).
|
||||||
Error
|
Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.scimService.ScheduleSync()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int, userGroupIDs []string) (model.SignupToken, error) {
|
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int, userGroupIDs []string) (model.SignupToken, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user