mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 20:15:07 +00:00
fix: token endpoint must not accept params as query string args (#1321)
This commit is contained in:
committed by
GitHub
parent
f0249377ac
commit
eb0456a395
@@ -1,6 +1,7 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -24,7 +25,11 @@ import (
|
|||||||
// @Description Initializes all OIDC-related API endpoints for authentication and client management
|
// @Description Initializes all OIDC-related API endpoints for authentication and client management
|
||||||
// @Tags OIDC
|
// @Tags OIDC
|
||||||
func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
|
func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
|
||||||
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
|
oc := &OidcController{
|
||||||
|
oidcService: oidcService,
|
||||||
|
jwtService: jwtService,
|
||||||
|
createTokens: oidcService.CreateTokens,
|
||||||
|
}
|
||||||
|
|
||||||
group.POST("/oidc/authorize", authMiddleware.WithAdminNotRequired().Add(), oc.authorizeHandler)
|
group.POST("/oidc/authorize", authMiddleware.WithAdminNotRequired().Add(), oc.authorizeHandler)
|
||||||
group.POST("/oidc/authorization-required", authMiddleware.WithAdminNotRequired().Add(), oc.authorizationConfirmationRequiredHandler)
|
group.POST("/oidc/authorization-required", authMiddleware.WithAdminNotRequired().Add(), oc.authorizationConfirmationRequiredHandler)
|
||||||
@@ -70,6 +75,7 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
|||||||
type OidcController struct {
|
type OidcController struct {
|
||||||
oidcService *service.OidcService
|
oidcService *service.OidcService
|
||||||
jwtService *service.JwtService
|
jwtService *service.JwtService
|
||||||
|
createTokens func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeHandler godoc
|
// authorizeHandler godoc
|
||||||
@@ -144,8 +150,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
|
|||||||
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
|
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
|
||||||
// @Router /api/oidc/token [post]
|
// @Router /api/oidc/token [post]
|
||||||
func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||||
|
// Per RFC-6749, parameters passed to the /token endpoint MUST be passed in the body of the request
|
||||||
|
// Gin's "ShouldBind" by default reads from the query string too, so we need to reset all query string args before invoking ShouldBind
|
||||||
|
c.Request.URL.RawQuery = ""
|
||||||
|
|
||||||
var input dto.OidcCreateTokensDto
|
var input dto.OidcCreateTokensDto
|
||||||
if err := c.ShouldBind(&input); err != nil {
|
err := c.ShouldBind(&input)
|
||||||
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -167,7 +178,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
input.ClientID, input.ClientSecret, _ = utils.OAuthClientBasicAuth(c.Request)
|
input.ClientID, input.ClientSecret, _ = utils.OAuthClientBasicAuth(c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input)
|
tokens, err := oc.createTokens(c.Request.Context(), input)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
||||||
|
|||||||
227
backend/internal/controller/oidc_controller_test.go
Normal file
227
backend/internal/controller/oidc_controller_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateTokensHandler(t *testing.T) {
|
||||||
|
createTestContext := func(t *testing.T, rawURL string, form url.Values, authHeader string, noCT bool) (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
mode := gin.Mode()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Cleanup(func() { gin.SetMode(mode) })
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, rawURL, strings.NewReader(form.Encode()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !noCT {
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
}
|
||||||
|
if authHeader != "" {
|
||||||
|
req.Header.Set("Authorization", authHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request = req
|
||||||
|
return c, recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Ignores Query String Parameters For Binding", func(t *testing.T) {
|
||||||
|
oc := &OidcController{}
|
||||||
|
|
||||||
|
c, _ := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token?grant_type=refresh_token&refresh_token=query-value",
|
||||||
|
url.Values{},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Len(t, c.Errors, 1)
|
||||||
|
assert.Contains(t, c.Errors[0].Err.Error(), "GrantType")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Missing Authorization Code", func(t *testing.T) {
|
||||||
|
oc := &OidcController{}
|
||||||
|
|
||||||
|
c, _ := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeAuthorizationCode},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Len(t, c.Errors, 1)
|
||||||
|
var missingCodeErr *common.OidcMissingAuthorizationCodeError
|
||||||
|
require.ErrorAs(t, c.Errors[0].Err, &missingCodeErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Missing Refresh Token", func(t *testing.T) {
|
||||||
|
oc := &OidcController{}
|
||||||
|
|
||||||
|
c, _ := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeRefreshToken},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Len(t, c.Errors, 1)
|
||||||
|
var missingRefreshErr *common.OidcMissingRefreshTokenError
|
||||||
|
require.ErrorAs(t, c.Errors[0].Err, &missingRefreshErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Uses Basic Auth Credentials When Body Credentials Missing", func(t *testing.T) {
|
||||||
|
var capturedInput dto.OidcCreateTokensDto
|
||||||
|
oc := &OidcController{
|
||||||
|
createTokens: func(_ context.Context, input dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
||||||
|
capturedInput = input
|
||||||
|
return service.CreatedTokens{
|
||||||
|
AccessToken: "access-token",
|
||||||
|
IdToken: "id-token",
|
||||||
|
RefreshToken: "refresh-token",
|
||||||
|
ExpiresIn: 2 * time.Minute,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("client-id:client-secret"))
|
||||||
|
c, recorder := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeRefreshToken},
|
||||||
|
"refresh_token": {"input-refresh-token"},
|
||||||
|
},
|
||||||
|
basicAuth,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Empty(t, c.Errors)
|
||||||
|
assert.Equal(t, "client-id", capturedInput.ClientID)
|
||||||
|
assert.Equal(t, "client-secret", capturedInput.ClientSecret)
|
||||||
|
assert.Equal(t, "input-refresh-token", capturedInput.RefreshToken)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
var response dto.OidcTokenResponseDto
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "access-token", response.AccessToken)
|
||||||
|
assert.Equal(t, "Bearer", response.TokenType)
|
||||||
|
assert.Equal(t, "id-token", response.IdToken)
|
||||||
|
assert.Equal(t, "refresh-token", response.RefreshToken)
|
||||||
|
assert.Equal(t, 120, response.ExpiresIn)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Maps Authorization Pending Error", func(t *testing.T) {
|
||||||
|
oc := &OidcController{
|
||||||
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
||||||
|
return service.CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, recorder := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeRefreshToken},
|
||||||
|
"refresh_token": {"input-refresh-token"},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Empty(t, c.Errors)
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
var response map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "authorization_pending", response["error"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Maps Slow Down Error", func(t *testing.T) {
|
||||||
|
oc := &OidcController{
|
||||||
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
||||||
|
return service.CreatedTokens{}, &common.OidcSlowDownError{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, recorder := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeRefreshToken},
|
||||||
|
"refresh_token": {"input-refresh-token"},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Empty(t, c.Errors)
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
var response map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "slow_down", response["error"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns Generic Service Error In Context", func(t *testing.T) {
|
||||||
|
expectedErr := errors.New("boom")
|
||||||
|
oc := &OidcController{
|
||||||
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
||||||
|
return service.CreatedTokens{}, expectedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, _ := createTestContext(
|
||||||
|
t,
|
||||||
|
"http://example.com/oidc/token",
|
||||||
|
url.Values{
|
||||||
|
"grant_type": {service.GrantTypeRefreshToken},
|
||||||
|
"refresh_token": {"input-refresh-token"},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.createTokensHandler(c)
|
||||||
|
|
||||||
|
require.Len(t, c.Errors, 1)
|
||||||
|
assert.ErrorIs(t, c.Errors[0].Err, expectedErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -162,6 +162,8 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the file starts with the gzip magic number
|
// Check if the file starts with the gzip magic number
|
||||||
|
// Gosec returns false positive for "G602: slice index out of range"
|
||||||
|
//nolint:gosec
|
||||||
isGzip := buf[0] == 0x1f && buf[1] == 0x8b
|
isGzip := buf[0] == 0x1f && buf[1] == 0x8b
|
||||||
|
|
||||||
if !isGzip {
|
if !isGzip {
|
||||||
|
|||||||
Reference in New Issue
Block a user