1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 20:15:07 +00:00

fix: various fixes in background jobs (#1362)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2026-03-07 09:07:26 -08:00
committed by GitHub
parent f4eb8db509
commit 2f56d16f98
21 changed files with 343 additions and 115 deletions

View File

@@ -28,7 +28,7 @@ func (s *Scheduler) RegisterAnalyticsJob(ctx context.Context, appConfig *service
appConfig: appConfig, appConfig: appConfig,
httpClient: httpClient, httpClient: httpClient,
} }
return s.RegisterJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, true) return s.RegisterJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, service.RegisterJobOpts{RunImmediately: true})
} }
type AnalyticsJob struct { type AnalyticsJob struct {

View File

@@ -22,7 +22,7 @@ func (s *Scheduler) RegisterApiKeyExpiryJob(ctx context.Context, apiKeyService *
} }
// Send every day at midnight // Send every day at midnight
return s.RegisterJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, false) return s.RegisterJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, service.RegisterJobOpts{})
} }
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error { func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
@@ -42,7 +42,11 @@ func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) err
} }
err = j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key) err = j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Failed to send expiring API key notification email", slog.String("key", key.ID), slog.Any("error", err)) slog.ErrorContext(ctx, "Failed to send expiring API key notification email",
slog.String("key", key.ID),
slog.String("user", key.User.ID),
slog.Any("error", err),
)
} }
} }
return nil return nil

View File

@@ -7,28 +7,37 @@ import (
"log/slog" "log/slog"
"time" "time"
"github.com/go-co-op/gocron/v2" backoff "github.com/cenkalti/backoff/v5"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
) )
func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) error { func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) error {
jobs := &DbCleanupJobs{db: db} jobs := &DbCleanupJobs{db: db}
// Run every 24 hours (but with some jitter so they don't run at the exact same time), and now newBackOff := func() *backoff.ExponentialBackOff {
def := gocron.DurationRandomJob(24*time.Hour-2*time.Minute, 24*time.Hour+2*time.Minute) bo := backoff.NewExponentialBackOff()
bo.Multiplier = 4
bo.RandomizationFactor = 0.1
bo.InitialInterval = time.Second
bo.MaxInterval = 45 * time.Second
return bo
}
// Use exponential backoff for each DB cleanup job so transient query failures are retried automatically rather than causing an immediate job failure
return errors.Join( return errors.Join(
s.RegisterJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true), s.RegisterJob(ctx, "ClearWebauthnSessions", jobDefWithJitter(24*time.Hour), jobs.clearWebauthnSessions, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true), s.RegisterJob(ctx, "ClearOneTimeAccessTokens", jobDefWithJitter(24*time.Hour), jobs.clearOneTimeAccessTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true), s.RegisterJob(ctx, "ClearSignupTokens", jobDefWithJitter(24*time.Hour), jobs.clearSignupTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearEmailVerificationTokens", def, jobs.clearEmailVerificationTokens, true), s.RegisterJob(ctx, "ClearEmailVerificationTokens", jobDefWithJitter(24*time.Hour), jobs.clearEmailVerificationTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true), s.RegisterJob(ctx, "ClearOidcAuthorizationCodes", jobDefWithJitter(24*time.Hour), jobs.clearOidcAuthorizationCodes, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true), s.RegisterJob(ctx, "ClearOidcRefreshTokens", jobDefWithJitter(24*time.Hour), jobs.clearOidcRefreshTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearReauthenticationTokens", def, jobs.clearReauthenticationTokens, true), s.RegisterJob(ctx, "ClearReauthenticationTokens", jobDefWithJitter(24*time.Hour), jobs.clearReauthenticationTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
s.RegisterJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true), s.RegisterJob(ctx, "ClearAuditLogs", jobDefWithJitter(24*time.Hour), jobs.clearAuditLogs, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
) )
} }

View File

@@ -13,20 +13,26 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/storage" "github.com/pocket-id/pocket-id/backend/internal/storage"
) )
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error { func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage} jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
err := s.RegisterJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false) var errs []error
errs = append(errs,
s.RegisterJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, service.RegisterJobOpts{}),
)
// Only necessary for file system storage // Only necessary for file system storage
if fileStorage.Type() == storage.TypeFileSystem { if fileStorage.Type() == storage.TypeFileSystem {
err = errors.Join(err, s.RegisterJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true)) errs = append(errs,
s.RegisterJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, service.RegisterJobOpts{RunImmediately: true}),
)
} }
return err return errors.Join(errs...)
} }
type FileCleanupJobs struct { type FileCleanupJobs struct {
@@ -68,7 +74,8 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
// If these initials aren't used by any user, delete the file // If these initials aren't used by any user, delete the file
if _, ok := initialsInUse[initials]; !ok { if _, ok := initialsInUse[initials]; !ok {
filePath := path.Join(defaultPicturesDir, filename) filePath := path.Join(defaultPicturesDir, filename)
if err := j.fileStorage.Delete(ctx, filePath); err != nil { err = j.fileStorage.Delete(ctx, filePath)
if err != nil {
slog.ErrorContext(ctx, "Failed to delete unused default profile picture", slog.String("path", filePath), slog.Any("error", err)) slog.ErrorContext(ctx, "Failed to delete unused default profile picture", slog.String("path", filePath), slog.Any("error", err))
} else { } else {
filesDeleted++ filesDeleted++
@@ -95,8 +102,9 @@ func (j *FileCleanupJobs) clearOrphanedTempFiles(ctx context.Context) error {
return nil return nil
} }
if err := j.fileStorage.Delete(ctx, p.Path); err != nil { rErr := j.fileStorage.Delete(ctx, p.Path)
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", err)) if rErr != nil {
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", rErr))
return nil return nil
} }
deleted++ deleted++

View File

@@ -23,7 +23,7 @@ func (s *Scheduler) RegisterGeoLiteUpdateJobs(ctx context.Context, geoLiteServic
jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService} jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService}
// Run every 24 hours (and right away) // Run every 24 hours (and right away)
return s.RegisterJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, true) return s.RegisterJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, service.RegisterJobOpts{RunImmediately: true})
} }
func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error { func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error {

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"time" "time"
"github.com/go-co-op/gocron/v2"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
@@ -17,8 +15,8 @@ type LdapJobs struct {
func (s *Scheduler) RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) error { func (s *Scheduler) RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) error {
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService} jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
// Register the job to run every hour // Register the job to run every hour (with some jitter)
return s.RegisterJob(ctx, "SyncLdap", gocron.DurationJob(time.Hour), jobs.syncLdap, true) return s.RegisterJob(ctx, "SyncLdap", jobDefWithJitter(time.Hour), jobs.syncLdap, service.RegisterJobOpts{RunImmediately: true})
} }
func (j *LdapJobs) syncLdap(ctx context.Context) error { func (j *LdapJobs) syncLdap(ctx context.Context) error {

View File

@@ -5,9 +5,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"time"
backoff "github.com/cenkalti/backoff/v5"
"github.com/go-co-op/gocron/v2" "github.com/go-co-op/gocron/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/service"
) )
type Scheduler struct { type Scheduler struct {
@@ -33,16 +37,12 @@ func (s *Scheduler) RemoveJob(name string) error {
if job.Name() == name { if job.Name() == name {
err := s.scheduler.RemoveJob(job.ID()) err := s.scheduler.RemoveJob(job.ID())
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("failed to unqueue job %q with ID %q: %w", name, job.ID().String(), err)) errs = append(errs, fmt.Errorf("failed to dequeue job %q with ID %q: %w", name, job.ID().String(), err))
} }
} }
} }
if len(errs) > 0 {
return errors.Join(errs...) return errors.Join(errs...)
}
return nil
} }
// Run the scheduler. // Run the scheduler.
@@ -64,7 +64,29 @@ func (s *Scheduler) Run(ctx context.Context) error {
return nil return nil
} }
func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error { func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, jobFn func(ctx context.Context) error, opts service.RegisterJobOpts) error {
// If a BackOff strategy is provided, wrap the job with retry logic
if opts.BackOff != nil {
origJob := jobFn
jobFn = func(ctx context.Context) error {
_, err := backoff.Retry(
ctx,
func() (struct{}, error) {
return struct{}{}, origJob(ctx)
},
backoff.WithBackOff(opts.BackOff),
backoff.WithNotify(func(err error, d time.Duration) {
slog.WarnContext(ctx, "Job failed, retrying",
slog.String("name", name),
slog.Any("error", err),
slog.Duration("retryIn", d),
)
}),
)
return err
}
}
jobOptions := []gocron.JobOption{ jobOptions := []gocron.JobOption{
gocron.WithContext(ctx), gocron.WithContext(ctx),
gocron.WithName(name), gocron.WithName(name),
@@ -91,13 +113,13 @@ func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.Job
), ),
} }
if runImmediately { if opts.RunImmediately {
jobOptions = append(jobOptions, gocron.JobOption(gocron.WithStartImmediately())) jobOptions = append(jobOptions, gocron.JobOption(gocron.WithStartImmediately()))
} }
jobOptions = append(jobOptions, extraOptions...) jobOptions = append(jobOptions, opts.ExtraOptions...)
_, err := s.scheduler.NewJob(def, gocron.NewTask(job), jobOptions...) _, err := s.scheduler.NewJob(def, gocron.NewTask(jobFn), jobOptions...)
if err != nil { if err != nil {
return fmt.Errorf("failed to register job %q: %w", name, err) return fmt.Errorf("failed to register job %q: %w", name, err)
@@ -105,3 +127,9 @@ func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.Job
return nil return nil
} }
func jobDefWithJitter(interval time.Duration) gocron.JobDefinition {
const jitter = 5 * time.Minute
return gocron.DurationRandomJob(interval-jitter, interval+jitter)
}

View File

@@ -16,8 +16,8 @@ type ScimJobs struct {
func (s *Scheduler) RegisterScimJobs(ctx context.Context, scimService *service.ScimService) error { func (s *Scheduler) RegisterScimJobs(ctx context.Context, scimService *service.ScimService) error {
jobs := &ScimJobs{scimService: scimService} jobs := &ScimJobs{scimService: scimService}
// Register the job to run every hour // Register the job to run every hour (with some jitter)
return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, true) return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, service.RegisterJobOpts{RunImmediately: true})
} }
func (j *ScimJobs) SyncScim(ctx context.Context) error { func (j *ScimJobs) SyncScim(ctx context.Context) error {

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"time" "time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
@@ -205,36 +206,33 @@ func (s *ApiKeyService) ListExpiringApiKeys(ctx context.Context, daysAhead int)
} }
func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error { func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error {
user := apiKey.User if apiKey.User.Email == nil {
if user.ID == "" {
if err := s.db.WithContext(ctx).First(&user, "id = ?", apiKey.UserID).Error; err != nil {
return err
}
}
if user.Email == nil {
return &common.UserEmailNotSetError{} return &common.UserEmailNotSetError{}
} }
err := SendEmail(ctx, s.emailService, email.Address{ err := SendEmail(ctx, s.emailService, email.Address{
Name: user.FullName(), Name: apiKey.User.FullName(),
Email: *user.Email, Email: *apiKey.User.Email,
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{ }, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
ApiKeyName: apiKey.Name, ApiKeyName: apiKey.Name,
ExpiresAt: apiKey.ExpiresAt.ToTime(), ExpiresAt: apiKey.ExpiresAt.ToTime(),
Name: user.FirstName, Name: apiKey.User.FirstName,
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("error sending notification email: %w", err)
} }
// Mark the API key as having had an expiration email sent // Mark the API key as having had an expiration email sent
return s.db.WithContext(ctx). err = s.db.WithContext(ctx).
Model(&model.ApiKey{}). Model(&model.ApiKey{}).
Where("id = ?", apiKey.ID). Where("id = ?", apiKey.ID).
Update("expiration_email_sent", true). Update("expiration_email_sent", true).
Error Error
if err != nil {
return fmt.Errorf("error recording expiration sent email in database: %w", err)
}
return nil
} }
func (s *ApiKeyService) initStaticApiKeyUser(ctx context.Context) (user model.User, err error) { func (s *ApiKeyService) initStaticApiKeyUser(ctx context.Context) (user model.User, err error) {

View File

@@ -73,7 +73,10 @@ func (lv *lockValue) Unmarshal(raw string) error {
// Acquire obtains the lock. When force is true, the lock is stolen from any existing owner. // Acquire obtains the lock. When force is true, the lock is stolen from any existing owner.
// If the lock is forcefully acquired, it blocks until the previous lock has expired. // If the lock is forcefully acquired, it blocks until the previous lock has expired.
func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) { func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) {
tx := s.db.Begin() tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return time.Time{}, fmt.Errorf("begin lock transaction: %w", tx.Error)
}
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
@@ -174,7 +177,8 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-ticker.C: case <-ticker.C:
if err := s.renew(ctx); err != nil { err := s.renew(ctx)
if err != nil {
return fmt.Errorf("renew lock: %w", err) return fmt.Errorf("renew lock: %w", err)
} }
} }
@@ -183,33 +187,43 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
// Release releases the lock if it is held by this process. // Release releases the lock if it is held by this process.
func (s *AppLockService) Release(ctx context.Context) error { func (s *AppLockService) Release(ctx context.Context) error {
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second) db, err := s.db.DB()
defer cancel() if err != nil {
return fmt.Errorf("failed to get DB connection: %w", err)
}
var query string var query string
switch s.db.Name() { switch s.db.Name() {
case "sqlite": case "sqlite":
query = ` query = `
DELETE FROM kv DELETE FROM kv
WHERE key = ? WHERE key = ?
AND json_extract(value, '$.lock_id') = ? AND json_extract(value, '$.lock_id') = ?
` `
case "postgres": case "postgres":
query = ` query = `
DELETE FROM kv DELETE FROM kv
WHERE key = $1 WHERE key = $1
AND value::json->>'lock_id' = $2 AND value::json->>'lock_id' = $2
` `
default: default:
return fmt.Errorf("unsupported database dialect: %s", s.db.Name()) return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
} }
res := s.db.WithContext(opCtx).Exec(query, lockKey, s.lockID) opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
if res.Error != nil { defer cancel()
return fmt.Errorf("release lock failed: %w", res.Error)
res, err := db.ExecContext(opCtx, query, lockKey, s.lockID)
if err != nil {
return fmt.Errorf("release lock failed: %w", err)
} }
if res.RowsAffected == 0 { count, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to count affected rows: %w", err)
}
if count == 0 {
slog.Warn("Application lock not held by this process, cannot release", slog.Warn("Application lock not held by this process, cannot release",
slog.Int64("process_id", s.processID), slog.Int64("process_id", s.processID),
slog.String("host_id", s.hostID), slog.String("host_id", s.hostID),
@@ -225,6 +239,11 @@ func (s *AppLockService) Release(ctx context.Context) error {
// renew tries to renew the lock, retrying up to renewRetries times (sleeping 1s between attempts). // renew tries to renew the lock, retrying up to renewRetries times (sleeping 1s between attempts).
func (s *AppLockService) renew(ctx context.Context) error { func (s *AppLockService) renew(ctx context.Context) error {
db, err := s.db.DB()
if err != nil {
return fmt.Errorf("failed to get DB connection: %w", err)
}
var lastErr error var lastErr error
for attempt := 1; attempt <= renewRetries; attempt++ { for attempt := 1; attempt <= renewRetries; attempt++ {
now := time.Now() now := time.Now()
@@ -246,42 +265,56 @@ func (s *AppLockService) renew(ctx context.Context) error {
switch s.db.Name() { switch s.db.Name() {
case "sqlite": case "sqlite":
query = ` query = `
UPDATE kv UPDATE kv
SET value = ? SET value = ?
WHERE key = ? WHERE key = ?
AND json_extract(value, '$.lock_id') = ? AND json_extract(value, '$.lock_id') = ?
AND json_extract(value, '$.expires_at') > ? AND json_extract(value, '$.expires_at') > ?
` `
case "postgres": case "postgres":
query = ` query = `
UPDATE kv UPDATE kv
SET value = $1 SET value = $1
WHERE key = $2 WHERE key = $2
AND value::json->>'lock_id' = $3 AND value::json->>'lock_id' = $3
AND ((value::json->>'expires_at')::bigint > $4) AND ((value::json->>'expires_at')::bigint > $4)
` `
default: default:
return fmt.Errorf("unsupported database dialect: %s", s.db.Name()) return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
} }
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second) opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
res := s.db.WithContext(opCtx).Exec(query, raw, lockKey, s.lockID, nowUnix) res, err := db.ExecContext(opCtx, query, raw, lockKey, s.lockID, nowUnix)
cancel() cancel()
switch { // Query succeeded, but may have updated 0 rows
case res.Error != nil: if err == nil {
lastErr = fmt.Errorf("lock renewal failed: %w", res.Error) count, err := res.RowsAffected()
case res.RowsAffected == 0: if err != nil {
// Must be after checking res.Error return fmt.Errorf("failed to count affected rows: %w", err)
}
// If no rows were updated, we lost the lock
if count == 0 {
return ErrLockLost return ErrLockLost
default: }
// All good
slog.Debug("Renewed application lock", slog.Debug("Renewed application lock",
slog.Int64("process_id", s.processID), slog.Int64("process_id", s.processID),
slog.String("host_id", s.hostID), slog.String("host_id", s.hostID),
slog.Duration("duration", time.Since(now)),
) )
return nil return nil
} }
// If we're here, we have an error that can be retried
slog.Debug("Application lock renewal attempt failed",
slog.Any("error", err),
slog.Duration("duration", time.Since(now)),
)
lastErr = fmt.Errorf("lock renewal failed: %w", err)
// Wait before next attempt or cancel if context is done // Wait before next attempt or cancel if context is done
if attempt < renewRetries { if attempt < renewRetries {
select { select {

View File

@@ -49,6 +49,23 @@ func readLockValue(t *testing.T, db *gorm.DB) lockValue {
return value return value
} }
func lockDatabaseForWrite(t *testing.T, db *gorm.DB) *gorm.DB {
t.Helper()
tx := db.Begin()
require.NoError(t, tx.Error)
// Keep a write transaction open to block other queries.
err := tx.Exec(
`INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO NOTHING`,
lockKey,
`{"expires_at":0}`,
).Error
require.NoError(t, err)
return tx
}
func TestAppLockServiceAcquire(t *testing.T) { func TestAppLockServiceAcquire(t *testing.T) {
t.Run("creates new lock when none exists", func(t *testing.T) { t.Run("creates new lock when none exists", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t) db := testutils.NewDatabaseForTest(t)
@@ -99,6 +116,66 @@ func TestAppLockServiceAcquire(t *testing.T) {
require.Equal(t, service.hostID, stored.HostID) require.Equal(t, service.hostID, stored.HostID)
require.Greater(t, stored.ExpiresAt, time.Now().Unix()) require.Greater(t, stored.ExpiresAt, time.Now().Unix())
}) })
t.Run("force acquisition returns wait duration when stealing active lock", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
existing := lockValue{
ProcessID: 99,
HostID: "other-host",
LockID: "other-lock-id",
ExpiresAt: time.Now().Add(ttl).Unix(),
}
insertLock(t, db, existing)
waitUntil, err := service.Acquire(context.Background(), true)
require.NoError(t, err)
require.WithinDuration(t, time.Unix(existing.ExpiresAt, 0), waitUntil, time.Second)
})
t.Run("force acquisition does not wait when lock id is unchanged", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
insertLock(t, db, lockValue{
ProcessID: 99,
HostID: "other-host",
LockID: service.lockID,
ExpiresAt: time.Now().Add(ttl).Unix(),
})
waitUntil, err := service.Acquire(context.Background(), true)
require.NoError(t, err)
require.True(t, waitUntil.IsZero())
})
t.Run("returns error when existing lock value is invalid JSON", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
raw := "this-is-not-json"
err := db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
require.NoError(t, err)
_, err = service.Acquire(context.Background(), false)
require.ErrorContains(t, err, "decode existing lock value")
})
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
_, err := service.Acquire(ctx, false)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorContains(t, err, "begin lock transaction")
})
} }
func TestAppLockServiceRelease(t *testing.T) { func TestAppLockServiceRelease(t *testing.T) {
@@ -134,6 +211,24 @@ func TestAppLockServiceRelease(t *testing.T) {
stored := readLockValue(t, db) stored := readLockValue(t, db)
require.Equal(t, existing, stored) require.Equal(t, existing, stored)
}) })
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
_, err := service.Acquire(context.Background(), false)
require.NoError(t, err)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
err = service.Release(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorContains(t, err, "release lock failed")
})
} }
func TestAppLockServiceRenew(t *testing.T) { func TestAppLockServiceRenew(t *testing.T) {
@@ -186,4 +281,21 @@ func TestAppLockServiceRenew(t *testing.T) {
err = service.renew(context.Background()) err = service.renew(context.Background())
require.ErrorIs(t, err, ErrLockLost) require.ErrorIs(t, err, ErrLockLost)
}) })
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
_, err := service.Acquire(context.Background(), false)
require.NoError(t, err)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
err = service.renew(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
} }

View File

@@ -150,7 +150,8 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
} }
// Send the email // Send the email
if err := srv.sendEmailContent(client, toEmail, c); err != nil { err = srv.sendEmailContent(client, toEmail, c)
if err != nil {
return fmt.Errorf("send email content: %w", err) return fmt.Errorf("send email content: %w", err)
} }

View File

@@ -0,0 +1,25 @@
package service
import (
"context"
backoff "github.com/cenkalti/backoff/v5"
"github.com/go-co-op/gocron/v2"
)
// RegisterJobOpts holds optional configuration for registering a scheduled job.
type RegisterJobOpts struct {
// RunImmediately runs the job immediately after registration.
RunImmediately bool
// ExtraOptions are additional gocron job options.
ExtraOptions []gocron.JobOption
// BackOff is an optional backoff strategy. If non-nil, the job will be wrapped
// with automatic retry logic using the provided backoff on transient failures.
BackOff backoff.BackOff
}
// Scheduler is an interface for registering and managing background jobs.
type Scheduler interface {
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, opts RegisterJobOpts) error
RemoveJob(name string) error
}

View File

@@ -34,11 +34,6 @@ const scimErrorBodyLimit = 4096
type scimSyncAction int type scimSyncAction int
type Scheduler interface {
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
RemoveJob(name string) error
}
const ( const (
scimActionNone scimSyncAction = iota scimActionNone scimSyncAction = iota
scimActionCreated scimActionCreated
@@ -149,7 +144,7 @@ func (s *ScimService) ScheduleSync() {
err := s.scheduler.RegisterJob( err := s.scheduler.RegisterJob(
context.Background(), jobName, context.Background(), jobName,
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false) gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, RegisterJobOpts{})
if err != nil { if err != nil {
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err)) slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
@@ -168,7 +163,8 @@ func (s *ScimService) SyncAll(ctx context.Context) error {
errs = append(errs, ctx.Err()) errs = append(errs, ctx.Err())
break break
} }
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil { err = s.SyncServiceProvider(ctx, provider.ID)
if err != nil {
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err)) errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
} }
} }
@@ -210,26 +206,20 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
} }
var errs []error var errs []error
var userStats scimSyncStats
var groupStats scimSyncStats
// Sync users first, so that groups can reference them // Sync users first, so that groups can reference them
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil { userStats, err := s.syncUsers(ctx, provider, users, &userResources)
errs = append(errs, err) if err != nil {
userStats = stats errs = append(errs, err)
} else { }
userStats = stats
} groupStats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
groupStats = stats
} else {
groupStats = stats
} }
if len(errs) > 0 { if len(errs) > 0 {
err = errors.Join(errs...)
slog.WarnContext(ctx, "SCIM sync completed with errors", slog.WarnContext(ctx, "SCIM sync completed with errors",
slog.String("provider_id", provider.ID), slog.String("provider_id", provider.ID),
slog.Int("error_count", len(errs)), slog.Int("error_count", len(errs)),
@@ -240,12 +230,14 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
slog.Int("groups_updated", groupStats.Updated), slog.Int("groups_updated", groupStats.Updated),
slog.Int("groups_deleted", groupStats.Deleted), slog.Int("groups_deleted", groupStats.Deleted),
slog.Duration("duration", time.Since(start)), slog.Duration("duration", time.Since(start)),
slog.Any("error", err),
) )
return errors.Join(errs...) return err
} }
provider.LastSyncedAt = new(datatype.DateTime(time.Now())) provider.LastSyncedAt = new(datatype.DateTime(time.Now()))
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil { err = s.db.WithContext(ctx).Save(&provider).Error
if err != nil {
return err return err
} }
@@ -273,7 +265,7 @@ func (s *ScimService) syncUsers(
// Update or create users // Update or create users
for _, u := range users { for _, u := range users {
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources) existing := getResourceByExternalID(u.ID, resourceList.Resources)
action, created, err := s.syncUser(ctx, provider, u, existing) action, created, err := s.syncUser(ctx, provider, u, existing)
if created != nil && existing == nil { if created != nil && existing == nil {
@@ -434,7 +426,7 @@ func (s *ScimService) syncGroup(
// Prepare group members // Prepare group members
members := make([]dto.ScimGroupMember, len(group.Users)) members := make([]dto.ScimGroupMember, len(group.Users))
for i, user := range group.Users { for i, user := range group.Users {
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources) userResource := getResourceByExternalID(user.ID, userResources)
if userResource == nil { if userResource == nil {
// Groups depend on user IDs already being provisioned // Groups depend on user IDs already being provisioned
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID) return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)

View File

@@ -1 +1 @@
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="background-color:#FBFBFB"><!--$--><!--html--><!--head--><!--body--><table border="0" width="100%" cellPadding="0" cellSpacing="0" role="presentation" align="center"><tbody><tr><td style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px"><img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">API Key Expiring Soon</h1></td><td align="right" data-id="__react-email-column"><p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Hello <!-- -->{{.Data.Name}}<!-- -->, <br/>This is a reminder that your API key <strong>{{.Data.APIKeyName}}</strong> <!-- -->will expire on <strong>{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}</strong>.</p><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Please generate a new API key if you need continued access.</p></div></td></tr></tbody></table></td></tr></tbody></table><!--/$--></body></html>{{end}} {{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="background-color:#FBFBFB"><!--$--><!--html--><!--head--><!--body--><table border="0" width="100%" cellPadding="0" cellSpacing="0" role="presentation" align="center"><tbody><tr><td style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px"><img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">API Key Expiring Soon</h1></td><td align="right" data-id="__react-email-column"><p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Hello <!-- -->{{.Data.Name}}<!-- -->, <br/>This is a reminder that your API key <strong>{{.Data.ApiKeyName}}</strong> <!-- -->will expire on <strong>{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}</strong>.</p><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Please generate a new API key if you need continued access.</p></div></td></tr></tbody></table></td></tr></tbody></table><!--/$--></body></html>{{end}}

View File

@@ -6,6 +6,6 @@ API KEY EXPIRING SOON
Warning Warning
Hello {{.Data.Name}}, Hello {{.Data.Name}},
This is a reminder that your API key {{.Data.APIKeyName}} will expire on {{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}. This is a reminder that your API key {{.Data.ApiKeyName}} will expire on {{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}.
Please generate a new API key if you need continued access.{{end}} Please generate a new API key if you need continued access.{{end}}

View File

@@ -0,0 +1 @@
-- No-op

View File

@@ -0,0 +1,6 @@
CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at);
CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at);

View File

@@ -0,0 +1 @@
-- No-op

View File

@@ -0,0 +1,12 @@
PRAGMA foreign_keys= OFF;
BEGIN;
CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at);
CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at);
COMMIT;
PRAGMA foreign_keys=ON;

View File

@@ -40,7 +40,7 @@ ApiKeyExpiringEmail.TemplateProps = {
...sharedTemplateProps, ...sharedTemplateProps,
data: { data: {
name: "{{.Data.Name}}", name: "{{.Data.Name}}",
apiKeyName: "{{.Data.APIKeyName}}", apiKeyName: "{{.Data.ApiKeyName}}",
expiresAt: '{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}', expiresAt: '{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}',
}, },
}; };