From 5550729120ac9f5e9361c7f9cf25b9075a33a94a Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Thu, 3 Jul 2025 11:34:34 -0700 Subject: [PATCH] feat: encrypt private keys saved on disk and in database (#682) Co-authored-by: Kyle Mendell --- .../internal/bootstrap/services_bootstrap.go | 2 +- backend/internal/common/env_config.go | 99 ++++-- backend/internal/common/env_config_test.go | 188 ++++++++++ backend/internal/model/kv.go | 11 + .../service/app_config_service_test.go | 32 +- backend/internal/service/e2etest_service.go | 4 +- backend/internal/service/jwt_service.go | 148 +++----- backend/internal/service/jwt_service_test.go | 243 ++++++++----- backend/internal/service/oidc_service_test.go | 15 +- backend/internal/utils/crypto/crypto.go | 69 ++++ backend/internal/utils/crypto/crypto_test.go | 208 +++++++++++ backend/internal/utils/jwk/key_provider.go | 50 +++ .../utils/jwk/key_provider_database.go | 109 ++++++ .../utils/jwk/key_provider_database_test.go | 275 +++++++++++++++ .../internal/utils/jwk/key_provider_file.go | 202 +++++++++++ .../utils/jwk/key_provider_file_test.go | 320 +++++++++++++++++ backend/internal/utils/jwk/utils.go | 180 ++++++++++ backend/internal/utils/jwk/utils_test.go | 324 ++++++++++++++++++ backend/internal/utils/jwk_util.go | 69 ---- .../testing/database.go} | 39 +-- .../internal/utils/testing/round_tripper.go | 38 ++ .../postgres/20250630000000_kv_table.down.sql | 1 + .../postgres/20250630000000_kv_table.up.sql | 6 + .../sqlite/20250630000000_kv_table.down.sql | 1 + .../sqlite/20250630000000_kv_table.up.sql | 6 + 25 files changed, 2311 insertions(+), 328 deletions(-) create mode 100644 backend/internal/common/env_config_test.go create mode 100644 backend/internal/model/kv.go create mode 100644 backend/internal/utils/crypto/crypto.go create mode 100644 backend/internal/utils/crypto/crypto_test.go create mode 100644 backend/internal/utils/jwk/key_provider.go create mode 100644 backend/internal/utils/jwk/key_provider_database.go create mode 100644 backend/internal/utils/jwk/key_provider_database_test.go create mode 100644 backend/internal/utils/jwk/key_provider_file.go create mode 100644 backend/internal/utils/jwk/key_provider_file_test.go create mode 100644 backend/internal/utils/jwk/utils.go create mode 100644 backend/internal/utils/jwk/utils_test.go delete mode 100644 backend/internal/utils/jwk_util.go rename backend/internal/{service/testutils_test.go => utils/testing/database.go} (68%) create mode 100644 backend/internal/utils/testing/round_tripper.go create mode 100644 backend/resources/migrations/postgres/20250630000000_kv_table.down.sql create mode 100644 backend/resources/migrations/postgres/20250630000000_kv_table.up.sql create mode 100644 backend/resources/migrations/sqlite/20250630000000_kv_table.down.sql create mode 100644 backend/resources/migrations/sqlite/20250630000000_kv_table.up.sql diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go index 892f47d0..3ab04da5 100644 --- a/backend/internal/bootstrap/services_bootstrap.go +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -38,7 +38,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv svc.geoLiteService = service.NewGeoLiteService(httpClient) svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService) - svc.jwtService = service.NewJwtService(svc.appConfigService) + svc.jwtService = service.NewJwtService(db, svc.appConfigService) svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService) svc.customClaimService = service.NewCustomClaimService(db) diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index 5e43c65a..2902d8dc 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -1,6 +1,8 @@ package common import ( + "errors" + "fmt" "log" "net/url" @@ -18,9 +20,10 @@ const ( ) const ( - DbProviderSqlite DbProvider = "sqlite" - DbProviderPostgres DbProvider = "postgres" - MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz" + DbProviderSqlite DbProvider = "sqlite" + DbProviderPostgres DbProvider = "postgres" + MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz" + defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate" ) type EnvConfigSchema struct { @@ -30,6 +33,9 @@ type EnvConfigSchema struct { DbConnectionString string `env:"DB_CONNECTION_STRING"` UploadPath string `env:"UPLOAD_PATH"` KeysPath string `env:"KEYS_PATH"` + KeysStorage string `env:"KEYS_STORAGE"` + EncryptionKey string `env:"ENCRYPTION_KEY"` + EncryptionKeyFile string `env:"ENCRYPTION_KEY_FILE"` Port string `env:"PORT"` Host string `env:"HOST"` UnixSocket string `env:"UNIX_SOCKET"` @@ -45,52 +51,83 @@ type EnvConfigSchema struct { AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"` } -var EnvConfig = &EnvConfigSchema{ - AppEnv: "production", - DbProvider: "sqlite", - DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate", - UploadPath: "data/uploads", - KeysPath: "data/keys", - AppURL: "http://localhost:1411", - Port: "1411", - Host: "0.0.0.0", - UnixSocket: "", - UnixSocketMode: "", - MaxMindLicenseKey: "", - GeoLiteDBPath: "data/GeoLite2-City.mmdb", - GeoLiteDBUrl: MaxMindGeoLiteCityUrl, - LocalIPv6Ranges: "", - UiConfigDisabled: false, - MetricsEnabled: false, - TracingEnabled: false, - TrustProxy: false, - AnalyticsDisabled: false, -} +var EnvConfig = defaultConfig() func init() { - if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil { - log.Fatal(err) + err := parseEnvConfig() + if err != nil { + log.Fatalf("Configuration error: %v", err) + } +} + +func defaultConfig() EnvConfigSchema { + return EnvConfigSchema{ + AppEnv: "production", + DbProvider: "sqlite", + DbConnectionString: "", + UploadPath: "data/uploads", + KeysPath: "data/keys", + KeysStorage: "", // "database" or "file" + EncryptionKey: "", + AppURL: "http://localhost:1411", + Port: "1411", + Host: "0.0.0.0", + UnixSocket: "", + UnixSocketMode: "", + MaxMindLicenseKey: "", + GeoLiteDBPath: "data/GeoLite2-City.mmdb", + GeoLiteDBUrl: MaxMindGeoLiteCityUrl, + LocalIPv6Ranges: "", + UiConfigDisabled: false, + MetricsEnabled: false, + TracingEnabled: false, + TrustProxy: false, + AnalyticsDisabled: false, + } +} + +func parseEnvConfig() error { + err := env.ParseWithOptions(&EnvConfig, env.Options{}) + if err != nil { + return fmt.Errorf("error parsing env config: %w", err) } // Validate the environment variables switch EnvConfig.DbProvider { case DbProviderSqlite: if EnvConfig.DbConnectionString == "" { - log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database") + EnvConfig.DbConnectionString = defaultSqliteConnString } case DbProviderPostgres: if EnvConfig.DbConnectionString == "" { - log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database") + return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database") } default: - log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'") + return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'") } parsedAppUrl, err := url.Parse(EnvConfig.AppURL) if err != nil { - log.Fatal("APP_URL is not a valid URL") + return errors.New("APP_URL is not a valid URL") } if parsedAppUrl.Path != "" { - log.Fatal("APP_URL must not contain a path") + return errors.New("APP_URL must not contain a path") } + + switch EnvConfig.KeysStorage { + // KeysStorage defaults to "file" if empty + case "": + EnvConfig.KeysStorage = "file" + case "database": + // If KeysStorage is "database", a key must be specified + if EnvConfig.EncryptionKey == "" && EnvConfig.EncryptionKeyFile == "" { + return errors.New("ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty when KEYS_STORAGE is database") + } + case "file": + // All good, these are valid values + default: + return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage) + } + + return nil } diff --git a/backend/internal/common/env_config_test.go b/backend/internal/common/env_config_test.go new file mode 100644 index 00000000..024eb4c4 --- /dev/null +++ b/backend/internal/common/env_config_test.go @@ -0,0 +1,188 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseEnvConfig(t *testing.T) { + // Store original config to restore later + originalConfig := EnvConfig + t.Cleanup(func() { + EnvConfig = originalConfig + }) + + t.Run("should parse valid SQLite config correctly", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider) + }) + + t.Run("should parse valid Postgres config correctly", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "postgres") + t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db") + t.Setenv("APP_URL", "https://example.com") + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider) + }) + + t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "invalid") + t.Setenv("DB_CONNECTION_STRING", "test") + t.Setenv("APP_URL", "http://localhost:3000") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "invalid DB_PROVIDER value") + }) + + t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty + t.Setenv("APP_URL", "http://localhost:3000") + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString) + }) + + t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "postgres") + t.Setenv("APP_URL", "http://localhost:3000") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres") + }) + + t.Run("should fail with invalid APP_URL", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "€://not-a-valid-url") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "APP_URL is not a valid URL") + }) + + t.Run("should fail when APP_URL contains path", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000/path") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "APP_URL must not contain a path") + }) + + t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, "file", EnvConfig.KeysStorage) + }) + + t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("KEYS_STORAGE", "database") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty") + }) + + t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) { + validStorageTypes := []string{"file", "database"} + + for _, storage := range validStorageTypes { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("KEYS_STORAGE", storage) + if storage == "database" { + t.Setenv("ENCRYPTION_KEY", "test-key") + } + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, storage, EnvConfig.KeysStorage) + } + }) + + t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("KEYS_STORAGE", "invalid") + + err := parseEnvConfig() + require.Error(t, err) + assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE") + }) + + t.Run("should parse boolean environment variables correctly", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("UI_CONFIG_DISABLED", "true") + t.Setenv("METRICS_ENABLED", "true") + t.Setenv("TRACING_ENABLED", "false") + t.Setenv("TRUST_PROXY", "true") + t.Setenv("ANALYTICS_DISABLED", "false") + + err := parseEnvConfig() + require.NoError(t, err) + assert.True(t, EnvConfig.UiConfigDisabled) + assert.True(t, EnvConfig.MetricsEnabled) + assert.False(t, EnvConfig.TracingEnabled) + assert.True(t, EnvConfig.TrustProxy) + assert.False(t, EnvConfig.AnalyticsDisabled) + }) + + t.Run("should parse string environment variables correctly", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "postgres") + t.Setenv("DB_CONNECTION_STRING", "postgres://test") + t.Setenv("APP_URL", "https://prod.example.com") + t.Setenv("APP_ENV", "staging") + t.Setenv("UPLOAD_PATH", "/custom/uploads") + t.Setenv("KEYS_PATH", "/custom/keys") + t.Setenv("PORT", "8080") + t.Setenv("HOST", "127.0.0.1") + t.Setenv("UNIX_SOCKET", "/tmp/app.sock") + t.Setenv("MAXMIND_LICENSE_KEY", "test-license") + t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb") + + err := parseEnvConfig() + require.NoError(t, err) + assert.Equal(t, "staging", EnvConfig.AppEnv) + assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath) + assert.Equal(t, "8080", EnvConfig.Port) + assert.Equal(t, "127.0.0.1", EnvConfig.Host) + }) +} diff --git a/backend/internal/model/kv.go b/backend/internal/model/kv.go new file mode 100644 index 00000000..a7a5d851 --- /dev/null +++ b/backend/internal/model/kv.go @@ -0,0 +1,11 @@ +package model + +type KV struct { + Key string `gorm:"primaryKey;not null"` + Value *string +} + +// TableName overrides the table name used by KV to `kv` +func (KV) TableName() string { + return "kv" +} diff --git a/backend/internal/service/app_config_service_test.go b/backend/internal/service/app_config_service_test.go index 5b538919..f22684fc 100644 --- a/backend/internal/service/app_config_service_test.go +++ b/backend/internal/service/app_config_service_test.go @@ -4,10 +4,12 @@ import ( "sync/atomic" "testing" + "github.com/stretchr/testify/require" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" - "github.com/stretchr/testify/require" + testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" ) // NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values @@ -22,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService { func TestLoadDbConfig(t *testing.T) { t.Run("empty config table", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) service := &AppConfigService{ db: db, } @@ -36,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("loads value from config table", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Populate the config table with some initial values err := db. @@ -66,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("ignores unknown config keys", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Add an entry with a key that doesn't exist in the config struct err := db.Create([]model.AppConfigVariable{ @@ -87,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("loading config multiple times", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Initial state err := db.Create([]model.AppConfigVariable{ @@ -129,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) { common.EnvConfig.UiConfigDisabled = true // Create database with config that should be ignored - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) err := db.Create([]model.AppConfigVariable{ {Key: "appName", Value: "DB App"}, {Key: "sessionDuration", Value: "120"}, @@ -165,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) { common.EnvConfig.UiConfigDisabled = false // Create database with config values that should take precedence - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) err := db.Create([]model.AppConfigVariable{ {Key: "appName", Value: "DB App"}, {Key: "sessionDuration", Value: "120"}, @@ -189,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) { func TestUpdateAppConfigValues(t *testing.T) { t.Run("update single value", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -214,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("update multiple values", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -258,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("empty value resets to default", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -279,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("error with odd number of arguments", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -295,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("error with invalid key", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -313,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) { func TestUpdateAppConfig(t *testing.T) { t.Run("updates configuration values from DTO", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -386,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) { }) t.Run("empty values reset to defaults", func(t *testing.T) { - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create a service with default config and modify some values service := &AppConfigService{ @@ -451,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) { // Disable UI config common.EnvConfig.UiConfigDisabled = true - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) service := &AppConfigService{ db: db, } diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index 8e6f8ad7..b7632e3a 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -17,6 +17,7 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/go-webauthn/webauthn/protocol" + "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" "gorm.io/gorm" @@ -25,6 +26,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/model" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" "github.com/pocket-id/pocket-id/backend/internal/utils" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" "github.com/pocket-id/pocket-id/backend/resources" ) @@ -60,7 +62,7 @@ func (s *TestService) initExternalIdP() error { return fmt.Errorf("failed to generate private key: %w", err) } - s.externalIdPKey, err = utils.ImportRawKey(rawKey) + s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "") if err != nil { return fmt.Errorf("failed to import private key: %w", err) } diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index a74653d4..5c36533b 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -2,23 +2,20 @@ package service import ( "context" - "crypto/rand" - "crypto/rsa" "encoding/json" "errors" "fmt" "log" - "os" - "path/filepath" "time" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" + "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" - "github.com/pocket-id/pocket-id/backend/internal/utils" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" ) const ( @@ -26,8 +23,9 @@ const ( // This is a JSON file containing a key encoded as JWK PrivateKeyFile = "jwt_private_key.json" - // RsaKeySize is the size, in bits, of the RSA key to generate if none is found - RsaKeySize = 2048 + // PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored + // This is a encrypted JSON file containing a key encoded as JWK + PrivateKeyFileEncrypted = "jwt_private_key.json.enc" // KeyUsageSigning is the usage for the private keys, for the "use" property KeyUsageSigning = "sig" @@ -59,58 +57,74 @@ const ( ) type JwtService struct { + envConfig *common.EnvConfigSchema privateKey jwk.Key keyId string appConfigService *AppConfigService jwksEncoded []byte } -func NewJwtService(appConfigService *AppConfigService) *JwtService { +func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService { service := &JwtService{} // Ensure keys are generated or loaded - if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil { + err := service.init(db, appConfigService, &common.EnvConfig) + if err != nil { log.Fatalf("Failed to initialize jwt service: %v", err) } return service } -func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error { +func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) { s.appConfigService = appConfigService + s.envConfig = envConfig // Ensure keys are generated or loaded - return s.loadOrGenerateKey(keysPath) + return s.loadOrGenerateKey(db) } -// loadOrGenerateKey loads the private key from the given path or generates it if not existing. -func (s *JwtService) loadOrGenerateKey(keysPath string) error { - var key jwk.Key - - // First, check if we have a JWK file - // If we do, then we just load that - jwkPath := filepath.Join(keysPath, PrivateKeyFile) - ok, err := utils.FileExists(jwkPath) +func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error { + // Get the key provider + keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value) if err != nil { - return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err) + return fmt.Errorf("failed to get key provider: %w", err) } - if ok { - key, err = s.loadKeyJWK(jwkPath) - if err != nil { - return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err) - } - // Set the key, and we are done + // Try loading a key + key, err := keyProvider.LoadKey() + if err != nil { + return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err) + } + + // If we have a key, store it in the object and we're done + if key != nil { err = s.SetKey(key) if err != nil { return fmt.Errorf("failed to set private key: %w", err) } - return nil } // If we are here, we need to generate a new key - key, err = s.generateNewRSAKey() + err = s.generateKey() + if err != nil { + return fmt.Errorf("failed to generate key: %w", err) + } + + // Save the newly-generated key + err = keyProvider.SaveKey(s.privateKey) + if err != nil { + return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err) + } + + return nil +} + +// generateKey generates a new key and stores it in the object +func (s *JwtService) generateKey() error { + // Default is to generate RS256 (RSA-2048) keys + key, err := jwkutils.GenerateKey(jwa.RS256().String(), "") if err != nil { return fmt.Errorf("failed to generate new private key: %w", err) } @@ -121,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error { return fmt.Errorf("failed to set private key: %w", err) } - // Save the key as JWK - err = SaveKeyJWK(s.privateKey, jwkPath) - if err != nil { - return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err) - } - return nil } @@ -192,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { Subject(user.ID). Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())). IssuedAt(now). - Issuer(common.EnvConfig.AppURL). + Issuer(s.envConfig.AppURL). Build() if err != nil { return "", fmt.Errorf("failed to build token: %w", err) } - err = SetAudienceString(token, common.EnvConfig.AppURL) + err = SetAudienceString(token, s.envConfig.AppURL) if err != nil { return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } @@ -229,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) { jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), - jwt.WithAudience(common.EnvConfig.AppURL), - jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithAudience(s.envConfig.AppURL), + jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)), ) if err != nil { @@ -246,7 +254,7 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no token, err := jwt.NewBuilder(). Expiration(now.Add(1 * time.Hour)). IssuedAt(now). - Issuer(common.EnvConfig.AppURL). + Issuer(s.envConfig.AppURL). Build() if err != nil { return nil, fmt.Errorf("failed to build token: %w", err) @@ -305,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), - jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)), ) @@ -335,7 +343,7 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw Subject(user.ID). Expiration(now.Add(1 * time.Hour)). IssuedAt(now). - Issuer(common.EnvConfig.AppURL). + Issuer(s.envConfig.AppURL). Build() if err != nil { return nil, fmt.Errorf("failed to build token: %w", err) @@ -377,7 +385,7 @@ func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, erro jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), - jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)), ) if err != nil { @@ -393,7 +401,7 @@ func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, r Subject(userID). Expiration(now.Add(RefreshTokenDuration)). IssuedAt(now). - Issuer(common.EnvConfig.AppURL). + Issuer(s.envConfig.AppURL). Build() if err != nil { return "", fmt.Errorf("failed to build token: %w", err) @@ -430,7 +438,7 @@ func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, client jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), - jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)), ) if err != nil { @@ -488,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) { return nil, fmt.Errorf("failed to get public key: %w", err) } - utils.EnsureAlgInKey(pubKey) + jwkutils.EnsureAlgInKey(pubKey, "", "") return pubKey, nil } @@ -517,56 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) { return alg, nil } -func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read key data: %w", err) - } - - key, err := jwk.ParseKey(data) - if err != nil { - return nil, fmt.Errorf("failed to parse key: %w", err) - } - - return key, nil -} - -func (s *JwtService) generateNewRSAKey() (jwk.Key, error) { - // We generate RSA keys only - rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize) - if err != nil { - return nil, fmt.Errorf("failed to generate RSA private key: %w", err) - } - - // Import the raw key - return utils.ImportRawKey(rawKey) -} - -// SaveKeyJWK saves a JWK to a file -func SaveKeyJWK(key jwk.Key, path string) error { - dir := filepath.Dir(path) - err := os.MkdirAll(dir, 0700) - if err != nil { - return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err) - } - - keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return fmt.Errorf("failed to create key file: %w", err) - } - defer keyFile.Close() - - // Write the JSON file to disk - enc := json.NewEncoder(keyFile) - enc.SetEscapeHTML(false) - err = enc.Encode(key) - if err != nil { - return fmt.Errorf("failed to write key file: %w", err) - } - - return nil -} - // GetIsAdmin returns the value of the "isAdmin" claim in the token func GetIsAdmin(token jwt.Token) (bool, error) { if !token.Has(IsAdminClaim) { diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 0a00f2fe..70d4915a 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -21,7 +21,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" - "github.com/pocket-id/pocket-id/backend/internal/utils" + jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" ) func TestJwtService_Init(t *testing.T) { @@ -33,9 +33,16 @@ func TestJwtService_Init(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Initialize the JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Verify the private key was set @@ -66,9 +73,16 @@ func TestJwtService_Init(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // First create a service to generate a key firstService := &JwtService{} - err := firstService.init(mockConfig, tempDir) + err := firstService.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err) // Get the key ID of the first service @@ -77,7 +91,7 @@ func TestJwtService_Init(t *testing.T) { // Now create a new service that should load the existing key secondService := &JwtService{} - err = secondService.init(mockConfig, tempDir) + err = secondService.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err) // Verify the loaded key has the same ID as the original @@ -90,12 +104,19 @@ func TestJwtService_Init(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Create a new JWK and save it to disk origKeyID := createECDSAKeyJWK(t, tempDir) // Now create a new service that should load the existing key svc := &JwtService{} - err := svc.init(mockConfig, tempDir) + err := svc.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err) // Ensure loaded key has the right algorithm @@ -113,12 +134,19 @@ func TestJwtService_Init(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Create a new JWK and save it to disk origKeyID := createEdDSAKeyJWK(t, tempDir) // Now create a new service that should load the existing key svc := &JwtService{} - err := svc.init(mockConfig, tempDir) + err := svc.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err) // Ensure loaded key has the right algorithm and curve @@ -147,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Create a JWT service with initialized key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Get the JWK (public key) @@ -178,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Create an ECDSA key and save it as JWK originalKeyID := createECDSAKeyJWK(t, tempDir) // Create a JWT service that loads the ECDSA key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Get the JWK (public key) @@ -216,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() + // Setup the environment variable required by the token verification + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } + // Create an EdDSA key and save it as JWK originalKeyID := createEdDSAKeyJWK(t, tempDir) // Create a JWT service that loads the EdDSA key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Get the JWK (public key) @@ -276,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) { }) // Setup the environment variable required by the token verification - originalAppURL := common.EnvConfig.AppURL - common.EnvConfig.AppURL = "https://test.example.com" - defer func() { - common.EnvConfig.AppURL = originalAppURL - }() + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } t.Run("generates token for regular user", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test user @@ -328,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { t.Run("generates token for admin user", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test admin user @@ -364,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { }) service := &JwtService{} - err := service.init(customMockConfig, tempDir) + err := service.init(nil, customMockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test user @@ -399,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -453,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -507,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -563,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) { }) // Setup the environment variable required by the token verification - originalAppURL := common.EnvConfig.AppURL - common.EnvConfig.AppURL = "https://test.example.com" - defer func() { - common.EnvConfig.AppURL = originalAppURL - }() + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } t.Run("generates and verifies ID token with standard claims", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create test claims @@ -601,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") // Check token expiration time is approximately 1 hour from now expectedExp := time.Now().Add(1 * time.Hour) @@ -614,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { t.Run("can accept expired tokens if told so", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create test claims @@ -628,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { // Create a token that's already expired token, err := jwt.NewBuilder(). Subject(userClaims["sub"].(string)). - Issuer(common.EnvConfig.AppURL). + Issuer(service.envConfig.AppURL). Audience([]string{clientID}). IssuedAt(time.Now().Add(-2 * time.Hour)). Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago @@ -666,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) { assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") }) t.Run("generates and verifies ID token with nonce", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create test claims with nonce @@ -703,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { t.Run("fails verification with incorrect issuer", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Generate a token with standard claims @@ -714,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { require.NoError(t, err, "Failed to generate ID token") // Temporarily change the app URL to simulate wrong issuer - common.EnvConfig.AppURL = "https://wrong-issuer.com" + service.envConfig.AppURL = "https://wrong-issuer.com" // Verify should fail due to issuer mismatch _, err = service.VerifyIdToken(tokenString, false) @@ -731,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -762,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") // Verify the key type is OKP publicKey, err := service.GetPublicJWK() @@ -784,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -795,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) { // Create test claims userClaims := map[string]interface{}{ "sub": "ecdsauser456", - "name": "ECDSA User", "email": "ecdsauser@example.com", } const clientID = "ecdsa-client-123" @@ -815,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") // Verify the key type is EC publicKey, err := service.GetPublicJWK() @@ -837,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -868,17 +934,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { assert.Equal(t, "rsauser456", subject, "Token subject should match user ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") - - // Verify the key type is RSA - publicKey, err := service.GetPublicJWK() - require.NoError(t, err) - assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA") - - // Verify the algorithm is RS256 - alg, ok := publicKey.Algorithm() - require.True(t, ok) - assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") }) } @@ -892,16 +948,16 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { }) // Setup the environment variable required by the token verification - originalAppURL := common.EnvConfig.AppURL - common.EnvConfig.AppURL = "https://test.example.com" - defer func() { - common.EnvConfig.AppURL = originalAppURL - }() + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test user @@ -931,7 +987,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID") issuer, ok := claims.Issuer() _ = assert.True(t, ok, "Issuer not found in token") && - assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL") + assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") // Check token expiration time is approximately 1 hour from now expectedExp := time.Now().Add(1 * time.Hour) @@ -944,7 +1000,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { t.Run("fails verification for expired token", func(t *testing.T) { // Create a JWT service with a mock function to generate an expired token service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test user @@ -961,7 +1017,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago IssuedAt(time.Now().Add(-2 * time.Hour)). Audience([]string{clientID}). - Issuer(common.EnvConfig.AppURL). + Issuer(service.envConfig.AppURL). Build() require.NoError(t, err, "Failed to build token") @@ -980,11 +1036,17 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { t.Run("fails verification with invalid signature", func(t *testing.T) { // Create two JWT services with different keys service1 := &JwtService{} - err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir + err := service1.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: t.TempDir(), // Use a different temp dir + }) require.NoError(t, err, "Failed to initialize first JWT service") service2 := &JwtService{} - err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir + err = service2.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: t.TempDir(), // Use a different temp dir + }) require.NoError(t, err, "Failed to initialize second JWT service") // Create a test user @@ -1014,7 +1076,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -1068,7 +1133,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -1122,7 +1190,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { // Create a JWT service that loads the key service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") // Verify it loaded the right key @@ -1176,16 +1247,16 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { mockConfig := NewTestAppConfigService(&model.AppConfig{}) // Setup the environment variable required by the token verification - originalAppURL := common.EnvConfig.AppURL - common.EnvConfig.AppURL = "https://test.example.com" - defer func() { - common.EnvConfig.AppURL = originalAppURL - }() + mockEnvConfig := &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + KeysStorage: "file", + KeysPath: tempDir, + } t.Run("generates and verifies refresh token", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Create a test user @@ -1211,7 +1282,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { t.Run("fails verification for expired token", func(t *testing.T) { // Create a JWT service service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, mockEnvConfig) require.NoError(t, err, "Failed to initialize JWT service") // Generate a token using JWT directly to create an expired token @@ -1220,7 +1291,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago IssuedAt(time.Now().Add(-2 * time.Hour)). Audience([]string{"client123"}). - Issuer(common.EnvConfig.AppURL). + Issuer(service.envConfig.AppURL). Build() require.NoError(t, err, "Failed to build token") @@ -1236,11 +1307,17 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { t.Run("fails verification with invalid signature", func(t *testing.T) { // Create two JWT services with different keys service1 := &JwtService{} - err := service1.init(mockConfig, t.TempDir()) + err := service1.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: t.TempDir(), // Use a different temp dir + }) require.NoError(t, err, "Failed to initialize first JWT service") service2 := &JwtService{} - err = service2.init(mockConfig, t.TempDir()) + err = service2.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: t.TempDir(), // Use a different temp dir + }) require.NoError(t, err, "Failed to initialize second JWT service") // Generate a token with the first service @@ -1308,7 +1385,10 @@ func TestGetTokenType(t *testing.T) { // Initialize the JWT service mockConfig := NewTestAppConfigService(&model.AppConfig{}) service := &JwtService{} - err := service.init(mockConfig, tempDir) + err := service.init(nil, mockConfig, &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: tempDir, + }) require.NoError(t, err, "Failed to initialize JWT service") buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string { @@ -1402,10 +1482,19 @@ func TestGetTokenType(t *testing.T) { func importKey(t *testing.T, privateKeyRaw any, path string) string { t.Helper() - privateKey, err := utils.ImportRawKey(privateKeyRaw) + privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "") require.NoError(t, err, "Failed to import private key") - err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile)) + keyProvider := &jwkutils.KeyProviderFile{} + err = keyProvider.Init(jwkutils.KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysStorage: "file", + KeysPath: path, + }, + }) + require.NoError(t, err, "Failed to init file key provider") + + err = keyProvider.SaveKey(privateKey) require.NoError(t, err, "Failed to save key") kid, _ := privateKey.KeyID() diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index c642a4e9..86448cad 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -18,6 +18,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" + testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" ) // generateTestECDSAKey creates an ECDSA key for testing @@ -62,12 +63,12 @@ func TestOidcService_jwkSetForURL(t *testing.T) { ) mockResponses := map[string]*http.Response{ //nolint:bodyclose - url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)), + url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)), //nolint:bodyclose - url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)), + url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)), } httpClient := &http.Client{ - Transport: &MockRoundTripper{ + Transport: &testutils.MockRoundTripper{ Responses: mockResponses, }, } @@ -139,7 +140,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { var err error // Create a test database - db := newDatabaseForTest(t) + db := testutils.NewDatabaseForTest(t) // Create two JWKs for testing privateJWK, jwkSetJSON := generateTestECDSAKey(t) @@ -149,12 +150,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { // Create a mock HTTP client with custom transport to return the JWKS httpClient := &http.Client{ - Transport: &MockRoundTripper{ + Transport: &testutils.MockRoundTripper{ Responses: map[string]*http.Response{ //nolint:bodyclose - federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)), + federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)), //nolint:bodyclose - federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)), + federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)), }, }, } diff --git a/backend/internal/utils/crypto/crypto.go b/backend/internal/utils/crypto/crypto.go new file mode 100644 index 00000000..90d831f0 --- /dev/null +++ b/backend/internal/utils/crypto/crypto.go @@ -0,0 +1,69 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "fmt" + "io" +) + +// ErrDecrypt is returned by Decrypt when the operation failed for any reason +var ErrDecrypt = errors.New("failed to decrypt data") + +// Encrypt a byte slice using AES-GCM and a random nonce +// Important: do not encrypt more than ~4 billion messages with the same key! +func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create block cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + _, err = io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, fmt.Errorf("failed to generate random nonce: %w", err) + } + + // Allocate the slice for the result, with additional space for the nonce and overhead + ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead()) + ciphertext = append(ciphertext, nonce...) + + // Encrypt the plaintext + // Tag is automatically added at the end + ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData) + + return ciphertext, nil +} + +// Decrypt a byte slice using AES-GCM +func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create block cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %w", err) + } + + // Extract the nonce + if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) { + return nil, ErrDecrypt + } + + // Decrypt the data + plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData) + if err != nil { + // Note: we do not return the exact error here, to avoid disclosing information + return nil, ErrDecrypt + } + + return plaintext, nil +} diff --git a/backend/internal/utils/crypto/crypto_test.go b/backend/internal/utils/crypto/crypto_test.go new file mode 100644 index 00000000..c9459dec --- /dev/null +++ b/backend/internal/utils/crypto/crypto_test.go @@ -0,0 +1,208 @@ +package crypto + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptDecrypt(t *testing.T) { + tests := []struct { + name string + keySize int + plaintext string + associatedData []byte + }{ + { + name: "AES-128 with short plaintext", + keySize: 16, + plaintext: "Hello, World!", + associatedData: []byte("test-aad"), + }, + { + name: "AES-192 with medium plaintext", + keySize: 24, + plaintext: "This is a longer message to test encryption and decryption", + associatedData: []byte("associated-data-192"), + }, + { + name: "AES-256 with unicode", + keySize: 32, + plaintext: "Hello δΈ–η•Œ! 🌍 Testing unicode characters", //nolint:gosmopolitan + associatedData: []byte("unicode-test"), + }, + { + name: "No associated data", + keySize: 32, + plaintext: "Testing without associated data", + associatedData: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate random key + key := make([]byte, tt.keySize) + _, err := rand.Read(key) + require.NoError(t, err, "Failed to generate random key") + + plaintext := []byte(tt.plaintext) + + // Test encryption + ciphertext, err := Encrypt(key, plaintext, tt.associatedData) + require.NoError(t, err, "Encrypt should succeed") + + // Verify ciphertext is different from plaintext (unless empty) + if len(plaintext) > 0 { + assert.NotEqual(t, plaintext, ciphertext) + } + + // Test decryption + decrypted, err := Decrypt(key, ciphertext, tt.associatedData) + require.NoError(t, err, "Decrypt should succeed") + + // Verify decrypted text matches original + assert.Equal(t, plaintext, decrypted, "Decrypted text should match original") + }) + } +} + +func TestEncryptWithInvalidKeySize(t *testing.T) { + invalidKeySizes := []int{8, 12, 33, 47, 55, 128} + + for _, keySize := range invalidKeySizes { + t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) { + key := make([]byte, keySize) + plaintext := []byte("test message") + + _, err := Encrypt(key, plaintext, nil) + require.Error(t, err) + assert.ErrorContains(t, err, "invalid key size") + }) + } +} + +func TestDecryptWithInvalidKeySize(t *testing.T) { + invalidKeySizes := []int{8, 12, 33, 47, 55, 128} + + for _, keySize := range invalidKeySizes { + t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) { + key := make([]byte, keySize) + ciphertext := []byte("fake ciphertext") + + _, err := Decrypt(key, ciphertext, nil) + require.Error(t, err) + assert.ErrorContains(t, err, "invalid key size") + }) + } +} + +func TestDecryptWithInvalidCiphertext(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err, "Failed to generate random key") + + tests := []struct { + name string + ciphertext []byte + }{ + { + name: "empty ciphertext", + ciphertext: []byte{}, + }, + { + name: "too short ciphertext", + ciphertext: []byte("short"), + }, + { + name: "random invalid data", + ciphertext: []byte("this is not valid encrypted data"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Decrypt(key, tt.ciphertext, nil) + require.Error(t, err) + require.ErrorIs(t, err, ErrDecrypt) + }) + } +} + +func TestDecryptWithWrongKey(t *testing.T) { + // Generate two different keys + key1 := make([]byte, 32) + key2 := make([]byte, 32) + _, err := rand.Read(key1) + require.NoError(t, err) + _, err = rand.Read(key2) + require.NoError(t, err) + + plaintext := []byte("secret message") + + // Encrypt with key1 + ciphertext, err := Encrypt(key1, plaintext, nil) + require.NoError(t, err) + + // Try to decrypt with key2 + _, err = Decrypt(key2, ciphertext, nil) + require.Error(t, err) + require.ErrorIs(t, err, ErrDecrypt) +} + +func TestDecryptWithWrongAssociatedData(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err, "Failed to generate random key") + + plaintext := []byte("secret message") + correctAAD := []byte("correct-aad") + wrongAAD := []byte("wrong-aad") + + // Encrypt with correct AAD + ciphertext, err := Encrypt(key, plaintext, correctAAD) + require.NoError(t, err) + + // Try to decrypt with wrong AAD + _, err = Decrypt(key, ciphertext, wrongAAD) + require.Error(t, err) + require.ErrorIs(t, err, ErrDecrypt) + + // Verify correct AAD works + decrypted, err := Decrypt(key, ciphertext, correctAAD) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD") +} + +func TestEncryptDecryptConsistency(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + plaintext := []byte("consistency test message") + associatedData := []byte("test-aad") + + // Encrypt multiple times and verify we get different ciphertexts (due to random IV) + ciphertext1, err := Encrypt(key, plaintext, associatedData) + require.NoError(t, err) + + ciphertext2, err := Encrypt(key, plaintext, associatedData) + require.NoError(t, err) + + // Ciphertexts should be different (due to random IV) + assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts") + + // Both should decrypt to the same plaintext + decrypted1, err := Decrypt(key, ciphertext1, associatedData) + require.NoError(t, err) + + decrypted2, err := Decrypt(key, ciphertext2, associatedData) + require.NoError(t, err) + + assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original") + assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original") + assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical") +} diff --git a/backend/internal/utils/jwk/key_provider.go b/backend/internal/utils/jwk/key_provider.go new file mode 100644 index 00000000..46da3f3f --- /dev/null +++ b/backend/internal/utils/jwk/key_provider.go @@ -0,0 +1,50 @@ +package jwk + +import ( + "fmt" + + "github.com/lestrrat-go/jwx/v3/jwk" + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/common" +) + +type KeyProviderOpts struct { + EnvConfig *common.EnvConfigSchema + DB *gorm.DB + Kek []byte +} + +type KeyProvider interface { + Init(opts KeyProviderOpts) error + LoadKey() (jwk.Key, error) + SaveKey(key jwk.Key) error +} + +func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) { + // Load the encryption key (KEK) if present + kek, err := LoadKeyEncryptionKey(envConfig, instanceID) + if err != nil { + return nil, fmt.Errorf("failed to load encryption key: %w", err) + } + + // Get the key provider + switch envConfig.KeysStorage { + case "file", "": + keyProvider = &KeyProviderFile{} + case "database": + keyProvider = &KeyProviderDatabase{} + default: + return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage) + } + err = keyProvider.Init(KeyProviderOpts{ + DB: db, + EnvConfig: envConfig, + Kek: kek, + }) + if err != nil { + return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err) + } + + return keyProvider, nil +} diff --git a/backend/internal/utils/jwk/key_provider_database.go b/backend/internal/utils/jwk/key_provider_database.go new file mode 100644 index 00000000..bca0f782 --- /dev/null +++ b/backend/internal/utils/jwk/key_provider_database.go @@ -0,0 +1,109 @@ +package jwk + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "time" + + "github.com/lestrrat-go/jwx/v3/jwk" + "gorm.io/gorm" + + "github.com/pocket-id/pocket-id/backend/internal/model" + cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" +) + +const PrivateKeyDBKey = "jwt_private_key.json" + +type KeyProviderDatabase struct { + db *gorm.DB + kek []byte +} + +func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error { + if len(opts.Kek) == 0 { + return errors.New("an encryption key is required when using the 'database' key provider") + } + + f.db = opts.DB + f.kek = opts.Kek + + return nil +} + +func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) { + row := model.KV{ + Key: PrivateKeyDBKey, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err = f.db.WithContext(ctx).First(&row).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + // Key not present in the database - return nil so a new one can be generated + return nil, nil + } else if err != nil { + return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err) + } + + if row.Value == nil || *row.Value == "" { + // Key not present in the database - return nil so a new one can be generated + return nil, nil + } + + // Decode from base64 + enc, err := base64.StdEncoding.DecodeString(*row.Value) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err) + } + + // Decrypt the data + data, err := cryptoutils.Decrypt(f.kek, enc, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt private key: %w", err) + } + + // Parse the key + key, err = jwk.ParseKey(data) + if err != nil { + return nil, fmt.Errorf("failed to parse encrypted private key: %w", err) + } + + return key, nil +} + +func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error { + // Encode the key to JSON + data, err := EncodeJWKBytes(key) + if err != nil { + return fmt.Errorf("failed to encode key to JSON: %w", err) + } + + // Encrypt the key then encode to Base64 + enc, err := cryptoutils.Encrypt(f.kek, data, nil) + if err != nil { + return fmt.Errorf("failed to encrypt key: %w", err) + } + encB64 := base64.StdEncoding.EncodeToString(enc) + + // Save to database + row := model.KV{ + Key: PrivateKeyDBKey, + Value: &encB64, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err = f.db.WithContext(ctx).Create(&row).Error + if err != nil { + // There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time + // In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database) + return fmt.Errorf("failed to store private key in database: %w", err) + } + + return nil +} + +// Compile-time interface check +var _ KeyProvider = (*KeyProviderDatabase)(nil) diff --git a/backend/internal/utils/jwk/key_provider_database_test.go b/backend/internal/utils/jwk/key_provider_database_test.go new file mode 100644 index 00000000..fd5dd2bd --- /dev/null +++ b/backend/internal/utils/jwk/key_provider_database_test.go @@ -0,0 +1,275 @@ +package jwk + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "testing" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/model" + cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" + testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" +) + +func TestKeyProviderDatabase_Init(t *testing.T) { + t.Run("Init fails when KEK is not provided", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: nil, // No KEK + }) + require.Error(t, err, "Expected error when KEK is not provided") + require.ErrorContains(t, err, "encryption key is required") + }) + + t.Run("Init succeeds with KEK", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: generateTestKEK(t), + }) + require.NoError(t, err, "Expected no error when KEK is provided") + }) +} + +func TestKeyProviderDatabase_LoadKey(t *testing.T) { + // Generate a test key to use in our tests + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + key, err := jwk.Import(pk) + require.NoError(t, err) + + t.Run("LoadKey with no existing key", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Load key when none exists + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.Nil(t, loadedKey, "Expected nil key when no key exists in database") + }) + + t.Run("LoadKey with existing key", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Save a key + err = provider.SaveKey(key) + require.NoError(t, err) + + // Load the key + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database") + + // Verify the loaded key is the same as the original + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + loadedKeyBytes, err := EncodeJWKBytes(loadedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key") + }) + + t.Run("LoadKey with invalid base64", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Insert invalid base64 data + invalidBase64 := "not-valid-base64" + err = db.Create(&model.KV{ + Key: PrivateKeyDBKey, + Value: &invalidBase64, + }).Error + require.NoError(t, err) + + // Attempt to load the key + loadedKey, err := provider.LoadKey() + require.Error(t, err, "Expected error when loading key with invalid base64") + require.ErrorContains(t, err, "not a valid base64-encoded value") + assert.Nil(t, loadedKey, "Expected nil key when loading fails") + }) + + t.Run("LoadKey with invalid encrypted data", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Insert valid base64 but invalid encrypted data + invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data")) + err = db.Create(&model.KV{ + Key: PrivateKeyDBKey, + Value: &invalidData, + }).Error + require.NoError(t, err) + + // Attempt to load the key + loadedKey, err := provider.LoadKey() + require.Error(t, err, "Expected error when loading key with invalid encrypted data") + require.ErrorContains(t, err, "failed to decrypt") + assert.Nil(t, loadedKey, "Expected nil key when loading fails") + }) + + t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + originalKek := generateTestKEK(t) + + // Save a key with the original KEK + originalProvider := &KeyProviderDatabase{} + err := originalProvider.Init(KeyProviderOpts{ + DB: db, + Kek: originalKek, + }) + require.NoError(t, err) + + err = originalProvider.SaveKey(key) + require.NoError(t, err) + + // Now try to load with a different KEK + differentKek := generateTestKEK(t) + differentProvider := &KeyProviderDatabase{} + err = differentProvider.Init(KeyProviderOpts{ + DB: db, + Kek: differentKek, + }) + require.NoError(t, err) + + // Attempt to load the key with the wrong KEK + loadedKey, err := differentProvider.LoadKey() + require.Error(t, err, "Expected error when loading key with wrong KEK") + require.ErrorContains(t, err, "failed to decrypt") + assert.Nil(t, loadedKey, "Expected nil key when loading fails") + }) + + t.Run("LoadKey with invalid key data", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Create invalid key data (valid JSON but not a valid JWK) + invalidKeyData := []byte(`{"not": "a valid jwk"}`) + + // Encrypt the invalid key data + encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil) + require.NoError(t, err) + + // Base64 encode the encrypted data + encodedData := base64.StdEncoding.EncodeToString(encryptedData) + + // Save to database + err = db.Create(&model.KV{ + Key: PrivateKeyDBKey, + Value: &encodedData, + }).Error + require.NoError(t, err) + + // Attempt to load the key + loadedKey, err := provider.LoadKey() + require.Error(t, err, "Expected error when loading invalid key data") + require.ErrorContains(t, err, "failed to parse") + assert.Nil(t, loadedKey, "Expected nil key when loading fails") + }) +} + +func TestKeyProviderDatabase_SaveKey(t *testing.T) { + // Generate a test key to use in our tests + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + key, err := jwk.Import(pk) + require.NoError(t, err) + + t.Run("SaveKey and verify database record", func(t *testing.T) { + db := testutils.NewDatabaseForTest(t) + kek := generateTestKEK(t) + + provider := &KeyProviderDatabase{} + err := provider.Init(KeyProviderOpts{ + DB: db, + Kek: kek, + }) + require.NoError(t, err) + + // Save the key + err = provider.SaveKey(key) + require.NoError(t, err, "Expected no error when saving key") + + // Verify record exists in database + var kv model.KV + err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error + require.NoError(t, err, "Expected to find key in database") + require.NotNil(t, kv.Value, "Expected non-nil value in database") + assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database") + + // Decode and decrypt to verify content + encBytes, err := base64.StdEncoding.DecodeString(*kv.Value) + require.NoError(t, err, "Expected valid base64 encoding") + + decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil) + require.NoError(t, err, "Expected valid encrypted data") + + parsedKey, err := jwk.ParseKey(decBytes) + require.NoError(t, err, "Expected valid JWK data") + + // Compare keys + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + parsedKeyBytes, err := EncodeJWKBytes(parsedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key") + }) +} + +func generateTestKEK(t *testing.T) []byte { + t.Helper() + + // Generate a 32-byte kek + kek := make([]byte, 32) + _, err := rand.Read(kek) + require.NoError(t, err) + return kek +} diff --git a/backend/internal/utils/jwk/key_provider_file.go b/backend/internal/utils/jwk/key_provider_file.go new file mode 100644 index 00000000..b8f2b07f --- /dev/null +++ b/backend/internal/utils/jwk/key_provider_file.go @@ -0,0 +1,202 @@ +package jwk + +import ( + "encoding/base64" + "fmt" + "os" + "path/filepath" + + "github.com/lestrrat-go/jwx/v3/jwk" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/utils" + cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" +) + +const ( + // PrivateKeyFile is the path in the data/keys folder where the key is stored + // This is a JSON file containing a key encoded as JWK + PrivateKeyFile = "jwt_private_key.json" + + // PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored + // This is a encrypted JSON file containing a key encoded as JWK + PrivateKeyFileEncrypted = "jwt_private_key.json.enc" +) + +type KeyProviderFile struct { + envConfig *common.EnvConfigSchema + kek []byte +} + +func (f *KeyProviderFile) Init(opts KeyProviderOpts) error { + f.envConfig = opts.EnvConfig + f.kek = opts.Kek + + return nil +} + +func (f *KeyProviderFile) LoadKey() (jwk.Key, error) { + if len(f.kek) > 0 { + return f.loadEncryptedKey() + } + return f.loadKey() +} + +func (f *KeyProviderFile) SaveKey(key jwk.Key) error { + if len(f.kek) > 0 { + return f.saveKeyEncrypted(key) + } + return f.saveKey(key) +} + +func (f *KeyProviderFile) loadKey() (jwk.Key, error) { + var key jwk.Key + + // First, check if we have a JWK file + // If we do, then we just load that + jwkPath := f.jwkPath() + ok, err := utils.FileExists(jwkPath) + if err != nil { + return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err) + } + if !ok { + // File doesn't exist, no key was loaded + return nil, nil + } + + data, err := os.ReadFile(jwkPath) + if err != nil { + return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err) + } + + key, err = jwk.ParseKey(data) + if err != nil { + return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err) + } + + return key, nil +} + +func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) { + // First, check if we have an encrypted JWK file + // If we do, then we just load that + encJwkPath := f.encJwkPath() + ok, err := utils.FileExists(encJwkPath) + if err != nil { + return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err) + } + if ok { + encB64, err := os.ReadFile(encJwkPath) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err) + } + + // Decode from base64 + enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64))) + n, err := base64.StdEncoding.Decode(enc, encB64) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err) + } + + // Decrypt the data + data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err) + } + + // Parse the key + key, err = jwk.ParseKey(data) + if err != nil { + return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err) + } + + return key, nil + } + + // Check if we have an un-encrypted JWK file + key, err = f.loadKey() + if err != nil { + return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err) + } + if key == nil { + // No key exists, encrypted or un-encrypted + return nil, nil + } + + // If we are here, we have loaded a key that was un-encrypted + // We need to replace the plaintext key with the encrypted one before we return + err = f.saveKeyEncrypted(key) + if err != nil { + return nil, fmt.Errorf("failed to save encrypted key file: %w", err) + } + jwkPath := f.jwkPath() + err = os.Remove(jwkPath) + if err != nil { + return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err) + } + + return key, nil +} + +func (f *KeyProviderFile) saveKey(key jwk.Key) error { + err := os.MkdirAll(f.envConfig.KeysPath, 0700) + if err != nil { + return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err) + } + + jwkPath := f.jwkPath() + keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err) + } + defer keyFile.Close() + + // Write the JSON file to disk + err = EncodeJWK(keyFile, key) + if err != nil { + return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err) + } + + return nil +} + +func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error { + err := os.MkdirAll(f.envConfig.KeysPath, 0700) + if err != nil { + return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err) + } + + // Encode the key to JSON + data, err := EncodeJWKBytes(key) + if err != nil { + return fmt.Errorf("failed to encode key to JSON: %w", err) + } + + // Encrypt the key then encode to Base64 + enc, err := cryptoutils.Encrypt(f.kek, data, nil) + if err != nil { + return fmt.Errorf("failed to encrypt key: %w", err) + } + encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc))) + base64.StdEncoding.Encode(encB64, enc) + + // Write to disk + encJwkPath := f.encJwkPath() + err = os.WriteFile(encJwkPath, encB64, 0600) + if err != nil { + return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err) + } + + return nil +} + +func (f *KeyProviderFile) jwkPath() string { + return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile) +} + +func (f *KeyProviderFile) encJwkPath() string { + return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted) +} + +// Compile-time interface check +var _ KeyProvider = (*KeyProviderFile)(nil) diff --git a/backend/internal/utils/jwk/key_provider_file_test.go b/backend/internal/utils/jwk/key_provider_file_test.go new file mode 100644 index 00000000..768dbee2 --- /dev/null +++ b/backend/internal/utils/jwk/key_provider_file_test.go @@ -0,0 +1,320 @@ +package jwk + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "os" + "path/filepath" + "testing" + + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/utils" + cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" +) + +func TestKeyProviderFile_LoadKey(t *testing.T) { + // Generate a test key to use in our tests + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + key, err := jwk.Import(pk) + require.NoError(t, err) + + t.Run("LoadKey with no existing key", func(t *testing.T) { + tempDir := t.TempDir() + + provider := &KeyProviderFile{} + err := provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + }) + require.NoError(t, err) + + // Load key when none exists + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.Nil(t, loadedKey, "Expected nil key when no key exists") + }) + + t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) { + tempDir := t.TempDir() + + provider := &KeyProviderFile{} + err = provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + Kek: makeKEK(t), + }) + require.NoError(t, err) + + // Load key when none exists + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.Nil(t, loadedKey, "Expected nil key when no key exists") + }) + + t.Run("LoadKey with unencrypted key", func(t *testing.T) { + tempDir := t.TempDir() + + provider := &KeyProviderFile{} + err := provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + }) + require.NoError(t, err) + + // Save a key + err = provider.SaveKey(key) + require.NoError(t, err) + + // Make sure the key file exists + keyPath := filepath.Join(tempDir, PrivateKeyFile) + exists, err := utils.FileExists(keyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected key file to exist") + + // Load the key + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.NotNil(t, loadedKey, "Expected non-nil key when key exists") + + // Verify the loaded key is the same as the original + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + loadedKeyBytes, err := EncodeJWKBytes(loadedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key") + }) + + t.Run("LoadKey with encrypted key", func(t *testing.T) { + tempDir := t.TempDir() + + provider := &KeyProviderFile{} + err = provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + Kek: makeKEK(t), + }) + require.NoError(t, err) + + // Save a key (will be encrypted) + err = provider.SaveKey(key) + require.NoError(t, err) + + // Make sure the encrypted key file exists + encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) + exists, err := utils.FileExists(encKeyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected encrypted key file to exist") + + // Make sure the unencrypted key file does not exist + keyPath := filepath.Join(tempDir, PrivateKeyFile) + exists, err = utils.FileExists(keyPath) + require.NoError(t, err) + assert.False(t, exists, "Expected unencrypted key file to not exist") + + // Load the key + loadedKey, err := provider.LoadKey() + require.NoError(t, err) + assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists") + + // Verify the loaded key is the same as the original + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + loadedKeyBytes, err := EncodeJWKBytes(loadedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key") + }) + + t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) { + tempDir := t.TempDir() + + // First, create an unencrypted key + providerNoKek := &KeyProviderFile{} + err := providerNoKek.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + }) + require.NoError(t, err) + + // Save an unencrypted key + err = providerNoKek.SaveKey(key) + require.NoError(t, err) + + // Verify unencrypted key exists + keyPath := filepath.Join(tempDir, PrivateKeyFile) + exists, err := utils.FileExists(keyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected unencrypted key file to exist") + + // Now create a provider with a kek + kek := make([]byte, 32) + _, err = rand.Read(kek) + require.NoError(t, err) + + providerWithKek := &KeyProviderFile{} + err = providerWithKek.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + Kek: kek, + }) + require.NoError(t, err) + + // Load the key - this should convert the unencrypted key to encrypted + loadedKey, err := providerWithKek.LoadKey() + require.NoError(t, err) + assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key") + + // Verify the unencrypted key no longer exists + exists, err = utils.FileExists(keyPath) + require.NoError(t, err) + assert.False(t, exists, "Expected unencrypted key file to be removed") + + // Verify the encrypted key file exists + encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) + exists, err = utils.FileExists(encKeyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected encrypted key file to exist after conversion") + + // Verify the key data + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + loadedKeyBytes, err := EncodeJWKBytes(loadedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion") + }) +} + +func TestKeyProviderFile_SaveKey(t *testing.T) { + // Generate a test key to use in our tests + pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + key, err := jwk.Import(pk) + require.NoError(t, err) + + t.Run("SaveKey unencrypted", func(t *testing.T) { + tempDir := t.TempDir() + + provider := &KeyProviderFile{} + err := provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + }) + require.NoError(t, err) + + // Save the key + err = provider.SaveKey(key) + require.NoError(t, err) + + // Verify the key file exists + keyPath := filepath.Join(tempDir, PrivateKeyFile) + exists, err := utils.FileExists(keyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected key file to exist") + + // Verify the content of the key file + data, err := os.ReadFile(keyPath) + require.NoError(t, err) + + parsedKey, err := jwk.ParseKey(data) + require.NoError(t, err) + + // Compare the saved key with the original + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + parsedKeyBytes, err := EncodeJWKBytes(parsedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key") + }) + + t.Run("SaveKey encrypted", func(t *testing.T) { + tempDir := t.TempDir() + + // Generate a 64-byte kek + kek := makeKEK(t) + + provider := &KeyProviderFile{} + err = provider.Init(KeyProviderOpts{ + EnvConfig: &common.EnvConfigSchema{ + KeysPath: tempDir, + }, + Kek: kek, + }) + require.NoError(t, err) + + // Save the key (will be encrypted) + err = provider.SaveKey(key) + require.NoError(t, err) + + // Verify the encrypted key file exists + encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) + exists, err := utils.FileExists(encKeyPath) + require.NoError(t, err) + assert.True(t, exists, "Expected encrypted key file to exist") + + // Verify the unencrypted key file doesn't exist + keyPath := filepath.Join(tempDir, PrivateKeyFile) + exists, err = utils.FileExists(keyPath) + require.NoError(t, err) + assert.False(t, exists, "Expected unencrypted key file to not exist") + + // Manually decrypt the encrypted key file to verify it contains the correct key + encB64, err := os.ReadFile(encKeyPath) + require.NoError(t, err) + + // Decode from base64 + enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64))) + n, err := base64.StdEncoding.Decode(enc, encB64) + require.NoError(t, err) + enc = enc[:n] // Trim any padding + + // Decrypt the data + data, err := cryptoutils.Decrypt(kek, enc, nil) + require.NoError(t, err) + + // Parse the key + parsedKey, err := jwk.ParseKey(data) + require.NoError(t, err) + + // Compare the decrypted key with the original + keyBytes, err := EncodeJWKBytes(key) + require.NoError(t, err) + + parsedKeyBytes, err := EncodeJWKBytes(parsedKey) + require.NoError(t, err) + + assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key") + }) +} + +func makeKEK(t *testing.T) []byte { + t.Helper() + + // Generate a 32-byte kek + kek := make([]byte, 32) + _, err := rand.Read(kek) + require.NoError(t, err) + return kek +} diff --git a/backend/internal/utils/jwk/utils.go b/backend/internal/utils/jwk/utils.go new file mode 100644 index 00000000..815d5734 --- /dev/null +++ b/backend/internal/utils/jwk/utils.go @@ -0,0 +1,180 @@ +package jwk + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha3" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "hash" + "io" + "os" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + + "github.com/pocket-id/pocket-id/backend/internal/common" +) + +const ( + // KeyUsageSigning is the usage for the private keys, for the "use" property + KeyUsageSigning = "sig" +) + +// EncodeJWK encodes a jwk.Key to a writable stream. +func EncodeJWK(w io.Writer, key jwk.Key) error { + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + return enc.Encode(key) +} + +// EncodeJWKBytes encodes a jwk.Key to a byte slice. +func EncodeJWKBytes(key jwk.Key) ([]byte, error) { + b := &bytes.Buffer{} + err := EncodeJWK(b, key) + if err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// LoadKeyEncryptionKey loads the key encryption key for JWKs +func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) { + // Try getting the key from the env var as string + kekInput := []byte(envConfig.EncryptionKey) + + // If there's nothing in the env, try loading from file + if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" { + kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err) + } + } + + // If there's still no key, return + if len(kekInput) == 0 { + return nil, nil + } + + // We need a 256-bit key for encryption with AES-GCM-256 + // We use HMAC with SHA3-256 here to derive the key from the one passed as input + // The key is tied to a specific instance of Pocket ID + h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput) + fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek") + kek = h.Sum(nil) + + return kek, nil +} + +// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key. +// It also populates additional fields such as the key ID, usage, and alg. +func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) { + key, err := jwk.Import(rawKey) + if err != nil { + return nil, fmt.Errorf("failed to import generated private key: %w", err) + } + + // Generate the key ID + kid, err := generateRandomKeyID() + if err != nil { + return nil, fmt.Errorf("failed to generate key ID: %w", err) + } + _ = key.Set(jwk.KeyIDKey, kid) + + // Set other required fields + _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning) + EnsureAlgInKey(key, alg, crv) + + return key, nil +} + +// generateRandomKeyID generates a random key ID. +func generateRandomKeyID() (string, error) { + buf := make([]byte, 8) + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + return "", fmt.Errorf("failed to read random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type +func EnsureAlgInKey(key jwk.Key, alg string, crv string) { + _, ok := key.Algorithm() + if ok { + // Algorithm is already set + return + } + + if alg != "" { + _ = key.Set(jwk.AlgorithmKey, alg) + if crv != "" { + eca, ok := jwa.LookupEllipticCurveAlgorithm(crv) + if ok { + switch key.KeyType() { + case jwa.EC(): + _ = key.Set(jwk.ECDSACrvKey, eca) + case jwa.OKP(): + _ = key.Set(jwk.OKPCrvKey, eca) + } + } + } + return + } + + // If we don't have an algorithm, set the default for the key type + switch key.KeyType() { + case jwa.RSA(): + // Default to RS256 for RSA keys + _ = key.Set(jwk.AlgorithmKey, jwa.RS256()) + case jwa.EC(): + // Default to ES256 for ECDSA keys + _ = key.Set(jwk.AlgorithmKey, jwa.ES256()) + _ = key.Set(jwk.ECDSACrvKey, jwa.P256()) + case jwa.OKP(): + // Default to EdDSA and Ed25519 for OKP keys + _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA()) + _ = key.Set(jwk.OKPCrvKey, jwa.Ed25519()) + } +} + +// GenerateKey generates a new jwk.Key +func GenerateKey(alg string, crv string) (key jwk.Key, err error) { + var rawKey any + switch alg { + case jwa.RS256().String(): + rawKey, err = rsa.GenerateKey(rand.Reader, 2048) + case jwa.RS384().String(): + rawKey, err = rsa.GenerateKey(rand.Reader, 3072) + case jwa.RS512().String(): + rawKey, err = rsa.GenerateKey(rand.Reader, 4096) + case jwa.ES256().String(): + rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case jwa.ES384().String(): + rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + case jwa.ES512().String(): + rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + case jwa.EdDSA().String(): + switch crv { + case jwa.Ed25519().String(): + _, rawKey, err = ed25519.GenerateKey(rand.Reader) + default: + return nil, errors.New("unsupported curve for EdDSA algorithm") + } + default: + return nil, errors.New("unsupported key algorithm") + } + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Import the raw key + return ImportRawKey(rawKey, alg, crv) +} diff --git a/backend/internal/utils/jwk/utils_test.go b/backend/internal/utils/jwk/utils_test.go new file mode 100644 index 00000000..e25f7277 --- /dev/null +++ b/backend/internal/utils/jwk/utils_test.go @@ -0,0 +1,324 @@ +package jwk + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "testing" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateKey(t *testing.T) { + tests := []struct { + name string + alg string + crv string + expectError bool + expectedAlg jwa.SignatureAlgorithm + }{ + { + name: "RS256", + alg: jwa.RS256().String(), + crv: "", + expectError: false, + expectedAlg: jwa.RS256(), + }, + { + name: "RS384", + alg: jwa.RS384().String(), + crv: "", + expectError: false, + expectedAlg: jwa.RS384(), + }, + // Skip the RS512 test as generating a RSA-4096 key can take some time + /* { + name: "RS512", + alg: jwa.RS512().String(), + crv: "", + expectError: false, + expectedAlg: jwa.RS512(), + }, */ + { + name: "ES256", + alg: jwa.ES256().String(), + crv: jwa.P256().String(), + expectError: false, + expectedAlg: jwa.ES256(), + }, + { + name: "ES384", + alg: jwa.ES384().String(), + crv: jwa.P384().String(), + expectError: false, + expectedAlg: jwa.ES384(), + }, + { + name: "ES512", + alg: jwa.ES512().String(), + crv: jwa.P521().String(), + expectError: false, + expectedAlg: jwa.ES512(), + }, + { + name: "EdDSA with Ed25519", + alg: jwa.EdDSA().String(), + crv: jwa.Ed25519().String(), + expectError: false, + expectedAlg: jwa.EdDSA(), + }, + { + name: "EdDSA with unsupported curve", + alg: jwa.EdDSA().String(), + crv: "unsupported", + expectError: true, + }, + { + name: "Unsupported algorithm", + alg: "UNSUPPORTED", + crv: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key, err := GenerateKey(tt.alg, tt.crv) + + if tt.expectError { + require.Error(t, err) + assert.Nil(t, key) + return + } + + require.NoError(t, err) + require.NotNil(t, key) + + // Verify the algorithm is set correctly + alg, ok := key.Algorithm() + require.True(t, ok, "algorithm should be set in the key") + assert.Equal(t, tt.expectedAlg.String(), alg.String()) + + // Verify other required fields are set + kid, ok := key.KeyID() + assert.True(t, ok, "key ID should be set") + assert.NotEmpty(t, kid, "key ID should not be empty") + + usage, ok := key.KeyUsage() + assert.True(t, ok, "key usage should be set") + assert.Equal(t, KeyUsageSigning, usage) + + var crv any + _ = key.Get("crv", &crv) + + // Verify key type matches expected algorithm + switch tt.expectedAlg { + case jwa.RS256(), jwa.RS384(), jwa.RS512(): + assert.Equal(t, jwa.RSA(), key.KeyType()) + assert.Nil(t, crv) + case jwa.ES256(), jwa.ES384(), jwa.ES512(): + assert.Equal(t, jwa.EC(), key.KeyType()) + eca, ok := crv.(jwa.EllipticCurveAlgorithm) + _ = assert.NotNil(t, crv) && + assert.True(t, ok) && + assert.Equal(t, tt.crv, eca.String()) + case jwa.EdDSA(): + assert.Equal(t, jwa.OKP(), key.KeyType()) + eca, ok := crv.(jwa.EllipticCurveAlgorithm) + _ = assert.NotNil(t, crv) && + assert.True(t, ok) && + assert.Equal(t, tt.crv, eca.String()) + } + }) + } +} + +func TestEnsureAlgInKey(t *testing.T) { + // Generate an RSA-2048 key + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + t.Run("does not change alg already set", func(t *testing.T) { + // Import the RSA key + key, err := jwk.Import(rsaKey) + require.NoError(t, err) + + // Pre-set the algorithm + _ = key.Set(jwk.AlgorithmKey, jwa.RS256()) + + // Call EnsureAlgInKey with a different algorithm + EnsureAlgInKey(key, jwa.RS384().String(), "") + + // Verify the algorithm wasn't changed + alg, ok := key.Algorithm() + require.True(t, ok) + assert.Equal(t, jwa.RS256().String(), alg.String()) + }) + + t.Run("set algorithm to explicitly-provided value", func(t *testing.T) { + tests := []struct { + name string + keyGen func() (any, error) + alg string + crv string + expectedAlg jwa.SignatureAlgorithm + expectedCrv string + }{ + { + name: "RSA key with RS384", + keyGen: func() (any, error) { + return rsaKey, nil + }, + alg: jwa.RS384().String(), + crv: "", + expectedAlg: jwa.RS384(), + expectedCrv: "", + }, + { + name: "ECDSA key with ES384", + keyGen: func() (any, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + }, + alg: jwa.ES384().String(), + crv: jwa.P384().String(), + expectedAlg: jwa.ES384(), + expectedCrv: jwa.P384().String(), + }, + { + name: "Ed25519 key with EdDSA", + keyGen: func() (any, error) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + return priv, err + }, + alg: jwa.EdDSA().String(), + crv: jwa.Ed25519().String(), + expectedAlg: jwa.EdDSA(), + expectedCrv: jwa.Ed25519().String(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rawKey, err := tt.keyGen() + require.NoError(t, err) + + key, err := jwk.Import(rawKey) + require.NoError(t, err) + + // Ensure no algorithm is set initially + _, ok := key.Algorithm() + assert.False(t, ok) + + // Call EnsureAlgInKey + EnsureAlgInKey(key, tt.alg, tt.crv) + + // Verify the algorithm was set correctly + alg, ok := key.Algorithm() + require.True(t, ok) + assert.Equal(t, tt.expectedAlg.String(), alg.String()) + + // Verify curve if expected + if tt.expectedCrv != "" { + var crv any + _ = key.Get("crv", &crv) + require.NotNil(t, crv) + eca, ok := crv.(jwa.EllipticCurveAlgorithm) + require.True(t, ok) + assert.Equal(t, tt.expectedCrv, eca.String()) + } + }) + } + }) + + t.Run("set default algorithms if not present", func(t *testing.T) { + tests := []struct { + name string + keyGen func() (any, error) + expectedAlg jwa.SignatureAlgorithm + expectedCrv string + }{ + { + name: "RSA key defaults to RS256", + keyGen: func() (any, error) { + return rsaKey, nil + }, + expectedAlg: jwa.RS256(), + expectedCrv: "", + }, + { + name: "ECDSA key defaults to ES256 with P256", + keyGen: func() (any, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + }, + expectedAlg: jwa.ES256(), + expectedCrv: jwa.P256().String(), + }, + { + name: "Ed25519 key defaults to EdDSA with Ed25519", + keyGen: func() (any, error) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + return priv, err + }, + expectedAlg: jwa.EdDSA(), + expectedCrv: jwa.Ed25519().String(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rawKey, err := tt.keyGen() + require.NoError(t, err) + + key, err := jwk.Import(rawKey) + require.NoError(t, err) + + // Ensure no algorithm is set initially + _, ok := key.Algorithm() + assert.False(t, ok) + + // Call EnsureAlgInKey with empty parameters + EnsureAlgInKey(key, "", "") + + // Verify the default algorithm was set + alg, ok := key.Algorithm() + require.True(t, ok) + assert.Equal(t, tt.expectedAlg.String(), alg.String()) + + // Verify curve if expected + if tt.expectedCrv != "" { + var crv any + _ = key.Get("crv", &crv) + require.NotNil(t, crv) + eca, ok := crv.(jwa.EllipticCurveAlgorithm) + require.True(t, ok) + assert.Equal(t, tt.expectedCrv, eca.String()) + } + }) + } + }) + + t.Run("invalid curve should not set curve parameter", func(t *testing.T) { + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + key, err := jwk.Import(rsaKey) + require.NoError(t, err) + + // Call EnsureAlgInKey with invalid curve + EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve") + + // Verify algorithm was set but curve was not + alg, ok := key.Algorithm() + require.True(t, ok) + assert.Equal(t, jwa.RS256().String(), alg.String()) + + var crv any + _ = key.Get("crv", &crv) + assert.Nil(t, crv) + }) +} diff --git a/backend/internal/utils/jwk_util.go b/backend/internal/utils/jwk_util.go deleted file mode 100644 index 2571b514..00000000 --- a/backend/internal/utils/jwk_util.go +++ /dev/null @@ -1,69 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" - - "github.com/lestrrat-go/jwx/v3/jwa" - "github.com/lestrrat-go/jwx/v3/jwk" -) - -const ( - // KeyUsageSigning is the usage for the private keys, for the "use" property - KeyUsageSigning = "sig" -) - -// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key. -// It also populates additional fields such as the key ID, usage, and alg. -func ImportRawKey(rawKey any) (jwk.Key, error) { - key, err := jwk.Import(rawKey) - if err != nil { - return nil, fmt.Errorf("failed to import generated private key: %w", err) - } - - // Generate the key ID - kid, err := generateRandomKeyID() - if err != nil { - return nil, fmt.Errorf("failed to generate key ID: %w", err) - } - _ = key.Set(jwk.KeyIDKey, kid) - - // Set other required fields - _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning) - EnsureAlgInKey(key) - - return key, nil -} - -// generateRandomKeyID generates a random key ID. -func generateRandomKeyID() (string, error) { - buf := make([]byte, 8) - _, err := io.ReadFull(rand.Reader, buf) - if err != nil { - return "", fmt.Errorf("failed to read random bytes: %w", err) - } - return base64.RawURLEncoding.EncodeToString(buf), nil -} - -// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type -func EnsureAlgInKey(key jwk.Key) { - _, ok := key.Algorithm() - if ok { - // Algorithm is already set - return - } - - switch key.KeyType() { - case jwa.RSA(): - // Default to RS256 for RSA keys - _ = key.Set(jwk.AlgorithmKey, jwa.RS256()) - case jwa.EC(): - // Default to ES256 for ECDSA keys - _ = key.Set(jwk.AlgorithmKey, jwa.ES256()) - case jwa.OKP(): - // Default to EdDSA for OKP keys - _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA()) - } -} diff --git a/backend/internal/service/testutils_test.go b/backend/internal/utils/testing/database.go similarity index 68% rename from backend/internal/service/testutils_test.go rename to backend/internal/utils/testing/database.go index 59cdd9fe..a58a789b 100644 --- a/backend/internal/service/testutils_test.go +++ b/backend/internal/utils/testing/database.go @@ -1,9 +1,8 @@ -package service +// This file is only imported by unit tests + +package testing import ( - "io" - "net/http" - "strings" "testing" "time" @@ -21,7 +20,10 @@ import ( "github.com/pocket-id/pocket-id/backend/resources" ) -func newDatabaseForTest(t *testing.T) *gorm.DB { +// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database. +// Each database connection is unique for the test. +// All migrations are automatically performed. +func NewDatabaseForTest(t *testing.T) *gorm.DB { t.Helper() // Get a name for this in-memory database that is specific to the test @@ -68,30 +70,3 @@ type testLoggerAdapter struct { func (l testLoggerAdapter) Printf(format string, args ...any) { l.t.Logf(format, args...) } - -// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL -type MockRoundTripper struct { - Err error - Responses map[string]*http.Response -} - -// RoundTrip implements the http.RoundTripper interface -func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - // Check if we have a specific response for this URL - for url, resp := range m.Responses { - if req.URL.String() == url { - return resp, nil - } - } - - return NewMockResponse(http.StatusNotFound, ""), nil -} - -// NewMockResponse creates an http.Response with the given status code and body -func NewMockResponse(statusCode int, body string) *http.Response { - return &http.Response{ - StatusCode: statusCode, - Body: io.NopCloser(strings.NewReader(body)), - Header: make(http.Header), - } -} diff --git a/backend/internal/utils/testing/round_tripper.go b/backend/internal/utils/testing/round_tripper.go new file mode 100644 index 00000000..806e2eaf --- /dev/null +++ b/backend/internal/utils/testing/round_tripper.go @@ -0,0 +1,38 @@ +// This file is only imported by unit tests + +package testing + +import ( + "io" + "net/http" + "strings" + + _ "github.com/golang-migrate/migrate/v4/source/file" +) + +// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL +type MockRoundTripper struct { + Err error + Responses map[string]*http.Response +} + +// RoundTrip implements the http.RoundTripper interface +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Check if we have a specific response for this URL + for url, resp := range m.Responses { + if req.URL.String() == url { + return resp, nil + } + } + + return NewMockResponse(http.StatusNotFound, ""), nil +} + +// NewMockResponse creates an http.Response with the given status code and body +func NewMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} diff --git a/backend/resources/migrations/postgres/20250630000000_kv_table.down.sql b/backend/resources/migrations/postgres/20250630000000_kv_table.down.sql new file mode 100644 index 00000000..7ec25a79 --- /dev/null +++ b/backend/resources/migrations/postgres/20250630000000_kv_table.down.sql @@ -0,0 +1 @@ +DROP TABLE kv; diff --git a/backend/resources/migrations/postgres/20250630000000_kv_table.up.sql b/backend/resources/migrations/postgres/20250630000000_kv_table.up.sql new file mode 100644 index 00000000..65eb379a --- /dev/null +++ b/backend/resources/migrations/postgres/20250630000000_kv_table.up.sql @@ -0,0 +1,6 @@ +-- The "kv" tables contains miscellaneous key-value pairs +CREATE TABLE kv +( + "key" TEXT NOT NULL PRIMARY KEY, + "value" TEXT +); diff --git a/backend/resources/migrations/sqlite/20250630000000_kv_table.down.sql b/backend/resources/migrations/sqlite/20250630000000_kv_table.down.sql new file mode 100644 index 00000000..7ec25a79 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250630000000_kv_table.down.sql @@ -0,0 +1 @@ +DROP TABLE kv; diff --git a/backend/resources/migrations/sqlite/20250630000000_kv_table.up.sql b/backend/resources/migrations/sqlite/20250630000000_kv_table.up.sql new file mode 100644 index 00000000..fae1de92 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250630000000_kv_table.up.sql @@ -0,0 +1,6 @@ +-- The "kv" tables contains miscellaneous key-value pairs +CREATE TABLE kv +( + "key" TEXT NOT NULL PRIMARY KEY, + "value" TEXT NOT NULL +);