1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 05:15:14 +00:00

fix: pass context to methods that were missing it (#487)

This commit is contained in:
Alessandro (Ale) Segala
2025-04-27 02:32:42 +09:00
committed by GitHub
parent 9e06f70380
commit 4c33793678
2 changed files with 35 additions and 30 deletions

View File

@@ -396,12 +396,12 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
_, err = s.VerifyClientCredentials(context.Background(), clientID, clientSecret, s.db)
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db)
if err != nil {
return introspectDto, err
}
@@ -410,7 +410,7 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
if err != nil {
if errors.Is(err, jwt.ParseError()) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
return s.introspectRefreshToken(tokenString)
return s.introspectRefreshToken(ctx, tokenString)
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
@@ -454,9 +454,11 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
return introspectDto, nil
}
func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
err = s.db.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
@@ -996,8 +998,8 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) CreateDeviceAuthorization(input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.VerifyClientCredentials(context.Background(), input.ClientID, input.ClientSecret, s.db)
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db)
if err != nil {
return nil, err
}
@@ -1120,7 +1122,12 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) {
var deviceAuth model.OidcDeviceCode
if err := s.db.Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
err := s.db.
WithContext(ctx).
Preload("Client").
First(&deviceAuth, "user_code = ?", userCode).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, &common.OidcInvalidDeviceCodeError{}
}