1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 15:39:45 +00:00

feat: add CLI command for encryption key rotation (#1209)

This commit is contained in:
Elias Schneider
2026-01-07 09:34:23 +01:00
committed by GitHub
parent 5828fa5779
commit 2af70d9b4d
13 changed files with 340 additions and 42 deletions

View File

@@ -52,7 +52,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
svc.geoLiteService = service.NewGeoLiteService(httpClient) svc.geoLiteService = service.NewGeoLiteService(httpClient)
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService) svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
svc.jwtService, err = service.NewJwtService(db, svc.appConfigService) svc.jwtService, err = service.NewJwtService(ctx, db, svc.appConfigService)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create JWT service: %w", err) return nil, fmt.Errorf("failed to create JWT service: %w", err)
} }

View File

@@ -0,0 +1,187 @@
package cmds
import (
"context"
"errors"
"fmt"
"os"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/spf13/cobra"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/common"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
type encryptionKeyRotateFlags struct {
NewKey string
Yes bool
}
func init() {
var flags encryptionKeyRotateFlags
encryptionKeyRotateCmd := &cobra.Command{
Use: "encryption-key-rotate",
Short: "Re-encrypts data using a new encryption key",
RunE: func(cmd *cobra.Command, args []string) error {
db, err := bootstrap.NewDatabase()
if err != nil {
return err
}
return encryptionKeyRotate(cmd.Context(), flags, db, &common.EnvConfig)
},
}
encryptionKeyRotateCmd.Flags().StringVar(&flags.NewKey, "new-key", "", "New encryption key to re-encrypt data with")
encryptionKeyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
rootCmd.AddCommand(encryptionKeyRotateCmd)
}
func encryptionKeyRotate(ctx context.Context, flags encryptionKeyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
oldKey := envConfig.EncryptionKey
newKey := []byte(flags.NewKey)
if len(newKey) == 0 {
return errors.New("new encryption key is required (--new-key)")
}
if len(newKey) < 16 {
return errors.New("new encryption key must be at least 16 bytes long")
}
if !flags.Yes {
fmt.Println("WARNING: Rotating the encryption key will re-encrypt secrets in the database. Pocket-ID must be restarted with the new ENCRYPTION_KEY after rotation is complete.")
ok, err := utils.PromptForConfirmation("Continue")
if err != nil {
return err
}
if !ok {
fmt.Println("Aborted")
os.Exit(1)
}
}
appConfigService, err := service.NewAppConfigService(ctx, db)
if err != nil {
return fmt.Errorf("failed to create app config service: %w", err)
}
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Derive the encryption keys used for the JWK encryption
oldKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: oldKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive old key encryption key: %w", err)
}
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive new key encryption key: %w", err)
}
// Derive the encryption keys used for EncryptedString fields
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
if err != nil {
return fmt.Errorf("failed to derive old encrypted string key: %w", err)
}
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
if err != nil {
return fmt.Errorf("failed to derive new encrypted string key: %w", err)
}
err = db.Transaction(func(tx *gorm.DB) error {
err = rotateSigningKeyEncryption(ctx, tx, oldKek, newKek)
if err != nil {
return err
}
err = rotateScimTokens(tx, oldEncKey, newEncKey)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
fmt.Println("Encryption key rotation completed successfully.")
fmt.Println("Restart pocket-id with the new ENCRYPTION_KEY to use the rotated data.")
return nil
}
func rotateSigningKeyEncryption(ctx context.Context, db *gorm.DB, oldKek []byte, newKek []byte) error {
oldProvider := &jwkutils.KeyProviderDatabase{}
err := oldProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: oldKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with old encryption key: %w", err)
}
key, err := oldProvider.LoadKey(ctx)
if err != nil {
return fmt.Errorf("failed to load signing key using old encryption key: %w", err)
}
if key == nil {
return nil
}
newProvider := &jwkutils.KeyProviderDatabase{}
err = newProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: newKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with new encryption key: %w", err)
}
if err := newProvider.SaveKey(ctx, key); err != nil {
return fmt.Errorf("failed to store signing key with new encryption key: %w", err)
}
return nil
}
type scimTokenRow struct {
ID string
Token string
}
func rotateScimTokens(db *gorm.DB, oldEncKey []byte, newEncKey []byte) error {
var rows []scimTokenRow
err := db.Model(&model.ScimServiceProvider{}).Select("id, token").Scan(&rows).Error
if err != nil {
return fmt.Errorf("failed to list SCIM service providers: %w", err)
}
for _, row := range rows {
if row.Token == "" {
continue
}
decBytes, err := datatype.DecryptEncryptedStringWithKey(oldEncKey, row.Token)
if err != nil {
return fmt.Errorf("failed to decrypt SCIM token for provider %s: %w", row.ID, err)
}
encValue, err := datatype.EncryptEncryptedStringWithKey(newEncKey, decBytes)
if err != nil {
return fmt.Errorf("failed to encrypt SCIM token for provider %s: %w", row.ID, err)
}
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", row.ID).Update("token", encValue).Error
if err != nil {
return fmt.Errorf("failed to update SCIM token for provider %s: %w", row.ID, err)
}
}
return nil
}

View File

@@ -0,0 +1,89 @@
package cmds
import (
"testing"
"time"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
func TestEncryptionKeyRotate(t *testing.T) {
oldKey := []byte("old-encryption-key-123456")
newKey := []byte("new-encryption-key-654321")
envConfig := &common.EnvConfigSchema{
EncryptionKey: oldKey,
}
db := testingutils.NewDatabaseForTest(t)
appConfigService, err := service.NewAppConfigService(t.Context(), db)
require.NoError(t, err)
instanceID := appConfigService.GetDbConfig().InstanceID.Value
oldKek, err := jwkutils.LoadKeyEncryptionKey(envConfig, instanceID)
require.NoError(t, err)
oldProvider := &jwkutils.KeyProviderDatabase{}
require.NoError(t, oldProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: oldKek,
}))
signingKey, err := jwkutils.GenerateKey("RS256", "")
require.NoError(t, err)
require.NoError(t, oldProvider.SaveKey(t.Context(), signingKey))
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
require.NoError(t, err)
encToken, err := datatype.EncryptEncryptedStringWithKey(oldEncKey, []byte("scim-token-123"))
require.NoError(t, err)
err = db.Exec(
`INSERT INTO scim_service_providers (id, created_at, endpoint, token, oidc_client_id) VALUES (?, ?, ?, ?, ?)`,
"scim-1",
time.Now(),
"https://example.com/scim",
encToken,
"client-1",
).Error
require.NoError(t, err)
flags := encryptionKeyRotateFlags{
NewKey: string(newKey),
Yes: true,
}
require.NoError(t, encryptionKeyRotate(t.Context(), flags, db, envConfig))
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
require.NoError(t, err)
newProvider := &jwkutils.KeyProviderDatabase{}
require.NoError(t, newProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: newKek,
}))
rotatedKey, err := newProvider.LoadKey(t.Context())
require.NoError(t, err)
require.NotNil(t, rotatedKey)
var storedToken string
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", "scim-1").Pluck("token", &storedToken).Error
require.NoError(t, err)
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
require.NoError(t, err)
decBytes, err := datatype.DecryptEncryptedStringWithKey(newEncKey, storedToken)
require.NoError(t, err)
assert.Equal(t, "scim-token-123", string(decBytes))
}

View File

@@ -102,7 +102,7 @@ func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig
} }
// Save the key // Save the key
err = keyProvider.SaveKey(key) err = keyProvider.SaveKey(ctx, key)
if err != nil { if err != nil {
return fmt.Errorf("failed to store new key: %w", err) return fmt.Errorf("failed to store new key: %w", err)
} }

View File

@@ -104,7 +104,7 @@ func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantEr
require.NoError(t, err) require.NoError(t, err)
// Verify key was created // Verify key was created
key, err := keyProvider.LoadKey() key, err := keyProvider.LoadKey(t.Context())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, key) require.NotNil(t, key)

View File

@@ -40,14 +40,9 @@ func (e *EncryptedString) Scan(value any) error {
return nil return nil
} }
encBytes, err := base64.StdEncoding.DecodeString(raw) decBytes, err := DecryptEncryptedStringWithKey(encStringKey, raw)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode encrypted string: %w", err) return err
}
decBytes, err := cryptoutils.Decrypt(encStringKey, encBytes, []byte(encryptedStringAAD))
if err != nil {
return fmt.Errorf("failed to decrypt encrypted string: %w", err)
} }
*e = EncryptedString(decBytes) *e = EncryptedString(decBytes)
@@ -59,19 +54,20 @@ func (e EncryptedString) Value() (driver.Value, error) {
return "", nil return "", nil
} }
encBytes, err := cryptoutils.Encrypt(encStringKey, []byte(e), []byte(encryptedStringAAD)) encValue, err := EncryptEncryptedStringWithKey(encStringKey, []byte(e))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encrypt string: %w", err) return nil, err
} }
return base64.StdEncoding.EncodeToString(encBytes), nil return encValue, nil
} }
func (e EncryptedString) String() string { func (e EncryptedString) String() string {
return string(e) return string(e)
} }
func deriveEncryptedStringKey(master []byte) ([]byte, error) { // DeriveEncryptedStringKey derives a key for encrypting EncryptedString values from the master key.
func DeriveEncryptedStringKey(master []byte) ([]byte, error) {
const info = "pocketid/encrypted_string" const info = "pocketid/encrypted_string"
r := hkdf.New(sha256.New, master, nil, []byte(info)) r := hkdf.New(sha256.New, master, nil, []byte(info))
@@ -82,8 +78,33 @@ func deriveEncryptedStringKey(master []byte) ([]byte, error) {
return key, nil return key, nil
} }
// DecryptEncryptedStringWithKey decrypts an EncryptedString value using the derived key.
func DecryptEncryptedStringWithKey(key []byte, encoded string) ([]byte, error) {
encBytes, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("failed to decode encrypted string: %w", err)
}
decBytes, err := cryptoutils.Decrypt(key, encBytes, []byte(encryptedStringAAD))
if err != nil {
return nil, fmt.Errorf("failed to decrypt encrypted string: %w", err)
}
return decBytes, nil
}
// EncryptEncryptedStringWithKey encrypts an EncryptedString value using the derived key.
func EncryptEncryptedStringWithKey(key []byte, plaintext []byte) (string, error) {
encBytes, err := cryptoutils.Encrypt(key, plaintext, []byte(encryptedStringAAD))
if err != nil {
return "", fmt.Errorf("failed to encrypt string: %w", err)
}
return base64.StdEncoding.EncodeToString(encBytes), nil
}
func init() { func init() {
key, err := deriveEncryptedStringKey(common.EnvConfig.EncryptionKey) key, err := DeriveEncryptedStringKey(common.EnvConfig.EncryptionKey)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to derive encrypted string key: %v", err)) panic(fmt.Sprintf("failed to derive encrypted string key: %v", err))
} }

View File

@@ -526,7 +526,7 @@ func (s *TestService) ResetAppConfig(ctx context.Context) error {
} }
// Reload the JWK // Reload the JWK
if err := s.jwtService.LoadOrGenerateKey(); err != nil { if err := s.jwtService.LoadOrGenerateKey(ctx); err != nil {
return err return err
} }

View File

@@ -56,10 +56,10 @@ type JwtService struct {
jwksEncoded []byte jwksEncoded []byte
} }
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) { func NewJwtService(ctx context.Context, db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) {
service := &JwtService{} service := &JwtService{}
err := service.init(db, appConfigService, &common.EnvConfig) err := service.init(ctx, db, appConfigService, &common.EnvConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -67,16 +67,16 @@ func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService
return service, nil return service, nil
} }
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) { func (s *JwtService) init(ctx context.Context, db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
s.appConfigService = appConfigService s.appConfigService = appConfigService
s.envConfig = envConfig s.envConfig = envConfig
s.db = db s.db = db
// Ensure keys are generated or loaded // Ensure keys are generated or loaded
return s.LoadOrGenerateKey() return s.LoadOrGenerateKey(ctx)
} }
func (s *JwtService) LoadOrGenerateKey() error { func (s *JwtService) LoadOrGenerateKey(ctx context.Context) error {
// Get the key provider // Get the key provider
keyProvider, err := jwkutils.GetKeyProvider(s.db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value) keyProvider, err := jwkutils.GetKeyProvider(s.db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
if err != nil { if err != nil {
@@ -84,7 +84,7 @@ func (s *JwtService) LoadOrGenerateKey() error {
} }
// Try loading a key // Try loading a key
key, err := keyProvider.LoadKey() key, err := keyProvider.LoadKey(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to load key: %w", err) return fmt.Errorf("failed to load key: %w", err)
} }
@@ -105,7 +105,7 @@ func (s *JwtService) LoadOrGenerateKey() error {
} }
// Save the newly-generated key // Save the newly-generated key
err = keyProvider.SaveKey(s.privateKey) err = keyProvider.SaveKey(ctx, s.privateKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to save private key: %w", err) return fmt.Errorf("failed to save private key: %w", err)
} }

View File

@@ -38,7 +38,7 @@ func initJwtService(t *testing.T, db *gorm.DB, appConfig *AppConfigService, envC
t.Helper() t.Helper()
service := &JwtService{} service := &JwtService{}
err := service.init(db, appConfig, envConfig) err := service.init(t.Context(), db, appConfig, envConfig)
require.NoError(t, err, "Failed to initialize JWT service") require.NoError(t, err, "Failed to initialize JWT service")
return service return service
@@ -65,7 +65,7 @@ func saveKeyToDatabase(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSch
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfig.GetDbConfig().InstanceID.Value) keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfig.GetDbConfig().InstanceID.Value)
require.NoError(t, err, "Failed to init key provider") require.NoError(t, err, "Failed to init key provider")
err = keyProvider.SaveKey(key) err = keyProvider.SaveKey(t.Context(), key)
require.NoError(t, err, "Failed to save key") require.NoError(t, err, "Failed to save key")
kid, ok := key.KeyID() kid, ok := key.KeyID()
@@ -93,7 +93,7 @@ func TestJwtService_Init(t *testing.T) {
// Verify the key has been persisted in the database // Verify the key has been persisted in the database
keyProvider, err := jwkutils.GetKeyProvider(db, mockEnvConfig, mockConfig.GetDbConfig().InstanceID.Value) keyProvider, err := jwkutils.GetKeyProvider(db, mockEnvConfig, mockConfig.GetDbConfig().InstanceID.Value)
require.NoError(t, err, "Failed to init key provider") require.NoError(t, err, "Failed to init key provider")
key, err := keyProvider.LoadKey() key, err := keyProvider.LoadKey(t.Context())
require.NoError(t, err, "Failed to load key from provider") require.NoError(t, err, "Failed to load key from provider")
require.NotNil(t, key, "Key should be present in the database") require.NotNil(t, key, "Key should be present in the database")

View File

@@ -160,7 +160,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
mockConfig := NewTestAppConfigService(&model.AppConfig{ mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}) })
mockJwtService, err := NewJwtService(db, mockConfig) mockJwtService, err := NewJwtService(t.Context(), db, mockConfig)
require.NoError(t, err) require.NoError(t, err)
// Create a mock HTTP client with custom transport to return the JWKS // Create a mock HTTP client with custom transport to return the JWKS

View File

@@ -1,6 +1,7 @@
package jwk package jwk
import ( import (
"context"
"fmt" "fmt"
"github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwk"
@@ -17,8 +18,8 @@ type KeyProviderOpts struct {
type KeyProvider interface { type KeyProvider interface {
Init(opts KeyProviderOpts) error Init(opts KeyProviderOpts) error
LoadKey() (jwk.Key, error) LoadKey(ctx context.Context) (jwk.Key, error)
SaveKey(key jwk.Key) error SaveKey(ctx context.Context, key jwk.Key) error
} }
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) { func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {

View File

@@ -33,12 +33,12 @@ func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
return nil return nil
} }
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) { func (f *KeyProviderDatabase) LoadKey(ctx context.Context) (key jwk.Key, err error) {
row := model.KV{ row := model.KV{
Key: PrivateKeyDBKey, Key: PrivateKeyDBKey,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
err = f.db.WithContext(ctx).First(&row).Error err = f.db.WithContext(ctx).First(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -74,7 +74,7 @@ func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
return key, nil return key, nil
} }
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error { func (f *KeyProviderDatabase) SaveKey(ctx context.Context, key jwk.Key) error {
// Encode the key to JSON // Encode the key to JSON
data, err := EncodeJWKBytes(key) data, err := EncodeJWKBytes(key)
if err != nil { if err != nil {
@@ -94,7 +94,7 @@ func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
Value: &encB64, Value: &encB64,
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
err = f.db. err = f.db.
WithContext(ctx). WithContext(ctx).

View File

@@ -59,7 +59,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Load key when none exists // Load key when none exists
loadedKey, err := provider.LoadKey() loadedKey, err := provider.LoadKey(t.Context())
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database") assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
}) })
@@ -76,11 +76,11 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Save a key // Save a key
err = provider.SaveKey(key) err = provider.SaveKey(t.Context(), key)
require.NoError(t, err) require.NoError(t, err)
// Load the key // Load the key
loadedKey, err := provider.LoadKey() loadedKey, err := provider.LoadKey(t.Context())
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database") assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
@@ -114,7 +114,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attempt to load the key // Attempt to load the key
loadedKey, err := provider.LoadKey() loadedKey, err := provider.LoadKey(t.Context())
require.Error(t, err, "Expected error when loading key with invalid base64") require.Error(t, err, "Expected error when loading key with invalid base64")
require.ErrorContains(t, err, "not a valid base64-encoded value") require.ErrorContains(t, err, "not a valid base64-encoded value")
assert.Nil(t, loadedKey, "Expected nil key when loading fails") assert.Nil(t, loadedKey, "Expected nil key when loading fails")
@@ -140,7 +140,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attempt to load the key // Attempt to load the key
loadedKey, err := provider.LoadKey() loadedKey, err := provider.LoadKey(t.Context())
require.Error(t, err, "Expected error when loading key with invalid encrypted data") require.Error(t, err, "Expected error when loading key with invalid encrypted data")
require.ErrorContains(t, err, "failed to decrypt") require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails") assert.Nil(t, loadedKey, "Expected nil key when loading fails")
@@ -158,7 +158,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
err = originalProvider.SaveKey(key) err = originalProvider.SaveKey(t.Context(), key)
require.NoError(t, err) require.NoError(t, err)
// Now try to load with a different KEK // Now try to load with a different KEK
@@ -171,7 +171,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attempt to load the key with the wrong KEK // Attempt to load the key with the wrong KEK
loadedKey, err := differentProvider.LoadKey() loadedKey, err := differentProvider.LoadKey(t.Context())
require.Error(t, err, "Expected error when loading key with wrong KEK") require.Error(t, err, "Expected error when loading key with wrong KEK")
require.ErrorContains(t, err, "failed to decrypt") require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails") assert.Nil(t, loadedKey, "Expected nil key when loading fails")
@@ -206,7 +206,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attempt to load the key // Attempt to load the key
loadedKey, err := provider.LoadKey() loadedKey, err := provider.LoadKey(t.Context())
require.Error(t, err, "Expected error when loading invalid key data") require.Error(t, err, "Expected error when loading invalid key data")
require.ErrorContains(t, err, "failed to parse") require.ErrorContains(t, err, "failed to parse")
assert.Nil(t, loadedKey, "Expected nil key when loading fails") assert.Nil(t, loadedKey, "Expected nil key when loading fails")
@@ -233,7 +233,7 @@ func TestKeyProviderDatabase_SaveKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Save the key // Save the key
err = provider.SaveKey(key) err = provider.SaveKey(t.Context(), key)
require.NoError(t, err, "Expected no error when saving key") require.NoError(t, err, "Expected no error when saving key")
// Verify record exists in database // Verify record exists in database