1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-16 11:21:12 +00:00

feat: encrypt private keys saved on disk and in database (#682)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-07-03 11:34:34 -07:00
committed by GitHub
parent 9872608d61
commit 5550729120
25 changed files with 2311 additions and 328 deletions

View File

@@ -2,23 +2,20 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
const (
@@ -26,8 +23,9 @@ const (
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
@@ -59,58 +57,74 @@ const (
)
type JwtService struct {
envConfig *common.EnvConfigSchema
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService {
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
err := service.init(db, appConfigService, &common.EnvConfig)
if err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
s.appConfigService = appConfigService
s.envConfig = envConfig
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
return s.loadOrGenerateKey(db)
}
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
// Get the key provider
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
return fmt.Errorf("failed to get key provider: %w", err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
// Try loading a key
key, err := keyProvider.LoadKey()
if err != nil {
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
// If we have a key, store it in the object and we're done
if key != nil {
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
err = s.generateKey()
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Save the newly-generated key
err = keyProvider.SaveKey(s.privateKey)
if err != nil {
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
return nil
}
// generateKey generates a new key and stores it in the object
func (s *JwtService) generateKey() error {
// Default is to generate RS256 (RSA-2048) keys
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
if err != nil {
return fmt.Errorf("failed to generate new private key: %w", err)
}
@@ -121,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
return fmt.Errorf("failed to set private key: %w", err)
}
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
@@ -192,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
Subject(user.ID).
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, common.EnvConfig.AppURL)
err = SetAudienceString(token, s.envConfig.AppURL)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
@@ -229,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithAudience(s.envConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
)
if err != nil {
@@ -246,7 +254,7 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
@@ -305,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
)
@@ -335,7 +343,7 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
Subject(user.ID).
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
@@ -377,7 +385,7 @@ func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, erro
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
)
if err != nil {
@@ -393,7 +401,7 @@ func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, r
Subject(userID).
Expiration(now.Add(RefreshTokenDuration)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
@@ -430,7 +438,7 @@ func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, client
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
)
if err != nil {
@@ -488,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
utils.EnsureAlgInKey(pubKey)
jwkutils.EnsureAlgInKey(pubKey, "", "")
return pubKey, nil
}
@@ -517,56 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
return alg, nil
}
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return utils.ImportRawKey(rawKey)
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {