mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-23 01:25:10 +00:00
fix: improve wildcard matching by using go-urlpattern (#1332)
This commit is contained in:
@@ -1,14 +1,31 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dunglas/go-urlpattern"
|
||||
)
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
|
||||
// ValidateCallbackURLPattern checks if the given callback URL pattern
|
||||
// is valid according to the rules defined in this package.
|
||||
func ValidateCallbackURLPattern(pattern string) error {
|
||||
if pattern == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
pattern = normalizeToURLPatternStandard(pattern)
|
||||
|
||||
_, err := urlpattern.New(pattern, "", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
|
||||
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
@@ -24,7 +41,12 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
||||
host := u.Hostname()
|
||||
ip := net.ParseIP(host)
|
||||
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
|
||||
u.Host = host
|
||||
// For IPv6 loopback hosts, brackets are required when serializing without a port.
|
||||
if strings.Contains(host, ":") {
|
||||
u.Host = "[" + host + "]"
|
||||
} else {
|
||||
u.Host = host
|
||||
}
|
||||
loopbackCallbackURLWithoutPort = u.String()
|
||||
}
|
||||
}
|
||||
@@ -64,143 +86,129 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Strip fragment part
|
||||
// Strip fragment part.
|
||||
// The endpoint URI MUST NOT include a fragment component.
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
|
||||
|
||||
// Store and strip query part
|
||||
var patternQuery url.Values
|
||||
if i := strings.Index(pattern, "?"); i >= 0 {
|
||||
patternQuery, err = url.ParseQuery(pattern[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
pattern = pattern[:i]
|
||||
}
|
||||
var inputQuery url.Values
|
||||
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
|
||||
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inputCallbackURL = inputCallbackURL[:i]
|
||||
}
|
||||
|
||||
// Split both pattern and input parts
|
||||
patternParts, patternPath := splitParts(pattern)
|
||||
inputParts, inputPath := splitParts(inputCallbackURL)
|
||||
|
||||
// Verify everything except the path and query parameters
|
||||
if len(patternParts) != len(inputParts) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i, patternPart := range patternParts {
|
||||
matched, err := path.Match(patternPart, inputParts[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Verify path with wildcard support
|
||||
matched, err := matchPath(patternPath, inputPath)
|
||||
if err != nil || !matched {
|
||||
pattern, patternQuery, err := extractQueryParams(pattern)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify query parameters
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
inputCallbackURL, inputQuery, err := extractQueryParams(inputCallbackURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
pattern = normalizeToURLPatternStandard(pattern)
|
||||
|
||||
// Validate query params
|
||||
v := validateQueryParams(patternQuery, inputQuery)
|
||||
if !v {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Validate the rest of the URL using urlpattern
|
||||
p, err := urlpattern.New(pattern, "", nil)
|
||||
if err != nil {
|
||||
//nolint:nilerr
|
||||
slog.Warn("invalid callback URL pattern, skipping", "pattern", pattern, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return p.Test(inputCallbackURL, ""), nil
|
||||
}
|
||||
|
||||
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
|
||||
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
|
||||
// and ** for multi-segment wildcards.
|
||||
func normalizeToURLPatternStandard(pattern string) string {
|
||||
patternBase, patternPath := extractPath(pattern)
|
||||
|
||||
var result strings.Builder
|
||||
for i := 0; i < len(patternPath); i++ {
|
||||
if patternPath[i] == '*' {
|
||||
// Replace globstar with a single asterisk
|
||||
if i+1 < len(patternPath) && patternPath[i+1] == '*' {
|
||||
result.WriteString("*")
|
||||
i++ // skip next *
|
||||
} else {
|
||||
// Replace single asterisk with :p{index}
|
||||
result.WriteString(":p")
|
||||
result.WriteString(strconv.Itoa(i))
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(patternPath[i])
|
||||
}
|
||||
}
|
||||
patternPath = result.String()
|
||||
|
||||
return patternBase + patternPath
|
||||
}
|
||||
|
||||
func extractPath(url string) (base string, path string) {
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(url, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := url[i+3:]
|
||||
if j := strings.IndexByte(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexByte(url, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = url[pathStart:]
|
||||
base = url[:pathStart]
|
||||
} else {
|
||||
path = ""
|
||||
base = url
|
||||
}
|
||||
|
||||
return base, path
|
||||
}
|
||||
|
||||
func extractQueryParams(rawUrl string) (base string, query url.Values, err error) {
|
||||
if i := strings.IndexByte(rawUrl, '?'); i >= 0 {
|
||||
query, err = url.ParseQuery(rawUrl[i+1:])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rawUrl = rawUrl[:i]
|
||||
}
|
||||
|
||||
return rawUrl, query, nil
|
||||
}
|
||||
|
||||
func validateQueryParams(patternQuery, inputQuery url.Values) bool {
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
return false
|
||||
}
|
||||
|
||||
for patternKey, patternValues := range patternQuery {
|
||||
inputValues, exists := inputQuery[patternKey]
|
||||
if !exists {
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
if len(patternValues) != len(inputValues) {
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range patternValues {
|
||||
matched, err := path.Match(patternValues[i], inputValues[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// matchPath matches the input path against the pattern with wildcard support
|
||||
// Supported wildcards:
|
||||
//
|
||||
// '*' matches any sequence of characters except '/'
|
||||
// '**' matches any sequence of characters including '/'
|
||||
func matchPath(pattern string, input string) (matches bool, err error) {
|
||||
var regexPattern strings.Builder
|
||||
regexPattern.WriteString("^")
|
||||
|
||||
runes := []rune(pattern)
|
||||
n := len(runes)
|
||||
|
||||
for i := 0; i < n; {
|
||||
switch runes[i] {
|
||||
case '*':
|
||||
// Check if it's a ** (globstar)
|
||||
if i+1 < n && runes[i+1] == '*' {
|
||||
// globstar = .* (match slashes too)
|
||||
regexPattern.WriteString(".*")
|
||||
i += 2
|
||||
} else {
|
||||
// single * = [^/]* (no slash)
|
||||
regexPattern.WriteString(`[^/]*`)
|
||||
i++
|
||||
}
|
||||
default:
|
||||
regexPattern.WriteString(regexp.QuoteMeta(string(runes[i])))
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
regexPattern.WriteString("$")
|
||||
|
||||
matched, err := regexp.MatchString(regexPattern.String(), input)
|
||||
return matched, err
|
||||
}
|
||||
|
||||
// splitParts splits the URL into parts by special characters and returns the path separately
|
||||
func splitParts(s string) (parts []string, path string) {
|
||||
split := func(r rune) bool {
|
||||
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
|
||||
}
|
||||
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(s, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := s[i+3:]
|
||||
if j := strings.IndexRune(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexRune(s, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = s[pathStart:]
|
||||
base := s[:pathStart]
|
||||
parts = strings.FieldsFunc(base, split)
|
||||
} else {
|
||||
parts = strings.FieldsFunc(s, split)
|
||||
path = ""
|
||||
}
|
||||
|
||||
return parts, path
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user