From 2af70d9b4d4536f4266ba1eb0b5919b16a8d2024 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Wed, 7 Jan 2026 09:34:23 +0100 Subject: [PATCH] feat: add CLI command for encryption key rotation (#1209) --- .../internal/bootstrap/services_bootstrap.go | 2 +- .../internal/cmds/encryption_key_rotate.go | 187 ++++++++++++++++++ .../cmds/encryption_key_rotate_test.go | 89 +++++++++ backend/internal/cmds/key_rotate.go | 2 +- backend/internal/cmds/key_rotate_test.go | 2 +- .../internal/model/types/encrypted_string.go | 45 +++-- backend/internal/service/e2etest_service.go | 2 +- backend/internal/service/jwt_service.go | 14 +- backend/internal/service/jwt_service_test.go | 6 +- backend/internal/service/oidc_service_test.go | 2 +- backend/internal/utils/jwk/key_provider.go | 5 +- .../utils/jwk/key_provider_database.go | 8 +- .../utils/jwk/key_provider_database_test.go | 18 +- 13 files changed, 340 insertions(+), 42 deletions(-) create mode 100644 backend/internal/cmds/encryption_key_rotate.go create mode 100644 backend/internal/cmds/encryption_key_rotate_test.go diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go index 30d10419..86a33a93 100644 --- a/backend/internal/bootstrap/services_bootstrap.go +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -52,7 +52,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima svc.geoLiteService = service.NewGeoLiteService(httpClient) svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService) - svc.jwtService, err = service.NewJwtService(db, svc.appConfigService) + svc.jwtService, err = service.NewJwtService(ctx, db, svc.appConfigService) if err != nil { return nil, fmt.Errorf("failed to create JWT service: %w", err) } diff --git a/backend/internal/cmds/encryption_key_rotate.go b/backend/internal/cmds/encryption_key_rotate.go new file mode 100644 index 00000000..7b5ea98d --- /dev/null +++ b/backend/internal/cmds/encryption_key_rotate.go @@ -0,0 +1,187 @@ +package cmds + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/spf13/cobra" + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/bootstrap" + "github.com/pocket-id/pocket-id/backend/internal/common" + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "github.com/pocket-id/pocket-id/backend/internal/service" + "github.com/pocket-id/pocket-id/backend/internal/utils" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" +) + +type encryptionKeyRotateFlags struct { + NewKey string + Yes bool +} + +func init() { + var flags encryptionKeyRotateFlags + + encryptionKeyRotateCmd := &cobra.Command{ + Use: "encryption-key-rotate", + Short: "Re-encrypts data using a new encryption key", + RunE: func(cmd *cobra.Command, args []string) error { + db, err := bootstrap.NewDatabase() + if err != nil { + return err + } + + return encryptionKeyRotate(cmd.Context(), flags, db, &common.EnvConfig) + }, + } + + encryptionKeyRotateCmd.Flags().StringVar(&flags.NewKey, "new-key", "", "New encryption key to re-encrypt data with") + encryptionKeyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation") + + rootCmd.AddCommand(encryptionKeyRotateCmd) +} + +func encryptionKeyRotate(ctx context.Context, flags encryptionKeyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error { + oldKey := envConfig.EncryptionKey + newKey := []byte(flags.NewKey) + if len(newKey) == 0 { + return errors.New("new encryption key is required (--new-key)") + } + if len(newKey) < 16 { + return errors.New("new encryption key must be at least 16 bytes long") + } + + if !flags.Yes { + fmt.Println("WARNING: Rotating the encryption key will re-encrypt secrets in the database. Pocket-ID must be restarted with the new ENCRYPTION_KEY after rotation is complete.") + ok, err := utils.PromptForConfirmation("Continue") + if err != nil { + return err + } + if !ok { + fmt.Println("Aborted") + os.Exit(1) + } + } + + appConfigService, err := service.NewAppConfigService(ctx, db) + if err != nil { + return fmt.Errorf("failed to create app config service: %w", err) + } + instanceID := appConfigService.GetDbConfig().InstanceID.Value + + // Derive the encryption keys used for the JWK encryption + oldKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: oldKey}, instanceID) + if err != nil { + return fmt.Errorf("failed to derive old key encryption key: %w", err) + } + newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID) + if err != nil { + return fmt.Errorf("failed to derive new key encryption key: %w", err) + } + + // Derive the encryption keys used for EncryptedString fields + oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey) + if err != nil { + return fmt.Errorf("failed to derive old encrypted string key: %w", err) + } + newEncKey, err := datatype.DeriveEncryptedStringKey(newKey) + if err != nil { + return fmt.Errorf("failed to derive new encrypted string key: %w", err) + } + + err = db.Transaction(func(tx *gorm.DB) error { + err = rotateSigningKeyEncryption(ctx, tx, oldKek, newKek) + if err != nil { + return err + } + + err = rotateScimTokens(tx, oldEncKey, newEncKey) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + fmt.Println("Encryption key rotation completed successfully.") + fmt.Println("Restart pocket-id with the new ENCRYPTION_KEY to use the rotated data.") + + return nil +} + +func rotateSigningKeyEncryption(ctx context.Context, db *gorm.DB, oldKek []byte, newKek []byte) error { + oldProvider := &jwkutils.KeyProviderDatabase{} + err := oldProvider.Init(jwkutils.KeyProviderOpts{ + DB: db, + Kek: oldKek, + }) + if err != nil { + return fmt.Errorf("failed to init key provider with old encryption key: %w", err) + } + + key, err := oldProvider.LoadKey(ctx) + if err != nil { + return fmt.Errorf("failed to load signing key using old encryption key: %w", err) + } + if key == nil { + return nil + } + + newProvider := &jwkutils.KeyProviderDatabase{} + err = newProvider.Init(jwkutils.KeyProviderOpts{ + DB: db, + Kek: newKek, + }) + if err != nil { + return fmt.Errorf("failed to init key provider with new encryption key: %w", err) + } + + if err := newProvider.SaveKey(ctx, key); err != nil { + return fmt.Errorf("failed to store signing key with new encryption key: %w", err) + } + + return nil +} + +type scimTokenRow struct { + ID string + Token string +} + +func rotateScimTokens(db *gorm.DB, oldEncKey []byte, newEncKey []byte) error { + var rows []scimTokenRow + err := db.Model(&model.ScimServiceProvider{}).Select("id, token").Scan(&rows).Error + if err != nil { + return fmt.Errorf("failed to list SCIM service providers: %w", err) + } + + for _, row := range rows { + if row.Token == "" { + continue + } + + decBytes, err := datatype.DecryptEncryptedStringWithKey(oldEncKey, row.Token) + if err != nil { + return fmt.Errorf("failed to decrypt SCIM token for provider %s: %w", row.ID, err) + } + + encValue, err := datatype.EncryptEncryptedStringWithKey(newEncKey, decBytes) + if err != nil { + return fmt.Errorf("failed to encrypt SCIM token for provider %s: %w", row.ID, err) + } + + err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", row.ID).Update("token", encValue).Error + if err != nil { + return fmt.Errorf("failed to update SCIM token for provider %s: %w", row.ID, err) + } + } + + return nil +} diff --git a/backend/internal/cmds/encryption_key_rotate_test.go b/backend/internal/cmds/encryption_key_rotate_test.go new file mode 100644 index 00000000..8d27e975 --- /dev/null +++ b/backend/internal/cmds/encryption_key_rotate_test.go @@ -0,0 +1,89 @@ +package cmds + +import ( + "testing" + "time" + + "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "github.com/pocket-id/pocket-id/backend/internal/service" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" + testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" +) + +func TestEncryptionKeyRotate(t *testing.T) { + oldKey := []byte("old-encryption-key-123456") + newKey := []byte("new-encryption-key-654321") + + envConfig := &common.EnvConfigSchema{ + EncryptionKey: oldKey, + } + + db := testingutils.NewDatabaseForTest(t) + + appConfigService, err := service.NewAppConfigService(t.Context(), db) + require.NoError(t, err) + instanceID := appConfigService.GetDbConfig().InstanceID.Value + + oldKek, err := jwkutils.LoadKeyEncryptionKey(envConfig, instanceID) + require.NoError(t, err) + + oldProvider := &jwkutils.KeyProviderDatabase{} + require.NoError(t, oldProvider.Init(jwkutils.KeyProviderOpts{ + DB: db, + Kek: oldKek, + })) + + signingKey, err := jwkutils.GenerateKey("RS256", "") + require.NoError(t, err) + require.NoError(t, oldProvider.SaveKey(t.Context(), signingKey)) + + oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey) + require.NoError(t, err) + encToken, err := datatype.EncryptEncryptedStringWithKey(oldEncKey, []byte("scim-token-123")) + require.NoError(t, err) + + err = db.Exec( + `INSERT INTO scim_service_providers (id, created_at, endpoint, token, oidc_client_id) VALUES (?, ?, ?, ?, ?)`, + "scim-1", + time.Now(), + "https://example.com/scim", + encToken, + "client-1", + ).Error + require.NoError(t, err) + + flags := encryptionKeyRotateFlags{ + NewKey: string(newKey), + Yes: true, + } + require.NoError(t, encryptionKeyRotate(t.Context(), flags, db, envConfig)) + + newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID) + require.NoError(t, err) + + newProvider := &jwkutils.KeyProviderDatabase{} + require.NoError(t, newProvider.Init(jwkutils.KeyProviderOpts{ + DB: db, + Kek: newKek, + })) + + rotatedKey, err := newProvider.LoadKey(t.Context()) + require.NoError(t, err) + require.NotNil(t, rotatedKey) + + var storedToken string + err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", "scim-1").Pluck("token", &storedToken).Error + require.NoError(t, err) + + newEncKey, err := datatype.DeriveEncryptedStringKey(newKey) + require.NoError(t, err) + + decBytes, err := datatype.DecryptEncryptedStringWithKey(newEncKey, storedToken) + require.NoError(t, err) + assert.Equal(t, "scim-token-123", string(decBytes)) +} diff --git a/backend/internal/cmds/key_rotate.go b/backend/internal/cmds/key_rotate.go index 5225c623..34a1b26e 100644 --- a/backend/internal/cmds/key_rotate.go +++ b/backend/internal/cmds/key_rotate.go @@ -102,7 +102,7 @@ func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig } // Save the key - err = keyProvider.SaveKey(key) + err = keyProvider.SaveKey(ctx, key) if err != nil { return fmt.Errorf("failed to store new key: %w", err) } diff --git a/backend/internal/cmds/key_rotate_test.go b/backend/internal/cmds/key_rotate_test.go index bccd3a3d..fb4f5548 100644 --- a/backend/internal/cmds/key_rotate_test.go +++ b/backend/internal/cmds/key_rotate_test.go @@ -104,7 +104,7 @@ func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantEr require.NoError(t, err) // Verify key was created - key, err := keyProvider.LoadKey() + key, err := keyProvider.LoadKey(t.Context()) require.NoError(t, err) require.NotNil(t, key) diff --git a/backend/internal/model/types/encrypted_string.go b/backend/internal/model/types/encrypted_string.go index 6d17d4b1..846274b3 100644 --- a/backend/internal/model/types/encrypted_string.go +++ b/backend/internal/model/types/encrypted_string.go @@ -40,14 +40,9 @@ func (e *EncryptedString) Scan(value any) error { return nil } - encBytes, err := base64.StdEncoding.DecodeString(raw) + decBytes, err := DecryptEncryptedStringWithKey(encStringKey, raw) if err != nil { - return fmt.Errorf("failed to decode encrypted string: %w", err) - } - - decBytes, err := cryptoutils.Decrypt(encStringKey, encBytes, []byte(encryptedStringAAD)) - if err != nil { - return fmt.Errorf("failed to decrypt encrypted string: %w", err) + return err } *e = EncryptedString(decBytes) @@ -59,19 +54,20 @@ func (e EncryptedString) Value() (driver.Value, error) { return "", nil } - encBytes, err := cryptoutils.Encrypt(encStringKey, []byte(e), []byte(encryptedStringAAD)) + encValue, err := EncryptEncryptedStringWithKey(encStringKey, []byte(e)) if err != nil { - return nil, fmt.Errorf("failed to encrypt string: %w", err) + return nil, err } - return base64.StdEncoding.EncodeToString(encBytes), nil + return encValue, nil } func (e EncryptedString) String() string { return string(e) } -func deriveEncryptedStringKey(master []byte) ([]byte, error) { +// DeriveEncryptedStringKey derives a key for encrypting EncryptedString values from the master key. +func DeriveEncryptedStringKey(master []byte) ([]byte, error) { const info = "pocketid/encrypted_string" r := hkdf.New(sha256.New, master, nil, []byte(info)) @@ -82,8 +78,33 @@ func deriveEncryptedStringKey(master []byte) ([]byte, error) { return key, nil } +// DecryptEncryptedStringWithKey decrypts an EncryptedString value using the derived key. +func DecryptEncryptedStringWithKey(key []byte, encoded string) ([]byte, error) { + encBytes, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("failed to decode encrypted string: %w", err) + } + + decBytes, err := cryptoutils.Decrypt(key, encBytes, []byte(encryptedStringAAD)) + if err != nil { + return nil, fmt.Errorf("failed to decrypt encrypted string: %w", err) + } + + return decBytes, nil +} + +// EncryptEncryptedStringWithKey encrypts an EncryptedString value using the derived key. +func EncryptEncryptedStringWithKey(key []byte, plaintext []byte) (string, error) { + encBytes, err := cryptoutils.Encrypt(key, plaintext, []byte(encryptedStringAAD)) + if err != nil { + return "", fmt.Errorf("failed to encrypt string: %w", err) + } + + return base64.StdEncoding.EncodeToString(encBytes), nil +} + func init() { - key, err := deriveEncryptedStringKey(common.EnvConfig.EncryptionKey) + key, err := DeriveEncryptedStringKey(common.EnvConfig.EncryptionKey) if err != nil { panic(fmt.Sprintf("failed to derive encrypted string key: %v", err)) } diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index 077f714b..1be77436 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -526,7 +526,7 @@ func (s *TestService) ResetAppConfig(ctx context.Context) error { } // Reload the JWK - if err := s.jwtService.LoadOrGenerateKey(); err != nil { + if err := s.jwtService.LoadOrGenerateKey(ctx); err != nil { return err } diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index 07203a99..a993e9e4 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -56,10 +56,10 @@ type JwtService struct { jwksEncoded []byte } -func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) { +func NewJwtService(ctx context.Context, db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) { service := &JwtService{} - err := service.init(db, appConfigService, &common.EnvConfig) + err := service.init(ctx, db, appConfigService, &common.EnvConfig) if err != nil { return nil, err } @@ -67,16 +67,16 @@ func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService return service, nil } -func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) { +func (s *JwtService) init(ctx context.Context, db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) { s.appConfigService = appConfigService s.envConfig = envConfig s.db = db // Ensure keys are generated or loaded - return s.LoadOrGenerateKey() + return s.LoadOrGenerateKey(ctx) } -func (s *JwtService) LoadOrGenerateKey() error { +func (s *JwtService) LoadOrGenerateKey(ctx context.Context) error { // Get the key provider keyProvider, err := jwkutils.GetKeyProvider(s.db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value) if err != nil { @@ -84,7 +84,7 @@ func (s *JwtService) LoadOrGenerateKey() error { } // Try loading a key - key, err := keyProvider.LoadKey() + key, err := keyProvider.LoadKey(ctx) if err != nil { return fmt.Errorf("failed to load key: %w", err) } @@ -105,7 +105,7 @@ func (s *JwtService) LoadOrGenerateKey() error { } // Save the newly-generated key - err = keyProvider.SaveKey(s.privateKey) + err = keyProvider.SaveKey(ctx, s.privateKey) if err != nil { return fmt.Errorf("failed to save private key: %w", err) } diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 0aa33c5f..e25dc966 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -38,7 +38,7 @@ func initJwtService(t *testing.T, db *gorm.DB, appConfig *AppConfigService, envC t.Helper() service := &JwtService{} - err := service.init(db, appConfig, envConfig) + err := service.init(t.Context(), db, appConfig, envConfig) require.NoError(t, err, "Failed to initialize JWT service") return service @@ -65,7 +65,7 @@ func saveKeyToDatabase(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSch keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfig.GetDbConfig().InstanceID.Value) require.NoError(t, err, "Failed to init key provider") - err = keyProvider.SaveKey(key) + err = keyProvider.SaveKey(t.Context(), key) require.NoError(t, err, "Failed to save key") kid, ok := key.KeyID() @@ -93,7 +93,7 @@ func TestJwtService_Init(t *testing.T) { // Verify the key has been persisted in the database keyProvider, err := jwkutils.GetKeyProvider(db, mockEnvConfig, mockConfig.GetDbConfig().InstanceID.Value) require.NoError(t, err, "Failed to init key provider") - key, err := keyProvider.LoadKey() + key, err := keyProvider.LoadKey(t.Context()) require.NoError(t, err, "Failed to load key from provider") require.NotNil(t, key, "Key should be present in the database") diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 7c8ddfc5..4dbab3f6 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -160,7 +160,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes }) - mockJwtService, err := NewJwtService(db, mockConfig) + mockJwtService, err := NewJwtService(t.Context(), db, mockConfig) require.NoError(t, err) // Create a mock HTTP client with custom transport to return the JWKS diff --git a/backend/internal/utils/jwk/key_provider.go b/backend/internal/utils/jwk/key_provider.go index 0cc32c8b..e47ec116 100644 --- a/backend/internal/utils/jwk/key_provider.go +++ b/backend/internal/utils/jwk/key_provider.go @@ -1,6 +1,7 @@ package jwk import ( + "context" "fmt" "github.com/lestrrat-go/jwx/v3/jwk" @@ -17,8 +18,8 @@ type KeyProviderOpts struct { type KeyProvider interface { Init(opts KeyProviderOpts) error - LoadKey() (jwk.Key, error) - SaveKey(key jwk.Key) error + LoadKey(ctx context.Context) (jwk.Key, error) + SaveKey(ctx context.Context, key jwk.Key) error } func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) { diff --git a/backend/internal/utils/jwk/key_provider_database.go b/backend/internal/utils/jwk/key_provider_database.go index 0158450b..1b46f400 100644 --- a/backend/internal/utils/jwk/key_provider_database.go +++ b/backend/internal/utils/jwk/key_provider_database.go @@ -33,12 +33,12 @@ func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error { return nil } -func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) { +func (f *KeyProviderDatabase) LoadKey(ctx context.Context) (key jwk.Key, err error) { row := model.KV{ Key: PrivateKeyDBKey, } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() err = f.db.WithContext(ctx).First(&row).Error if errors.Is(err, gorm.ErrRecordNotFound) { @@ -74,7 +74,7 @@ func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) { return key, nil } -func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error { +func (f *KeyProviderDatabase) SaveKey(ctx context.Context, key jwk.Key) error { // Encode the key to JSON data, err := EncodeJWKBytes(key) if err != nil { @@ -94,7 +94,7 @@ func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error { Value: &encB64, } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() err = f.db. WithContext(ctx). diff --git a/backend/internal/utils/jwk/key_provider_database_test.go b/backend/internal/utils/jwk/key_provider_database_test.go index fd5dd2bd..fa783d40 100644 --- a/backend/internal/utils/jwk/key_provider_database_test.go +++ b/backend/internal/utils/jwk/key_provider_database_test.go @@ -59,7 +59,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Load key when none exists - loadedKey, err := provider.LoadKey() + loadedKey, err := provider.LoadKey(t.Context()) require.NoError(t, err) assert.Nil(t, loadedKey, "Expected nil key when no key exists in database") }) @@ -76,11 +76,11 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Save a key - err = provider.SaveKey(key) + err = provider.SaveKey(t.Context(), key) require.NoError(t, err) // Load the key - loadedKey, err := provider.LoadKey() + loadedKey, err := provider.LoadKey(t.Context()) require.NoError(t, err) assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database") @@ -114,7 +114,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Attempt to load the key - loadedKey, err := provider.LoadKey() + loadedKey, err := provider.LoadKey(t.Context()) require.Error(t, err, "Expected error when loading key with invalid base64") require.ErrorContains(t, err, "not a valid base64-encoded value") assert.Nil(t, loadedKey, "Expected nil key when loading fails") @@ -140,7 +140,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Attempt to load the key - loadedKey, err := provider.LoadKey() + loadedKey, err := provider.LoadKey(t.Context()) require.Error(t, err, "Expected error when loading key with invalid encrypted data") require.ErrorContains(t, err, "failed to decrypt") assert.Nil(t, loadedKey, "Expected nil key when loading fails") @@ -158,7 +158,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { }) require.NoError(t, err) - err = originalProvider.SaveKey(key) + err = originalProvider.SaveKey(t.Context(), key) require.NoError(t, err) // Now try to load with a different KEK @@ -171,7 +171,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Attempt to load the key with the wrong KEK - loadedKey, err := differentProvider.LoadKey() + loadedKey, err := differentProvider.LoadKey(t.Context()) require.Error(t, err, "Expected error when loading key with wrong KEK") require.ErrorContains(t, err, "failed to decrypt") assert.Nil(t, loadedKey, "Expected nil key when loading fails") @@ -206,7 +206,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) { require.NoError(t, err) // Attempt to load the key - loadedKey, err := provider.LoadKey() + loadedKey, err := provider.LoadKey(t.Context()) require.Error(t, err, "Expected error when loading invalid key data") require.ErrorContains(t, err, "failed to parse") assert.Nil(t, loadedKey, "Expected nil key when loading fails") @@ -233,7 +233,7 @@ func TestKeyProviderDatabase_SaveKey(t *testing.T) { require.NoError(t, err) // Save the key - err = provider.SaveKey(key) + err = provider.SaveKey(t.Context(), key) require.NoError(t, err, "Expected no error when saving key") // Verify record exists in database