1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 17:24:48 +00:00

feat: add support for S3 storage backend (#1080)

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Elias Schneider
2025-11-10 10:02:25 +01:00
committed by GitHub
parent d5e0cfd4a6
commit bfd71d090c
28 changed files with 1084 additions and 616 deletions

View File

@@ -2,68 +2,76 @@ package bootstrap
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"path"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/resources"
)
// initApplicationImages copies the images from the images directory to the application-images directory
// initApplicationImages copies the images from the embedded directory to the storage backend
// and returns a map containing the detected file extensions in the application-images directory.
func initApplicationImages() (map[string]string, error) {
func initApplicationImages(ctx context.Context, fileStorage storage.FileStorage) (map[string]string, error) {
// Previous versions of images
// If these are found, they are deleted
legacyImageHashes := imageHashMap{
"background.jpg": mustDecodeHex("138d510030ed845d1d74de34658acabff562d306476454369a60ab8ade31933f"),
}
dirPath := common.EnvConfig.UploadPath + "/application-images"
sourceFiles, err := resources.FS.ReadDir("images")
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
destinationFiles, err := os.ReadDir(dirPath)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to read directory: %w", err)
destinationFiles, err := fileStorage.List(ctx, "application-images")
if err != nil {
if storage.IsNotExist(err) {
destinationFiles = []storage.ObjectInfo{}
} else {
return nil, fmt.Errorf("failed to list application images: %w", err)
}
}
dstNameToExt := make(map[string]string, len(destinationFiles))
for _, f := range destinationFiles {
if f.IsDir() {
continue
}
name := f.Name()
nameWithoutExt, ext := utils.SplitFileName(name)
destFilePath := path.Join(dirPath, name)
// Skip directories
if f.IsDir() {
_, name := path.Split(f.Path)
if name == "" {
continue
}
h, err := utils.CreateSha256FileHash(destFilePath)
nameWithoutExt, ext := utils.SplitFileName(name)
reader, _, err := fileStorage.Open(ctx, f.Path)
if err != nil {
slog.Warn("Failed to get hash for file", slog.String("name", name), slog.Any("error", err))
if errors.Is(err, fs.ErrNotExist) {
continue
}
slog.Warn("Failed to open application image for hashing", slog.String("name", name), slog.Any("error", err))
continue
}
hash, err := hashStream(reader)
reader.Close()
if err != nil {
slog.Warn("Failed to hash application image", slog.String("name", name), slog.Any("error", err))
continue
}
// Check if the file is a legacy one - if so, delete it
if legacyImageHashes.Contains(h) {
if legacyImageHashes.Contains(hash) {
slog.Info("Found legacy application image that will be removed", slog.String("name", name))
err = os.Remove(destFilePath)
if err != nil {
if err := fileStorage.Delete(ctx, f.Path); err != nil {
return nil, fmt.Errorf("failed to remove legacy file '%s': %w", name, err)
}
continue
}
// Track existing files
dstNameToExt[nameWithoutExt] = ext
}
@@ -76,21 +84,21 @@ func initApplicationImages() (map[string]string, error) {
name := sourceFile.Name()
nameWithoutExt, ext := utils.SplitFileName(name)
srcFilePath := path.Join("images", name)
destFilePath := path.Join(dirPath, name)
// Skip if there's already an image at the path
// We do not check the extension because users could have uploaded a different one
if _, exists := dstNameToExt[nameWithoutExt]; exists {
continue
}
slog.Info("Writing new application image", slog.String("name", name))
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
srcFile, err := resources.FS.Open(srcFilePath)
if err != nil {
return nil, fmt.Errorf("failed to copy file: %w", err)
return nil, fmt.Errorf("failed to open embedded file '%s': %w", name, err)
}
// Track the newly copied file so it can be included in the extensions map later
if err := fileStorage.Save(ctx, path.Join("application-images", name), srcFile); err != nil {
srcFile.Close()
return nil, fmt.Errorf("failed to store application image '%s': %w", name, err)
}
srcFile.Close()
dstNameToExt[nameWithoutExt] = ext
}
@@ -118,3 +126,11 @@ func mustDecodeHex(str string) []byte {
}
return b
}
func hashStream(r io.Reader) ([]byte, error) {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
return nil, err
}
return h.Sum(nil), nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/job"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
@@ -21,7 +22,31 @@ func Bootstrap(ctx context.Context) error {
}
slog.InfoContext(ctx, "Pocket ID is starting")
imageExtensions, err := initApplicationImages()
// Initialize the file storage backend
var fileStorage storage.FileStorage
switch common.EnvConfig.FileBackend {
case storage.TypeFileSystem:
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
case storage.TypeS3:
s3Cfg := storage.S3Config{
Bucket: common.EnvConfig.S3Bucket,
Region: common.EnvConfig.S3Region,
Endpoint: common.EnvConfig.S3Endpoint,
AccessKeyID: common.EnvConfig.S3AccessKeyID,
SecretAccessKey: common.EnvConfig.S3SecretAccessKey,
ForcePathStyle: common.EnvConfig.S3ForcePathStyle,
Root: common.EnvConfig.UploadPath,
}
fileStorage, err = storage.NewS3Storage(ctx, s3Cfg)
default:
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
}
if err != nil {
return fmt.Errorf("failed to initialize file storage: %w", err)
}
imageExtensions, err := initApplicationImages(ctx, fileStorage)
if err != nil {
return fmt.Errorf("failed to initialize application images: %w", err)
}
@@ -33,7 +58,7 @@ func Bootstrap(ctx context.Context) error {
}
// Create all services
svc, err := initServices(ctx, db, httpClient, imageExtensions)
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
}

View File

@@ -17,7 +17,7 @@ import (
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService, svc.fileStorage)
if err != nil {
slog.Error("Failed to initialize test service", slog.Any("error", err))
os.Exit(1)

View File

@@ -23,7 +23,7 @@ func registerScheduledJobs(ctx context.Context, db *gorm.DB, svc *services, http
if err != nil {
return fmt.Errorf("failed to register DB cleanup jobs in scheduler: %w", err)
}
err = scheduler.RegisterFileCleanupJobs(ctx, db)
err = scheduler.RegisterFileCleanupJobs(ctx, db, svc.fileStorage)
if err != nil {
return fmt.Errorf("failed to register file cleanup jobs in scheduler: %w", err)
}

View File

@@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/storage"
)
type services struct {
@@ -25,10 +26,11 @@ type services struct {
ldapService *service.LdapService
apiKeyService *service.ApiKeyService
versionService *service.VersionService
fileStorage storage.FileStorage
}
// Initializes all services
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string) (svc *services, err error) {
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string, fileStorage storage.FileStorage) (svc *services, err error) {
svc = &services{}
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
@@ -36,7 +38,8 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
return nil, fmt.Errorf("failed to create app config service: %w", err)
}
svc.appImagesService = service.NewAppImagesService(imageExtensions)
svc.fileStorage = fileStorage
svc.appImagesService = service.NewAppImagesService(imageExtensions, fileStorage)
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
if err != nil {
@@ -56,13 +59,13 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
}
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient)
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient, fileStorage)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
}
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)

View File

@@ -29,6 +29,7 @@ const (
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
defaultSqliteConnString string = "data/pocket-id.db"
defaultFsUploadPath string = "data/uploads"
AppUrl string = "http://localhost:1411"
)
@@ -38,7 +39,14 @@ type EnvConfigSchema struct {
AppURL string `env:"APP_URL" options:"toLower,trimTrailingSlash"`
DbProvider DbProvider `env:"DB_PROVIDER" options:"toLower"`
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
FileBackend string `env:"FILE_BACKEND" options:"toLower"`
UploadPath string `env:"UPLOAD_PATH"`
S3Bucket string `env:"S3_BUCKET"`
S3Region string `env:"S3_REGION"`
S3Endpoint string `env:"S3_ENDPOINT"`
S3AccessKeyID string `env:"S3_ACCESS_KEY_ID"`
S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"`
S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"`
KeysPath string `env:"KEYS_PATH"`
KeysStorage string `env:"KEYS_STORAGE"`
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
@@ -72,30 +80,16 @@ func init() {
func defaultConfig() EnvConfigSchema {
return EnvConfigSchema{
AppEnv: "production",
LogLevel: "info",
DbProvider: "sqlite",
DbConnectionString: "",
UploadPath: "data/uploads",
KeysPath: "data/keys",
KeysStorage: "", // "database" or "file"
EncryptionKey: nil,
AppURL: AppUrl,
Port: "1411",
Host: "0.0.0.0",
UnixSocket: "",
UnixSocketMode: "",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
LocalIPv6Ranges: "",
UiConfigDisabled: false,
MetricsEnabled: false,
TracingEnabled: false,
TrustProxy: false,
AnalyticsDisabled: false,
AllowDowngrade: false,
InternalAppURL: "",
AppEnv: "production",
LogLevel: "info",
DbProvider: "sqlite",
FileBackend: "fs",
KeysPath: "data/keys",
AppURL: AppUrl,
Port: "1411",
Host: "0.0.0.0",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
}
}
@@ -181,6 +175,19 @@ func validateEnvConfig(config *EnvConfigSchema) error {
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
}
switch config.FileBackend {
case "s3":
if config.KeysStorage == "file" {
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
}
case "", "fs":
if config.UploadPath == "" {
config.UploadPath = defaultFsUploadPath
}
default:
return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'")
}
// Validate LOCAL_IPV6_RANGES
ranges := strings.Split(config.LocalIPv6Ranges, ",")
for _, rangeStr := range ranges {

View File

@@ -208,6 +208,44 @@ func TestParseEnvConfig(t *testing.T) {
assert.Equal(t, "8080", EnvConfig.Port)
assert.Equal(t, "localhost", EnvConfig.Host) // lowercased
})
t.Run("should normalize file backend and default upload path", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("FILE_BACKEND", "FS")
t.Setenv("UPLOAD_PATH", "")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "fs", EnvConfig.FileBackend)
assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath)
})
t.Run("should fail when FILE_BACKEND is s3 but keys are stored on filesystem", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("FILE_BACKEND", "s3")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
})
t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("FILE_BACKEND", "invalid")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid FILE_BACKEND value")
})
}
func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {

View File

@@ -116,7 +116,7 @@ func (c *AppImagesController) updateLogoHandler(ctx *gin.Context) {
imageName = "logoDark"
}
if err := c.appImagesService.UpdateImage(file, imageName); err != nil {
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, imageName); err != nil {
_ = ctx.Error(err)
return
}
@@ -139,7 +139,7 @@ func (c *AppImagesController) updateBackgroundImageHandler(ctx *gin.Context) {
return
}
if err := c.appImagesService.UpdateImage(file, "background"); err != nil {
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "background"); err != nil {
_ = ctx.Error(err)
return
}
@@ -168,7 +168,7 @@ func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
return
}
if err := c.appImagesService.UpdateImage(file, "favicon"); err != nil {
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "favicon"); err != nil {
_ = ctx.Error(err)
return
}
@@ -177,15 +177,16 @@ func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
}
func (c *AppImagesController) getImage(ctx *gin.Context, name string) {
imagePath, mimeType, err := c.appImagesService.GetImage(name)
reader, size, mimeType, err := c.appImagesService.GetImage(ctx.Request.Context(), name)
if err != nil {
_ = ctx.Error(err)
return
}
defer reader.Close()
ctx.Header("Content-Type", mimeType)
utils.SetCacheControlHeader(ctx, 15*time.Minute, 24*time.Hour)
ctx.File(imagePath)
ctx.DataFromReader(http.StatusOK, size, mimeType, reader, nil)
}
// updateDefaultProfilePicture godoc
@@ -203,7 +204,7 @@ func (c *AppImagesController) updateDefaultProfilePicture(ctx *gin.Context) {
return
}
if err := c.appImagesService.UpdateImage(file, "default-profile-picture"); err != nil {
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "default-profile-picture"); err != nil {
_ = ctx.Error(err)
return
}
@@ -218,7 +219,7 @@ func (c *AppImagesController) updateDefaultProfilePicture(ctx *gin.Context) {
// @Success 204 "No Content"
// @Router /api/application-images/default-profile-picture [delete]
func (c *AppImagesController) deleteDefaultProfilePicture(ctx *gin.Context) {
if err := c.appImagesService.DeleteImage("default-profile-picture"); err != nil {
if err := c.appImagesService.DeleteImage(ctx.Request.Context(), "default-profile-picture"); err != nil {
_ = ctx.Error(err)
return
}

View File

@@ -548,16 +548,17 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"), lightLogo)
reader, size, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"), lightLogo)
if err != nil {
_ = c.Error(err)
return
}
defer reader.Close()
utils.SetCacheControlHeader(c, 15*time.Minute, 12*time.Hour)
c.Header("Content-Type", mimeType)
c.File(imagePath)
c.DataFromReader(http.StatusOK, size, mimeType, reader, nil)
}
// updateClientLogoHandler godoc

View File

@@ -286,7 +286,7 @@ func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) {
}
defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
if err := uc.userService.UpdateProfilePicture(c.Request.Context(), userID, file); err != nil {
_ = c.Error(err)
return
}
@@ -317,7 +317,7 @@ func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context)
}
defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
if err := uc.userService.UpdateProfilePicture(c.Request.Context(), userID, file); err != nil {
_ = c.Error(err)
return
}
@@ -687,7 +687,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id")
if err := uc.userService.ResetProfilePicture(userID); err != nil {
if err := uc.userService.ResetProfilePicture(c.Request.Context(), userID); err != nil {
_ = c.Error(err)
return
}
@@ -705,7 +705,7 @@ func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
func (uc *UserController) resetCurrentUserProfilePictureHandler(c *gin.Context) {
userID := c.GetString("userID")
if err := uc.userService.ResetProfilePicture(userID); err != nil {
if err := uc.userService.ResetProfilePicture(c.Request.Context(), userID); err != nil {
_ = c.Error(err)
return
}

View File

@@ -2,29 +2,36 @@ package job
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"path"
"strings"
"time"
"github.com/go-co-op/gocron/v2"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/storage"
)
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) error {
jobs := &FileCleanupJobs{db: db}
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
// Run every 24 hours
return s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
err := s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
// Only necessary for file system storage
if fileStorage.Type() == storage.TypeFileSystem {
err = errors.Join(err, s.registerJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true))
}
return err
}
type FileCleanupJobs struct {
db *gorm.DB
db *gorm.DB
fileStorage storage.FileStorage
}
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
@@ -44,29 +51,24 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
initialsInUse[user.Initials()] = struct{}{}
}
defaultPicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults"
if _, err := os.Stat(defaultPicturesDir); os.IsNotExist(err) {
return nil
}
files, err := os.ReadDir(defaultPicturesDir)
defaultPicturesDir := path.Join("profile-pictures", "defaults")
files, err := j.fileStorage.List(ctx, defaultPicturesDir)
if err != nil {
return fmt.Errorf("failed to read default profile pictures directory: %w", err)
return fmt.Errorf("failed to list default profile pictures: %w", err)
}
filesDeleted := 0
for _, file := range files {
if file.IsDir() {
continue // Skip directories
_, filename := path.Split(file.Path)
if filename == "" {
continue
}
filename := file.Name()
initials := strings.TrimSuffix(filename, ".png")
// If these initials aren't used by any user, delete the file
if _, ok := initialsInUse[initials]; !ok {
filePath := filepath.Join(defaultPicturesDir, filename)
if err := os.Remove(filePath); err != nil {
filePath := path.Join(defaultPicturesDir, filename)
if err := j.fileStorage.Delete(ctx, filePath); err != nil {
slog.ErrorContext(ctx, "Failed to delete unused default profile picture", slog.String("path", filePath), slog.Any("error", err))
} else {
filesDeleted++
@@ -77,3 +79,34 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
slog.Info("Done deleting unused default profile pictures", slog.Int("count", filesDeleted))
return nil
}
// clearOrphanedTempFiles deletes temporary files that are produced by failed atomic writes
func (j *FileCleanupJobs) clearOrphanedTempFiles(ctx context.Context) error {
const minAge = 10 * time.Minute
var deleted int
err := j.fileStorage.Walk(ctx, "/", func(p storage.ObjectInfo) error {
// Only temp files
if !strings.HasSuffix(p.Path, "-tmp") {
return nil
}
if time.Since(p.ModTime) < minAge {
return nil
}
if err := j.fileStorage.Delete(ctx, p.Path); err != nil {
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", err))
return nil
}
deleted++
return nil
})
if err != nil {
return fmt.Errorf("failed to scan storage: %w", err)
}
slog.Info("Done cleaning orphaned temp files", slog.Int("count", deleted))
return nil
}

View File

@@ -1,42 +1,52 @@
package service
import (
"context"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"path"
"strings"
"sync"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
type AppImagesService struct {
mu sync.RWMutex
extensions map[string]string
storage storage.FileStorage
}
func NewAppImagesService(extensions map[string]string) *AppImagesService {
return &AppImagesService{extensions: extensions}
func NewAppImagesService(extensions map[string]string, storage storage.FileStorage) *AppImagesService {
return &AppImagesService{extensions: extensions, storage: storage}
}
func (s *AppImagesService) GetImage(name string) (string, string, error) {
func (s *AppImagesService) GetImage(ctx context.Context, name string) (io.ReadCloser, int64, string, error) {
ext, err := s.getExtension(name)
if err != nil {
return "", "", err
return nil, 0, "", err
}
mimeType := utils.GetImageMimeType(ext)
if mimeType == "" {
return "", "", fmt.Errorf("unsupported image type '%s'", ext)
return nil, 0, "", fmt.Errorf("unsupported image type '%s'", ext)
}
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", name, ext))
return imagePath, mimeType, nil
imagePath := path.Join("application-images", name+"."+ext)
reader, size, err := s.storage.Open(ctx, imagePath)
if err != nil {
if storage.IsNotExist(err) {
return nil, 0, "", &common.ImageNotFoundError{}
}
return nil, 0, "", err
}
return reader, size, mimeType, nil
}
func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName string) error {
func (s *AppImagesService) UpdateImage(ctx context.Context, file *multipart.FileHeader, imageName string) error {
fileType := strings.ToLower(utils.GetFileExtension(file.Filename))
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
@@ -51,15 +61,20 @@ func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName str
s.extensions[imageName] = fileType
}
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, fileType))
imagePath := path.Join("application-images", imageName+"."+fileType)
fileReader, err := file.Open()
if err != nil {
return err
}
defer fileReader.Close()
if err := utils.SaveFile(file, imagePath); err != nil {
if err := s.storage.Save(ctx, imagePath, fileReader); err != nil {
return err
}
if currentExt != "" && currentExt != fileType {
oldImagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, currentExt))
if err := os.Remove(oldImagePath); err != nil && !os.IsNotExist(err) {
oldImagePath := path.Join("application-images", imageName+"."+currentExt)
if err := s.storage.Delete(ctx, oldImagePath); err != nil {
return err
}
}
@@ -69,7 +84,7 @@ func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName str
return nil
}
func (s *AppImagesService) DeleteImage(imageName string) error {
func (s *AppImagesService) DeleteImage(ctx context.Context, imageName string) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -78,8 +93,8 @@ func (s *AppImagesService) DeleteImage(imageName string) error {
return &common.ImageNotFoundError{}
}
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", imageName+"."+ext)
if err := os.Remove(imagePath); err != nil && !os.IsNotExist(err) {
imagePath := path.Join("application-images", imageName+"."+ext)
if err := s.storage.Delete(ctx, imagePath); err != nil {
return err
}

View File

@@ -2,66 +2,92 @@ package service
import (
"bytes"
"context"
"io"
"io/fs"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/storage"
)
func TestAppImagesService_GetImage(t *testing.T) {
tempDir := t.TempDir()
originalUploadPath := common.EnvConfig.UploadPath
common.EnvConfig.UploadPath = tempDir
t.Cleanup(func() {
common.EnvConfig.UploadPath = originalUploadPath
})
imagesDir := filepath.Join(tempDir, "application-images")
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
filePath := filepath.Join(imagesDir, "background.webp")
require.NoError(t, os.WriteFile(filePath, []byte("data"), fs.FileMode(0o644)))
service := NewAppImagesService(map[string]string{"background": "webp"})
path, mimeType, err := service.GetImage("background")
store, err := storage.NewFilesystemStorage(t.TempDir())
require.NoError(t, err)
require.Equal(t, filePath, path)
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "background.webp"), bytes.NewReader([]byte("data"))))
service := NewAppImagesService(map[string]string{"background": "webp"}, store)
reader, size, mimeType, err := service.GetImage(context.Background(), "background")
require.NoError(t, err)
defer reader.Close()
payload, err := io.ReadAll(reader)
require.NoError(t, err)
require.Equal(t, []byte("data"), payload)
require.Equal(t, int64(len(payload)), size)
require.Equal(t, "image/webp", mimeType)
}
func TestAppImagesService_UpdateImage(t *testing.T) {
tempDir := t.TempDir()
originalUploadPath := common.EnvConfig.UploadPath
common.EnvConfig.UploadPath = tempDir
t.Cleanup(func() {
common.EnvConfig.UploadPath = originalUploadPath
})
store, err := storage.NewFilesystemStorage(t.TempDir())
require.NoError(t, err)
imagesDir := filepath.Join(tempDir, "application-images")
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "logoLight.svg"), bytes.NewReader([]byte("old"))))
oldPath := filepath.Join(imagesDir, "logoLight.svg")
require.NoError(t, os.WriteFile(oldPath, []byte("old"), fs.FileMode(0o644)))
service := NewAppImagesService(map[string]string{"logoLight": "svg"})
service := NewAppImagesService(map[string]string{"logoLight": "svg"}, store)
fileHeader := newFileHeader(t, "logoLight.png", []byte("new"))
require.NoError(t, service.UpdateImage(fileHeader, "logoLight"))
require.NoError(t, service.UpdateImage(context.Background(), fileHeader, "logoLight"))
_, err := os.Stat(filepath.Join(imagesDir, "logoLight.png"))
reader, _, err := store.Open(context.Background(), path.Join("application-images", "logoLight.png"))
require.NoError(t, err)
_ = reader.Close()
_, _, err = store.Open(context.Background(), path.Join("application-images", "logoLight.svg"))
require.ErrorIs(t, err, fs.ErrNotExist)
}
func TestAppImagesService_ErrorsAndFlags(t *testing.T) {
store, err := storage.NewFilesystemStorage(t.TempDir())
require.NoError(t, err)
_, err = os.Stat(oldPath)
require.ErrorIs(t, err, os.ErrNotExist)
service := NewAppImagesService(map[string]string{}, store)
t.Run("get missing image returns not found", func(t *testing.T) {
_, _, _, err := service.GetImage(context.Background(), "missing")
require.Error(t, err)
var imageErr *common.ImageNotFoundError
assert.ErrorAs(t, err, &imageErr)
})
t.Run("reject unsupported file types", func(t *testing.T) {
err := service.UpdateImage(context.Background(), newFileHeader(t, "logo.txt", []byte("nope")), "logo")
require.Error(t, err)
var fileTypeErr *common.FileTypeNotSupportedError
assert.ErrorAs(t, err, &fileTypeErr)
})
t.Run("delete and extension tracking", func(t *testing.T) {
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "default-profile-picture.png"), bytes.NewReader([]byte("img"))))
service.extensions["default-profile-picture"] = "png"
require.NoError(t, service.DeleteImage(context.Background(), "default-profile-picture"))
assert.False(t, service.IsDefaultProfilePictureSet())
err := service.DeleteImage(context.Background(), "default-profile-picture")
require.Error(t, err)
var imageErr *common.ImageNotFoundError
assert.ErrorAs(t, err, &imageErr)
})
}
func newFileHeader(t *testing.T, filename string, content []byte) *multipart.FileHeader {

View File

@@ -11,8 +11,7 @@ import (
"encoding/base64"
"fmt"
"log/slog"
"os"
"path/filepath"
"path"
"time"
"github.com/fxamacker/cbor/v2"
@@ -25,6 +24,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
"github.com/pocket-id/pocket-id/backend/resources"
@@ -35,15 +35,17 @@ type TestService struct {
jwtService *JwtService
appConfigService *AppConfigService
ldapService *LdapService
fileStorage storage.FileStorage
externalIdPKey jwk.Key
}
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService, fileStorage storage.FileStorage) (*TestService, error) {
s := &TestService{
db: db,
appConfigService: appConfigService,
jwtService: jwtService,
ldapService: ldapService,
fileStorage: fileStorage,
}
err := s.initExternalIdP()
if err != nil {
@@ -424,8 +426,8 @@ func (s *TestService) ResetDatabase() error {
}
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
slog.ErrorContext(ctx, "Error removing directory", slog.Any("error", err))
if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil {
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
return err
}
@@ -435,13 +437,19 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error {
}
for _, file := range files {
srcFilePath := filepath.Join("images", file.Name())
destFilePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", file.Name())
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
if file.IsDir() {
continue
}
srcFilePath := path.Join("images", file.Name())
srcFile, err := resources.FS.Open(srcFilePath)
if err != nil {
return err
}
if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil {
srcFile.Close()
return err
}
srcFile.Close()
}
return nil

View File

@@ -470,7 +470,7 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
}
// Update the profile picture
err = s.userService.UpdateProfilePicture(userId, reader)
err = s.userService.UpdateProfilePicture(parentCtx, userId, reader)
if err != nil {
return fmt.Errorf("failed to update profile picture: %w", err)
}

View File

@@ -15,8 +15,7 @@ import (
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"path"
"regexp"
"slices"
"strings"
@@ -35,6 +34,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
@@ -59,8 +59,9 @@ type OidcService struct {
customClaimService *CustomClaimService
webAuthnService *WebAuthnService
httpClient *http.Client
jwkCache *jwk.Cache
httpClient *http.Client
jwkCache *jwk.Cache
fileStorage storage.FileStorage
}
func NewOidcService(
@@ -72,6 +73,7 @@ func NewOidcService(
customClaimService *CustomClaimService,
webAuthnService *WebAuthnService,
httpClient *http.Client,
fileStorage storage.FileStorage,
) (s *OidcService, err error) {
s = &OidcService{
db: db,
@@ -81,6 +83,7 @@ func NewOidcService(
customClaimService: customClaimService,
webAuthnService: webAuthnService,
httpClient: httpClient,
fileStorage: fileStorage,
}
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
@@ -884,34 +887,41 @@ func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (
return clientSecret, nil
}
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string, light bool) (string, string, error) {
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string, light bool) (io.ReadCloser, int64, string, error) {
var client model.OidcClient
err := s.db.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", err
return nil, 0, "", err
}
var imagePath, mimeType string
var suffix string
var ext string
switch {
case !light && client.DarkImageType != nil:
// Dark logo if requested and exists
imagePath = common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "-dark." + *client.DarkImageType
mimeType = utils.GetImageMimeType(*client.DarkImageType)
suffix = "-dark"
ext = *client.DarkImageType
case client.ImageType != nil:
// Light logo if requested or no dark logo is available
imagePath = common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
mimeType = utils.GetImageMimeType(*client.ImageType)
ext = *client.ImageType
default:
return "", "", errors.New("image not found")
return nil, 0, "", errors.New("image not found")
}
return imagePath, mimeType, nil
mimeType := utils.GetImageMimeType(ext)
if mimeType == "" {
return nil, 0, "", fmt.Errorf("unsupported image type '%s'", ext)
}
key := path.Join("oidc-client-images", client.ID+suffix+"."+ext)
reader, size, err := s.fileStorage.Open(ctx, key)
if err != nil {
return nil, 0, "", err
}
return reader, size, mimeType, nil
}
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader, light bool) error {
@@ -925,11 +935,15 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
darkSuffix = "-dark"
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + darkSuffix + "." + fileType
err := utils.SaveFile(file, imagePath)
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+fileType)
reader, err := file.Open()
if err != nil {
return err
}
defer reader.Close()
if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
return err
}
tx := s.db.Begin()
@@ -972,8 +986,8 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
return err
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + oldImageType
if err := os.Remove(imagePath); err != nil {
imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
return err
}
@@ -1015,8 +1029,8 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
return err
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "-dark." + oldImageType
if err := os.Remove(imagePath); err != nil {
imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
return err
}
@@ -2017,20 +2031,13 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
return &common.FileTypeNotSupportedError{}
}
folderPath := filepath.Join(common.EnvConfig.UploadPath, "oidc-client-images")
err = os.MkdirAll(folderPath, os.ModePerm)
if err != nil {
return err
}
var darkSuffix string
if !light {
darkSuffix = "-dark"
}
imagePath := filepath.Join(folderPath, clientID+darkSuffix+"."+ext)
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
if err != nil {
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
return err
}
@@ -2042,8 +2049,6 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
}
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
var darkSuffix string
if !light {
darkSuffix = "-dark"
@@ -2053,9 +2058,15 @@ func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, cli
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if client.ImageType != nil && *client.ImageType != ext {
old := fmt.Sprintf("%s/%s%s.%s", uploadsDir, client.ID, darkSuffix, *client.ImageType)
_ = os.Remove(old)
var currentType *string
if light {
currentType = client.ImageType
} else {
currentType = client.DarkImageType
}
if currentType != nil && *currentType != ext {
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
_ = s.fileStorage.Delete(ctx, old)
}
var column string

View File

@@ -7,10 +7,10 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/url"
"os"
"path/filepath"
"path"
"strings"
"time"
@@ -22,6 +22,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
@@ -35,9 +36,10 @@ type UserService struct {
appConfigService *AppConfigService
customClaimService *CustomClaimService
appImagesService *AppImagesService
fileStorage storage.FileStorage
}
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService) *UserService {
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService, fileStorage storage.FileStorage) *UserService {
return &UserService{
db: db,
jwtService: jwtService,
@@ -46,6 +48,7 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
appConfigService: appConfigService,
customClaimService: customClaimService,
appImagesService: appImagesService,
fileStorage: fileStorage,
}
}
@@ -95,34 +98,32 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
return nil, 0, err
}
profilePicturePath := filepath.Join(common.EnvConfig.UploadPath, "profile-pictures", userID+".png")
profilePicturePath := path.Join("profile-pictures", userID+".png")
// Try custom profile picture
if file, size, err := utils.OpenFileWithSize(profilePicturePath); err == nil {
if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
return file, size, nil
} else if !errors.Is(err, os.ErrNotExist) {
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
// Try default global profile picture
if s.appImagesService.IsDefaultProfilePictureSet() {
path, _, err := s.appImagesService.GetImage("default-profile-picture")
if err != nil {
return nil, 0, err
reader, size, _, err := s.appImagesService.GetImage(ctx, "default-profile-picture")
if err == nil {
return reader, size, nil
}
if file, size, err := utils.OpenFileWithSize(path); err == nil {
return file, size, nil
} else if !errors.Is(err, os.ErrNotExist) {
if !errors.Is(err, &common.ImageNotFoundError{}) {
return nil, 0, err
}
}
// Try cached default for initials
defaultProfilePicturesDir := filepath.Join(common.EnvConfig.UploadPath, "profile-pictures", "defaults")
defaultPicturePath := filepath.Join(defaultProfilePicturesDir, user.Initials()+".png")
if file, size, err := utils.OpenFileWithSize(defaultPicturePath); err == nil {
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
return file, size, nil
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
// Create and return generated default with initials
@@ -132,13 +133,11 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
}
// Save the default picture for future use (in a goroutine to avoid blocking)
//nolint:contextcheck
defaultPictureBytes := defaultPicture.Bytes()
//nolint:contextcheck
go func() {
if err := os.MkdirAll(defaultProfilePicturesDir, os.ModePerm); err != nil {
slog.Error("Failed to create directory for default profile picture", slog.Any("error", err))
return
}
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
}
}()
@@ -160,7 +159,7 @@ func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model
return user.UserGroups, nil
}
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal
err := uuid.Validate(userID)
if err != nil {
@@ -173,15 +172,8 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
return err
}
// Ensure the directory exists
profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
err = os.MkdirAll(profilePictureDir, os.ModePerm)
if err != nil {
return err
}
// Create the profile picture file
err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
profilePicturePath := path.Join("profile-pictures", userID+".png")
err = s.fileStorage.Save(ctx, profilePicturePath, profilePicture)
if err != nil {
return err
}
@@ -212,10 +204,8 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
return &common.LdapUserUpdateError{}
}
// Delete the profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
err = os.Remove(profilePicturePath)
if err != nil && !os.IsNotExist(err) {
profilePicturePath := path.Join("profile-pictures", userID+".png")
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
return err
}
@@ -676,26 +666,16 @@ func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User
}
// ResetProfilePicture deletes a user's custom profile picture
func (s *UserService) ResetProfilePicture(userID string) error {
func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) error {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return &common.InvalidUUIDError{}
}
// Build path to profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
// Check if file exists and delete it
if _, err := os.Stat(profilePicturePath); err == nil {
if err := os.Remove(profilePicturePath); err != nil {
return fmt.Errorf("failed to delete profile picture: %w", err)
}
} else if !os.IsNotExist(err) {
// If any error other than "file not exists"
return fmt.Errorf("failed to check if profile picture exists: %w", err)
profilePicturePath := path.Join("profile-pictures", userID+".png")
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
return fmt.Errorf("failed to delete profile picture: %w", err)
}
// It's okay if the file doesn't exist - just means there's no custom picture to delete
return nil
}

View File

@@ -0,0 +1,193 @@
package storage
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/google/uuid"
)
type filesystemStorage struct {
root *os.Root
absoluteRootPath string
}
func NewFilesystemStorage(rootPath string) (FileStorage, error) {
if err := os.MkdirAll(rootPath, 0700); err != nil {
return nil, fmt.Errorf("failed to create root directory '%s': %w", rootPath, err)
}
root, err := os.OpenRoot(rootPath)
if err != nil {
return nil, fmt.Errorf("failed to open root directory '%s': %w", rootPath, err)
}
absoluteRootPath, err := filepath.Abs(rootPath)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path of root directory '%s': %w", rootPath, err)
}
return &filesystemStorage{root: root, absoluteRootPath: absoluteRootPath}, err
}
func (s *filesystemStorage) Type() string {
return TypeFileSystem
}
func (s *filesystemStorage) Save(_ context.Context, path string, data io.Reader) error {
path = filepath.FromSlash(path)
if err := s.root.MkdirAll(filepath.Dir(path), 0700); err != nil {
return fmt.Errorf("failed to create directories for path '%s': %w", path, err)
}
// Our strategy is to save to a separate file and then rename it to override the original file
tmpName := path + "." + uuid.NewString() + "-tmp"
// Write to the temporary file
tmpFile, err := s.root.Create(tmpName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpName, err)
}
_, err = io.Copy(tmpFile, data)
if err != nil {
tmpFile.Close()
_ = s.root.Remove(tmpName)
return fmt.Errorf("failed to write temporary file: %w", err)
}
if err = tmpFile.Close(); err != nil {
_ = s.root.Remove(tmpName)
return fmt.Errorf("failed to close temporary file: %w", err)
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
if err = s.root.Rename(tmpName, path); err != nil {
_ = s.root.Remove(tmpName)
return fmt.Errorf("failed to move temporary file: %w", err)
}
return nil
}
func (s *filesystemStorage) Open(_ context.Context, path string) (io.ReadCloser, int64, error) {
path = filepath.FromSlash(path)
file, err := s.root.Open(path)
if err != nil {
return nil, 0, err
}
info, err := file.Stat()
if err != nil {
file.Close()
return nil, 0, err
}
return file, info.Size(), nil
}
func (s *filesystemStorage) Delete(_ context.Context, path string) error {
path = filepath.FromSlash(path)
err := s.root.Remove(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
}
return nil
}
func (s *filesystemStorage) DeleteAll(_ context.Context, path string) error {
path = filepath.FromSlash(path)
// If "/", "." or "" is requested, we delete all contents of the root.
if path == "" || path == "/" || path == "." {
dir, err := s.root.Open(".")
if err != nil {
return fmt.Errorf("failed to open root directory: %w", err)
}
defer dir.Close()
entries, err := dir.ReadDir(-1)
if err != nil {
return fmt.Errorf("failed to list root directory: %w", err)
}
for _, entry := range entries {
if err := s.root.RemoveAll(entry.Name()); err != nil {
return fmt.Errorf("failed to delete '%s': %w", entry.Name(), err)
}
}
return nil
}
return s.root.RemoveAll(path)
}
func (s *filesystemStorage) List(_ context.Context, path string) ([]ObjectInfo, error) {
path = filepath.FromSlash(path)
dir, err := s.root.Open(path)
if err != nil {
return nil, err
}
defer dir.Close()
entries, err := dir.ReadDir(-1)
if err != nil {
return nil, err
}
objects := make([]ObjectInfo, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
return nil, err
}
objects = append(objects, ObjectInfo{
Path: filepath.Join(path, entry.Name()),
Size: info.Size(),
ModTime: info.ModTime(),
})
}
return objects, nil
}
func (s *filesystemStorage) Walk(_ context.Context, root string, fn func(ObjectInfo) error) error {
root = filepath.FromSlash(root)
fullPath := filepath.Clean(filepath.Join(s.absoluteRootPath, root))
// As we can't use os.Root here, we manually ensure that the fullPath is within the root directory
sep := string(filepath.Separator)
if !strings.HasPrefix(fullPath+sep, s.absoluteRootPath+sep) {
return fmt.Errorf("invalid root path: %s", root)
}
return filepath.WalkDir(fullPath, func(full string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(s.absoluteRootPath, full)
if err != nil {
return err
}
info, err := d.Info()
if err != nil {
return err
}
return fn(ObjectInfo{
Path: filepath.ToSlash(rel),
Size: info.Size(),
ModTime: info.ModTime(),
})
})
}

View File

@@ -0,0 +1,68 @@
package storage
import (
"bytes"
"context"
"io"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFilesystemStorageOperations(t *testing.T) {
ctx := context.Background()
store, err := NewFilesystemStorage(t.TempDir())
require.NoError(t, err)
t.Run("save, open and list files", func(t *testing.T) {
err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data"))
require.NoError(t, err)
reader, size, err := store.Open(ctx, "images/logo.png")
require.NoError(t, err)
defer reader.Close()
contents, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, []byte("logo-data"), contents)
assert.Equal(t, int64(len(contents)), size)
err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child"))
require.NoError(t, err)
files, err := store.List(ctx, "images")
require.NoError(t, err)
require.Len(t, files, 1)
assert.Equal(t, filepath.Join("images", "logo.png"), files[0].Path)
assert.Equal(t, int64(len("logo-data")), files[0].Size)
})
t.Run("delete files individually and idempotently", func(t *testing.T) {
err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp"))
require.NoError(t, err)
require.NoError(t, store.Delete(ctx, "images/delete-me.txt"))
_, _, err = store.Open(ctx, "images/delete-me.txt")
require.Error(t, err)
assert.True(t, IsNotExist(err))
// Deleting a missing object should be a no-op.
require.NoError(t, store.Delete(ctx, "images/missing.txt"))
})
t.Run("delete all files under a prefix", func(t *testing.T) {
require.NoError(t, store.Save(ctx, "images/a.txt", bytes.NewBufferString("a")))
require.NoError(t, store.Save(ctx, "images/b.txt", bytes.NewBufferString("b")))
require.NoError(t, store.DeleteAll(ctx, "images"))
_, _, err := store.Open(ctx, "images/a.txt")
require.Error(t, err)
assert.True(t, IsNotExist(err))
_, _, err = store.Open(ctx, "images/b.txt")
require.Error(t, err)
assert.True(t, IsNotExist(err))
})
}

View File

@@ -0,0 +1,185 @@
package storage
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
awscfg "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
)
type S3Config struct {
Bucket string
Region string
Endpoint string
AccessKeyID string
SecretAccessKey string
ForcePathStyle bool
Root string
}
type s3Storage struct {
client *s3.Client
bucket string
prefix string
}
func NewS3Storage(ctx context.Context, cfg S3Config) (FileStorage, error) {
creds := credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")
awsCfg, err := awscfg.LoadDefaultConfig(ctx, awscfg.WithRegion(cfg.Region), awscfg.WithCredentialsProvider(creds))
if err != nil {
return nil, fmt.Errorf("failed to load AWS configuration: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = aws.String(cfg.Endpoint)
}
o.UsePathStyle = cfg.ForcePathStyle
})
return &s3Storage{
client: client,
bucket: cfg.Bucket,
prefix: strings.Trim(cfg.Root, "/"),
}, nil
}
func (s *s3Storage) Type() string {
return TypeS3
}
func (s *s3Storage) Save(ctx context.Context, path string, data io.Reader) error {
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.buildObjectKey(path)),
Body: data,
})
return err
}
func (s *s3Storage) Open(ctx context.Context, path string) (io.ReadCloser, int64, error) {
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.buildObjectKey(path)),
})
if err != nil {
if isS3NotFound(err) {
return nil, 0, fs.ErrNotExist
}
return nil, 0, err
}
return resp.Body, aws.ToInt64(resp.ContentLength), nil
}
func (s *s3Storage) Delete(ctx context.Context, path string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.buildObjectKey(path)),
})
return err
}
func (s *s3Storage) DeleteAll(ctx context.Context, path string) error {
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(s.buildObjectKey(path)),
})
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return err
}
if len(page.Contents) == 0 {
continue
}
objects := make([]s3types.ObjectIdentifier, 0, len(page.Contents))
for _, obj := range page.Contents {
objects = append(objects, s3types.ObjectIdentifier{Key: obj.Key})
}
_, err = s.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(s.bucket),
Delete: &s3types.Delete{Objects: objects, Quiet: aws.Bool(true)},
})
if err != nil {
return err
}
}
return nil
}
func (s *s3Storage) List(ctx context.Context, path string) ([]ObjectInfo, error) {
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(s.buildObjectKey(path)),
})
var objects []ObjectInfo
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, err
}
for _, obj := range page.Contents {
if obj.Key == nil {
continue
}
objects = append(objects, ObjectInfo{
Path: aws.ToString(obj.Key),
Size: aws.ToInt64(obj.Size),
ModTime: aws.ToTime(obj.LastModified),
})
}
}
return objects, nil
}
func (s *s3Storage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
objects, err := s.List(ctx, root)
if err != nil {
return err
}
for _, obj := range objects {
if err := fn(obj); err != nil {
return err
}
}
return nil
}
func (s *s3Storage) buildObjectKey(p string) string {
p = filepath.Clean(p)
p = filepath.ToSlash(p)
p = strings.Trim(p, "/")
if p == "" || p == "." {
return s.prefix
}
if s.prefix == "" {
return p
}
return s.prefix + "/" + p
}
func isS3NotFound(err error) bool {
var apiErr smithy.APIError
if errors.As(err, &apiErr) {
if apiErr.ErrorCode() == "NotFound" || apiErr.ErrorCode() == "NoSuchKey" {
return true
}
}
var missingKey *s3types.NoSuchKey
return errors.As(err, &missingKey)
}

View File

@@ -0,0 +1,44 @@
package storage
import (
"errors"
"testing"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
"github.com/stretchr/testify/assert"
)
func TestS3Helpers(t *testing.T) {
t.Run("buildObjectKey trims and joins prefix", func(t *testing.T) {
tests := []struct {
name string
prefix string
path string
expected string
}{
{name: "no prefix no path", prefix: "", path: "", expected: ""},
{name: "prefix no path", prefix: "root", path: "", expected: "root"},
{name: "prefix with nested path", prefix: "root", path: "foo/bar/baz", expected: "root/foo/bar/baz"},
{name: "trimmed path and prefix", prefix: "root", path: "/foo//bar/", expected: "root/foo/bar"},
{name: "no prefix path only", prefix: "", path: "./images/logo.png", expected: "images/logo.png"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s := &s3Storage{
bucket: "bucket",
prefix: tc.prefix,
}
assert.Equal(t, tc.expected, s.buildObjectKey(tc.path))
})
}
})
t.Run("isS3NotFound detects expected errors", func(t *testing.T) {
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NoSuchKey"}))
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NotFound"}))
assert.True(t, isS3NotFound(&s3types.NoSuchKey{}))
assert.False(t, isS3NotFound(errors.New("boom")))
})
}

View File

@@ -0,0 +1,33 @@
package storage
import (
"context"
"io"
"os"
"time"
)
var (
TypeFileSystem = "fs"
TypeS3 = "s3"
)
type ObjectInfo struct {
Path string
Size int64
ModTime time.Time
}
type FileStorage interface {
Save(ctx context.Context, relativePath string, data io.Reader) error
Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error)
Delete(ctx context.Context, relativePath string) error
DeleteAll(ctx context.Context, prefix string) error
List(ctx context.Context, prefix string) ([]ObjectInfo, error)
Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error
Type() string
}
func IsNotExist(err error) bool {
return os.IsNotExist(err)
}

View File

@@ -2,20 +2,15 @@ package utils
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/resources"
)
func GetFileExtension(filename string) string {
@@ -86,110 +81,6 @@ func GetImageExtensionFromMimeType(mimeType string) string {
}
}
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
srcFile, err := resources.FS.Open(srcFilePath)
if err != nil {
return fmt.Errorf("failed to open embedded file: %w", err)
}
defer srcFile.Close()
err = os.MkdirAll(filepath.Dir(destFilePath), os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
destFile, err := os.Create(destFilePath)
if err != nil {
return fmt.Errorf("failed to open destination file: %w", err)
}
defer destFile.Close()
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("failed to write to destination file: %w", err)
}
return nil
}
func EmbeddedFileSha256(filePath string) ([]byte, error) {
f, err := resources.FS.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open embedded file: %w", err)
}
defer f.Close()
h := sha256.New()
_, err = io.Copy(h, f)
if err != nil {
return nil, fmt.Errorf("failed to read embedded file: %w", err)
}
return h.Sum(nil), nil
}
func SaveFile(file *multipart.FileHeader, dst string) error {
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil {
return err
}
return SaveFileStream(src, dst)
}
// SaveFileStream saves a stream to a file.
func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file
tmpFileName := dstFileName + "." + uuid.NewString() + "-tmp"
// Write to the temporary file
tmpFile, err := os.Create(tmpFileName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
}
n, err := io.Copy(tmpFile, r)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
}
err = tmpFile.Close()
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
}
if n == 0 {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return errors.New("no data written")
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
err = os.Rename(tmpFileName, dstFileName)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
}
return nil
}
// FileExists returns true if a file exists on disk and is a regular file
func FileExists(path string) (bool, error) {
s, err := os.Stat(path)
@@ -239,17 +130,3 @@ func IsWritableDir(dir string) (bool, error) {
return true, nil
}
// OpenFileWithSize opens a file and returns its size
func OpenFileWithSize(path string) (io.ReadCloser, int64, error) {
f, err := os.Open(path)
if err != nil {
return nil, 0, err
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, 0, err
}
return f, info.Size(), nil
}

View File

@@ -19,7 +19,7 @@ import (
const profilePictureSize = 300
// CreateProfilePicture resizes the profile picture to a square
func CreateProfilePicture(file io.Reader) (io.Reader, error) {
func CreateProfilePicture(file io.Reader) (io.ReadSeeker, error) {
img, _, err := imageorient.Decode(file)
if err != nil {
return nil, fmt.Errorf("failed to decode image: %w", err)
@@ -27,17 +27,13 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
pr, pw := io.Pipe()
go func() {
innerErr := imaging.Encode(pw, img, imaging.PNG)
if innerErr != nil {
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", innerErr))
return
}
pw.Close()
}()
var buf bytes.Buffer
err = imaging.Encode(&buf, img, imaging.PNG)
if err != nil {
return nil, fmt.Errorf("failed to encode image: %w", err)
}
return pr, nil
return bytes.NewReader(buf.Bytes()), nil
}
// CreateDefaultProfilePicture creates a profile picture with the initials