1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 20:50:07 +00:00

fix: federated client credentials not working if sub ≠ client_id (#1342)

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala
2026-03-01 09:48:20 -08:00
committed by GitHub
parent 590e495c1d
commit 4d22c2dbcf
28 changed files with 85 additions and 78 deletions

View File

@@ -20,7 +20,7 @@ type AlreadyInUseError struct {
func (e *AlreadyInUseError) Error() string {
return e.Property + " is already in use"
}
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
func (e *AlreadyInUseError) HttpStatusCode() int { return http.StatusBadRequest }
func (e *AlreadyInUseError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AlreadyInUseError
@@ -31,26 +31,26 @@ func (e *AlreadyInUseError) Is(target error) bool {
type SetupAlreadyCompletedError struct{}
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 }
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return http.StatusConflict }
type TokenInvalidOrExpiredError struct{}
func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" }
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 }
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return http.StatusUnauthorized }
type DeviceCodeInvalid struct{}
func (e *DeviceCodeInvalid) Error() string {
return "one time access code must be used on the device it was generated for"
}
func (e *DeviceCodeInvalid) HttpStatusCode() int { return 400 }
func (e *DeviceCodeInvalid) HttpStatusCode() int { return http.StatusUnauthorized }
type TokenInvalidError struct{}
func (e *TokenInvalidError) Error() string {
return "Token is invalid"
}
func (e *TokenInvalidError) HttpStatusCode() int { return 400 }
func (e *TokenInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
type OidcMissingAuthorizationError struct{}
@@ -60,46 +60,51 @@ func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.Statu
type OidcGrantTypeNotSupportedError struct{}
func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" }
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 }
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcMissingClientCredentialsError struct{}
func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" }
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 }
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcClientSecretInvalidError struct{}
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
type OidcClientAssertionInvalidError struct{}
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
type OidcInvalidAuthorizationCodeError struct{}
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 }
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcClientNotFoundError struct{}
func (e *OidcClientNotFoundError) Error() string { return "client not found" }
func (e *OidcClientNotFoundError) HttpStatusCode() int { return http.StatusNotFound }
type OidcMissingCallbackURLError struct{}
func (e *OidcMissingCallbackURLError) Error() string {
return "unable to detect callback url, it might be necessary for an admin to fix this"
}
func (e *OidcMissingCallbackURLError) HttpStatusCode() int { return 400 }
func (e *OidcMissingCallbackURLError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcInvalidCallbackURLError struct{}
func (e *OidcInvalidCallbackURLError) Error() string {
return "invalid callback URL, it might be necessary for an admin to fix this"
}
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 }
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return http.StatusBadRequest }
type FileTypeNotSupportedError struct{}
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return http.StatusBadRequest }
type FileTooLargeError struct {
MaxSize string

View File

@@ -335,11 +335,13 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
)
creds.ClientID, creds.ClientSecret, ok = utils.OAuthClientBasicAuth(c.Request)
if !ok {
// If there's no basic auth, check if we have a bearer token
// If there's no basic auth, check if we have a bearer token (used as client assertion)
bearer, ok := utils.BearerAuth(c.Request)
if ok {
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
creds.ClientAssertion = bearer
// When using client assertions, client_id can be passed as a form field
creds.ClientID = input.ClientID
}
}
@@ -662,8 +664,13 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
}
func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
// Per RFC 8628 (OAuth 2.0 Device Authorization Grant), parameters for the device authorization request MUST be sent in the body of the POST request
// Gin's "ShouldBind" by default reads from the query string too, so we need to reset all query string args before invoking ShouldBind
c.Request.URL.RawQuery = ""
var input dto.OidcDeviceAuthorizationRequestDto
if err := c.ShouldBind(&input); err != nil {
err := c.ShouldBind(&input)
if err != nil {
_ = c.Error(err)
return
}

View File

@@ -98,7 +98,8 @@ type OidcCreateTokensDto struct {
}
type OidcIntrospectDto struct {
Token string `form:"token" binding:"required"`
Token string `form:"token" binding:"required"`
ClientID string `form:"client_id"`
}
type OidcUpdateAllowedUserGroupsDto struct {

View File

@@ -1644,34 +1644,19 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
// Determine the client ID based on the authentication method
var clientID string
switch {
case isClientAssertion:
// Extract client ID from the JWT assertion's 'sub' claim
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
if err != nil {
slog.Error("Failed to extract client ID from assertion", "error", err)
return nil, &common.OidcClientAssertionInvalidError{}
}
case input.ClientID != "":
// Use the provided client ID for other authentication methods
clientID = input.ClientID
default:
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
return nil, &common.OidcClientAssertionInvalidError{}
}
if errors.Is(err, gorm.ErrRecordNotFound) {
slog.WarnContext(ctx, "Client not found", slog.String("client", input.ClientID))
return nil, &common.OidcClientNotFoundError{}
} else if err != nil {
return nil, err
}
@@ -1686,7 +1671,7 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
return client, nil
// Next, check if we want to use client assertions from federated identities
case isClientAssertion:
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
if err != nil {
slog.WarnContext(ctx, "Invalid assertion for client", slog.String("client", client.ID), slog.Any("error", err))
@@ -1783,36 +1768,20 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_, err = jwt.Parse(assertion,
jwt.WithValidate(true),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
jwt.WithAudience(audience),
jwt.WithSubject(subject),
)
if err != nil {
return fmt.Errorf("client assertion is not valid: %w", err)
return fmt.Errorf("client assertion could not be verified: %w", err)
}
// If we're here, the assertion is valid
return nil
}
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
// Parse the JWT without verification first to get the claims
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
if err != nil {
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
}
// Extract the subject claim which must be the client_id according to RFC 7523
sub, ok := insecureToken.Subject()
if !ok || sub == "" {
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
}
return sub, nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin()
defer func() {

View File

@@ -229,6 +229,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
Subject: federatedClient.ID,
JWKS: federatedClientIssuer + "/jwks.json",
},
{
Issuer: "federated-issuer-2",
Audience: federatedClientAudience,
Subject: "my-federated-client",
JWKS: federatedClientIssuer + "/jwks.json",
},
{Issuer: federatedClientIssuerDefaults},
},
},
@@ -461,6 +467,43 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Generate a token
input := dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertion: string(signedToken),
ClientAssertionType: ClientAssertionTypeJWTBearer,
}
createdToken, err := s.createTokenFromClientCredentials(t.Context(), input)
require.NoError(t, err)
require.NotNil(t, token)
// Verify the token
claims, err := s.jwtService.VerifyOAuthAccessToken(createdToken.AccessToken)
require.NoError(t, err, "Failed to verify generated token")
// Check the claims
subject, ok := claims.Subject()
_ = assert.True(t, ok, "User ID not found in token") &&
assert.Equal(t, "client-"+federatedClient.ID, subject, "Token subject should match federated client ID with prefix")
audience, ok := claims.Audience()
_ = assert.True(t, ok, "Audience not found in token") &&
assert.Equal(t, []string{federatedClient.ID}, audience, "Audience should contain the federated client ID")
})
t.Run("Succeeds with valid assertion and custom subject", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer("federated-issuer-2").
Audience([]string{federatedClientAudience}).
Subject("my-federated-client").
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Generate a token
input := dto.OidcCreateTokensDto{
ClientID: federatedClient.ID,
ClientAssertion: string(signedToken),
ClientAssertionType: ClientAssertionTypeJWTBearer,
}
@@ -483,6 +526,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with invalid assertion", func(t *testing.T) {
input := dto.OidcCreateTokensDto{
ClientID: confidentialClient.ID,
ClientAssertion: "invalid.jwt.token",
ClientAssertionType: ClientAssertionTypeJWTBearer,
}