1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 15:04:43 +00:00
Files
pocket-id/backend/internal/cmds/encryption_key_rotate.go
2026-01-07 09:34:23 +01:00

188 lines
5.3 KiB
Go

package cmds
import (
"context"
"errors"
"fmt"
"os"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/spf13/cobra"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/common"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
type encryptionKeyRotateFlags struct {
NewKey string
Yes bool
}
func init() {
var flags encryptionKeyRotateFlags
encryptionKeyRotateCmd := &cobra.Command{
Use: "encryption-key-rotate",
Short: "Re-encrypts data using a new encryption key",
RunE: func(cmd *cobra.Command, args []string) error {
db, err := bootstrap.NewDatabase()
if err != nil {
return err
}
return encryptionKeyRotate(cmd.Context(), flags, db, &common.EnvConfig)
},
}
encryptionKeyRotateCmd.Flags().StringVar(&flags.NewKey, "new-key", "", "New encryption key to re-encrypt data with")
encryptionKeyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
rootCmd.AddCommand(encryptionKeyRotateCmd)
}
func encryptionKeyRotate(ctx context.Context, flags encryptionKeyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
oldKey := envConfig.EncryptionKey
newKey := []byte(flags.NewKey)
if len(newKey) == 0 {
return errors.New("new encryption key is required (--new-key)")
}
if len(newKey) < 16 {
return errors.New("new encryption key must be at least 16 bytes long")
}
if !flags.Yes {
fmt.Println("WARNING: Rotating the encryption key will re-encrypt secrets in the database. Pocket-ID must be restarted with the new ENCRYPTION_KEY after rotation is complete.")
ok, err := utils.PromptForConfirmation("Continue")
if err != nil {
return err
}
if !ok {
fmt.Println("Aborted")
os.Exit(1)
}
}
appConfigService, err := service.NewAppConfigService(ctx, db)
if err != nil {
return fmt.Errorf("failed to create app config service: %w", err)
}
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Derive the encryption keys used for the JWK encryption
oldKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: oldKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive old key encryption key: %w", err)
}
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive new key encryption key: %w", err)
}
// Derive the encryption keys used for EncryptedString fields
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
if err != nil {
return fmt.Errorf("failed to derive old encrypted string key: %w", err)
}
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
if err != nil {
return fmt.Errorf("failed to derive new encrypted string key: %w", err)
}
err = db.Transaction(func(tx *gorm.DB) error {
err = rotateSigningKeyEncryption(ctx, tx, oldKek, newKek)
if err != nil {
return err
}
err = rotateScimTokens(tx, oldEncKey, newEncKey)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
fmt.Println("Encryption key rotation completed successfully.")
fmt.Println("Restart pocket-id with the new ENCRYPTION_KEY to use the rotated data.")
return nil
}
func rotateSigningKeyEncryption(ctx context.Context, db *gorm.DB, oldKek []byte, newKek []byte) error {
oldProvider := &jwkutils.KeyProviderDatabase{}
err := oldProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: oldKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with old encryption key: %w", err)
}
key, err := oldProvider.LoadKey(ctx)
if err != nil {
return fmt.Errorf("failed to load signing key using old encryption key: %w", err)
}
if key == nil {
return nil
}
newProvider := &jwkutils.KeyProviderDatabase{}
err = newProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: newKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with new encryption key: %w", err)
}
if err := newProvider.SaveKey(ctx, key); err != nil {
return fmt.Errorf("failed to store signing key with new encryption key: %w", err)
}
return nil
}
type scimTokenRow struct {
ID string
Token string
}
func rotateScimTokens(db *gorm.DB, oldEncKey []byte, newEncKey []byte) error {
var rows []scimTokenRow
err := db.Model(&model.ScimServiceProvider{}).Select("id, token").Scan(&rows).Error
if err != nil {
return fmt.Errorf("failed to list SCIM service providers: %w", err)
}
for _, row := range rows {
if row.Token == "" {
continue
}
decBytes, err := datatype.DecryptEncryptedStringWithKey(oldEncKey, row.Token)
if err != nil {
return fmt.Errorf("failed to decrypt SCIM token for provider %s: %w", row.ID, err)
}
encValue, err := datatype.EncryptEncryptedStringWithKey(newEncKey, decBytes)
if err != nil {
return fmt.Errorf("failed to encrypt SCIM token for provider %s: %w", row.ID, err)
}
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", row.ID).Update("token", encValue).Error
if err != nil {
return fmt.Errorf("failed to update SCIM token for provider %s: %w", row.ID, err)
}
}
return nil
}