1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 15:10:08 +00:00

feat: JWT bearer assertions for client authentication (#566)

Co-authored-by: Kyle Mendell <ksm@ofkm.us>
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-06-06 03:23:51 -07:00
committed by GitHub
parent 035b2c022b
commit 05bfe00924
38 changed files with 1464 additions and 293 deletions

View File

@@ -3,18 +3,25 @@ package service
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"mime/multipart"
"net/http"
"os"
"regexp"
"slices"
"strings"
"time"
"github.com/lestrrat-go/httprc/v3"
"github.com/lestrrat-go/httprc/v3/errsink"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
@@ -31,6 +38,8 @@ const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
)
type OidcService struct {
@@ -39,16 +48,61 @@ type OidcService struct {
appConfigService *AppConfigService
auditLogService *AuditLogService
customClaimService *CustomClaimService
httpClient *http.Client
jwkCache *jwk.Cache
}
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
return &OidcService{
func NewOidcService(
ctx context.Context,
db *gorm.DB,
jwtService *JwtService,
appConfigService *AppConfigService,
auditLogService *AuditLogService,
customClaimService *CustomClaimService,
) (s *OidcService, err error) {
s = &OidcService{
db: db,
jwtService: jwtService,
appConfigService: appConfigService,
auditLogService: auditLogService,
customClaimService: customClaimService,
}
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
s.jwkCache, err = s.getJWKCache(ctx)
if err != nil {
return nil, err
}
return s, nil
}
func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
// We need to create a custom HTTP client to set a timeout.
client := s.httpClient
if client == nil {
client = &http.Client{
Timeout: 20 * time.Second,
}
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
// Indicates a development-time error
panic("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
client.Transport = transport
}
// Create the JWKS cache
return jwk.NewCache(ctx,
httprc.NewClient(
httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
httprc.WithHTTPClient(client),
),
)
}
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
@@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: clientID,
ClientSecret: clientSecret,
})
if err != nil {
return introspectDto, err
}
@@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
introspectDto.Active = true
introspectDto.TokenType = "access_token"
if token.Has("scope") {
var asString string
var asStrings []string
var (
asString string
asStrings []string
)
if err := token.Get("scope", &asString); err == nil {
introspectDto.Scope = asString
} else if err := token.Get("scope", &asStrings); err == nil {
introspectDto.Scope = strings.Join(asStrings, " ")
}
}
if expiration, hasExpiration := token.Expiration(); hasExpiration {
if expiration, ok := token.Expiration(); ok {
introspectDto.Expiration = expiration.Unix()
}
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
if issuedAt, ok := token.IssuedAt(); ok {
introspectDto.IssuedAt = issuedAt.Unix()
}
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
if notBefore, ok := token.NotBefore(); ok {
introspectDto.NotBefore = notBefore.Unix()
}
if subject, hasSubject := token.Subject(); hasSubject {
if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject
}
if audience, hasAudience := token.Audience(); hasAudience {
if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
if issuer, hasIssuer := token.Issuer(); hasIssuer {
if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer
}
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
if identifier, ok := token.JwtID(); ok {
introspectDto.Identifier = identifier
}
@@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURLs: input.CallbackURLs,
LogoutCallbackURLs: input.LogoutCallbackURLs,
CreatedByID: userID,
IsPublic: input.IsPublic,
PkceEnabled: input.PkceEnabled,
CreatedByID: userID,
}
updateOIDCClientModelFromDto(&client, &input)
err := s.db.
WithContext(ctx).
@@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return model.OidcClient{}, err
}
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.LogoutCallbackURLs = input.LogoutCallbackURLs
client.IsPublic = input.IsPublic
client.PkceEnabled = input.IsPublic || input.PkceEnabled
updateOIDCClientModelFromDto(&client, &input)
err = tx.
WithContext(ctx).
@@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return client, nil
}
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
// Base fields
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.LogoutCallbackURLs = input.LogoutCallbackURLs
client.IsPublic = input.IsPublic
// PKCE is required for public clients
client.PkceEnabled = input.IsPublic || input.PkceEnabled
// Credentials
if len(input.Credentials.FederatedIdentities) > 0 {
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
for i, fi := range input.Credentials.FederatedIdentities {
client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
Issuer: fi.Issuer,
Audience: fi.Audience,
Subject: fi.Subject,
JWKS: fi.JWKS,
}
}
}
}
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient
err := s.db.
@@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
}
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
})
if err != nil {
return nil, err
}
@@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
var client model.OidcClient
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
return model.OidcClient{}, err
return nil, err
}
// If we have a client secret, we validate it
// Otherwise, we require the client to be public
if clientSecret != "" {
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
// We have 3 options
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
switch {
// First, if we have a client secret, we validate it
case input.ClientSecret != "":
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
if err != nil {
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
return nil, &common.OidcClientSecretInvalidError{}
}
return &client, nil
// Next, check if we want to use client assertions from federated identities
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
if err != nil {
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
return nil, &common.OidcClientAssertionInvalidError{}
}
return &client, nil
// There's no credentials
// This is allowed only if the client is public
case client.IsPublic:
return &client, nil
// If we're here, we have no credentials AND the client is not public, so credentials are required
default:
return nil, &common.OidcMissingClientCredentialsError{}
}
}
func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
// Check if we have already registered the URL
if !s.jwkCache.IsRegistered(ctx, url) {
// We set a timeout because otherwise Register will keep trying in case of errors
registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
defer registerCancel()
// We need to register the URL
err = s.jwkCache.Register(
registerCtx,
url,
jwk.WithMaxInterval(24*time.Hour),
jwk.WithMinInterval(15*time.Minute),
jwk.WithWaitReady(true),
)
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
return nil, fmt.Errorf("failed to register JWK set: %w", err)
}
return client, nil
} else if !client.IsPublic {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
return client, nil
jwks, err := s.jwkCache.CachedSet(url)
if err != nil {
return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
}
return jwks, nil
}
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := []byte(input.ClientAssertion)
insecureToken, err := jwt.ParseInsecure(assertion)
if err != nil {
return fmt.Errorf("failed to parse client assertion JWT: %w", err)
}
issuer, _ := insecureToken.Issuer()
if issuer == "" {
return errors.New("client assertion does not contain an issuer claim")
}
// Ensure that this client is federated with the one that issued the token
ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
if !ok {
return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
}
// Get the JWK set for the issuer
jwksURL := ocfi.JWKS
if jwksURL == "" {
// Default URL is from the issuer
if strings.HasSuffix(issuer, "/") {
jwksURL = issuer + ".well-known/jwks.json"
} else {
jwksURL = issuer + "/.well-known/jwks.json"
}
}
jwks, err := s.jwkSetForURL(ctx, jwksURL)
if err != nil {
return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
}
// Set default audience and subject if missing
audience := ocfi.Audience
if audience == "" {
// Default to the Pocket ID's URL
audience = common.EnvConfig.AppURL
}
subject := ocfi.Subject
if subject == "" {
// Default to the client ID, per RFC 7523
subject = client.ID
}
// Now re-parse the token with proper validation
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_, err = jwt.Parse(assertion,
jwt.WithValidate(true),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
jwt.WithAudience(audience),
jwt.WithSubject(subject),
)
if err != nil {
return fmt.Errorf("client assertion is not valid: %w", err)
}
// If we're here, the assertion is valid
return nil
}