diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index f8db7d6d..e37bbe0b 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "gorm.io/gorm/clause" "log" "mime/multipart" "os" @@ -94,24 +95,8 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie // If the user has not authorized the client, create a new authorization in the database if !hasAuthorizedClient { - userAuthorizedClient := model.UserAuthorizedOidcClient{ - UserID: userID, - ClientID: input.ClientID, - Scope: input.Scope, - } - - err = tx. - WithContext(ctx). - Create(&userAuthorizedClient). - Error - if errors.Is(err, gorm.ErrDuplicatedKey) { - // The client has already been authorized but with a different scope so we need to update the scope - if err := tx. - WithContext(ctx). - Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil { - return "", "", err - } - } else if err != nil { + err := s.createAuthorizedClientInternal(ctx, userID, input.ClientID, input.Scope, tx) + if err != nil { return "", "", err } } @@ -201,7 +186,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, tx.Rollback() }() - _, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) + _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) if err != nil { return "", "", "", 0, err } @@ -269,7 +254,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code tx.Rollback() }() - client, err := s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) + client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) if err != nil { return "", "", "", 0, err } @@ -342,7 +327,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo tx.Rollback() }() - _, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) + _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) if err != nil { return "", "", 0, err } @@ -401,7 +386,7 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre return introspectDto, &common.OidcMissingClientCredentialsError{} } - _, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db) + _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db) if err != nil { return introspectDto, err } @@ -999,7 +984,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca } func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { - client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db) + client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db) if err != nil { return nil, err } @@ -1095,23 +1080,11 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use } if !hasAuthorizedClient { - userAuthorizedClient := model.UserAuthorizedOidcClient{ - UserID: userID, - ClientID: deviceAuth.ClientID, - Scope: deviceAuth.Scope, + err := s.createAuthorizedClientInternal(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope, tx) + if err != nil { + return err } - if err := tx.WithContext(ctx).Create(&userAuthorizedClient).Error; err != nil { - if !errors.Is(err, gorm.ErrDuplicatedKey) { - return err - } - // If duplicate, update scope - if err := tx.WithContext(ctx).Model(&model.UserAuthorizedOidcClient{}). - Where("user_id = ? AND client_id = ?", userID, deviceAuth.ClientID). - Update("scope", deviceAuth.Scope).Error; err != nil { - return err - } - } s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx) } else { s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx) @@ -1188,7 +1161,25 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u return refreshToken, nil } -func (s *OidcService) VerifyClientCredentials(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { +func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error { + userAuthorizedClient := model.UserAuthorizedOidcClient{ + UserID: userID, + ClientID: clientID, + Scope: scope, + } + + err := tx.WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "user_id"}, {Name: "client_id"}}, + DoUpdates: clause.AssignmentColumns([]string{"scope"}), + }). + Create(&userAuthorizedClient). + Error + + return err +} + +func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { if clientID == "" { return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} }