diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 62ccc081..8a8ecb1d 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -70,6 +70,13 @@ type OidcInvalidAuthorizationCodeError struct{} func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" } func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 } +type OidcMissingCallbackURLError struct{} + +func (e *OidcMissingCallbackURLError) Error() string { + return "unable to detect callback url, it might be necessary for an admin to fix this" +} +func (e *OidcMissingCallbackURLError) HttpStatusCode() int { return 400 } + type OidcInvalidCallbackURLError struct{} func (e *OidcInvalidCallbackURLError) Error() string { diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index edbcf53b..e9304d86 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -26,7 +26,7 @@ type OidcClientWithAllowedGroupsCountDto struct { type OidcClientCreateDto struct { Name string `json:"name" binding:"required,max=50"` - CallbackURLs []string `json:"callbackURLs" binding:"required"` + CallbackURLs []string `json:"callbackURLs"` LogoutCallbackURLs []string `json:"logoutCallbackURLs"` IsPublic bool `json:"isPublic"` PkceEnabled bool `json:"pkceEnabled"` diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index e2cc363a..b42abfd7 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -73,7 +73,7 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie } // Get the callback URL of the client. Return an error if the provided callback URL is not allowed - callbackURL, err := s.getCallbackURL(client.CallbackURLs, input.CallbackURL) + callbackURL, err := s.getCallbackURL(client.CallbackURLs, input.CallbackURL, input.ClientID, tx, ctx) if err != nil { return "", "", err } @@ -947,7 +947,7 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo return "", &common.OidcNoCallbackURLError{} } - callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client.LogoutCallbackURLs, input.PostLogoutRedirectUri) + callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client.LogoutCallbackURLs, input.PostLogoutRedirectUri, userAuthorizedOIDCClient.Client.ID, s.db, ctx) if err != nil { return "", err } @@ -1006,23 +1006,55 @@ func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, c return encodedVerifierHash == codeChallenge } -func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (callbackURL string, err error) { +func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string, clientID string, tx *gorm.DB, ctx context.Context) (callbackURL string, err error) { + // If no input callback URL provided, use the first configured URL if inputCallbackURL == "" { - return urls[0], nil + if len(urls) > 0 { + return urls[0], nil + } + // If no URLs are configured and no input URL, this is an error + return "", &common.OidcMissingCallbackURLError{} } - for _, callbackPattern := range urls { - regexPattern := "^" + strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$" - matched, err := regexp.MatchString(regexPattern, inputCallbackURL) - if err != nil { - return "", err - } - if matched { - return inputCallbackURL, nil + // If URLs are already configured, validate against them + if len(urls) > 0 { + for _, callbackPattern := range urls { + regexPattern := "^" + strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$" + matched, err := regexp.MatchString(regexPattern, inputCallbackURL) + if err != nil { + return "", err + } + if matched { + return inputCallbackURL, nil + } } + return "", &common.OidcInvalidCallbackURLError{} } - return "", &common.OidcInvalidCallbackURLError{} + // If no URLs are configured, trust and store the first URL (TOFU) + err = s.addCallbackURLToClient(ctx, clientID, inputCallbackURL, tx) + if err != nil { + return "", err + } + return inputCallbackURL, nil +} + +func (s *OidcService) addCallbackURLToClient(ctx context.Context, clientID string, callbackURL string, tx *gorm.DB) error { + var client model.OidcClient + err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error + if err != nil { + return err + } + + // Add the new callback URL to the existing list + client.CallbackURLs = append(client.CallbackURLs, callbackURL) + + err = tx.WithContext(ctx).Save(&client).Error + if err != nil { + return err + } + + return nil } func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { diff --git a/frontend/messages/en-US.json b/frontend/messages/en-US.json index 6176f371..0317f8cd 100644 --- a/frontend/messages/en-US.json +++ b/frontend/messages/en-US.json @@ -340,7 +340,7 @@ "login_code_email_success": "The login code has been sent to the user.", "send_email": "Send Email", "show_code": "Show Code", - "callback_url_description": "URL(s) provided by your client. Wildcards (*) are supported, but best avoided for better security.", + "callback_url_description": "URL(s) provided by your client. Will be automatically added if left blank. Wildcards (*) are supported, but best avoided for better security.", "api_key_expiration": "API Key Expiration", "send_an_email_to_the_user_when_their_api_key_is_about_to_expire": "Send an email to the user when their API key is about to expire.", "authorize_device": "Authorize Device", diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index 7c82e218..dfffa89c 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -7,7 +7,7 @@ export type OidcClientMetaData = { }; export type OidcClient = OidcClientMetaData & { - callbackURLs: [string, ...string[]]; + callbackURLs: string[]; // No longer requires at least one URL logoutCallbackURLs: string[]; isPublic: boolean; pkceEnabled: boolean; diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-callback-url-input.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-callback-url-input.svelte index 0df7fb59..cb10f56e 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-callback-url-input.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-callback-url-input.svelte @@ -11,13 +11,11 @@ label, callbackURLs = $bindable(), error = $bindable(null), - allowEmpty = false, ...restProps }: HTMLAttributes & { label: string; callbackURLs: string[]; error?: string | null; - allowEmpty?: boolean; children?: Snippet; } = $props(); @@ -32,15 +30,13 @@ data-testid={`callback-url-${i + 1}`} bind:value={callbackURLs[i]} /> - {#if callbackURLs.length > 1 || allowEmpty} - - {/if} + {/each} diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte index 5116b6c0..1c9f8767 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte @@ -30,7 +30,7 @@ const client: OidcClientCreate = { name: existingClient?.name || '', - callbackURLs: existingClient?.callbackURLs || [''], + callbackURLs: existingClient?.callbackURLs || [], logoutCallbackURLs: existingClient?.logoutCallbackURLs || [], isPublic: existingClient?.isPublic || false, pkceEnabled: existingClient?.pkceEnabled || false @@ -38,7 +38,7 @@ const formSchema = z.object({ name: z.string().min(2).max(50), - callbackURLs: z.array(z.string().nonempty()).nonempty(), + callbackURLs: z.array(z.string().nonempty()).default([]), logoutCallbackURLs: z.array(z.string().nonempty()), isPublic: z.boolean(), pkceEnabled: z.boolean() @@ -91,7 +91,6 @@ diff --git a/tests/specs/oidc-client-settings.spec.ts b/tests/specs/oidc-client-settings.spec.ts index ed8d3756..1d051bfb 100644 --- a/tests/specs/oidc-client-settings.spec.ts +++ b/tests/specs/oidc-client-settings.spec.ts @@ -11,13 +11,12 @@ test("Create OIDC client", async ({ page }) => { await page.getByRole("button", { name: "Add OIDC Client" }).click(); await page.getByLabel("Name").fill(oidcClient.name); + await page.getByRole("button", { name: "Add" }).nth(1).click(); await page.getByTestId("callback-url-1").fill(oidcClient.callbackUrl); await page.getByRole("button", { name: "Add another" }).click(); await page.getByTestId("callback-url-2").fill(oidcClient.secondCallbackUrl!); - await page - .getByLabel("logo") - .setInputFiles("assets/pingvin-share-logo.png"); + await page.getByLabel("logo").setInputFiles("assets/pingvin-share-logo.png"); await page.getByRole("button", { name: "Save" }).click(); const clientId = await page.getByTestId("client-id").textContent(); @@ -53,9 +52,7 @@ test("Edit OIDC client", async ({ page }) => { .getByTestId("callback-url-1") .first() .fill("http://nextcloud-updated/auth/callback"); - await page - .getByLabel("logo") - .setInputFiles("assets/nextcloud-logo.png"); + await page.getByLabel("logo").setInputFiles("assets/nextcloud-logo.png"); await page.getByRole("button", { name: "Save" }).click(); await expect(page.locator('[data-type="success"]')).toHaveText(