diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 32204aca..e2b55f59 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -1,9 +1,12 @@ package main import ( + "fmt" + "os" _ "time/tzdata" "github.com/pocket-id/pocket-id/backend/internal/cmds" + "github.com/pocket-id/pocket-id/backend/internal/common" ) // @title Pocket ID API @@ -11,5 +14,9 @@ import ( // @description.markdown func main() { + if err := common.ValidateEnvConfig(&common.EnvConfig); err != nil { + fmt.Fprintf(os.Stderr, "config error: %v\n", err) + os.Exit(1) + } cmds.Execute() } diff --git a/backend/internal/cmds/key_rotate_test.go b/backend/internal/cmds/key_rotate_test.go index c6528019..bccd3a3d 100644 --- a/backend/internal/cmds/key_rotate_test.go +++ b/backend/internal/cmds/key_rotate_test.go @@ -1,8 +1,6 @@ package cmds import ( - "os" - "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -69,78 +67,14 @@ func TestKeyRotate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Run("file storage", func(t *testing.T) { - testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg) - }) - - t.Run("database storage", func(t *testing.T) { - testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg) - }) + testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg) }) } } -func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) { - // Create temporary directory for keys - tempDir := t.TempDir() - keysPath := filepath.Join(tempDir, "keys") - err := os.MkdirAll(keysPath, 0755) - require.NoError(t, err) - - // Set up file storage config - envConfig := &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: keysPath, - } - - // Create test database - db := testingutils.NewDatabaseForTest(t) - - // Initialize app config service and create instance - appConfigService, err := service.NewAppConfigService(t.Context(), db) - require.NoError(t, err) - instanceID := appConfigService.GetDbConfig().InstanceID.Value - - // Check if key exists before rotation - keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID) - require.NoError(t, err) - - // Run the key rotation - err = keyRotate(t.Context(), flags, db, envConfig) - - if wantErr { - require.Error(t, err) - if errMsg != "" { - require.ErrorContains(t, err, errMsg) - } - return - } - - require.NoError(t, err) - - // Verify key was created - key, err := keyProvider.LoadKey() - require.NoError(t, err) - require.NotNil(t, key) - - // Verify the algorithm matches what we requested - alg, _ := key.Algorithm() - assert.NotEmpty(t, alg) - if flags.Alg != "" { - expectedAlg := flags.Alg - if expectedAlg == "EdDSA" { - // EdDSA keys should have the EdDSA algorithm - assert.Equal(t, "EdDSA", alg.String()) - } else { - assert.Equal(t, expectedAlg, alg.String()) - } - } -} - func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) { // Set up database storage config envConfig := &common.EnvConfigSchema{ - KeysStorage: "database", EncryptionKey: []byte("test-encryption-key-characters-long"), } diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index cb81eb0e..4011a36f 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -52,8 +52,6 @@ type EnvConfigSchema struct { S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"` S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"` S3DisableDefaultIntegrityChecks bool `env:"S3_DISABLE_DEFAULT_INTEGRITY_CHECKS"` - KeysPath string `env:"KEYS_PATH"` - KeysStorage string `env:"KEYS_STORAGE"` EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"` Port string `env:"PORT"` Host string `env:"HOST" options:"toLower"` @@ -90,7 +88,6 @@ func defaultConfig() EnvConfigSchema { LogLevel: "info", DbProvider: "sqlite", FileBackend: "filesystem", - KeysPath: "data/keys", AuditLogRetentionDays: 90, AppURL: AppUrl, Port: "1411", @@ -119,21 +116,20 @@ func parseEnvConfig() error { return fmt.Errorf("error preparing env config: %w", err) } - err = validateEnvConfig(&EnvConfig) - if err != nil { - return err - } - return nil } -// validateEnvConfig checks the EnvConfig for required fields and valid values -func validateEnvConfig(config *EnvConfigSchema) error { +// ValidateEnvConfig checks the EnvConfig for required fields and valid values +func ValidateEnvConfig(config *EnvConfigSchema) error { if _, err := sloggin.ParseLevel(config.LogLevel); err != nil { return errors.New("invalid LOG_LEVEL value. Must be 'debug', 'info', 'warn' or 'error'") } + if len(config.EncryptionKey) < 16 { + return errors.New("ENCRYPTION_KEY must be at least 16 bytes long") + } + switch config.DbProvider { case DbProviderSqlite: if config.DbConnectionString == "" { @@ -168,28 +164,10 @@ func validateEnvConfig(config *EnvConfigSchema) error { } } - switch config.KeysStorage { - // KeysStorage defaults to "file" if empty - case "": - config.KeysStorage = "file" - case "database": - if config.EncryptionKey == nil { - return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database") - } - case "file": - // All good, these are valid values - default: - return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage) - } - switch config.FileBackend { - case "s3": - if config.KeysStorage == "file" { - return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'") - } - case "database": + case "s3", "database": // All good, these are valid values - case "", "filesystem": + case "", "fs": if config.UploadPath == "" { config.UploadPath = defaultFsUploadPath } diff --git a/backend/internal/common/env_config_test.go b/backend/internal/common/env_config_test.go index bfd8a887..6d023586 100644 --- a/backend/internal/common/env_config_test.go +++ b/backend/internal/common/env_config_test.go @@ -8,6 +8,20 @@ import ( "github.com/stretchr/testify/require" ) +func parseAndValidateEnvConfig(t *testing.T) error { + t.Helper() + + if _, exists := os.LookupEnv("ENCRYPTION_KEY"); !exists { + t.Setenv("ENCRYPTION_KEY", "0123456789abcdef") + } + + if err := parseEnvConfig(); err != nil { + return err + } + + return ValidateEnvConfig(&EnvConfig) +} + func TestParseEnvConfig(t *testing.T) { // Store original config to restore later originalConfig := EnvConfig @@ -21,7 +35,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("APP_URL", "HTTP://LOCALHOST:3000") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider) assert.Equal(t, "http://localhost:3000", EnvConfig.AppURL) @@ -33,7 +47,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db") t.Setenv("APP_URL", "https://example.com") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider) }) @@ -44,17 +58,29 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "test") t.Setenv("APP_URL", "http://localhost:3000") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "invalid DB_PROVIDER value") }) + t.Run("should fail when ENCRYPTION_KEY is too short", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_PROVIDER", "sqlite") + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("ENCRYPTION_KEY", "short") + + err := parseAndValidateEnvConfig(t) + require.Error(t, err) + assert.ErrorContains(t, err, "ENCRYPTION_KEY must be at least 16 bytes long") + }) + t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) { EnvConfig = defaultConfig() t.Setenv("DB_PROVIDER", "sqlite") t.Setenv("APP_URL", "http://localhost:3000") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString) }) @@ -64,7 +90,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_PROVIDER", "postgres") t.Setenv("APP_URL", "http://localhost:3000") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres") }) @@ -75,7 +101,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("APP_URL", "€://not-a-valid-url") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "APP_URL is not a valid URL") }) @@ -86,7 +112,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("APP_URL", "http://localhost:3000/path") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "APP_URL must not contain a path") }) @@ -97,7 +123,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("INTERNAL_APP_URL", "€://not-a-valid-url") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "INTERNAL_APP_URL is not a valid URL") }) @@ -108,65 +134,11 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("DB_CONNECTION_STRING", "file:test.db") t.Setenv("INTERNAL_APP_URL", "http://localhost:3000/path") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "INTERNAL_APP_URL must not contain a path") }) - t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) { - EnvConfig = defaultConfig() - t.Setenv("DB_PROVIDER", "sqlite") - t.Setenv("DB_CONNECTION_STRING", "file:test.db") - t.Setenv("APP_URL", "http://localhost:3000") - - err := parseEnvConfig() - require.NoError(t, err) - assert.Equal(t, "file", EnvConfig.KeysStorage) - }) - - t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) { - EnvConfig = defaultConfig() - t.Setenv("DB_PROVIDER", "sqlite") - t.Setenv("DB_CONNECTION_STRING", "file:test.db") - t.Setenv("APP_URL", "http://localhost:3000") - t.Setenv("KEYS_STORAGE", "database") - - err := parseEnvConfig() - require.Error(t, err) - assert.ErrorContains(t, err, "ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database") - }) - - t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) { - validStorageTypes := []string{"file", "database"} - - for _, storage := range validStorageTypes { - EnvConfig = defaultConfig() - t.Setenv("DB_PROVIDER", "sqlite") - t.Setenv("DB_CONNECTION_STRING", "file:test.db") - t.Setenv("APP_URL", "http://localhost:3000") - t.Setenv("KEYS_STORAGE", storage) - if storage == "database" { - t.Setenv("ENCRYPTION_KEY", "test-key") - } - - err := parseEnvConfig() - require.NoError(t, err) - assert.Equal(t, storage, EnvConfig.KeysStorage) - } - }) - - t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) { - EnvConfig = defaultConfig() - t.Setenv("DB_PROVIDER", "sqlite") - t.Setenv("DB_CONNECTION_STRING", "file:test.db") - t.Setenv("APP_URL", "http://localhost:3000") - t.Setenv("KEYS_STORAGE", "invalid") - - err := parseEnvConfig() - require.Error(t, err) - assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE") - }) - t.Run("should parse boolean environment variables correctly", func(t *testing.T) { EnvConfig = defaultConfig() t.Setenv("DB_PROVIDER", "sqlite") @@ -178,7 +150,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("TRUST_PROXY", "true") t.Setenv("ANALYTICS_DISABLED", "false") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.True(t, EnvConfig.UiConfigDisabled) assert.True(t, EnvConfig.MetricsEnabled) @@ -229,14 +201,13 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("APP_URL", "https://prod.example.com") t.Setenv("APP_ENV", "PRODUCTION") t.Setenv("UPLOAD_PATH", "/custom/uploads") - t.Setenv("KEYS_PATH", "/custom/keys") t.Setenv("PORT", "8080") t.Setenv("HOST", "LOCALHOST") t.Setenv("UNIX_SOCKET", "/tmp/app.sock") t.Setenv("MAXMIND_LICENSE_KEY", "test-license") t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.Equal(t, AppEnvProduction, EnvConfig.AppEnv) // lowercased assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath) @@ -252,24 +223,12 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("FILE_BACKEND", "FILESYSTEM") t.Setenv("UPLOAD_PATH", "") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.NoError(t, err) assert.Equal(t, "filesystem", EnvConfig.FileBackend) assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath) }) - t.Run("should fail when FILE_BACKEND is s3 but keys are stored on filesystem", func(t *testing.T) { - EnvConfig = defaultConfig() - t.Setenv("DB_PROVIDER", "sqlite") - t.Setenv("DB_CONNECTION_STRING", "file:test.db") - t.Setenv("APP_URL", "http://localhost:3000") - t.Setenv("FILE_BACKEND", "s3") - - err := parseEnvConfig() - require.Error(t, err) - assert.ErrorContains(t, err, "KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'") - }) - t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) { EnvConfig = defaultConfig() t.Setenv("DB_PROVIDER", "sqlite") @@ -277,7 +236,7 @@ func TestParseEnvConfig(t *testing.T) { t.Setenv("APP_URL", "http://localhost:3000") t.Setenv("FILE_BACKEND", "invalid") - err := parseEnvConfig() + err := parseAndValidateEnvConfig(t) require.Error(t, err) assert.ErrorContains(t, err, "invalid FILE_BACKEND value") }) diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index b156a27f..eb9aef11 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -18,14 +18,6 @@ import ( ) const ( - // PrivateKeyFile is the path in the data/keys folder where the key is stored - // This is a JSON file containing a key encoded as JWK - PrivateKeyFile = "jwt_private_key.json" - - // PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored - // This is a encrypted JSON file containing a key encoded as JWK - PrivateKeyFileEncrypted = "jwt_private_key.json.enc" - // KeyUsageSigning is the usage for the private keys, for the "use" property KeyUsageSigning = "sig" @@ -93,7 +85,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error { // Try loading a key key, err := keyProvider.LoadKey() if err != nil { - return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err) + return fmt.Errorf("failed to load key: %w", err) } // If we have a key, store it in the object and we're done @@ -114,7 +106,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error { // Save the newly-generated key err = keyProvider.SaveKey(s.privateKey) if err != nil { - return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err) + return fmt.Errorf("failed to save private key: %w", err) } return nil diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 46425b3a..0aa33c5f 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -7,8 +7,6 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" - "os" - "path/filepath" "sync" "testing" "time" @@ -16,49 +14,88 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" - "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/pocket-id/pocket-id/backend/internal/utils" jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" + testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" ) +const testEncryptionKey = "0123456789abcdef0123456789abcdef" + +func newTestEnvConfig() *common.EnvConfigSchema { + return &common.EnvConfigSchema{ + AppURL: "https://test.example.com", + EncryptionKey: []byte(testEncryptionKey), + } +} + +func initJwtService(t *testing.T, db *gorm.DB, appConfig *AppConfigService, envConfig *common.EnvConfigSchema) *JwtService { + t.Helper() + + service := &JwtService{} + err := service.init(db, appConfig, envConfig) + require.NoError(t, err, "Failed to initialize JWT service") + + return service +} + +func setupJwtService(t *testing.T, appConfig *AppConfigService) (*JwtService, *gorm.DB, *common.EnvConfigSchema) { + t.Helper() + + db := testutils.NewDatabaseForTest(t) + envConfig := newTestEnvConfig() + + return initJwtService(t, db, appConfig, envConfig), db, envConfig +} + +func newTestDbAndEnv(t *testing.T) (*gorm.DB, *common.EnvConfigSchema) { + t.Helper() + + return testutils.NewDatabaseForTest(t), newTestEnvConfig() +} + +func saveKeyToDatabase(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSchema, appConfig *AppConfigService, key jwk.Key) string { + t.Helper() + + keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfig.GetDbConfig().InstanceID.Value) + require.NoError(t, err, "Failed to init key provider") + + err = keyProvider.SaveKey(key) + require.NoError(t, err, "Failed to save key") + + kid, ok := key.KeyID() + require.True(t, ok, "Key ID must be set") + require.NotEmpty(t, kid, "Key ID must not be empty") + + return kid +} + func TestJwtService_Init(t *testing.T) { mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes }) t.Run("should generate new key when none exists", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() // Initialize the JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service := initJwtService(t, db, mockConfig, mockEnvConfig) // Verify the private key was set require.NotNil(t, service.privateKey, "Private key should be set") - // Verify the key has been saved to disk as JWK - jwkPath := filepath.Join(tempDir, PrivateKeyFile) - _, err = os.Stat(jwkPath) - require.NoError(t, err, "JWK file should exist") - - // Verify the generated key is valid - keyData, err := os.ReadFile(jwkPath) - require.NoError(t, err) - key, err := jwk.ParseKey(keyData) - require.NoError(t, err) + // Verify the key has been persisted in the database + keyProvider, err := jwkutils.GetKeyProvider(db, mockEnvConfig, mockConfig.GetDbConfig().InstanceID.Value) + require.NoError(t, err, "Failed to init key provider") + key, err := keyProvider.LoadKey() + require.NoError(t, err, "Failed to load key from provider") + require.NotNil(t, key, "Key should be present in the database") // Key should have required properties keyID, ok := key.KeyID() @@ -67,33 +104,22 @@ func TestJwtService_Init(t *testing.T) { keyUsage, ok := key.KeyUsage() assert.True(t, ok, "Key should have a key usage") - assert.Equal(t, "sig", keyUsage) + assert.Equal(t, KeyUsageSigning, keyUsage) }) t.Run("should load existing JWK key", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() // First create a service to generate a key - firstService := &JwtService{} - err := firstService.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err) + firstService := initJwtService(t, db, mockConfig, mockEnvConfig) // Get the key ID of the first service origKeyID, ok := firstService.privateKey.KeyID() require.True(t, ok) // Now create a new service that should load the existing key - secondService := &JwtService{} - err = secondService.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err) + secondService := initJwtService(t, db, mockConfig, mockEnvConfig) // Verify the loaded key has the same ID as the original loadedKeyID, ok := secondService.privateKey.KeyID() @@ -102,23 +128,14 @@ func TestJwtService_Init(t *testing.T) { }) t.Run("should load existing JWK for ECDSA keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - - // Create a new JWK and save it to disk - origKeyID := createECDSAKeyJWK(t, tempDir) + // Create a new JWK and save it to the database + origKeyID := createECDSAKeyJWK(t, db, mockEnvConfig, mockConfig) // Now create a new service that should load the existing key - svc := &JwtService{} - err := svc.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err) + svc := initJwtService(t, db, mockConfig, mockEnvConfig) // Ensure loaded key has the right algorithm alg, ok := svc.privateKey.Algorithm() @@ -132,23 +149,14 @@ func TestJwtService_Init(t *testing.T) { }) t.Run("should load existing JWK for EdDSA keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - - // Create a new JWK and save it to disk - origKeyID := createEdDSAKeyJWK(t, tempDir) + // Create a new JWK and save it to the database + origKeyID := createEdDSAKeyJWK(t, db, mockEnvConfig, mockConfig) // Now create a new service that should load the existing key - svc := &JwtService{} - err := svc.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err) + svc := initJwtService(t, db, mockConfig, mockEnvConfig) // Ensure loaded key has the right algorithm and curve alg, ok := svc.privateKey.Algorithm() @@ -156,7 +164,7 @@ func TestJwtService_Init(t *testing.T) { assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Loaded key has the incorrect algorithm") var curve jwa.EllipticCurveAlgorithm - err = svc.privateKey.Get("crv", &curve) + err := svc.privateKey.Get("crv", &curve) _ = assert.NoError(t, err, "Failed to get 'crv' claim") && assert.Equal(t, jwa.Ed25519().String(), curve.String(), "Curve does not match expected value") @@ -173,20 +181,7 @@ func TestJwtService_GetPublicJWK(t *testing.T) { }) t.Run("returns public key when private key is initialized", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - - // Create a JWT service with initialized key - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) // Get the JWK (public key) publicKey, err := service.GetPublicJWK() @@ -211,23 +206,14 @@ func TestJwtService_GetPublicJWK(t *testing.T) { }) t.Run("returns public key when ECDSA private key is initialized", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - - // Create an ECDSA key and save it as JWK - originalKeyID := createECDSAKeyJWK(t, tempDir) + // Create an ECDSA key and save it in the database + originalKeyID := createECDSAKeyJWK(t, db, mockEnvConfig, mockConfig) // Create a JWT service that loads the ECDSA key - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service := initJwtService(t, db, mockConfig, mockEnvConfig) // Get the JWK (public key) publicKey, err := service.GetPublicJWK() @@ -256,23 +242,14 @@ func TestJwtService_GetPublicJWK(t *testing.T) { }) t.Run("returns public key when EdDSA private key is initialized", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db := testutils.NewDatabaseForTest(t) + mockEnvConfig := newTestEnvConfig() - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - - // Create an EdDSA key and save it as JWK - originalKeyID := createEdDSAKeyJWK(t, tempDir) + // Create an EdDSA key and save it in the database + originalKeyID := createEdDSAKeyJWK(t, db, mockEnvConfig, mockConfig) // Create a JWT service that loads the EdDSA key - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service := initJwtService(t, db, mockConfig, mockEnvConfig) // Get the JWK (public key) publicKey, err := service.GetPublicJWK() @@ -317,46 +294,26 @@ func TestJwtService_GetPublicJWK(t *testing.T) { } func TestGenerateVerifyAccessToken(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Initialize the JWT service with a mock AppConfigService mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes }) - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - t.Run("generates token for regular user", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create a test user user := model.User{ - Base: model.Base{ - ID: "user123", - }, + Base: model.Base{ID: "user123"}, Email: utils.Ptr("user@example.com"), IsAdmin: false, } - // Generate a token tokenString, err := service.GenerateAccessToken(user) require.NoError(t, err, "Failed to generate access token") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -365,9 +322,8 @@ func TestGenerateVerifyAccessToken(t *testing.T) { assert.False(t, isAdmin, "isAdmin should be false") audience, ok := claims.Audience() _ = assert.True(t, ok, "Audience not found in token") && - assert.Equal(t, []string{"https://test.example.com"}, audience, "Audience should contain the app URL") + assert.Equal(t, []string{service.envConfig.AppURL}, audience, "Audience should contain the app URL") - // Check token expiration time is approximately 1 hour from now expectedExp := time.Now().Add(1 * time.Hour) expiration, ok := claims.Expiration() assert.True(t, ok, "Expiration not found in token") @@ -376,29 +332,20 @@ func TestGenerateVerifyAccessToken(t *testing.T) { }) t.Run("generates token for admin user", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create a test admin user adminUser := model.User{ - Base: model.Base{ - ID: "admin123", - }, + Base: model.Base{ID: "admin123"}, Email: utils.Ptr("admin@example.com"), IsAdmin: true, } - // Generate a token tokenString, err := service.GenerateAccessToken(adminUser) require.NoError(t, err, "Failed to generate access token") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token") - // Check the IsAdmin claim is true isAdmin, err := GetIsAdmin(claims) _ = assert.NoError(t, err, "Failed to get isAdmin claim") && assert.True(t, isAdmin, "isAdmin should be true") @@ -408,31 +355,21 @@ func TestGenerateVerifyAccessToken(t *testing.T) { }) t.Run("uses session duration from config", func(t *testing.T) { - // Create a JWT service with a different session duration customMockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes }) + service, _, _ := setupJwtService(t, customMockConfig) - service := &JwtService{} - err := service.init(nil, customMockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") - - // Create a test user user := model.User{ - Base: model.Base{ - ID: "user456", - }, + Base: model.Base{ID: "user456"}, } - // Generate a token tokenString, err := service.GenerateAccessToken(user) require.NoError(t, err, "Failed to generate access token") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token") - // Check token expiration time is approximately 30 minutes from now expectedExp := time.Now().Add(30 * time.Minute) expiration, ok := claims.Expiration() assert.True(t, ok, "Expiration not found in token") @@ -441,44 +378,27 @@ func TestGenerateVerifyAccessToken(t *testing.T) { }) t.Run("works with Ed25519 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createEdDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an Ed25519 key and save it as JWK - origKeyID := createEdDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "eddsauser123", - }, + Base: model.Base{ID: "eddsauser123"}, Email: utils.Ptr("eddsauser@example.com"), IsAdmin: true, } - // Generate a token tokenString, err := service.GenerateAccessToken(user) require.NoError(t, err, "Failed to generate access token with Ed25519 key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token with Ed25519 key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -486,56 +406,36 @@ func TestGenerateVerifyAccessToken(t *testing.T) { _ = assert.NoError(t, err, "Failed to get isAdmin claim") && assert.True(t, isAdmin, "isAdmin should be true") - // Verify the key type is OKP publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, "OKP", publicKey.KeyType().String(), "Key type should be OKP") - - // Verify the algorithm is EdDSA alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, "EdDSA", alg.String(), "Algorithm should be EdDSA") }) t.Run("works with P-256 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createECDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an ECDSA key and save it as JWK - origKeyID := createECDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "ecdsauser123", - }, + Base: model.Base{ID: "ecdsauser123"}, Email: utils.Ptr("ecdsauser@example.com"), IsAdmin: true, } - // Generate a token tokenString, err := service.GenerateAccessToken(user) require.NoError(t, err, "Failed to generate access token with ECDSA key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token with ECDSA key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -543,56 +443,36 @@ func TestGenerateVerifyAccessToken(t *testing.T) { _ = assert.NoError(t, err, "Failed to get isAdmin claim") && assert.True(t, isAdmin, "isAdmin should be true") - // Verify the key type is EC publicKey, err := service.GetPublicJWK() require.NoError(t, err) - assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC") - - // Verify the algorithm is ES256 + assert.Equal(t, "EC", publicKey.KeyType().String(), "Key type should be EC") alg, ok := publicKey.Algorithm() require.True(t, ok) - assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256") + assert.Equal(t, "ES256", alg.String(), "Algorithm should be ES256") }) t.Run("works with RSA-4096 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createRSA4096KeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an RSA-4096 key and save it as JWK - origKeyID := createRSA4096KeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "rsauser123", - }, + Base: model.Base{ID: "rsauser123"}, Email: utils.Ptr("rsauser@example.com"), IsAdmin: true, } - // Generate a token tokenString, err := service.GenerateAccessToken(user) require.NoError(t, err, "Failed to generate access token with RSA key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated token with RSA key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -600,12 +480,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) { _ = assert.NoError(t, err, "Failed to get isAdmin claim") && assert.True(t, isAdmin, "isAdmin should be true") - // Verify the key type is RSA publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA") - - // Verify the algorithm is RS256 alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256") @@ -613,28 +490,13 @@ func TestGenerateVerifyAccessToken(t *testing.T) { } func TestGenerateVerifyIdToken(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Initialize the JWT service with a mock AppConfigService mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes }) - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - t.Run("generates and verifies ID token with standard claims", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create test claims userClaims := map[string]interface{}{ "sub": "user123", "name": "Test User", @@ -642,16 +504,13 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "test-client-123" - // Generate a token tokenString, err := service.GenerateIDToken(userClaims, clientID, "") require.NoError(t, err, "Failed to generate ID token") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyIdToken(tokenString, false) require.NoError(t, err, "Failed to verify generated ID token") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, "user123", subject, "Token subject should match user ID") @@ -662,7 +521,6 @@ func TestGenerateVerifyIdToken(t *testing.T) { _ = assert.True(t, ok, "Issuer not found in token") && assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") - // Check token expiration time is approximately 1 hour from now expectedExp := time.Now().Add(1 * time.Hour) expiration, ok := claims.Expiration() assert.True(t, ok, "Expiration not found in token") @@ -671,12 +529,8 @@ func TestGenerateVerifyIdToken(t *testing.T) { }) t.Run("can accept expired tokens if told so", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create test claims userClaims := map[string]interface{}{ "sub": "user123", "name": "Test User", @@ -684,42 +538,36 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "test-client-123" - // Create a token that's already expired token, err := jwt.NewBuilder(). Subject(userClaims["sub"].(string)). Issuer(service.envConfig.AppURL). Audience([]string{clientID}). IssuedAt(time.Now().Add(-2 * time.Hour)). - Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago + Expiration(time.Now().Add(-1 * time.Hour)). Build() require.NoError(t, err, "Failed to build token") err = SetTokenType(token, IDTokenJWTType) require.NoError(t, err, "Failed to set token type") - // Add custom claims for k, v := range userClaims { - if k != "sub" { // Already set above + if k != "sub" { err = token.Set(k, v) require.NoError(t, err, "Failed to set claim") } } - // Sign the token signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) require.NoError(t, err, "Failed to sign token") tokenString := string(signed) - // Verify the token without allowExpired flag - should fail _, err = service.VerifyIdToken(tokenString, false) require.Error(t, err, "Verification should fail with expired token when not allowing expired tokens") - assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") + assert.Contains(t, err.Error(), "\"exp\" not satisfied", "Error message should indicate token verification failure") - // Verify the token with allowExpired flag - should succeed claims, err := service.VerifyIdToken(tokenString, true) require.NoError(t, err, "Verification should succeed with expired token when allowing expired tokens") - // Validate the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID") @@ -729,12 +577,8 @@ func TestGenerateVerifyIdToken(t *testing.T) { }) t.Run("generates and verifies ID token with nonce", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create test claims with nonce userClaims := map[string]interface{}{ "sub": "user456", "name": "Another User", @@ -742,11 +586,9 @@ func TestGenerateVerifyIdToken(t *testing.T) { const clientID = "test-client-456" nonce := "random-nonce-value" - // Generate a token with nonce tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce) require.NoError(t, err, "Failed to generate ID token with nonce") - // Parse the token manually to check nonce publicKey, err := service.GetPublicJWK() require.NoError(t, err, "Failed to get public key") token, err := jwt.Parse([]byte(tokenString), jwt.WithKey(jwa.RS256(), publicKey)) @@ -760,48 +602,30 @@ func TestGenerateVerifyIdToken(t *testing.T) { }) t.Run("fails verification with incorrect issuer", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Generate a token with standard claims userClaims := map[string]interface{}{ "sub": "user789", } tokenString, err := service.GenerateIDToken(userClaims, "client-789", "") require.NoError(t, err, "Failed to generate ID token") - // Temporarily change the app URL to simulate wrong issuer service.envConfig.AppURL = "https://wrong-issuer.com" - // Verify should fail due to issuer mismatch _, err = service.VerifyIdToken(tokenString, false) require.Error(t, err, "Verification should fail with incorrect issuer") - assert.Contains(t, err.Error(), `"iss" not satisfied`, "Error message should indicate token verification failure") + assert.Contains(t, err.Error(), "\"iss\" not satisfied", "Error message should indicate token verification failure") }) t.Run("works with Ed25519 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createEdDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an Ed25519 key and save it as JWK - origKeyID := createEdDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create test claims userClaims := map[string]interface{}{ "sub": "eddsauser456", "name": "EdDSA User", @@ -809,16 +633,13 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "eddsa-client-123" - // Generate a token tokenString, err := service.GenerateIDToken(userClaims, clientID, "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyIdToken(tokenString, false) require.NoError(t, err, "Failed to verify generated ID token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID") @@ -826,54 +647,36 @@ func TestGenerateVerifyIdToken(t *testing.T) { _ = assert.True(t, ok, "Issuer not found in token") && assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") - // Verify the key type is OKP publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.OKP().String(), publicKey.KeyType().String(), "Key type should be OKP") - - // Verify the algorithm is EdDSA alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Algorithm should be EdDSA") }) t.Run("works with P-256 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createECDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an ECDSA key and save it as JWK - origKeyID := createECDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create test claims userClaims := map[string]interface{}{ "sub": "ecdsauser456", "email": "ecdsauser@example.com", } const clientID = "ecdsa-client-123" - // Generate a token tokenString, err := service.GenerateIDToken(userClaims, clientID, "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyIdToken(tokenString, false) require.NoError(t, err, "Failed to verify generated ID token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID") @@ -881,38 +684,23 @@ func TestGenerateVerifyIdToken(t *testing.T) { _ = assert.True(t, ok, "Issuer not found in token") && assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") - // Verify the key type is EC publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC") - - // Verify the algorithm is ES256 alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256") }) t.Run("works with RSA-4096 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createRSA4096KeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an RSA-4096 key and save it as JWK - origKeyID := createRSA4096KeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create test claims userClaims := map[string]interface{}{ "sub": "rsauser456", "name": "RSA User", @@ -920,16 +708,13 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "rsa-client-123" - // Generate a token tokenString, err := service.GenerateIDToken(userClaims, clientID, "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyIdToken(tokenString, false) require.NoError(t, err, "Failed to verify generated ID token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, "rsauser456", subject, "Token subject should match user ID") @@ -940,46 +725,26 @@ func TestGenerateVerifyIdToken(t *testing.T) { } func TestGenerateVerifyOAuthAccessToken(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Initialize the JWT service with a mock AppConfigService mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes }) - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create a test user user := model.User{ - Base: model.Base{ - ID: "user123", - }, + Base: model.Base{ID: "user123"}, Email: utils.Ptr("user@example.com"), } const clientID = "test-client-123" - // Generate a token tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -990,7 +755,6 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { _ = assert.True(t, ok, "Issuer not found in token") && assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL") - // Check token expiration time is approximately 1 hour from now expectedExp := time.Now().Add(1 * time.Hour) expiration, ok := claims.Expiration() assert.True(t, ok, "Expiration not found in token") @@ -999,23 +763,14 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { }) t.Run("fails verification for expired token", func(t *testing.T) { - // Create a JWT service with a mock function to generate an expired token - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create a test user - user := model.User{ - Base: model.Base{ - ID: "user456", - }, - } + user := model.User{Base: model.Base{ID: "user456"}} const clientID = "test-client-456" - // Generate a token using JWT directly to create an expired token token, err := jwt.NewBuilder(). Subject(user.ID). - Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago + Expiration(time.Now().Add(-1 * time.Hour)). IssuedAt(time.Now().Add(-2 * time.Hour)). Audience([]string{clientID}). Issuer(service.envConfig.AppURL). @@ -1028,85 +783,48 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) require.NoError(t, err, "Failed to sign token") - // Verify should fail due to expiration _, err = service.VerifyOAuthAccessToken(string(signed)) require.Error(t, err, "Verification should fail with expired token") - assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") + assert.Contains(t, err.Error(), "\"exp\" not satisfied", "Error message should indicate token verification failure") }) t.Run("fails verification with invalid signature", func(t *testing.T) { - // Create two JWT services with different keys - service1 := &JwtService{} - err := service1.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: t.TempDir(), // Use a different temp dir - }) - require.NoError(t, err, "Failed to initialize first JWT service") + service1, _, _ := setupJwtService(t, mockConfig) + service2, _, _ := setupJwtService(t, mockConfig) - service2 := &JwtService{} - err = service2.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: t.TempDir(), // Use a different temp dir - }) - require.NoError(t, err, "Failed to initialize second JWT service") - - // Create a test user - user := model.User{ - Base: model.Base{ - ID: "user789", - }, - } + user := model.User{Base: model.Base{ID: "user789"}} const clientID = "test-client-789" - // Generate a token with the first service tokenString, err := service1.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token") - // Verify with the second service should fail due to different keys _, err = service2.VerifyOAuthAccessToken(tokenString) require.Error(t, err, "Verification should fail with invalid signature") assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure") }) t.Run("works with Ed25519 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createEdDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an Ed25519 key and save it as JWK - origKeyID := createEdDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "eddsauser789", - }, + Base: model.Base{ID: "eddsauser789"}, Email: utils.Ptr("eddsaoauth@example.com"), } const clientID = "eddsa-oauth-client" - // Generate a token tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -1114,56 +832,36 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { _ = assert.True(t, ok, "Audience not found in token") && assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID") - // Verify the key type is OKP publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.OKP().String(), publicKey.KeyType().String(), "Key type should be OKP") - - // Verify the algorithm is EdDSA alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.EdDSA().String(), alg.String(), "Algorithm should be EdDSA") }) t.Run("works with ECDSA keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + db, envConfig := newTestDbAndEnv(t) + origKeyID := createECDSAKeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an ECDSA key and save it as JWK - origKeyID := createECDSAKeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "ecdsauser789", - }, + Base: model.Base{ID: "ecdsauser789"}, Email: utils.Ptr("ecdsaoauth@example.com"), } const clientID = "ecdsa-oauth-client" - // Generate a token tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -1171,56 +869,36 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { _ = assert.True(t, ok, "Audience not found in token") && assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID") - // Verify the key type is EC publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.EC().String(), publicKey.KeyType().String(), "Key type should be EC") - - // Verify the algorithm is ES256 alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.ES256().String(), alg.String(), "Algorithm should be ES256") }) - t.Run("works with RSA-4096 keys", func(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() + t.Run("works with RSA keys", func(t *testing.T) { + db, envConfig := newTestDbAndEnv(t) + origKeyID := createRSA4096KeyJWK(t, db, envConfig, mockConfig) + service := initJwtService(t, db, mockConfig, envConfig) - // Create an RSA-4096 key and save it as JWK - origKeyID := createRSA4096KeyJWK(t, tempDir) - - // Create a JWT service that loads the key - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") - - // Verify it loaded the right key loadedKeyID, ok := service.privateKey.KeyID() require.True(t, ok) assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") - // Create a test user user := model.User{ - Base: model.Base{ - ID: "rsauser789", - }, + Base: model.Base{ID: "rsauser789"}, Email: utils.Ptr("rsaoauth@example.com"), } const clientID = "rsa-oauth-client" - // Generate a token tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") - // Check the claims subject, ok := claims.Subject() _ = assert.True(t, ok, "User ID not found in token") && assert.Equal(t, user.ID, subject, "Token subject should match user ID") @@ -1228,12 +906,9 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { _ = assert.True(t, ok, "Audience not found in token") && assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID") - // Verify the key type is RSA publicKey, err := service.GetPublicJWK() require.NoError(t, err) assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA") - - // Verify the algorithm is RS256 alg, ok := publicKey.Algorithm() require.True(t, ok) assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256") @@ -1241,38 +916,21 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { } func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Initialize the JWT service with a mock AppConfigService mockConfig := NewTestAppConfigService(&model.AppConfig{}) - // Setup the environment variable required by the token verification - mockEnvConfig := &common.EnvConfigSchema{ - AppURL: "https://test.example.com", - KeysStorage: "file", - KeysPath: tempDir, - } - t.Run("generates and verifies refresh token", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Create a test user const ( userID = "user123" clientID = "client123" refreshToken = "rt-123" ) - // Generate a token tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken) require.NoError(t, err, "Failed to generate refresh token") assert.NotEmpty(t, tokenString, "Token should not be empty") - // Verify the token resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString) require.NoError(t, err, "Failed to verify generated token") assert.Equal(t, userID, resUser, "Should return correct user ID") @@ -1281,15 +939,11 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { }) t.Run("fails verification for expired token", func(t *testing.T) { - // Create a JWT service - service := &JwtService{} - err := service.init(nil, mockConfig, mockEnvConfig) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) - // Generate a token using JWT directly to create an expired token token, err := jwt.NewBuilder(). Subject("user789"). - Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago + Expiration(time.Now().Add(-1 * time.Hour)). IssuedAt(time.Now().Add(-2 * time.Hour)). Audience([]string{"client123"}). Issuer(service.envConfig.AppURL). @@ -1299,33 +953,18 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) require.NoError(t, err, "Failed to sign token") - // Verify should fail due to expiration _, _, _, err = service.VerifyOAuthRefreshToken(string(signed)) require.Error(t, err, "Verification should fail with expired token") - assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") + assert.Contains(t, err.Error(), "\"exp\" not satisfied", "Error message should indicate token verification failure") }) t.Run("fails verification with invalid signature", func(t *testing.T) { - // Create two JWT services with different keys - service1 := &JwtService{} - err := service1.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: t.TempDir(), // Use a different temp dir - }) - require.NoError(t, err, "Failed to initialize first JWT service") + service1, _, _ := setupJwtService(t, mockConfig) + service2, _, _ := setupJwtService(t, mockConfig) - service2 := &JwtService{} - err = service2.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: t.TempDir(), // Use a different temp dir - }) - require.NoError(t, err, "Failed to initialize second JWT service") - - // Generate a token with the first service tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123") require.NoError(t, err, "Failed to generate refresh token") - // Verify with the second service should fail due to different keys _, _, _, err = service2.VerifyOAuthRefreshToken(tokenString) require.Error(t, err, "Verification should fail with invalid signature") assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure") @@ -1380,17 +1019,8 @@ func TestTokenTypeValidator(t *testing.T) { } func TestGetTokenType(t *testing.T) { - // Create a temporary directory for the test - tempDir := t.TempDir() - - // Initialize the JWT service mockConfig := NewTestAppConfigService(&model.AppConfig{}) - service := &JwtService{} - err := service.init(nil, mockConfig, &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: tempDir, - }) - require.NoError(t, err, "Failed to initialize JWT service") + service, _, _ := setupJwtService(t, mockConfig) buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string { t.Helper() @@ -1417,91 +1047,35 @@ func TestGetTokenType(t *testing.T) { t.Run("correctly identifies access tokens", func(t *testing.T) { tokenString := buildTokenForType(t, AccessTokenJWTType, nil) - // Get the token type without validating tokenType, _, err := service.GetTokenType(tokenString) require.NoError(t, err, "GetTokenType should not return an error") - assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token") + assert.Equal(t, AccessTokenJWTType, tokenType, "Should identify access token type") }) t.Run("correctly identifies ID tokens", func(t *testing.T) { tokenString := buildTokenForType(t, IDTokenJWTType, nil) - // Get the token type without validating tokenType, _, err := service.GetTokenType(tokenString) require.NoError(t, err, "GetTokenType should not return an error") - assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token") + assert.Equal(t, IDTokenJWTType, tokenType, "Should identify ID token type") }) - t.Run("correctly identifies OAuth access tokens", func(t *testing.T) { - tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil) - - // Get the token type without validating - tokenType, _, err := service.GetTokenType(tokenString) - require.NoError(t, err, "GetTokenType should not return an error") - assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token") - }) - - t.Run("correctly identifies refresh tokens", func(t *testing.T) { - tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil) - - // Get the token type without validating - tokenType, _, err := service.GetTokenType(tokenString) - require.NoError(t, err, "GetTokenType should not return an error") - assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token") - }) - - t.Run("works with expired tokens", func(t *testing.T) { - tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) { - b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago - }) - - // Get the token type without validating - tokenType, _, err := service.GetTokenType(tokenString) - require.NoError(t, err, "GetTokenType should not return an error for expired tokens") - assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens") - }) - - t.Run("returns error for malformed tokens", func(t *testing.T) { - // Try to get the token type of a malformed token - tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token") - require.Error(t, err, "GetTokenType should return an error for malformed tokens") - assert.Empty(t, tokenType, "Token type should be empty for malformed tokens") - }) - - t.Run("returns error for tokens without type claim", func(t *testing.T) { - // Create a token without type claim + t.Run("fails when token type claim is missing", func(t *testing.T) { tokenString := buildTokenForType(t, "", nil) - // Get the token type without validating - tokenType, _, err := service.GetTokenType(tokenString) + _, _, err := service.GetTokenType(tokenString) require.Error(t, err, "GetTokenType should return an error for tokens without type claim") - assert.Empty(t, tokenType, "Token type should be empty when type claim is missing") assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim") }) } -func importKey(t *testing.T, privateKeyRaw any, path string) string { +func importKey(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSchema, appConfig *AppConfigService, privateKeyRaw any) string { t.Helper() privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "") require.NoError(t, err, "Failed to import private key") - keyProvider := &jwkutils.KeyProviderFile{} - err = keyProvider.Init(jwkutils.KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysStorage: "file", - KeysPath: path, - }, - }) - require.NoError(t, err, "Failed to init file key provider") - - err = keyProvider.SaveKey(privateKey) - require.NoError(t, err, "Failed to save key") - - kid, _ := privateKey.KeyID() - require.NotEmpty(t, kid, "Key ID must be set") - - return kid + return saveKeyToDatabase(t, db, envConfig, appConfig, privateKey) } // Because generating a RSA-406 key isn't immediate, we pre-compute one @@ -1510,7 +1084,7 @@ var ( rsaKeyPrecomputeOnce sync.Once ) -func createRSA4096KeyJWK(t *testing.T, path string) string { +func createRSA4096KeyJWK(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSchema, appConfig *AppConfigService) string { t.Helper() rsaKeyPrecomputeOnce.Do(func() { @@ -1521,29 +1095,29 @@ func createRSA4096KeyJWK(t *testing.T, path string) string { } }) - // Import as JWK and save to disk - return importKey(t, rsaKeyPrecomputed, path) + // Import as JWK and save it + return importKey(t, db, envConfig, appConfig, rsaKeyPrecomputed) } -func createECDSAKeyJWK(t *testing.T, path string) string { +func createECDSAKeyJWK(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSchema, appConfig *AppConfigService) string { t.Helper() // Generate a new P-256 ECDSA key privateKeyRaw, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err, "Failed to generate ECDSA key") - // Import as JWK and save to disk - return importKey(t, privateKeyRaw, path) + // Import as JWK and save it + return importKey(t, db, envConfig, appConfig, privateKeyRaw) } // Helper function to create an Ed25519 key and save it as JWK -func createEdDSAKeyJWK(t *testing.T, path string) string { +func createEdDSAKeyJWK(t *testing.T, db *gorm.DB, envConfig *common.EnvConfigSchema, appConfig *AppConfigService) string { t.Helper() // Generate a new Ed25519 key pair _, privateKeyRaw, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err, "Failed to generate Ed25519 key") - // Import as JWK and save to disk - return importKey(t, privateKeyRaw, path) + // Import as JWK and save it + return importKey(t, db, envConfig, appConfig, privateKeyRaw) } diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 0f4b5b6c..7c8ddfc5 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -148,6 +148,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { var err error // Create a test database db := testutils.NewDatabaseForTest(t) + common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef") // Create two JWKs for testing privateJWK, jwkSetJSON := generateTestECDSAKey(t) diff --git a/backend/internal/utils/jwk/key_provider.go b/backend/internal/utils/jwk/key_provider.go index 46da3f3f..0cc32c8b 100644 --- a/backend/internal/utils/jwk/key_provider.go +++ b/backend/internal/utils/jwk/key_provider.go @@ -28,22 +28,14 @@ func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID s return nil, fmt.Errorf("failed to load encryption key: %w", err) } - // Get the key provider - switch envConfig.KeysStorage { - case "file", "": - keyProvider = &KeyProviderFile{} - case "database": - keyProvider = &KeyProviderDatabase{} - default: - return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage) - } + keyProvider = &KeyProviderDatabase{} err = keyProvider.Init(KeyProviderOpts{ DB: db, EnvConfig: envConfig, Kek: kek, }) if err != nil { - return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err) + return nil, fmt.Errorf("failed to init key provider: %w", err) } return keyProvider, nil diff --git a/backend/internal/utils/jwk/key_provider_file.go b/backend/internal/utils/jwk/key_provider_file.go deleted file mode 100644 index b8f2b07f..00000000 --- a/backend/internal/utils/jwk/key_provider_file.go +++ /dev/null @@ -1,202 +0,0 @@ -package jwk - -import ( - "encoding/base64" - "fmt" - "os" - "path/filepath" - - "github.com/lestrrat-go/jwx/v3/jwk" - - "github.com/pocket-id/pocket-id/backend/internal/common" - "github.com/pocket-id/pocket-id/backend/internal/utils" - cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" -) - -const ( - // PrivateKeyFile is the path in the data/keys folder where the key is stored - // This is a JSON file containing a key encoded as JWK - PrivateKeyFile = "jwt_private_key.json" - - // PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored - // This is a encrypted JSON file containing a key encoded as JWK - PrivateKeyFileEncrypted = "jwt_private_key.json.enc" -) - -type KeyProviderFile struct { - envConfig *common.EnvConfigSchema - kek []byte -} - -func (f *KeyProviderFile) Init(opts KeyProviderOpts) error { - f.envConfig = opts.EnvConfig - f.kek = opts.Kek - - return nil -} - -func (f *KeyProviderFile) LoadKey() (jwk.Key, error) { - if len(f.kek) > 0 { - return f.loadEncryptedKey() - } - return f.loadKey() -} - -func (f *KeyProviderFile) SaveKey(key jwk.Key) error { - if len(f.kek) > 0 { - return f.saveKeyEncrypted(key) - } - return f.saveKey(key) -} - -func (f *KeyProviderFile) loadKey() (jwk.Key, error) { - var key jwk.Key - - // First, check if we have a JWK file - // If we do, then we just load that - jwkPath := f.jwkPath() - ok, err := utils.FileExists(jwkPath) - if err != nil { - return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err) - } - if !ok { - // File doesn't exist, no key was loaded - return nil, nil - } - - data, err := os.ReadFile(jwkPath) - if err != nil { - return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err) - } - - key, err = jwk.ParseKey(data) - if err != nil { - return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err) - } - - return key, nil -} - -func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) { - // First, check if we have an encrypted JWK file - // If we do, then we just load that - encJwkPath := f.encJwkPath() - ok, err := utils.FileExists(encJwkPath) - if err != nil { - return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err) - } - if ok { - encB64, err := os.ReadFile(encJwkPath) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err) - } - - // Decode from base64 - enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64))) - n, err := base64.StdEncoding.Decode(enc, encB64) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err) - } - - // Decrypt the data - data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil) - if err != nil { - return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err) - } - - // Parse the key - key, err = jwk.ParseKey(data) - if err != nil { - return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err) - } - - return key, nil - } - - // Check if we have an un-encrypted JWK file - key, err = f.loadKey() - if err != nil { - return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err) - } - if key == nil { - // No key exists, encrypted or un-encrypted - return nil, nil - } - - // If we are here, we have loaded a key that was un-encrypted - // We need to replace the plaintext key with the encrypted one before we return - err = f.saveKeyEncrypted(key) - if err != nil { - return nil, fmt.Errorf("failed to save encrypted key file: %w", err) - } - jwkPath := f.jwkPath() - err = os.Remove(jwkPath) - if err != nil { - return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err) - } - - return key, nil -} - -func (f *KeyProviderFile) saveKey(key jwk.Key) error { - err := os.MkdirAll(f.envConfig.KeysPath, 0700) - if err != nil { - return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err) - } - - jwkPath := f.jwkPath() - keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err) - } - defer keyFile.Close() - - // Write the JSON file to disk - err = EncodeJWK(keyFile, key) - if err != nil { - return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err) - } - - return nil -} - -func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error { - err := os.MkdirAll(f.envConfig.KeysPath, 0700) - if err != nil { - return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err) - } - - // Encode the key to JSON - data, err := EncodeJWKBytes(key) - if err != nil { - return fmt.Errorf("failed to encode key to JSON: %w", err) - } - - // Encrypt the key then encode to Base64 - enc, err := cryptoutils.Encrypt(f.kek, data, nil) - if err != nil { - return fmt.Errorf("failed to encrypt key: %w", err) - } - encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc))) - base64.StdEncoding.Encode(encB64, enc) - - // Write to disk - encJwkPath := f.encJwkPath() - err = os.WriteFile(encJwkPath, encB64, 0600) - if err != nil { - return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err) - } - - return nil -} - -func (f *KeyProviderFile) jwkPath() string { - return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile) -} - -func (f *KeyProviderFile) encJwkPath() string { - return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted) -} - -// Compile-time interface check -var _ KeyProvider = (*KeyProviderFile)(nil) diff --git a/backend/internal/utils/jwk/key_provider_file_test.go b/backend/internal/utils/jwk/key_provider_file_test.go deleted file mode 100644 index 768dbee2..00000000 --- a/backend/internal/utils/jwk/key_provider_file_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package jwk - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "encoding/base64" - "os" - "path/filepath" - "testing" - - "github.com/lestrrat-go/jwx/v3/jwk" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/pocket-id/pocket-id/backend/internal/common" - "github.com/pocket-id/pocket-id/backend/internal/utils" - cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto" -) - -func TestKeyProviderFile_LoadKey(t *testing.T) { - // Generate a test key to use in our tests - pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - key, err := jwk.Import(pk) - require.NoError(t, err) - - t.Run("LoadKey with no existing key", func(t *testing.T) { - tempDir := t.TempDir() - - provider := &KeyProviderFile{} - err := provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - }) - require.NoError(t, err) - - // Load key when none exists - loadedKey, err := provider.LoadKey() - require.NoError(t, err) - assert.Nil(t, loadedKey, "Expected nil key when no key exists") - }) - - t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) { - tempDir := t.TempDir() - - provider := &KeyProviderFile{} - err = provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - Kek: makeKEK(t), - }) - require.NoError(t, err) - - // Load key when none exists - loadedKey, err := provider.LoadKey() - require.NoError(t, err) - assert.Nil(t, loadedKey, "Expected nil key when no key exists") - }) - - t.Run("LoadKey with unencrypted key", func(t *testing.T) { - tempDir := t.TempDir() - - provider := &KeyProviderFile{} - err := provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - }) - require.NoError(t, err) - - // Save a key - err = provider.SaveKey(key) - require.NoError(t, err) - - // Make sure the key file exists - keyPath := filepath.Join(tempDir, PrivateKeyFile) - exists, err := utils.FileExists(keyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected key file to exist") - - // Load the key - loadedKey, err := provider.LoadKey() - require.NoError(t, err) - assert.NotNil(t, loadedKey, "Expected non-nil key when key exists") - - // Verify the loaded key is the same as the original - keyBytes, err := EncodeJWKBytes(key) - require.NoError(t, err) - - loadedKeyBytes, err := EncodeJWKBytes(loadedKey) - require.NoError(t, err) - - assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key") - }) - - t.Run("LoadKey with encrypted key", func(t *testing.T) { - tempDir := t.TempDir() - - provider := &KeyProviderFile{} - err = provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - Kek: makeKEK(t), - }) - require.NoError(t, err) - - // Save a key (will be encrypted) - err = provider.SaveKey(key) - require.NoError(t, err) - - // Make sure the encrypted key file exists - encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) - exists, err := utils.FileExists(encKeyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected encrypted key file to exist") - - // Make sure the unencrypted key file does not exist - keyPath := filepath.Join(tempDir, PrivateKeyFile) - exists, err = utils.FileExists(keyPath) - require.NoError(t, err) - assert.False(t, exists, "Expected unencrypted key file to not exist") - - // Load the key - loadedKey, err := provider.LoadKey() - require.NoError(t, err) - assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists") - - // Verify the loaded key is the same as the original - keyBytes, err := EncodeJWKBytes(key) - require.NoError(t, err) - - loadedKeyBytes, err := EncodeJWKBytes(loadedKey) - require.NoError(t, err) - - assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key") - }) - - t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) { - tempDir := t.TempDir() - - // First, create an unencrypted key - providerNoKek := &KeyProviderFile{} - err := providerNoKek.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - }) - require.NoError(t, err) - - // Save an unencrypted key - err = providerNoKek.SaveKey(key) - require.NoError(t, err) - - // Verify unencrypted key exists - keyPath := filepath.Join(tempDir, PrivateKeyFile) - exists, err := utils.FileExists(keyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected unencrypted key file to exist") - - // Now create a provider with a kek - kek := make([]byte, 32) - _, err = rand.Read(kek) - require.NoError(t, err) - - providerWithKek := &KeyProviderFile{} - err = providerWithKek.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - Kek: kek, - }) - require.NoError(t, err) - - // Load the key - this should convert the unencrypted key to encrypted - loadedKey, err := providerWithKek.LoadKey() - require.NoError(t, err) - assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key") - - // Verify the unencrypted key no longer exists - exists, err = utils.FileExists(keyPath) - require.NoError(t, err) - assert.False(t, exists, "Expected unencrypted key file to be removed") - - // Verify the encrypted key file exists - encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) - exists, err = utils.FileExists(encKeyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected encrypted key file to exist after conversion") - - // Verify the key data - keyBytes, err := EncodeJWKBytes(key) - require.NoError(t, err) - - loadedKeyBytes, err := EncodeJWKBytes(loadedKey) - require.NoError(t, err) - - assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion") - }) -} - -func TestKeyProviderFile_SaveKey(t *testing.T) { - // Generate a test key to use in our tests - pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - key, err := jwk.Import(pk) - require.NoError(t, err) - - t.Run("SaveKey unencrypted", func(t *testing.T) { - tempDir := t.TempDir() - - provider := &KeyProviderFile{} - err := provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - }) - require.NoError(t, err) - - // Save the key - err = provider.SaveKey(key) - require.NoError(t, err) - - // Verify the key file exists - keyPath := filepath.Join(tempDir, PrivateKeyFile) - exists, err := utils.FileExists(keyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected key file to exist") - - // Verify the content of the key file - data, err := os.ReadFile(keyPath) - require.NoError(t, err) - - parsedKey, err := jwk.ParseKey(data) - require.NoError(t, err) - - // Compare the saved key with the original - keyBytes, err := EncodeJWKBytes(key) - require.NoError(t, err) - - parsedKeyBytes, err := EncodeJWKBytes(parsedKey) - require.NoError(t, err) - - assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key") - }) - - t.Run("SaveKey encrypted", func(t *testing.T) { - tempDir := t.TempDir() - - // Generate a 64-byte kek - kek := makeKEK(t) - - provider := &KeyProviderFile{} - err = provider.Init(KeyProviderOpts{ - EnvConfig: &common.EnvConfigSchema{ - KeysPath: tempDir, - }, - Kek: kek, - }) - require.NoError(t, err) - - // Save the key (will be encrypted) - err = provider.SaveKey(key) - require.NoError(t, err) - - // Verify the encrypted key file exists - encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted) - exists, err := utils.FileExists(encKeyPath) - require.NoError(t, err) - assert.True(t, exists, "Expected encrypted key file to exist") - - // Verify the unencrypted key file doesn't exist - keyPath := filepath.Join(tempDir, PrivateKeyFile) - exists, err = utils.FileExists(keyPath) - require.NoError(t, err) - assert.False(t, exists, "Expected unencrypted key file to not exist") - - // Manually decrypt the encrypted key file to verify it contains the correct key - encB64, err := os.ReadFile(encKeyPath) - require.NoError(t, err) - - // Decode from base64 - enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64))) - n, err := base64.StdEncoding.Decode(enc, encB64) - require.NoError(t, err) - enc = enc[:n] // Trim any padding - - // Decrypt the data - data, err := cryptoutils.Decrypt(kek, enc, nil) - require.NoError(t, err) - - // Parse the key - parsedKey, err := jwk.ParseKey(data) - require.NoError(t, err) - - // Compare the decrypted key with the original - keyBytes, err := EncodeJWKBytes(key) - require.NoError(t, err) - - parsedKeyBytes, err := EncodeJWKBytes(parsedKey) - require.NoError(t, err) - - assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key") - }) -} - -func makeKEK(t *testing.T) []byte { - t.Helper() - - // Generate a 32-byte kek - kek := make([]byte, 32) - _, err := rand.Read(kek) - require.NoError(t, err) - return kek -} diff --git a/tests/setup/docker-compose-s3.yml b/tests/setup/docker-compose-s3.yml index 8159ee0b..159c1c8a 100644 --- a/tests/setup/docker-compose-s3.yml +++ b/tests/setup/docker-compose-s3.yml @@ -28,14 +28,13 @@ services: file: docker-compose.yml service: pocket-id environment: - - S3_BUCKET=pocket-id-test - - S3_REGION=us-east-1 - - S3_ENDPOINT=http://localstack-s3:4566 - - S3_ACCESS_KEY_ID=test - - S3_SECRET_ACCESS_KEY=test - - S3_FORCE_PATH_STYLE=true - - KEYS_STORAGE=database - - ENCRYPTION_KEY=test1234test1234test1234test1234 + FILE_BACKEND: s3 + S3_BUCKET: pocket-id-test + S3_REGION: us-east-1 + S3_ENDPOINT: http://localstack-s3:4566 + S3_ACCESS_KEY_ID: test + S3_SECRET_ACCESS_KEY: test + S3_FORCE_PATH_STYLE: true depends_on: create-bucket: condition: service_completed_successfully diff --git a/tests/setup/docker-compose.yml b/tests/setup/docker-compose.yml index 74e1e778..8ac7d80f 100644 --- a/tests/setup/docker-compose.yml +++ b/tests/setup/docker-compose.yml @@ -13,8 +13,9 @@ services: ports: - '1411:1411' environment: - - APP_ENV=test - - FILE_BACKEND=${FILE_BACKEND} + APP_ENV: test + ENCRYPTION_KEY: test-encryption-key + FILE_BACKEND: ${FILE_BACKEND} build: args: - BUILD_TAGS=e2etest