mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-15 14:35:06 +00:00
fix: pass context to methods that were missing it (#487)
This commit is contained in:
committed by
GitHub
parent
9e06f70380
commit
4c33793678
@@ -155,23 +155,22 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, refreshToken, accessToken, expiresIn, err := oc.oidcService.CreateTokens(
|
idToken, refreshToken, accessToken, expiresIn, err :=
|
||||||
c,
|
oc.oidcService.CreateTokens(c.Request.Context(), input)
|
||||||
input,
|
|
||||||
)
|
switch {
|
||||||
if err != nil {
|
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
||||||
switch {
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
"error": "authorization_pending",
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
})
|
||||||
"error": "authorization_pending",
|
return
|
||||||
})
|
case errors.Is(err, &common.OidcSlowDownError{}):
|
||||||
case errors.Is(err, &common.OidcSlowDownError{}):
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
"error": "slow_down",
|
||||||
"error": "slow_down",
|
})
|
||||||
})
|
return
|
||||||
default:
|
case err != nil:
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,7 +307,6 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
|
|||||||
// @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result."
|
// @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result."
|
||||||
// @Router /api/oidc/introspect [post]
|
// @Router /api/oidc/introspect [post]
|
||||||
func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
||||||
|
|
||||||
var input dto.OidcIntrospectDto
|
var input dto.OidcIntrospectDto
|
||||||
if err := c.ShouldBind(&input); err != nil {
|
if err := c.ShouldBind(&input); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
@@ -322,7 +320,7 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
|||||||
// and client_secret anyway).
|
// and client_secret anyway).
|
||||||
clientID, clientSecret, _ := c.Request.BasicAuth()
|
clientID, clientSecret, _ := c.Request.BasicAuth()
|
||||||
|
|
||||||
response, err := oc.oidcService.IntrospectToken(clientID, clientSecret, input.Token)
|
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
@@ -634,7 +632,7 @@ func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
|
|||||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := oc.oidcService.CreateDeviceAuthorization(input)
|
response, err := oc.oidcService.CreateDeviceAuthorization(c.Request.Context(), input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
@@ -654,7 +652,7 @@ func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
|
|||||||
ipAddress := c.ClientIP()
|
ipAddress := c.ClientIP()
|
||||||
userAgent := c.Request.UserAgent()
|
userAgent := c.Request.UserAgent()
|
||||||
|
|
||||||
err := oc.oidcService.VerifyDeviceCode(c, userCode, c.GetString("userID"), ipAddress, userAgent)
|
err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
@@ -670,7 +668,7 @@ func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c, userCode, c.GetString("userID"))
|
deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c.Request.Context(), userCode, c.GetString("userID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -396,12 +396,12 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
|||||||
return accessToken, newRefreshToken, 3600, nil
|
return accessToken, newRefreshToken, 3600, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||||
if clientID == "" || clientSecret == "" {
|
if clientID == "" || clientSecret == "" {
|
||||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.VerifyClientCredentials(context.Background(), clientID, clientSecret, s.db)
|
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return introspectDto, err
|
return introspectDto, err
|
||||||
}
|
}
|
||||||
@@ -410,7 +410,7 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, jwt.ParseError()) {
|
if errors.Is(err, jwt.ParseError()) {
|
||||||
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
||||||
return s.introspectRefreshToken(tokenString)
|
return s.introspectRefreshToken(ctx, tokenString)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||||
@@ -454,9 +454,11 @@ func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string
|
|||||||
return introspectDto, nil
|
return introspectDto, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||||
var storedRefreshToken model.OidcRefreshToken
|
var storedRefreshToken model.OidcRefreshToken
|
||||||
err = s.db.Preload("User").
|
err = s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Preload("User").
|
||||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||||
First(&storedRefreshToken).
|
First(&storedRefreshToken).
|
||||||
Error
|
Error
|
||||||
@@ -996,8 +998,8 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
|||||||
return "", &common.OidcInvalidCallbackURLError{}
|
return "", &common.OidcInvalidCallbackURLError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateDeviceAuthorization(input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||||
client, err := s.VerifyClientCredentials(context.Background(), input.ClientID, input.ClientSecret, s.db)
|
client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1120,7 +1122,12 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
|||||||
|
|
||||||
func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) {
|
func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, userID string) (*dto.DeviceCodeInfoDto, error) {
|
||||||
var deviceAuth model.OidcDeviceCode
|
var deviceAuth model.OidcDeviceCode
|
||||||
if err := s.db.Preload("Client").First(&deviceAuth, "user_code = ?", userCode).Error; err != nil {
|
err := s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Preload("Client").
|
||||||
|
First(&deviceAuth, "user_code = ?", userCode).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, &common.OidcInvalidDeviceCodeError{}
|
return nil, &common.OidcInvalidDeviceCodeError{}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user