mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 22:35:06 +00:00
fix: federated client credentials not working if sub ≠ client_id (#1342)
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
590e495c1d
commit
4d22c2dbcf
@@ -20,7 +20,7 @@ type AlreadyInUseError struct {
|
||||
func (e *AlreadyInUseError) Error() string {
|
||||
return e.Property + " is already in use"
|
||||
}
|
||||
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
|
||||
func (e *AlreadyInUseError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
func (e *AlreadyInUseError) Is(target error) bool {
|
||||
// Ignore the field property when checking if an error is of the type AlreadyInUseError
|
||||
@@ -31,26 +31,26 @@ func (e *AlreadyInUseError) Is(target error) bool {
|
||||
type SetupAlreadyCompletedError struct{}
|
||||
|
||||
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
|
||||
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return 400 }
|
||||
func (e *SetupAlreadyCompletedError) HttpStatusCode() int { return http.StatusConflict }
|
||||
|
||||
type TokenInvalidOrExpiredError struct{}
|
||||
|
||||
func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" }
|
||||
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 }
|
||||
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type DeviceCodeInvalid struct{}
|
||||
|
||||
func (e *DeviceCodeInvalid) Error() string {
|
||||
return "one time access code must be used on the device it was generated for"
|
||||
}
|
||||
func (e *DeviceCodeInvalid) HttpStatusCode() int { return 400 }
|
||||
func (e *DeviceCodeInvalid) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type TokenInvalidError struct{}
|
||||
|
||||
func (e *TokenInvalidError) Error() string {
|
||||
return "Token is invalid"
|
||||
}
|
||||
func (e *TokenInvalidError) HttpStatusCode() int { return 400 }
|
||||
func (e *TokenInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type OidcMissingAuthorizationError struct{}
|
||||
|
||||
@@ -60,46 +60,51 @@ func (e *OidcMissingAuthorizationError) HttpStatusCode() int { return http.Statu
|
||||
type OidcGrantTypeNotSupportedError struct{}
|
||||
|
||||
func (e *OidcGrantTypeNotSupportedError) Error() string { return "grant type not supported" }
|
||||
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcGrantTypeNotSupportedError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OidcMissingClientCredentialsError struct{}
|
||||
|
||||
func (e *OidcMissingClientCredentialsError) Error() string { return "client id or secret not provided" }
|
||||
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcMissingClientCredentialsError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OidcClientSecretInvalidError struct{}
|
||||
|
||||
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
|
||||
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type OidcClientAssertionInvalidError struct{}
|
||||
|
||||
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
|
||||
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type OidcInvalidAuthorizationCodeError struct{}
|
||||
|
||||
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
|
||||
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OidcClientNotFoundError struct{}
|
||||
|
||||
func (e *OidcClientNotFoundError) Error() string { return "client not found" }
|
||||
func (e *OidcClientNotFoundError) HttpStatusCode() int { return http.StatusNotFound }
|
||||
|
||||
type OidcMissingCallbackURLError struct{}
|
||||
|
||||
func (e *OidcMissingCallbackURLError) Error() string {
|
||||
return "unable to detect callback url, it might be necessary for an admin to fix this"
|
||||
}
|
||||
func (e *OidcMissingCallbackURLError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcMissingCallbackURLError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OidcInvalidCallbackURLError struct{}
|
||||
|
||||
func (e *OidcInvalidCallbackURLError) Error() string {
|
||||
return "invalid callback URL, it might be necessary for an admin to fix this"
|
||||
}
|
||||
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 }
|
||||
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type FileTypeNotSupportedError struct{}
|
||||
|
||||
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
|
||||
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
|
||||
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type FileTooLargeError struct {
|
||||
MaxSize string
|
||||
|
||||
@@ -335,11 +335,13 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
||||
)
|
||||
creds.ClientID, creds.ClientSecret, ok = utils.OAuthClientBasicAuth(c.Request)
|
||||
if !ok {
|
||||
// If there's no basic auth, check if we have a bearer token
|
||||
// If there's no basic auth, check if we have a bearer token (used as client assertion)
|
||||
bearer, ok := utils.BearerAuth(c.Request)
|
||||
if ok {
|
||||
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
|
||||
creds.ClientAssertion = bearer
|
||||
// When using client assertions, client_id can be passed as a form field
|
||||
creds.ClientID = input.ClientID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,8 +664,13 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
|
||||
// Per RFC 8628 (OAuth 2.0 Device Authorization Grant), parameters for the device authorization request MUST be sent in the body of the POST request
|
||||
// Gin's "ShouldBind" by default reads from the query string too, so we need to reset all query string args before invoking ShouldBind
|
||||
c.Request.URL.RawQuery = ""
|
||||
|
||||
var input dto.OidcDeviceAuthorizationRequestDto
|
||||
if err := c.ShouldBind(&input); err != nil {
|
||||
err := c.ShouldBind(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -98,7 +98,8 @@ type OidcCreateTokensDto struct {
|
||||
}
|
||||
|
||||
type OidcIntrospectDto struct {
|
||||
Token string `form:"token" binding:"required"`
|
||||
Token string `form:"token" binding:"required"`
|
||||
ClientID string `form:"client_id"`
|
||||
}
|
||||
|
||||
type OidcUpdateAllowedUserGroupsDto struct {
|
||||
|
||||
@@ -1644,34 +1644,19 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
|
||||
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
|
||||
|
||||
// Determine the client ID based on the authentication method
|
||||
var clientID string
|
||||
switch {
|
||||
case isClientAssertion:
|
||||
// Extract client ID from the JWT assertion's 'sub' claim
|
||||
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
|
||||
if err != nil {
|
||||
slog.Error("Failed to extract client ID from assertion", "error", err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
case input.ClientID != "":
|
||||
// Use the provided client ID for other authentication methods
|
||||
clientID = input.ClientID
|
||||
default:
|
||||
if input.ClientID == "" {
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
First(&client, "id = ?", input.ClientID).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
slog.WarnContext(ctx, "Client not found", slog.String("client", input.ClientID))
|
||||
return nil, &common.OidcClientNotFoundError{}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1686,7 +1671,7 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
|
||||
return client, nil
|
||||
|
||||
// Next, check if we want to use client assertions from federated identities
|
||||
case isClientAssertion:
|
||||
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "Invalid assertion for client", slog.String("client", client.ID), slog.Any("error", err))
|
||||
@@ -1783,36 +1768,20 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
|
||||
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
|
||||
_, err = jwt.Parse(assertion,
|
||||
jwt.WithValidate(true),
|
||||
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
|
||||
jwt.WithAudience(audience),
|
||||
jwt.WithSubject(subject),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client assertion is not valid: %w", err)
|
||||
return fmt.Errorf("client assertion could not be verified: %w", err)
|
||||
}
|
||||
|
||||
// If we're here, the assertion is valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
|
||||
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
|
||||
// Parse the JWT without verification first to get the claims
|
||||
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
|
||||
}
|
||||
|
||||
// Extract the subject claim which must be the client_id according to RFC 7523
|
||||
sub, ok := insecureToken.Subject()
|
||||
if !ok || sub == "" {
|
||||
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
|
||||
@@ -229,6 +229,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
Subject: federatedClient.ID,
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{
|
||||
Issuer: "federated-issuer-2",
|
||||
Audience: federatedClientAudience,
|
||||
Subject: "my-federated-client",
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{Issuer: federatedClientIssuerDefaults},
|
||||
},
|
||||
},
|
||||
@@ -461,6 +467,43 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
// Generate a token
|
||||
input := dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertion: string(signedToken),
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
}
|
||||
createdToken, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, token)
|
||||
|
||||
// Verify the token
|
||||
claims, err := s.jwtService.VerifyOAuthAccessToken(createdToken.AccessToken)
|
||||
require.NoError(t, err, "Failed to verify generated token")
|
||||
|
||||
// Check the claims
|
||||
subject, ok := claims.Subject()
|
||||
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||
assert.Equal(t, "client-"+federatedClient.ID, subject, "Token subject should match federated client ID with prefix")
|
||||
audience, ok := claims.Audience()
|
||||
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||
assert.Equal(t, []string{federatedClient.ID}, audience, "Audience should contain the federated client ID")
|
||||
})
|
||||
|
||||
t.Run("Succeeds with valid assertion and custom subject", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer("federated-issuer-2").
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject("my-federated-client").
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a token
|
||||
input := dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertion: string(signedToken),
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
}
|
||||
@@ -483,6 +526,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
t.Run("Fails with invalid assertion", func(t *testing.T) {
|
||||
input := dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user