mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-16 19:20:17 +00:00
refactor!: remove old DB env variables, and jwk migrations logic (#529)
This commit is contained in:
@@ -20,10 +20,6 @@ func Bootstrap() error {
|
|||||||
|
|
||||||
initApplicationImages()
|
initApplicationImages()
|
||||||
|
|
||||||
// Perform migrations for changes
|
|
||||||
migrateConfigDBConnstring()
|
|
||||||
migrateKey()
|
|
||||||
|
|
||||||
// Initialize the tracer and metrics exporter
|
// Initialize the tracer and metrics exporter
|
||||||
shutdownFns, httpClient, err := initOtel(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
|
shutdownFns, httpClient, err := initOtel(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Performs the migration of the database connection string
|
|
||||||
// See: https://github.com/pocket-id/pocket-id/pull/388
|
|
||||||
func migrateConfigDBConnstring() {
|
|
||||||
switch common.EnvConfig.DbProvider {
|
|
||||||
case common.DbProviderSqlite:
|
|
||||||
// Check if we're using the deprecated SqliteDBPath env var
|
|
||||||
if common.EnvConfig.SqliteDBPath != "" {
|
|
||||||
connString := "file:" + common.EnvConfig.SqliteDBPath + "?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate"
|
|
||||||
common.EnvConfig.DbConnectionString = connString
|
|
||||||
common.EnvConfig.SqliteDBPath = ""
|
|
||||||
|
|
||||||
log.Printf("[WARN] Env var 'SQLITE_DB_PATH' is deprecated - use 'DB_CONNECTION_STRING' instead with the value: '%s'", connString)
|
|
||||||
}
|
|
||||||
case common.DbProviderPostgres:
|
|
||||||
// Check if we're using the deprecated PostgresConnectionString alias
|
|
||||||
if common.EnvConfig.PostgresConnectionString != "" {
|
|
||||||
common.EnvConfig.DbConnectionString = common.EnvConfig.PostgresConnectionString
|
|
||||||
common.EnvConfig.PostgresConnectionString = ""
|
|
||||||
|
|
||||||
log.Print("[WARN] Env var 'POSTGRES_CONNECTION_STRING' is deprecated - use 'DB_CONNECTION_STRING' instead with the same value")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// We don't do anything here in the default case
|
|
||||||
// This is an error, but will be handled later on
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
privateKeyFilePem = "jwt_private_key.pem"
|
|
||||||
)
|
|
||||||
|
|
||||||
func migrateKey() {
|
|
||||||
err := migrateKeyInternal(common.EnvConfig.KeysPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to perform migration of keys: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func migrateKeyInternal(basePath string) error {
|
|
||||||
// First, check if there's already a JWK stored
|
|
||||||
jwkPath := filepath.Join(basePath, service.PrivateKeyFile)
|
|
||||||
ok, err := utils.FileExists(jwkPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
// There's already a key as JWK, so we don't do anything else here
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if there's a PEM file
|
|
||||||
pemPath := filepath.Join(basePath, privateKeyFilePem)
|
|
||||||
ok, err = utils.FileExists(pemPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check if private key file (PEM) exists at path '%s': %w", pemPath, err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
// No file to migrate, return
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load and validate the key
|
|
||||||
key, err := loadKeyPEM(pemPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load private key file (PEM) at path '%s': %w", pemPath, err)
|
|
||||||
}
|
|
||||||
err = service.ValidateKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("key object is invalid: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the key as JWK
|
|
||||||
err = service.SaveKeyJWK(key, jwkPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, delete the PEM file
|
|
||||||
err = os.Remove(pemPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove migrated key at path '%s': %w", pemPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadKeyPEM(path string) (jwk.Key, error) {
|
|
||||||
// Load the key from disk and parse it
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read key data: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the key ID using the "legacy" algorithm
|
|
||||||
keyId, err := generateKeyID(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
|
||||||
}
|
|
||||||
err = key.Set(jwk.KeyIDKey, keyId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set key ID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate other required fields
|
|
||||||
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
|
||||||
service.EnsureAlgInKey(key)
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key's PKIX-serialized structure.
|
|
||||||
// This is used for legacy keys, imported from PEM.
|
|
||||||
func generateKeyID(key jwk.Key) (string, error) {
|
|
||||||
// Export the public key and serialize it to PKIX (not in a PEM block)
|
|
||||||
// This is for backwards-compatibility with the algorithm used before the switch to JWK
|
|
||||||
pubKey, err := key.PublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get public key: %w", err)
|
|
||||||
}
|
|
||||||
var pubKeyRaw any
|
|
||||||
err = jwk.Export(pubKey, &pubKeyRaw)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to export public key: %w", err)
|
|
||||||
}
|
|
||||||
pubASN1, err := x509.MarshalPKIXPublicKey(pubKeyRaw)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to marshal public key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute SHA-256 hash of the public key
|
|
||||||
hash := sha256.New()
|
|
||||||
hash.Write(pubASN1)
|
|
||||||
hashed := hash.Sum(nil)
|
|
||||||
|
|
||||||
// Truncate the hash to the first 8 bytes for a shorter Key ID
|
|
||||||
shortHash := hashed[:8]
|
|
||||||
|
|
||||||
// Return Base64 encoded truncated hash as Key ID
|
|
||||||
return base64.RawURLEncoding.EncodeToString(shortHash), nil
|
|
||||||
}
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMigrateKey(t *testing.T) {
|
|
||||||
// Create a temporary directory for testing
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
t.Run("no keys exist", func(t *testing.T) {
|
|
||||||
// Test when no keys exist
|
|
||||||
err := migrateKeyInternal(tempDir)
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("jwk already exists", func(t *testing.T) {
|
|
||||||
// Create a JWK file
|
|
||||||
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
|
|
||||||
key, err := createTestRSAKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = service.SaveKeyJWK(key, jwkPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Run migration - should do nothing
|
|
||||||
err = migrateKeyInternal(tempDir)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Check the file still exists
|
|
||||||
exists, err := utils.FileExists(jwkPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, exists)
|
|
||||||
|
|
||||||
// Delete for next test
|
|
||||||
err = os.Remove(jwkPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("migrate pem to jwk", func(t *testing.T) {
|
|
||||||
// Create a PEM file
|
|
||||||
pemPath := filepath.Join(tempDir, privateKeyFilePem)
|
|
||||||
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
|
|
||||||
|
|
||||||
// Generate RSA key and save as PEM
|
|
||||||
createRSAPrivateKeyPEM(t, pemPath)
|
|
||||||
|
|
||||||
// Run migration
|
|
||||||
err := migrateKeyInternal(tempDir)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Check PEM file is gone
|
|
||||||
exists, err := utils.FileExists(pemPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, exists)
|
|
||||||
|
|
||||||
// Check JWK file exists
|
|
||||||
exists, err = utils.FileExists(jwkPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, exists)
|
|
||||||
|
|
||||||
// Verify the JWK can be loaded
|
|
||||||
data, err := os.ReadFile(jwkPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = jwk.ParseKey(data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadKeyPEM(t *testing.T) {
|
|
||||||
// Create a temporary directory for testing
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
t.Run("successfully load PEM key", func(t *testing.T) {
|
|
||||||
pemPath := filepath.Join(tempDir, "test_key.pem")
|
|
||||||
|
|
||||||
// Generate RSA key and save as PEM
|
|
||||||
createRSAPrivateKeyPEM(t, pemPath)
|
|
||||||
|
|
||||||
// Load the key
|
|
||||||
key, err := loadKeyPEM(pemPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify key properties
|
|
||||||
assert.NotEmpty(t, key)
|
|
||||||
|
|
||||||
// Check key ID is set
|
|
||||||
var keyID string
|
|
||||||
err = key.Get(jwk.KeyIDKey, &keyID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotEmpty(t, keyID)
|
|
||||||
|
|
||||||
// Check algorithm is set
|
|
||||||
var alg jwa.SignatureAlgorithm
|
|
||||||
err = key.Get(jwk.AlgorithmKey, &alg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotEmpty(t, alg)
|
|
||||||
|
|
||||||
// Check key usage is set
|
|
||||||
var keyUsage string
|
|
||||||
err = key.Get(jwk.KeyUsageKey, &keyUsage)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, service.KeyUsageSigning, keyUsage)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("file not found", func(t *testing.T) {
|
|
||||||
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, key)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid file content", func(t *testing.T) {
|
|
||||||
invalidPath := filepath.Join(tempDir, "invalid.pem")
|
|
||||||
err := os.WriteFile(invalidPath, []byte("not a valid PEM"), 0600)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
key, err := loadKeyPEM(invalidPath)
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, key)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateKeyID(t *testing.T) {
|
|
||||||
key, err := createTestRSAKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
keyID, err := generateKeyID(key)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Key ID should be non-empty
|
|
||||||
assert.NotEmpty(t, keyID)
|
|
||||||
|
|
||||||
// Generate another key ID to prove it depends on the key
|
|
||||||
key2, err := createTestRSAKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
keyID2, err := generateKeyID(key2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// The two key IDs should be different
|
|
||||||
assert.NotEqual(t, keyID, keyID2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
|
|
||||||
func createTestRSAKey() (jwk.Key, error) {
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := jwk.Import(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createRSAPrivateKeyPEM generates an RSA private key and returns its PEM-encoded form
|
|
||||||
func createRSAPrivateKeyPEM(t *testing.T, pemPath string) ([]byte, *rsa.PrivateKey) {
|
|
||||||
// Generate RSA key
|
|
||||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Encode to PEM format
|
|
||||||
pemData := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
err = os.WriteFile(pemPath, pemData, 0600)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return pemData, privKey
|
|
||||||
}
|
|
||||||
@@ -28,8 +28,6 @@ type EnvConfigSchema struct {
|
|||||||
AppURL string `env:"PUBLIC_APP_URL"`
|
AppURL string `env:"PUBLIC_APP_URL"`
|
||||||
DbProvider DbProvider `env:"DB_PROVIDER"`
|
DbProvider DbProvider `env:"DB_PROVIDER"`
|
||||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||||
SqliteDBPath string `env:"SQLITE_DB_PATH"` // Deprecated: use "DB_CONNECTION_STRING" instead
|
|
||||||
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` // Deprecated: use "DB_CONNECTION_STRING" instead
|
|
||||||
UploadPath string `env:"UPLOAD_PATH"`
|
UploadPath string `env:"UPLOAD_PATH"`
|
||||||
KeysPath string `env:"KEYS_PATH"`
|
KeysPath string `env:"KEYS_PATH"`
|
||||||
Port string `env:"BACKEND_PORT"`
|
Port string `env:"BACKEND_PORT"`
|
||||||
@@ -46,8 +44,6 @@ var EnvConfig = &EnvConfigSchema{
|
|||||||
AppEnv: "production",
|
AppEnv: "production",
|
||||||
DbProvider: "sqlite",
|
DbProvider: "sqlite",
|
||||||
DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
|
DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
|
||||||
SqliteDBPath: "",
|
|
||||||
PostgresConnectionString: "",
|
|
||||||
UploadPath: "data/uploads",
|
UploadPath: "data/uploads",
|
||||||
KeysPath: "data/keys",
|
KeysPath: "data/keys",
|
||||||
AppURL: "http://localhost",
|
AppURL: "http://localhost",
|
||||||
|
|||||||
Reference in New Issue
Block a user