mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 15:00:07 +00:00
302 lines
8.1 KiB
Go
302 lines
8.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
|
)
|
|
|
|
func newTestAppLockService(t *testing.T, db *gorm.DB) *AppLockService {
|
|
t.Helper()
|
|
|
|
return &AppLockService{
|
|
db: db,
|
|
processID: 1,
|
|
hostID: "test-host",
|
|
lockID: "a13c7673-c7ae-49f1-9112-2cd2d0d4b0c1",
|
|
}
|
|
}
|
|
|
|
func insertLock(t *testing.T, db *gorm.DB, value lockValue) {
|
|
t.Helper()
|
|
|
|
raw, err := value.Marshal()
|
|
require.NoError(t, err)
|
|
|
|
err = db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func readLockValue(t *testing.T, db *gorm.DB) lockValue {
|
|
t.Helper()
|
|
|
|
var row model.KV
|
|
err := db.Take(&row, "key = ?", lockKey).Error
|
|
require.NoError(t, err)
|
|
|
|
require.NotNil(t, row.Value)
|
|
|
|
var value lockValue
|
|
err = value.Unmarshal(*row.Value)
|
|
require.NoError(t, err)
|
|
|
|
return value
|
|
}
|
|
|
|
func lockDatabaseForWrite(t *testing.T, db *gorm.DB) *gorm.DB {
|
|
t.Helper()
|
|
|
|
tx := db.Begin()
|
|
require.NoError(t, tx.Error)
|
|
|
|
// Keep a write transaction open to block other queries.
|
|
err := tx.Exec(
|
|
`INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO NOTHING`,
|
|
lockKey,
|
|
`{"expires_at":0}`,
|
|
).Error
|
|
require.NoError(t, err)
|
|
|
|
return tx
|
|
}
|
|
|
|
func TestAppLockServiceAcquire(t *testing.T) {
|
|
t.Run("creates new lock when none exists", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
stored := readLockValue(t, db)
|
|
require.Equal(t, service.processID, stored.ProcessID)
|
|
require.Equal(t, service.hostID, stored.HostID)
|
|
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
|
})
|
|
|
|
t.Run("returns ErrLockUnavailable when lock held by another process", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
existing := lockValue{
|
|
ProcessID: 99,
|
|
HostID: "other-host",
|
|
ExpiresAt: time.Now().Add(ttl).Unix(),
|
|
}
|
|
insertLock(t, db, existing)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.ErrorIs(t, err, ErrLockUnavailable)
|
|
|
|
current := readLockValue(t, db)
|
|
require.Equal(t, existing, current)
|
|
})
|
|
|
|
t.Run("force acquisition steals lock", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
insertLock(t, db, lockValue{
|
|
ProcessID: 99,
|
|
HostID: "other-host",
|
|
ExpiresAt: time.Now().Unix(),
|
|
})
|
|
|
|
_, err := service.Acquire(context.Background(), true)
|
|
require.NoError(t, err)
|
|
|
|
stored := readLockValue(t, db)
|
|
require.Equal(t, service.processID, stored.ProcessID)
|
|
require.Equal(t, service.hostID, stored.HostID)
|
|
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
|
})
|
|
|
|
t.Run("force acquisition returns wait duration when stealing active lock", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
existing := lockValue{
|
|
ProcessID: 99,
|
|
HostID: "other-host",
|
|
LockID: "other-lock-id",
|
|
ExpiresAt: time.Now().Add(ttl).Unix(),
|
|
}
|
|
insertLock(t, db, existing)
|
|
|
|
waitUntil, err := service.Acquire(context.Background(), true)
|
|
require.NoError(t, err)
|
|
require.WithinDuration(t, time.Unix(existing.ExpiresAt, 0), waitUntil, time.Second)
|
|
})
|
|
|
|
t.Run("force acquisition does not wait when lock id is unchanged", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
insertLock(t, db, lockValue{
|
|
ProcessID: 99,
|
|
HostID: "other-host",
|
|
LockID: service.lockID,
|
|
ExpiresAt: time.Now().Add(ttl).Unix(),
|
|
})
|
|
|
|
waitUntil, err := service.Acquire(context.Background(), true)
|
|
require.NoError(t, err)
|
|
require.True(t, waitUntil.IsZero())
|
|
})
|
|
|
|
t.Run("returns error when existing lock value is invalid JSON", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
raw := "this-is-not-json"
|
|
err := db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
|
|
require.NoError(t, err)
|
|
|
|
_, err = service.Acquire(context.Background(), false)
|
|
require.ErrorContains(t, err, "decode existing lock value")
|
|
})
|
|
|
|
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
tx := lockDatabaseForWrite(t, db)
|
|
defer tx.Rollback()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := service.Acquire(ctx, false)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.ErrorContains(t, err, "begin lock transaction")
|
|
})
|
|
}
|
|
|
|
func TestAppLockServiceRelease(t *testing.T) {
|
|
t.Run("removes owned lock", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
err = service.Release(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
var row model.KV
|
|
err = db.Take(&row, "key = ?", lockKey).Error
|
|
require.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
|
})
|
|
|
|
t.Run("ignores lock held by another owner", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
existing := lockValue{
|
|
ProcessID: 2,
|
|
HostID: "other-host",
|
|
ExpiresAt: time.Now().Add(ttl).Unix(),
|
|
}
|
|
insertLock(t, db, existing)
|
|
|
|
err := service.Release(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
stored := readLockValue(t, db)
|
|
require.Equal(t, existing, stored)
|
|
})
|
|
|
|
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
tx := lockDatabaseForWrite(t, db)
|
|
defer tx.Rollback()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
|
defer cancel()
|
|
|
|
err = service.Release(ctx)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.ErrorContains(t, err, "release lock failed")
|
|
})
|
|
}
|
|
|
|
func TestAppLockServiceRenew(t *testing.T) {
|
|
t.Run("extends expiration when lock is still owned", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
before := readLockValue(t, db)
|
|
|
|
err = service.renew(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
after := readLockValue(t, db)
|
|
require.Equal(t, service.processID, after.ProcessID)
|
|
require.Equal(t, service.hostID, after.HostID)
|
|
require.GreaterOrEqual(t, after.ExpiresAt, before.ExpiresAt)
|
|
})
|
|
|
|
t.Run("returns ErrLockLost when lock is missing", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
err := service.renew(context.Background())
|
|
require.ErrorIs(t, err, ErrLockLost)
|
|
})
|
|
|
|
t.Run("returns ErrLockLost when ownership changed", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
// Simulate a different process taking the lock.
|
|
newOwner := lockValue{
|
|
ProcessID: 9,
|
|
HostID: "stolen-host",
|
|
ExpiresAt: time.Now().Add(ttl).Unix(),
|
|
}
|
|
raw, marshalErr := newOwner.Marshal()
|
|
require.NoError(t, marshalErr)
|
|
updateErr := db.Model(&model.KV{}).
|
|
Where("key = ?", lockKey).
|
|
Update("value", raw).Error
|
|
require.NoError(t, updateErr)
|
|
|
|
err = service.renew(context.Background())
|
|
require.ErrorIs(t, err, ErrLockLost)
|
|
})
|
|
|
|
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
|
db := testutils.NewDatabaseForTest(t)
|
|
service := newTestAppLockService(t, db)
|
|
|
|
_, err := service.Acquire(context.Background(), false)
|
|
require.NoError(t, err)
|
|
|
|
tx := lockDatabaseForWrite(t, db)
|
|
defer tx.Rollback()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
|
defer cancel()
|
|
|
|
err = service.renew(ctx)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
})
|
|
}
|