diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 57406de9..5dd9404f 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -1,6 +1,7 @@ package controller import ( + "context" "errors" "log/slog" "net/http" @@ -24,7 +25,11 @@ import ( // @Description Initializes all OIDC-related API endpoints for authentication and client management // @Tags OIDC func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) { - oc := &OidcController{oidcService: oidcService, jwtService: jwtService} + oc := &OidcController{ + oidcService: oidcService, + jwtService: jwtService, + createTokens: oidcService.CreateTokens, + } group.POST("/oidc/authorize", authMiddleware.WithAdminNotRequired().Add(), oc.authorizeHandler) group.POST("/oidc/authorization-required", authMiddleware.WithAdminNotRequired().Add(), oc.authorizationConfirmationRequiredHandler) @@ -68,8 +73,9 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi } type OidcController struct { - oidcService *service.OidcService - jwtService *service.JwtService + oidcService *service.OidcService + jwtService *service.JwtService + createTokens func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) } // authorizeHandler godoc @@ -144,8 +150,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex // @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token" // @Router /api/oidc/token [post] func (oc *OidcController) createTokensHandler(c *gin.Context) { + // Per RFC-6749, parameters passed to the /token endpoint MUST be passed in the body of the request + // Gin's "ShouldBind" by default reads from the query string too, so we need to reset all query string args before invoking ShouldBind + c.Request.URL.RawQuery = "" + var input dto.OidcCreateTokensDto - if err := c.ShouldBind(&input); err != nil { + err := c.ShouldBind(&input) + if err != nil { _ = c.Error(err) return } @@ -167,7 +178,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { input.ClientID, input.ClientSecret, _ = utils.OAuthClientBasicAuth(c.Request) } - tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input) + tokens, err := oc.createTokens(c.Request.Context(), input) switch { case errors.Is(err, &common.OidcAuthorizationPendingError{}): diff --git a/backend/internal/controller/oidc_controller_test.go b/backend/internal/controller/oidc_controller_test.go new file mode 100644 index 00000000..09c23e8f --- /dev/null +++ b/backend/internal/controller/oidc_controller_test.go @@ -0,0 +1,227 @@ +package controller + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/dto" + "github.com/pocket-id/pocket-id/backend/internal/service" +) + +func TestCreateTokensHandler(t *testing.T) { + createTestContext := func(t *testing.T, rawURL string, form url.Values, authHeader string, noCT bool) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + + mode := gin.Mode() + gin.SetMode(gin.TestMode) + t.Cleanup(func() { gin.SetMode(mode) }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, rawURL, strings.NewReader(form.Encode())) + require.NoError(t, err) + + if !noCT { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + + c.Request = req + return c, recorder + } + + t.Run("Ignores Query String Parameters For Binding", func(t *testing.T) { + oc := &OidcController{} + + c, _ := createTestContext( + t, + "http://example.com/oidc/token?grant_type=refresh_token&refresh_token=query-value", + url.Values{}, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Len(t, c.Errors, 1) + assert.Contains(t, c.Errors[0].Err.Error(), "GrantType") + }) + + t.Run("Missing Authorization Code", func(t *testing.T) { + oc := &OidcController{} + + c, _ := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeAuthorizationCode}, + }, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Len(t, c.Errors, 1) + var missingCodeErr *common.OidcMissingAuthorizationCodeError + require.ErrorAs(t, c.Errors[0].Err, &missingCodeErr) + }) + + t.Run("Missing Refresh Token", func(t *testing.T) { + oc := &OidcController{} + + c, _ := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeRefreshToken}, + }, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Len(t, c.Errors, 1) + var missingRefreshErr *common.OidcMissingRefreshTokenError + require.ErrorAs(t, c.Errors[0].Err, &missingRefreshErr) + }) + + t.Run("Uses Basic Auth Credentials When Body Credentials Missing", func(t *testing.T) { + var capturedInput dto.OidcCreateTokensDto + oc := &OidcController{ + createTokens: func(_ context.Context, input dto.OidcCreateTokensDto) (service.CreatedTokens, error) { + capturedInput = input + return service.CreatedTokens{ + AccessToken: "access-token", + IdToken: "id-token", + RefreshToken: "refresh-token", + ExpiresIn: 2 * time.Minute, + }, nil + }, + } + + basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("client-id:client-secret")) + c, recorder := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeRefreshToken}, + "refresh_token": {"input-refresh-token"}, + }, + basicAuth, + false, + ) + + oc.createTokensHandler(c) + + require.Empty(t, c.Errors) + assert.Equal(t, "client-id", capturedInput.ClientID) + assert.Equal(t, "client-secret", capturedInput.ClientSecret) + assert.Equal(t, "input-refresh-token", capturedInput.RefreshToken) + + require.Equal(t, http.StatusOK, recorder.Code) + var response dto.OidcTokenResponseDto + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response)) + assert.Equal(t, "access-token", response.AccessToken) + assert.Equal(t, "Bearer", response.TokenType) + assert.Equal(t, "id-token", response.IdToken) + assert.Equal(t, "refresh-token", response.RefreshToken) + assert.Equal(t, 120, response.ExpiresIn) + }) + + t.Run("Maps Authorization Pending Error", func(t *testing.T) { + oc := &OidcController{ + createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) { + return service.CreatedTokens{}, &common.OidcAuthorizationPendingError{} + }, + } + + c, recorder := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeRefreshToken}, + "refresh_token": {"input-refresh-token"}, + }, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Empty(t, c.Errors) + require.Equal(t, http.StatusBadRequest, recorder.Code) + var response map[string]string + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response)) + assert.Equal(t, "authorization_pending", response["error"]) + }) + + t.Run("Maps Slow Down Error", func(t *testing.T) { + oc := &OidcController{ + createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) { + return service.CreatedTokens{}, &common.OidcSlowDownError{} + }, + } + + c, recorder := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeRefreshToken}, + "refresh_token": {"input-refresh-token"}, + }, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Empty(t, c.Errors) + require.Equal(t, http.StatusBadRequest, recorder.Code) + var response map[string]string + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response)) + assert.Equal(t, "slow_down", response["error"]) + }) + + t.Run("Returns Generic Service Error In Context", func(t *testing.T) { + expectedErr := errors.New("boom") + oc := &OidcController{ + createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) { + return service.CreatedTokens{}, expectedErr + }, + } + + c, _ := createTestContext( + t, + "http://example.com/oidc/token", + url.Values{ + "grant_type": {service.GrantTypeRefreshToken}, + "refresh_token": {"input-refresh-token"}, + }, + "", + false, + ) + + oc.createTokensHandler(c) + + require.Len(t, c.Errors, 1) + assert.ErrorIs(t, c.Errors[0].Err, expectedErr) + }) +} diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index aeadd3b6..66a2a5f0 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -162,6 +162,8 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error { } // Check if the file starts with the gzip magic number + // Gosec returns false positive for "G602: slice index out of range" + //nolint:gosec isGzip := buf[0] == 0x1f && buf[1] == 0x8b if !isGzip {