mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 20:15:07 +00:00
fix: various fixes in background jobs (#1362)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
f4eb8db509
commit
2f56d16f98
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
@@ -205,36 +206,33 @@ func (s *ApiKeyService) ListExpiringApiKeys(ctx context.Context, daysAhead int)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error {
|
||||
user := apiKey.User
|
||||
|
||||
if user.ID == "" {
|
||||
if err := s.db.WithContext(ctx).First(&user, "id = ?", apiKey.UserID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if user.Email == nil {
|
||||
if apiKey.User.Email == nil {
|
||||
return &common.UserEmailNotSetError{}
|
||||
}
|
||||
|
||||
err := SendEmail(ctx, s.emailService, email.Address{
|
||||
Name: user.FullName(),
|
||||
Email: *user.Email,
|
||||
Name: apiKey.User.FullName(),
|
||||
Email: *apiKey.User.Email,
|
||||
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
|
||||
ApiKeyName: apiKey.Name,
|
||||
ExpiresAt: apiKey.ExpiresAt.ToTime(),
|
||||
Name: user.FirstName,
|
||||
Name: apiKey.User.FirstName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error sending notification email: %w", err)
|
||||
}
|
||||
|
||||
// Mark the API key as having had an expiration email sent
|
||||
return s.db.WithContext(ctx).
|
||||
err = s.db.WithContext(ctx).
|
||||
Model(&model.ApiKey{}).
|
||||
Where("id = ?", apiKey.ID).
|
||||
Update("expiration_email_sent", true).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recording expiration sent email in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) initStaticApiKeyUser(ctx context.Context) (user model.User, err error) {
|
||||
|
||||
@@ -73,7 +73,10 @@ func (lv *lockValue) Unmarshal(raw string) error {
|
||||
// Acquire obtains the lock. When force is true, the lock is stolen from any existing owner.
|
||||
// If the lock is forcefully acquired, it blocks until the previous lock has expired.
|
||||
func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) {
|
||||
tx := s.db.Begin()
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return time.Time{}, fmt.Errorf("begin lock transaction: %w", tx.Error)
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
@@ -174,7 +177,8 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := s.renew(ctx); err != nil {
|
||||
err := s.renew(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew lock: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -183,33 +187,43 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
|
||||
|
||||
// Release releases the lock if it is held by this process.
|
||||
func (s *AppLockService) Release(ctx context.Context) error {
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
db, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get DB connection: %w", err)
|
||||
}
|
||||
|
||||
var query string
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
`
|
||||
DELETE FROM kv
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = $1
|
||||
AND value::json->>'lock_id' = $2
|
||||
`
|
||||
DELETE FROM kv
|
||||
WHERE key = $1
|
||||
AND value::json->>'lock_id' = $2
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
res := s.db.WithContext(opCtx).Exec(query, lockKey, s.lockID)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("release lock failed: %w", res.Error)
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := db.ExecContext(opCtx, query, lockKey, s.lockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release lock failed: %w", err)
|
||||
}
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
slog.Warn("Application lock not held by this process, cannot release",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
@@ -225,6 +239,11 @@ func (s *AppLockService) Release(ctx context.Context) error {
|
||||
|
||||
// renew tries to renew the lock, retrying up to renewRetries times (sleeping 1s between attempts).
|
||||
func (s *AppLockService) renew(ctx context.Context) error {
|
||||
db, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get DB connection: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= renewRetries; attempt++ {
|
||||
now := time.Now()
|
||||
@@ -246,42 +265,56 @@ func (s *AppLockService) renew(ctx context.Context) error {
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = ?
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
AND json_extract(value, '$.expires_at') > ?
|
||||
`
|
||||
UPDATE kv
|
||||
SET value = ?
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
AND json_extract(value, '$.expires_at') > ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = $1
|
||||
WHERE key = $2
|
||||
AND value::json->>'lock_id' = $3
|
||||
AND ((value::json->>'expires_at')::bigint > $4)
|
||||
`
|
||||
UPDATE kv
|
||||
SET value = $1
|
||||
WHERE key = $2
|
||||
AND value::json->>'lock_id' = $3
|
||||
AND ((value::json->>'expires_at')::bigint > $4)
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
res := s.db.WithContext(opCtx).Exec(query, raw, lockKey, s.lockID, nowUnix)
|
||||
res, err := db.ExecContext(opCtx, query, raw, lockKey, s.lockID, nowUnix)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case res.Error != nil:
|
||||
lastErr = fmt.Errorf("lock renewal failed: %w", res.Error)
|
||||
case res.RowsAffected == 0:
|
||||
// Must be after checking res.Error
|
||||
return ErrLockLost
|
||||
default:
|
||||
// Query succeeded, but may have updated 0 rows
|
||||
if err == nil {
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
// If no rows were updated, we lost the lock
|
||||
if count == 0 {
|
||||
return ErrLockLost
|
||||
}
|
||||
|
||||
// All good
|
||||
slog.Debug("Renewed application lock",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
slog.Duration("duration", time.Since(now)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we're here, we have an error that can be retried
|
||||
slog.Debug("Application lock renewal attempt failed",
|
||||
slog.Any("error", err),
|
||||
slog.Duration("duration", time.Since(now)),
|
||||
)
|
||||
lastErr = fmt.Errorf("lock renewal failed: %w", err)
|
||||
|
||||
// Wait before next attempt or cancel if context is done
|
||||
if attempt < renewRetries {
|
||||
select {
|
||||
|
||||
@@ -49,6 +49,23 @@ func readLockValue(t *testing.T, db *gorm.DB) lockValue {
|
||||
return value
|
||||
}
|
||||
|
||||
func lockDatabaseForWrite(t *testing.T, db *gorm.DB) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
tx := db.Begin()
|
||||
require.NoError(t, tx.Error)
|
||||
|
||||
// Keep a write transaction open to block other queries.
|
||||
err := tx.Exec(
|
||||
`INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO NOTHING`,
|
||||
lockKey,
|
||||
`{"expires_at":0}`,
|
||||
).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
func TestAppLockServiceAcquire(t *testing.T) {
|
||||
t.Run("creates new lock when none exists", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
@@ -99,6 +116,66 @@ func TestAppLockServiceAcquire(t *testing.T) {
|
||||
require.Equal(t, service.hostID, stored.HostID)
|
||||
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
||||
})
|
||||
|
||||
t.Run("force acquisition returns wait duration when stealing active lock", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
existing := lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
LockID: "other-lock-id",
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
insertLock(t, db, existing)
|
||||
|
||||
waitUntil, err := service.Acquire(context.Background(), true)
|
||||
require.NoError(t, err)
|
||||
require.WithinDuration(t, time.Unix(existing.ExpiresAt, 0), waitUntil, time.Second)
|
||||
})
|
||||
|
||||
t.Run("force acquisition does not wait when lock id is unchanged", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
insertLock(t, db, lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
LockID: service.lockID,
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
})
|
||||
|
||||
waitUntil, err := service.Acquire(context.Background(), true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, waitUntil.IsZero())
|
||||
})
|
||||
|
||||
t.Run("returns error when existing lock value is invalid JSON", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
raw := "this-is-not-json"
|
||||
err := db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Acquire(context.Background(), false)
|
||||
require.ErrorContains(t, err, "decode existing lock value")
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := service.Acquire(ctx, false)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorContains(t, err, "begin lock transaction")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRelease(t *testing.T) {
|
||||
@@ -134,6 +211,24 @@ func TestAppLockServiceRelease(t *testing.T) {
|
||||
stored := readLockValue(t, db)
|
||||
require.Equal(t, existing, stored)
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = service.Release(ctx)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorContains(t, err, "release lock failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRenew(t *testing.T) {
|
||||
@@ -186,4 +281,21 @@ func TestAppLockServiceRenew(t *testing.T) {
|
||||
err = service.renew(context.Background())
|
||||
require.ErrorIs(t, err, ErrLockLost)
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = service.renew(ctx)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -150,7 +150,8 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
|
||||
}
|
||||
|
||||
// Send the email
|
||||
if err := srv.sendEmailContent(client, toEmail, c); err != nil {
|
||||
err = srv.sendEmailContent(client, toEmail, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send email content: %w", err)
|
||||
}
|
||||
|
||||
|
||||
25
backend/internal/service/scheduler.go
Normal file
25
backend/internal/service/scheduler.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v5"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
)
|
||||
|
||||
// RegisterJobOpts holds optional configuration for registering a scheduled job.
|
||||
type RegisterJobOpts struct {
|
||||
// RunImmediately runs the job immediately after registration.
|
||||
RunImmediately bool
|
||||
// ExtraOptions are additional gocron job options.
|
||||
ExtraOptions []gocron.JobOption
|
||||
// BackOff is an optional backoff strategy. If non-nil, the job will be wrapped
|
||||
// with automatic retry logic using the provided backoff on transient failures.
|
||||
BackOff backoff.BackOff
|
||||
}
|
||||
|
||||
// Scheduler is an interface for registering and managing background jobs.
|
||||
type Scheduler interface {
|
||||
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, opts RegisterJobOpts) error
|
||||
RemoveJob(name string) error
|
||||
}
|
||||
@@ -34,11 +34,6 @@ const scimErrorBodyLimit = 4096
|
||||
|
||||
type scimSyncAction int
|
||||
|
||||
type Scheduler interface {
|
||||
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
|
||||
RemoveJob(name string) error
|
||||
}
|
||||
|
||||
const (
|
||||
scimActionNone scimSyncAction = iota
|
||||
scimActionCreated
|
||||
@@ -149,7 +144,7 @@ func (s *ScimService) ScheduleSync() {
|
||||
|
||||
err := s.scheduler.RegisterJob(
|
||||
context.Background(), jobName,
|
||||
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false)
|
||||
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, RegisterJobOpts{})
|
||||
|
||||
if err != nil {
|
||||
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
|
||||
@@ -168,7 +163,8 @@ func (s *ScimService) SyncAll(ctx context.Context) error {
|
||||
errs = append(errs, ctx.Err())
|
||||
break
|
||||
}
|
||||
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil {
|
||||
err = s.SyncServiceProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
|
||||
}
|
||||
}
|
||||
@@ -210,26 +206,20 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var userStats scimSyncStats
|
||||
var groupStats scimSyncStats
|
||||
|
||||
// Sync users first, so that groups can reference them
|
||||
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil {
|
||||
errs = append(errs, err)
|
||||
userStats = stats
|
||||
} else {
|
||||
userStats = stats
|
||||
}
|
||||
|
||||
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
||||
userStats, err := s.syncUsers(ctx, provider, users, &userResources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
groupStats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
groupStats = stats
|
||||
} else {
|
||||
groupStats = stats
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
err = errors.Join(errs...)
|
||||
slog.WarnContext(ctx, "SCIM sync completed with errors",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.Int("error_count", len(errs)),
|
||||
@@ -240,12 +230,14 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
|
||||
slog.Int("groups_updated", groupStats.Updated),
|
||||
slog.Int("groups_deleted", groupStats.Deleted),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
return errors.Join(errs...)
|
||||
return err
|
||||
}
|
||||
|
||||
provider.LastSyncedAt = new(datatype.DateTime(time.Now()))
|
||||
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
||||
err = s.db.WithContext(ctx).Save(&provider).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -273,7 +265,7 @@ func (s *ScimService) syncUsers(
|
||||
|
||||
// Update or create users
|
||||
for _, u := range users {
|
||||
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources)
|
||||
existing := getResourceByExternalID(u.ID, resourceList.Resources)
|
||||
|
||||
action, created, err := s.syncUser(ctx, provider, u, existing)
|
||||
if created != nil && existing == nil {
|
||||
@@ -434,7 +426,7 @@ func (s *ScimService) syncGroup(
|
||||
// Prepare group members
|
||||
members := make([]dto.ScimGroupMember, len(group.Users))
|
||||
for i, user := range group.Users {
|
||||
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources)
|
||||
userResource := getResourceByExternalID(user.ID, userResources)
|
||||
if userResource == nil {
|
||||
// Groups depend on user IDs already being provisioned
|
||||
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)
|
||||
|
||||
Reference in New Issue
Block a user