1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 19:15:04 +00:00

fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -1,8 +1,8 @@
package service
import (
"context"
"errors"
"log"
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
@@ -12,6 +12,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type ApiKeyService struct {
@@ -22,8 +23,11 @@ func NewApiKeyService(db *gorm.DB) *ApiKeyService {
return &ApiKeyService{db: db}
}
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{})
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.
WithContext(ctx).
Where("user_id = ?", userID).
Model(&model.ApiKey{})
var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
@@ -34,7 +38,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
return apiKeys, pagination, nil
}
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
@@ -54,7 +58,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
UserID: userID,
}
if err := s.db.Create(&apiKey).Error; err != nil {
err = s.db.
WithContext(ctx).
Create(&apiKey).
Error
if err != nil {
return model.ApiKey{}, "", err
}
@@ -62,29 +70,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
return apiKey, token, nil
}
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error {
func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil {
err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", apiKeyID, userID).
Delete(&apiKey).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{}
}
return err
}
return s.db.Delete(&apiKey).Error
return nil
}
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{}
}
var key model.ApiKey
now := time.Now()
hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?",
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil {
var key model.ApiKey
err := s.db.
WithContext(ctx).
Model(&model.ApiKey{}).
Clauses(clause.Returning{}).
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
Updates(&model.ApiKey{
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
}).
Preload("User").
First(&key).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{}
}
@@ -92,12 +115,5 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
return model.User{}, err
}
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil
}

View File

@@ -1,7 +1,8 @@
package service
import (
"fmt"
"context"
"errors"
"log"
"mime/multipart"
"os"
@@ -19,12 +20,14 @@ type AppConfigService struct {
db *gorm.DB
}
func NewAppConfigService(db *gorm.DB) *AppConfigService {
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db,
}
if err := service.InitDbConfig(); err != nil {
err := service.InitDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
@@ -197,17 +200,24 @@ var defaultDbConfig = model.AppConfig{
},
}
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
if common.EnvConfig.UiConfigDisabled {
return nil, &common.UiConfigDisabledError{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var err error
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
var savedConfigVariables []model.AppConfigVariable
for i := 0; i < rt.NumField(); i++ {
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
@@ -220,32 +230,47 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
}
var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
First(&appConfigVariable, "key = ? AND is_internal = false", key).
Error
if err != nil {
return nil, err
}
appConfigVariable.Value = value
if err := tx.Save(&appConfigVariable).Error; err != nil {
tx.Rollback()
err = tx.
WithContext(ctx).
Save(&appConfigVariable).
Error
if err != nil {
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
}
tx.Commit()
err = tx.Commit().Error
if err != nil {
return nil, err
}
if err := s.LoadDbConfigFromDb(); err != nil {
err = s.LoadDbConfigFromDb()
if err != nil {
return nil, err
}
return savedConfigVariables, nil
}
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
key := fmt.Sprintf("%sImageType", imageName)
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error {
key := imageName + "ImageType"
err := s.db.
WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Update("value", fileType).
Error
if err != nil {
return err
}
@@ -253,14 +278,17 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er
return s.LoadDbConfigFromDb()
}
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) {
if showAll {
err = s.db.Find(&configuration).Error
err = s.db.
WithContext(ctx).
Find(&configuration).
Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
err = s.db.
WithContext(ctx).
Find(&configuration, "is_public = true").
Error
}
if err != nil {
@@ -271,7 +299,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
@@ -281,7 +308,7 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl
return configuration, nil
}
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
@@ -290,19 +317,22 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
if err := os.Remove(oldImagePath); err != nil {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath)
if err != nil {
return err
}
// Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil {
err = s.updateImageType(ctx, imageName, fileType)
if err != nil {
return err
}
@@ -312,33 +342,58 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig() error {
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
for i := range defaultConfigReflectValue.NumField() {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
err = tx.
WithContext(ctx).
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&defaultConfigVar).
Error
if err != nil {
return err
}
continue
} else if err != nil {
return err
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
if storedConfigVar.Type != defaultConfigVar.Type ||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
// Set values
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
if err := s.db.Save(&storedConfigVar).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&storedConfigVar).
Error
if err != nil {
return err
}
}
@@ -346,43 +401,68 @@ func (s *AppConfigService) InitDbConfig() error {
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
err = tx.
WithContext(ctx).
Find(&allConfigVars).
Error
if err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
return err
}
if _, exists := defaultKeys[config.Key]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&config).
Error
if err != nil {
return err
}
}
return s.LoadDbConfigFromDb()
// Commit the changes
err = tx.Commit().Error
if err != nil {
return err
}
// Reload the configuration
err = s.LoadDbConfigFromDb()
if err != nil {
return err
}
return nil
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
return s.db.Transaction(func(tx *gorm.DB) error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
for i := range dbConfigReflectValue.NumField() {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error
if err != nil {
return err
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
return nil
})
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {

View File

@@ -25,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
}
// Create creates a new audit log entry in the database
func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog {
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil {
log.Printf("Failed to get IP location: %v\n", err)
log.Printf("Failed to get IP location: %v", err)
}
auditLog := model.AuditLog{
@@ -42,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
}
// Save the audit log in the database
if err := s.db.Create(&auditLog).Error; err != nil {
log.Printf("Failed to create audit log: %v\n", err)
err = tx.
WithContext(ctx).
Create(&auditLog).
Error
if err != nil {
log.Printf("Failed to create audit log: %v", err)
return model.AuditLog{}
}
@@ -51,12 +55,17 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
}
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog {
createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{})
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
// Count the number of times the user has logged in from the same device
var count int64
err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error
err := tx.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
Count(&count).
Error
if err != nil {
log.Printf("Failed to count audit logs: %v\n", err)
return createdAuditLog
@@ -64,11 +73,23 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
var user model.User
s.db.Where("id = ?", userID).First(&user)
innerCtx := context.Background()
err := SendEmail(s.emailService, email.Address{
// Note we don't use the transaction here because this is running in background
var user model.User
innerErr := s.db.
WithContext(innerCtx).
Where("id = ?", userID).
First(&user).
Error
if innerErr != nil {
log.Printf("Failed to load user: %v", innerErr)
}
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
Name: user.Username,
Email: user.Email,
}, NewLoginTemplate, &NewLoginTemplateData{
@@ -78,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
Device: s.DeviceStringFromUserAgent(userAgent),
DateTime: createdAuditLog.CreatedAt.UTC(),
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
if innerErr != nil {
log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
}
}()
}
@@ -88,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
}
// ListAuditLogsForUser retrieves all audit logs for a given user ID
func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog
query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID)
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ?", userID)
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
return logs, pagination, err
@@ -162,19 +186,19 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s
}
func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
dialect := s.db.Name()
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{})
dialect := s.db.Name()
switch dialect {
case "sqlite":
query = query.
Select("DISTINCT json_extract(data, '$.clientName') as client_name").
Select("DISTINCT json_extract(data, '$.clientName') AS client_name").
Where("json_extract(data, '$.clientName') IS NOT NULL")
case "postgres":
query = query.
Select("DISTINCT data->>'clientName' as client_name").
Select("DISTINCT data->>'clientName' AS client_name").
Where("data->>'clientName' IS NOT NULL")
default:
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)

View File

@@ -1,34 +1,14 @@
package service
import (
"context"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
)
// Reserved claims
var reservedClaims = map[string]struct{}{
"given_name": {},
"family_name": {},
"name": {},
"email": {},
"preferred_username": {},
"groups": {},
"sub": {},
"iss": {},
"aud": {},
"exp": {},
"iat": {},
"auth_time": {},
"nonce": {},
"acr": {},
"amr": {},
"azp": {},
"nbf": {},
"jti": {},
}
type CustomClaimService struct {
db *gorm.DB
}
@@ -39,8 +19,29 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
func isReservedClaim(key string) bool {
_, ok := reservedClaims[key]
return ok
switch key {
case "given_name",
"family_name",
"name",
"email",
"preferred_username",
"groups",
"sub",
"iss",
"aud",
"exp",
"iat",
"auth_time",
"nonce",
"acr",
"amr",
"azp",
"nbf",
"jti":
return true
default:
return false
}
}
// idType is the type of the id used to identify the user or user group
@@ -52,28 +53,38 @@ const (
)
// UpdateCustomClaimsForUser updates the custom claims for a user
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserID, userID, claims)
func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(ctx, UserID, userID, claims)
}
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserGroupID, userGroupID, claims)
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
}
// updateCustomClaims updates the custom claims for a user or user group
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// Check for duplicate keys in the claims slice
seenKeys := make(map[string]bool)
seenKeys := make(map[string]struct{})
for _, claim := range claims {
if seenKeys[claim.Key] {
if _, ok := seenKeys[claim.Key]; ok {
return nil, &common.DuplicateClaimError{Key: claim.Key}
}
seenKeys[claim.Key] = true
seenKeys[claim.Key] = struct{}{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var existingClaims []model.CustomClaim
err := s.db.Where(string(idType), value).Find(&existingClaims).Error
err := tx.
WithContext(ctx).
Where(string(idType), value).
Find(&existingClaims).
Error
if err != nil {
return nil, err
}
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
break
}
}
if !found {
err = s.db.Delete(&existingClaim).Error
err = tx.
WithContext(ctx).
Delete(&existingClaim).
Error
if err != nil {
return nil, err
}
@@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
}
// Update the claim if it already exists or create a new one
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error
err = tx.
WithContext(ctx).
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
Assign(&customClaim).
FirstOrCreate(&model.CustomClaim{}).
Error
if err != nil {
return nil, err
}
@@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
// Get the updated claims
var updatedClaims []model.CustomClaim
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error
err = tx.
WithContext(ctx).
Where(string(idType)+" = ?", value).
Find(&updatedClaims).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
@@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
return updatedClaims, nil
}
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error
err := tx.
WithContext(ctx).
Where("user_id = ?", userID).
Find(&customClaims).
Error
return customClaims, err
}
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error
err := tx.
WithContext(ctx).
Where("user_group_id = ?", userGroupID).
Find(&customClaims).
Error
return customClaims, err
}
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
// prioritizing the user's claims over user group claims with the same key.
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) {
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
// Get the custom claims of the user
customClaims, err := s.GetCustomClaimsForUser(userID)
customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
if err != nil {
return nil, err
}
@@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
// Get all user groups of the user
var userGroupsOfUser []model.UserGroup
err = s.db.Preload("CustomClaims").
err = tx.
WithContext(ctx).
Preload("CustomClaims").
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
Where("user_groups_users.user_id = ?", userID).
Find(&userGroupsOfUser).Error
@@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
}
// GetSuggestions returns a list of custom claim keys that have been used before
func (s *CustomClaimService) GetSuggestions() ([]string, error) {
func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
var customClaimsKeys []string
err := s.db.Model(&model.CustomClaim{}).
err := s.db.
WithContext(ctx).
Model(&model.CustomClaim{}).
Group("key").
Order("COUNT(*) DESC").
Pluck("key", &customClaimsKeys).Error

View File

@@ -3,6 +3,7 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
@@ -34,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
}
//nolint:gocognit
func (s *TestService) SeedDatabase() error {
return s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
@@ -187,11 +189,8 @@ func (s *TestService) SeedDatabase() error {
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
if err != nil {
return err
}
publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
webauthnCredentials := []model.WebauthnCredential{
{
Name: "Passkey 1",
@@ -303,7 +302,7 @@ func (s *TestService) ResetApplicationImages() error {
func (s *TestService) ResetAppConfig() error {
// Reseed the config variables
if err := s.appConfigService.InitDbConfig(); err != nil {
if err := s.appConfigService.InitDbConfig(context.Background()); err != nil {
return err
}
@@ -320,7 +319,7 @@ func (s *TestService) SetJWTKeys() {
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey)
_ = s.jwtService.SetKey(privateKey)
}
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key

View File

@@ -2,10 +2,12 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
htemplate "html/template"
"io"
"mime/multipart"
"mime/quotedprintable"
"net/textproto"
@@ -17,10 +19,11 @@ import (
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
)
type EmailService struct {
@@ -49,20 +52,24 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer
}, nil
}
func (srv *EmailService) SendTestEmail(recipientUserId string) error {
func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error {
var user model.User
if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil {
err := srv.db.
WithContext(ctx).
First(&user, "id = ?", recipientUserId).
Error
if err != nil {
return err
}
return SendEmail(srv,
return SendEmail(ctx, srv,
email.Address{
Email: user.Email,
Name: user.FullName(),
}, TestTemplate, nil)
}
func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
@@ -112,6 +119,15 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.Body(body)
// Check if the context is still valid before attemtping to connect
// We need to do this because the smtp library doesn't have context support
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Connect to the SMTP server
client, err := srv.getSmtpClient()
if err != nil {
@@ -119,6 +135,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
}
defer client.Close()
// Check if the context is still valid before sending the email
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Send the email
if err := srv.sendEmailContent(client, toEmail, c); err != nil {
return fmt.Errorf("send email content: %w", err)
@@ -215,7 +239,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add
}
// Write the email content
_, err = w.Write([]byte(c.String()))
_, err = io.Copy(w, strings.NewReader(c.String()))
if err != nil {
return fmt.Errorf("failed to write email data: %w", err)
}

View File

@@ -42,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{
}
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
func NewGeoLiteService() *GeoLiteService {
func NewGeoLiteService(ctx context.Context) *GeoLiteService {
service := &GeoLiteService{}
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
@@ -52,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService {
}
go func() {
if err := service.updateDatabase(); err != nil {
log.Printf("Failed to update GeoLite2 City database: %v\n", err)
err := service.updateDatabase(ctx)
if err != nil {
log.Printf("Failed to update GeoLite2 City database: %v", err)
}
}()
@@ -111,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
}
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
func (s *GeoLiteService) updateDatabase() error {
func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error {
if s.disableUpdater {
// Avoid updating the GeoLite2 City database.
return nil
@@ -125,7 +126,7 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)

View File

@@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
}))
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
return client, nil
}
func (s *LdapService) SyncAll() error {
err := s.SyncUsers()
func (s *LdapService) SyncAll(ctx context.Context) error {
// Start a transaction
tx := s.db.Begin()
err := s.SyncUsers(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
err = s.SyncGroups()
err = s.SyncGroups(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error {
// Try to find the group in the database
var databaseGroup model.UserGroup
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
@@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
var databaseUser model.User
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
@@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error {
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup)
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
}
}
// Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch groups from database: %v", err)
}
// Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil {
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
@@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error {
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error {
// Get the user from the database
var databaseUser model.User
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
// Check if user is admin by checking if they are in the admin group
isAdmin := false
@@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error {
}
if databaseUser.ID == "" {
_, err = s.userService.CreateUser(newUser)
_, err = s.userService.createUserInternal(ctx, newUser, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
} else {
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true)
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
}
// Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
}
}
}
// Get all LDAP users from the database
var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch users from database: %v", err)
}
// Delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
if err := s.userService.DeleteUser(user.ID, true); err != nil {
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
log.Printf("Failed to delete user %s with: %v", user.Username, err)
} else {
log.Printf("Deleted user %s", user.Username)
}
}
}
return nil
}
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := url.ParseRequestURI(pictureString)
if err == nil {
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}
defer response.Body.Close()
reader = response.Body
defer res.Body.Close()
reader = res.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto)
} else {
// If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString))

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
@@ -9,6 +10,7 @@ import (
"mime/multipart"
"os"
"regexp"
"slices"
"strings"
"time"
@@ -39,9 +41,20 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
}
}
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("AllowedUserGroups").
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
return "", "", err
}
@@ -58,7 +71,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
// Check if the user group is allowed to authorize the client
var user model.User
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("UserGroups").
First(&user, "id = ?", userID).
Error
if err != nil {
return "", "", err
}
@@ -67,7 +85,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
}
// Check if the user has already authorized the client with the given scope
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope)
hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, input.ClientID, userID, input.Scope, tx)
if err != nil {
return "", "", err
}
@@ -80,39 +98,55 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
Scope: input.Scope,
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
// The client has already been authorized but with a different scope so we need to update the scope
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err
}
} else {
err = tx.
WithContext(ctx).
Create(&userAuthorizedClient).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// The client has already been authorized but with a different scope so we need to update the scope
if err := tx.
WithContext(ctx).
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err
}
} else if err != nil {
return "", "", err
}
}
// Create the authorization code
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
if err != nil {
return "", "", err
}
// Log the authorization event
if hasAuthorizedClient {
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
s.auditLogService.Create(ctx, model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
} else {
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
s.auditLogService.Create(ctx, model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
}
err = tx.Commit().Error
if err != nil {
return "", "", err
}
return code, callbackURL, nil
}
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) {
func (s *OidcService) HasAuthorizedClient(ctx context.Context, clientID, userID, scope string) (bool, error) {
return s.hasAuthorizedClientInternal(ctx, clientID, userID, scope, s.db)
}
func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID, userID, scope string, tx *gorm.DB) (bool, error) {
var userAuthorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
@@ -145,21 +179,31 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
switch grantType {
case "authorization_code":
return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier)
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret)
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
return "", accessToken, newRefreshToken, exp, err
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
}
}
func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", "", 0, err
}
@@ -176,7 +220,11 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err = tx.
WithContext(ctx).
Preload("User").
First(&authorizationCodeMetaData, "code = ?", code).
Error
if err != nil {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
@@ -192,7 +240,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
if err != nil {
return "", "", "", 0, err
}
@@ -203,7 +251,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
// Generate a refresh token
refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
if err != nil {
return "", "", "", 0, err
}
@@ -213,19 +261,40 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData)
err = tx.
WithContext(ctx).
Delete(&authorizationCodeMetaData).
Error
if err != nil {
return "", "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", "", 0, err
}
return idToken, accessToken, refreshToken, 3600, nil
}
func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
if refreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{}
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
// Get the client to check if it's public
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", 0, err
}
@@ -243,7 +312,9 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
@@ -266,29 +337,53 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
if err != nil {
return "", "", 0, err
}
// Delete the used refresh token
s.db.Delete(&storedRefreshToken)
err = tx.
WithContext(ctx).
Delete(&storedRefreshToken).
Error
if err != nil {
return "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", 0, err
}
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
return s.getClientInternal(ctx, clientID, s.db)
}
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
var client model.OidcClient
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("CreatedBy").
Preload("AllowedUserGroups").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
func (s *OidcService) ListClients(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
query := s.db.
WithContext(ctx).
Preload("CreatedBy").
Model(&model.OidcClient{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern)
@@ -302,7 +397,7 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
return clients, pagination, nil
}
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURLs: input.CallbackURLs,
@@ -312,16 +407,31 @@ func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string)
PkceEnabled: input.IsPublic || input.PkceEnabled,
}
if err := s.db.Create(&client).Error; err != nil {
err := s.db.
WithContext(ctx).
Create(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err
}
@@ -331,29 +441,49 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.IsPublic = input.IsPublic
client.PkceEnabled = input.IsPublic || input.PkceEnabled
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) DeleteClient(clientID string) error {
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if err := s.db.Delete(&client).Error; err != nil {
err := s.db.
WithContext(ctx).
Where("id = ?", clientID).
Delete(&client).
Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", err
}
@@ -368,16 +498,29 @@ func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
}
client.Secret = string(hashedSecret)
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return "", err
}
err = tx.Commit().Error
if err != nil {
return "", err
}
return clientSecret, nil
}
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (string, string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := s.db.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", err
}
@@ -385,26 +528,36 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
return "", "", errors.New("image not found")
}
imageType := *client.ImageType
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
mimeType := utils.GetImageMimeType(imageType)
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
mimeType := utils.GetImageMimeType(*client.ImageType)
return imagePath, mimeType, nil
}
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error {
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
return &common.FileTypeNotSupportedError{}
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
if err := utils.SaveFile(file, imagePath); err != nil {
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + fileType
err := utils.SaveFile(file, imagePath)
if err != nil {
return err
}
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
@@ -416,16 +569,35 @@ func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHead
}
client.ImageType = &fileType
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) DeleteClientLogo(clientID string) error {
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err
}
@@ -433,38 +605,72 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
return errors.New("image not found")
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
client.ImageType = nil
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
if err := os.Remove(imagePath); err != nil {
return err
}
client.ImageType = nil
if err := s.db.Save(&client).Error; err != nil {
err = tx.Commit().Error
if err != nil {
return err
}
return nil
}
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("User.UserGroups").
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
Error
if err != nil {
return nil, err
}
user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope
scopes := strings.Split(authorizedOidcClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
if strings.Contains(scope, "email") {
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
}
if strings.Contains(scope, "groups") {
if slices.Contains(scopes, "groups") {
userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
@@ -477,17 +683,17 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID),
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if strings.Contains(scope, "profile") {
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
if err != nil {
return nil, err
}
@@ -505,15 +711,22 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
}
}
}
if strings.Contains(scope, "email") {
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
}
return claims, nil
}
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
client, err = s.GetClient(id)
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
client, err = s.getClientInternal(ctx, id, tx)
if err != nil {
return model.OidcClient{}, err
}
@@ -521,18 +734,37 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
// Fetch the user groups based on UserGroupIDs in input
var groups []model.UserGroup
if len(input.UserGroupIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil {
err = tx.
WithContext(ctx).
Where("id IN (?)", input.UserGroupIDs).
Find(&groups).
Error
if err != nil {
return model.OidcClient{}, err
}
}
// Replace the current user groups with the new set of user groups
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil {
err = tx.
WithContext(ctx).
Model(&client).
Association("AllowedUserGroups").
Replace(groups)
if err != nil {
return model.OidcClient{}, err
}
// Save the updated client
if err := s.db.Save(&client).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err
}
@@ -540,7 +772,7 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
}
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) {
func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) {
// If no ID token hint is provided, return an error
if input.IdTokenHint == "" {
return "", &common.TokenInvalidError{}
@@ -564,7 +796,12 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil {
err = s.db.
WithContext(ctx).
Preload("Client").
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
Error
if err != nil {
return "", &common.OidcMissingAuthorizationError{}
}
@@ -582,7 +819,7 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
@@ -601,7 +838,11 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&oidcAuthorizationCode).
Error
if err != nil {
return "", err
}
@@ -647,7 +888,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
@@ -665,7 +906,11 @@ func (s *OidcService) createRefreshToken(clientID string, userID string, scope s
Scope: scope,
}
if err := s.db.Create(&m).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&m).
Error
if err != nil {
return "", err
}

View File

@@ -1,13 +1,15 @@
package service
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type UserGroupService struct {
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
return &UserGroupService{db: db, appConfigService: appConfigService}
}
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})
func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.
WithContext(ctx).
Preload("CustomClaims").
Model(&model.UserGroup{})
if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%")
@@ -42,26 +47,59 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
return groups, response, err
}
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
return s.getInternal(ctx, id, s.db)
}
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
err = tx.
WithContext(ctx).
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
First(&group).
Error
return group, err
}
func (s *UserGroupService) Delete(id string) error {
func (s *UserGroupService) Delete(ctx context.Context, id string) error {
tx := s.db.Begin()
var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", id).
First(&group).
Error
if err != nil {
return err
}
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
err = tx.Rollback().Error
if err != nil {
return err
}
return &common.LdapUserGroupUpdateError{}
}
return s.db.Delete(&group).Error
err = tx.
WithContext(ctx).
Delete(&group).
Error
if err != nil {
return err
}
return tx.Commit().Error
}
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return s.createInternal(ctx, input, s.db)
}
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
group = model.UserGroup{
FriendlyName: input.FriendlyName,
Name: input.Name,
@@ -71,7 +109,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
group.LdapID = &input.LdapID
}
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Create(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -80,8 +123,26 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
return group, nil
}
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -94,7 +155,12 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
group.Name = input.Name
group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Save(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -103,8 +169,26 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
// Fetch the users based on the userIds
var users []model.User
if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id IN (?)", userIds).
Find(&users).
Error
if err != nil {
return model.UserGroup{}, err
}
}
// Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil {
err = tx.
WithContext(ctx).
Model(&group).
Association("Users").
Replace(users)
if err != nil {
return model.UserGroup{}, err
}
// Save the updated group
if err := s.db.Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) {
func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
// We only perform select queries here, so we can rollback in all cases
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("Users").
Where("id = ?", id).
First(&group).
Error
if err != nil {
return 0, err
}
return s.db.Model(&group).Association("Users").Count(), nil
count := tx.
WithContext(ctx).
Model(&group).
Association("Users").
Count()
return count, nil
}

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -12,7 +13,7 @@ import (
"time"
"github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -20,7 +21,7 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
)
type UserService struct {
@@ -35,9 +36,9 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
}
func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
var users []model.User
query := s.db.Model(&model.User{})
query := s.db.WithContext(ctx).Model(&model.User{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
@@ -48,13 +49,23 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
return users, pagination, err
}
func (s *UserService) GetUser(userID string) (model.User, error) {
func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
return s.getUserInternal(ctx, userID, s.db)
}
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
var user model.User
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
err := tx.
WithContext(ctx).
Preload("UserGroups").
Preload("CustomClaims").
Where("id = ?", userID).
First(&user).
Error
return user, err
}
func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, error) {
func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{}
@@ -74,7 +85,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
}
// If no custom picture exists, get the user's data for creating initials
user, err := s.GetUser(userID)
user, err := s.GetUser(ctx, userID)
if err != nil {
return nil, 0, err
}
@@ -115,9 +126,15 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
}
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
err := s.db.
WithContext(ctx).
Preload("UserGroups").
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return nil, err
}
return user.UserGroups, nil
@@ -152,9 +169,21 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
return nil
}
func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
return s.db.Transaction(func(tx *gorm.DB) error {
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
})
}
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return err
}
@@ -165,14 +194,35 @@ func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
// Delete the profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
err = os.Remove(profilePicturePath)
if err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&user).Error
return tx.WithContext(ctx).Delete(&user).Error
}
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err := s.createUserInternal(ctx, input, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) {
user := model.User{
FirstName: input.FirstName,
LastName: input.LastName,
@@ -185,18 +235,47 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
user.LdapID = &input.LdapID
}
if err := s.db.Create(&user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.User{}, s.checkDuplicatedFields(user)
}
err := tx.WithContext(ctx).Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
return model.User{}, err
} else if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return model.User{}, err
}
@@ -214,24 +293,42 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
user.IsAdmin = updatedUser.IsAdmin
}
if err := s.db.Save(&user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return user, s.checkDuplicatedFields(user)
}
err = tx.
WithContext(ctx).
Save(&user).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
return user, err
} else if err != nil {
return user, err
}
return user, nil
}
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddress, redirectPath string) error {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
var user model.User
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
err := tx.
WithContext(ctx).
Where("email = ?", emailAddress).
First(&user).
Error
if err != nil {
// Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
@@ -240,22 +337,31 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
}
}
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute))
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, time.Now().Add(15*time.Minute), tx)
if err != nil {
return err
}
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL)
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken)
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath)
err = tx.Commit().Error
if err != nil {
return err
}
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
err := SendEmail(s.emailService, email.Address{
innerCtx := context.Background()
link := common.EnvConfig.AppURL + "/lc"
linkWithCode := link + "/" + oneTimeAccessToken
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
}
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
Name: user.Username,
Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
@@ -263,18 +369,21 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
LoginLink: link,
LoginLinkWithCode: linkWithCode,
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
if errInternal != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
}
}()
return nil
}
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
tokenLength := 16
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) {
return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db)
}
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) {
// If expires at is less than 15 minutes, use an 6 character token instead of 16
tokenLength := 16
if time.Until(expiresAt) <= 15*time.Minute {
tokenLength = 6
}
@@ -290,16 +399,27 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
Token: randomString,
}
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil {
return "", err
}
return oneTimeAccessToken.Token, nil
}
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) {
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
err := tx.
WithContext(ctx).
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
First(&oneTimeAccessToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
}
@@ -310,19 +430,34 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return model.User{}, "", err
}
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
err = tx.
WithContext(ctx).
Delete(&oneTimeAccessToken).
Error
if err != nil {
return model.User{}, "", err
}
if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{})
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return oneTimeAccessToken.User, accessToken, nil
}
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id)
func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
user, err = s.getUserInternal(ctx, id, tx)
if err != nil {
return model.User{}, err
}
@@ -330,27 +465,49 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m
// Fetch the groups based on userGroupIds
var groups []model.UserGroup
if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
err = tx.
WithContext(ctx).
Where("id IN (?)", userGroupIds).
Find(&groups).
Error
if err != nil {
return model.User{}, err
}
}
// Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
err = tx.
WithContext(ctx).
Model(&user).
Association("UserGroups").
Replace(groups)
if err != nil {
return model.User{}, err
}
// Save the updated user
if err := s.db.Save(&user).Error; err != nil {
err = tx.WithContext(ctx).Save(&user).Error
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err
}
if userCount > 1 {
@@ -365,7 +522,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
IsAdmin: true,
}
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
return model.User{}, "", err
}
@@ -378,16 +535,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
return model.User{}, "", err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return user, token, nil
}
func (s *UserService) checkDuplicatedFields(user model.User) error {
var existingUser model.User
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
var result struct {
Found bool
}
err := tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "email"}
}
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
err = tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "username"}
}

View File

@@ -1,16 +1,19 @@
package service
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type WebAuthnService struct {
@@ -43,15 +46,31 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService}
}
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
s.updateWebAuthnConfig()
var user model.User
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("Credentials").
Find(&user, "id = ?", userID).
Error
if err != nil {
tx.Rollback()
return nil, err
}
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
options, session, err := s.webAuthn.BeginRegistration(
&user,
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
)
if err != nil {
return nil, err
}
@@ -62,7 +81,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
@@ -73,9 +101,19 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil
}
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -86,7 +124,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
}
var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
err = tx.
WithContext(ctx).
Find(&user, "id = ?", userID).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -108,7 +150,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
}
if err := s.db.Create(&credentialToStore).Error; err != nil {
err = tx.
WithContext(ctx).
Create(&credentialToStore).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, err
}
@@ -125,7 +176,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
return "New Passkey" // Default fallback
}
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil {
return nil, err
@@ -137,7 +188,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
err = s.db.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
@@ -148,9 +203,19 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil
}
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.User{}, "", err
}
@@ -160,9 +225,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
}
var user *model.User
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
return nil, err
_, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
innerErr := tx.
WithContext(ctx).
Preload("Credentials").
First(&user, "id = ?", string(userHandle)).
Error
if innerErr != nil {
return nil, innerErr
}
return user, nil
}, session, credentialAssertionData)
@@ -176,41 +246,70 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
return model.User{}, "", err
}
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID)
s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return *user, token, nil
}
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
var credentials []model.WebauthnCredential
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
err := s.db.
WithContext(ctx).
Find(&credentials, "user_id = ?", userID).
Error
if err != nil {
return nil, err
}
return credentials, nil
}
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
var credential model.WebauthnCredential
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
return err
}
if err := s.db.Delete(&credential).Error; err != nil {
return err
func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
Delete(&model.WebauthnCredential{}).
Error
if err != nil {
return fmt.Errorf("failed to delete record: %w", err)
}
return nil
}
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
// This is a no-op if the transaction has been committed already
tx.Rollback()
}()
var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
First(&credential).
Error
if err != nil {
return credential, err
}
credential.Name = name
if err := s.db.Save(&credential).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&credential).
Error
if err != nil {
return credential, err
}
err = tx.Commit().Error
if err != nil {
return credential, err
}