1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-23 13:40:08 +00:00

fix: federated client credentials not working if sub ≠ client_id (#1342)

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala
2026-03-01 09:48:20 -08:00
committed by GitHub
parent 590e495c1d
commit 4d22c2dbcf
28 changed files with 85 additions and 78 deletions

View File

@@ -1644,34 +1644,19 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
// Determine the client ID based on the authentication method
var clientID string
switch {
case isClientAssertion:
// Extract client ID from the JWT assertion's 'sub' claim
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
if err != nil {
slog.Error("Failed to extract client ID from assertion", "error", err)
return nil, &common.OidcClientAssertionInvalidError{}
}
case input.ClientID != "":
// Use the provided client ID for other authentication methods
clientID = input.ClientID
default:
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
return nil, &common.OidcClientAssertionInvalidError{}
}
if errors.Is(err, gorm.ErrRecordNotFound) {
slog.WarnContext(ctx, "Client not found", slog.String("client", input.ClientID))
return nil, &common.OidcClientNotFoundError{}
} else if err != nil {
return nil, err
}
@@ -1686,7 +1671,7 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
return client, nil
// Next, check if we want to use client assertions from federated identities
case isClientAssertion:
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
if err != nil {
slog.WarnContext(ctx, "Invalid assertion for client", slog.String("client", client.ID), slog.Any("error", err))
@@ -1783,36 +1768,20 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_, err = jwt.Parse(assertion,
jwt.WithValidate(true),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
jwt.WithAudience(audience),
jwt.WithSubject(subject),
)
if err != nil {
return fmt.Errorf("client assertion is not valid: %w", err)
return fmt.Errorf("client assertion could not be verified: %w", err)
}
// If we're here, the assertion is valid
return nil
}
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
// Parse the JWT without verification first to get the claims
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
if err != nil {
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
}
// Extract the subject claim which must be the client_id according to RFC 7523
sub, ok := insecureToken.Subject()
if !ok || sub == "" {
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
}
return sub, nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin()
defer func() {

View File

@@ -229,6 +229,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
Subject: federatedClient.ID,
JWKS: federatedClientIssuer + "/jwks.json",
},
{
Issuer: "federated-issuer-2",
Audience: federatedClientAudience,
Subject: "my-federated-client",
JWKS: federatedClientIssuer + "/jwks.json",
},
{Issuer: federatedClientIssuerDefaults},
},
},
@@ -461,6 +467,43 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Generate a token
input := dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertion: string(signedToken),
ClientAssertionType: ClientAssertionTypeJWTBearer,
}
createdToken, err := s.createTokenFromClientCredentials(t.Context(), input)
require.NoError(t, err)
require.NotNil(t, token)
// Verify the token
claims, err := s.jwtService.VerifyOAuthAccessToken(createdToken.AccessToken)
require.NoError(t, err, "Failed to verify generated token")
// Check the claims
subject, ok := claims.Subject()
_ = assert.True(t, ok, "User ID not found in token") &&
assert.Equal(t, "client-"+federatedClient.ID, subject, "Token subject should match federated client ID with prefix")
audience, ok := claims.Audience()
_ = assert.True(t, ok, "Audience not found in token") &&
assert.Equal(t, []string{federatedClient.ID}, audience, "Audience should contain the federated client ID")
})
t.Run("Succeeds with valid assertion and custom subject", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer("federated-issuer-2").
Audience([]string{federatedClientAudience}).
Subject("my-federated-client").
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Generate a token
input := dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertion: string(signedToken),
ClientAssertionType: ClientAssertionTypeJWTBearer,
}
@@ -483,6 +526,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with invalid assertion", func(t *testing.T) {
input := dto.OidcCreateTokensDto{
ClientID: confidentialClient.ID,
ClientAssertion: "invalid.jwt.token",
ClientAssertionType: ClientAssertionTypeJWTBearer,
}