diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index b0b01598..57406de9 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -164,7 +164,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { // Client id and secret can also be passed over the Authorization header if input.ClientID == "" && input.ClientSecret == "" { - input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() + input.ClientID, input.ClientSecret, _ = utils.OAuthClientBasicAuth(c.Request) } tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input) @@ -322,7 +322,7 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) { creds service.ClientAuthCredentials ok bool ) - creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth() + creds.ClientID, creds.ClientSecret, ok = utils.OAuthClientBasicAuth(c.Request) if !ok { // If there's no basic auth, check if we have a bearer token bearer, ok := utils.BearerAuth(c.Request) @@ -659,7 +659,7 @@ func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) { // Client id and secret can also be passed over the Authorization header if input.ClientID == "" && input.ClientSecret == "" { - input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() + input.ClientID, input.ClientSecret, _ = utils.OAuthClientBasicAuth(c.Request) } response, err := oc.oidcService.CreateDeviceAuthorization(c.Request.Context(), input) diff --git a/backend/internal/utils/http_util.go b/backend/internal/utils/http_util.go index 0ba10336..bd9cc028 100644 --- a/backend/internal/utils/http_util.go +++ b/backend/internal/utils/http_util.go @@ -2,6 +2,7 @@ package utils import ( "net/http" + "net/url" "strconv" "strings" "time" @@ -21,6 +22,27 @@ func BearerAuth(r *http.Request) (string, bool) { return "", false } +// OAuthClientBasicAuth returns the OAuth client ID and secret provided in the request's +// Authorization header, if present. See RFC 6749, Section 2.3. +func OAuthClientBasicAuth(r *http.Request) (clientID, clientSecret string, ok bool) { + clientID, clientSecret, ok = r.BasicAuth() + if !ok { + return "", "", false + } + + clientID, err := url.QueryUnescape(clientID) + if err != nil { + return "", "", false + } + + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return "", "", false + } + + return clientID, clientSecret, true +} + // SetCacheControlHeader sets the Cache-Control header for the response. func SetCacheControlHeader(ctx *gin.Context, maxAge, staleWhileRevalidate time.Duration) { _, ok := ctx.GetQuery("skipCache") diff --git a/backend/internal/utils/http_util_test.go b/backend/internal/utils/http_util_test.go index c754c878..db5af19a 100644 --- a/backend/internal/utils/http_util_test.go +++ b/backend/internal/utils/http_util_test.go @@ -63,3 +63,62 @@ func TestBearerAuth(t *testing.T) { }) } } + +func TestOAuthClientBasicAuth(t *testing.T) { + tests := []struct { + name string + authHeader string + expectedClientID string + expectedClientSecret string + expectedOk bool + }{ + { + name: "Valid client ID and secret in header (example from RFC 6749)", + authHeader: "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3", + expectedClientID: "s6BhdRkqt3", + expectedClientSecret: "7Fjfp0ZBr1KtDRbnfVdmIw", + expectedOk: true, + }, + { + name: "Valid client ID and secret in header (escaped values)", + authHeader: "Basic ZTUwOTcyYmQtNmUzMi00OTU3LWJhZmMtMzU0MTU3ZjI1NDViOislMjUlMjYlMkIlQzIlQTMlRTIlODIlQUM=", + expectedClientID: "e50972bd-6e32-4957-bafc-354157f2545b", + // This is the example string from RFC 6749, Appendix B. + expectedClientSecret: " %&+£€", + expectedOk: true, + }, + { + name: "Empty auth header", + authHeader: "", + expectedClientID: "", + expectedClientSecret: "", + expectedOk: false, + }, + { + name: "Basic prefix only", + authHeader: "Basic ", + expectedClientID: "", + expectedClientSecret: "", + expectedOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil) + require.NoError(t, err, "Failed to create request") + + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + clientId, clientSecret, ok := OAuthClientBasicAuth(req) + + assert.Equal(t, tt.expectedOk, ok) + + if tt.expectedOk { + assert.Equal(t, tt.expectedClientID, clientId) + assert.Equal(t, tt.expectedClientSecret, clientSecret) + } + }) + } +}