mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 15:04:43 +00:00
188 lines
5.3 KiB
Go
188 lines
5.3 KiB
Go
package cmds
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
"github.com/spf13/cobra"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
|
)
|
|
|
|
type encryptionKeyRotateFlags struct {
|
|
NewKey string
|
|
Yes bool
|
|
}
|
|
|
|
func init() {
|
|
var flags encryptionKeyRotateFlags
|
|
|
|
encryptionKeyRotateCmd := &cobra.Command{
|
|
Use: "encryption-key-rotate",
|
|
Short: "Re-encrypts data using a new encryption key",
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
db, err := bootstrap.NewDatabase()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return encryptionKeyRotate(cmd.Context(), flags, db, &common.EnvConfig)
|
|
},
|
|
}
|
|
|
|
encryptionKeyRotateCmd.Flags().StringVar(&flags.NewKey, "new-key", "", "New encryption key to re-encrypt data with")
|
|
encryptionKeyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
|
|
|
|
rootCmd.AddCommand(encryptionKeyRotateCmd)
|
|
}
|
|
|
|
func encryptionKeyRotate(ctx context.Context, flags encryptionKeyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
|
|
oldKey := envConfig.EncryptionKey
|
|
newKey := []byte(flags.NewKey)
|
|
if len(newKey) == 0 {
|
|
return errors.New("new encryption key is required (--new-key)")
|
|
}
|
|
if len(newKey) < 16 {
|
|
return errors.New("new encryption key must be at least 16 bytes long")
|
|
}
|
|
|
|
if !flags.Yes {
|
|
fmt.Println("WARNING: Rotating the encryption key will re-encrypt secrets in the database. Pocket-ID must be restarted with the new ENCRYPTION_KEY after rotation is complete.")
|
|
ok, err := utils.PromptForConfirmation("Continue")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !ok {
|
|
fmt.Println("Aborted")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
appConfigService, err := service.NewAppConfigService(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create app config service: %w", err)
|
|
}
|
|
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
|
|
|
// Derive the encryption keys used for the JWK encryption
|
|
oldKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: oldKey}, instanceID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to derive old key encryption key: %w", err)
|
|
}
|
|
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to derive new key encryption key: %w", err)
|
|
}
|
|
|
|
// Derive the encryption keys used for EncryptedString fields
|
|
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to derive old encrypted string key: %w", err)
|
|
}
|
|
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to derive new encrypted string key: %w", err)
|
|
}
|
|
|
|
err = db.Transaction(func(tx *gorm.DB) error {
|
|
err = rotateSigningKeyEncryption(ctx, tx, oldKek, newKek)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = rotateScimTokens(tx, oldEncKey, newEncKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fmt.Println("Encryption key rotation completed successfully.")
|
|
fmt.Println("Restart pocket-id with the new ENCRYPTION_KEY to use the rotated data.")
|
|
|
|
return nil
|
|
}
|
|
|
|
func rotateSigningKeyEncryption(ctx context.Context, db *gorm.DB, oldKek []byte, newKek []byte) error {
|
|
oldProvider := &jwkutils.KeyProviderDatabase{}
|
|
err := oldProvider.Init(jwkutils.KeyProviderOpts{
|
|
DB: db,
|
|
Kek: oldKek,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to init key provider with old encryption key: %w", err)
|
|
}
|
|
|
|
key, err := oldProvider.LoadKey(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load signing key using old encryption key: %w", err)
|
|
}
|
|
if key == nil {
|
|
return nil
|
|
}
|
|
|
|
newProvider := &jwkutils.KeyProviderDatabase{}
|
|
err = newProvider.Init(jwkutils.KeyProviderOpts{
|
|
DB: db,
|
|
Kek: newKek,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to init key provider with new encryption key: %w", err)
|
|
}
|
|
|
|
if err := newProvider.SaveKey(ctx, key); err != nil {
|
|
return fmt.Errorf("failed to store signing key with new encryption key: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scimTokenRow struct {
|
|
ID string
|
|
Token string
|
|
}
|
|
|
|
func rotateScimTokens(db *gorm.DB, oldEncKey []byte, newEncKey []byte) error {
|
|
var rows []scimTokenRow
|
|
err := db.Model(&model.ScimServiceProvider{}).Select("id, token").Scan(&rows).Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to list SCIM service providers: %w", err)
|
|
}
|
|
|
|
for _, row := range rows {
|
|
if row.Token == "" {
|
|
continue
|
|
}
|
|
|
|
decBytes, err := datatype.DecryptEncryptedStringWithKey(oldEncKey, row.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt SCIM token for provider %s: %w", row.ID, err)
|
|
}
|
|
|
|
encValue, err := datatype.EncryptEncryptedStringWithKey(newEncKey, decBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt SCIM token for provider %s: %w", row.ID, err)
|
|
}
|
|
|
|
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", row.ID).Update("token", encValue).Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update SCIM token for provider %s: %w", row.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|