From 89349dc1adc2a6defd2bc05ff96c1e8c6c2ae721 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 3 Mar 2026 01:52:42 -0800 Subject: [PATCH] fix: handle IPv6 addresses in callback URLs (#1355) --- backend/internal/utils/callback_url_util.go | 94 ++++++-- .../internal/utils/callback_url_util_test.go | 227 +++++++++++++++++- 2 files changed, 290 insertions(+), 31 deletions(-) diff --git a/backend/internal/utils/callback_url_util.go b/backend/internal/utils/callback_url_util.go index f4c3306b..7fa44a97 100644 --- a/backend/internal/utils/callback_url_util.go +++ b/backend/internal/utils/callback_url_util.go @@ -34,22 +34,7 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL // time of the request for loopback IP redirect URIs, to accommodate // clients that obtain an available ephemeral port from the operating // system at the time of the request. - loopbackCallbackURLWithoutPort := "" - u, _ := url.Parse(inputCallbackURL) - - if u != nil && u.Scheme == "http" { - host := u.Hostname() - ip := net.ParseIP(host) - if host == "localhost" || (ip != nil && ip.IsLoopback()) { - // For IPv6 loopback hosts, brackets are required when serializing without a port. - if strings.Contains(host, ":") { - u.Host = "[" + host + "]" - } else { - u.Host = host - } - loopbackCallbackURLWithoutPort = u.String() - } - } + loopbackCallbackURLWithoutPort := loopbackURLWithWildcardPort(inputCallbackURL) for _, pattern := range urls { // Try the original callback first @@ -76,6 +61,28 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL return "", nil } +func loopbackURLWithWildcardPort(input string) string { + u, _ := url.Parse(input) + + if u == nil || u.Scheme != "http" { + return "" + } + + host := u.Hostname() + ip := net.ParseIP(host) + if host != "localhost" && (ip == nil || !ip.IsLoopback()) { + return "" + } + + // For IPv6 loopback hosts, brackets are required when serializing without a port. + if strings.Contains(host, ":") { + u.Host = "[" + host + "]" + } else { + u.Host = host + } + return u.String() +} + // matchCallbackURL checks if the input callback URL matches the given pattern. // It supports wildcard matching for paths and query parameters. // @@ -125,10 +132,57 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er // normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards // into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards // and ** for multi-segment wildcards. +// Additionally, it escapes ":" with a backslash inside IPv6 addresses func normalizeToURLPatternStandard(pattern string) string { patternBase, patternPath := extractPath(pattern) var result strings.Builder + result.Grow(len(pattern) + 5) // Add 5 for some extra capacity, hoping to avoid many re-allocations + + // First, process the base + + // 0 = scheme + // 1 = hostname (optionally with username/password) - before IPv6 start (no `[` found) + // 2 = is matching IPv6 (until `]`) + // 3 = after hostname + var step int + for i := 0; i < len(patternBase); i++ { + switch step { + case 0: + if i > 3 && patternBase[i] == '/' && patternBase[i-1] == '/' && patternBase[i-2] == ':' { + // We just passed the scheme + step = 1 + } + case 1: + switch patternBase[i] { + case '/', ']': + // No IPv6, skip to end of this logic + step = 3 + case '[': + // Start of IPv6 match + step = 2 + } + case 2: + if patternBase[i] == '/' || patternBase[i] == ']' || patternBase[i] == '[' { + // End of IPv6 match + step = 3 + } + + switch patternBase[i] { + case ':': + // We are matching an IPv6 block and there's a colon, so escape that + result.WriteByte('\\') + case '/', ']', '[': + // End of IPv6 match + step = 3 + } + } + + // Write the byte + result.WriteByte(patternBase[i]) + } + + // Next, process the path for i := 0; i < len(patternPath); i++ { if patternPath[i] == '*' { // Replace globstar with a single asterisk @@ -141,19 +195,19 @@ func normalizeToURLPatternStandard(pattern string) string { result.WriteString(strconv.Itoa(i)) } } else { + // Add the byte result.WriteByte(patternPath[i]) } } - patternPath = result.String() - - return patternBase + patternPath + return result.String() } func extractPath(url string) (base string, path string) { pathStart := -1 // Look for scheme:// first - if i := strings.Index(url, "://"); i >= 0 { + i := strings.Index(url, "://") + if i >= 0 { // Look for the next slash after scheme:// rest := url[i+3:] if j := strings.IndexByte(rest, '/'); j >= 0 { diff --git a/backend/internal/utils/callback_url_util_test.go b/backend/internal/utils/callback_url_util_test.go index c3423dde..9e109620 100644 --- a/backend/internal/utils/callback_url_util_test.go +++ b/backend/internal/utils/callback_url_util_test.go @@ -58,11 +58,6 @@ func TestValidateCallbackURLPattern(t *testing.T) { pattern: "https://exa[mple.com/callback", shouldError: true, }, - { - name: "malformed authority", - pattern: "https://[::1/callback", - shouldError: true, - }, } for _, tt := range tests { @@ -78,6 +73,76 @@ func TestValidateCallbackURLPattern(t *testing.T) { } } +func TestNormalizeToURLPatternStandard(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "exact URL unchanged", + input: "https://example.com/callback", + expected: "https://example.com/callback", + }, + { + name: "single wildcard path segment converted to named parameter", + input: "https://example.com/api/*/callback", + expected: "https://example.com/api/:p5/callback", + }, + { + name: "single wildcard in path suffix converted to named parameter", + input: "https://example.com/test*", + expected: "https://example.com/test:p5", + }, + { + name: "globstar converted to single asterisk", + input: "https://example.com/**/callback", + expected: "https://example.com/*/callback", + }, + { + name: "mixed globstar and single wildcard conversion", + input: "https://example.com/**/v1/**/callback/*", + expected: "https://example.com/*/v1/*/callback/:p19", + }, + { + name: "URL without path unchanged", + input: "https://example.com", + expected: "https://example.com", + }, + { + name: "relative path conversion", + input: "/foo/*/bar", + expected: "/foo/:p5/bar", + }, + { + name: "wildcard in hostname is not normalized by this function", + input: "https://*.example.com/callback", + expected: "https://*.example.com/callback", + }, + { + name: "IPv6 hostname escapes all colons inside address", + input: "https://[2001:db8:1:1::a:1]/callback", + expected: "https://[2001\\:db8\\:1\\:1\\:\\:a\\:1]/callback", + }, + { + name: "IPv6 hostname with port escapes only address colons", + input: "https://[::1]:8080/callback", + expected: "https://[\\:\\:1]:8080/callback", + }, + { + name: "wildcard in query is converted when query is part of input", + input: "https://example.com/callback?code=*", + expected: "https://example.com/callback?code=:p15", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, normalizeToURLPatternStandard(tt.input)) + }) + } +} + func TestMatchCallbackURL(t *testing.T) { tests := []struct { name string @@ -98,6 +163,18 @@ func TestMatchCallbackURL(t *testing.T) { "https://example.com/callback", false, }, + { + "exact match - IPv4", + "https://10.1.0.1/callback", + "https://10.1.0.1/callback", + true, + }, + { + "exact match - IPv6", + "https://[2001:db8:1:1::a:1]/callback", + "https://[2001:db8:1:1::a:1]/callback", + true, + }, // Scheme { @@ -182,6 +259,30 @@ func TestMatchCallbackURL(t *testing.T) { "https://example.com:8080/callback", true, }, + { + "wildcard port - IPv4", + "https://10.1.0.1:*/callback", + "https://10.1.0.1:8080/callback", + true, + }, + { + "partial wildcard in port prefix - IPv4", + "https://10.1.0.1:80*/callback", + "https://10.1.0.1:8080/callback", + true, + }, + { + "wildcard port - IPv6", + "https://[2001:db8:1:1::a:1]:*/callback", + "https://[2001:db8:1:1::a:1]:8080/callback", + true, + }, + { + "partial wildcard in port prefix - IPv6", + "https://[2001:db8:1:1::a:1]:80*/callback", + "https://[2001:db8:1:1::a:1]:8080/callback", + true, + }, // Path { @@ -202,6 +303,18 @@ func TestMatchCallbackURL(t *testing.T) { "https://example.com/callback", true, }, + { + "wildcard entire path - IPv4", + "https://10.1.0.1/*", + "https://10.1.0.1/callback", + true, + }, + { + "wildcard entire path - IPv6", + "https://[2001:db8:1:1::a:1]/*", + "https://[2001:db8:1:1::a:1]/callback", + true, + }, { "partial wildcard in path prefix", "https://example.com/test*", @@ -435,10 +548,11 @@ func TestMatchCallbackURL(t *testing.T) { } for _, tt := range tests { - matches, err := matchCallbackURL(tt.pattern, tt.input) - require.NoError(t, err, tt.name) - assert.Equal(t, tt.shouldMatch, matches, tt.name) - + t.Run(tt.name, func(t *testing.T) { + matches, err := matchCallbackURL(tt.pattern, tt.input) + require.NoError(t, err) + assert.Equal(t, tt.shouldMatch, matches) + }) } } @@ -472,14 +586,21 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) { expectMatch: true, }, { - name: "IPv6 loopback with dynamic port", + name: "IPv6 loopback with dynamic port - exact match", urls: []string{"http://[::1]/callback"}, inputCallbackURL: "http://[::1]:8080/callback", expectedURL: "http://[::1]:8080/callback", expectMatch: true, }, { - name: "IPv6 loopback with wildcard path", + name: "IPv6 loopback with same port - exact match", + urls: []string{"http://[::1]:8080/callback"}, + inputCallbackURL: "http://[::1]:8080/callback", + expectedURL: "http://[::1]:8080/callback", + expectMatch: true, + }, + { + name: "IPv6 loopback with path match", urls: []string{"http://[::1]/auth/*"}, inputCallbackURL: "http://[::1]:8080/auth/callback", expectedURL: "http://[::1]:8080/auth/callback", @@ -506,6 +627,20 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) { expectedURL: "http://127.0.0.1:3000/auth/callback", expectMatch: true, }, + { + name: "loopback with path port", + urls: []string{"http://127.0.0.1:*/auth/callback"}, + inputCallbackURL: "http://127.0.0.1:3000/auth/callback", + expectedURL: "http://127.0.0.1:3000/auth/callback", + expectMatch: true, + }, + { + name: "IPv6 loopback with path port", + urls: []string{"http://[::1]:*/auth/callback"}, + inputCallbackURL: "http://[::1]:3000/auth/callback", + expectedURL: "http://[::1]:3000/auth/callback", + expectMatch: true, + }, { name: "loopback with path mismatch", urls: []string{"http://127.0.0.1/callback"}, @@ -549,6 +684,76 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) { } } +func TestLoopbackURLWithWildcardPort(t *testing.T) { + tests := []struct { + name string + input string + output string + }{ + { + name: "localhost http with port strips port", + input: "http://localhost:3000/callback", + output: "http://localhost/callback", + }, + { + name: "localhost http without port stays same", + input: "http://localhost/callback", + output: "http://localhost/callback", + }, + { + name: "IPv4 loopback with port strips port", + input: "http://127.0.0.1:8080/callback", + output: "http://127.0.0.1/callback", + }, + { + name: "IPv4 loopback without port stays same", + input: "http://127.0.0.1/callback", + output: "http://127.0.0.1/callback", + }, + { + name: "IPv6 loopback with port strips port and keeps brackets", + input: "http://[::1]:8080/callback", + output: "http://[::1]/callback", + }, + { + name: "IPv6 loopback preserves path query and fragment", + input: "http://[::1]:8080/auth/callback?code=123#state", + output: "http://[::1]/auth/callback?code=123#state", + }, + { + name: "https loopback returns empty", + input: "https://127.0.0.1:8080/callback", + output: "", + }, + { + name: "non loopback host returns empty", + input: "http://example.com:8080/callback", + output: "", + }, + { + name: "non loopback IP returns empty", + input: "http://192.168.1.10:8080/callback", + output: "", + }, + { + name: "malformed URL returns empty", + input: "http://[::1:8080/callback", + output: "", + }, + { + name: "relative URL returns empty", + input: "/callback", + output: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.output, loopbackURLWithWildcardPort(tt.input)) + }) + } +} + func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) { tests := []struct { name string