1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 12:46:45 +00:00

feat!: drop support for storing JWK on the filesystem (#1088)

This commit is contained in:
Elias Schneider
2025-11-14 12:02:51 +01:00
parent e1c5021eee
commit f0144584af
12 changed files with 241 additions and 1326 deletions

View File

@@ -1,9 +1,12 @@
package main package main
import ( import (
"fmt"
"os"
_ "time/tzdata" _ "time/tzdata"
"github.com/pocket-id/pocket-id/backend/internal/cmds" "github.com/pocket-id/pocket-id/backend/internal/cmds"
"github.com/pocket-id/pocket-id/backend/internal/common"
) )
// @title Pocket ID API // @title Pocket ID API
@@ -11,5 +14,9 @@ import (
// @description.markdown // @description.markdown
func main() { func main() {
if err := common.ValidateEnvConfig(&common.EnvConfig); err != nil {
fmt.Fprintf(os.Stderr, "config error: %v\n", err)
os.Exit(1)
}
cmds.Execute() cmds.Execute()
} }

View File

@@ -1,8 +1,6 @@
package cmds package cmds
import ( import (
"os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -69,78 +67,14 @@ func TestKeyRotate(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Run("file storage", func(t *testing.T) { testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg)
})
t.Run("database storage", func(t *testing.T) {
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
})
}) })
} }
} }
func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
// Create temporary directory for keys
tempDir := t.TempDir()
keysPath := filepath.Join(tempDir, "keys")
err := os.MkdirAll(keysPath, 0755)
require.NoError(t, err)
// Set up file storage config
envConfig := &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: keysPath,
}
// Create test database
db := testingutils.NewDatabaseForTest(t)
// Initialize app config service and create instance
appConfigService, err := service.NewAppConfigService(t.Context(), db)
require.NoError(t, err)
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Check if key exists before rotation
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
require.NoError(t, err)
// Run the key rotation
err = keyRotate(t.Context(), flags, db, envConfig)
if wantErr {
require.Error(t, err)
if errMsg != "" {
require.ErrorContains(t, err, errMsg)
}
return
}
require.NoError(t, err)
// Verify key was created
key, err := keyProvider.LoadKey()
require.NoError(t, err)
require.NotNil(t, key)
// Verify the algorithm matches what we requested
alg, _ := key.Algorithm()
assert.NotEmpty(t, alg)
if flags.Alg != "" {
expectedAlg := flags.Alg
if expectedAlg == "EdDSA" {
// EdDSA keys should have the EdDSA algorithm
assert.Equal(t, "EdDSA", alg.String())
} else {
assert.Equal(t, expectedAlg, alg.String())
}
}
}
func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) { func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
// Set up database storage config // Set up database storage config
envConfig := &common.EnvConfigSchema{ envConfig := &common.EnvConfigSchema{
KeysStorage: "database",
EncryptionKey: []byte("test-encryption-key-characters-long"), EncryptionKey: []byte("test-encryption-key-characters-long"),
} }

View File

@@ -52,8 +52,6 @@ type EnvConfigSchema struct {
S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"` S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"`
S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"` S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"`
S3DisableDefaultIntegrityChecks bool `env:"S3_DISABLE_DEFAULT_INTEGRITY_CHECKS"` S3DisableDefaultIntegrityChecks bool `env:"S3_DISABLE_DEFAULT_INTEGRITY_CHECKS"`
KeysPath string `env:"KEYS_PATH"`
KeysStorage string `env:"KEYS_STORAGE"`
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"` EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
Port string `env:"PORT"` Port string `env:"PORT"`
Host string `env:"HOST" options:"toLower"` Host string `env:"HOST" options:"toLower"`
@@ -90,7 +88,6 @@ func defaultConfig() EnvConfigSchema {
LogLevel: "info", LogLevel: "info",
DbProvider: "sqlite", DbProvider: "sqlite",
FileBackend: "filesystem", FileBackend: "filesystem",
KeysPath: "data/keys",
AuditLogRetentionDays: 90, AuditLogRetentionDays: 90,
AppURL: AppUrl, AppURL: AppUrl,
Port: "1411", Port: "1411",
@@ -119,21 +116,20 @@ func parseEnvConfig() error {
return fmt.Errorf("error preparing env config: %w", err) return fmt.Errorf("error preparing env config: %w", err)
} }
err = validateEnvConfig(&EnvConfig)
if err != nil {
return err
}
return nil return nil
} }
// validateEnvConfig checks the EnvConfig for required fields and valid values // ValidateEnvConfig checks the EnvConfig for required fields and valid values
func validateEnvConfig(config *EnvConfigSchema) error { func ValidateEnvConfig(config *EnvConfigSchema) error {
if _, err := sloggin.ParseLevel(config.LogLevel); err != nil { if _, err := sloggin.ParseLevel(config.LogLevel); err != nil {
return errors.New("invalid LOG_LEVEL value. Must be 'debug', 'info', 'warn' or 'error'") return errors.New("invalid LOG_LEVEL value. Must be 'debug', 'info', 'warn' or 'error'")
} }
if len(config.EncryptionKey) < 16 {
return errors.New("ENCRYPTION_KEY must be at least 16 bytes long")
}
switch config.DbProvider { switch config.DbProvider {
case DbProviderSqlite: case DbProviderSqlite:
if config.DbConnectionString == "" { if config.DbConnectionString == "" {
@@ -168,28 +164,10 @@ func validateEnvConfig(config *EnvConfigSchema) error {
} }
} }
switch config.KeysStorage {
// KeysStorage defaults to "file" if empty
case "":
config.KeysStorage = "file"
case "database":
if config.EncryptionKey == nil {
return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
}
case "file":
// All good, these are valid values
default:
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
}
switch config.FileBackend { switch config.FileBackend {
case "s3": case "s3", "database":
if config.KeysStorage == "file" {
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
}
case "database":
// All good, these are valid values // All good, these are valid values
case "", "filesystem": case "", "fs":
if config.UploadPath == "" { if config.UploadPath == "" {
config.UploadPath = defaultFsUploadPath config.UploadPath = defaultFsUploadPath
} }

View File

@@ -8,6 +8,20 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func parseAndValidateEnvConfig(t *testing.T) error {
t.Helper()
if _, exists := os.LookupEnv("ENCRYPTION_KEY"); !exists {
t.Setenv("ENCRYPTION_KEY", "0123456789abcdef")
}
if err := parseEnvConfig(); err != nil {
return err
}
return ValidateEnvConfig(&EnvConfig)
}
func TestParseEnvConfig(t *testing.T) { func TestParseEnvConfig(t *testing.T) {
// Store original config to restore later // Store original config to restore later
originalConfig := EnvConfig originalConfig := EnvConfig
@@ -21,7 +35,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "HTTP://LOCALHOST:3000") t.Setenv("APP_URL", "HTTP://LOCALHOST:3000")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider) assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
assert.Equal(t, "http://localhost:3000", EnvConfig.AppURL) assert.Equal(t, "http://localhost:3000", EnvConfig.AppURL)
@@ -33,7 +47,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db") t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
t.Setenv("APP_URL", "https://example.com") t.Setenv("APP_URL", "https://example.com")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider) assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
}) })
@@ -44,17 +58,29 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "test") t.Setenv("DB_CONNECTION_STRING", "test")
t.Setenv("APP_URL", "http://localhost:3000") t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "invalid DB_PROVIDER value") assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
}) })
t.Run("should fail when ENCRYPTION_KEY is too short", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("ENCRYPTION_KEY", "short")
err := parseAndValidateEnvConfig(t)
require.Error(t, err)
assert.ErrorContains(t, err, "ENCRYPTION_KEY must be at least 16 bytes long")
})
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) { t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
EnvConfig = defaultConfig() EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite") t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("APP_URL", "http://localhost:3000") t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString) assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
}) })
@@ -64,7 +90,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_PROVIDER", "postgres") t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("APP_URL", "http://localhost:3000") t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres") assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
}) })
@@ -75,7 +101,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "€://not-a-valid-url") t.Setenv("APP_URL", "€://not-a-valid-url")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL is not a valid URL") assert.ErrorContains(t, err, "APP_URL is not a valid URL")
}) })
@@ -86,7 +112,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000/path") t.Setenv("APP_URL", "http://localhost:3000/path")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL must not contain a path") assert.ErrorContains(t, err, "APP_URL must not contain a path")
}) })
@@ -97,7 +123,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("INTERNAL_APP_URL", "€://not-a-valid-url") t.Setenv("INTERNAL_APP_URL", "€://not-a-valid-url")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "INTERNAL_APP_URL is not a valid URL") assert.ErrorContains(t, err, "INTERNAL_APP_URL is not a valid URL")
}) })
@@ -108,65 +134,11 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("INTERNAL_APP_URL", "http://localhost:3000/path") t.Setenv("INTERNAL_APP_URL", "http://localhost:3000/path")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "INTERNAL_APP_URL must not contain a path") assert.ErrorContains(t, err, "INTERNAL_APP_URL must not contain a path")
}) })
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "file", EnvConfig.KeysStorage)
})
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "database")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
})
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
validStorageTypes := []string{"file", "database"}
for _, storage := range validStorageTypes {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", storage)
if storage == "database" {
t.Setenv("ENCRYPTION_KEY", "test-key")
}
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, storage, EnvConfig.KeysStorage)
}
})
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "invalid")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
})
t.Run("should parse boolean environment variables correctly", func(t *testing.T) { t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
EnvConfig = defaultConfig() EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite") t.Setenv("DB_PROVIDER", "sqlite")
@@ -178,7 +150,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("TRUST_PROXY", "true") t.Setenv("TRUST_PROXY", "true")
t.Setenv("ANALYTICS_DISABLED", "false") t.Setenv("ANALYTICS_DISABLED", "false")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, EnvConfig.UiConfigDisabled) assert.True(t, EnvConfig.UiConfigDisabled)
assert.True(t, EnvConfig.MetricsEnabled) assert.True(t, EnvConfig.MetricsEnabled)
@@ -229,14 +201,13 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("APP_URL", "https://prod.example.com") t.Setenv("APP_URL", "https://prod.example.com")
t.Setenv("APP_ENV", "PRODUCTION") t.Setenv("APP_ENV", "PRODUCTION")
t.Setenv("UPLOAD_PATH", "/custom/uploads") t.Setenv("UPLOAD_PATH", "/custom/uploads")
t.Setenv("KEYS_PATH", "/custom/keys")
t.Setenv("PORT", "8080") t.Setenv("PORT", "8080")
t.Setenv("HOST", "LOCALHOST") t.Setenv("HOST", "LOCALHOST")
t.Setenv("UNIX_SOCKET", "/tmp/app.sock") t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
t.Setenv("MAXMIND_LICENSE_KEY", "test-license") t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb") t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, AppEnvProduction, EnvConfig.AppEnv) // lowercased assert.Equal(t, AppEnvProduction, EnvConfig.AppEnv) // lowercased
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath) assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
@@ -252,24 +223,12 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("FILE_BACKEND", "FILESYSTEM") t.Setenv("FILE_BACKEND", "FILESYSTEM")
t.Setenv("UPLOAD_PATH", "") t.Setenv("UPLOAD_PATH", "")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "filesystem", EnvConfig.FileBackend) assert.Equal(t, "filesystem", EnvConfig.FileBackend)
assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath) assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath)
}) })
t.Run("should fail when FILE_BACKEND is s3 but keys are stored on filesystem", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("FILE_BACKEND", "s3")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
})
t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) { t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) {
EnvConfig = defaultConfig() EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite") t.Setenv("DB_PROVIDER", "sqlite")
@@ -277,7 +236,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("APP_URL", "http://localhost:3000") t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("FILE_BACKEND", "invalid") t.Setenv("FILE_BACKEND", "invalid")
err := parseEnvConfig() err := parseAndValidateEnvConfig(t)
require.Error(t, err) require.Error(t, err)
assert.ErrorContains(t, err, "invalid FILE_BACKEND value") assert.ErrorContains(t, err, "invalid FILE_BACKEND value")
}) })

View File

@@ -18,14 +18,6 @@ import (
) )
const ( const (
// PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
// KeyUsageSigning is the usage for the private keys, for the "use" property // KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig" KeyUsageSigning = "sig"
@@ -93,7 +85,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
// Try loading a key // Try loading a key
key, err := keyProvider.LoadKey() key, err := keyProvider.LoadKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err) return fmt.Errorf("failed to load key: %w", err)
} }
// If we have a key, store it in the object and we're done // If we have a key, store it in the object and we're done
@@ -114,7 +106,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
// Save the newly-generated key // Save the newly-generated key
err = keyProvider.SaveKey(s.privateKey) err = keyProvider.SaveKey(s.privateKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err) return fmt.Errorf("failed to save private key: %w", err)
} }
return nil return nil

File diff suppressed because it is too large Load Diff

View File

@@ -148,6 +148,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
var err error var err error
// Create a test database // Create a test database
db := testutils.NewDatabaseForTest(t) db := testutils.NewDatabaseForTest(t)
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
// Create two JWKs for testing // Create two JWKs for testing
privateJWK, jwkSetJSON := generateTestECDSAKey(t) privateJWK, jwkSetJSON := generateTestECDSAKey(t)

View File

@@ -28,22 +28,14 @@ func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID s
return nil, fmt.Errorf("failed to load encryption key: %w", err) return nil, fmt.Errorf("failed to load encryption key: %w", err)
} }
// Get the key provider keyProvider = &KeyProviderDatabase{}
switch envConfig.KeysStorage {
case "file", "":
keyProvider = &KeyProviderFile{}
case "database":
keyProvider = &KeyProviderDatabase{}
default:
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
}
err = keyProvider.Init(KeyProviderOpts{ err = keyProvider.Init(KeyProviderOpts{
DB: db, DB: db,
EnvConfig: envConfig, EnvConfig: envConfig,
Kek: kek, Kek: kek,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err) return nil, fmt.Errorf("failed to init key provider: %w", err)
} }
return keyProvider, nil return keyProvider, nil

View File

@@ -1,202 +0,0 @@
package jwk
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
const (
// PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
)
type KeyProviderFile struct {
envConfig *common.EnvConfigSchema
kek []byte
}
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
f.envConfig = opts.EnvConfig
f.kek = opts.Kek
return nil
}
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
if len(f.kek) > 0 {
return f.loadEncryptedKey()
}
return f.loadKey()
}
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
if len(f.kek) > 0 {
return f.saveKeyEncrypted(key)
}
return f.saveKey(key)
}
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := f.jwkPath()
ok, err := utils.FileExists(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
}
if !ok {
// File doesn't exist, no key was loaded
return nil, nil
}
data, err := os.ReadFile(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
}
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
// First, check if we have an encrypted JWK file
// If we do, then we just load that
encJwkPath := f.encJwkPath()
ok, err := utils.FileExists(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
}
if ok {
encB64, err := os.ReadFile(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
}
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
}
// Decrypt the data
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
}
// Parse the key
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
}
return key, nil
}
// Check if we have an un-encrypted JWK file
key, err = f.loadKey()
if err != nil {
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
}
if key == nil {
// No key exists, encrypted or un-encrypted
return nil, nil
}
// If we are here, we have loaded a key that was un-encrypted
// We need to replace the plaintext key with the encrypted one before we return
err = f.saveKeyEncrypted(key)
if err != nil {
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
}
jwkPath := f.jwkPath()
err = os.Remove(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
}
jwkPath := f.jwkPath()
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
}
defer keyFile.Close()
// Write the JSON file to disk
err = EncodeJWK(keyFile, key)
if err != nil {
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
}
return nil
}
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
}
// Encode the key to JSON
data, err := EncodeJWKBytes(key)
if err != nil {
return fmt.Errorf("failed to encode key to JSON: %w", err)
}
// Encrypt the key then encode to Base64
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
if err != nil {
return fmt.Errorf("failed to encrypt key: %w", err)
}
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
base64.StdEncoding.Encode(encB64, enc)
// Write to disk
encJwkPath := f.encJwkPath()
err = os.WriteFile(encJwkPath, encB64, 0600)
if err != nil {
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
}
return nil
}
func (f *KeyProviderFile) jwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
}
func (f *KeyProviderFile) encJwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
}
// Compile-time interface check
var _ KeyProvider = (*KeyProviderFile)(nil)

View File

@@ -1,320 +0,0 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"os"
"path/filepath"
"testing"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
func TestKeyProviderFile_LoadKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("LoadKey with no existing key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save a key
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey with encrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Save a key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Make sure the unencrypted key file does not exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
tempDir := t.TempDir()
// First, create an unencrypted key
providerNoKek := &KeyProviderFile{}
err := providerNoKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save an unencrypted key
err = providerNoKek.SaveKey(key)
require.NoError(t, err)
// Verify unencrypted key exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected unencrypted key file to exist")
// Now create a provider with a kek
kek := make([]byte, 32)
_, err = rand.Read(kek)
require.NoError(t, err)
providerWithKek := &KeyProviderFile{}
err = providerWithKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Load the key - this should convert the unencrypted key to encrypted
loadedKey, err := providerWithKek.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
// Verify the unencrypted key no longer exists
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to be removed")
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err = utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
// Verify the key data
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
})
}
func TestKeyProviderFile_SaveKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("SaveKey unencrypted", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save the key
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Verify the content of the key file
data, err := os.ReadFile(keyPath)
require.NoError(t, err)
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the saved key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
})
t.Run("SaveKey encrypted", func(t *testing.T) {
tempDir := t.TempDir()
// Generate a 64-byte kek
kek := makeKEK(t)
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Save the key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Verify the unencrypted key file doesn't exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Manually decrypt the encrypted key file to verify it contains the correct key
encB64, err := os.ReadFile(encKeyPath)
require.NoError(t, err)
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
require.NoError(t, err)
enc = enc[:n] // Trim any padding
// Decrypt the data
data, err := cryptoutils.Decrypt(kek, enc, nil)
require.NoError(t, err)
// Parse the key
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the decrypted key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
})
}
func makeKEK(t *testing.T) []byte {
t.Helper()
// Generate a 32-byte kek
kek := make([]byte, 32)
_, err := rand.Read(kek)
require.NoError(t, err)
return kek
}

View File

@@ -28,14 +28,13 @@ services:
file: docker-compose.yml file: docker-compose.yml
service: pocket-id service: pocket-id
environment: environment:
- S3_BUCKET=pocket-id-test FILE_BACKEND: s3
- S3_REGION=us-east-1 S3_BUCKET: pocket-id-test
- S3_ENDPOINT=http://localstack-s3:4566 S3_REGION: us-east-1
- S3_ACCESS_KEY_ID=test S3_ENDPOINT: http://localstack-s3:4566
- S3_SECRET_ACCESS_KEY=test S3_ACCESS_KEY_ID: test
- S3_FORCE_PATH_STYLE=true S3_SECRET_ACCESS_KEY: test
- KEYS_STORAGE=database S3_FORCE_PATH_STYLE: true
- ENCRYPTION_KEY=test1234test1234test1234test1234
depends_on: depends_on:
create-bucket: create-bucket:
condition: service_completed_successfully condition: service_completed_successfully

View File

@@ -13,8 +13,9 @@ services:
ports: ports:
- '1411:1411' - '1411:1411'
environment: environment:
- APP_ENV=test APP_ENV: test
- FILE_BACKEND=${FILE_BACKEND} ENCRYPTION_KEY: test-encryption-key
FILE_BACKEND: ${FILE_BACKEND}
build: build:
args: args:
- BUILD_TAGS=e2etest - BUILD_TAGS=e2etest