1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 23:05:09 +00:00

fix: improve wildcard matching by using go-urlpattern (#1332)

This commit is contained in:
Elias Schneider
2026-02-28 14:08:35 +01:00
committed by GitHub
parent d98db79d5e
commit 3a339e3319
5 changed files with 242 additions and 382 deletions

View File

@@ -1,14 +1,31 @@
package utils
import (
"log/slog"
"net"
"net/url"
"path"
"regexp"
"strconv"
"strings"
"github.com/dunglas/go-urlpattern"
)
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
// ValidateCallbackURLPattern checks if the given callback URL pattern
// is valid according to the rules defined in this package.
func ValidateCallbackURLPattern(pattern string) error {
if pattern == "*" {
return nil
}
pattern, _, _ = strings.Cut(pattern, "#")
pattern = normalizeToURLPatternStandard(pattern)
_, err := urlpattern.New(pattern, "", nil)
return err
}
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
@@ -24,7 +41,12 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
host := u.Hostname()
ip := net.ParseIP(host)
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
u.Host = host
// For IPv6 loopback hosts, brackets are required when serializing without a port.
if strings.Contains(host, ":") {
u.Host = "[" + host + "]"
} else {
u.Host = host
}
loopbackCallbackURLWithoutPort = u.String()
}
}
@@ -64,143 +86,129 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
return true, nil
}
// Strip fragment part
// Strip fragment part.
// The endpoint URI MUST NOT include a fragment component.
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
pattern, _, _ = strings.Cut(pattern, "#")
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
// Store and strip query part
var patternQuery url.Values
if i := strings.Index(pattern, "?"); i >= 0 {
patternQuery, err = url.ParseQuery(pattern[i+1:])
if err != nil {
return false, err
}
pattern = pattern[:i]
}
var inputQuery url.Values
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
if err != nil {
return false, err
}
inputCallbackURL = inputCallbackURL[:i]
}
// Split both pattern and input parts
patternParts, patternPath := splitParts(pattern)
inputParts, inputPath := splitParts(inputCallbackURL)
// Verify everything except the path and query parameters
if len(patternParts) != len(inputParts) {
return false, nil
}
for i, patternPart := range patternParts {
matched, err := path.Match(patternPart, inputParts[i])
if err != nil || !matched {
return false, err
}
}
// Verify path with wildcard support
matched, err := matchPath(patternPath, inputPath)
if err != nil || !matched {
pattern, patternQuery, err := extractQueryParams(pattern)
if err != nil {
return false, err
}
// Verify query parameters
if len(patternQuery) != len(inputQuery) {
inputCallbackURL, inputQuery, err := extractQueryParams(inputCallbackURL)
if err != nil {
return false, err
}
pattern = normalizeToURLPatternStandard(pattern)
// Validate query params
v := validateQueryParams(patternQuery, inputQuery)
if !v {
return false, nil
}
// Validate the rest of the URL using urlpattern
p, err := urlpattern.New(pattern, "", nil)
if err != nil {
//nolint:nilerr
slog.Warn("invalid callback URL pattern, skipping", "pattern", pattern, "error", err)
return false, nil
}
return p.Test(inputCallbackURL, ""), nil
}
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
// and ** for multi-segment wildcards.
func normalizeToURLPatternStandard(pattern string) string {
patternBase, patternPath := extractPath(pattern)
var result strings.Builder
for i := 0; i < len(patternPath); i++ {
if patternPath[i] == '*' {
// Replace globstar with a single asterisk
if i+1 < len(patternPath) && patternPath[i+1] == '*' {
result.WriteString("*")
i++ // skip next *
} else {
// Replace single asterisk with :p{index}
result.WriteString(":p")
result.WriteString(strconv.Itoa(i))
}
} else {
result.WriteByte(patternPath[i])
}
}
patternPath = result.String()
return patternBase + patternPath
}
func extractPath(url string) (base string, path string) {
pathStart := -1
// Look for scheme:// first
if i := strings.Index(url, "://"); i >= 0 {
// Look for the next slash after scheme://
rest := url[i+3:]
if j := strings.IndexByte(rest, '/'); j >= 0 {
pathStart = i + 3 + j
}
} else {
// Otherwise, first slash is path start
pathStart = strings.IndexByte(url, '/')
}
if pathStart >= 0 {
path = url[pathStart:]
base = url[:pathStart]
} else {
path = ""
base = url
}
return base, path
}
func extractQueryParams(rawUrl string) (base string, query url.Values, err error) {
if i := strings.IndexByte(rawUrl, '?'); i >= 0 {
query, err = url.ParseQuery(rawUrl[i+1:])
if err != nil {
return "", nil, err
}
rawUrl = rawUrl[:i]
}
return rawUrl, query, nil
}
func validateQueryParams(patternQuery, inputQuery url.Values) bool {
if len(patternQuery) != len(inputQuery) {
return false
}
for patternKey, patternValues := range patternQuery {
inputValues, exists := inputQuery[patternKey]
if !exists {
return false, nil
return false
}
if len(patternValues) != len(inputValues) {
return false, nil
return false
}
for i := range patternValues {
matched, err := path.Match(patternValues[i], inputValues[i])
if err != nil || !matched {
return false, err
return false
}
}
}
return true, nil
}
// matchPath matches the input path against the pattern with wildcard support
// Supported wildcards:
//
// '*' matches any sequence of characters except '/'
// '**' matches any sequence of characters including '/'
func matchPath(pattern string, input string) (matches bool, err error) {
var regexPattern strings.Builder
regexPattern.WriteString("^")
runes := []rune(pattern)
n := len(runes)
for i := 0; i < n; {
switch runes[i] {
case '*':
// Check if it's a ** (globstar)
if i+1 < n && runes[i+1] == '*' {
// globstar = .* (match slashes too)
regexPattern.WriteString(".*")
i += 2
} else {
// single * = [^/]* (no slash)
regexPattern.WriteString(`[^/]*`)
i++
}
default:
regexPattern.WriteString(regexp.QuoteMeta(string(runes[i])))
i++
}
}
regexPattern.WriteString("$")
matched, err := regexp.MatchString(regexPattern.String(), input)
return matched, err
}
// splitParts splits the URL into parts by special characters and returns the path separately
func splitParts(s string) (parts []string, path string) {
split := func(r rune) bool {
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
}
pathStart := -1
// Look for scheme:// first
if i := strings.Index(s, "://"); i >= 0 {
// Look for the next slash after scheme://
rest := s[i+3:]
if j := strings.IndexRune(rest, '/'); j >= 0 {
pathStart = i + 3 + j
}
} else {
// Otherwise, first slash is path start
pathStart = strings.IndexRune(s, '/')
}
if pathStart >= 0 {
path = s[pathStart:]
base := s[:pathStart]
parts = strings.FieldsFunc(base, split)
} else {
parts = strings.FieldsFunc(s, split)
path = ""
}
return parts, path
return true
}

View File

@@ -7,6 +7,77 @@ import (
"github.com/stretchr/testify/require"
)
func TestValidateCallbackURLPattern(t *testing.T) {
tests := []struct {
name string
pattern string
shouldError bool
}{
{
name: "exact URL",
pattern: "https://example.com/callback",
shouldError: false,
},
{
name: "wildcard scheme",
pattern: "*://example.com/callback",
shouldError: false,
},
{
name: "wildcard port",
pattern: "https://example.com:*/callback",
shouldError: false,
},
{
name: "partial wildcard port",
pattern: "https://example.com:80*/callback",
shouldError: false,
},
{
name: "wildcard userinfo",
pattern: "https://user:*@example.com/callback",
shouldError: false,
},
{
name: "glob wildcard",
pattern: "*",
shouldError: false,
},
{
name: "relative URL",
pattern: "/callback",
shouldError: true,
},
{
name: "missing scheme separator",
pattern: "https//example.com/callback",
shouldError: true,
},
{
name: "malformed wildcard host glob",
pattern: "https://exa[mple.com/callback",
shouldError: true,
},
{
name: "malformed authority",
pattern: "https://[::1/callback",
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCallbackURLPattern(tt.pattern)
if tt.shouldError {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func TestMatchCallbackURL(t *testing.T) {
tests := []struct {
name string
@@ -187,12 +258,6 @@ func TestMatchCallbackURL(t *testing.T) {
"https://example.com/callback",
false,
},
{
"unexpected credentials",
"https://example.com/callback",
"https://user:pass@example.com/callback",
false,
},
{
"wildcard password",
"https://user:*@example.com/callback",
@@ -347,7 +412,7 @@ func TestMatchCallbackURL(t *testing.T) {
"backslash instead of forward slash",
"https://example.com/callback",
"https://example.com\\callback",
false,
true,
},
{
"double slash in hostname (protocol smuggling)",
@@ -553,246 +618,3 @@ func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) {
})
}
}
func TestMatchPath(t *testing.T) {
tests := []struct {
name string
pattern string
input string
shouldMatch bool
}{
// Exact matches
{
name: "exact match",
pattern: "/callback",
input: "/callback",
shouldMatch: true,
},
{
name: "exact mismatch",
pattern: "/callback",
input: "/other",
shouldMatch: false,
},
{
name: "empty paths",
pattern: "",
input: "",
shouldMatch: true,
},
// Single wildcard (*)
{
name: "single wildcard matches segment",
pattern: "/api/*/callback",
input: "/api/v1/callback",
shouldMatch: true,
},
{
name: "single wildcard doesn't match multiple segments",
pattern: "/api/*/callback",
input: "/api/v1/v2/callback",
shouldMatch: false,
},
{
name: "single wildcard at end",
pattern: "/callback/*",
input: "/callback/test",
shouldMatch: true,
},
{
name: "single wildcard at start",
pattern: "/*/callback",
input: "/api/callback",
shouldMatch: true,
},
{
name: "multiple single wildcards",
pattern: "/*/test/*",
input: "/api/test/callback",
shouldMatch: true,
},
{
name: "partial wildcard prefix",
pattern: "/test*",
input: "/testing",
shouldMatch: true,
},
{
name: "partial wildcard suffix",
pattern: "/*-callback",
input: "/oauth-callback",
shouldMatch: true,
},
{
name: "partial wildcard middle",
pattern: "/api-*-v1",
input: "/api-internal-v1",
shouldMatch: true,
},
// Double wildcard (**)
{
name: "double wildcard matches multiple segments",
pattern: "/api/**/callback",
input: "/api/v1/v2/v3/callback",
shouldMatch: true,
},
{
name: "double wildcard matches single segment",
pattern: "/api/**/callback",
input: "/api/v1/callback",
shouldMatch: true,
},
{
name: "double wildcard doesn't match when pattern has extra slashes",
pattern: "/api/**/callback",
input: "/api/callback",
shouldMatch: false,
},
{
name: "double wildcard at end",
pattern: "/api/**",
input: "/api/v1/v2/callback",
shouldMatch: true,
},
{
name: "double wildcard in middle",
pattern: "/api/**/v2/**/callback",
input: "/api/v1/v2/v3/v4/callback",
shouldMatch: true,
},
// Complex patterns
{
name: "mix of single and double wildcards",
pattern: "/*/api/**/callback",
input: "/app/api/v1/v2/callback",
shouldMatch: true,
},
{
name: "wildcard with special characters",
pattern: "/callback-*",
input: "/callback-123",
shouldMatch: true,
},
{
name: "path with query-like string (no special handling)",
pattern: "/callback?code=*",
input: "/callback?code=abc",
shouldMatch: true,
},
// Edge cases
{
name: "single wildcard matches empty segment",
pattern: "/api/*/callback",
input: "/api//callback",
shouldMatch: true,
},
{
name: "pattern longer than input",
pattern: "/api/v1/callback",
input: "/api",
shouldMatch: false,
},
{
name: "input longer than pattern",
pattern: "/api",
input: "/api/v1/callback",
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches, err := matchPath(tt.pattern, tt.input)
require.NoError(t, err)
assert.Equal(t, tt.shouldMatch, matches)
})
}
}
func TestSplitParts(t *testing.T) {
tests := []struct {
name string
input string
expectedParts []string
expectedPath string
}{
{
name: "simple https URL",
input: "https://example.com/callback",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/callback",
},
{
name: "URL with port",
input: "https://example.com:8080/callback",
expectedParts: []string{"https", "example", "com", "8080"},
expectedPath: "/callback",
},
{
name: "URL with subdomain",
input: "https://api.example.com/callback",
expectedParts: []string{"https", "api", "example", "com"},
expectedPath: "/callback",
},
{
name: "URL with credentials",
input: "https://user:pass@example.com/callback",
expectedParts: []string{"https", "user", "pass", "example", "com"},
expectedPath: "/callback",
},
{
name: "URL without path",
input: "https://example.com",
expectedParts: []string{"https", "example", "com"},
expectedPath: "",
},
{
name: "URL with deep path",
input: "https://example.com/api/v1/callback",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/api/v1/callback",
},
{
name: "URL with path and query",
input: "https://example.com/callback?code=123",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/callback?code=123",
},
{
name: "URL with trailing slash",
input: "https://example.com/",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/",
},
{
name: "URL with multiple subdomains",
input: "https://api.v1.staging.example.com/callback",
expectedParts: []string{"https", "api", "v1", "staging", "example", "com"},
expectedPath: "/callback",
},
{
name: "URL with port and credentials",
input: "https://user:pass@example.com:8080/callback",
expectedParts: []string{"https", "user", "pass", "example", "com", "8080"},
expectedPath: "/callback",
},
{
name: "scheme with authority separator but no slash",
input: "http://example.com",
expectedParts: []string{"http", "example", "com"},
expectedPath: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parts, path := splitParts(tt.input)
assert.Equal(t, tt.expectedParts, parts, "parts mismatch")
assert.Equal(t, tt.expectedPath, path, "path mismatch")
})
}
}