mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-15 15:10:08 +00:00
feat: JWT bearer assertions for client authentication (#566)
Co-authored-by: Kyle Mendell <ksm@ofkm.us> Co-authored-by: Kyle Mendell <kmendell@ofkm.us> Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
035b2c022b
commit
05bfe00924
@@ -3,18 +3,25 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/httprc/v3"
|
||||
"github.com/lestrrat-go/httprc/v3/errsink"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
@@ -31,6 +38,8 @@ const (
|
||||
GrantTypeAuthorizationCode = "authorization_code"
|
||||
GrantTypeRefreshToken = "refresh_token"
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -39,16 +48,61 @@ type OidcService struct {
|
||||
appConfigService *AppConfigService
|
||||
auditLogService *AuditLogService
|
||||
customClaimService *CustomClaimService
|
||||
|
||||
httpClient *http.Client
|
||||
jwkCache *jwk.Cache
|
||||
}
|
||||
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
|
||||
return &OidcService{
|
||||
func NewOidcService(
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
jwtService *JwtService,
|
||||
appConfigService *AppConfigService,
|
||||
auditLogService *AuditLogService,
|
||||
customClaimService *CustomClaimService,
|
||||
) (s *OidcService, err error) {
|
||||
s = &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: appConfigService,
|
||||
auditLogService: auditLogService,
|
||||
customClaimService: customClaimService,
|
||||
}
|
||||
|
||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||
s.jwkCache, err = s.getJWKCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
||||
// We need to create a custom HTTP client to set a timeout.
|
||||
client := s.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
|
||||
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
// Indicates a development-time error
|
||||
panic("Default transport is not of type *http.Transport")
|
||||
}
|
||||
transport := defaultTransport.Clone()
|
||||
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
|
||||
client.Transport = transport
|
||||
}
|
||||
|
||||
// Create the JWKS cache
|
||||
return jwk.NewCache(ctx,
|
||||
httprc.NewClient(
|
||||
httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
|
||||
httprc.WithHTTPClient(client),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
@@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
})
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
@@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
if token.Has("scope") {
|
||||
var asString string
|
||||
var asStrings []string
|
||||
var (
|
||||
asString string
|
||||
asStrings []string
|
||||
)
|
||||
if err := token.Get("scope", &asString); err == nil {
|
||||
introspectDto.Scope = asString
|
||||
} else if err := token.Get("scope", &asStrings); err == nil {
|
||||
introspectDto.Scope = strings.Join(asStrings, " ")
|
||||
}
|
||||
}
|
||||
if expiration, hasExpiration := token.Expiration(); hasExpiration {
|
||||
if expiration, ok := token.Expiration(); ok {
|
||||
introspectDto.Expiration = expiration.Unix()
|
||||
}
|
||||
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
|
||||
if issuedAt, ok := token.IssuedAt(); ok {
|
||||
introspectDto.IssuedAt = issuedAt.Unix()
|
||||
}
|
||||
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
|
||||
if notBefore, ok := token.NotBefore(); ok {
|
||||
introspectDto.NotBefore = notBefore.Unix()
|
||||
}
|
||||
if subject, hasSubject := token.Subject(); hasSubject {
|
||||
if subject, ok := token.Subject(); ok {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, hasAudience := token.Audience(); hasAudience {
|
||||
if audience, ok := token.Audience(); ok {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, hasIssuer := token.Issuer(); hasIssuer {
|
||||
if issuer, ok := token.Issuer(); ok {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
|
||||
if identifier, ok := token.JwtID(); ok {
|
||||
introspectDto.Identifier = identifier
|
||||
}
|
||||
|
||||
@@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
LogoutCallbackURLs: input.LogoutCallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.PkceEnabled,
|
||||
CreatedByID: userID,
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
@@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
|
||||
// Base fields
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
// PKCE is required for public clients
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
// Credentials
|
||||
if len(input.Credentials.FederatedIdentities) > 0 {
|
||||
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
|
||||
for i, fi := range input.Credentials.FederatedIdentities {
|
||||
client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
|
||||
Issuer: fi.Issuer,
|
||||
Audience: fi.Audience,
|
||||
Subject: fi.Subject,
|
||||
JWKS: fi.JWKS,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
var client model.OidcClient
|
||||
err := s.db.
|
||||
@@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
if input.ClientID == "" {
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
First(&client, "id = ?", input.ClientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we have a client secret, we validate it
|
||||
// Otherwise, we require the client to be public
|
||||
if clientSecret != "" {
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
// We have 3 options
|
||||
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
|
||||
switch {
|
||||
// First, if we have a client secret, we validate it
|
||||
case input.ClientSecret != "":
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
|
||||
if err != nil {
|
||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||
return nil, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
|
||||
// Next, check if we want to use client assertions from federated identities
|
||||
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
|
||||
if err != nil {
|
||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
|
||||
// There's no credentials
|
||||
// This is allowed only if the client is public
|
||||
case client.IsPublic:
|
||||
return &client, nil
|
||||
|
||||
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
||||
default:
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
|
||||
// Check if we have already registered the URL
|
||||
if !s.jwkCache.IsRegistered(ctx, url) {
|
||||
// We set a timeout because otherwise Register will keep trying in case of errors
|
||||
registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer registerCancel()
|
||||
// We need to register the URL
|
||||
err = s.jwkCache.Register(
|
||||
registerCtx,
|
||||
url,
|
||||
jwk.WithMaxInterval(24*time.Hour),
|
||||
jwk.WithMinInterval(15*time.Minute),
|
||||
jwk.WithWaitReady(true),
|
||||
)
|
||||
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
|
||||
if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
|
||||
return nil, fmt.Errorf("failed to register JWK set: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
} else if !client.IsPublic {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
jwks, err := s.jwkCache.CachedSet(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
|
||||
// First, parse the assertion JWT, without validating it, to check the issuer
|
||||
assertion := []byte(input.ClientAssertion)
|
||||
insecureToken, err := jwt.ParseInsecure(assertion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse client assertion JWT: %w", err)
|
||||
}
|
||||
|
||||
issuer, _ := insecureToken.Issuer()
|
||||
if issuer == "" {
|
||||
return errors.New("client assertion does not contain an issuer claim")
|
||||
}
|
||||
|
||||
// Ensure that this client is federated with the one that issued the token
|
||||
ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
|
||||
if !ok {
|
||||
return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
|
||||
}
|
||||
|
||||
// Get the JWK set for the issuer
|
||||
jwksURL := ocfi.JWKS
|
||||
if jwksURL == "" {
|
||||
// Default URL is from the issuer
|
||||
if strings.HasSuffix(issuer, "/") {
|
||||
jwksURL = issuer + ".well-known/jwks.json"
|
||||
} else {
|
||||
jwksURL = issuer + "/.well-known/jwks.json"
|
||||
}
|
||||
}
|
||||
jwks, err := s.jwkSetForURL(ctx, jwksURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
|
||||
}
|
||||
|
||||
// Set default audience and subject if missing
|
||||
audience := ocfi.Audience
|
||||
if audience == "" {
|
||||
// Default to the Pocket ID's URL
|
||||
audience = common.EnvConfig.AppURL
|
||||
}
|
||||
subject := ocfi.Subject
|
||||
if subject == "" {
|
||||
// Default to the client ID, per RFC 7523
|
||||
subject = client.ID
|
||||
}
|
||||
|
||||
// Now re-parse the token with proper validation
|
||||
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
|
||||
_, err = jwt.Parse(assertion,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
|
||||
jwt.WithAudience(audience),
|
||||
jwt.WithSubject(subject),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client assertion is not valid: %w", err)
|
||||
}
|
||||
|
||||
// If we're here, the assertion is valid
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user