mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-14 22:15:13 +00:00
feat: implement token introspection (#405)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us> Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
8d6c1e5c08
commit
7e5d16be9b
@@ -11,8 +11,11 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
@@ -37,6 +40,12 @@ const (
|
||||
// This may be omitted on non-admin tokens
|
||||
IsAdminClaim = "isAdmin"
|
||||
|
||||
// AccessTokenJWTType is the media type for access tokens
|
||||
AccessTokenJWTType = "AT+JWT"
|
||||
|
||||
// IDTokenJWTType is the media type for ID tokens
|
||||
IDTokenJWTType = "ID+JWT"
|
||||
|
||||
// Acceptable clock skew for verifying tokens
|
||||
clockSkew = time.Minute
|
||||
)
|
||||
@@ -247,8 +256,13 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
|
||||
}
|
||||
}
|
||||
|
||||
headers, err := CreateTokenTypeHeader(IDTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set token type: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
@@ -285,6 +299,11 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
err = VerifyTokenTypeHeader(tokenString, IDTokenJWTType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token type: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -305,8 +324,13 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
headers, err := CreateTokenTypeHeader(AccessTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set token type: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
@@ -327,6 +351,11 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
err = VerifyTokenTypeHeader(tokenString, AccessTokenJWTType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token type: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -481,6 +510,17 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
return isAdmin, err
|
||||
}
|
||||
|
||||
// CreateTokenTypeHeader creates a new JWS header with the given token type
|
||||
func CreateTokenTypeHeader(tokenType string) (jws.Headers, error) {
|
||||
headers := jws.NewHeaders()
|
||||
err := headers.Set(jws.TypeKey, tokenType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set token type: %w", err)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// SetIsAdmin sets the "isAdmin" claim in the token
|
||||
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
|
||||
// Only set if true
|
||||
@@ -495,3 +535,37 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
|
||||
func SetAudienceString(token jwt.Token, audience string) error {
|
||||
return token.Set(jwt.AudienceKey, audience)
|
||||
}
|
||||
|
||||
// VerifyTokenTypeHeader verifies that the "typ" header in the token matches the expected type
|
||||
func VerifyTokenTypeHeader(tokenBytes string, expectedTokenType string) error {
|
||||
// Parse the raw token string purely as a JWS message structure
|
||||
// We don't need to verify the signature at this stage, just inspect headers.
|
||||
msg, err := jws.Parse([]byte(tokenBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse token as JWS message: %w", err)
|
||||
}
|
||||
|
||||
// Get the list of signatures attached to the message. Usually just one for JWT.
|
||||
signatures := msg.Signatures()
|
||||
if len(signatures) == 0 {
|
||||
return errors.New("JWS message contains no signatures")
|
||||
}
|
||||
|
||||
protectedHeaders := signatures[0].ProtectedHeaders()
|
||||
if protectedHeaders == nil {
|
||||
return fmt.Errorf("JWS signature has no protected headers")
|
||||
}
|
||||
|
||||
// Retrieve the 'typ' header value from the PROTECTED headers.
|
||||
var typHeaderValue string
|
||||
err = protectedHeaders.Get(jws.TypeKey, &typHeaderValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token is missing required protected header '%s'", jws.TypeKey)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(typHeaderValue, expectedTokenType) {
|
||||
return fmt.Errorf("'%s' header mismatch: expected '%s', got '%s'", jws.TypeKey, expectedTokenType, typHeaderValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,12 +6,15 @@ import (
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
@@ -651,8 +654,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create headers with the specified type
|
||||
hdrs := jws.NewHeaders()
|
||||
err = hdrs.Set(jws.TypeKey, "ID+JWT")
|
||||
require.NoError(t, err, "Failed to set header type")
|
||||
|
||||
// Sign the token
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
tokenString := string(signed)
|
||||
|
||||
@@ -1172,6 +1180,63 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifyTokenTypeHeader(t *testing.T) {
|
||||
mockConfig := &AppConfigService{}
|
||||
tempDir := t.TempDir()
|
||||
// Helper function to create a token with a specific type header
|
||||
createTokenWithType := func(tokenType string) (string, error) {
|
||||
// Create a simple JWT token
|
||||
token := jwt.New()
|
||||
err := token.Set("test_claim", "test_value")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim: %w", err)
|
||||
}
|
||||
|
||||
// Create headers with the specified type
|
||||
hdrs := jws.NewHeaders()
|
||||
if tokenType != "" {
|
||||
err = hdrs.Set(jws.TypeKey, tokenType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set type header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sign the token with the headers
|
||||
service := &JwtService{}
|
||||
err = service.init(mockConfig, tempDir)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
|
||||
// Create a token with "JWT" type
|
||||
tokenString, err := createTokenWithType("JWT")
|
||||
require.NoError(t, err, "Failed to create test token")
|
||||
|
||||
// Verify the token type
|
||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
||||
assert.NoError(t, err, "Should accept token with matching type")
|
||||
})
|
||||
|
||||
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
|
||||
// Create a token with "AT+JWT" type
|
||||
tokenString, err := createTokenWithType("AT+JWT")
|
||||
require.NoError(t, err, "Failed to create test token")
|
||||
|
||||
// Verify the token with different expected type
|
||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
||||
require.Error(t, err, "Should reject token with non-matching type")
|
||||
assert.Contains(t, err.Error(), "header mismatch: expected 'JWT', got 'AT+JWT'")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
@@ -356,6 +358,93 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
return accessToken, newRefreshToken, 3600, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Get the client to check if we are authorized.
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
// Verify the client secret. This endpoint may not be used by public clients.
|
||||
if client.IsPublic {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
|
||||
return introspectDto, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
|
||||
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ParseError()) {
|
||||
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
||||
return s.introspectRefreshToken(tokenString)
|
||||
}
|
||||
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
if token.Has("scope") {
|
||||
var asString string
|
||||
var asStrings []string
|
||||
if err := token.Get("scope", &asString); err == nil {
|
||||
introspectDto.Scope = asString
|
||||
} else if err := token.Get("scope", &asStrings); err == nil {
|
||||
introspectDto.Scope = strings.Join(asStrings, " ")
|
||||
}
|
||||
}
|
||||
if expiration, hasExpiration := token.Expiration(); hasExpiration {
|
||||
introspectDto.Expiration = expiration.Unix()
|
||||
}
|
||||
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
|
||||
introspectDto.IssuedAt = issuedAt.Unix()
|
||||
}
|
||||
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
|
||||
introspectDto.NotBefore = notBefore.Unix()
|
||||
}
|
||||
if subject, hasSubject := token.Subject(); hasSubject {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, hasAudience := token.Audience(); hasAudience {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, hasIssuer := token.Issuer(); hasIssuer {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
|
||||
introspectDto.Identifier = identifier
|
||||
}
|
||||
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = s.db.Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "refresh_token"
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
||||
return s.getClientInternal(ctx, clientID, s.db)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user