mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-15 00:35:14 +00:00
feat: encrypt private keys saved on disk and in database (#682)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
9872608d61
commit
5550729120
69
backend/internal/utils/crypto/crypto.go
Normal file
69
backend/internal/utils/crypto/crypto.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
|
||||
var ErrDecrypt = errors.New("failed to decrypt data")
|
||||
|
||||
// Encrypt a byte slice using AES-GCM and a random nonce
|
||||
// Important: do not encrypt more than ~4 billion messages with the same key!
|
||||
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||
}
|
||||
|
||||
// Allocate the slice for the result, with additional space for the nonce and overhead
|
||||
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
|
||||
ciphertext = append(ciphertext, nonce...)
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Tag is automatically added at the end
|
||||
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
|
||||
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt a byte slice using AES-GCM
|
||||
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Extract the nonce
|
||||
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
|
||||
if err != nil {
|
||||
// Note: we do not return the exact error here, to avoid disclosing information
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
208
backend/internal/utils/crypto/crypto_test.go
Normal file
208
backend/internal/utils/crypto/crypto_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keySize int
|
||||
plaintext string
|
||||
associatedData []byte
|
||||
}{
|
||||
{
|
||||
name: "AES-128 with short plaintext",
|
||||
keySize: 16,
|
||||
plaintext: "Hello, World!",
|
||||
associatedData: []byte("test-aad"),
|
||||
},
|
||||
{
|
||||
name: "AES-192 with medium plaintext",
|
||||
keySize: 24,
|
||||
plaintext: "This is a longer message to test encryption and decryption",
|
||||
associatedData: []byte("associated-data-192"),
|
||||
},
|
||||
{
|
||||
name: "AES-256 with unicode",
|
||||
keySize: 32,
|
||||
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
|
||||
associatedData: []byte("unicode-test"),
|
||||
},
|
||||
{
|
||||
name: "No associated data",
|
||||
keySize: 32,
|
||||
plaintext: "Testing without associated data",
|
||||
associatedData: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate random key
|
||||
key := make([]byte, tt.keySize)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte(tt.plaintext)
|
||||
|
||||
// Test encryption
|
||||
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
|
||||
require.NoError(t, err, "Encrypt should succeed")
|
||||
|
||||
// Verify ciphertext is different from plaintext (unless empty)
|
||||
if len(plaintext) > 0 {
|
||||
assert.NotEqual(t, plaintext, ciphertext)
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
|
||||
require.NoError(t, err, "Decrypt should succeed")
|
||||
|
||||
// Verify decrypted text matches original
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
plaintext := []byte("test message")
|
||||
|
||||
_, err := Encrypt(key, plaintext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
ciphertext := []byte("fake ciphertext")
|
||||
|
||||
_, err := Decrypt(key, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ciphertext []byte
|
||||
}{
|
||||
{
|
||||
name: "empty ciphertext",
|
||||
ciphertext: []byte{},
|
||||
},
|
||||
{
|
||||
name: "too short ciphertext",
|
||||
ciphertext: []byte("short"),
|
||||
},
|
||||
{
|
||||
name: "random invalid data",
|
||||
ciphertext: []byte("this is not valid encrypted data"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Decrypt(key, tt.ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKey(t *testing.T) {
|
||||
// Generate two different keys
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
_, err := rand.Read(key1)
|
||||
require.NoError(t, err)
|
||||
_, err = rand.Read(key2)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
|
||||
// Encrypt with key1
|
||||
ciphertext, err := Encrypt(key1, plaintext, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with key2
|
||||
_, err = Decrypt(key2, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongAssociatedData(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
correctAAD := []byte("correct-aad")
|
||||
wrongAAD := []byte("wrong-aad")
|
||||
|
||||
// Encrypt with correct AAD
|
||||
ciphertext, err := Encrypt(key, plaintext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with wrong AAD
|
||||
_, err = Decrypt(key, ciphertext, wrongAAD)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
|
||||
// Verify correct AAD works
|
||||
decrypted, err := Decrypt(key, ciphertext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
|
||||
}
|
||||
|
||||
func TestEncryptDecryptConsistency(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("consistency test message")
|
||||
associatedData := []byte("test-aad")
|
||||
|
||||
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
|
||||
ciphertext1, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
ciphertext2, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ciphertexts should be different (due to random IV)
|
||||
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
|
||||
|
||||
// Both should decrypt to the same plaintext
|
||||
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
|
||||
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
|
||||
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
|
||||
}
|
||||
50
backend/internal/utils/jwk/key_provider.go
Normal file
50
backend/internal/utils/jwk/key_provider.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
type KeyProviderOpts struct {
|
||||
EnvConfig *common.EnvConfigSchema
|
||||
DB *gorm.DB
|
||||
Kek []byte
|
||||
}
|
||||
|
||||
type KeyProvider interface {
|
||||
Init(opts KeyProviderOpts) error
|
||||
LoadKey() (jwk.Key, error)
|
||||
SaveKey(key jwk.Key) error
|
||||
}
|
||||
|
||||
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
|
||||
// Load the encryption key (KEK) if present
|
||||
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Get the key provider
|
||||
switch envConfig.KeysStorage {
|
||||
case "file", "":
|
||||
keyProvider = &KeyProviderFile{}
|
||||
case "database":
|
||||
keyProvider = &KeyProviderDatabase{}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
|
||||
}
|
||||
err = keyProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
EnvConfig: envConfig,
|
||||
Kek: kek,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
return keyProvider, nil
|
||||
}
|
||||
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const PrivateKeyDBKey = "jwt_private_key.json"
|
||||
|
||||
type KeyProviderDatabase struct {
|
||||
db *gorm.DB
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
|
||||
if len(opts.Kek) == 0 {
|
||||
return errors.New("an encryption key is required when using the 'database' key provider")
|
||||
}
|
||||
|
||||
f.db = opts.DB
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).First(&row).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
|
||||
}
|
||||
|
||||
if row.Value == nil || *row.Value == "" {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc, err := base64.StdEncoding.DecodeString(*row.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := base64.StdEncoding.EncodeToString(enc)
|
||||
|
||||
// Save to database
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encB64,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).Create(&row).Error
|
||||
if err != nil {
|
||||
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
||||
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
||||
return fmt.Errorf("failed to store private key in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderDatabase)(nil)
|
||||
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestKeyProviderDatabase_Init(t *testing.T) {
|
||||
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: nil, // No KEK
|
||||
})
|
||||
require.Error(t, err, "Expected error when KEK is not provided")
|
||||
require.ErrorContains(t, err, "encryption key is required")
|
||||
})
|
||||
|
||||
t.Run("Init succeeds with KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: generateTestKEK(t),
|
||||
})
|
||||
require.NoError(t, err, "Expected no error when KEK is provided")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid base64", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert invalid base64 data
|
||||
invalidBase64 := "not-valid-base64"
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidBase64,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid base64")
|
||||
require.ErrorContains(t, err, "not a valid base64-encoded value")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert valid base64 but invalid encrypted data
|
||||
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
originalKek := generateTestKEK(t)
|
||||
|
||||
// Save a key with the original KEK
|
||||
originalProvider := &KeyProviderDatabase{}
|
||||
err := originalProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: originalKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = originalProvider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to load with a different KEK
|
||||
differentKek := generateTestKEK(t)
|
||||
differentProvider := &KeyProviderDatabase{}
|
||||
err = differentProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: differentKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key with the wrong KEK
|
||||
loadedKey, err := differentProvider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with wrong KEK")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid key data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create invalid key data (valid JSON but not a valid JWK)
|
||||
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
|
||||
|
||||
// Encrypt the invalid key data
|
||||
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Base64 encode the encrypted data
|
||||
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
// Save to database
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encodedData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading invalid key data")
|
||||
require.ErrorContains(t, err, "failed to parse")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey and verify database record", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err, "Expected no error when saving key")
|
||||
|
||||
// Verify record exists in database
|
||||
var kv model.KV
|
||||
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
|
||||
require.NoError(t, err, "Expected to find key in database")
|
||||
require.NotNil(t, kv.Value, "Expected non-nil value in database")
|
||||
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
|
||||
|
||||
// Decode and decrypt to verify content
|
||||
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
|
||||
require.NoError(t, err, "Expected valid base64 encoding")
|
||||
|
||||
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
|
||||
require.NoError(t, err, "Expected valid encrypted data")
|
||||
|
||||
parsedKey, err := jwk.ParseKey(decBytes)
|
||||
require.NoError(t, err, "Expected valid JWK data")
|
||||
|
||||
// Compare keys
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func generateTestKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
)
|
||||
|
||||
type KeyProviderFile struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
|
||||
f.envConfig = opts.EnvConfig
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
|
||||
if len(f.kek) > 0 {
|
||||
return f.loadEncryptedKey()
|
||||
}
|
||||
return f.loadKey()
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
|
||||
if len(f.kek) > 0 {
|
||||
return f.saveKeyEncrypted(key)
|
||||
}
|
||||
return f.saveKey(key)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := f.jwkPath()
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
if !ok {
|
||||
// File doesn't exist, no key was loaded
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
|
||||
// First, check if we have an encrypted JWK file
|
||||
// If we do, then we just load that
|
||||
encJwkPath := f.encJwkPath()
|
||||
ok, err := utils.FileExists(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
if ok {
|
||||
encB64, err := os.ReadFile(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check if we have an un-encrypted JWK file
|
||||
key, err = f.loadKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
|
||||
}
|
||||
if key == nil {
|
||||
// No key exists, encrypted or un-encrypted
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we are here, we have loaded a key that was un-encrypted
|
||||
// We need to replace the plaintext key with the encrypted one before we return
|
||||
err = f.saveKeyEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
|
||||
}
|
||||
jwkPath := f.jwkPath()
|
||||
err = os.Remove(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
jwkPath := f.jwkPath()
|
||||
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
err = EncodeJWK(keyFile, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
|
||||
base64.StdEncoding.Encode(encB64, enc)
|
||||
|
||||
// Write to disk
|
||||
encJwkPath := f.encJwkPath()
|
||||
err = os.WriteFile(encJwkPath, encB64, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) jwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) encJwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderFile)(nil)
|
||||
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
func TestKeyProviderFile_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with encrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Make sure the unencrypted key file does not exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// First, create an unencrypted key
|
||||
providerNoKek := &KeyProviderFile{}
|
||||
err := providerNoKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save an unencrypted key
|
||||
err = providerNoKek.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify unencrypted key exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected unencrypted key file to exist")
|
||||
|
||||
// Now create a provider with a kek
|
||||
kek := make([]byte, 32)
|
||||
_, err = rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
|
||||
providerWithKek := &KeyProviderFile{}
|
||||
err = providerWithKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key - this should convert the unencrypted key to encrypted
|
||||
loadedKey, err := providerWithKek.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
|
||||
|
||||
// Verify the unencrypted key no longer exists
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to be removed")
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err = utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
|
||||
|
||||
// Verify the key data
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderFile_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey unencrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Verify the content of the key file
|
||||
data, err := os.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the saved key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
|
||||
t.Run("SaveKey encrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate a 64-byte kek
|
||||
kek := makeKEK(t)
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Verify the unencrypted key file doesn't exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Manually decrypt the encrypted key file to verify it contains the correct key
|
||||
encB64, err := os.ReadFile(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
require.NoError(t, err)
|
||||
enc = enc[:n] // Trim any padding
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(kek, enc, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the key
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the decrypted key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func makeKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
180
backend/internal/utils/jwk/utils.go
Normal file
180
backend/internal/utils/jwk/utils.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha3"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// EncodeJWK encodes a jwk.Key to a writable stream.
|
||||
func EncodeJWK(w io.Writer, key jwk.Key) error {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(key)
|
||||
}
|
||||
|
||||
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
|
||||
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
err := EncodeJWK(b, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// LoadKeyEncryptionKey loads the key encryption key for JWKs
|
||||
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
|
||||
// Try getting the key from the env var as string
|
||||
kekInput := []byte(envConfig.EncryptionKey)
|
||||
|
||||
// If there's nothing in the env, try loading from file
|
||||
if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" {
|
||||
kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If there's still no key, return
|
||||
if len(kekInput) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We need a 256-bit key for encryption with AES-GCM-256
|
||||
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
|
||||
// The key is tied to a specific instance of Pocket ID
|
||||
h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput)
|
||||
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
|
||||
kek = h.Sum(nil)
|
||||
|
||||
return kek, nil
|
||||
}
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key, alg, crv)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
if alg != "" {
|
||||
_ = key.Set(jwk.AlgorithmKey, alg)
|
||||
if crv != "" {
|
||||
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
|
||||
if ok {
|
||||
switch key.KeyType() {
|
||||
case jwa.EC():
|
||||
_ = key.Set(jwk.ECDSACrvKey, eca)
|
||||
case jwa.OKP():
|
||||
_ = key.Set(jwk.OKPCrvKey, eca)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't have an algorithm, set the default for the key type
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA and Ed25519 for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKey generates a new jwk.Key
|
||||
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
|
||||
var rawKey any
|
||||
switch alg {
|
||||
case jwa.RS256().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
case jwa.RS384().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
||||
case jwa.RS512().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
case jwa.ES256().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case jwa.ES384().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case jwa.ES512().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
case jwa.EdDSA().String():
|
||||
switch crv {
|
||||
case jwa.Ed25519().String():
|
||||
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
default:
|
||||
return nil, errors.New("unsupported curve for EdDSA algorithm")
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported key algorithm")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return ImportRawKey(rawKey, alg, crv)
|
||||
}
|
||||
324
backend/internal/utils/jwk/utils_test.go
Normal file
324
backend/internal/utils/jwk/utils_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
alg string
|
||||
crv string
|
||||
expectError bool
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
}{
|
||||
{
|
||||
name: "RS256",
|
||||
alg: jwa.RS256().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS256(),
|
||||
},
|
||||
{
|
||||
name: "RS384",
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS384(),
|
||||
},
|
||||
// Skip the RS512 test as generating a RSA-4096 key can take some time
|
||||
/* {
|
||||
name: "RS512",
|
||||
alg: jwa.RS512().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS512(),
|
||||
}, */
|
||||
{
|
||||
name: "ES256",
|
||||
alg: jwa.ES256().String(),
|
||||
crv: jwa.P256().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES256(),
|
||||
},
|
||||
{
|
||||
name: "ES384",
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES384(),
|
||||
},
|
||||
{
|
||||
name: "ES512",
|
||||
alg: jwa.ES512().String(),
|
||||
crv: jwa.P521().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES512(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with Ed25519",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with unsupported curve",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: "unsupported",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsupported algorithm",
|
||||
alg: "UNSUPPORTED",
|
||||
crv: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key, err := GenerateKey(tt.alg, tt.crv)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm is set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok, "algorithm should be set in the key")
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify other required fields are set
|
||||
kid, ok := key.KeyID()
|
||||
assert.True(t, ok, "key ID should be set")
|
||||
assert.NotEmpty(t, kid, "key ID should not be empty")
|
||||
|
||||
usage, ok := key.KeyUsage()
|
||||
assert.True(t, ok, "key usage should be set")
|
||||
assert.Equal(t, KeyUsageSigning, usage)
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
|
||||
// Verify key type matches expected algorithm
|
||||
switch tt.expectedAlg {
|
||||
case jwa.RS256(), jwa.RS384(), jwa.RS512():
|
||||
assert.Equal(t, jwa.RSA(), key.KeyType())
|
||||
assert.Nil(t, crv)
|
||||
case jwa.ES256(), jwa.ES384(), jwa.ES512():
|
||||
assert.Equal(t, jwa.EC(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
case jwa.EdDSA():
|
||||
assert.Equal(t, jwa.OKP(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAlgInKey(t *testing.T) {
|
||||
// Generate an RSA-2048 key
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("does not change alg already set", func(t *testing.T) {
|
||||
// Import the RSA key
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pre-set the algorithm
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
|
||||
// Call EnsureAlgInKey with a different algorithm
|
||||
EnsureAlgInKey(key, jwa.RS384().String(), "")
|
||||
|
||||
// Verify the algorithm wasn't changed
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
})
|
||||
|
||||
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
alg string
|
||||
crv string
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with RS384",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectedAlg: jwa.RS384(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key with ES384",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectedAlg: jwa.ES384(),
|
||||
expectedCrv: jwa.P384().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key with EdDSA",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey
|
||||
EnsureAlgInKey(key, tt.alg, tt.crv)
|
||||
|
||||
// Verify the algorithm was set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set default algorithms if not present", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key defaults to RS256",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
expectedAlg: jwa.RS256(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key defaults to ES256 with P256",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
expectedAlg: jwa.ES256(),
|
||||
expectedCrv: jwa.P256().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key defaults to EdDSA with Ed25519",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey with empty parameters
|
||||
EnsureAlgInKey(key, "", "")
|
||||
|
||||
// Verify the default algorithm was set
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Call EnsureAlgInKey with invalid curve
|
||||
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
|
||||
|
||||
// Verify algorithm was set but curve was not
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
assert.Nil(t, crv)
|
||||
})
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
}
|
||||
}
|
||||
72
backend/internal/utils/testing/database.go
Normal file
72
backend/internal/utils/testing/database.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
|
||||
// Each database connection is unique for the test.
|
||||
// All migrations are automatically performed.
|
||||
func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
dbName := utils.CreateSha256Hash(t.Name())
|
||||
|
||||
// Connect to a new in-memory SQL database
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
||||
&gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: logger.New(
|
||||
testLoggerAdapter{t: t},
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
ParameterizedQueries: false,
|
||||
Colorful: false,
|
||||
},
|
||||
),
|
||||
})
|
||||
require.NoError(t, err, "Failed to connect to test database")
|
||||
|
||||
// Perform migrations with the embedded migrations
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err, "Failed to get sql.DB")
|
||||
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
|
||||
require.NoError(t, err, "Failed to create migration driver")
|
||||
source, err := iofs.New(resources.FS, "migrations/sqlite")
|
||||
require.NoError(t, err, "Failed to create embedded migration source")
|
||||
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
||||
require.NoError(t, err, "Failed to create migration instance")
|
||||
err = m.Up()
|
||||
require.NoError(t, err, "Failed to perform migrations")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Implements gorm's logger.Writer interface
|
||||
type testLoggerAdapter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
38
backend/internal/utils/testing/round_tripper.go
Normal file
38
backend/internal/utils/testing/round_tripper.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||
type MockRoundTripper struct {
|
||||
Err error
|
||||
Responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Check if we have a specific response for this URL
|
||||
for url, resp := range m.Responses {
|
||||
if req.URL.String() == url {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||
}
|
||||
|
||||
// NewMockResponse creates an http.Response with the given status code and body
|
||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user