mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-14 22:50:15 +00:00
fix: enable foreign key check for sqlite (#863)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
@@ -86,9 +86,6 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
if common.EnvConfig.DbConnectionString == "" {
|
if common.EnvConfig.DbConnectionString == "" {
|
||||||
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
|
|
||||||
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
|
|
||||||
}
|
|
||||||
sqliteutil.RegisterSqliteFunctions()
|
sqliteutil.RegisterSqliteFunctions()
|
||||||
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,25 +120,43 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// The official C implementation of SQLite allows some additional properties in the connection string
|
|
||||||
// that are not supported in the in the modernc.org/sqlite driver, and which must be passed as PRAGMA args instead.
|
|
||||||
// To ensure that people can use similar args as in the C driver, which was also used by Pocket ID
|
|
||||||
// previously (via github.com/mattn/go-sqlite3), we are converting some options.
|
|
||||||
func parseSqliteConnectionString(connString string) (string, error) {
|
func parseSqliteConnectionString(connString string) (string, error) {
|
||||||
if !strings.HasPrefix(connString, "file:") {
|
if !strings.HasPrefix(connString, "file:") {
|
||||||
connString = "file:" + connString
|
connString = "file:" + connString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we're using an in-memory database
|
||||||
|
isMemoryDB := isSqliteInMemory(connString)
|
||||||
|
|
||||||
|
// Parse the connection string
|
||||||
connStringUrl, err := url.Parse(connString)
|
connStringUrl, err := url.Parse(connString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
return "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert options for the old SQLite driver to the new one
|
||||||
|
convertSqlitePragmaArgs(connStringUrl)
|
||||||
|
|
||||||
|
// Add the default and required params
|
||||||
|
err = addSqliteDefaultParameters(connStringUrl, isMemoryDB)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid SQLite connection string: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return connStringUrl.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The official C implementation of SQLite allows some additional properties in the connection string
|
||||||
|
// that are not supported in the in the modernc.org/sqlite driver, and which must be passed as PRAGMA args instead.
|
||||||
|
// To ensure that people can use similar args as in the C driver, which was also used by Pocket ID
|
||||||
|
// previously (via github.com/mattn/go-sqlite3), we are converting some options.
|
||||||
|
// Note this function updates connStringUrl.
|
||||||
|
func convertSqlitePragmaArgs(connStringUrl *url.URL) {
|
||||||
// Reference: https://github.com/mattn/go-sqlite3?tab=readme-ov-file#connection-string
|
// Reference: https://github.com/mattn/go-sqlite3?tab=readme-ov-file#connection-string
|
||||||
// This only includes a subset of options, excluding those that are not relevant to us
|
// This only includes a subset of options, excluding those that are not relevant to us
|
||||||
qs := make(url.Values, len(connStringUrl.Query()))
|
qs := make(url.Values, len(connStringUrl.Query()))
|
||||||
for k, v := range connStringUrl.Query() {
|
for k, v := range connStringUrl.Query() {
|
||||||
switch k {
|
switch strings.ToLower(k) {
|
||||||
case "_auto_vacuum", "_vacuum":
|
case "_auto_vacuum", "_vacuum":
|
||||||
qs.Add("_pragma", "auto_vacuum("+v[0]+")")
|
qs.Add("_pragma", "auto_vacuum("+v[0]+")")
|
||||||
case "_busy_timeout", "_timeout":
|
case "_busy_timeout", "_timeout":
|
||||||
@@ -162,9 +177,123 @@ func parseSqliteConnectionString(connString string) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the connStringUrl object
|
||||||
|
connStringUrl.RawQuery = qs.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds the default (and some required) parameters to the SQLite connection string.
|
||||||
|
// Note this function updates connStringUrl.
|
||||||
|
func addSqliteDefaultParameters(connStringUrl *url.URL, isMemoryDB bool) error {
|
||||||
|
// This function include code adapted from https://github.com/dapr/components-contrib/blob/v1.14.6/
|
||||||
|
// Copyright (C) 2023 The Dapr Authors
|
||||||
|
// License: Apache2
|
||||||
|
const defaultBusyTimeout = 2500 * time.Millisecond
|
||||||
|
|
||||||
|
// Get the "query string" from the connection string if present
|
||||||
|
qs := connStringUrl.Query()
|
||||||
|
if len(qs) == 0 {
|
||||||
|
qs = make(url.Values, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the database is in-memory, we must ensure that cache=shared is set
|
||||||
|
if isMemoryDB {
|
||||||
|
qs["cache"] = []string{"shared"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the database is read-only or immutable
|
||||||
|
isReadOnly := false
|
||||||
|
if len(qs["mode"]) > 0 {
|
||||||
|
// Keep the first value only
|
||||||
|
qs["mode"] = []string{
|
||||||
|
strings.ToLower(qs["mode"][0]),
|
||||||
|
}
|
||||||
|
if qs["mode"][0] == "ro" {
|
||||||
|
isReadOnly = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(qs["immutable"]) > 0 {
|
||||||
|
// Keep the first value only
|
||||||
|
qs["immutable"] = []string{
|
||||||
|
strings.ToLower(qs["immutable"][0]),
|
||||||
|
}
|
||||||
|
if qs["immutable"][0] == "1" {
|
||||||
|
isReadOnly = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate"
|
||||||
|
if len(qs["_txlock"]) > 0 {
|
||||||
|
// Keep the first value only
|
||||||
|
qs["_txlock"] = []string{
|
||||||
|
strings.ToLower(qs["_txlock"][0]),
|
||||||
|
}
|
||||||
|
if qs["_txlock"][0] != "immediate" {
|
||||||
|
slog.Warn("SQLite connection is being created with a _txlock different from the recommended value 'immediate'")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
qs["_txlock"] = []string{"immediate"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add pragma values
|
||||||
|
var hasBusyTimeout, hasJournalMode bool
|
||||||
|
if len(qs["_pragma"]) == 0 {
|
||||||
|
qs["_pragma"] = make([]string, 0, 3)
|
||||||
|
} else {
|
||||||
|
for _, p := range qs["_pragma"] {
|
||||||
|
p = strings.ToLower(p)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(p, "busy_timeout"):
|
||||||
|
hasBusyTimeout = true
|
||||||
|
case strings.HasPrefix(p, "journal_mode"):
|
||||||
|
hasJournalMode = true
|
||||||
|
case strings.HasPrefix(p, "foreign_keys"):
|
||||||
|
return errors.New("found forbidden option '_pragma=foreign_keys' in the connection string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasBusyTimeout {
|
||||||
|
qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", defaultBusyTimeout.Milliseconds()))
|
||||||
|
}
|
||||||
|
if !hasJournalMode {
|
||||||
|
switch {
|
||||||
|
case isMemoryDB:
|
||||||
|
// For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective)
|
||||||
|
qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)")
|
||||||
|
case isReadOnly:
|
||||||
|
// Set the journaling mode to "DELETE" (the default) if the database is read-only
|
||||||
|
qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)")
|
||||||
|
default:
|
||||||
|
// Enable WAL
|
||||||
|
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forcefully enable foreign keys
|
||||||
|
qs["_pragma"] = append(qs["_pragma"], "foreign_keys(1)")
|
||||||
|
|
||||||
|
// Update the connStringUrl object
|
||||||
connStringUrl.RawQuery = qs.Encode()
|
connStringUrl.RawQuery = qs.Encode()
|
||||||
|
|
||||||
return connStringUrl.String(), nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSqliteInMemory returns true if the connection string is for an in-memory database.
|
||||||
|
func isSqliteInMemory(connString string) bool {
|
||||||
|
lc := strings.ToLower(connString)
|
||||||
|
|
||||||
|
// First way to define an in-memory database is to use ":memory:" or "file::memory:" as connection string
|
||||||
|
if strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Another way is to pass "mode=memory" in the "query string"
|
||||||
|
idx := strings.IndexRune(lc, '?')
|
||||||
|
if idx < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
qs, _ := url.ParseQuery(lc[(idx + 1):])
|
||||||
|
|
||||||
|
return len(qs["mode"]) > 0 && qs["mode"][0] == "memory"
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGormLogger() gormLogger.Interface {
|
func getGormLogger() gormLogger.Interface {
|
||||||
|
|||||||
@@ -8,23 +8,93 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSqliteConnectionString(t *testing.T) {
|
func TestIsSqliteInMemory(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input string
|
connStr string
|
||||||
expected string
|
expected bool
|
||||||
expectedError bool
|
}{
|
||||||
|
{
|
||||||
|
name: "memory database with :memory:",
|
||||||
|
connStr: ":memory:",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "memory database with file::memory:",
|
||||||
|
connStr: "file::memory:",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "memory database with :MEMORY: (uppercase)",
|
||||||
|
connStr: ":MEMORY:",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "memory database with FILE::MEMORY: (uppercase)",
|
||||||
|
connStr: "FILE::MEMORY:",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "memory database with mixed case",
|
||||||
|
connStr: ":Memory:",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "has mode=memory",
|
||||||
|
connStr: "file:data?mode=memory",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file database",
|
||||||
|
connStr: "data.db",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file database with path",
|
||||||
|
connStr: "/path/to/data.db",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file database with file: prefix",
|
||||||
|
connStr: "file:data.db",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
connStr: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string containing memory but not at start",
|
||||||
|
connStr: "data:memory:.db",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "has mode=ro",
|
||||||
|
connStr: "file:data?mode=ro",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isSqliteInMemory(tt.connStr)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertSqlitePragmaArgs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic file path",
|
name: "basic file path",
|
||||||
input: "file:test.db",
|
input: "file:test.db",
|
||||||
expected: "file:test.db",
|
expected: "file:test.db",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "adds file: prefix if missing",
|
|
||||||
input: "test.db",
|
|
||||||
expected: "file:test.db",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "converts _busy_timeout to pragma",
|
name: "converts _busy_timeout to pragma",
|
||||||
input: "file:test.db?_busy_timeout=5000",
|
input: "file:test.db?_busy_timeout=5000",
|
||||||
@@ -100,46 +170,161 @@ func TestParseSqliteConnectionString(t *testing.T) {
|
|||||||
input: "file:test.db?_fk=1&mode=rw&_timeout=5000",
|
input: "file:test.db?_fk=1&mode=rw&_timeout=5000",
|
||||||
expected: "file:test.db?_pragma=foreign_keys%281%29&_pragma=busy_timeout%285000%29&mode=rw",
|
expected: "file:test.db?_pragma=foreign_keys%281%29&_pragma=busy_timeout%285000%29&mode=rw",
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resultURL, _ := url.Parse(tt.input)
|
||||||
|
convertSqlitePragmaArgs(resultURL)
|
||||||
|
|
||||||
|
// Parse both URLs to compare components independently
|
||||||
|
expectedURL, err := url.Parse(tt.expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare scheme and path components
|
||||||
|
compareQueryStrings(t, expectedURL, resultURL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddSqliteDefaultParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
isMemoryDB bool
|
||||||
|
expected string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
{
|
{
|
||||||
name: "invalid URL format",
|
name: "basic file database",
|
||||||
input: "file:invalid#$%^&*@test.db",
|
input: "file:test.db",
|
||||||
expectedError: true,
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_txlock=immediate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "in-memory database",
|
||||||
|
input: "file::memory:",
|
||||||
|
isMemoryDB: true,
|
||||||
|
expected: "file::memory:?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28MEMORY%29&_txlock=immediate&cache=shared",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "read-only database with mode=ro",
|
||||||
|
input: "file:test.db?mode=ro",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28DELETE%29&_txlock=immediate&mode=ro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "immutable database",
|
||||||
|
input: "file:test.db?immutable=1",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28DELETE%29&_txlock=immediate&immutable=1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with existing _txlock",
|
||||||
|
input: "file:test.db?_txlock=deferred",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_txlock=deferred",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with existing busy_timeout pragma",
|
||||||
|
input: "file:test.db?_pragma=busy_timeout%285000%29",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%285000%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_txlock=immediate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with existing journal_mode pragma",
|
||||||
|
input: "file:test.db?_pragma=journal_mode%28DELETE%29",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28DELETE%29&_txlock=immediate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with forbidden foreign_keys pragma",
|
||||||
|
input: "file:test.db?_pragma=foreign_keys%280%29",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with multiple existing pragmas",
|
||||||
|
input: "file:test.db?_pragma=busy_timeout%283000%29&_pragma=journal_mode%28TRUNCATE%29&_pragma=synchronous%28NORMAL%29",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%283000%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28TRUNCATE%29&_pragma=synchronous%28NORMAL%29&_txlock=immediate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "in-memory database with cache already set",
|
||||||
|
input: "file::memory:?cache=private",
|
||||||
|
isMemoryDB: true,
|
||||||
|
expected: "file::memory:?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28MEMORY%29&_txlock=immediate&cache=shared",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with mode=rw (not read-only)",
|
||||||
|
input: "file:test.db?mode=rw",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_txlock=immediate&mode=rw",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with immutable=0 (not immutable)",
|
||||||
|
input: "file:test.db?immutable=0",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_txlock=immediate&immutable=0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with mixed case mode=RO",
|
||||||
|
input: "file:test.db?mode=RO",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28DELETE%29&_txlock=immediate&mode=ro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database with mixed case immutable=1",
|
||||||
|
input: "file:test.db?immutable=1",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28DELETE%29&_txlock=immediate&immutable=1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex database configuration",
|
||||||
|
input: "file:test.db?cache=shared&mode=rwc&_txlock=immediate&_pragma=synchronous%28FULL%29",
|
||||||
|
isMemoryDB: false,
|
||||||
|
expected: "file:test.db?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28WAL%29&_pragma=synchronous%28FULL%29&_txlock=immediate&cache=shared&mode=rwc",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, err := parseSqliteConnectionString(tt.input)
|
resultURL, err := url.Parse(tt.input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tt.expectedError {
|
err = addSqliteDefaultParameters(resultURL, tt.isMemoryDB)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Parse both URLs to compare components independently
|
|
||||||
expectedURL, err := url.Parse(tt.expected)
|
expectedURL, err := url.Parse(tt.expected)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resultURL, err := url.Parse(result)
|
compareQueryStrings(t, expectedURL, resultURL)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Compare scheme and path components
|
|
||||||
assert.Equal(t, expectedURL.Scheme, resultURL.Scheme)
|
|
||||||
assert.Equal(t, expectedURL.Path, resultURL.Path)
|
|
||||||
|
|
||||||
// Compare query parameters regardless of order
|
|
||||||
expectedQuery := expectedURL.Query()
|
|
||||||
resultQuery := resultURL.Query()
|
|
||||||
|
|
||||||
assert.Len(t, expectedQuery, len(resultQuery))
|
|
||||||
|
|
||||||
for key, expectedValues := range expectedQuery {
|
|
||||||
resultValues, ok := resultQuery[key]
|
|
||||||
_ = assert.True(t, ok) &&
|
|
||||||
assert.ElementsMatch(t, expectedValues, resultValues)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compareQueryStrings(t *testing.T, expectedURL *url.URL, resultURL *url.URL) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Compare scheme and path components
|
||||||
|
assert.Equal(t, expectedURL.Scheme, resultURL.Scheme)
|
||||||
|
assert.Equal(t, expectedURL.Path, resultURL.Path)
|
||||||
|
|
||||||
|
// Compare query parameters regardless of order
|
||||||
|
expectedQuery := expectedURL.Query()
|
||||||
|
resultQuery := resultURL.Query()
|
||||||
|
|
||||||
|
assert.Len(t, expectedQuery, len(resultQuery))
|
||||||
|
|
||||||
|
for key, expectedValues := range expectedQuery {
|
||||||
|
resultValues, ok := resultQuery[key]
|
||||||
|
_ = assert.True(t, ok) &&
|
||||||
|
assert.ElementsMatch(t, expectedValues, resultValues)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ const (
|
|||||||
DbProviderSqlite DbProvider = "sqlite"
|
DbProviderSqlite DbProvider = "sqlite"
|
||||||
DbProviderPostgres DbProvider = "postgres"
|
DbProviderPostgres DbProvider = "postgres"
|
||||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||||
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
|
defaultSqliteConnString string = "data/pocket-id.db"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnvConfigSchema struct {
|
type EnvConfigSchema struct {
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ type OidcClient struct {
|
|||||||
LaunchURL *string
|
LaunchURL *string
|
||||||
|
|
||||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||||
CreatedByID string
|
CreatedByID *string
|
||||||
CreatedBy User
|
CreatedBy *User
|
||||||
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
|
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
|||||||
CallbackURLs: model.UrlList{"http://nextcloud/auth/callback"},
|
CallbackURLs: model.UrlList{"http://nextcloud/auth/callback"},
|
||||||
LogoutCallbackURLs: model.UrlList{"http://nextcloud/auth/logout/callback"},
|
LogoutCallbackURLs: model.UrlList{"http://nextcloud/auth/logout/callback"},
|
||||||
ImageType: utils.StringPointer("png"),
|
ImageType: utils.StringPointer("png"),
|
||||||
CreatedByID: users[0].ID,
|
CreatedByID: utils.Ptr(users[0].ID),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
@@ -168,7 +168,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
|||||||
Name: "Immich",
|
Name: "Immich",
|
||||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||||
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
|
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
|
||||||
CreatedByID: users[1].ID,
|
CreatedByID: utils.Ptr(users[1].ID),
|
||||||
AllowedUserGroups: []model.UserGroup{
|
AllowedUserGroups: []model.UserGroup{
|
||||||
userGroups[1],
|
userGroups[1],
|
||||||
},
|
},
|
||||||
@@ -181,7 +181,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
|||||||
Secret: "$2a$10$xcRReBsvkI1XI6FG8xu/pOgzeF00bH5Wy4d/NThwcdi3ZBpVq/B9a", // n4VfQeXlTzA6yKpWbR9uJcMdSx2qH0Lo
|
Secret: "$2a$10$xcRReBsvkI1XI6FG8xu/pOgzeF00bH5Wy4d/NThwcdi3ZBpVq/B9a", // n4VfQeXlTzA6yKpWbR9uJcMdSx2qH0Lo
|
||||||
CallbackURLs: model.UrlList{"http://tailscale/auth/callback"},
|
CallbackURLs: model.UrlList{"http://tailscale/auth/callback"},
|
||||||
LogoutCallbackURLs: model.UrlList{"http://tailscale/auth/logout/callback"},
|
LogoutCallbackURLs: model.UrlList{"http://tailscale/auth/logout/callback"},
|
||||||
CreatedByID: users[0].ID,
|
CreatedByID: utils.Ptr(users[0].ID),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
@@ -190,7 +190,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
|||||||
Name: "Federated",
|
Name: "Federated",
|
||||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||||
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
|
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
|
||||||
CreatedByID: users[1].ID,
|
CreatedByID: utils.Ptr(users[1].ID),
|
||||||
AllowedUserGroups: []model.UserGroup{},
|
AllowedUserGroups: []model.UserGroup{},
|
||||||
Credentials: model.OidcClientCredentials{
|
Credentials: model.OidcClientCredentials{
|
||||||
FederatedIdentities: []model.OidcClientFederatedIdentity{
|
FederatedIdentities: []model.OidcClientFederatedIdentity{
|
||||||
|
|||||||
@@ -670,7 +670,7 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
|||||||
|
|
||||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||||
client := model.OidcClient{
|
client := model.OidcClient{
|
||||||
CreatedByID: userID,
|
CreatedByID: utils.Ptr(userID),
|
||||||
}
|
}
|
||||||
updateOIDCClientModelFromDto(&client, &input)
|
updateOIDCClientModelFromDto(&client, &input)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
-- No-op
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
ALTER TABLE public.audit_logs
|
||||||
|
DROP CONSTRAINT IF EXISTS audit_logs_user_id_fkey,
|
||||||
|
ADD CONSTRAINT audit_logs_user_id_fkey
|
||||||
|
FOREIGN KEY (user_id) REFERENCES public.users (id) ON DELETE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE public.oidc_authorization_codes
|
||||||
|
ADD CONSTRAINT oidc_authorization_codes_client_fk
|
||||||
|
FOREIGN KEY (client_id) REFERENCES public.oidc_clients (id) ON DELETE CASCADE;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
-- No-op
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
---------------------------
|
||||||
|
-- Delete all orphaned rows
|
||||||
|
---------------------------
|
||||||
|
UPDATE oidc_clients
|
||||||
|
SET created_by_id = NULL
|
||||||
|
WHERE created_by_id IS NOT NULL
|
||||||
|
AND created_by_id NOT IN (SELECT id FROM users);
|
||||||
|
|
||||||
|
DELETE FROM oidc_authorization_codes WHERE user_id NOT IN (SELECT id FROM users);
|
||||||
|
DELETE FROM one_time_access_tokens WHERE user_id NOT IN (SELECT id FROM users);
|
||||||
|
DELETE FROM webauthn_credentials WHERE user_id NOT IN (SELECT id FROM users);
|
||||||
|
DELETE FROM audit_logs WHERE user_id IS NOT NULL AND user_id NOT IN (SELECT id FROM users);
|
||||||
|
DELETE FROM api_keys WHERE user_id IS NOT NULL AND user_id NOT IN (SELECT id FROM users);
|
||||||
|
|
||||||
|
DELETE FROM oidc_refresh_tokens WHERE user_id NOT IN (SELECT id FROM users) OR client_id NOT IN (SELECT id FROM oidc_clients);
|
||||||
|
DELETE FROM oidc_device_codes WHERE (user_id IS NOT NULL AND user_id NOT IN (SELECT id FROM users)) OR client_id NOT IN (SELECT id FROM oidc_clients);
|
||||||
|
DELETE FROM user_authorized_oidc_clients WHERE user_id NOT IN (SELECT id FROM users) OR client_id NOT IN (SELECT id FROM oidc_clients);
|
||||||
|
|
||||||
|
DELETE FROM user_groups_users WHERE user_id NOT IN (SELECT id FROM users) OR user_group_id NOT IN (SELECT id FROM user_groups);
|
||||||
|
|
||||||
|
DELETE FROM custom_claims WHERE (user_id IS NOT NULL AND user_id NOT IN (SELECT id FROM users)) OR (user_group_id IS NOT NULL AND user_group_id NOT IN (SELECT id FROM user_groups));
|
||||||
|
|
||||||
|
DELETE FROM oidc_clients_allowed_user_groups WHERE oidc_client_id NOT IN (SELECT id FROM oidc_clients) OR user_group_id NOT IN (SELECT id FROM user_groups);
|
||||||
|
|
||||||
|
DELETE FROM reauthentication_tokens WHERE user_id NOT IN (SELECT id FROM users);
|
||||||
|
|
||||||
|
---------------------------
|
||||||
|
-- Add missing foreign keys and edit cascade behavior where necessary
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
-- reauthentication_tokens: add missing FK user_id → users
|
||||||
|
CREATE TABLE reauthentication_tokens_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
expires_at INTEGER NOT NULL,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
INSERT INTO reauthentication_tokens_new (id, created_at, token, expires_at, user_id)
|
||||||
|
SELECT id, created_at, token, expires_at, user_id
|
||||||
|
FROM reauthentication_tokens;
|
||||||
|
DROP TABLE reauthentication_tokens;
|
||||||
|
ALTER TABLE reauthentication_tokens_new RENAME TO reauthentication_tokens;
|
||||||
|
CREATE INDEX idx_reauthentication_tokens_token
|
||||||
|
ON reauthentication_tokens (token);
|
||||||
|
|
||||||
|
-- oidc_authorization_codes: add FK client_id, user_id → CASCADE
|
||||||
|
CREATE TABLE oidc_authorization_codes_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
code TEXT NOT NULL UNIQUE,
|
||||||
|
scope TEXT NOT NULL,
|
||||||
|
nonce TEXT,
|
||||||
|
expires_at DATETIME NOT NULL,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users ON DELETE CASCADE,
|
||||||
|
client_id TEXT NOT NULL REFERENCES oidc_clients ON DELETE CASCADE,
|
||||||
|
code_challenge TEXT,
|
||||||
|
code_challenge_method_sha256 NUMERIC
|
||||||
|
);
|
||||||
|
INSERT INTO oidc_authorization_codes_new
|
||||||
|
(id, created_at, code, scope, nonce, expires_at, user_id, client_id, code_challenge, code_challenge_method_sha256)
|
||||||
|
SELECT id, created_at, code, scope, nonce, expires_at, user_id, client_id, code_challenge, code_challenge_method_sha256
|
||||||
|
FROM oidc_authorization_codes;
|
||||||
|
DROP TABLE oidc_authorization_codes;
|
||||||
|
ALTER TABLE oidc_authorization_codes_new RENAME TO oidc_authorization_codes;
|
||||||
|
|
||||||
|
-- user_authorized_oidc_clients: add FK user_id, cascade client_id
|
||||||
|
CREATE TABLE user_authorized_oidc_clients_new
|
||||||
|
(
|
||||||
|
scope TEXT,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users ON DELETE CASCADE,
|
||||||
|
client_id TEXT NOT NULL REFERENCES oidc_clients ON DELETE CASCADE,
|
||||||
|
last_used_at DATETIME NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, client_id)
|
||||||
|
);
|
||||||
|
INSERT INTO user_authorized_oidc_clients_new (scope, user_id, client_id, last_used_at)
|
||||||
|
SELECT scope, user_id, client_id, last_used_at
|
||||||
|
FROM user_authorized_oidc_clients;
|
||||||
|
DROP TABLE user_authorized_oidc_clients;
|
||||||
|
ALTER TABLE user_authorized_oidc_clients_new RENAME TO user_authorized_oidc_clients;
|
||||||
|
|
||||||
|
-- audit_logs: user_id → CASCADE
|
||||||
|
CREATE TABLE audit_logs_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
event TEXT NOT NULL,
|
||||||
|
ip_address TEXT,
|
||||||
|
user_agent TEXT NOT NULL,
|
||||||
|
data BLOB NOT NULL,
|
||||||
|
user_id TEXT REFERENCES users ON DELETE CASCADE,
|
||||||
|
country TEXT,
|
||||||
|
city TEXT
|
||||||
|
);
|
||||||
|
INSERT INTO audit_logs_new
|
||||||
|
(id, created_at, event, ip_address, user_agent, data, user_id, country, city)
|
||||||
|
SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city
|
||||||
|
FROM audit_logs;
|
||||||
|
DROP TABLE audit_logs;
|
||||||
|
ALTER TABLE audit_logs_new RENAME TO audit_logs;
|
||||||
|
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));
|
||||||
|
CREATE INDEX idx_audit_logs_country ON audit_logs (country);
|
||||||
|
CREATE INDEX idx_audit_logs_created_at ON audit_logs (created_at);
|
||||||
|
CREATE INDEX idx_audit_logs_event ON audit_logs (event);
|
||||||
|
CREATE INDEX idx_audit_logs_user_agent ON audit_logs (user_agent);
|
||||||
|
CREATE INDEX idx_audit_logs_user_id ON audit_logs (user_id);
|
||||||
|
|
||||||
|
-- oidc_clients: created_by_id → SET NULL
|
||||||
|
CREATE TABLE oidc_clients_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
secret TEXT,
|
||||||
|
callback_urls BLOB,
|
||||||
|
image_type TEXT,
|
||||||
|
created_by_id TEXT REFERENCES users ON DELETE SET NULL,
|
||||||
|
is_public BOOLEAN DEFAULT FALSE,
|
||||||
|
pkce_enabled BOOLEAN DEFAULT FALSE,
|
||||||
|
logout_callback_urls BLOB,
|
||||||
|
credentials TEXT,
|
||||||
|
launch_url TEXT,
|
||||||
|
requires_reauthentication BOOLEAN DEFAULT FALSE NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO oidc_clients_new
|
||||||
|
(id, created_at, name, secret, callback_urls, image_type, created_by_id,
|
||||||
|
is_public, pkce_enabled, logout_callback_urls, credentials, launch_url, requires_reauthentication)
|
||||||
|
SELECT id, created_at, name, secret, callback_urls, image_type, created_by_id,
|
||||||
|
is_public, pkce_enabled, logout_callback_urls, credentials, launch_url, requires_reauthentication
|
||||||
|
FROM oidc_clients;
|
||||||
|
DROP TABLE oidc_clients;
|
||||||
|
ALTER TABLE oidc_clients_new RENAME TO oidc_clients;
|
||||||
|
|
||||||
|
-- one_time_access_tokens: user_id → CASCADE
|
||||||
|
CREATE TABLE one_time_access_tokens_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
expires_at DATETIME NOT NULL,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
INSERT INTO one_time_access_tokens_new
|
||||||
|
(id, created_at, token, expires_at, user_id)
|
||||||
|
SELECT id, created_at, token, expires_at, user_id
|
||||||
|
FROM one_time_access_tokens;
|
||||||
|
DROP TABLE one_time_access_tokens;
|
||||||
|
ALTER TABLE one_time_access_tokens_new RENAME TO one_time_access_tokens;
|
||||||
|
|
||||||
|
-- webauthn_credentials: user_id → CASCADE
|
||||||
|
CREATE TABLE webauthn_credentials_new
|
||||||
|
(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
credential_id TEXT NOT NULL UNIQUE,
|
||||||
|
public_key BLOB NOT NULL,
|
||||||
|
attestation_type TEXT NOT NULL,
|
||||||
|
transport BLOB NOT NULL,
|
||||||
|
user_id TEXT REFERENCES users ON DELETE CASCADE,
|
||||||
|
backup_eligible BOOLEAN DEFAULT FALSE NOT NULL,
|
||||||
|
backup_state BOOLEAN DEFAULT FALSE NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO webauthn_credentials_new
|
||||||
|
(id, created_at, name, credential_id, public_key, attestation_type,
|
||||||
|
transport, user_id, backup_eligible, backup_state)
|
||||||
|
SELECT id, created_at, name, credential_id, public_key, attestation_type,
|
||||||
|
transport, user_id, backup_eligible, backup_state
|
||||||
|
FROM webauthn_credentials;
|
||||||
|
DROP TABLE webauthn_credentials;
|
||||||
|
ALTER TABLE webauthn_credentials_new RENAME TO webauthn_credentials;
|
||||||
Reference in New Issue
Block a user