1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 20:15:07 +00:00

Address review comments + more tests

This commit is contained in:
ItalyPaleAle
2026-03-04 20:25:01 -08:00
parent 62e7882493
commit 058ac29b8f
4 changed files with 124 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ func (s *Scheduler) RegisterScimJobs(ctx context.Context, scimService *service.S
jobs := &ScimJobs{scimService: scimService} jobs := &ScimJobs{scimService: scimService}
// Register the job to run every hour (with some jitter) // Register the job to run every hour (with some jitter)
return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(24*time.Hour), jobs.SyncScim, true) return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, true)
} }
func (j *ScimJobs) SyncScim(ctx context.Context) error { func (j *ScimJobs) SyncScim(ctx context.Context) error {

View File

@@ -73,7 +73,10 @@ func (lv *lockValue) Unmarshal(raw string) error {
// Acquire obtains the lock. When force is true, the lock is stolen from any existing owner. // Acquire obtains the lock. When force is true, the lock is stolen from any existing owner.
// If the lock is forcefully acquired, it blocks until the previous lock has expired. // If the lock is forcefully acquired, it blocks until the previous lock has expired.
func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) { func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) {
tx := s.db.Begin() tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return time.Time{}, fmt.Errorf("begin lock transaction: %w", tx.Error)
}
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()

View File

@@ -49,6 +49,23 @@ func readLockValue(t *testing.T, db *gorm.DB) lockValue {
return value return value
} }
func lockDatabaseForWrite(t *testing.T, db *gorm.DB) *gorm.DB {
t.Helper()
tx := db.Begin()
require.NoError(t, tx.Error)
// Keep a write transaction open to block other queries.
err := tx.Exec(
`INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO NOTHING`,
lockKey,
`{"expires_at":0}`,
).Error
require.NoError(t, err)
return tx
}
func TestAppLockServiceAcquire(t *testing.T) { func TestAppLockServiceAcquire(t *testing.T) {
t.Run("creates new lock when none exists", func(t *testing.T) { t.Run("creates new lock when none exists", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t) db := testutils.NewDatabaseForTest(t)
@@ -99,6 +116,66 @@ func TestAppLockServiceAcquire(t *testing.T) {
require.Equal(t, service.hostID, stored.HostID) require.Equal(t, service.hostID, stored.HostID)
require.Greater(t, stored.ExpiresAt, time.Now().Unix()) require.Greater(t, stored.ExpiresAt, time.Now().Unix())
}) })
t.Run("force acquisition returns wait duration when stealing active lock", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
existing := lockValue{
ProcessID: 99,
HostID: "other-host",
LockID: "other-lock-id",
ExpiresAt: time.Now().Add(ttl).Unix(),
}
insertLock(t, db, existing)
waitUntil, err := service.Acquire(context.Background(), true)
require.NoError(t, err)
require.WithinDuration(t, time.Unix(existing.ExpiresAt, 0), waitUntil, time.Second)
})
t.Run("force acquisition does not wait when lock id is unchanged", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
insertLock(t, db, lockValue{
ProcessID: 99,
HostID: "other-host",
LockID: service.lockID,
ExpiresAt: time.Now().Add(ttl).Unix(),
})
waitUntil, err := service.Acquire(context.Background(), true)
require.NoError(t, err)
require.True(t, waitUntil.IsZero())
})
t.Run("returns error when existing lock value is invalid JSON", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
raw := "this-is-not-json"
err := db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
require.NoError(t, err)
_, err = service.Acquire(context.Background(), false)
require.ErrorContains(t, err, "decode existing lock value")
})
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
_, err := service.Acquire(ctx, false)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorContains(t, err, "begin lock transaction")
})
} }
func TestAppLockServiceRelease(t *testing.T) { func TestAppLockServiceRelease(t *testing.T) {
@@ -134,6 +211,24 @@ func TestAppLockServiceRelease(t *testing.T) {
stored := readLockValue(t, db) stored := readLockValue(t, db)
require.Equal(t, existing, stored) require.Equal(t, existing, stored)
}) })
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
_, err := service.Acquire(context.Background(), false)
require.NoError(t, err)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
err = service.Release(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorContains(t, err, "release lock failed")
})
} }
func TestAppLockServiceRenew(t *testing.T) { func TestAppLockServiceRenew(t *testing.T) {
@@ -186,4 +281,21 @@ func TestAppLockServiceRenew(t *testing.T) {
err = service.renew(context.Background()) err = service.renew(context.Background())
require.ErrorIs(t, err, ErrLockLost) require.ErrorIs(t, err, ErrLockLost)
}) })
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
service := newTestAppLockService(t, db)
_, err := service.Acquire(context.Background(), false)
require.NoError(t, err)
tx := lockDatabaseForWrite(t, db)
defer tx.Rollback()
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
err = service.renew(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
} }

View File

@@ -1,6 +1,12 @@
PRAGMA foreign_keys= OFF;
BEGIN;
CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at); CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at);
CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at); CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at); CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at);
CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at); CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at); CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at);
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at); CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at);
COMMIT;
PRAGMA foreign_keys=ON;