1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 22:00:08 +00:00
Files
pocket-id/backend/internal/utils/callback_url_util.go
2026-03-03 10:52:42 +01:00

269 lines
6.8 KiB
Go

package utils
import (
"log/slog"
"net"
"net/url"
"path"
"strconv"
"strings"
"github.com/dunglas/go-urlpattern"
)
// ValidateCallbackURLPattern checks if the given callback URL pattern
// is valid according to the rules defined in this package.
func ValidateCallbackURLPattern(pattern string) error {
if pattern == "*" {
return nil
}
pattern, _, _ = strings.Cut(pattern, "#")
pattern = normalizeToURLPatternStandard(pattern)
_, err := urlpattern.New(pattern, "", nil)
return err
}
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
//
// The authorization server MUST allow any port to be specified at the
// time of the request for loopback IP redirect URIs, to accommodate
// clients that obtain an available ephemeral port from the operating
// system at the time of the request.
loopbackCallbackURLWithoutPort := loopbackURLWithWildcardPort(inputCallbackURL)
for _, pattern := range urls {
// Try the original callback first
matches, err := matchCallbackURL(pattern, inputCallbackURL)
if err != nil {
return "", err
}
if matches {
return inputCallbackURL, nil
}
// If we have a loopback variant, try that too
if loopbackCallbackURLWithoutPort != "" {
matches, err = matchCallbackURL(pattern, loopbackCallbackURLWithoutPort)
if err != nil {
return "", err
}
if matches {
return inputCallbackURL, nil
}
}
}
return "", nil
}
func loopbackURLWithWildcardPort(input string) string {
u, _ := url.Parse(input)
if u == nil || u.Scheme != "http" {
return ""
}
host := u.Hostname()
ip := net.ParseIP(host)
if host != "localhost" && (ip == nil || !ip.IsLoopback()) {
return ""
}
// For IPv6 loopback hosts, brackets are required when serializing without a port.
if strings.Contains(host, ":") {
u.Host = "[" + host + "]"
} else {
u.Host = host
}
return u.String()
}
// matchCallbackURL checks if the input callback URL matches the given pattern.
// It supports wildcard matching for paths and query parameters.
//
// The base URL (scheme, userinfo, host, port) and query parameters supports single '*' wildcards only,
// while the path supports both single '*' and double '**' wildcards.
func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, err error) {
if pattern == inputCallbackURL || pattern == "*" {
return true, nil
}
// Strip fragment part.
// The endpoint URI MUST NOT include a fragment component.
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
pattern, _, _ = strings.Cut(pattern, "#")
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
// Store and strip query part
pattern, patternQuery, err := extractQueryParams(pattern)
if err != nil {
return false, err
}
inputCallbackURL, inputQuery, err := extractQueryParams(inputCallbackURL)
if err != nil {
return false, err
}
pattern = normalizeToURLPatternStandard(pattern)
// Validate query params
v := validateQueryParams(patternQuery, inputQuery)
if !v {
return false, nil
}
// Validate the rest of the URL using urlpattern
p, err := urlpattern.New(pattern, "", nil)
if err != nil {
//nolint:nilerr
slog.Warn("invalid callback URL pattern, skipping", "pattern", pattern, "error", err)
return false, nil
}
return p.Test(inputCallbackURL, ""), nil
}
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
// and ** for multi-segment wildcards.
// Additionally, it escapes ":" with a backslash inside IPv6 addresses
func normalizeToURLPatternStandard(pattern string) string {
patternBase, patternPath := extractPath(pattern)
var result strings.Builder
result.Grow(len(pattern) + 5) // Add 5 for some extra capacity, hoping to avoid many re-allocations
// First, process the base
// 0 = scheme
// 1 = hostname (optionally with username/password) - before IPv6 start (no `[` found)
// 2 = is matching IPv6 (until `]`)
// 3 = after hostname
var step int
for i := 0; i < len(patternBase); i++ {
switch step {
case 0:
if i > 3 && patternBase[i] == '/' && patternBase[i-1] == '/' && patternBase[i-2] == ':' {
// We just passed the scheme
step = 1
}
case 1:
switch patternBase[i] {
case '/', ']':
// No IPv6, skip to end of this logic
step = 3
case '[':
// Start of IPv6 match
step = 2
}
case 2:
if patternBase[i] == '/' || patternBase[i] == ']' || patternBase[i] == '[' {
// End of IPv6 match
step = 3
}
switch patternBase[i] {
case ':':
// We are matching an IPv6 block and there's a colon, so escape that
result.WriteByte('\\')
case '/', ']', '[':
// End of IPv6 match
step = 3
}
}
// Write the byte
result.WriteByte(patternBase[i])
}
// Next, process the path
for i := 0; i < len(patternPath); i++ {
if patternPath[i] == '*' {
// Replace globstar with a single asterisk
if i+1 < len(patternPath) && patternPath[i+1] == '*' {
result.WriteString("*")
i++ // skip next *
} else {
// Replace single asterisk with :p{index}
result.WriteString(":p")
result.WriteString(strconv.Itoa(i))
}
} else {
// Add the byte
result.WriteByte(patternPath[i])
}
}
return result.String()
}
func extractPath(url string) (base string, path string) {
pathStart := -1
// Look for scheme:// first
i := strings.Index(url, "://")
if i >= 0 {
// Look for the next slash after scheme://
rest := url[i+3:]
if j := strings.IndexByte(rest, '/'); j >= 0 {
pathStart = i + 3 + j
}
} else {
// Otherwise, first slash is path start
pathStart = strings.IndexByte(url, '/')
}
if pathStart >= 0 {
path = url[pathStart:]
base = url[:pathStart]
} else {
path = ""
base = url
}
return base, path
}
func extractQueryParams(rawUrl string) (base string, query url.Values, err error) {
if i := strings.IndexByte(rawUrl, '?'); i >= 0 {
query, err = url.ParseQuery(rawUrl[i+1:])
if err != nil {
return "", nil, err
}
rawUrl = rawUrl[:i]
}
return rawUrl, query, nil
}
func validateQueryParams(patternQuery, inputQuery url.Values) bool {
if len(patternQuery) != len(inputQuery) {
return false
}
for patternKey, patternValues := range patternQuery {
inputValues, exists := inputQuery[patternKey]
if !exists {
return false
}
if len(patternValues) != len(inputValues) {
return false
}
for i := range patternValues {
matched, err := path.Match(patternValues[i], inputValues[i])
if err != nil || !matched {
return false
}
}
}
return true
}