mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 17:24:48 +00:00
feat: add support for S3 storage backend (#1080)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
@@ -2,68 +2,76 @@ package bootstrap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
// initApplicationImages copies the images from the images directory to the application-images directory
|
||||
// initApplicationImages copies the images from the embedded directory to the storage backend
|
||||
// and returns a map containing the detected file extensions in the application-images directory.
|
||||
func initApplicationImages() (map[string]string, error) {
|
||||
func initApplicationImages(ctx context.Context, fileStorage storage.FileStorage) (map[string]string, error) {
|
||||
// Previous versions of images
|
||||
// If these are found, they are deleted
|
||||
legacyImageHashes := imageHashMap{
|
||||
"background.jpg": mustDecodeHex("138d510030ed845d1d74de34658acabff562d306476454369a60ab8ade31933f"),
|
||||
}
|
||||
|
||||
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
||||
|
||||
sourceFiles, err := resources.FS.ReadDir("images")
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
destinationFiles, err := os.ReadDir(dirPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
destinationFiles, err := fileStorage.List(ctx, "application-images")
|
||||
if err != nil {
|
||||
if storage.IsNotExist(err) {
|
||||
destinationFiles = []storage.ObjectInfo{}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to list application images: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
dstNameToExt := make(map[string]string, len(destinationFiles))
|
||||
for _, f := range destinationFiles {
|
||||
if f.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := f.Name()
|
||||
nameWithoutExt, ext := utils.SplitFileName(name)
|
||||
destFilePath := path.Join(dirPath, name)
|
||||
|
||||
// Skip directories
|
||||
if f.IsDir() {
|
||||
_, name := path.Split(f.Path)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
h, err := utils.CreateSha256FileHash(destFilePath)
|
||||
nameWithoutExt, ext := utils.SplitFileName(name)
|
||||
reader, _, err := fileStorage.Open(ctx, f.Path)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to get hash for file", slog.String("name", name), slog.Any("error", err))
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("Failed to open application image for hashing", slog.String("name", name), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
hash, err := hashStream(reader)
|
||||
reader.Close()
|
||||
if err != nil {
|
||||
slog.Warn("Failed to hash application image", slog.String("name", name), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the file is a legacy one - if so, delete it
|
||||
if legacyImageHashes.Contains(h) {
|
||||
if legacyImageHashes.Contains(hash) {
|
||||
slog.Info("Found legacy application image that will be removed", slog.String("name", name))
|
||||
err = os.Remove(destFilePath)
|
||||
if err != nil {
|
||||
if err := fileStorage.Delete(ctx, f.Path); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove legacy file '%s': %w", name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Track existing files
|
||||
dstNameToExt[nameWithoutExt] = ext
|
||||
}
|
||||
|
||||
@@ -76,21 +84,21 @@ func initApplicationImages() (map[string]string, error) {
|
||||
name := sourceFile.Name()
|
||||
nameWithoutExt, ext := utils.SplitFileName(name)
|
||||
srcFilePath := path.Join("images", name)
|
||||
destFilePath := path.Join(dirPath, name)
|
||||
|
||||
// Skip if there's already an image at the path
|
||||
// We do not check the extension because users could have uploaded a different one
|
||||
if _, exists := dstNameToExt[nameWithoutExt]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Writing new application image", slog.String("name", name))
|
||||
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
|
||||
srcFile, err := resources.FS.Open(srcFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to copy file: %w", err)
|
||||
return nil, fmt.Errorf("failed to open embedded file '%s': %w", name, err)
|
||||
}
|
||||
|
||||
// Track the newly copied file so it can be included in the extensions map later
|
||||
if err := fileStorage.Save(ctx, path.Join("application-images", name), srcFile); err != nil {
|
||||
srcFile.Close()
|
||||
return nil, fmt.Errorf("failed to store application image '%s': %w", name, err)
|
||||
}
|
||||
srcFile.Close()
|
||||
dstNameToExt[nameWithoutExt] = ext
|
||||
}
|
||||
|
||||
@@ -118,3 +126,11 @@ func mustDecodeHex(str string) []byte {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func hashStream(r io.Reader) ([]byte, error) {
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/job"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,31 @@ func Bootstrap(ctx context.Context) error {
|
||||
}
|
||||
slog.InfoContext(ctx, "Pocket ID is starting")
|
||||
|
||||
imageExtensions, err := initApplicationImages()
|
||||
// Initialize the file storage backend
|
||||
var fileStorage storage.FileStorage
|
||||
|
||||
switch common.EnvConfig.FileBackend {
|
||||
case storage.TypeFileSystem:
|
||||
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
|
||||
case storage.TypeS3:
|
||||
s3Cfg := storage.S3Config{
|
||||
Bucket: common.EnvConfig.S3Bucket,
|
||||
Region: common.EnvConfig.S3Region,
|
||||
Endpoint: common.EnvConfig.S3Endpoint,
|
||||
AccessKeyID: common.EnvConfig.S3AccessKeyID,
|
||||
SecretAccessKey: common.EnvConfig.S3SecretAccessKey,
|
||||
ForcePathStyle: common.EnvConfig.S3ForcePathStyle,
|
||||
Root: common.EnvConfig.UploadPath,
|
||||
}
|
||||
fileStorage, err = storage.NewS3Storage(ctx, s3Cfg)
|
||||
default:
|
||||
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize file storage: %w", err)
|
||||
}
|
||||
|
||||
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||
}
|
||||
@@ -33,7 +58,7 @@ func Bootstrap(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create all services
|
||||
svc, err := initServices(ctx, db, httpClient, imageExtensions)
|
||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize services: %w", err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func init() {
|
||||
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
|
||||
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService, svc.fileStorage)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize test service", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
|
||||
@@ -23,7 +23,7 @@ func registerScheduledJobs(ctx context.Context, db *gorm.DB, svc *services, http
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register DB cleanup jobs in scheduler: %w", err)
|
||||
}
|
||||
err = scheduler.RegisterFileCleanupJobs(ctx, db)
|
||||
err = scheduler.RegisterFileCleanupJobs(ctx, db, svc.fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register file cleanup jobs in scheduler: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
)
|
||||
|
||||
type services struct {
|
||||
@@ -25,10 +26,11 @@ type services struct {
|
||||
ldapService *service.LdapService
|
||||
apiKeyService *service.ApiKeyService
|
||||
versionService *service.VersionService
|
||||
fileStorage storage.FileStorage
|
||||
}
|
||||
|
||||
// Initializes all services
|
||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string) (svc *services, err error) {
|
||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string, fileStorage storage.FileStorage) (svc *services, err error) {
|
||||
svc = &services{}
|
||||
|
||||
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
||||
@@ -36,7 +38,8 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
return nil, fmt.Errorf("failed to create app config service: %w", err)
|
||||
}
|
||||
|
||||
svc.appImagesService = service.NewAppImagesService(imageExtensions)
|
||||
svc.fileStorage = fileStorage
|
||||
svc.appImagesService = service.NewAppImagesService(imageExtensions, fileStorage)
|
||||
|
||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||
if err != nil {
|
||||
@@ -56,13 +59,13 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
||||
}
|
||||
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient)
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient, fileStorage)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||
}
|
||||
|
||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ const (
|
||||
DbProviderPostgres DbProvider = "postgres"
|
||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||
defaultSqliteConnString string = "data/pocket-id.db"
|
||||
defaultFsUploadPath string = "data/uploads"
|
||||
AppUrl string = "http://localhost:1411"
|
||||
)
|
||||
|
||||
@@ -38,7 +39,14 @@ type EnvConfigSchema struct {
|
||||
AppURL string `env:"APP_URL" options:"toLower,trimTrailingSlash"`
|
||||
DbProvider DbProvider `env:"DB_PROVIDER" options:"toLower"`
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
|
||||
FileBackend string `env:"FILE_BACKEND" options:"toLower"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
S3Bucket string `env:"S3_BUCKET"`
|
||||
S3Region string `env:"S3_REGION"`
|
||||
S3Endpoint string `env:"S3_ENDPOINT"`
|
||||
S3AccessKeyID string `env:"S3_ACCESS_KEY_ID"`
|
||||
S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"`
|
||||
S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
KeysStorage string `env:"KEYS_STORAGE"`
|
||||
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
|
||||
@@ -72,30 +80,16 @@ func init() {
|
||||
|
||||
func defaultConfig() EnvConfigSchema {
|
||||
return EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
LogLevel: "info",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
KeysStorage: "", // "database" or "file"
|
||||
EncryptionKey: nil,
|
||||
AppURL: AppUrl,
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
UnixSocket: "",
|
||||
UnixSocketMode: "",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
LocalIPv6Ranges: "",
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
TrustProxy: false,
|
||||
AnalyticsDisabled: false,
|
||||
AllowDowngrade: false,
|
||||
InternalAppURL: "",
|
||||
AppEnv: "production",
|
||||
LogLevel: "info",
|
||||
DbProvider: "sqlite",
|
||||
FileBackend: "fs",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: AppUrl,
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,6 +175,19 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
|
||||
}
|
||||
|
||||
switch config.FileBackend {
|
||||
case "s3":
|
||||
if config.KeysStorage == "file" {
|
||||
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||
}
|
||||
case "", "fs":
|
||||
if config.UploadPath == "" {
|
||||
config.UploadPath = defaultFsUploadPath
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'")
|
||||
}
|
||||
|
||||
// Validate LOCAL_IPV6_RANGES
|
||||
ranges := strings.Split(config.LocalIPv6Ranges, ",")
|
||||
for _, rangeStr := range ranges {
|
||||
|
||||
@@ -208,6 +208,44 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
assert.Equal(t, "8080", EnvConfig.Port)
|
||||
assert.Equal(t, "localhost", EnvConfig.Host) // lowercased
|
||||
})
|
||||
|
||||
t.Run("should normalize file backend and default upload path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "FS")
|
||||
t.Setenv("UPLOAD_PATH", "")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "fs", EnvConfig.FileBackend)
|
||||
assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath)
|
||||
})
|
||||
|
||||
t.Run("should fail when FILE_BACKEND is s3 but keys are stored on filesystem", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "s3")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "invalid")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid FILE_BACKEND value")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {
|
||||
|
||||
@@ -116,7 +116,7 @@ func (c *AppImagesController) updateLogoHandler(ctx *gin.Context) {
|
||||
imageName = "logoDark"
|
||||
}
|
||||
|
||||
if err := c.appImagesService.UpdateImage(file, imageName); err != nil {
|
||||
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, imageName); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func (c *AppImagesController) updateBackgroundImageHandler(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.appImagesService.UpdateImage(file, "background"); err != nil {
|
||||
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "background"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.appImagesService.UpdateImage(file, "favicon"); err != nil {
|
||||
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "favicon"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -177,15 +177,16 @@ func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
func (c *AppImagesController) getImage(ctx *gin.Context, name string) {
|
||||
imagePath, mimeType, err := c.appImagesService.GetImage(name)
|
||||
reader, size, mimeType, err := c.appImagesService.GetImage(ctx.Request.Context(), name)
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
ctx.Header("Content-Type", mimeType)
|
||||
utils.SetCacheControlHeader(ctx, 15*time.Minute, 24*time.Hour)
|
||||
ctx.File(imagePath)
|
||||
ctx.DataFromReader(http.StatusOK, size, mimeType, reader, nil)
|
||||
}
|
||||
|
||||
// updateDefaultProfilePicture godoc
|
||||
@@ -203,7 +204,7 @@ func (c *AppImagesController) updateDefaultProfilePicture(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.appImagesService.UpdateImage(file, "default-profile-picture"); err != nil {
|
||||
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "default-profile-picture"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -218,7 +219,7 @@ func (c *AppImagesController) updateDefaultProfilePicture(ctx *gin.Context) {
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/application-images/default-profile-picture [delete]
|
||||
func (c *AppImagesController) deleteDefaultProfilePicture(ctx *gin.Context) {
|
||||
if err := c.appImagesService.DeleteImage("default-profile-picture"); err != nil {
|
||||
if err := c.appImagesService.DeleteImage(ctx.Request.Context(), "default-profile-picture"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,16 +548,17 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
|
||||
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"), lightLogo)
|
||||
reader, size, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"), lightLogo)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
utils.SetCacheControlHeader(c, 15*time.Minute, 12*time.Hour)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
c.DataFromReader(http.StatusOK, size, mimeType, reader, nil)
|
||||
}
|
||||
|
||||
// updateClientLogoHandler godoc
|
||||
|
||||
@@ -286,7 +286,7 @@ func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) {
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
|
||||
if err := uc.userService.UpdateProfilePicture(c.Request.Context(), userID, file); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -317,7 +317,7 @@ func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
|
||||
if err := uc.userService.UpdateProfilePicture(c.Request.Context(), userID, file); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -687,7 +687,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
|
||||
if err := uc.userService.ResetProfilePicture(userID); err != nil {
|
||||
if err := uc.userService.ResetProfilePicture(c.Request.Context(), userID); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -705,7 +705,7 @@ func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
|
||||
func (uc *UserController) resetCurrentUserProfilePictureHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
|
||||
if err := uc.userService.ResetProfilePicture(userID); err != nil {
|
||||
if err := uc.userService.ResetProfilePicture(c.Request.Context(), userID); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,29 +2,36 @@ package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
)
|
||||
|
||||
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) error {
|
||||
jobs := &FileCleanupJobs{db: db}
|
||||
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
|
||||
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
|
||||
|
||||
// Run every 24 hours
|
||||
return s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
|
||||
err := s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
|
||||
|
||||
// Only necessary for file system storage
|
||||
if fileStorage.Type() == storage.TypeFileSystem {
|
||||
err = errors.Join(err, s.registerJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type FileCleanupJobs struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
fileStorage storage.FileStorage
|
||||
}
|
||||
|
||||
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
|
||||
@@ -44,29 +51,24 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
|
||||
initialsInUse[user.Initials()] = struct{}{}
|
||||
}
|
||||
|
||||
defaultPicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults"
|
||||
if _, err := os.Stat(defaultPicturesDir); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(defaultPicturesDir)
|
||||
defaultPicturesDir := path.Join("profile-pictures", "defaults")
|
||||
files, err := j.fileStorage.List(ctx, defaultPicturesDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read default profile pictures directory: %w", err)
|
||||
return fmt.Errorf("failed to list default profile pictures: %w", err)
|
||||
}
|
||||
|
||||
filesDeleted := 0
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue // Skip directories
|
||||
_, filename := path.Split(file.Path)
|
||||
if filename == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
filename := file.Name()
|
||||
initials := strings.TrimSuffix(filename, ".png")
|
||||
|
||||
// If these initials aren't used by any user, delete the file
|
||||
if _, ok := initialsInUse[initials]; !ok {
|
||||
filePath := filepath.Join(defaultPicturesDir, filename)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
filePath := path.Join(defaultPicturesDir, filename)
|
||||
if err := j.fileStorage.Delete(ctx, filePath); err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to delete unused default profile picture", slog.String("path", filePath), slog.Any("error", err))
|
||||
} else {
|
||||
filesDeleted++
|
||||
@@ -77,3 +79,34 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
|
||||
slog.Info("Done deleting unused default profile pictures", slog.Int("count", filesDeleted))
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearOrphanedTempFiles deletes temporary files that are produced by failed atomic writes
|
||||
func (j *FileCleanupJobs) clearOrphanedTempFiles(ctx context.Context) error {
|
||||
const minAge = 10 * time.Minute
|
||||
|
||||
var deleted int
|
||||
err := j.fileStorage.Walk(ctx, "/", func(p storage.ObjectInfo) error {
|
||||
// Only temp files
|
||||
if !strings.HasSuffix(p.Path, "-tmp") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if time.Since(p.ModTime) < minAge {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := j.fileStorage.Delete(ctx, p.Path); err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
deleted++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan storage: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("Done cleaning orphaned temp files", slog.Int("count", deleted))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,42 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
type AppImagesService struct {
|
||||
mu sync.RWMutex
|
||||
extensions map[string]string
|
||||
storage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewAppImagesService(extensions map[string]string) *AppImagesService {
|
||||
return &AppImagesService{extensions: extensions}
|
||||
func NewAppImagesService(extensions map[string]string, storage storage.FileStorage) *AppImagesService {
|
||||
return &AppImagesService{extensions: extensions, storage: storage}
|
||||
}
|
||||
|
||||
func (s *AppImagesService) GetImage(name string) (string, string, error) {
|
||||
func (s *AppImagesService) GetImage(ctx context.Context, name string) (io.ReadCloser, int64, string, error) {
|
||||
ext, err := s.getExtension(name)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
mimeType := utils.GetImageMimeType(ext)
|
||||
if mimeType == "" {
|
||||
return "", "", fmt.Errorf("unsupported image type '%s'", ext)
|
||||
return nil, 0, "", fmt.Errorf("unsupported image type '%s'", ext)
|
||||
}
|
||||
|
||||
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", name, ext))
|
||||
return imagePath, mimeType, nil
|
||||
imagePath := path.Join("application-images", name+"."+ext)
|
||||
reader, size, err := s.storage.Open(ctx, imagePath)
|
||||
if err != nil {
|
||||
if storage.IsNotExist(err) {
|
||||
return nil, 0, "", &common.ImageNotFoundError{}
|
||||
}
|
||||
return nil, 0, "", err
|
||||
}
|
||||
return reader, size, mimeType, nil
|
||||
}
|
||||
|
||||
func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName string) error {
|
||||
func (s *AppImagesService) UpdateImage(ctx context.Context, file *multipart.FileHeader, imageName string) error {
|
||||
fileType := strings.ToLower(utils.GetFileExtension(file.Filename))
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
if mimeType == "" {
|
||||
@@ -51,15 +61,20 @@ func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName str
|
||||
s.extensions[imageName] = fileType
|
||||
}
|
||||
|
||||
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, fileType))
|
||||
imagePath := path.Join("application-images", imageName+"."+fileType)
|
||||
fileReader, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
if err := utils.SaveFile(file, imagePath); err != nil {
|
||||
if err := s.storage.Save(ctx, imagePath, fileReader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if currentExt != "" && currentExt != fileType {
|
||||
oldImagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, currentExt))
|
||||
if err := os.Remove(oldImagePath); err != nil && !os.IsNotExist(err) {
|
||||
oldImagePath := path.Join("application-images", imageName+"."+currentExt)
|
||||
if err := s.storage.Delete(ctx, oldImagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -69,7 +84,7 @@ func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AppImagesService) DeleteImage(imageName string) error {
|
||||
func (s *AppImagesService) DeleteImage(ctx context.Context, imageName string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -78,8 +93,8 @@ func (s *AppImagesService) DeleteImage(imageName string) error {
|
||||
return &common.ImageNotFoundError{}
|
||||
}
|
||||
|
||||
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", imageName+"."+ext)
|
||||
if err := os.Remove(imagePath); err != nil && !os.IsNotExist(err) {
|
||||
imagePath := path.Join("application-images", imageName+"."+ext)
|
||||
if err := s.storage.Delete(ctx, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -2,66 +2,92 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
)
|
||||
|
||||
func TestAppImagesService_GetImage(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
originalUploadPath := common.EnvConfig.UploadPath
|
||||
common.EnvConfig.UploadPath = tempDir
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.UploadPath = originalUploadPath
|
||||
})
|
||||
|
||||
imagesDir := filepath.Join(tempDir, "application-images")
|
||||
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
|
||||
|
||||
filePath := filepath.Join(imagesDir, "background.webp")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("data"), fs.FileMode(0o644)))
|
||||
|
||||
service := NewAppImagesService(map[string]string{"background": "webp"})
|
||||
|
||||
path, mimeType, err := service.GetImage("background")
|
||||
store, err := storage.NewFilesystemStorage(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, filePath, path)
|
||||
|
||||
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "background.webp"), bytes.NewReader([]byte("data"))))
|
||||
|
||||
service := NewAppImagesService(map[string]string{"background": "webp"}, store)
|
||||
|
||||
reader, size, mimeType, err := service.GetImage(context.Background(), "background")
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
payload, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("data"), payload)
|
||||
require.Equal(t, int64(len(payload)), size)
|
||||
require.Equal(t, "image/webp", mimeType)
|
||||
}
|
||||
|
||||
func TestAppImagesService_UpdateImage(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
originalUploadPath := common.EnvConfig.UploadPath
|
||||
common.EnvConfig.UploadPath = tempDir
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.UploadPath = originalUploadPath
|
||||
})
|
||||
store, err := storage.NewFilesystemStorage(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
imagesDir := filepath.Join(tempDir, "application-images")
|
||||
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
|
||||
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "logoLight.svg"), bytes.NewReader([]byte("old"))))
|
||||
|
||||
oldPath := filepath.Join(imagesDir, "logoLight.svg")
|
||||
require.NoError(t, os.WriteFile(oldPath, []byte("old"), fs.FileMode(0o644)))
|
||||
|
||||
service := NewAppImagesService(map[string]string{"logoLight": "svg"})
|
||||
service := NewAppImagesService(map[string]string{"logoLight": "svg"}, store)
|
||||
|
||||
fileHeader := newFileHeader(t, "logoLight.png", []byte("new"))
|
||||
|
||||
require.NoError(t, service.UpdateImage(fileHeader, "logoLight"))
|
||||
require.NoError(t, service.UpdateImage(context.Background(), fileHeader, "logoLight"))
|
||||
|
||||
_, err := os.Stat(filepath.Join(imagesDir, "logoLight.png"))
|
||||
reader, _, err := store.Open(context.Background(), path.Join("application-images", "logoLight.png"))
|
||||
require.NoError(t, err)
|
||||
_ = reader.Close()
|
||||
|
||||
_, _, err = store.Open(context.Background(), path.Join("application-images", "logoLight.svg"))
|
||||
require.ErrorIs(t, err, fs.ErrNotExist)
|
||||
}
|
||||
|
||||
func TestAppImagesService_ErrorsAndFlags(t *testing.T) {
|
||||
store, err := storage.NewFilesystemStorage(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = os.Stat(oldPath)
|
||||
require.ErrorIs(t, err, os.ErrNotExist)
|
||||
service := NewAppImagesService(map[string]string{}, store)
|
||||
|
||||
t.Run("get missing image returns not found", func(t *testing.T) {
|
||||
_, _, _, err := service.GetImage(context.Background(), "missing")
|
||||
require.Error(t, err)
|
||||
var imageErr *common.ImageNotFoundError
|
||||
assert.ErrorAs(t, err, &imageErr)
|
||||
})
|
||||
|
||||
t.Run("reject unsupported file types", func(t *testing.T) {
|
||||
err := service.UpdateImage(context.Background(), newFileHeader(t, "logo.txt", []byte("nope")), "logo")
|
||||
require.Error(t, err)
|
||||
var fileTypeErr *common.FileTypeNotSupportedError
|
||||
assert.ErrorAs(t, err, &fileTypeErr)
|
||||
})
|
||||
|
||||
t.Run("delete and extension tracking", func(t *testing.T) {
|
||||
require.NoError(t, store.Save(context.Background(), path.Join("application-images", "default-profile-picture.png"), bytes.NewReader([]byte("img"))))
|
||||
service.extensions["default-profile-picture"] = "png"
|
||||
|
||||
require.NoError(t, service.DeleteImage(context.Background(), "default-profile-picture"))
|
||||
assert.False(t, service.IsDefaultProfilePictureSet())
|
||||
|
||||
err := service.DeleteImage(context.Background(), "default-profile-picture")
|
||||
require.Error(t, err)
|
||||
var imageErr *common.ImageNotFoundError
|
||||
assert.ErrorAs(t, err, &imageErr)
|
||||
})
|
||||
}
|
||||
|
||||
func newFileHeader(t *testing.T, filename string, content []byte) *multipart.FileHeader {
|
||||
|
||||
@@ -11,8 +11,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
@@ -35,15 +35,17 @@ type TestService struct {
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
ldapService *LdapService
|
||||
fileStorage storage.FileStorage
|
||||
externalIdPKey jwk.Key
|
||||
}
|
||||
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService, fileStorage storage.FileStorage) (*TestService, error) {
|
||||
s := &TestService{
|
||||
db: db,
|
||||
appConfigService: appConfigService,
|
||||
jwtService: jwtService,
|
||||
ldapService: ldapService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
err := s.initExternalIdP()
|
||||
if err != nil {
|
||||
@@ -424,8 +426,8 @@ func (s *TestService) ResetDatabase() error {
|
||||
}
|
||||
|
||||
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
||||
slog.ErrorContext(ctx, "Error removing directory", slog.Any("error", err))
|
||||
if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil {
|
||||
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -435,13 +437,19 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
srcFilePath := filepath.Join("images", file.Name())
|
||||
destFilePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", file.Name())
|
||||
|
||||
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
srcFilePath := path.Join("images", file.Name())
|
||||
srcFile, err := resources.FS.Open(srcFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil {
|
||||
srcFile.Close()
|
||||
return err
|
||||
}
|
||||
srcFile.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -470,7 +470,7 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
|
||||
}
|
||||
|
||||
// Update the profile picture
|
||||
err = s.userService.UpdateProfilePicture(userId, reader)
|
||||
err = s.userService.UpdateProfilePicture(parentCtx, userId, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update profile picture: %w", err)
|
||||
}
|
||||
|
||||
@@ -15,8 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -35,6 +34,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
@@ -59,8 +59,9 @@ type OidcService struct {
|
||||
customClaimService *CustomClaimService
|
||||
webAuthnService *WebAuthnService
|
||||
|
||||
httpClient *http.Client
|
||||
jwkCache *jwk.Cache
|
||||
httpClient *http.Client
|
||||
jwkCache *jwk.Cache
|
||||
fileStorage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewOidcService(
|
||||
@@ -72,6 +73,7 @@ func NewOidcService(
|
||||
customClaimService *CustomClaimService,
|
||||
webAuthnService *WebAuthnService,
|
||||
httpClient *http.Client,
|
||||
fileStorage storage.FileStorage,
|
||||
) (s *OidcService, err error) {
|
||||
s = &OidcService{
|
||||
db: db,
|
||||
@@ -81,6 +83,7 @@ func NewOidcService(
|
||||
customClaimService: customClaimService,
|
||||
webAuthnService: webAuthnService,
|
||||
httpClient: httpClient,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
|
||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||
@@ -884,34 +887,41 @@ func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string, light bool) (string, string, error) {
|
||||
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string, light bool) (io.ReadCloser, int64, string, error) {
|
||||
var client model.OidcClient
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
var imagePath, mimeType string
|
||||
|
||||
var suffix string
|
||||
var ext string
|
||||
switch {
|
||||
case !light && client.DarkImageType != nil:
|
||||
// Dark logo if requested and exists
|
||||
imagePath = common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "-dark." + *client.DarkImageType
|
||||
mimeType = utils.GetImageMimeType(*client.DarkImageType)
|
||||
|
||||
suffix = "-dark"
|
||||
ext = *client.DarkImageType
|
||||
case client.ImageType != nil:
|
||||
// Light logo if requested or no dark logo is available
|
||||
imagePath = common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
|
||||
mimeType = utils.GetImageMimeType(*client.ImageType)
|
||||
|
||||
ext = *client.ImageType
|
||||
default:
|
||||
return "", "", errors.New("image not found")
|
||||
return nil, 0, "", errors.New("image not found")
|
||||
}
|
||||
|
||||
return imagePath, mimeType, nil
|
||||
mimeType := utils.GetImageMimeType(ext)
|
||||
if mimeType == "" {
|
||||
return nil, 0, "", fmt.Errorf("unsupported image type '%s'", ext)
|
||||
}
|
||||
key := path.Join("oidc-client-images", client.ID+suffix+"."+ext)
|
||||
reader, size, err := s.fileStorage.Open(ctx, key)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
return reader, size, mimeType, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader, light bool) error {
|
||||
@@ -925,11 +935,15 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
||||
darkSuffix = "-dark"
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + darkSuffix + "." + fileType
|
||||
err := utils.SaveFile(file, imagePath)
|
||||
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+fileType)
|
||||
reader, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
|
||||
@@ -972,8 +986,8 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + oldImageType
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1015,8 +1029,8 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "-dark." + oldImageType
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2017,20 +2031,13 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
folderPath := filepath.Join(common.EnvConfig.UploadPath, "oidc-client-images")
|
||||
err = os.MkdirAll(folderPath, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var darkSuffix string
|
||||
if !light {
|
||||
darkSuffix = "-dark"
|
||||
}
|
||||
|
||||
imagePath := filepath.Join(folderPath, clientID+darkSuffix+"."+ext)
|
||||
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
|
||||
if err != nil {
|
||||
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
|
||||
if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2042,8 +2049,6 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
}
|
||||
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
|
||||
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
|
||||
|
||||
var darkSuffix string
|
||||
if !light {
|
||||
darkSuffix = "-dark"
|
||||
@@ -2053,9 +2058,15 @@ func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, cli
|
||||
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if client.ImageType != nil && *client.ImageType != ext {
|
||||
old := fmt.Sprintf("%s/%s%s.%s", uploadsDir, client.ID, darkSuffix, *client.ImageType)
|
||||
_ = os.Remove(old)
|
||||
var currentType *string
|
||||
if light {
|
||||
currentType = client.ImageType
|
||||
} else {
|
||||
currentType = client.DarkImageType
|
||||
}
|
||||
if currentType != nil && *currentType != ext {
|
||||
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
|
||||
var column string
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
||||
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
||||
@@ -35,9 +36,10 @@ type UserService struct {
|
||||
appConfigService *AppConfigService
|
||||
customClaimService *CustomClaimService
|
||||
appImagesService *AppImagesService
|
||||
fileStorage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService) *UserService {
|
||||
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService, fileStorage storage.FileStorage) *UserService {
|
||||
return &UserService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
@@ -46,6 +48,7 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
|
||||
appConfigService: appConfigService,
|
||||
customClaimService: customClaimService,
|
||||
appImagesService: appImagesService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,34 +98,32 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
profilePicturePath := filepath.Join(common.EnvConfig.UploadPath, "profile-pictures", userID+".png")
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
|
||||
// Try custom profile picture
|
||||
if file, size, err := utils.OpenFileWithSize(profilePicturePath); err == nil {
|
||||
if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
|
||||
return file, size, nil
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Try default global profile picture
|
||||
if s.appImagesService.IsDefaultProfilePictureSet() {
|
||||
path, _, err := s.appImagesService.GetImage("default-profile-picture")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
reader, size, _, err := s.appImagesService.GetImage(ctx, "default-profile-picture")
|
||||
if err == nil {
|
||||
return reader, size, nil
|
||||
}
|
||||
if file, size, err := utils.OpenFileWithSize(path); err == nil {
|
||||
return file, size, nil
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
if !errors.Is(err, &common.ImageNotFoundError{}) {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Try cached default for initials
|
||||
defaultProfilePicturesDir := filepath.Join(common.EnvConfig.UploadPath, "profile-pictures", "defaults")
|
||||
defaultPicturePath := filepath.Join(defaultProfilePicturesDir, user.Initials()+".png")
|
||||
|
||||
if file, size, err := utils.OpenFileWithSize(defaultPicturePath); err == nil {
|
||||
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
|
||||
if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
|
||||
return file, size, nil
|
||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Create and return generated default with initials
|
||||
@@ -132,13 +133,11 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
||||
}
|
||||
|
||||
// Save the default picture for future use (in a goroutine to avoid blocking)
|
||||
//nolint:contextcheck
|
||||
defaultPictureBytes := defaultPicture.Bytes()
|
||||
//nolint:contextcheck
|
||||
go func() {
|
||||
if err := os.MkdirAll(defaultProfilePicturesDir, os.ModePerm); err != nil {
|
||||
slog.Error("Failed to create directory for default profile picture", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
|
||||
if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
|
||||
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
@@ -160,7 +159,7 @@ func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model
|
||||
return user.UserGroups, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
|
||||
func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, file io.Reader) error {
|
||||
// Validate the user ID to prevent directory traversal
|
||||
err := uuid.Validate(userID)
|
||||
if err != nil {
|
||||
@@ -173,15 +172,8 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
|
||||
err = os.MkdirAll(profilePictureDir, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the profile picture file
|
||||
err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
err = s.fileStorage.Save(ctx, profilePicturePath, profilePicture)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -212,10 +204,8 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
||||
return &common.LdapUserUpdateError{}
|
||||
}
|
||||
|
||||
// Delete the profile picture
|
||||
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
|
||||
err = os.Remove(profilePicturePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -676,26 +666,16 @@ func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User
|
||||
}
|
||||
|
||||
// ResetProfilePicture deletes a user's custom profile picture
|
||||
func (s *UserService) ResetProfilePicture(userID string) error {
|
||||
func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) error {
|
||||
// Validate the user ID to prevent directory traversal
|
||||
if err := uuid.Validate(userID); err != nil {
|
||||
return &common.InvalidUUIDError{}
|
||||
}
|
||||
|
||||
// Build path to profile picture
|
||||
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
|
||||
|
||||
// Check if file exists and delete it
|
||||
if _, err := os.Stat(profilePicturePath); err == nil {
|
||||
if err := os.Remove(profilePicturePath); err != nil {
|
||||
return fmt.Errorf("failed to delete profile picture: %w", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
// If any error other than "file not exists"
|
||||
return fmt.Errorf("failed to check if profile picture exists: %w", err)
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
|
||||
return fmt.Errorf("failed to delete profile picture: %w", err)
|
||||
}
|
||||
// It's okay if the file doesn't exist - just means there's no custom picture to delete
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
193
backend/internal/storage/filesystem.go
Normal file
193
backend/internal/storage/filesystem.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type filesystemStorage struct {
|
||||
root *os.Root
|
||||
absoluteRootPath string
|
||||
}
|
||||
|
||||
func NewFilesystemStorage(rootPath string) (FileStorage, error) {
|
||||
if err := os.MkdirAll(rootPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create root directory '%s': %w", rootPath, err)
|
||||
}
|
||||
root, err := os.OpenRoot(rootPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open root directory '%s': %w", rootPath, err)
|
||||
}
|
||||
|
||||
absoluteRootPath, err := filepath.Abs(rootPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path of root directory '%s': %w", rootPath, err)
|
||||
}
|
||||
|
||||
return &filesystemStorage{root: root, absoluteRootPath: absoluteRootPath}, err
|
||||
}
|
||||
|
||||
func (s *filesystemStorage) Type() string {
|
||||
return TypeFileSystem
|
||||
}
|
||||
|
||||
func (s *filesystemStorage) Save(_ context.Context, path string, data io.Reader) error {
|
||||
path = filepath.FromSlash(path)
|
||||
|
||||
if err := s.root.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directories for path '%s': %w", path, err)
|
||||
}
|
||||
|
||||
// Our strategy is to save to a separate file and then rename it to override the original file
|
||||
tmpName := path + "." + uuid.NewString() + "-tmp"
|
||||
|
||||
// Write to the temporary file
|
||||
tmpFile, err := s.root.Create(tmpName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpName, err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(tmpFile, data)
|
||||
if err != nil {
|
||||
tmpFile.Close()
|
||||
_ = s.root.Remove(tmpName)
|
||||
return fmt.Errorf("failed to write temporary file: %w", err)
|
||||
}
|
||||
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = s.root.Remove(tmpName)
|
||||
return fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
// Rename to the final file, which overrides existing files
|
||||
// This is an atomic operation
|
||||
if err = s.root.Rename(tmpName, path); err != nil {
|
||||
_ = s.root.Remove(tmpName)
|
||||
return fmt.Errorf("failed to move temporary file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *filesystemStorage) Open(_ context.Context, path string) (io.ReadCloser, int64, error) {
|
||||
path = filepath.FromSlash(path)
|
||||
|
||||
file, err := s.root.Open(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return file, info.Size(), nil
|
||||
}
|
||||
|
||||
func (s *filesystemStorage) Delete(_ context.Context, path string) error {
|
||||
path = filepath.FromSlash(path)
|
||||
|
||||
err := s.root.Remove(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *filesystemStorage) DeleteAll(_ context.Context, path string) error {
|
||||
path = filepath.FromSlash(path)
|
||||
|
||||
// If "/", "." or "" is requested, we delete all contents of the root.
|
||||
if path == "" || path == "/" || path == "." {
|
||||
dir, err := s.root.Open(".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open root directory: %w", err)
|
||||
}
|
||||
defer dir.Close()
|
||||
|
||||
entries, err := dir.ReadDir(-1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list root directory: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if err := s.root.RemoveAll(entry.Name()); err != nil {
|
||||
return fmt.Errorf("failed to delete '%s': %w", entry.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.root.RemoveAll(path)
|
||||
}
|
||||
func (s *filesystemStorage) List(_ context.Context, path string) ([]ObjectInfo, error) {
|
||||
path = filepath.FromSlash(path)
|
||||
|
||||
dir, err := s.root.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dir.Close()
|
||||
|
||||
entries, err := dir.ReadDir(-1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
objects := make([]ObjectInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, ObjectInfo{
|
||||
Path: filepath.Join(path, entry.Name()),
|
||||
Size: info.Size(),
|
||||
ModTime: info.ModTime(),
|
||||
})
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
func (s *filesystemStorage) Walk(_ context.Context, root string, fn func(ObjectInfo) error) error {
|
||||
root = filepath.FromSlash(root)
|
||||
|
||||
fullPath := filepath.Clean(filepath.Join(s.absoluteRootPath, root))
|
||||
|
||||
// As we can't use os.Root here, we manually ensure that the fullPath is within the root directory
|
||||
sep := string(filepath.Separator)
|
||||
if !strings.HasPrefix(fullPath+sep, s.absoluteRootPath+sep) {
|
||||
return fmt.Errorf("invalid root path: %s", root)
|
||||
}
|
||||
|
||||
return filepath.WalkDir(fullPath, func(full string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(s.absoluteRootPath, full)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fn(ObjectInfo{
|
||||
Path: filepath.ToSlash(rel),
|
||||
Size: info.Size(),
|
||||
ModTime: info.ModTime(),
|
||||
})
|
||||
})
|
||||
}
|
||||
68
backend/internal/storage/filesystem_test.go
Normal file
68
backend/internal/storage/filesystem_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFilesystemStorageOperations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store, err := NewFilesystemStorage(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("save, open and list files", func(t *testing.T) {
|
||||
err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, size, err := store.Open(ctx, "images/logo.png")
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
|
||||
contents, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("logo-data"), contents)
|
||||
assert.Equal(t, int64(len(contents)), size)
|
||||
|
||||
err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child"))
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err := store.List(ctx, "images")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1)
|
||||
assert.Equal(t, filepath.Join("images", "logo.png"), files[0].Path)
|
||||
assert.Equal(t, int64(len("logo-data")), files[0].Size)
|
||||
})
|
||||
|
||||
t.Run("delete files individually and idempotently", func(t *testing.T) {
|
||||
err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, store.Delete(ctx, "images/delete-me.txt"))
|
||||
_, _, err = store.Open(ctx, "images/delete-me.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
// Deleting a missing object should be a no-op.
|
||||
require.NoError(t, store.Delete(ctx, "images/missing.txt"))
|
||||
})
|
||||
|
||||
t.Run("delete all files under a prefix", func(t *testing.T) {
|
||||
require.NoError(t, store.Save(ctx, "images/a.txt", bytes.NewBufferString("a")))
|
||||
require.NoError(t, store.Save(ctx, "images/b.txt", bytes.NewBufferString("b")))
|
||||
require.NoError(t, store.DeleteAll(ctx, "images"))
|
||||
|
||||
_, _, err := store.Open(ctx, "images/a.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
_, _, err = store.Open(ctx, "images/b.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
})
|
||||
}
|
||||
185
backend/internal/storage/s3.go
Normal file
185
backend/internal/storage/s3.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awscfg "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
type S3Config struct {
|
||||
Bucket string
|
||||
Region string
|
||||
Endpoint string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
ForcePathStyle bool
|
||||
Root string
|
||||
}
|
||||
|
||||
type s3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewS3Storage(ctx context.Context, cfg S3Config) (FileStorage, error) {
|
||||
creds := credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")
|
||||
awsCfg, err := awscfg.LoadDefaultConfig(ctx, awscfg.WithRegion(cfg.Region), awscfg.WithCredentialsProvider(creds))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load AWS configuration: %w", err)
|
||||
}
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||||
}
|
||||
o.UsePathStyle = cfg.ForcePathStyle
|
||||
})
|
||||
|
||||
return &s3Storage{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: strings.Trim(cfg.Root, "/"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *s3Storage) Type() string {
|
||||
return TypeS3
|
||||
}
|
||||
|
||||
func (s *s3Storage) Save(ctx context.Context, path string, data io.Reader) error {
|
||||
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s.buildObjectKey(path)),
|
||||
Body: data,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *s3Storage) Open(ctx context.Context, path string) (io.ReadCloser, int64, error) {
|
||||
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s.buildObjectKey(path)),
|
||||
})
|
||||
if err != nil {
|
||||
if isS3NotFound(err) {
|
||||
return nil, 0, fs.ErrNotExist
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
return resp.Body, aws.ToInt64(resp.ContentLength), nil
|
||||
}
|
||||
|
||||
func (s *s3Storage) Delete(ctx context.Context, path string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s.buildObjectKey(path)),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *s3Storage) DeleteAll(ctx context.Context, path string) error {
|
||||
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(s.buildObjectKey(path)),
|
||||
})
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(page.Contents) == 0 {
|
||||
continue
|
||||
}
|
||||
objects := make([]s3types.ObjectIdentifier, 0, len(page.Contents))
|
||||
for _, obj := range page.Contents {
|
||||
objects = append(objects, s3types.ObjectIdentifier{Key: obj.Key})
|
||||
}
|
||||
_, err = s.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Delete: &s3types.Delete{Objects: objects, Quiet: aws.Bool(true)},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *s3Storage) List(ctx context.Context, path string) ([]ObjectInfo, error) {
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(s.buildObjectKey(path)),
|
||||
})
|
||||
var objects []ObjectInfo
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, obj := range page.Contents {
|
||||
if obj.Key == nil {
|
||||
continue
|
||||
}
|
||||
objects = append(objects, ObjectInfo{
|
||||
Path: aws.ToString(obj.Key),
|
||||
Size: aws.ToInt64(obj.Size),
|
||||
ModTime: aws.ToTime(obj.LastModified),
|
||||
})
|
||||
}
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
func (s *s3Storage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
||||
objects, err := s.List(ctx, root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, obj := range objects {
|
||||
if err := fn(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *s3Storage) buildObjectKey(p string) string {
|
||||
p = filepath.Clean(p)
|
||||
p = filepath.ToSlash(p)
|
||||
p = strings.Trim(p, "/")
|
||||
|
||||
if p == "" || p == "." {
|
||||
return s.prefix
|
||||
}
|
||||
|
||||
if s.prefix == "" {
|
||||
return p
|
||||
}
|
||||
|
||||
return s.prefix + "/" + p
|
||||
}
|
||||
|
||||
func isS3NotFound(err error) bool {
|
||||
var apiErr smithy.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
if apiErr.ErrorCode() == "NotFound" || apiErr.ErrorCode() == "NoSuchKey" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
var missingKey *s3types.NoSuchKey
|
||||
return errors.As(err, &missingKey)
|
||||
}
|
||||
44
backend/internal/storage/s3_test.go
Normal file
44
backend/internal/storage/s3_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestS3Helpers(t *testing.T) {
|
||||
t.Run("buildObjectKey trims and joins prefix", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{name: "no prefix no path", prefix: "", path: "", expected: ""},
|
||||
{name: "prefix no path", prefix: "root", path: "", expected: "root"},
|
||||
{name: "prefix with nested path", prefix: "root", path: "foo/bar/baz", expected: "root/foo/bar/baz"},
|
||||
{name: "trimmed path and prefix", prefix: "root", path: "/foo//bar/", expected: "root/foo/bar"},
|
||||
{name: "no prefix path only", prefix: "", path: "./images/logo.png", expected: "images/logo.png"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &s3Storage{
|
||||
bucket: "bucket",
|
||||
prefix: tc.prefix,
|
||||
}
|
||||
assert.Equal(t, tc.expected, s.buildObjectKey(tc.path))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("isS3NotFound detects expected errors", func(t *testing.T) {
|
||||
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NoSuchKey"}))
|
||||
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NotFound"}))
|
||||
assert.True(t, isS3NotFound(&s3types.NoSuchKey{}))
|
||||
assert.False(t, isS3NotFound(errors.New("boom")))
|
||||
})
|
||||
}
|
||||
33
backend/internal/storage/storage.go
Normal file
33
backend/internal/storage/storage.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
TypeFileSystem = "fs"
|
||||
TypeS3 = "s3"
|
||||
)
|
||||
|
||||
type ObjectInfo struct {
|
||||
Path string
|
||||
Size int64
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
type FileStorage interface {
|
||||
Save(ctx context.Context, relativePath string, data io.Reader) error
|
||||
Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error)
|
||||
Delete(ctx context.Context, relativePath string) error
|
||||
DeleteAll(ctx context.Context, prefix string) error
|
||||
List(ctx context.Context, prefix string) ([]ObjectInfo, error)
|
||||
Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error
|
||||
Type() string
|
||||
}
|
||||
|
||||
func IsNotExist(err error) bool {
|
||||
return os.IsNotExist(err)
|
||||
}
|
||||
@@ -2,20 +2,15 @@ package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
func GetFileExtension(filename string) string {
|
||||
@@ -86,110 +81,6 @@ func GetImageExtensionFromMimeType(mimeType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
|
||||
srcFile, err := resources.FS.Open(srcFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open embedded file: %w", err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(destFilePath), os.ModePerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
destFile, err := os.Create(destFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open destination file: %w", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to destination file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmbeddedFileSha256(filePath string) ([]byte, error) {
|
||||
f, err := resources.FS.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open embedded file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
_, err = io.Copy(h, f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read embedded file: %w", err)
|
||||
}
|
||||
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
func SaveFile(file *multipart.FileHeader, dst string) error {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return SaveFileStream(src, dst)
|
||||
}
|
||||
|
||||
// SaveFileStream saves a stream to a file.
|
||||
func SaveFileStream(r io.Reader, dstFileName string) error {
|
||||
// Our strategy is to save to a separate file and then rename it to override the original file
|
||||
tmpFileName := dstFileName + "." + uuid.NewString() + "-tmp"
|
||||
|
||||
// Write to the temporary file
|
||||
tmpFile, err := os.Create(tmpFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
|
||||
}
|
||||
|
||||
n, err := io.Copy(tmpFile, r)
|
||||
if err != nil {
|
||||
// Delete the temporary file; we ignore errors here
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFileName)
|
||||
|
||||
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
|
||||
}
|
||||
|
||||
err = tmpFile.Close()
|
||||
if err != nil {
|
||||
// Delete the temporary file; we ignore errors here
|
||||
_ = os.Remove(tmpFileName)
|
||||
|
||||
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
// Delete the temporary file; we ignore errors here
|
||||
_ = os.Remove(tmpFileName)
|
||||
|
||||
return errors.New("no data written")
|
||||
}
|
||||
|
||||
// Rename to the final file, which overrides existing files
|
||||
// This is an atomic operation
|
||||
err = os.Rename(tmpFileName, dstFileName)
|
||||
if err != nil {
|
||||
// Delete the temporary file; we ignore errors here
|
||||
_ = os.Remove(tmpFileName)
|
||||
|
||||
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileExists returns true if a file exists on disk and is a regular file
|
||||
func FileExists(path string) (bool, error) {
|
||||
s, err := os.Stat(path)
|
||||
@@ -239,17 +130,3 @@ func IsWritableDir(dir string) (bool, error) {
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OpenFileWithSize opens a file and returns its size
|
||||
func OpenFileWithSize(path string) (io.ReadCloser, int64, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, info.Size(), nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
const profilePictureSize = 300
|
||||
|
||||
// CreateProfilePicture resizes the profile picture to a square
|
||||
func CreateProfilePicture(file io.Reader) (io.Reader, error) {
|
||||
func CreateProfilePicture(file io.Reader) (io.ReadSeeker, error) {
|
||||
img, _, err := imageorient.Decode(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode image: %w", err)
|
||||
@@ -27,17 +27,13 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
|
||||
|
||||
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
innerErr := imaging.Encode(pw, img, imaging.PNG)
|
||||
if innerErr != nil {
|
||||
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", innerErr))
|
||||
return
|
||||
}
|
||||
pw.Close()
|
||||
}()
|
||||
var buf bytes.Buffer
|
||||
err = imaging.Encode(&buf, img, imaging.PNG)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode image: %w", err)
|
||||
}
|
||||
|
||||
return pr, nil
|
||||
return bytes.NewReader(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// CreateDefaultProfilePicture creates a profile picture with the initials
|
||||
|
||||
Reference in New Issue
Block a user