From 4c33793678709eb4981be2c1fd5803bace5f5939 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Sun, 27 Apr 2025 02:32:42 +0900 Subject: [PATCH] fix: pass context to methods that were missing it (#487) --- .../internal/controller/oidc_controller.go | 42 +++++++++---------- backend/internal/service/oidc_service.go | 23 ++++++---- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 7503a7e4..4958ce61 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -155,23 +155,22 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() } - idToken, refreshToken, accessToken, expiresIn, err := oc.oidcService.CreateTokens( - c, - input, - ) - if err != nil { - switch { - case errors.Is(err, &common.OidcAuthorizationPendingError{}): - c.JSON(http.StatusBadRequest, gin.H{ - "error": "authorization_pending", - }) - case errors.Is(err, &common.OidcSlowDownError{}): - c.JSON(http.StatusBadRequest, gin.H{ - "error": "slow_down", - }) - default: - _ = c.Error(err) - } + idToken, refreshToken, accessToken, expiresIn, err := + oc.oidcService.CreateTokens(c.Request.Context(), input) + + switch { + case errors.Is(err, &common.OidcAuthorizationPendingError{}): + c.JSON(http.StatusBadRequest, gin.H{ + "error": "authorization_pending", + }) + return + case errors.Is(err, &common.OidcSlowDownError{}): + c.JSON(http.StatusBadRequest, gin.H{ + "error": "slow_down", + }) + return + case err != nil: + _ = c.Error(err) return } @@ -308,7 +307,6 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) { // @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result." // @Router /api/oidc/introspect [post] func (oc *OidcController) introspectTokenHandler(c *gin.Context) { - var input dto.OidcIntrospectDto if err := c.ShouldBind(&input); err != nil { _ = c.Error(err) @@ -322,7 +320,7 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) { // and client_secret anyway). clientID, clientSecret, _ := c.Request.BasicAuth() - response, err := oc.oidcService.IntrospectToken(clientID, clientSecret, input.Token) + response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token) if err != nil { _ = c.Error(err) return @@ -634,7 +632,7 @@ func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) { input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() } - response, err := oc.oidcService.CreateDeviceAuthorization(input) + response, err := oc.oidcService.CreateDeviceAuthorization(c.Request.Context(), input) if err != nil { _ = c.Error(err) return @@ -654,7 +652,7 @@ func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) { ipAddress := c.ClientIP() userAgent := c.Request.UserAgent() - err := oc.oidcService.VerifyDeviceCode(c, userCode, c.GetString("userID"), ipAddress, userAgent) + err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent) if err != nil { _ = c.Error(err) return @@ -670,7 +668,7 @@ func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) { return } - deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c, userCode, c.GetString("userID")) + deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c.Request.Context(), userCode, c.GetString("userID")) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 2ac7c51b..f8db7d6d 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -396,12 +396,12 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo return accessToken, newRefreshToken, 3600, nil } -func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { +func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { if clientID == "" || clientSecret == "" { return introspectDto, &common.OidcMissingClientCredentialsError{} } - _, err = s.VerifyClientCredentials(context.Background(), clientID, clientSecret, s.db) + _, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db) if err != nil { return introspectDto, err } @@ -410,7 +410,7 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string if err != nil { if errors.Is(err, jwt.ParseError()) { // It's apparently not a valid JWT token, so we check if it's a valid refresh_token. - return s.introspectRefreshToken(tokenString) + return s.introspectRefreshToken(ctx, tokenString) } // Every failure we get means the token is invalid. Nothing more to do with the error. @@ -454,9 +454,11 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string return introspectDto, nil } -func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { +func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { var storedRefreshToken model.OidcRefreshToken - err = s.db.Preload("User"). + err = s.db. + WithContext(ctx). + Preload("User"). Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). First(&storedRefreshToken). Error @@ -996,8 +998,8 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca return "", &common.OidcInvalidCallbackURLError{} } -func (s *OidcService) CreateDeviceAuthorization(input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { - client, err := s.VerifyClientCredentials(context.Background(), input.ClientID, input.ClientSecret, s.db) +func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { + client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db) if err != nil { return nil, err } @@ -1120,7 +1122,12 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) { var deviceAuth model.OidcDeviceCode - if err := s.db.Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil { + err := s.db. + WithContext(ctx). + Preload("Client"). + First(&deviceAuth, "user_code = ?", userCode). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &common.OidcInvalidDeviceCodeError{} }