mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-16 21:04:12 +00:00
fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
c810fec8c4
commit
ec626ee797
@@ -1,34 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Reserved claims
|
||||
var reservedClaims = map[string]struct{}{
|
||||
"given_name": {},
|
||||
"family_name": {},
|
||||
"name": {},
|
||||
"email": {},
|
||||
"preferred_username": {},
|
||||
"groups": {},
|
||||
"sub": {},
|
||||
"iss": {},
|
||||
"aud": {},
|
||||
"exp": {},
|
||||
"iat": {},
|
||||
"auth_time": {},
|
||||
"nonce": {},
|
||||
"acr": {},
|
||||
"amr": {},
|
||||
"azp": {},
|
||||
"nbf": {},
|
||||
"jti": {},
|
||||
}
|
||||
|
||||
type CustomClaimService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
@@ -39,8 +19,29 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
|
||||
|
||||
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
|
||||
func isReservedClaim(key string) bool {
|
||||
_, ok := reservedClaims[key]
|
||||
return ok
|
||||
switch key {
|
||||
case "given_name",
|
||||
"family_name",
|
||||
"name",
|
||||
"email",
|
||||
"preferred_username",
|
||||
"groups",
|
||||
"sub",
|
||||
"iss",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"azp",
|
||||
"nbf",
|
||||
"jti":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// idType is the type of the id used to identify the user or user group
|
||||
@@ -52,28 +53,38 @@ const (
|
||||
)
|
||||
|
||||
// UpdateCustomClaimsForUser updates the custom claims for a user
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserID, userID, claims)
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(ctx, UserID, userID, claims)
|
||||
}
|
||||
|
||||
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
|
||||
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
|
||||
}
|
||||
|
||||
// updateCustomClaims updates the custom claims for a user or user group
|
||||
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
|
||||
// Check for duplicate keys in the claims slice
|
||||
seenKeys := make(map[string]bool)
|
||||
seenKeys := make(map[string]struct{})
|
||||
for _, claim := range claims {
|
||||
if seenKeys[claim.Key] {
|
||||
if _, ok := seenKeys[claim.Key]; ok {
|
||||
return nil, &common.DuplicateClaimError{Key: claim.Key}
|
||||
}
|
||||
seenKeys[claim.Key] = true
|
||||
seenKeys[claim.Key] = struct{}{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var existingClaims []model.CustomClaim
|
||||
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType), value).
|
||||
Find(&existingClaims).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = s.db.Delete(&existingClaim).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&existingClaim).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
}
|
||||
|
||||
// Update the claim if it already exists or create a new one
|
||||
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
|
||||
Assign(&customClaim).
|
||||
FirstOrCreate(&model.CustomClaim{}).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
|
||||
// Get the updated claims
|
||||
var updatedClaims []model.CustomClaim
|
||||
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where(string(idType)+" = ?", value).
|
||||
Find(&updatedClaims).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
|
||||
return updatedClaims, nil
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Find(&customClaims).
|
||||
Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
var customClaims []model.CustomClaim
|
||||
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("user_group_id = ?", userGroupID).
|
||||
Find(&customClaims).
|
||||
Error
|
||||
return customClaims, err
|
||||
}
|
||||
|
||||
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
|
||||
// prioritizing the user's claims over user group claims with the same key.
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
|
||||
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
|
||||
// Get the custom claims of the user
|
||||
customClaims, err := s.GetCustomClaimsForUser(userID)
|
||||
customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
|
||||
|
||||
// Get all user groups of the user
|
||||
var userGroupsOfUser []model.UserGroup
|
||||
err = s.db.Preload("CustomClaims").
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("CustomClaims").
|
||||
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
|
||||
Where("user_groups_users.user_id = ?", userID).
|
||||
Find(&userGroupsOfUser).Error
|
||||
@@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
|
||||
}
|
||||
|
||||
// GetSuggestions returns a list of custom claim keys that have been used before
|
||||
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
|
||||
func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
|
||||
var customClaimsKeys []string
|
||||
|
||||
err := s.db.Model(&model.CustomClaim{}).
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.CustomClaim{}).
|
||||
Group("key").
|
||||
Order("COUNT(*) DESC").
|
||||
Pluck("key", &customClaimsKeys).Error
|
||||
|
||||
Reference in New Issue
Block a user