mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-22 23:05:09 +00:00
fix: improve wildcard matching by using go-urlpattern (#1332)
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
@@ -67,19 +65,6 @@ func ValidateClientID(clientID string) bool {
|
||||
|
||||
// ValidateCallbackURL validates callback URLs with support for wildcards
|
||||
func ValidateCallbackURL(raw string) bool {
|
||||
// Don't validate if it contains a wildcard
|
||||
if strings.Contains(raw, "*") {
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !u.IsAbs() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
err := utils.ValidateCallbackURLPattern(raw)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,31 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dunglas/go-urlpattern"
|
||||
)
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
|
||||
// ValidateCallbackURLPattern checks if the given callback URL pattern
|
||||
// is valid according to the rules defined in this package.
|
||||
func ValidateCallbackURLPattern(pattern string) error {
|
||||
if pattern == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
pattern = normalizeToURLPatternStandard(pattern)
|
||||
|
||||
_, err := urlpattern.New(pattern, "", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
|
||||
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
@@ -24,7 +41,12 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
||||
host := u.Hostname()
|
||||
ip := net.ParseIP(host)
|
||||
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
|
||||
u.Host = host
|
||||
// For IPv6 loopback hosts, brackets are required when serializing without a port.
|
||||
if strings.Contains(host, ":") {
|
||||
u.Host = "[" + host + "]"
|
||||
} else {
|
||||
u.Host = host
|
||||
}
|
||||
loopbackCallbackURLWithoutPort = u.String()
|
||||
}
|
||||
}
|
||||
@@ -64,143 +86,129 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Strip fragment part
|
||||
// Strip fragment part.
|
||||
// The endpoint URI MUST NOT include a fragment component.
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
|
||||
|
||||
// Store and strip query part
|
||||
var patternQuery url.Values
|
||||
if i := strings.Index(pattern, "?"); i >= 0 {
|
||||
patternQuery, err = url.ParseQuery(pattern[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
pattern = pattern[:i]
|
||||
}
|
||||
var inputQuery url.Values
|
||||
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
|
||||
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inputCallbackURL = inputCallbackURL[:i]
|
||||
}
|
||||
|
||||
// Split both pattern and input parts
|
||||
patternParts, patternPath := splitParts(pattern)
|
||||
inputParts, inputPath := splitParts(inputCallbackURL)
|
||||
|
||||
// Verify everything except the path and query parameters
|
||||
if len(patternParts) != len(inputParts) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i, patternPart := range patternParts {
|
||||
matched, err := path.Match(patternPart, inputParts[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Verify path with wildcard support
|
||||
matched, err := matchPath(patternPath, inputPath)
|
||||
if err != nil || !matched {
|
||||
pattern, patternQuery, err := extractQueryParams(pattern)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify query parameters
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
inputCallbackURL, inputQuery, err := extractQueryParams(inputCallbackURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
pattern = normalizeToURLPatternStandard(pattern)
|
||||
|
||||
// Validate query params
|
||||
v := validateQueryParams(patternQuery, inputQuery)
|
||||
if !v {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate the rest of the URL using urlpattern
|
||||
p, err := urlpattern.New(pattern, "", nil)
|
||||
if err != nil {
|
||||
//nolint:nilerr
|
||||
slog.Warn("invalid callback URL pattern, skipping", "pattern", pattern, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return p.Test(inputCallbackURL, ""), nil
|
||||
}
|
||||
|
||||
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
|
||||
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
|
||||
// and ** for multi-segment wildcards.
|
||||
func normalizeToURLPatternStandard(pattern string) string {
|
||||
patternBase, patternPath := extractPath(pattern)
|
||||
|
||||
var result strings.Builder
|
||||
for i := 0; i < len(patternPath); i++ {
|
||||
if patternPath[i] == '*' {
|
||||
// Replace globstar with a single asterisk
|
||||
if i+1 < len(patternPath) && patternPath[i+1] == '*' {
|
||||
result.WriteString("*")
|
||||
i++ // skip next *
|
||||
} else {
|
||||
// Replace single asterisk with :p{index}
|
||||
result.WriteString(":p")
|
||||
result.WriteString(strconv.Itoa(i))
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(patternPath[i])
|
||||
}
|
||||
}
|
||||
patternPath = result.String()
|
||||
|
||||
return patternBase + patternPath
|
||||
}
|
||||
|
||||
func extractPath(url string) (base string, path string) {
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(url, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := url[i+3:]
|
||||
if j := strings.IndexByte(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexByte(url, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = url[pathStart:]
|
||||
base = url[:pathStart]
|
||||
} else {
|
||||
path = ""
|
||||
base = url
|
||||
}
|
||||
|
||||
return base, path
|
||||
}
|
||||
|
||||
func extractQueryParams(rawUrl string) (base string, query url.Values, err error) {
|
||||
if i := strings.IndexByte(rawUrl, '?'); i >= 0 {
|
||||
query, err = url.ParseQuery(rawUrl[i+1:])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rawUrl = rawUrl[:i]
|
||||
}
|
||||
|
||||
return rawUrl, query, nil
|
||||
}
|
||||
|
||||
func validateQueryParams(patternQuery, inputQuery url.Values) bool {
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
return false
|
||||
}
|
||||
|
||||
for patternKey, patternValues := range patternQuery {
|
||||
inputValues, exists := inputQuery[patternKey]
|
||||
if !exists {
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
if len(patternValues) != len(inputValues) {
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range patternValues {
|
||||
matched, err := path.Match(patternValues[i], inputValues[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// matchPath matches the input path against the pattern with wildcard support
|
||||
// Supported wildcards:
|
||||
//
|
||||
// '*' matches any sequence of characters except '/'
|
||||
// '**' matches any sequence of characters including '/'
|
||||
func matchPath(pattern string, input string) (matches bool, err error) {
|
||||
var regexPattern strings.Builder
|
||||
regexPattern.WriteString("^")
|
||||
|
||||
runes := []rune(pattern)
|
||||
n := len(runes)
|
||||
|
||||
for i := 0; i < n; {
|
||||
switch runes[i] {
|
||||
case '*':
|
||||
// Check if it's a ** (globstar)
|
||||
if i+1 < n && runes[i+1] == '*' {
|
||||
// globstar = .* (match slashes too)
|
||||
regexPattern.WriteString(".*")
|
||||
i += 2
|
||||
} else {
|
||||
// single * = [^/]* (no slash)
|
||||
regexPattern.WriteString(`[^/]*`)
|
||||
i++
|
||||
}
|
||||
default:
|
||||
regexPattern.WriteString(regexp.QuoteMeta(string(runes[i])))
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
regexPattern.WriteString("$")
|
||||
|
||||
matched, err := regexp.MatchString(regexPattern.String(), input)
|
||||
return matched, err
|
||||
}
|
||||
|
||||
// splitParts splits the URL into parts by special characters and returns the path separately
|
||||
func splitParts(s string) (parts []string, path string) {
|
||||
split := func(r rune) bool {
|
||||
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
|
||||
}
|
||||
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(s, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := s[i+3:]
|
||||
if j := strings.IndexRune(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexRune(s, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = s[pathStart:]
|
||||
base := s[:pathStart]
|
||||
parts = strings.FieldsFunc(base, split)
|
||||
} else {
|
||||
parts = strings.FieldsFunc(s, split)
|
||||
path = ""
|
||||
}
|
||||
|
||||
return parts, path
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -7,6 +7,77 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateCallbackURLPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "exact URL",
|
||||
pattern: "https://example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard scheme",
|
||||
pattern: "*://example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard port",
|
||||
pattern: "https://example.com:*/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard port",
|
||||
pattern: "https://example.com:80*/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard userinfo",
|
||||
pattern: "https://user:*@example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "glob wildcard",
|
||||
pattern: "*",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "relative URL",
|
||||
pattern: "/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "missing scheme separator",
|
||||
pattern: "https//example.com/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "malformed wildcard host glob",
|
||||
pattern: "https://exa[mple.com/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "malformed authority",
|
||||
pattern: "https://[::1/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateCallbackURLPattern(tt.pattern)
|
||||
if tt.shouldError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCallbackURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -187,12 +258,6 @@ func TestMatchCallbackURL(t *testing.T) {
|
||||
"https://example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unexpected credentials",
|
||||
"https://example.com/callback",
|
||||
"https://user:pass@example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard password",
|
||||
"https://user:*@example.com/callback",
|
||||
@@ -347,7 +412,7 @@ func TestMatchCallbackURL(t *testing.T) {
|
||||
"backslash instead of forward slash",
|
||||
"https://example.com/callback",
|
||||
"https://example.com\\callback",
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"double slash in hostname (protocol smuggling)",
|
||||
@@ -553,246 +618,3 @@ func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
input string
|
||||
shouldMatch bool
|
||||
}{
|
||||
// Exact matches
|
||||
{
|
||||
name: "exact match",
|
||||
pattern: "/callback",
|
||||
input: "/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "exact mismatch",
|
||||
pattern: "/callback",
|
||||
input: "/other",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "empty paths",
|
||||
pattern: "",
|
||||
input: "",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Single wildcard (*)
|
||||
{
|
||||
name: "single wildcard matches segment",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard doesn't match multiple segments",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api/v1/v2/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single wildcard at end",
|
||||
pattern: "/callback/*",
|
||||
input: "/callback/test",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard at start",
|
||||
pattern: "/*/callback",
|
||||
input: "/api/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "multiple single wildcards",
|
||||
pattern: "/*/test/*",
|
||||
input: "/api/test/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard prefix",
|
||||
pattern: "/test*",
|
||||
input: "/testing",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard suffix",
|
||||
pattern: "/*-callback",
|
||||
input: "/oauth-callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard middle",
|
||||
pattern: "/api-*-v1",
|
||||
input: "/api-internal-v1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Double wildcard (**)
|
||||
{
|
||||
name: "double wildcard matches multiple segments",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/v1/v2/v3/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard matches single segment",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard doesn't match when pattern has extra slashes",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "double wildcard at end",
|
||||
pattern: "/api/**",
|
||||
input: "/api/v1/v2/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard in middle",
|
||||
pattern: "/api/**/v2/**/callback",
|
||||
input: "/api/v1/v2/v3/v4/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Complex patterns
|
||||
{
|
||||
name: "mix of single and double wildcards",
|
||||
pattern: "/*/api/**/callback",
|
||||
input: "/app/api/v1/v2/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard with special characters",
|
||||
pattern: "/callback-*",
|
||||
input: "/callback-123",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path with query-like string (no special handling)",
|
||||
pattern: "/callback?code=*",
|
||||
input: "/callback?code=abc",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "single wildcard matches empty segment",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api//callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "pattern longer than input",
|
||||
pattern: "/api/v1/callback",
|
||||
input: "/api",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "input longer than pattern",
|
||||
pattern: "/api",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches, err := matchPath(tt.pattern, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.shouldMatch, matches)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitParts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedParts []string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "simple https URL",
|
||||
input: "https://example.com/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with port",
|
||||
input: "https://example.com:8080/callback",
|
||||
expectedParts: []string{"https", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with subdomain",
|
||||
input: "https://api.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with credentials",
|
||||
input: "https://user:pass@example.com/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL without path",
|
||||
input: "https://example.com",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "",
|
||||
},
|
||||
{
|
||||
name: "URL with deep path",
|
||||
input: "https://example.com/api/v1/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/api/v1/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with path and query",
|
||||
input: "https://example.com/callback?code=123",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback?code=123",
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash",
|
||||
input: "https://example.com/",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/",
|
||||
},
|
||||
{
|
||||
name: "URL with multiple subdomains",
|
||||
input: "https://api.v1.staging.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "v1", "staging", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with port and credentials",
|
||||
input: "https://user:pass@example.com:8080/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "scheme with authority separator but no slash",
|
||||
input: "http://example.com",
|
||||
expectedParts: []string{"http", "example", "com"},
|
||||
expectedPath: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parts, path := splitParts(tt.input)
|
||||
assert.Equal(t, tt.expectedParts, parts, "parts mismatch")
|
||||
assert.Equal(t, tt.expectedPath, path, "path mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user