mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-16 18:45:17 +00:00
feat: add support for ECDSA and EdDSA keys (#359)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
5c198c280c
commit
96876a99c5
@@ -11,13 +11,11 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
@@ -34,6 +32,13 @@ const (
|
||||
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
|
||||
// IsAdminClaim is a boolean claim used in access tokens for admin users
|
||||
// This may be omitted on non-admin tokens
|
||||
IsAdminClaim = "isAdmin"
|
||||
|
||||
// Acceptable clock skew for verifying tokens
|
||||
clockSkew = time.Minute
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
@@ -61,11 +66,6 @@ func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) e
|
||||
return s.loadOrGenerateKey(keysPath)
|
||||
}
|
||||
|
||||
type AccessTokenJWTClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
}
|
||||
|
||||
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
var key jwk.Key
|
||||
@@ -170,133 +170,164 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
|
||||
claim := AccessTokenJWTClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{common.EnvConfig.AppURL},
|
||||
},
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
var privateKeyRaw any
|
||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
signed, err := token.SignedString(privateKeyRaw)
|
||||
err = SetAudienceString(token, common.EnvConfig.AppURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetIsAdmin(token, user.IsAdmin)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return signed, nil
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return s.getPublicKeyRaw()
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithAudience(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, common.EnvConfig.AppURL) {
|
||||
return nil, errors.New("audience doesn't match")
|
||||
}
|
||||
return claims, nil
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
|
||||
// Initialize with capacity for userClaims, + 4 fixed claims, + 2 claims which may be set in some cases, to avoid re-allocations
|
||||
claims := make(jwt.MapClaims, len(userClaims)+6)
|
||||
claims["aud"] = clientID
|
||||
claims["exp"] = jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
|
||||
claims["iat"] = jwt.NewNumericDate(time.Now())
|
||||
claims["iss"] = common.EnvConfig.AppURL
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
claims[k] = v
|
||||
err = token.Set(k, v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||
}
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
err = token.Set("nonce", nonce)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
var privateKeyRaw any
|
||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return token.SignedString(privateKeyRaw)
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return s.getPublicKeyRaw()
|
||||
}, jwt.WithIssuer(common.EnvConfig.AppURL))
|
||||
func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
|
||||
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
opts := make([]jwt.ParseOption, 0)
|
||||
|
||||
// These options are always present
|
||||
opts = append(opts,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
)
|
||||
|
||||
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
|
||||
// In case we want to accept expired tokens (during logout), we need to set the validators explicitly without validating "exp"
|
||||
if acceptExpiredTokens {
|
||||
// This is equivalent to the default validators except it doesn't validate "exp"
|
||||
opts = append(opts,
|
||||
jwt.WithResetValidators(true),
|
||||
jwt.WithValidator(jwt.IsIssuedAtValid()),
|
||||
jwt.WithValidator(jwt.IsNbfValid()),
|
||||
)
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
token, err := jwt.ParseString(tokenString, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
claim := jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{clientID},
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
var privateKeyRaw any
|
||||
err := jwk.Export(s.privateKey, &privateKeyRaw)
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to export private key object: %w", err)
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
return token.SignedString(privateKeyRaw)
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return s.getPublicKeyRaw()
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
|
||||
@@ -325,17 +356,18 @@ func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
|
||||
return s.jwksEncoded, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) getPublicKeyRaw() (any, error) {
|
||||
pubKey, err := s.privateKey.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
// GetKeyAlg returns the algorithm of the key
|
||||
func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
|
||||
if len(s.jwksEncoded) == 0 {
|
||||
return nil, errors.New("key is not initialized")
|
||||
}
|
||||
var pubKeyRaw any
|
||||
err = jwk.Export(pubKey, &pubKeyRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to export raw public key: %w", err)
|
||||
|
||||
alg, ok := s.privateKey.Algorithm()
|
||||
if !ok || alg == nil {
|
||||
return nil, errors.New("failed to retrieve algorithm for key")
|
||||
}
|
||||
return pubKeyRaw, nil
|
||||
|
||||
return alg, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
||||
@@ -438,3 +470,28 @@ func generateRandomKeyID() (string, error) {
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
if !token.Has(IsAdminClaim) {
|
||||
return false, nil
|
||||
}
|
||||
var isAdmin bool
|
||||
err := token.Get(IsAdminClaim, &isAdmin)
|
||||
return isAdmin, err
|
||||
}
|
||||
|
||||
// SetIsAdmin sets the "isAdmin" claim in the token
|
||||
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
|
||||
// Only set if true
|
||||
if !isAdmin {
|
||||
return nil
|
||||
}
|
||||
return token.Set(IsAdminClaim, isAdmin)
|
||||
}
|
||||
|
||||
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
|
||||
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
|
||||
func SetAudienceString(token jwt.Token, audience string) error {
|
||||
return token.Set(jwt.AudienceKey, audience)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user