mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-10 03:14:19 +00:00
feat: add database storage backend (#1091)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
12125713a2
commit
29a1d3b778
@@ -1,7 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
@@ -56,6 +59,23 @@ func IsPrivateIP(ip net.IP) bool {
|
||||
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
||||
}
|
||||
|
||||
func IsURLPrivate(ctx context.Context, u *url.URL) (bool, error) {
|
||||
var r net.Resolver
|
||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return false, errors.New("cannot resolve hostname")
|
||||
}
|
||||
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
for _, addr := range ips {
|
||||
if IsPrivateIP(addr.IP) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
||||
for _, ipNet := range ipNets {
|
||||
if ipNet.Contains(ip) {
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
@@ -20,9 +26,8 @@ func TestIsLocalhostIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalhostIP(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsLocalhostIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,9 +45,8 @@ func TestIsPrivateLanIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateLanIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsPrivateLanIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,9 +63,9 @@ func TestIsTailscaleIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsTailscaleIP(ip); got != tt.expected {
|
||||
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
|
||||
got := IsTailscaleIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,16 +90,17 @@ func TestIsLocalIPv6(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalIPv6(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsLocalIPv6(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||
})
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
||||
localIPv6Ranges = nil // reset
|
||||
@@ -115,9 +120,8 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsPrivateIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,22 +142,202 @@ func TestListContainsIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := listContainsIP(list, ip); got != tt.expected {
|
||||
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := listContainsIP(list, ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||
})
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
||||
localIPv6Ranges = nil
|
||||
loadLocalIPv6Ranges()
|
||||
|
||||
if len(localIPv6Ranges) != 2 {
|
||||
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
|
||||
assert.Len(t, localIPv6Ranges, 2)
|
||||
}
|
||||
|
||||
func TestIsURLPrivate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectPriv bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "localhost by name",
|
||||
urlStr: "http://localhost",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
urlStr: "http://localhost:8080",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 IP",
|
||||
urlStr: "http://127.0.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
urlStr: "http://127.0.0.1:3000",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
urlStr: "http://[::1]",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback with port",
|
||||
urlStr: "http://[::1]:8080",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 10.x.x.x",
|
||||
urlStr: "http://10.0.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 192.168.x.x",
|
||||
urlStr: "http://192.168.1.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 172.16.x.x",
|
||||
urlStr: "http://172.16.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Tailscale IP",
|
||||
urlStr: "http://100.64.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "public IP - Google DNS",
|
||||
urlStr: "http://8.8.8.8",
|
||||
expectPriv: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "public IP - Cloudflare DNS",
|
||||
urlStr: "http://1.1.1.1",
|
||||
expectPriv: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid hostname",
|
||||
urlStr: "http://this-should-not-resolve-ever-123456789.invalid",
|
||||
expectPriv: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.urlStr)
|
||||
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||
|
||||
isPriv, err := IsURLPrivate(ctx, u)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "IsURLPrivate(%s) expected error but got none", tt.urlStr)
|
||||
} else {
|
||||
require.NoError(t, err, "IsURLPrivate(%s) unexpected error", tt.urlStr)
|
||||
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsURLPrivate_WithDomainName(t *testing.T) {
|
||||
// Note: These tests rely on actual DNS resolution
|
||||
// They test real public domains to ensure they are not flagged as private
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectPriv bool
|
||||
}{
|
||||
{
|
||||
name: "Google public domain",
|
||||
urlStr: "https://www.google.com",
|
||||
expectPriv: false,
|
||||
},
|
||||
{
|
||||
name: "GitHub public domain",
|
||||
urlStr: "https://github.com",
|
||||
expectPriv: false,
|
||||
},
|
||||
{
|
||||
// localhost.localtest.me is a well-known domain that resolves to 127.0.0.1
|
||||
name: "localhost.localtest.me resolves to 127.0.0.1",
|
||||
urlStr: "http://localhost.localtest.me",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 10.0.0.1.nip.io resolves to 10.0.0.1 (private IP)
|
||||
name: "nip.io domain resolving to private 10.x IP",
|
||||
urlStr: "http://10.0.0.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 192.168.1.1.nip.io resolves to 192.168.1.1 (private IP)
|
||||
name: "nip.io domain resolving to private 192.168.x IP",
|
||||
urlStr: "http://192.168.1.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 127.0.0.1.nip.io resolves to 127.0.0.1 (localhost)
|
||||
name: "nip.io domain resolving to localhost",
|
||||
urlStr: "http://127.0.0.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.urlStr)
|
||||
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||
|
||||
isPriv, err := IsURLPrivate(ctx, u)
|
||||
if err != nil {
|
||||
t.Skipf("DNS resolution failed for %s (network issue?): %v", tt.urlStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsURLPrivate_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
u, err := url.Parse("http://example.com")
|
||||
require.NoError(t, err, "Failed to parse URL")
|
||||
|
||||
_, err = IsURLPrivate(ctx, u)
|
||||
assert.Error(t, err, "IsURLPrivate with cancelled context expected error but got none")
|
||||
}
|
||||
|
||||
34
backend/internal/utils/stream_util.go
Normal file
34
backend/internal/utils/stream_util.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrSizeExceeded = errors.New("stream size exceeded")
|
||||
|
||||
// LimitReader is like io.LimitReader but throws an error if the stream exceeds the max size
|
||||
// io.LimitReader instead just returns io.EOF
|
||||
// Adapted from https://github.com/golang/go/issues/51115#issuecomment-1079761212
|
||||
type LimitReader struct {
|
||||
io.ReadCloser
|
||||
N int64
|
||||
}
|
||||
|
||||
func NewLimitReader(r io.ReadCloser, limit int64) *LimitReader {
|
||||
return &LimitReader{r, limit}
|
||||
}
|
||||
|
||||
func (r *LimitReader) Read(p []byte) (n int, err error) {
|
||||
if r.N <= 0 {
|
||||
return 0, ErrSizeExceeded
|
||||
}
|
||||
|
||||
if int64(len(p)) > r.N {
|
||||
p = p[0:r.N]
|
||||
}
|
||||
|
||||
n, err = r.ReadCloser.Read(p)
|
||||
r.N -= int64(n)
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user