1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-14 21:40:06 +00:00

feat: oidc client data preview (#624)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-06-09 10:46:03 -05:00
committed by GitHub
parent 61bf14225b
commit c111b79147
12 changed files with 626 additions and 113 deletions

View File

@@ -234,7 +234,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
return token, nil
}
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
// BuildIDToken creates an ID token with all claims
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
@@ -242,33 +243,43 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
return nil, fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
for k, v := range userClaims {
err = token.Set(k, v)
if err != nil {
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
return nil, fmt.Errorf("failed to set claim '%s': %w", k, err)
}
}
if nonce != "" {
err = token.Set("nonce", nonce)
if err != nil {
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
return nil, fmt.Errorf("failed to set claim 'nonce': %w", err)
}
}
return token, nil
}
// GenerateIDToken creates and signs an ID token
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
token, err := s.BuildIDToken(userClaims, clientID, nonce)
if err != nil {
return "", err
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
@@ -311,7 +322,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return token, nil
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
// BuildOauthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
@@ -320,17 +332,27 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
return nil, fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthAccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
return token, nil
}
// GenerateOauthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOauthAccessToken(user, clientID)
if err != nil {
return "", err
}
alg, _ := s.privateKey.Algorithm()

View File

@@ -841,97 +841,6 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
return nil
}
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
err := tx.
WithContext(ctx).
Preload("User.UserGroups").
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
Error
if err != nil {
return nil, err
}
user := authorizedOidcClient.User
scopes := strings.Split(authorizedOidcClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
}
if slices.Contains(scopes, "groups") {
userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
}
claims["groups"] = userGroups
}
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
if err != nil {
return nil, err
}
for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface{}
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil {
// It's JSON so we store it as an object
claims[customClaim.Key] = jsonValue
} else {
// Marshalling failed, so we store it as a string
claims[customClaim.Key] = customClaim.Value
}
}
}
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
}
return claims, nil
}
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
tx := s.db.Begin()
defer func() {
@@ -1519,3 +1428,168 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
// If we're here, the assertion is valid
return nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
client, err := s.getClientInternal(ctx, clientID, tx)
if err != nil {
return nil, err
}
var user model.User
err = tx.
WithContext(ctx).
Preload("UserGroups").
First(&user, "id = ?", userID).
Error
if err != nil {
return nil, err
}
if !s.IsUserGroupAllowedToAuthorize(user, client) {
return nil, &common.OidcAccessDeniedError{}
}
dummyAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: clientID,
Scope: scopes,
User: user,
}
userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx)
if err != nil {
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
if err != nil {
return nil, err
}
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
if err != nil {
return nil, err
}
idTokenPayload, err := utils.GetClaimsFromToken(idToken)
if err != nil {
return nil, err
}
accessTokenPayload, err := utils.GetClaimsFromToken(accessToken)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return &dto.OidcClientPreviewDto{
IdToken: idTokenPayload,
AccessToken: accessTokenPayload,
UserInfo: userClaims,
}, nil
}
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
err := tx.
WithContext(ctx).
Preload("User.UserGroups").
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
Error
if err != nil {
return nil, err
}
return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx)
}
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
user := authorizedClient.User
scopes := strings.Split(authorizedClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
}
if slices.Contains(scopes, "groups") {
userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
}
claims["groups"] = userGroups
}
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
if err != nil {
return nil, err
}
for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface{}
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil {
// It's JSON, so we store it as an object
claims[customClaim.Key] = jsonValue
} else {
// Marshaling failed, so we store it as a string
claims[customClaim.Key] = customClaim.Value
}
}
}
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
}
return claims, nil
}