1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-23 17:10:06 +00:00

fix: handle IPv6 addresses in callback URLs (#1355)

This commit is contained in:
Alessandro (Ale) Segala
2026-03-03 01:52:42 -08:00
committed by GitHub
parent 6159e0bf96
commit 89349dc1ad
2 changed files with 290 additions and 31 deletions

View File

@@ -34,22 +34,7 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
// time of the request for loopback IP redirect URIs, to accommodate
// clients that obtain an available ephemeral port from the operating
// system at the time of the request.
loopbackCallbackURLWithoutPort := ""
u, _ := url.Parse(inputCallbackURL)
if u != nil && u.Scheme == "http" {
host := u.Hostname()
ip := net.ParseIP(host)
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
// For IPv6 loopback hosts, brackets are required when serializing without a port.
if strings.Contains(host, ":") {
u.Host = "[" + host + "]"
} else {
u.Host = host
}
loopbackCallbackURLWithoutPort = u.String()
}
}
loopbackCallbackURLWithoutPort := loopbackURLWithWildcardPort(inputCallbackURL)
for _, pattern := range urls {
// Try the original callback first
@@ -76,6 +61,28 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
return "", nil
}
func loopbackURLWithWildcardPort(input string) string {
u, _ := url.Parse(input)
if u == nil || u.Scheme != "http" {
return ""
}
host := u.Hostname()
ip := net.ParseIP(host)
if host != "localhost" && (ip == nil || !ip.IsLoopback()) {
return ""
}
// For IPv6 loopback hosts, brackets are required when serializing without a port.
if strings.Contains(host, ":") {
u.Host = "[" + host + "]"
} else {
u.Host = host
}
return u.String()
}
// matchCallbackURL checks if the input callback URL matches the given pattern.
// It supports wildcard matching for paths and query parameters.
//
@@ -125,10 +132,57 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
// and ** for multi-segment wildcards.
// Additionally, it escapes ":" with a backslash inside IPv6 addresses
func normalizeToURLPatternStandard(pattern string) string {
patternBase, patternPath := extractPath(pattern)
var result strings.Builder
result.Grow(len(pattern) + 5) // Add 5 for some extra capacity, hoping to avoid many re-allocations
// First, process the base
// 0 = scheme
// 1 = hostname (optionally with username/password) - before IPv6 start (no `[` found)
// 2 = is matching IPv6 (until `]`)
// 3 = after hostname
var step int
for i := 0; i < len(patternBase); i++ {
switch step {
case 0:
if i > 3 && patternBase[i] == '/' && patternBase[i-1] == '/' && patternBase[i-2] == ':' {
// We just passed the scheme
step = 1
}
case 1:
switch patternBase[i] {
case '/', ']':
// No IPv6, skip to end of this logic
step = 3
case '[':
// Start of IPv6 match
step = 2
}
case 2:
if patternBase[i] == '/' || patternBase[i] == ']' || patternBase[i] == '[' {
// End of IPv6 match
step = 3
}
switch patternBase[i] {
case ':':
// We are matching an IPv6 block and there's a colon, so escape that
result.WriteByte('\\')
case '/', ']', '[':
// End of IPv6 match
step = 3
}
}
// Write the byte
result.WriteByte(patternBase[i])
}
// Next, process the path
for i := 0; i < len(patternPath); i++ {
if patternPath[i] == '*' {
// Replace globstar with a single asterisk
@@ -141,19 +195,19 @@ func normalizeToURLPatternStandard(pattern string) string {
result.WriteString(strconv.Itoa(i))
}
} else {
// Add the byte
result.WriteByte(patternPath[i])
}
}
patternPath = result.String()
return patternBase + patternPath
return result.String()
}
func extractPath(url string) (base string, path string) {
pathStart := -1
// Look for scheme:// first
if i := strings.Index(url, "://"); i >= 0 {
i := strings.Index(url, "://")
if i >= 0 {
// Look for the next slash after scheme://
rest := url[i+3:]
if j := strings.IndexByte(rest, '/'); j >= 0 {