diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml
index acd002fb..ab558411 100644
--- a/.github/workflows/e2e-tests.yml
+++ b/.github/workflows/e2e-tests.yml
@@ -57,7 +57,17 @@ jobs:
strategy:
fail-fast: false
matrix:
- db: [sqlite, postgres, sqlite-s3]
+ include:
+ - db: sqlite
+ storage: fs
+ - db: postgres
+ storage: fs
+ - db: sqlite
+ storage: s3
+ - db: sqlite
+ storage: database
+ - db: postgres
+ storage: database
steps:
- uses: actions/checkout@v5
@@ -71,65 +81,74 @@ jobs:
node-version: 22
- name: Cache Playwright Browsers
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
- name: Cache PostgreSQL Docker image
- if: matrix.db == 'postgres'
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: postgres-cache
with:
path: /tmp/postgres-image.tar
key: postgres-17-${{ runner.os }}
-
- name: Pull and save PostgreSQL image
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit != 'true'
run: |
docker pull postgres:17
docker save postgres:17 > /tmp/postgres-image.tar
-
- name: Load PostgreSQL image
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit == 'true'
run: docker load < /tmp/postgres-image.tar
- name: Cache LLDAP Docker image
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: lldap-cache
with:
path: /tmp/lldap-image.tar
key: lldap-stable-${{ runner.os }}
-
- name: Pull and save LLDAP image
if: steps.lldap-cache.outputs.cache-hit != 'true'
run: |
- docker pull nitnelave/lldap:stable
- docker save nitnelave/lldap:stable > /tmp/lldap-image.tar
-
+ docker pull lldap/lldap:2025-05-19
+ docker save lldap/lldap:2025-05-19 > /tmp/lldap-image.tar
- name: Load LLDAP image
if: steps.lldap-cache.outputs.cache-hit == 'true'
run: docker load < /tmp/lldap-image.tar
- name: Cache Localstack S3 Docker image
- if: matrix.db == 'sqlite-s3'
- uses: actions/cache@v3
+ if: matrix.storage == 's3'
+ uses: actions/cache@v4
id: s3-cache
with:
path: /tmp/localstack-s3-image.tar
key: localstack-s3-latest-${{ runner.os }}
-
- name: Pull and save Localstack S3 image
- if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit != 'true'
+ if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit != 'true'
run: |
docker pull localstack/localstack:s3-latest
docker save localstack/localstack:s3-latest > /tmp/localstack-s3-image.tar
-
- name: Load Localstack S3 image
- if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit == 'true'
+ if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit == 'true'
run: docker load < /tmp/localstack-s3-image.tar
+ - name: Cache AWS CLI Docker image
+ if: matrix.storage == 's3'
+ uses: actions/cache@v4
+ id: aws-cli-cache
+ with:
+ path: /tmp/aws-cli-image.tar
+ key: aws-cli-latest-${{ runner.os }}
+ - name: Pull and save AWS CLI image
+ if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit != 'true'
+ run: |
+ docker pull amazon/aws-cli:latest
+ docker save amazon/aws-cli:latest > /tmp/aws-cli-image.tar
+ - name: Load AWS CLI image
+ if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit == 'true'
+ run: docker load < /tmp/aws-cli-image.tar
+
- name: Download Docker image artifact
uses: actions/download-artifact@v4
with:
@@ -147,26 +166,20 @@ jobs:
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: pnpm exec playwright install --with-deps chromium
- - name: Run Docker Container (sqlite) with LDAP
- if: matrix.db == 'sqlite'
+ - name: Run Docker containers
working-directory: ./tests/setup
run: |
- docker compose up -d
- docker compose logs -f pocket-id &> /tmp/backend.log &
+ DOCKER_COMPOSE_FILE=docker-compose.yml
- - name: Run Docker Container (postgres) with LDAP
- if: matrix.db == 'postgres'
- working-directory: ./tests/setup
- run: |
- docker compose -f docker-compose-postgres.yml up -d
- docker compose -f docker-compose-postgres.yml logs -f pocket-id &> /tmp/backend.log &
+ export FILE_BACKEND="${{ matrix.storage }}"
+ if [ "${{ matrix.db }}" = "postgres" ]; then
+ DOCKER_COMPOSE_FILE=docker-compose-postgres.yml
+ elif [ "${{ matrix.storage }}" = "s3" ]; then
+ DOCKER_COMPOSE_FILE=docker-compose-s3.yml
+ fi
- - name: Run Docker Container (sqlite-s3) with LDAP + S3
- if: matrix.db == 'sqlite-s3'
- working-directory: ./tests/setup
- run: |
- docker compose -f docker-compose-s3.yml up -d
- docker compose -f docker-compose-s3.yml logs -f pocket-id &> /tmp/backend.log &
+ docker compose -f "$DOCKER_COMPOSE_FILE" up -d
+ docker compose -f "$DOCKER_COMPOSE_FILE" logs -f pocket-id &> /tmp/backend.log &
- name: Run Playwright tests
working-directory: ./tests
@@ -176,7 +189,7 @@ jobs:
uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with:
- name: playwright-report-${{ matrix.db }}
+ name: playwright-report-${{ matrix.db }}-${{ matrix.storage }}
path: tests/.report
include-hidden-files: true
retention-days: 15
@@ -185,7 +198,7 @@ jobs:
uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with:
- name: backend-${{ matrix.db }}
+ name: backend-${{ matrix.db }}-${{ matrix.storage }}
path: /tmp/backend.log
include-hidden-files: true
retention-days: 15
diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go
index a4e22700..ba1ea993 100644
--- a/backend/internal/bootstrap/bootstrap.go
+++ b/backend/internal/bootstrap/bootstrap.go
@@ -22,12 +22,20 @@ func Bootstrap(ctx context.Context) error {
}
slog.InfoContext(ctx, "Pocket ID is starting")
+ // Connect to the database
+ db, err := NewDatabase()
+ if err != nil {
+ return fmt.Errorf("failed to initialize database: %w", err)
+ }
+
// Initialize the file storage backend
var fileStorage storage.FileStorage
switch common.EnvConfig.FileBackend {
case storage.TypeFileSystem:
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
+ case storage.TypeDatabase:
+ fileStorage, err = storage.NewDatabaseStorage(db)
case storage.TypeS3:
s3Cfg := storage.S3Config{
Bucket: common.EnvConfig.S3Bucket,
@@ -43,7 +51,7 @@ func Bootstrap(ctx context.Context) error {
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
}
if err != nil {
- return fmt.Errorf("failed to initialize file storage: %w", err)
+ return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err)
}
imageExtensions, err := initApplicationImages(ctx, fileStorage)
@@ -51,12 +59,6 @@ func Bootstrap(ctx context.Context) error {
return fmt.Errorf("failed to initialize application images: %w", err)
}
- // Connect to the database
- db, err := NewDatabase()
- if err != nil {
- return fmt.Errorf("failed to initialize database: %w", err)
- }
-
// Create all services
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
if err != nil {
diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go
index 2e362475..3c1072be 100644
--- a/backend/internal/bootstrap/router_bootstrap.go
+++ b/backend/internal/bootstrap/router_bootstrap.go
@@ -41,11 +41,11 @@ func initRouter(db *gorm.DB, svc *services) utils.Service {
func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
// Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv {
- case "production":
+ case common.AppEnvProduction:
gin.SetMode(gin.ReleaseMode)
- case "development":
+ case common.AppEnvDevelopment:
gin.SetMode(gin.DebugMode)
- case "test":
+ case common.AppEnvTest:
gin.SetMode(gin.TestMode)
}
@@ -92,7 +92,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
controller.NewVersionController(apiGroup, svc.versionService)
// Add test controller in non-production environments
- if common.EnvConfig.AppEnv != "production" {
+ if !common.EnvConfig.AppEnv.IsProduction() {
for _, f := range registerTestControllers {
f(apiGroup, db, svc)
}
diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go
index 31a9967e..21254627 100644
--- a/backend/internal/bootstrap/services_bootstrap.go
+++ b/backend/internal/bootstrap/services_bootstrap.go
@@ -66,7 +66,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
- svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
+ svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
svc.versionService = service.NewVersionService(httpClient)
diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go
index 3d54972e..7c17c37d 100644
--- a/backend/internal/common/env_config.go
+++ b/backend/internal/common/env_config.go
@@ -15,6 +15,7 @@ import (
_ "github.com/joho/godotenv/autoload"
)
+type AppEnv string
type DbProvider string
const (
@@ -25,6 +26,9 @@ const (
)
const (
+ AppEnvProduction AppEnv = "production"
+ AppEnvDevelopment AppEnv = "development"
+ AppEnvTest AppEnv = "test"
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
@@ -34,7 +38,7 @@ const (
)
type EnvConfigSchema struct {
- AppEnv string `env:"APP_ENV" options:"toLower"`
+ AppEnv AppEnv `env:"APP_ENV" options:"toLower"`
LogLevel string `env:"LOG_LEVEL" options:"toLower"`
AppURL string `env:"APP_URL" options:"toLower,trimTrailingSlash"`
DbProvider DbProvider `env:"DB_PROVIDER" options:"toLower"`
@@ -78,7 +82,7 @@ func init() {
func defaultConfig() EnvConfigSchema {
return EnvConfigSchema{
- AppEnv: "production",
+ AppEnv: AppEnvProduction,
LogLevel: "info",
DbProvider: "sqlite",
FileBackend: "fs",
@@ -158,13 +162,13 @@ func ValidateEnvConfig(config *EnvConfigSchema) error {
}
switch config.FileBackend {
- case "s3":
+ case "s3", "database":
case "", "fs":
if config.UploadPath == "" {
config.UploadPath = defaultFsUploadPath
}
default:
- return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'")
+ return errors.New("invalid FILE_BACKEND value. Must be 'fs', 'database', or 's3'")
}
// Validate LOCAL_IPV6_RANGES
@@ -265,3 +269,11 @@ func resolveFileBasedEnvVariable(field reflect.Value, fieldType reflect.StructFi
return nil
}
+
+func (a AppEnv) IsProduction() bool {
+ return a == AppEnvProduction
+}
+
+func (a AppEnv) IsTest() bool {
+ return a == AppEnvTest
+}
diff --git a/backend/internal/common/env_config_test.go b/backend/internal/common/env_config_test.go
index c3e41f4e..d5e42a17 100644
--- a/backend/internal/common/env_config_test.go
+++ b/backend/internal/common/env_config_test.go
@@ -164,7 +164,7 @@ func TestParseEnvConfig(t *testing.T) {
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
t.Setenv("APP_URL", "https://prod.example.com")
- t.Setenv("APP_ENV", "STAGING")
+ t.Setenv("APP_ENV", "PRODUCTION")
t.Setenv("UPLOAD_PATH", "/custom/uploads")
t.Setenv("PORT", "8080")
t.Setenv("HOST", "LOCALHOST")
@@ -174,7 +174,7 @@ func TestParseEnvConfig(t *testing.T) {
err := parseAndValidateEnvConfig(t)
require.NoError(t, err)
- assert.Equal(t, "staging", EnvConfig.AppEnv) // lowercased
+ assert.Equal(t, AppEnvProduction, EnvConfig.AppEnv) // lowercased
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
assert.Equal(t, "8080", EnvConfig.Port)
assert.Equal(t, "localhost", EnvConfig.Host) // lowercased
@@ -238,7 +238,7 @@ func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {
err := prepareEnvConfig(&config)
require.NoError(t, err)
- assert.Equal(t, "staging", config.AppEnv)
+ assert.Equal(t, AppEnv("staging"), config.AppEnv)
assert.Equal(t, "localhost", config.Host)
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
assert.Equal(t, dbConnContent, config.DbConnectionString)
diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go
index 607be22e..86d26e14 100644
--- a/backend/internal/controller/oidc_controller.go
+++ b/backend/internal/controller/oidc_controller.go
@@ -587,7 +587,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
}
c.Status(http.StatusNoContent)
-
}
// deleteClientLogoHandler godoc
@@ -614,7 +613,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
}
c.Status(http.StatusNoContent)
-
}
// updateAllowedUserGroupsHandler godoc
diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go
index 51ffb587..7dee5602 100644
--- a/backend/internal/controller/webauthn_controller.go
+++ b/backend/internal/controller/webauthn_controller.go
@@ -57,7 +57,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
}
userID := c.GetString("userID")
- credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request)
+ credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request, c.ClientIP())
if err != nil {
_ = c.Error(err)
return
@@ -134,8 +134,10 @@ func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
userID := c.GetString("userID")
credentialID := c.Param("id")
+ clientIP := c.ClientIP()
+ userAgent := c.Request.UserAgent()
- err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID)
+ err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID, clientIP, userAgent)
if err != nil {
_ = c.Error(err)
return
diff --git a/backend/internal/job/analytics_job.go b/backend/internal/job/analytics_job.go
index 468d45f0..6cf2b944 100644
--- a/backend/internal/job/analytics_job.go
+++ b/backend/internal/job/analytics_job.go
@@ -19,7 +19,7 @@ const heartbeatUrl = "https://analytics.pocket-id.org/heartbeat"
func (s *Scheduler) RegisterAnalyticsJob(ctx context.Context, appConfig *service.AppConfigService, httpClient *http.Client) error {
// Skip if analytics are disabled or not in production environment
- if common.EnvConfig.AnalyticsDisabled || common.EnvConfig.AppEnv != "production" {
+ if common.EnvConfig.AnalyticsDisabled || !common.EnvConfig.AppEnv.IsProduction() {
return nil
}
@@ -39,7 +39,7 @@ type AnalyticsJob struct {
// sendHeartbeat sends a heartbeat to the analytics service
func (j *AnalyticsJob) sendHeartbeat(parentCtx context.Context) error {
// Skip if analytics are disabled or not in production environment
- if common.EnvConfig.AnalyticsDisabled || common.EnvConfig.AppEnv != "production" {
+ if common.EnvConfig.AnalyticsDisabled || !common.EnvConfig.AppEnv.IsProduction() {
return nil
}
diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go
index 910b5865..210ea273 100644
--- a/backend/internal/middleware/rate_limit.go
+++ b/backend/internal/middleware/rate_limit.go
@@ -29,7 +29,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Skip rate limiting for localhost and test environment
// If the client ip is localhost the request comes from the frontend
- if ip == "" || ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
+ if ip == "" || ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv.IsTest() {
c.Next()
return
}
diff --git a/backend/internal/model/audit_log.go b/backend/internal/model/audit_log.go
index 46b6a76a..c7b0505b 100644
--- a/backend/internal/model/audit_log.go
+++ b/backend/internal/model/audit_log.go
@@ -34,6 +34,8 @@ const (
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
AuditLogEventDeviceCodeAuthorization AuditLogEvent = "DEVICE_CODE_AUTHORIZATION"
AuditLogEventNewDeviceCodeAuthorization AuditLogEvent = "NEW_DEVICE_CODE_AUTHORIZATION"
+ AuditLogEventPasskeyAdded AuditLogEvent = "PASSKEY_ADDED"
+ AuditLogEventPasskeyRemoved AuditLogEvent = "PASSKEY_REMOVED"
)
// Scan and Value methods for GORM to handle the custom type
diff --git a/backend/internal/model/storage.go b/backend/internal/model/storage.go
new file mode 100644
index 00000000..f5668b87
--- /dev/null
+++ b/backend/internal/model/storage.go
@@ -0,0 +1,17 @@
+package model
+
+import (
+ datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
+)
+
+type Storage struct {
+ Path string `gorm:"primaryKey"`
+ Data []byte
+ Size int64
+ ModTime datatype.DateTime
+ CreatedAt datatype.DateTime
+}
+
+func (Storage) TableName() string {
+ return "storage"
+}
diff --git a/backend/internal/service/audit_log_service.go b/backend/internal/service/audit_log_service.go
index c19e3560..abb2a23b 100644
--- a/backend/internal/service/audit_log_service.go
+++ b/backend/internal/service/audit_log_service.go
@@ -34,7 +34,7 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent,
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil {
// Log the error but don't interrupt the operation
- slog.Warn("Failed to get IP location", "error", err)
+ slog.Warn("Failed to get IP location", slog.String("ip", ipAddress), slog.Any("error", err))
}
auditLog := model.AuditLog{
@@ -201,8 +201,8 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s
WithContext(ctx).
Joins("User").
Model(&model.AuditLog{}).
- Select("DISTINCT \"User\".id, \"User\".username").
- Where("\"User\".username IS NOT NULL")
+ Select(`DISTINCT "User".id, "User".username`).
+ Where(`"User".username IS NOT NULL`)
type Result struct {
ID string `gorm:"column:id"`
@@ -210,7 +210,8 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s
}
var results []Result
- if err := query.Find(&results).Error; err != nil {
+ err = query.Find(&results).Error
+ if err != nil {
return nil, fmt.Errorf("failed to query user IDs: %w", err)
}
@@ -246,7 +247,8 @@ func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []st
}
var results []Result
- if err := query.Find(&results).Error; err != nil {
+ err = query.Find(&results).Error
+ if err != nil {
return nil, fmt.Errorf("failed to query client IDs: %w", err)
}
diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go
index e6765a33..5b9549e1 100644
--- a/backend/internal/service/e2etest_service.go
+++ b/backend/internal/service/e2etest_service.go
@@ -426,7 +426,8 @@ func (s *TestService) ResetDatabase() error {
}
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
- if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil {
+ err := s.fileStorage.DeleteAll(ctx, "/")
+ if err != nil {
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
return err
}
@@ -445,7 +446,8 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error {
if err != nil {
return err
}
- if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil {
+ err = s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile)
+ if err != nil {
srcFile.Close()
return err
}
diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go
index cef4e050..93778e96 100644
--- a/backend/internal/service/ldap_service.go
+++ b/backend/internal/service/ldap_service.go
@@ -11,12 +11,14 @@ import (
"log/slog"
"net/http"
"net/url"
+ "path"
"strings"
"time"
"unicode/utf8"
"github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
+ "github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"golang.org/x/text/unicode/norm"
"gorm.io/gorm"
@@ -32,15 +34,23 @@ type LdapService struct {
appConfigService *AppConfigService
userService *UserService
groupService *UserGroupService
+ fileStorage storage.FileStorage
}
-func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
+type savePicture struct {
+ userID string
+ username string
+ picture string
+}
+
+func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
return &LdapService{
db: db,
httpClient: httpClient,
appConfigService: appConfigService,
userService: userService,
groupService: groupService,
+ fileStorage: fileStorage,
}
}
@@ -68,12 +78,6 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
}
func (s *LdapService) SyncAll(ctx context.Context) error {
- // Start a transaction
- tx := s.db.Begin()
- defer func() {
- tx.Rollback()
- }()
-
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -81,7 +85,13 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
}
defer client.Close()
- err = s.SyncUsers(ctx, tx, client)
+ // Start a transaction
+ tx := s.db.Begin()
+ defer func() {
+ tx.Rollback()
+ }()
+
+ savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
@@ -97,6 +107,25 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
+ // Now that we've committed the transaction, we can perform operations on the storage layer
+ // First, save all new pictures
+ for _, sp := range savePictures {
+ err = s.saveProfilePicture(ctx, sp.userID, sp.picture)
+ if err != nil {
+ // This is not a fatal error
+ slog.Warn("Error saving profile picture for LDAP user", slog.String("username", sp.username), slog.Any("error", err))
+ }
+ }
+
+ // Delete all old files
+ for _, path := range deleteFiles {
+ err = s.fileStorage.Delete(ctx, path)
+ if err != nil {
+ // This is not a fatal error
+ slog.Error("Failed to delete file after LDAP sync", slog.String("path", path), slog.Any("error", err))
+ }
+ }
+
return nil
}
@@ -266,7 +295,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
}
//nolint:gocognit
-func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
+func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
dbConfig := s.appConfigService.GetDbConfig()
searchAttrs := []string{
@@ -294,11 +323,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
result, err := client.Search(searchReq)
if err != nil {
- return fmt.Errorf("failed to query LDAP: %w", err)
+ return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
}
// Create a mapping for users that exist
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
+ savePictures = make([]savePicture, 0, len(result.Entries))
for _, value := range result.Entries {
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
@@ -329,13 +359,13 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
Error
if err != nil {
- return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
+ return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
}
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
- return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
+ return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
}
// Check if user is admin by checking if they are in the admin group
@@ -369,32 +399,35 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
continue
}
+ userID := databaseUser.ID
if databaseUser.ID == "" {
- _, err = s.userService.createUserInternal(ctx, newUser, true, tx)
+ createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
- return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
+ return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
}
+ userID = createdUser.ID
} else {
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
- return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
+ return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
}
}
// Save profile picture
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
if pictureString != "" {
- err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
- if err != nil {
- // This is not a fatal error
- slog.Warn("Error saving profile picture for user", slog.String("username", newUser.Username), slog.Any("error", err))
- }
+ // Storage operations must be executed outside of a transaction
+ savePictures = append(savePictures, savePicture{
+ userID: databaseUser.ID,
+ username: userID,
+ picture: pictureString,
+ })
}
}
@@ -406,10 +439,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
Select("id, username, ldap_id, disabled").
Error
if err != nil {
- return fmt.Errorf("failed to fetch users from database: %w", err)
+ return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
}
// Mark users as disabled or delete users that no longer exist in LDAP
+ deleteFiles = make([]string, 0, len(ldapUserIDs))
for _, user := range ldapUsersInDb {
// Skip if the user ID exists in the fetched LDAP results
if _, exists := ldapUserIDs[*user.LdapID]; exists {
@@ -417,26 +451,30 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
}
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
- err = s.userService.disableUserInternal(ctx, user.ID, tx)
+ err = s.userService.disableUserInternal(ctx, tx, user.ID)
if err != nil {
- return fmt.Errorf("failed to disable user %s: %w", user.Username, err)
+ return nil, nil, fmt.Errorf("failed to disable user %s: %w", user.Username, err)
}
slog.Info("Disabled user", slog.String("username", user.Username))
} else {
- err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
- target := &common.LdapUserUpdateError{}
- if errors.As(err, &target) {
- return fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
- } else if err != nil {
- return fmt.Errorf("failed to delete user %s: %w", user.Username, err)
+ err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
+ if err != nil {
+ target := &common.LdapUserUpdateError{}
+ if errors.As(err, &target) {
+ return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
+ }
+ return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
}
slog.Info("Deleted user", slog.String("username", user.Username))
+
+ // Storage operations must be executed outside of a transaction
+ deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
}
}
- return nil
+ return savePictures, deleteFiles, nil
}
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go
index 08b9ecae..2d42f452 100644
--- a/backend/internal/service/oidc_service.go
+++ b/backend/internal/service/oidc_service.go
@@ -12,7 +12,6 @@ import (
"io"
"log/slog"
"mime/multipart"
- "net"
"net/http"
"net/url"
"path"
@@ -679,19 +678,21 @@ func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID strin
}
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
- return s.getClientInternal(ctx, clientID, s.db)
+ return s.getClientInternal(ctx, clientID, s.db, false)
}
-func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
+func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
var client model.OidcClient
- err := tx.
+ q := tx.
WithContext(ctx).
Preload("CreatedBy").
- Preload("AllowedUserGroups").
- First(&client, "id = ?", clientID).
- Error
- if err != nil {
- return model.OidcClient{}, err
+ Preload("AllowedUserGroups")
+ if forUpdate {
+ q = q.Clauses(clause.Locking{Strength: "UPDATE"})
+ }
+ q = q.First(&client, "id = ?", clientID)
+ if q.Error != nil {
+ return model.OidcClient{}, q.Error
}
return client, nil
}
@@ -724,11 +725,6 @@ func (s *OidcService) ListClients(ctx context.Context, name string, listRequestO
}
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
- tx := s.db.Begin()
- defer func() {
- tx.Rollback()
- }()
-
client := model.OidcClient{
Base: model.Base{
ID: input.ID,
@@ -737,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
}
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
- err := tx.
+ err := s.db.
WithContext(ctx).
Create(&client).
Error
@@ -748,62 +744,65 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
return model.OidcClient{}, err
}
+ // All storage operations must be executed outside of a transaction
if input.LogoURL != nil {
- err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
+ err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
}
if input.DarkLogoURL != nil {
- err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
+ err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
}
}
- err = tx.Commit().Error
- if err != nil {
- return model.OidcClient{}, err
- }
-
return client, nil
}
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
tx := s.db.Begin()
- defer func() { tx.Rollback() }()
+ defer func() {
+ tx.Rollback()
+ }()
var client model.OidcClient
- if err := tx.WithContext(ctx).
+ err := tx.WithContext(ctx).
Preload("CreatedBy").
- First(&client, "id = ?", clientID).Error; err != nil {
+ First(&client, "id = ?", clientID).Error
+ if err != nil {
return model.OidcClient{}, err
}
updateOIDCClientModelFromDto(&client, &input)
- if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
+ err = tx.WithContext(ctx).Save(&client).Error
+ if err != nil {
return model.OidcClient{}, err
}
+ err = tx.Commit().Error
+ if err != nil {
+ return model.OidcClient{}, err
+ }
+
+ // All storage operations must be executed outside of a transaction
if input.LogoURL != nil {
- err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
+ err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
}
}
if input.DarkLogoURL != nil {
- err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
+ err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
if err != nil {
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
}
}
- if err := tx.Commit().Error; err != nil {
- return model.OidcClient{}, err
- }
return client, nil
}
@@ -836,12 +835,24 @@ func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
err := s.db.
WithContext(ctx).
Where("id = ?", clientID).
+ Clauses(clause.Returning{}).
Delete(&client).
Error
if err != nil {
return err
}
+ // Delete images if present
+ // Note that storage operations must be done outside of a transaction
+ if client.ImageType != nil && *client.ImageType != "" {
+ old := path.Join("oidc-client-images", client.ID+"."+*client.ImageType)
+ _ = s.fileStorage.Delete(ctx, old)
+ }
+ if client.DarkImageType != nil && *client.DarkImageType != "" {
+ old := path.Join("oidc-client-images", client.ID+"-dark."+*client.DarkImageType)
+ _ = s.fileStorage.Delete(ctx, old)
+ }
+
return nil
}
@@ -941,57 +952,12 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
return err
}
defer reader.Close()
- if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
- return err
- }
-
- tx := s.db.Begin()
-
- err = s.updateClientLogoType(ctx, tx, clientID, fileType, light)
- if err != nil {
- tx.Rollback()
- return err
- }
-
- return tx.Commit().Error
-}
-
-func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
- tx := s.db.Begin()
- defer func() {
- tx.Rollback()
- }()
-
- var client model.OidcClient
- err := tx.
- WithContext(ctx).
- First(&client, "id = ?", clientID).
- Error
+ err = s.fileStorage.Save(ctx, imagePath, reader)
if err != nil {
return err
}
- if client.ImageType == nil {
- return errors.New("image not found")
- }
-
- oldImageType := *client.ImageType
- client.ImageType = nil
-
- err = tx.
- WithContext(ctx).
- Save(&client).
- Error
- if err != nil {
- return err
- }
-
- imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
- if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
- return err
- }
-
- err = tx.Commit().Error
+ err = s.updateClientLogoType(ctx, clientID, fileType, light)
if err != nil {
return err
}
@@ -999,7 +965,31 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
return nil
}
+func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
+ return s.deleteClientLogoInternal(ctx, clientID, "", func(client *model.OidcClient) (string, error) {
+ if client.ImageType == nil {
+ return "", errors.New("image not found")
+ }
+
+ oldImageType := *client.ImageType
+ client.ImageType = nil
+ return oldImageType, nil
+ })
+}
+
func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error {
+ return s.deleteClientLogoInternal(ctx, clientID, "-dark", func(client *model.OidcClient) (string, error) {
+ if client.DarkImageType == nil {
+ return "", errors.New("image not found")
+ }
+
+ oldImageType := *client.DarkImageType
+ client.DarkImageType = nil
+ return oldImageType, nil
+ })
+}
+
+func (s *OidcService) deleteClientLogoInternal(ctx context.Context, clientID string, imagePathSuffix string, setClientImage func(*model.OidcClient) (string, error)) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -1014,13 +1004,11 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
return err
}
- if client.DarkImageType == nil {
- return errors.New("image not found")
+ oldImageType, err := setClientImage(&client)
+ if err != nil {
+ return err
}
- oldImageType := *client.DarkImageType
- client.DarkImageType = nil
-
err = tx.
WithContext(ctx).
Save(&client).
@@ -1029,12 +1017,14 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
return err
}
- imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
- if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
+ err = tx.Commit().Error
+ if err != nil {
return err
}
- err = tx.Commit().Error
+ // All storage operations must be performed outside of a database transaction
+ imagePath := path.Join("oidc-client-images", client.ID+imagePathSuffix+"."+oldImageType)
+ err = s.fileStorage.Delete(ctx, imagePath)
if err != nil {
return err
}
@@ -1048,7 +1038,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
tx.Rollback()
}()
- client, err = s.getClientInternal(ctx, id, tx)
+ client, err = s.getClientInternal(ctx, id, tx, true)
if err != nil {
return model.OidcClient{}, err
}
@@ -1831,7 +1821,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
tx.Rollback()
}()
- client, err := s.getClientInternal(ctx, clientID, tx)
+ client, err := s.getClientInternal(ctx, clientID, tx, false)
if err != nil {
return nil, err
}
@@ -1976,7 +1966,25 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
return s.IsUserGroupAllowedToAuthorize(user, client), nil
}
-func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string, light bool) error {
+var errLogoTooLarge = errors.New("logo is too large")
+
+func httpClientWithCheckRedirect(source *http.Client, checkRedirect func(req *http.Request, via []*http.Request) error) *http.Client {
+ if source == nil {
+ source = http.DefaultClient
+ }
+
+ // Create a new client that clones the transport
+ client := &http.Client{
+ Transport: source.Transport,
+ }
+
+ // Assign the CheckRedirect function
+ client.CheckRedirect = checkRedirect
+
+ return client
+}
+
+func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, clientID string, raw string, light bool) error {
u, err := url.Parse(raw)
if err != nil {
return err
@@ -1985,18 +1993,29 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
defer cancel()
- r := net.Resolver{}
- ips, err := r.LookupIPAddr(ctx, u.Hostname())
- if err != nil || len(ips) == 0 {
- return fmt.Errorf("cannot resolve hostname")
+ // Prevents SSRF by allowing only public IPs
+ ok, err := utils.IsURLPrivate(ctx, u)
+ if err != nil {
+ return err
+ } else if ok {
+ return errors.New("private IP addresses are not allowed")
}
- // Prevents SSRF by allowing only public IPs
- for _, addr := range ips {
- if utils.IsPrivateIP(addr.IP) {
- return fmt.Errorf("private IP addresses are not allowed")
+ // We need to check this on redirects too
+ client := httpClientWithCheckRedirect(s.httpClient, func(r *http.Request, via []*http.Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
}
- }
+
+ ok, err := utils.IsURLPrivate(r.Context(), r.URL)
+ if err != nil {
+ return err
+ } else if ok {
+ return errors.New("private IP addresses are not allowed")
+ }
+
+ return nil
+ })
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
if err != nil {
@@ -2005,7 +2024,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
req.Header.Set("Accept", "image/*")
- resp, err := s.httpClient.Do(req)
+ resp, err := client.Do(req)
if err != nil {
return err
}
@@ -2017,7 +2036,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
if resp.ContentLength > maxLogoSize {
- return fmt.Errorf("logo is too large")
+ return errLogoTooLarge
}
// Prefer extension in path if supported
@@ -2037,48 +2056,70 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
}
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
- if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
+ err = s.fileStorage.Save(ctx, imagePath, utils.NewLimitReader(resp.Body, maxLogoSize+1))
+ if errors.Is(err, utils.ErrSizeExceeded) {
+ return errLogoTooLarge
+ } else if err != nil {
return err
}
- if err := s.updateClientLogoType(ctx, tx, clientID, ext, light); err != nil {
+ err = s.updateClientLogoType(ctx, clientID, ext, light)
+ if err != nil {
return err
}
return nil
}
-func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
+func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string, ext string, light bool) error {
var darkSuffix string
if !light {
darkSuffix = "-dark"
}
+ tx := s.db.Begin()
+ defer func() {
+ tx.Rollback()
+ }()
+
+ // We need to acquire an update lock for the row to be locked, since we'll update it later
var client model.OidcClient
- if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
- return err
+ err := tx.
+ WithContext(ctx).
+ Clauses(clause.Locking{Strength: "UPDATE"}).
+ First(&client, "id = ?", clientID).
+ Error
+ if err != nil {
+ return fmt.Errorf("failed to look up client: %w", err)
}
+
var currentType *string
if light {
currentType = client.ImageType
+ client.ImageType = &ext
} else {
currentType = client.DarkImageType
+ client.DarkImageType = &ext
}
+
+ err = tx.
+ WithContext(ctx).
+ Save(&client).
+ Error
+ if err != nil {
+ return fmt.Errorf("failed to save updated client: %w", err)
+ }
+
+ err = tx.Commit().Error
+ if err != nil {
+ return fmt.Errorf("failed to commit transaction: %w", err)
+ }
+
+ // Storage operations must be executed outside of a transaction
if currentType != nil && *currentType != ext {
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
_ = s.fileStorage.Delete(ctx, old)
}
- var column string
- if light {
- column = "image_type"
- } else {
- column = "dark_image_type"
- }
-
- return tx.WithContext(ctx).
- Model(&model.OidcClient{}).
- Where("id = ?", clientID).
- Update(column, ext).
- Error
+ return nil
}
diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go
index df73edab..7c8ddfc5 100644
--- a/backend/internal/service/oidc_service_test.go
+++ b/backend/internal/service/oidc_service_test.go
@@ -8,7 +8,10 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
+ "io"
"net/http"
+ "strconv"
+ "strings"
"testing"
"time"
@@ -21,6 +24,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
+ "github.com/pocket-id/pocket-id/backend/internal/storage"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
@@ -538,3 +542,435 @@ func TestValidateCodeVerifier_Plain(t *testing.T) {
require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true))
})
}
+
+func TestOidcService_updateClientLogoType(t *testing.T) {
+ // Create a test database
+ db := testutils.NewDatabaseForTest(t)
+
+ // Create database storage
+ dbStorage, err := storage.NewDatabaseStorage(db)
+ require.NoError(t, err)
+
+ // Init the OidcService
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ }
+
+ // Create a test client
+ client := model.OidcClient{
+ Name: "Test Client",
+ CallbackURLs: model.UrlList{"https://example.com/callback"},
+ }
+ err = db.Create(&client).Error
+ require.NoError(t, err)
+
+ // Helper function to check if a file exists in storage
+ fileExists := func(t *testing.T, path string) bool {
+ t.Helper()
+ _, _, err := dbStorage.Open(t.Context(), path)
+ return err == nil
+ }
+
+ // Helper function to create a dummy file in storage
+ createDummyFile := func(t *testing.T, path string) {
+ t.Helper()
+ err := dbStorage.Save(t.Context(), path, strings.NewReader("dummy content"))
+ require.NoError(t, err)
+ }
+
+ t.Run("Updates light logo type for client without previous logo", func(t *testing.T) {
+ // Update the logo type
+ err := s.updateClientLogoType(t.Context(), client.ID, "png", true)
+ require.NoError(t, err)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.ImageType)
+ assert.Equal(t, "png", *updatedClient.ImageType)
+ })
+
+ t.Run("Updates dark logo type for client without previous dark logo", func(t *testing.T) {
+ // Update the dark logo type
+ err := s.updateClientLogoType(t.Context(), client.ID, "jpg", false)
+ require.NoError(t, err)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.DarkImageType)
+ assert.Equal(t, "jpg", *updatedClient.DarkImageType)
+ })
+
+ t.Run("Updates light logo type and deletes old file when type changes", func(t *testing.T) {
+ // Create the old PNG file in storage
+ oldPath := "oidc-client-images/" + client.ID + ".png"
+ createDummyFile(t, oldPath)
+ require.True(t, fileExists(t, oldPath), "Old file should exist before update")
+
+ // Client currently has a PNG logo, update to WEBP
+ err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
+ require.NoError(t, err)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.ImageType)
+ assert.Equal(t, "webp", *updatedClient.ImageType)
+
+ // Old PNG file should be deleted
+ assert.False(t, fileExists(t, oldPath), "Old PNG file should have been deleted")
+ })
+
+ t.Run("Updates dark logo type and deletes old file when type changes", func(t *testing.T) {
+ // Create the old JPG dark file in storage
+ oldPath := "oidc-client-images/" + client.ID + "-dark.jpg"
+ createDummyFile(t, oldPath)
+ require.True(t, fileExists(t, oldPath), "Old dark file should exist before update")
+
+ // Client currently has a JPG dark logo, update to WEBP
+ err := s.updateClientLogoType(t.Context(), client.ID, "webp", false)
+ require.NoError(t, err)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.DarkImageType)
+ assert.Equal(t, "webp", *updatedClient.DarkImageType)
+
+ // Old JPG dark file should be deleted
+ assert.False(t, fileExists(t, oldPath), "Old JPG dark file should have been deleted")
+ })
+
+ t.Run("Does not delete file when type remains the same", func(t *testing.T) {
+ // Create the WEBP file in storage
+ webpPath := "oidc-client-images/" + client.ID + ".webp"
+ createDummyFile(t, webpPath)
+ require.True(t, fileExists(t, webpPath), "WEBP file should exist before update")
+
+ // Update to the same type (WEBP)
+ err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
+ require.NoError(t, err)
+
+ // Verify the client still has WEBP
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.ImageType)
+ assert.Equal(t, "webp", *updatedClient.ImageType)
+
+ // WEBP file should still exist since type didn't change
+ assert.True(t, fileExists(t, webpPath), "WEBP file should still exist")
+ })
+
+ t.Run("Returns error for non-existent client", func(t *testing.T) {
+ err := s.updateClientLogoType(t.Context(), "non-existent-client-id", "png", true)
+ require.Error(t, err)
+ require.ErrorContains(t, err, "failed to look up client")
+ })
+}
+
+func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) {
+ // Create a test database
+ db := testutils.NewDatabaseForTest(t)
+
+ // Create database storage
+ dbStorage, err := storage.NewDatabaseStorage(db)
+ require.NoError(t, err)
+
+ // Create a test client
+ client := model.OidcClient{
+ Name: "Test Client",
+ CallbackURLs: model.UrlList{"https://example.com/callback"},
+ }
+ err = db.Create(&client).Error
+ require.NoError(t, err)
+
+ // Helper function to check if a file exists in storage
+ fileExists := func(t *testing.T, path string) bool {
+ t.Helper()
+ _, _, err := dbStorage.Open(t.Context(), path)
+ return err == nil
+ }
+
+ // Helper function to get file content from storage
+ getFileContent := func(t *testing.T, path string) []byte {
+ t.Helper()
+ reader, _, err := dbStorage.Open(t.Context(), path)
+ require.NoError(t, err)
+ defer reader.Close()
+ content, err := io.ReadAll(reader)
+ require.NoError(t, err)
+ return content
+ }
+
+ t.Run("Successfully downloads and saves PNG logo from URL", func(t *testing.T) {
+ // Create mock PNG content
+ pngContent := []byte("fake-png-content")
+
+ // Create a mock HTTP response with headers
+ //nolint:bodyclose
+ pngResponse := testutils.NewMockResponse(http.StatusOK, string(pngContent))
+ pngResponse.Header.Set("Content-Type", "image/png")
+
+ // Create a mock HTTP client with responses
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/logo.png": pngResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ // Init the OidcService with mock HTTP client
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ // Download and save the logo
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true)
+ require.NoError(t, err)
+
+ // Verify the file was saved
+ logoPath := "oidc-client-images/" + client.ID + ".png"
+ require.True(t, fileExists(t, logoPath), "Logo file should exist in storage")
+
+ // Verify the content
+ savedContent := getFileContent(t, logoPath)
+ assert.Equal(t, pngContent, savedContent)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.ImageType)
+ assert.Equal(t, "png", *updatedClient.ImageType)
+ })
+
+ t.Run("Successfully downloads and saves dark logo", func(t *testing.T) {
+ // Create mock WEBP content
+ webpContent := []byte("fake-webp-content")
+
+ //nolint:bodyclose
+ webpResponse := testutils.NewMockResponse(http.StatusOK, string(webpContent))
+ webpResponse.Header.Set("Content-Type", "image/webp")
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/dark-logo.webp": webpResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ // Download and save the dark logo
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false)
+ require.NoError(t, err)
+
+ // Verify the dark logo file was saved
+ darkLogoPath := "oidc-client-images/" + client.ID + "-dark.webp"
+ require.True(t, fileExists(t, darkLogoPath), "Dark logo file should exist in storage")
+
+ // Verify the content
+ savedContent := getFileContent(t, darkLogoPath)
+ assert.Equal(t, webpContent, savedContent)
+
+ // Verify the client was updated
+ var updatedClient model.OidcClient
+ err = db.First(&updatedClient, "id = ?", client.ID).Error
+ require.NoError(t, err)
+ require.NotNil(t, updatedClient.DarkImageType)
+ assert.Equal(t, "webp", *updatedClient.DarkImageType)
+ })
+
+ t.Run("Detects extension from URL path", func(t *testing.T) {
+ svgContent := []byte("")
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)),
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true)
+ require.NoError(t, err)
+
+ // Verify SVG file was saved
+ logoPath := "oidc-client-images/" + client.ID + ".svg"
+ require.True(t, fileExists(t, logoPath), "SVG logo should exist")
+ })
+
+ t.Run("Detects extension from Content-Type when path has no extension", func(t *testing.T) {
+ jpgContent := []byte("fake-jpg-content")
+
+ //nolint:bodyclose
+ jpgResponse := testutils.NewMockResponse(http.StatusOK, string(jpgContent))
+ jpgResponse.Header.Set("Content-Type", "image/jpeg")
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/logo": jpgResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true)
+ require.NoError(t, err)
+
+ // Verify JPG file was saved (jpeg extension is normalized to jpg)
+ logoPath := "oidc-client-images/" + client.ID + ".jpg"
+ require.True(t, fileExists(t, logoPath), "JPG logo should exist")
+ })
+
+ t.Run("Returns error for invalid URL", func(t *testing.T) {
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: &http.Client{},
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "://invalid-url", true)
+ require.Error(t, err)
+ })
+
+ t.Run("Returns error for non-200 status code", func(t *testing.T) {
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"),
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true)
+ require.Error(t, err)
+ require.ErrorContains(t, err, "failed to fetch logo")
+ })
+
+ t.Run("Returns error for too large content", func(t *testing.T) {
+ // Create content larger than 2MB (maxLogoSize)
+ largeContent := strings.Repeat("x", 2<<20+100) // 2.1MB
+
+ //nolint:bodyclose
+ largeResponse := testutils.NewMockResponse(http.StatusOK, largeContent)
+ largeResponse.Header.Set("Content-Type", "image/png")
+ largeResponse.Header.Set("Content-Length", strconv.Itoa(len(largeContent)))
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/large.png": largeResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true)
+ require.Error(t, err)
+ require.ErrorIs(t, err, errLogoTooLarge)
+ })
+
+ t.Run("Returns error for unsupported file type", func(t *testing.T) {
+ //nolint:bodyclose
+ textResponse := testutils.NewMockResponse(http.StatusOK, "text content")
+ textResponse.Header.Set("Content-Type", "text/plain")
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/file.txt": textResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true)
+ require.Error(t, err)
+ var fileTypeErr *common.FileTypeNotSupportedError
+ require.ErrorAs(t, err, &fileTypeErr)
+ })
+
+ t.Run("Returns error for non-existent client", func(t *testing.T) {
+ //nolint:bodyclose
+ pngResponse := testutils.NewMockResponse(http.StatusOK, "content")
+ pngResponse.Header.Set("Content-Type", "image/png")
+
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ "https://example.com/logo.png": pngResponse,
+ }
+ httpClient := &http.Client{
+ Transport: &testutils.MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ s := &OidcService{
+ db: db,
+ fileStorage: dbStorage,
+ httpClient: httpClient,
+ }
+
+ err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true)
+ require.Error(t, err)
+ require.ErrorContains(t, err, "failed to look up client")
+ })
+}
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 6cef3176..fa92a27d 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -17,6 +17,7 @@ import (
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
+ "gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -101,9 +102,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
profilePicturePath := path.Join("profile-pictures", userID+".png")
// Try custom profile picture
- if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
+ file, size, err := s.fileStorage.Open(ctx, profilePicturePath)
+ if err == nil {
return file, size, nil
- } else if err != nil && !errors.Is(err, fs.ErrNotExist) {
+ } else if !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
@@ -120,9 +122,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
// Try cached default for initials
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
- if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
+ file, size, err = s.fileStorage.Open(ctx, defaultPicturePath)
+ if err == nil {
return file, size, nil
- } else if err != nil && !errors.Is(err, fs.ErrNotExist) {
+ } else if !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
@@ -133,12 +136,13 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
}
// Save the default picture for future use (in a goroutine to avoid blocking)
- //nolint:contextcheck
defaultPictureBytes := defaultPicture.Bytes()
//nolint:contextcheck
go func() {
- if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
- slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
+ // Use bytes.NewReader because we need an io.ReadSeeker
+ rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes))
+ if rErr != nil {
+ slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr))
}
}()
@@ -182,17 +186,30 @@ func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, f
}
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
- return s.db.Transaction(func(tx *gorm.DB) error {
- return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
+ err := s.db.Transaction(func(tx *gorm.DB) error {
+ return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete)
})
+ if err != nil {
+ return fmt.Errorf("failed to delete user '%s': %w", userID, err)
+ }
+
+ // Storage operations must be executed outside of a transaction
+ profilePicturePath := path.Join("profile-pictures", userID+".png")
+ err = s.fileStorage.Delete(ctx, profilePicturePath)
+ if err != nil && !storage.IsNotExist(err) {
+ return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err)
+ }
+
+ return nil
}
-func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
+func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error {
var user model.User
err := tx.
WithContext(ctx).
Where("id = ?", userID).
+ Clauses(clause.Locking{Strength: "UPDATE"}).
First(&user).
Error
if err != nil {
@@ -204,11 +221,6 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
return &common.LdapUserUpdateError{}
}
- profilePicturePath := path.Join("profile-pictures", userID+".png")
- if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
- return err
- }
-
err = tx.WithContext(ctx).Delete(&user).Error
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
@@ -286,16 +298,27 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
// Apply default user groups
var groupIDs []string
- if v := config.SignupDefaultUserGroupIDs.Value; v != "" && v != "[]" {
- if err := json.Unmarshal([]byte(v), &groupIDs); err != nil {
+ v := config.SignupDefaultUserGroupIDs.Value
+ if v != "" && v != "[]" {
+ err := json.Unmarshal([]byte(v), &groupIDs)
+ if err != nil {
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
}
if len(groupIDs) > 0 {
var groups []model.UserGroup
- if err := tx.WithContext(ctx).Where("id IN ?", groupIDs).Find(&groups).Error; err != nil {
+ err = tx.WithContext(ctx).
+ Where("id IN ?", groupIDs).
+ Find(&groups).
+ Error
+ if err != nil {
return fmt.Errorf("failed to find default user groups: %w", err)
}
- if err := tx.WithContext(ctx).Model(user).Association("UserGroups").Replace(groups); err != nil {
+
+ err = tx.WithContext(ctx).
+ Model(user).
+ Association("UserGroups").
+ Replace(groups)
+ if err != nil {
return fmt.Errorf("failed to associate default user groups: %w", err)
}
}
@@ -303,12 +326,15 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
// Apply default custom claims
var claims []dto.CustomClaimCreateDto
- if v := config.SignupDefaultCustomClaims.Value; v != "" && v != "[]" {
- if err := json.Unmarshal([]byte(v), &claims); err != nil {
+ v = config.SignupDefaultCustomClaims.Value
+ if v != "" && v != "[]" {
+ err := json.Unmarshal([]byte(v), &claims)
+ if err != nil {
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
}
if len(claims) > 0 {
- if _, err := s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx); err != nil {
+ _, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx)
+ if err != nil {
return fmt.Errorf("failed to apply default custom claims: %w", err)
}
}
@@ -345,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
err := tx.
WithContext(ctx).
Where("id = ?", userID).
+ Clauses(clause.Locking{Strength: "UPDATE"}).
First(&user).
Error
if err != nil {
@@ -416,13 +443,11 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context
var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
- if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
// Do not return error if user not found to prevent email enumeration
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return nil
- } else {
- return err
- }
+ return nil
+ } else if err != nil {
+ return err
}
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
@@ -513,7 +538,9 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
var oneTimeAccessToken model.OneTimeAccessToken
err := tx.
WithContext(ctx).
- Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
+ Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).
+ Preload("User").
+ Clauses(clause.Locking{Strength: "UPDATE"}).
First(&oneTimeAccessToken).
Error
if err != nil {
@@ -679,7 +706,7 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er
return nil
}
-func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error {
+func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
return tx.
WithContext(ctx).
Model(&model.User{}).
@@ -720,6 +747,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
err := tx.
WithContext(ctx).
Where("token = ?", signupData.Token).
+ Clauses(clause.Locking{Strength: "UPDATE"}).
First(&signupToken).
Error
if err != nil {
diff --git a/backend/internal/service/version_service.go b/backend/internal/service/version_service.go
index 2a12930a..d52fd6f9 100644
--- a/backend/internal/service/version_service.go
+++ b/backend/internal/service/version_service.go
@@ -58,7 +58,7 @@ func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) {
}
if payload.TagName == "" {
- return "", fmt.Errorf("GitHub API returned empty tag name")
+ return "", errors.New("GitHub API returned empty tag name")
}
return strings.TrimPrefix(payload.TagName, "v"), nil
diff --git a/backend/internal/service/webauthn_service.go b/backend/internal/service/webauthn_service.go
index d32e149b..79aa2e5e 100644
--- a/backend/internal/service/webauthn_service.go
+++ b/backend/internal/service/webauthn_service.go
@@ -2,6 +2,8 @@ package service
import (
"context"
+ "encoding/hex"
+ "errors"
"fmt"
"net/http"
"time"
@@ -114,7 +116,7 @@ func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string)
}, nil
}
-func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
+func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID string, userID string, r *http.Request, ipAddress string) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -173,6 +175,9 @@ func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, use
return model.WebauthnCredential{}, fmt.Errorf("failed to store WebAuthn credential: %w", err)
}
+ auditLogData := model.AuditLogData{"credentialID": hex.EncodeToString(credential.ID), "passkeyName": passkeyName}
+ s.auditLogService.Create(ctx, model.AuditLogEventPasskeyAdded, ipAddress, r.UserAgent(), userID, auditLogData, tx)
+
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, fmt.Errorf("failed to commit transaction: %w", err)
@@ -288,16 +293,30 @@ func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([
return credentials, nil
}
-func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
- err := s.db.
+func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID string, credentialID string, ipAddress string, userAgent string) error {
+ tx := s.db.Begin()
+ defer func() {
+ tx.Rollback()
+ }()
+
+ credential := &model.WebauthnCredential{}
+ err := tx.
WithContext(ctx).
- Where("id = ? AND user_id = ?", credentialID, userID).
- Delete(&model.WebauthnCredential{}).
+ Clauses(clause.Returning{}).
+ Delete(credential, "id = ? AND user_id = ?", credentialID, userID).
Error
if err != nil {
return fmt.Errorf("failed to delete record: %w", err)
}
+ auditLogData := model.AuditLogData{"credentialID": hex.EncodeToString(credential.CredentialID), "passkeyName": credential.Name}
+ s.auditLogService.Create(ctx, model.AuditLogEventPasskeyRemoved, ipAddress, userAgent, userID, auditLogData, tx)
+
+ err = tx.Commit().Error
+ if err != nil {
+ return fmt.Errorf("failed to commit transaction: %w", err)
+ }
+
return nil
}
@@ -353,7 +372,7 @@ func (s *WebAuthnService) CreateReauthenticationTokenWithAccessToken(ctx context
userID, ok := token.Subject()
if !ok {
- return "", fmt.Errorf("access token does not contain user ID")
+ return "", errors.New("access token does not contain user ID")
}
// Check if token is issued less than a minute ago
diff --git a/backend/internal/storage/database.go b/backend/internal/storage/database.go
new file mode 100644
index 00000000..2c8779dc
--- /dev/null
+++ b/backend/internal/storage/database.go
@@ -0,0 +1,226 @@
+package storage
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/pocket-id/pocket-id/backend/internal/model"
+ datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+)
+
+var TypeDatabase = "database"
+
+type databaseStorage struct {
+ db *gorm.DB
+}
+
+// NewDatabaseStorage creates a new database storage provider
+func NewDatabaseStorage(db *gorm.DB) (FileStorage, error) {
+ if db == nil {
+ return nil, errors.New("database connection is required")
+ }
+ return &databaseStorage{db: db}, nil
+}
+
+func (s *databaseStorage) Type() string {
+ return TypeDatabase
+}
+
+func (s *databaseStorage) Save(ctx context.Context, relativePath string, data io.Reader) error {
+ // Normalize the path
+ relativePath = filepath.ToSlash(filepath.Clean(relativePath))
+
+ // Read all data into memory
+ b, err := io.ReadAll(data)
+ if err != nil {
+ return fmt.Errorf("failed to read data: %w", err)
+ }
+
+ now := datatype.DateTime(time.Now())
+ storage := model.Storage{
+ Path: relativePath,
+ Data: b,
+ Size: int64(len(b)),
+ ModTime: now,
+ CreatedAt: now,
+ }
+
+ // Use upsert: insert or update on conflict
+ result := s.db.
+ WithContext(ctx).
+ Clauses(clause.OnConflict{
+ Columns: []clause.Column{{Name: "path"}},
+ DoUpdates: clause.AssignmentColumns([]string{"data", "size", "mod_time"}),
+ }).
+ Create(&storage)
+
+ if result.Error != nil {
+ return fmt.Errorf("failed to save file to database: %w", result.Error)
+ }
+
+ return nil
+}
+
+func (s *databaseStorage) Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error) {
+ relativePath = filepath.ToSlash(filepath.Clean(relativePath))
+
+ var storage model.Storage
+ result := s.db.
+ WithContext(ctx).
+ Where("path = ?", relativePath).
+ First(&storage)
+
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, 0, os.ErrNotExist
+ }
+ return nil, 0, fmt.Errorf("failed to read file from database: %w", result.Error)
+ }
+
+ reader := io.NopCloser(bytes.NewReader(storage.Data))
+ return reader, storage.Size, nil
+}
+
+func (s *databaseStorage) Delete(ctx context.Context, relativePath string) error {
+ relativePath = filepath.ToSlash(filepath.Clean(relativePath))
+
+ result := s.db.
+ WithContext(ctx).
+ Where("path = ?", relativePath).
+ Delete(&model.Storage{})
+ if result.Error != nil {
+ return fmt.Errorf("failed to delete file from database: %w", result.Error)
+ }
+
+ return nil
+}
+
+func (s *databaseStorage) DeleteAll(ctx context.Context, prefix string) error {
+ prefix = filepath.ToSlash(filepath.Clean(prefix))
+
+ // If empty prefix, delete all
+ if isRootPath(prefix) {
+ result := s.db.
+ WithContext(ctx).
+ Where("1 = 1"). // Delete everything
+ Delete(&model.Storage{})
+ if result.Error != nil {
+ return fmt.Errorf("failed to delete all files from database: %w", result.Error)
+ }
+ return nil
+ }
+
+ // Ensure prefix ends with / for proper prefix matching
+ if !strings.HasSuffix(prefix, "/") {
+ prefix += "/"
+ }
+
+ query := s.db.WithContext(ctx)
+ query = addPathPrefixClause(s.db.Name(), query, prefix)
+ result := query.Delete(&model.Storage{})
+ if result.Error != nil {
+ return fmt.Errorf("failed to delete files with prefix '%s' from database: %w", prefix, result.Error)
+ }
+
+ return nil
+}
+
+func (s *databaseStorage) List(ctx context.Context, prefix string) ([]ObjectInfo, error) {
+ prefix = filepath.ToSlash(filepath.Clean(prefix))
+
+ var storageItems []model.Storage
+ query := s.db.WithContext(ctx)
+
+ if !isRootPath(prefix) {
+ // Ensure prefix matching
+ if !strings.HasSuffix(prefix, "/") {
+ prefix += "/"
+ }
+ query = addPathPrefixClause(s.db.Name(), query, prefix)
+ }
+
+ result := query.
+ Select("path", "size", "mod_time").
+ Find(&storageItems)
+ if result.Error != nil {
+ return nil, fmt.Errorf("failed to list files from database: %w", result.Error)
+ }
+
+ objects := make([]ObjectInfo, 0, len(storageItems))
+ for _, item := range storageItems {
+ // Filter out directory-like paths (those that contain additional slashes after the prefix)
+ relativePath := strings.TrimPrefix(item.Path, prefix)
+ if strings.ContainsRune(relativePath, '/') {
+ continue
+ }
+
+ objects = append(objects, ObjectInfo{
+ Path: item.Path,
+ Size: item.Size,
+ ModTime: time.Time(item.ModTime),
+ })
+ }
+
+ return objects, nil
+}
+
+func (s *databaseStorage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
+ root = filepath.ToSlash(filepath.Clean(root))
+
+ var storageItems []model.Storage
+ query := s.db.WithContext(ctx)
+
+ if !isRootPath(root) {
+ // Ensure root matching
+ if !strings.HasSuffix(root, "/") {
+ root += "/"
+ }
+ query = addPathPrefixClause(s.db.Name(), query, root)
+ }
+
+ result := query.
+ Select("path", "size", "mod_time").
+ Find(&storageItems)
+ if result.Error != nil {
+ return fmt.Errorf("failed to walk files from database: %w", result.Error)
+ }
+
+ for _, item := range storageItems {
+ err := fn(ObjectInfo{
+ Path: item.Path,
+ Size: item.Size,
+ ModTime: time.Time(item.ModTime),
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func isRootPath(path string) bool {
+ return path == "" || path == "/" || path == "."
+}
+
+func addPathPrefixClause(dialect string, query *gorm.DB, prefix string) *gorm.DB {
+ // In SQLite, we use "GLOB" which can use the index
+ switch dialect {
+ case "sqlite":
+ return query.Where("path GLOB ?", prefix+"*")
+ case "postgres":
+ return query.Where("path LIKE ?", prefix+"%")
+ default:
+ // Indicates a development-time error
+ panic(fmt.Errorf("unsupported database dialect: %s", dialect))
+ }
+}
diff --git a/backend/internal/storage/database_test.go b/backend/internal/storage/database_test.go
new file mode 100644
index 00000000..208fb7b7
--- /dev/null
+++ b/backend/internal/storage/database_test.go
@@ -0,0 +1,148 @@
+package storage
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ testingutil "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
+)
+
+func TestDatabaseStorageOperations(t *testing.T) {
+ ctx := context.Background()
+ db := testingutil.NewDatabaseForTest(t)
+ store, err := NewDatabaseStorage(db)
+ require.NoError(t, err)
+
+ t.Run("type should be database", func(t *testing.T) {
+ assert.Equal(t, TypeDatabase, store.Type())
+ })
+
+ t.Run("save, open and list files", func(t *testing.T) {
+ err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data"))
+ require.NoError(t, err)
+
+ reader, size, err := store.Open(ctx, "images/logo.png")
+ require.NoError(t, err)
+ defer reader.Close()
+
+ contents, err := io.ReadAll(reader)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("logo-data"), contents)
+ assert.Equal(t, int64(len(contents)), size)
+
+ err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child"))
+ require.NoError(t, err)
+
+ files, err := store.List(ctx, "images")
+ require.NoError(t, err)
+ require.Len(t, files, 1)
+ assert.Equal(t, "images/logo.png", files[0].Path)
+ assert.Equal(t, int64(len("logo-data")), files[0].Size)
+ })
+
+ t.Run("save should update existing file", func(t *testing.T) {
+ err := store.Save(ctx, "test/update.txt", bytes.NewBufferString("original"))
+ require.NoError(t, err)
+
+ err = store.Save(ctx, "test/update.txt", bytes.NewBufferString("updated"))
+ require.NoError(t, err)
+
+ reader, size, err := store.Open(ctx, "test/update.txt")
+ require.NoError(t, err)
+ defer reader.Close()
+
+ contents, err := io.ReadAll(reader)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("updated"), contents)
+ assert.Equal(t, int64(len("updated")), size)
+ })
+
+ t.Run("delete files individually", func(t *testing.T) {
+ err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp"))
+ require.NoError(t, err)
+
+ require.NoError(t, store.Delete(ctx, "images/delete-me.txt"))
+ _, _, err = store.Open(ctx, "images/delete-me.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+ })
+
+ t.Run("delete missing file should not error", func(t *testing.T) {
+ require.NoError(t, store.Delete(ctx, "images/missing.txt"))
+ })
+
+ t.Run("delete all files", func(t *testing.T) {
+ require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
+ require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
+ require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
+ require.NoError(t, store.DeleteAll(ctx, "/"))
+
+ _, _, err := store.Open(ctx, "cleanup/a.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+
+ _, _, err = store.Open(ctx, "cleanup/b.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+
+ _, _, err = store.Open(ctx, "cleanup/nested/c.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+ })
+
+ t.Run("delete all files under a prefix", func(t *testing.T) {
+ require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
+ require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
+ require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
+ require.NoError(t, store.DeleteAll(ctx, "cleanup"))
+
+ _, _, err := store.Open(ctx, "cleanup/a.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+
+ _, _, err = store.Open(ctx, "cleanup/b.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+
+ _, _, err = store.Open(ctx, "cleanup/nested/c.txt")
+ require.Error(t, err)
+ assert.True(t, IsNotExist(err))
+ })
+
+ t.Run("walk files", func(t *testing.T) {
+ require.NoError(t, store.Save(ctx, "walk/file1.txt", bytes.NewBufferString("1")))
+ require.NoError(t, store.Save(ctx, "walk/file2.txt", bytes.NewBufferString("2")))
+ require.NoError(t, store.Save(ctx, "walk/nested/file3.txt", bytes.NewBufferString("3")))
+
+ var paths []string
+ err := store.Walk(ctx, "walk", func(info ObjectInfo) error {
+ paths = append(paths, info.Path)
+ return nil
+ })
+ require.NoError(t, err)
+ assert.Len(t, paths, 3)
+ assert.Contains(t, paths, "walk/file1.txt")
+ assert.Contains(t, paths, "walk/file2.txt")
+ assert.Contains(t, paths, "walk/nested/file3.txt")
+ })
+}
+
+func TestNewDatabaseStorage(t *testing.T) {
+ t.Run("should return error with nil database", func(t *testing.T) {
+ _, err := NewDatabaseStorage(nil)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "database connection is required")
+ })
+
+ t.Run("should create storage with valid database", func(t *testing.T) {
+ db := testingutil.NewDatabaseForTest(t)
+ store, err := NewDatabaseStorage(db)
+ require.NoError(t, err)
+ assert.NotNil(t, store)
+ })
+}
diff --git a/backend/internal/utils/ip_util.go b/backend/internal/utils/ip_util.go
index 9832046b..be2ab83c 100644
--- a/backend/internal/utils/ip_util.go
+++ b/backend/internal/utils/ip_util.go
@@ -1,7 +1,10 @@
package utils
import (
+ "context"
+ "errors"
"net"
+ "net/url"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -56,6 +59,23 @@ func IsPrivateIP(ip net.IP) bool {
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
}
+func IsURLPrivate(ctx context.Context, u *url.URL) (bool, error) {
+ var r net.Resolver
+ ips, err := r.LookupIPAddr(ctx, u.Hostname())
+ if err != nil || len(ips) == 0 {
+ return false, errors.New("cannot resolve hostname")
+ }
+
+ // Prevents SSRF by allowing only public IPs
+ for _, addr := range ips {
+ if IsPrivateIP(addr.IP) {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
for _, ipNet := range ipNets {
if ipNet.Contains(ip) {
diff --git a/backend/internal/utils/ip_util_test.go b/backend/internal/utils/ip_util_test.go
index 01c7bf68..5da1eb68 100644
--- a/backend/internal/utils/ip_util_test.go
+++ b/backend/internal/utils/ip_util_test.go
@@ -1,8 +1,14 @@
package utils
import (
+ "context"
"net"
+ "net/url"
"testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
@@ -20,9 +26,8 @@ func TestIsLocalhostIP(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := IsLocalhostIP(ip); got != tt.expected {
- t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+ got := IsLocalhostIP(ip)
+ assert.Equal(t, tt.expected, got)
}
}
@@ -40,9 +45,8 @@ func TestIsPrivateLanIP(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := IsPrivateLanIP(ip); got != tt.expected {
- t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+ got := IsPrivateLanIP(ip)
+ assert.Equal(t, tt.expected, got)
}
}
@@ -59,9 +63,9 @@ func TestIsTailscaleIP(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := IsTailscaleIP(ip); got != tt.expected {
- t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+
+ got := IsTailscaleIP(ip)
+ assert.Equal(t, tt.expected, got)
}
}
@@ -86,16 +90,17 @@ func TestIsLocalIPv6(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := IsLocalIPv6(ip); got != tt.expected {
- t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+ got := IsLocalIPv6(ip)
+ assert.Equal(t, tt.expected, got)
}
}
func TestIsPrivateIP(t *testing.T) {
// Save and restore env config
origRanges := common.EnvConfig.LocalIPv6Ranges
- defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
+ t.Cleanup(func() {
+ common.EnvConfig.LocalIPv6Ranges = origRanges
+ })
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
localIPv6Ranges = nil // reset
@@ -115,9 +120,8 @@ func TestIsPrivateIP(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := IsPrivateIP(ip); got != tt.expected {
- t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+ got := IsPrivateIP(ip)
+ assert.Equal(t, tt.expected, got)
}
}
@@ -138,22 +142,202 @@ func TestListContainsIP(t *testing.T) {
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
- if got := listContainsIP(list, ip); got != tt.expected {
- t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
- }
+ got := listContainsIP(list, ip)
+ assert.Equal(t, tt.expected, got)
}
}
func TestInit_LocalIPv6Ranges(t *testing.T) {
// Save and restore env config
origRanges := common.EnvConfig.LocalIPv6Ranges
- defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
+ t.Cleanup(func() {
+ common.EnvConfig.LocalIPv6Ranges = origRanges
+ })
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
localIPv6Ranges = nil
loadLocalIPv6Ranges()
- if len(localIPv6Ranges) != 2 {
- t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
+ assert.Len(t, localIPv6Ranges, 2)
+}
+
+func TestIsURLPrivate(t *testing.T) {
+ ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
+ defer cancel()
+
+ tests := []struct {
+ name string
+ urlStr string
+ expectPriv bool
+ expectError bool
+ }{
+ {
+ name: "localhost by name",
+ urlStr: "http://localhost",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "localhost with port",
+ urlStr: "http://localhost:8080",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "127.0.0.1 IP",
+ urlStr: "http://127.0.0.1",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "127.0.0.1 with port",
+ urlStr: "http://127.0.0.1:3000",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "IPv6 loopback",
+ urlStr: "http://[::1]",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "IPv6 loopback with port",
+ urlStr: "http://[::1]:8080",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "private IP 10.x.x.x",
+ urlStr: "http://10.0.0.1",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "private IP 192.168.x.x",
+ urlStr: "http://192.168.1.1",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "private IP 172.16.x.x",
+ urlStr: "http://172.16.0.1",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "Tailscale IP",
+ urlStr: "http://100.64.0.1",
+ expectPriv: true,
+ expectError: false,
+ },
+ {
+ name: "public IP - Google DNS",
+ urlStr: "http://8.8.8.8",
+ expectPriv: false,
+ expectError: false,
+ },
+ {
+ name: "public IP - Cloudflare DNS",
+ urlStr: "http://1.1.1.1",
+ expectPriv: false,
+ expectError: false,
+ },
+ {
+ name: "invalid hostname",
+ urlStr: "http://this-should-not-resolve-ever-123456789.invalid",
+ expectPriv: false,
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ u, err := url.Parse(tt.urlStr)
+ require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
+
+ isPriv, err := IsURLPrivate(ctx, u)
+
+ if tt.expectError {
+ require.Error(t, err, "IsURLPrivate(%s) expected error but got none", tt.urlStr)
+ } else {
+ require.NoError(t, err, "IsURLPrivate(%s) unexpected error", tt.urlStr)
+ assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
+ }
+ })
}
}
+
+func TestIsURLPrivate_WithDomainName(t *testing.T) {
+ // Note: These tests rely on actual DNS resolution
+ // They test real public domains to ensure they are not flagged as private
+ ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
+ defer cancel()
+
+ tests := []struct {
+ name string
+ urlStr string
+ expectPriv bool
+ }{
+ {
+ name: "Google public domain",
+ urlStr: "https://www.google.com",
+ expectPriv: false,
+ },
+ {
+ name: "GitHub public domain",
+ urlStr: "https://github.com",
+ expectPriv: false,
+ },
+ {
+ // localhost.localtest.me is a well-known domain that resolves to 127.0.0.1
+ name: "localhost.localtest.me resolves to 127.0.0.1",
+ urlStr: "http://localhost.localtest.me",
+ expectPriv: true,
+ },
+ {
+ // 10.0.0.1.nip.io resolves to 10.0.0.1 (private IP)
+ name: "nip.io domain resolving to private 10.x IP",
+ urlStr: "http://10.0.0.1.nip.io",
+ expectPriv: true,
+ },
+ {
+ // 192.168.1.1.nip.io resolves to 192.168.1.1 (private IP)
+ name: "nip.io domain resolving to private 192.168.x IP",
+ urlStr: "http://192.168.1.1.nip.io",
+ expectPriv: true,
+ },
+ {
+ // 127.0.0.1.nip.io resolves to 127.0.0.1 (localhost)
+ name: "nip.io domain resolving to localhost",
+ urlStr: "http://127.0.0.1.nip.io",
+ expectPriv: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ u, err := url.Parse(tt.urlStr)
+ require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
+
+ isPriv, err := IsURLPrivate(ctx, u)
+ if err != nil {
+ t.Skipf("DNS resolution failed for %s (network issue?): %v", tt.urlStr, err)
+ return
+ }
+
+ assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
+ })
+ }
+}
+
+func TestIsURLPrivate_ContextCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel() // Cancel immediately
+
+ u, err := url.Parse("http://example.com")
+ require.NoError(t, err, "Failed to parse URL")
+
+ _, err = IsURLPrivate(ctx, u)
+ assert.Error(t, err, "IsURLPrivate with cancelled context expected error but got none")
+}
diff --git a/backend/internal/utils/stream_util.go b/backend/internal/utils/stream_util.go
new file mode 100644
index 00000000..ce77ddf6
--- /dev/null
+++ b/backend/internal/utils/stream_util.go
@@ -0,0 +1,34 @@
+package utils
+
+import (
+ "errors"
+ "io"
+)
+
+var ErrSizeExceeded = errors.New("stream size exceeded")
+
+// LimitReader is like io.LimitReader but throws an error if the stream exceeds the max size
+// io.LimitReader instead just returns io.EOF
+// Adapted from https://github.com/golang/go/issues/51115#issuecomment-1079761212
+type LimitReader struct {
+ io.ReadCloser
+ N int64
+}
+
+func NewLimitReader(r io.ReadCloser, limit int64) *LimitReader {
+ return &LimitReader{r, limit}
+}
+
+func (r *LimitReader) Read(p []byte) (n int, err error) {
+ if r.N <= 0 {
+ return 0, ErrSizeExceeded
+ }
+
+ if int64(len(p)) > r.N {
+ p = p[0:r.N]
+ }
+
+ n, err = r.ReadCloser.Read(p)
+ r.N -= int64(n)
+ return
+}
diff --git a/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql b/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql
new file mode 100644
index 00000000..c914273a
--- /dev/null
+++ b/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql
@@ -0,0 +1 @@
+DROP TABLE storage;
diff --git a/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql b/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql
new file mode 100644
index 00000000..337d4daa
--- /dev/null
+++ b/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql
@@ -0,0 +1,9 @@
+-- The "storage" table contains file data stored in the database
+CREATE TABLE storage
+(
+ path TEXT NOT NULL PRIMARY KEY,
+ data BYTEA NOT NULL,
+ size BIGINT NOT NULL,
+ mod_time TIMESTAMPTZ NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL
+);
diff --git a/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql
new file mode 100644
index 00000000..2098d8dc
--- /dev/null
+++ b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql
@@ -0,0 +1 @@
+DROP INDEX idx_api_keys_expires_at;
diff --git a/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql
new file mode 100644
index 00000000..af1e6cad
--- /dev/null
+++ b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql
@@ -0,0 +1 @@
+CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
diff --git a/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql b/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql
new file mode 100644
index 00000000..2135db93
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql
@@ -0,0 +1,6 @@
+PRAGMA foreign_keys=OFF;
+BEGIN;
+DROP TABLE storage;
+
+COMMIT;
+PRAGMA foreign_keys=ON;
diff --git a/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql b/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql
new file mode 100644
index 00000000..12dd4dcb
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql
@@ -0,0 +1,14 @@
+PRAGMA foreign_keys=OFF;
+BEGIN;
+-- The "storage" table contains file data stored in the database
+CREATE TABLE storage
+(
+ path TEXT NOT NULL PRIMARY KEY,
+ data BLOB NOT NULL,
+ size INTEGER NOT NULL,
+ mod_time DATETIME NOT NULL,
+ created_at DATETIME NOT NULL
+);
+
+COMMIT;
+PRAGMA foreign_keys=ON;
diff --git a/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql
new file mode 100644
index 00000000..ec8c10cb
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql
@@ -0,0 +1,5 @@
+PRAGMA foreign_keys=OFF;
+BEGIN;
+DROP INDEX idx_api_keys_expires_at;
+COMMIT;
+PRAGMA foreign_keys=ON;
diff --git a/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql
new file mode 100644
index 00000000..899b4d02
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql
@@ -0,0 +1,5 @@
+PRAGMA foreign_keys=OFF;
+BEGIN;
+CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
+COMMIT;
+PRAGMA foreign_keys=ON;
diff --git a/frontend/messages/en.json b/frontend/messages/en.json
index 21e184be..02e3b77a 100644
--- a/frontend/messages/en.json
+++ b/frontend/messages/en.json
@@ -331,6 +331,10 @@
"token_sign_in": "Token Sign In",
"client_authorization": "Client Authorization",
"new_client_authorization": "New Client Authorization",
+ "device_code_authorization": "Device Code Authorization",
+ "new_device_code_authorization": "New Device Code Authorization",
+ "passkey_added": "Passkey Added",
+ "passkey_removed": "Passkey Removed",
"disable_animations": "Disable Animations",
"turn_off_ui_animations": "Turn off animations throughout the UI.",
"user_disabled": "Account Disabled",
diff --git a/frontend/src/lib/utils/audit-log-translator.ts b/frontend/src/lib/utils/audit-log-translator.ts
index d5e41798..50d6c35d 100644
--- a/frontend/src/lib/utils/audit-log-translator.ts
+++ b/frontend/src/lib/utils/audit-log-translator.ts
@@ -5,7 +5,11 @@ export const eventTypes: Record = {
TOKEN_SIGN_IN: m.token_sign_in(),
CLIENT_AUTHORIZATION: m.client_authorization(),
NEW_CLIENT_AUTHORIZATION: m.new_client_authorization(),
- ACCOUNT_CREATED: m.account_created()
+ ACCOUNT_CREATED: m.account_created(),
+ DEVICE_CODE_AUTHORIZATION: m.device_code_authorization(),
+ NEW_DEVICE_CODE_AUTHORIZATION: m.new_device_code_authorization(),
+ PASSKEY_ADDED: m.passkey_added(),
+ PASSKEY_REMOVED: m.passkey_removed(),
}
/**
diff --git a/tests/setup/docker-compose-postgres.yml b/tests/setup/docker-compose-postgres.yml
index 0171a91b..09539b75 100644
--- a/tests/setup/docker-compose-postgres.yml
+++ b/tests/setup/docker-compose-postgres.yml
@@ -20,8 +20,10 @@ services:
file: docker-compose.yml
service: pocket-id
environment:
+ - APP_ENV=test
- DB_PROVIDER=postgres
- DB_CONNECTION_STRING=postgres://postgres:postgres@postgres:5432/pocket-id
+ - FILE_BACKEND=${FILE_BACKEND}
depends_on:
postgres:
condition: service_healthy
diff --git a/tests/setup/docker-compose-s3.yml b/tests/setup/docker-compose-s3.yml
index 7ec008bf..34bcb590 100644
--- a/tests/setup/docker-compose-s3.yml
+++ b/tests/setup/docker-compose-s3.yml
@@ -13,7 +13,7 @@ services:
retries: 10
create-bucket:
- image: amazon/aws-cli
+ image: amazon/aws-cli:latest
environment:
AWS_ACCESS_KEY_ID: test
AWS_SECRET_ACCESS_KEY: test
diff --git a/tests/setup/docker-compose.yml b/tests/setup/docker-compose.yml
index 8f551387..8ac7d80f 100644
--- a/tests/setup/docker-compose.yml
+++ b/tests/setup/docker-compose.yml
@@ -15,6 +15,7 @@ services:
environment:
APP_ENV: test
ENCRYPTION_KEY: test-encryption-key
+ FILE_BACKEND: ${FILE_BACKEND}
build:
args:
- BUILD_TAGS=e2etest