1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-16 19:20:17 +00:00

fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -1,7 +1,8 @@
package service
import (
"fmt"
"context"
"errors"
"log"
"mime/multipart"
"os"
@@ -19,12 +20,14 @@ type AppConfigService struct {
db *gorm.DB
}
func NewAppConfigService(db *gorm.DB) *AppConfigService {
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db,
}
if err := service.InitDbConfig(); err != nil {
err := service.InitDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
@@ -197,17 +200,24 @@ var defaultDbConfig = model.AppConfig{
},
}
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
if common.EnvConfig.UiConfigDisabled {
return nil, &common.UiConfigDisabledError{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var err error
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
var savedConfigVariables []model.AppConfigVariable
for i := 0; i < rt.NumField(); i++ {
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
@@ -220,32 +230,47 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
}
var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
First(&appConfigVariable, "key = ? AND is_internal = false", key).
Error
if err != nil {
return nil, err
}
appConfigVariable.Value = value
if err := tx.Save(&appConfigVariable).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
Save(&appConfigVariable).
Error
if err != nil {
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
}
tx.Commit()
err = tx.Commit().Error
if err != nil {
return nil, err
}
if err := s.LoadDbConfigFromDb(); err != nil {
err = s.LoadDbConfigFromDb()
if err != nil {
return nil, err
}
return savedConfigVariables, nil
}
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
key := fmt.Sprintf("%sImageType", imageName)
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error {
key := imageName + "ImageType"
err := s.db.
WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Update("value", fileType).
Error
if err != nil {
return err
}
@@ -253,14 +278,17 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er
return s.LoadDbConfigFromDb()
}
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) {
if showAll {
err = s.db.Find(&configuration).Error
err = s.db.
WithContext(ctx).
Find(&configuration).
Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
err = s.db.
WithContext(ctx).
Find(&configuration, "is_public = true").
Error
}
if err != nil {
@@ -271,7 +299,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
@@ -281,7 +308,7 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
return configuration, nil
}
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
@@ -290,19 +317,22 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
if err := os.Remove(oldImagePath); err != nil {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath)
if err != nil {
return err
}
// Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil {
err = s.updateImageType(ctx, imageName, fileType)
if err != nil {
return err
}
@@ -312,33 +342,58 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig() error {
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
for i := range defaultConfigReflectValue.NumField() {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
err = tx.
WithContext(ctx).
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&defaultConfigVar).
Error
if err != nil {
return err
}
continue
} else if err != nil {
return err
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
if storedConfigVar.Type != defaultConfigVar.Type ||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
// Set values
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
if err := s.db.Save(&storedConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&storedConfigVar).
Error
if err != nil {
return err
}
}
@@ -346,43 +401,68 @@ func (s *AppConfigService) InitDbConfig() error {
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
err = tx.
WithContext(ctx).
Find(&allConfigVars).
Error
if err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
return err
}
if _, exists := defaultKeys[config.Key]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&config).
Error
if err != nil {
return err
}
}
return s.LoadDbConfigFromDb()
// Commit the changes
err = tx.Commit().Error
if err != nil {
return err
}
// Reload the configuration
err = s.LoadDbConfigFromDb()
if err != nil {
return err
}
return nil
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
return s.db.Transaction(func(tx *gorm.DB) error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
for i := range dbConfigReflectValue.NumField() {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error
if err != nil {
return err
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
return nil
})
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {