diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go index c53e10e6..16b2fcaa 100644 --- a/backend/internal/bootstrap/bootstrap.go +++ b/backend/internal/bootstrap/bootstrap.go @@ -1,19 +1,23 @@ package bootstrap import ( + "context" + _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/pocket-id/pocket-id/backend/internal/service" ) func Bootstrap() { + ctx := context.TODO() + initApplicationImages() migrateConfigDBConnstring() db := newDatabase() - appConfigService := service.NewAppConfigService(db) + appConfigService := service.NewAppConfigService(ctx, db) migrateKey() - initRouter(db, appConfigService) + initRouter(ctx, db, appConfigService) } diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index f87d4022..4cec75b5 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -1,6 +1,7 @@ package bootstrap import ( + "context" "log" "net" "time" @@ -19,7 +20,7 @@ import ( // This is used to register additional controllers for tests var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) -func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { +func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) { // Set the appropriate Gin mode based on the environment switch common.EnvConfig.AppEnv { case "production": @@ -36,10 +37,10 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { // Initialize services emailService, err := service.NewEmailService(appConfigService, db) if err != nil { - log.Fatalf("Unable to create email service: %s", err) + log.Fatalf("Unable to create email service: %v", err) } - geoLiteService := service.NewGeoLiteService() + geoLiteService := service.NewGeoLiteService(ctx) auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService) jwtService := service.NewJwtService(appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) @@ -57,9 +58,9 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60)) - job.RegisterLdapJobs(ldapService, appConfigService) - job.RegisterDbCleanupJobs(db) - job.RegisterFileCleanupJobs(db) + job.RegisterLdapJobs(ctx, ldapService, appConfigService) + job.RegisterDbCleanupJobs(ctx, db) + job.RegisterFileCleanupJobs(ctx, db) // Initialize middleware for specific routes authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService) diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index b593795f..5307431c 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -17,7 +17,7 @@ type AlreadyInUseError struct { } func (e *AlreadyInUseError) Error() string { - return fmt.Sprintf("%s is already in use", e.Property) + return e.Property + " is already in use" } func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } diff --git a/backend/internal/controller/api_key_controller.go b/backend/internal/controller/api_key_controller.go index d0e1223d..66786e12 100644 --- a/backend/internal/controller/api_key_controller.go +++ b/backend/internal/controller/api_key_controller.go @@ -53,7 +53,7 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) { return } - apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest) + apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest) if err != nil { _ = ctx.Error(err) return @@ -87,7 +87,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) { return } - apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input) + apiKey, token, err := c.apiKeyService.CreateApiKey(ctx.Request.Context(), userID, input) if err != nil { _ = ctx.Error(err) return @@ -116,7 +116,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) { userID := ctx.GetString("userID") apiKeyID := ctx.Param("id") - if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil { + if err := c.apiKeyService.RevokeApiKey(ctx.Request.Context(), userID, apiKeyID); err != nil { _ = ctx.Error(err) return } diff --git a/backend/internal/controller/app_config_controller.go b/backend/internal/controller/app_config_controller.go index deb3a1e1..8e3dc89f 100644 --- a/backend/internal/controller/app_config_controller.go +++ b/backend/internal/controller/app_config_controller.go @@ -1,7 +1,6 @@ package controller import ( - "fmt" "net/http" "strconv" @@ -61,7 +60,7 @@ type AppConfigController struct { // @Failure 500 {object} object "{"error": "error message"}" // @Router /application-configuration [get] func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { - configuration, err := acc.appConfigService.ListAppConfig(false) + configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), false) if err != nil { _ = c.Error(err) return @@ -73,7 +72,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { return } - c.JSON(200, configVariablesDto) + c.JSON(http.StatusOK, configVariablesDto) } // listAllAppConfigHandler godoc @@ -86,7 +85,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { // @Security BearerAuth // @Router /application-configuration/all [get] func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { - configuration, err := acc.appConfigService.ListAppConfig(true) + configuration, err := acc.appConfigService.ListAppConfig(c.Request.Context(), true) if err != nil { _ = c.Error(err) return @@ -98,7 +97,7 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { return } - c.JSON(200, configVariablesDto) + c.JSON(http.StatusOK, configVariablesDto) } // updateAppConfigHandler godoc @@ -118,7 +117,7 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) { return } - savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input) + savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(c.Request.Context(), input) if err != nil { _ = c.Error(err) return @@ -253,7 +252,7 @@ func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) { // getImage is a helper function to serve image files func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) { - imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) + imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType mimeType := utils.GetImageMimeType(imageType) c.Header("Content-Type", mimeType) @@ -268,7 +267,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol return } - err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) + err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType) if err != nil { _ = c.Error(err) return @@ -285,7 +284,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol // @Security BearerAuth // @Router /api/application-configuration/sync-ldap [post] func (acc *AppConfigController) syncLdapHandler(c *gin.Context) { - err := acc.ldapService.SyncAll() + err := acc.ldapService.SyncAll(c.Request.Context()) if err != nil { _ = c.Error(err) return @@ -304,7 +303,7 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) { func (acc *AppConfigController) testEmailHandler(c *gin.Context) { userID := c.GetString("userID") - err := acc.emailService.SendTestEmail(userID) + err := acc.emailService.SendTestEmail(c.Request.Context(), userID) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/audit_log_controller.go b/backend/internal/controller/audit_log_controller.go index a9465a4e..28c03f57 100644 --- a/backend/internal/controller/audit_log_controller.go +++ b/backend/internal/controller/audit_log_controller.go @@ -42,7 +42,9 @@ type AuditLogController struct { // @Router /api/audit-logs [get] func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { var sortedPaginationRequest utils.SortedPaginationRequest - if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { + + err := c.ShouldBindQuery(&sortedPaginationRequest) + if err != nil { _ = c.Error(err) return } @@ -50,7 +52,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { userID := c.GetString("userID") // Fetch audit logs for the user - logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest) + logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/custom_claim_controller.go b/backend/internal/controller/custom_claim_controller.go index 40168c3a..e2f67c47 100644 --- a/backend/internal/controller/custom_claim_controller.go +++ b/backend/internal/controller/custom_claim_controller.go @@ -41,7 +41,7 @@ type CustomClaimController struct { // @Security BearerAuth // @Router /api/custom-claims/suggestions [get] func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { - claims, err := ccc.customClaimService.GetSuggestions() + claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context()) if err != nil { _ = c.Error(err) return @@ -69,7 +69,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex } userId := c.Param("userId") - claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input) + claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(c.Request.Context(), userId, input) if err != nil { _ = c.Error(err) return @@ -104,7 +104,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.C } userGroupId := c.Param("userGroupId") - claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input) + claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(c.Request.Context(), userGroupId, input) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 201ec39e..d4b44022 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -69,7 +69,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { return } - code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) + code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) if err != nil { _ = c.Error(err) return @@ -100,7 +100,7 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex return } - hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope) + hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(c.Request.Context(), input.ClientID, c.GetString("userID"), input.Scope) if err != nil { _ = c.Error(err) return @@ -153,6 +153,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { } idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens( + c.Request.Context(), input.Code, input.GrantType, clientID, @@ -216,7 +217,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { _ = c.Error(&common.TokenInvalidError{}) return } - claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0]) + claims, err := oc.oidcService.GetUserClaimsForClient(c.Request.Context(), userID, clientID[0]) if err != nil { _ = c.Error(err) return @@ -254,7 +255,7 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) { } } - callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID")) + callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID")) if err != nil { // If the validation fails, the user has to confirm the logout manually and doesn't get redirected log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err) @@ -300,7 +301,7 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) { // @Router /api/oidc/clients/{id}/meta [get] func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) { clientId := c.Param("id") - client, err := oc.oidcService.GetClient(clientId) + client, err := oc.oidcService.GetClient(c.Request.Context(), clientId) if err != nil { _ = c.Error(err) return @@ -327,7 +328,7 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) { // @Router /api/oidc/clients/{id} [get] func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") - client, err := oc.oidcService.GetClient(clientId) + client, err := oc.oidcService.GetClient(c.Request.Context(), clientId) if err != nil { _ = c.Error(err) return @@ -363,7 +364,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { return } - clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest) + clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest) if err != nil { _ = c.Error(err) return @@ -398,7 +399,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { return } - client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) + client, err := oc.oidcService.CreateClient(c.Request.Context(), input, c.GetString("userID")) if err != nil { _ = c.Error(err) return @@ -422,7 +423,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) { // @Security BearerAuth // @Router /api/oidc/clients/{id} [delete] func (oc *OidcController) deleteClientHandler(c *gin.Context) { - err := oc.oidcService.DeleteClient(c.Param("id")) + err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -449,7 +450,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { return } - client, err := oc.oidcService.UpdateClient(c.Param("id"), input) + client, err := oc.oidcService.UpdateClient(c.Request.Context(), c.Param("id"), input) if err != nil { _ = c.Error(err) return @@ -474,7 +475,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) { // @Security BearerAuth // @Router /api/oidc/clients/{id}/secret [post] func (oc *OidcController) createClientSecretHandler(c *gin.Context) { - secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) + secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -494,7 +495,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) { // @Success 200 {file} binary "Logo image" // @Router /api/oidc/clients/{id}/logo [get] func (oc *OidcController) getClientLogoHandler(c *gin.Context) { - imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) + imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -521,7 +522,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { return } - err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) + err = oc.oidcService.UpdateClientLogo(c.Request.Context(), c.Param("id"), file) if err != nil { _ = c.Error(err) return @@ -539,7 +540,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { // @Security BearerAuth // @Router /api/oidc/clients/{id}/logo [delete] func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { - err := oc.oidcService.DeleteClientLogo(c.Param("id")) + err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -566,7 +567,7 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) { return } - oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input) + oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Request.Context(), c.Param("id"), input) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index e6c2538c..fe62d326 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -65,7 +65,7 @@ type UserController struct { // @Router /api/users/{id}/groups [get] func (uc *UserController) getUserGroupsHandler(c *gin.Context) { userID := c.Param("id") - groups, err := uc.userService.GetUserGroups(userID) + groups, err := uc.userService.GetUserGroups(c.Request.Context(), userID) if err != nil { _ = c.Error(err) return @@ -99,7 +99,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { return } - users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest) + users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest) if err != nil { _ = c.Error(err) return @@ -125,7 +125,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { // @Success 200 {object} dto.UserDto // @Router /api/users/{id} [get] func (uc *UserController) getUserHandler(c *gin.Context) { - user, err := uc.userService.GetUser(c.Param("id")) + user, err := uc.userService.GetUser(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -147,7 +147,7 @@ func (uc *UserController) getUserHandler(c *gin.Context) { // @Success 200 {object} dto.UserDto // @Router /api/users/me [get] func (uc *UserController) getCurrentUserHandler(c *gin.Context) { - user, err := uc.userService.GetUser(c.GetString("userID")) + user, err := uc.userService.GetUser(c.Request.Context(), c.GetString("userID")) if err != nil { _ = c.Error(err) return @@ -170,7 +170,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) { // @Success 204 "No Content" // @Router /api/users/{id} [delete] func (uc *UserController) deleteUserHandler(c *gin.Context) { - if err := uc.userService.DeleteUser(c.Param("id"), false); err != nil { + if err := uc.userService.DeleteUser(c.Request.Context(), c.Param("id"), false); err != nil { _ = c.Error(err) return } @@ -192,7 +192,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) { return } - user, err := uc.userService.CreateUser(input) + user, err := uc.userService.CreateUser(c.Request.Context(), input) if err != nil { _ = c.Error(err) return @@ -245,7 +245,7 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) { userID := c.Param("id") - picture, size, err := uc.userService.GetProfilePicture(userID) + picture, size, err := uc.userService.GetProfilePicture(c.Request.Context(), userID) if err != nil { _ = c.Error(err) return @@ -332,7 +332,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo if own { input.UserID = c.GetString("userID") } - token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) + token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt) if err != nil { _ = c.Error(err) return @@ -364,7 +364,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) { return } - err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath) + err := uc.userService.RequestOneTimeAccessEmail(c.Request.Context(), input.Email, input.RedirectPath) if err != nil { _ = c.Error(err) return @@ -381,7 +381,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) { // @Success 200 {object} dto.UserDto // @Router /api/one-time-access-token/{token} [post] func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { - user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent()) + user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent()) if err != nil { _ = c.Error(err) return @@ -406,7 +406,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { // @Success 200 {object} dto.UserDto // @Router /api/one-time-access-token/setup [post] func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { - user, token, err := uc.userService.SetupInitialAdmin() + user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context()) if err != nil { _ = c.Error(err) return @@ -439,7 +439,7 @@ func (uc *UserController) updateUserGroups(c *gin.Context) { return } - user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds) + user, err := uc.userService.UpdateUserGroups(c.Request.Context(), c.Param("id"), input.UserGroupIds) if err != nil { _ = c.Error(err) return @@ -469,7 +469,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { userID = c.Param("id") } - user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false) + user, err := uc.userService.UpdateUser(c.Request.Context(), userID, input, updateOwnUser, false) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/user_group_controller.go b/backend/internal/controller/user_group_controller.go index 04865990..8f47b40b 100644 --- a/backend/internal/controller/user_group_controller.go +++ b/backend/internal/controller/user_group_controller.go @@ -47,6 +47,8 @@ type UserGroupController struct { // @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount] // @Router /api/user-groups [get] func (ugc *UserGroupController) list(c *gin.Context) { + ctx := c.Request.Context() + searchTerm := c.Query("search") var sortedPaginationRequest utils.SortedPaginationRequest if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { @@ -54,7 +56,7 @@ func (ugc *UserGroupController) list(c *gin.Context) { return } - groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest) + groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest) if err != nil { _ = c.Error(err) return @@ -68,7 +70,7 @@ func (ugc *UserGroupController) list(c *gin.Context) { _ = c.Error(err) return } - groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID) + groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID) if err != nil { _ = c.Error(err) return @@ -93,7 +95,7 @@ func (ugc *UserGroupController) list(c *gin.Context) { // @Security BearerAuth // @Router /api/user-groups/{id} [get] func (ugc *UserGroupController) get(c *gin.Context) { - group, err := ugc.UserGroupService.Get(c.Param("id")) + group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id")) if err != nil { _ = c.Error(err) return @@ -125,7 +127,7 @@ func (ugc *UserGroupController) create(c *gin.Context) { return } - group, err := ugc.UserGroupService.Create(input) + group, err := ugc.UserGroupService.Create(c.Request.Context(), input) if err != nil { _ = c.Error(err) return @@ -158,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) { return } - group, err := ugc.UserGroupService.Update(c.Param("id"), input, false) + group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false) if err != nil { _ = c.Error(err) return @@ -184,7 +186,7 @@ func (ugc *UserGroupController) update(c *gin.Context) { // @Security BearerAuth // @Router /api/user-groups/{id} [delete] func (ugc *UserGroupController) delete(c *gin.Context) { - if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil { + if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil { _ = c.Error(err) return } @@ -210,7 +212,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) { return } - group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs) + group, err := ugc.UserGroupService.UpdateUsers(c.Request.Context(), c.Param("id"), input.UserIDs) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go index 010dd994..75668326 100644 --- a/backend/internal/controller/webauthn_controller.go +++ b/backend/internal/controller/webauthn_controller.go @@ -37,7 +37,7 @@ type WebauthnController struct { func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { userID := c.GetString("userID") - options, err := wc.webAuthnService.BeginRegistration(userID) + options, err := wc.webAuthnService.BeginRegistration(c.Request.Context(), userID) if err != nil { _ = c.Error(err) return @@ -55,7 +55,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { } userID := c.GetString("userID") - credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) + credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request) if err != nil { _ = c.Error(err) return @@ -71,7 +71,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { } func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { - options, err := wc.webAuthnService.BeginLogin() + options, err := wc.webAuthnService.BeginLogin(c.Request.Context()) if err != nil { _ = c.Error(err) return @@ -94,7 +94,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { return } - user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) + user, token, err := wc.webAuthnService.VerifyLogin(c.Request.Context(), sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) if err != nil { _ = c.Error(err) return @@ -114,7 +114,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { userID := c.GetString("userID") - credentials, err := wc.webAuthnService.ListCredentials(userID) + credentials, err := wc.webAuthnService.ListCredentials(c.Request.Context(), userID) if err != nil { _ = c.Error(err) return @@ -133,7 +133,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { userID := c.GetString("userID") credentialID := c.Param("id") - err := wc.webAuthnService.DeleteCredential(userID, credentialID) + err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID) if err != nil { _ = c.Error(err) return @@ -152,7 +152,7 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { return } - credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) + credential, err := wc.webAuthnService.UpdateCredential(c.Request.Context(), userID, credentialID, input.Name) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/job/db_cleanup_job.go b/backend/internal/job/db_cleanup_job.go index 0e63bc34..b72d8f34 100644 --- a/backend/internal/job/db_cleanup_job.go +++ b/backend/internal/job/db_cleanup_job.go @@ -1,16 +1,18 @@ package job import ( + "context" "log" "time" "github.com/go-co-op/gocron/v2" + "gorm.io/gorm" + "github.com/pocket-id/pocket-id/backend/internal/model" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" - "gorm.io/gorm" ) -func RegisterDbCleanupJobs(db *gorm.DB) { +func RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) { scheduler, err := gocron.NewScheduler() if err != nil { log.Fatalf("Failed to create a new scheduler: %s", err) @@ -18,11 +20,11 @@ func RegisterDbCleanupJobs(db *gorm.DB) { jobs := &DbCleanupJobs{db: db} - registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions) - registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens) - registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes) - registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens) - registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs) + registerJob(ctx, scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions) + registerJob(ctx, scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens) + registerJob(ctx, scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes) + registerJob(ctx, scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens) + registerJob(ctx, scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs) scheduler.Start() } @@ -31,26 +33,41 @@ type DbCleanupJobs struct { } // ClearWebauthnSessions deletes WebAuthn sessions that have expired -func (j *DbCleanupJobs) clearWebauthnSessions() error { - return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).Error +func (j *DbCleanupJobs) clearWebauthnSessions(ctx context.Context) error { + return j.db. + WithContext(ctx). + Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())). + Error } // ClearOneTimeAccessTokens deletes one-time access tokens that have expired -func (j *DbCleanupJobs) clearOneTimeAccessTokens() error { - return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error +func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error { + return j.db. + WithContext(ctx). + Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())). + Error } // ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired -func (j *DbCleanupJobs) clearOidcAuthorizationCodes() error { - return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error +func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error { + return j.db. + WithContext(ctx). + Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())). + Error } // ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired -func (j *DbCleanupJobs) clearOidcRefreshTokens() error { - return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error +func (j *DbCleanupJobs) clearOidcRefreshTokens(ctx context.Context) error { + return j.db. + WithContext(ctx). + Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())). + Error } // ClearAuditLogs deletes audit logs older than 90 days -func (j *DbCleanupJobs) clearAuditLogs() error { - return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error +func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error { + return j.db. + WithContext(ctx). + Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))). + Error } diff --git a/backend/internal/job/file_cleanup_job.go b/backend/internal/job/file_cleanup_job.go index d275e6f6..13b62cf9 100644 --- a/backend/internal/job/file_cleanup_job.go +++ b/backend/internal/job/file_cleanup_job.go @@ -1,6 +1,7 @@ package job import ( + "context" "fmt" "log" "os" @@ -8,12 +9,13 @@ import ( "strings" "github.com/go-co-op/gocron/v2" + "gorm.io/gorm" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" - "gorm.io/gorm" ) -func RegisterFileCleanupJobs(db *gorm.DB) { +func RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) { scheduler, err := gocron.NewScheduler() if err != nil { log.Fatalf("Failed to create a new scheduler: %s", err) @@ -21,7 +23,7 @@ func RegisterFileCleanupJobs(db *gorm.DB) { jobs := &FileCleanupJobs{db: db} - registerJob(scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures) + registerJob(ctx, scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures) scheduler.Start() } @@ -31,9 +33,13 @@ type FileCleanupJobs struct { } // ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials -func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures() error { +func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context) error { var users []model.User - if err := j.db.Find(&users).Error; err != nil { + err := j.db. + WithContext(ctx). + Find(&users). + Error + if err != nil { return fmt.Errorf("failed to fetch users: %w", err) } diff --git a/backend/internal/job/job.go b/backend/internal/job/job.go index 628d4926..2bd9a2fe 100644 --- a/backend/internal/job/job.go +++ b/backend/internal/job/job.go @@ -1,16 +1,18 @@ package job import ( + "context" "log" "github.com/go-co-op/gocron/v2" "github.com/google/uuid" ) -func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) { +func registerJob(ctx context.Context, scheduler gocron.Scheduler, name string, interval string, job func(ctx context.Context) error) { _, err := scheduler.NewJob( gocron.CronJob(interval, false), gocron.NewTask(job), + gocron.WithContext(ctx), gocron.WithEventListeners( gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) { log.Printf("Job %q run successfully", name) diff --git a/backend/internal/job/ldap_job.go b/backend/internal/job/ldap_job.go index 2b9dc8f4..c03f7d51 100644 --- a/backend/internal/job/ldap_job.go +++ b/backend/internal/job/ldap_job.go @@ -1,6 +1,7 @@ package job import ( + "context" "log" "github.com/go-co-op/gocron/v2" @@ -12,28 +13,30 @@ type LdapJobs struct { appConfigService *service.AppConfigService } -func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *service.AppConfigService) { +func RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) { jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService} scheduler, err := gocron.NewScheduler() if err != nil { - log.Fatalf("Failed to create a new scheduler: %s", err) + log.Fatalf("Failed to create a new scheduler: %v", err) } // Register the job to run every hour - registerJob(scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap) + registerJob(ctx, scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap) // Run the job immediately on startup - if err := jobs.syncLdap(); err != nil { - log.Printf("Failed to sync LDAP: %s", err) + err = jobs.syncLdap(ctx) + if err != nil { + log.Printf("Failed to sync LDAP: %v", err) } scheduler.Start() } -func (j *LdapJobs) syncLdap() error { - if j.appConfigService.DbConfig.LdapEnabled.IsTrue() { - return j.ldapService.SyncAll() +func (j *LdapJobs) syncLdap(ctx context.Context) error { + if !j.appConfigService.DbConfig.LdapEnabled.IsTrue() { + return nil } - return nil + + return j.ldapService.SyncAll(ctx) } diff --git a/backend/internal/middleware/api_key_auth.go b/backend/internal/middleware/api_key_auth.go index 4464c5f3..0741cf34 100644 --- a/backend/internal/middleware/api_key_auth.go +++ b/backend/internal/middleware/api_key_auth.go @@ -36,7 +36,7 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc { func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) { apiKey := c.GetHeader("X-API-KEY") - user, err := m.apiKeyService.ValidateApiKey(apiKey) + user, err := m.apiKeyService.ValidateApiKey(c.Request.Context(), apiKey) if err != nil { return "", false, &common.NotSignedInError{} } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index d9bc0179..50b04e5c 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -1,8 +1,8 @@ package service import ( + "context" "errors" - "log" "time" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" @@ -12,6 +12,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/utils" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type ApiKeyService struct { @@ -22,8 +23,11 @@ func NewApiKeyService(db *gorm.DB) *ApiKeyService { return &ApiKeyService{db: db} } -func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) { - query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{}) +func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) { + query := s.db. + WithContext(ctx). + Where("user_id = ?", userID). + Model(&model.ApiKey{}) var apiKeys []model.ApiKey pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys) @@ -34,7 +38,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils return apiKeys, pagination, nil } -func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) { +func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) { // Check if expiration is in the future if !input.ExpiresAt.ToTime().After(time.Now()) { return model.ApiKey{}, "", &common.APIKeyExpirationDateError{} @@ -54,7 +58,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) ( UserID: userID, } - if err := s.db.Create(&apiKey).Error; err != nil { + err = s.db. + WithContext(ctx). + Create(&apiKey). + Error + if err != nil { return model.ApiKey{}, "", err } @@ -62,29 +70,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) ( return apiKey, token, nil } -func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error { +func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error { var apiKey model.ApiKey - if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil { + err := s.db. + WithContext(ctx). + Where("id = ? AND user_id = ?", apiKeyID, userID). + Delete(&apiKey). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return &common.APIKeyNotFoundError{} } return err } - return s.db.Delete(&apiKey).Error + return nil } -func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) { +func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) { if apiKey == "" { return model.User{}, &common.NoAPIKeyProvidedError{} } - var key model.ApiKey + now := time.Now() hashedKey := utils.CreateSha256Hash(apiKey) - if err := s.db.Preload("User").Where("key = ? AND expires_at > ?", - hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil { - + var key model.ApiKey + err := s.db. + WithContext(ctx). + Model(&model.ApiKey{}). + Clauses(clause.Returning{}). + Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)). + Updates(&model.ApiKey{ + LastUsedAt: utils.Ptr(datatype.DateTime(now)), + }). + Preload("User"). + First(&key). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return model.User{}, &common.InvalidAPIKeyError{} } @@ -92,12 +115,5 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) { return model.User{}, err } - // Update last used time - now := datatype.DateTime(time.Now()) - key.LastUsedAt = &now - if err := s.db.Save(&key).Error; err != nil { - log.Printf("Failed to update last used time: %v", err) - } - return key.User, nil } diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index 1aea718b..b0c72bf7 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -1,7 +1,8 @@ package service import ( - "fmt" + "context" + "errors" "log" "mime/multipart" "os" @@ -19,12 +20,14 @@ type AppConfigService struct { db *gorm.DB } -func NewAppConfigService(db *gorm.DB) *AppConfigService { +func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService { service := &AppConfigService{ DbConfig: &defaultDbConfig, db: db, } - if err := service.InitDbConfig(); err != nil { + + err := service.InitDbConfig(ctx) + if err != nil { log.Fatalf("Failed to initialize app config service: %v", err) } @@ -197,17 +200,24 @@ var defaultDbConfig = model.AppConfig{ }, } -func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { +func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { if common.EnvConfig.UiConfigDisabled { return nil, &common.UiConfigDisabledError{} } tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + var err error + rt := reflect.ValueOf(input).Type() rv := reflect.ValueOf(input) - var savedConfigVariables []model.AppConfigVariable - for i := 0; i < rt.NumField(); i++ { + savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField()) + for i := range rt.NumField() { field := rt.Field(i) key := field.Tag.Get("json") value := rv.FieldByName(field.Name).String() @@ -220,32 +230,47 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode } var appConfigVariable model.AppConfigVariable - if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil { - tx.Rollback() + err = tx. + WithContext(ctx). + First(&appConfigVariable, "key = ? AND is_internal = false", key). + Error + if err != nil { return nil, err } appConfigVariable.Value = value - if err := tx.Save(&appConfigVariable).Error; err != nil { - tx.Rollback() + err = tx. + WithContext(ctx). + Save(&appConfigVariable). + Error + if err != nil { return nil, err } savedConfigVariables = append(savedConfigVariables, appConfigVariable) } - tx.Commit() + err = tx.Commit().Error + if err != nil { + return nil, err + } - if err := s.LoadDbConfigFromDb(); err != nil { + err = s.LoadDbConfigFromDb() + if err != nil { return nil, err } return savedConfigVariables, nil } -func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error { - key := fmt.Sprintf("%sImageType", imageName) - err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error +func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error { + key := imageName + "ImageType" + err := s.db. + WithContext(ctx). + Model(&model.AppConfigVariable{}). + Where("key = ?", key). + Update("value", fileType). + Error if err != nil { return err } @@ -253,14 +278,17 @@ func (s *AppConfigService) UpdateImageType(imageName string, fileType string) er return s.LoadDbConfigFromDb() } -func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) { - var configuration []model.AppConfigVariable - var err error - +func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) { if showAll { - err = s.db.Find(&configuration).Error + err = s.db. + WithContext(ctx). + Find(&configuration). + Error } else { - err = s.db.Find(&configuration, "is_public = true").Error + err = s.db. + WithContext(ctx). + Find(&configuration, "is_public = true"). + Error } if err != nil { @@ -271,7 +299,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl if common.EnvConfig.UiConfigDisabled { // Set the value to the environment variable if the UI config is disabled configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue) - } else if configuration[i].Value == "" && configuration[i].DefaultValue != "" { // Set the value to the default value if it is empty configuration[i].Value = configuration[i].DefaultValue @@ -281,7 +308,7 @@ func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariabl return configuration, nil } -func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error { +func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) { fileType := utils.GetFileExtension(uploadedFile.Filename) mimeType := utils.GetImageMimeType(fileType) if mimeType == "" { @@ -290,19 +317,22 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image // Delete the old image if it has a different file type if fileType != oldImageType { - oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType) - if err := os.Remove(oldImagePath); err != nil { + oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType + err = os.Remove(oldImagePath) + if err != nil { return err } } - imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType) - if err := utils.SaveFile(uploadedFile, imagePath); err != nil { + imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType + err = utils.SaveFile(uploadedFile, imagePath) + if err != nil { return err } // Update the file type in the database - if err := s.UpdateImageType(imageName, fileType); err != nil { + err = s.updateImageType(ctx, imageName, fileType) + if err != nil { return err } @@ -312,33 +342,58 @@ func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, image // InitDbConfig creates the default configuration values in the database if they do not exist, // updates existing configurations if they differ from the default, and deletes any configurations // that are not in the default configuration. -func (s *AppConfigService) InitDbConfig() error { +func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + // Reflect to get the underlying value of DbConfig and its default configuration defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig) defaultKeys := make(map[string]struct{}) // Iterate over the fields of DbConfig - for i := 0; i < defaultConfigReflectValue.NumField(); i++ { + for i := range defaultConfigReflectValue.NumField() { defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable) defaultKeys[defaultConfigVar.Key] = struct{}{} var storedConfigVar model.AppConfigVariable - if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil { + err = tx. + WithContext(ctx). + First(&storedConfigVar, "key = ?", defaultConfigVar.Key). + Error + if errors.Is(err, gorm.ErrRecordNotFound) { // If the configuration does not exist, create it - if err := s.db.Create(&defaultConfigVar).Error; err != nil { + err = tx. + WithContext(ctx). + Create(&defaultConfigVar). + Error + if err != nil { return err } continue + } else if err != nil { + return err } // Update existing configuration if it differs from the default - if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue { + if storedConfigVar.Type != defaultConfigVar.Type || + storedConfigVar.IsPublic != defaultConfigVar.IsPublic || + storedConfigVar.IsInternal != defaultConfigVar.IsInternal || + storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue { + // Set values storedConfigVar.Type = defaultConfigVar.Type storedConfigVar.IsPublic = defaultConfigVar.IsPublic storedConfigVar.IsInternal = defaultConfigVar.IsInternal storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue - if err := s.db.Save(&storedConfigVar).Error; err != nil { + + err = tx. + WithContext(ctx). + Save(&storedConfigVar). + Error + if err != nil { return err } } @@ -346,43 +401,68 @@ func (s *AppConfigService) InitDbConfig() error { // Delete any configurations not in the default keys var allConfigVars []model.AppConfigVariable - if err := s.db.Find(&allConfigVars).Error; err != nil { + err = tx. + WithContext(ctx). + Find(&allConfigVars). + Error + if err != nil { return err } for _, config := range allConfigVars { - if _, exists := defaultKeys[config.Key]; !exists { - if err := s.db.Delete(&config).Error; err != nil { - return err - } + if _, exists := defaultKeys[config.Key]; exists { + continue + } + + err = tx. + WithContext(ctx). + Delete(&config). + Error + if err != nil { + return err } } - return s.LoadDbConfigFromDb() + + // Commit the changes + err = tx.Commit().Error + if err != nil { + return err + } + + // Reload the configuration + err = s.LoadDbConfigFromDb() + if err != nil { + return err + } + + return nil } // LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct. func (s *AppConfigService) LoadDbConfigFromDb() error { - dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem() + return s.db.Transaction(func(tx *gorm.DB) error { + dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem() - for i := 0; i < dbConfigReflectValue.NumField(); i++ { - dbConfigField := dbConfigReflectValue.Field(i) - currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable) - var storedConfigVar model.AppConfigVariable - if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil { - return err + for i := range dbConfigReflectValue.NumField() { + dbConfigField := dbConfigReflectValue.Field(i) + currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable) + var storedConfigVar model.AppConfigVariable + err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error + if err != nil { + return err + } + + if common.EnvConfig.UiConfigDisabled { + storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue) + } else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" { + storedConfigVar.Value = storedConfigVar.DefaultValue + } + + dbConfigField.Set(reflect.ValueOf(storedConfigVar)) } - if common.EnvConfig.UiConfigDisabled { - storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue) - } else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" { - storedConfigVar.Value = storedConfigVar.DefaultValue - } - - dbConfigField.Set(reflect.ValueOf(storedConfigVar)) - - } - - return nil + return nil + }) } func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string { diff --git a/backend/internal/service/audit_log_service.go b/backend/internal/service/audit_log_service.go index aaf4c8b3..3f2848b7 100644 --- a/backend/internal/service/audit_log_service.go +++ b/backend/internal/service/audit_log_service.go @@ -25,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe } // Create creates a new audit log entry in the database -func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog { +func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog { country, city, err := s.geoliteService.GetLocationByIP(ipAddress) if err != nil { - log.Printf("Failed to get IP location: %v\n", err) + log.Printf("Failed to get IP location: %v", err) } auditLog := model.AuditLog{ @@ -42,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent } // Save the audit log in the database - if err := s.db.Create(&auditLog).Error; err != nil { - log.Printf("Failed to create audit log: %v\n", err) + err = tx. + WithContext(ctx). + Create(&auditLog). + Error + if err != nil { + log.Printf("Failed to create audit log: %v", err) return model.AuditLog{} } @@ -51,12 +55,17 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent } // CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before -func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog { - createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}) +func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog { + createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx) // Count the number of times the user has logged in from the same device var count int64 - err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error + err := tx. + WithContext(ctx). + Model(&model.AuditLog{}). + Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent). + Count(&count). + Error if err != nil { log.Printf("Failed to count audit logs: %v\n", err) return createdAuditLog @@ -64,11 +73,23 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID // If the user hasn't logged in from the same device before and email notifications are enabled, send an email if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 { + // We use a background context here as this is running in a goroutine + //nolint:contextcheck go func() { - var user model.User - s.db.Where("id = ?", userID).First(&user) + innerCtx := context.Background() - err := SendEmail(s.emailService, email.Address{ + // Note we don't use the transaction here because this is running in background + var user model.User + innerErr := s.db. + WithContext(innerCtx). + Where("id = ?", userID). + First(&user). + Error + if innerErr != nil { + log.Printf("Failed to load user: %v", innerErr) + } + + innerErr = SendEmail(innerCtx, s.emailService, email.Address{ Name: user.Username, Email: user.Email, }, NewLoginTemplate, &NewLoginTemplateData{ @@ -78,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID Device: s.DeviceStringFromUserAgent(userAgent), DateTime: createdAuditLog.CreatedAt.UTC(), }) - if err != nil { - log.Printf("Failed to send email to '%s': %v\n", user.Email, err) + if innerErr != nil { + log.Printf("Failed to send email to '%s': %v", user.Email, innerErr) } }() } @@ -88,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID } // ListAuditLogsForUser retrieves all audit logs for a given user ID -func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) { +func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) { var logs []model.AuditLog - query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID) + query := s.db. + WithContext(ctx). + Model(&model.AuditLog{}). + Where("user_id = ?", userID) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs) return logs, pagination, err @@ -162,19 +186,19 @@ func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[s } func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) { + dialect := s.db.Name() query := s.db. WithContext(ctx). Model(&model.AuditLog{}) - dialect := s.db.Name() switch dialect { case "sqlite": query = query. - Select("DISTINCT json_extract(data, '$.clientName') as client_name"). + Select("DISTINCT json_extract(data, '$.clientName') AS client_name"). Where("json_extract(data, '$.clientName') IS NOT NULL") case "postgres": query = query. - Select("DISTINCT data->>'clientName' as client_name"). + Select("DISTINCT data->>'clientName' AS client_name"). Where("data->>'clientName' IS NOT NULL") default: return nil, fmt.Errorf("unsupported database dialect: %s", dialect) diff --git a/backend/internal/service/custom_claim_service.go b/backend/internal/service/custom_claim_service.go index b43ebe0a..fe6f4ef8 100644 --- a/backend/internal/service/custom_claim_service.go +++ b/backend/internal/service/custom_claim_service.go @@ -1,34 +1,14 @@ package service import ( + "context" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" "gorm.io/gorm" ) -// Reserved claims -var reservedClaims = map[string]struct{}{ - "given_name": {}, - "family_name": {}, - "name": {}, - "email": {}, - "preferred_username": {}, - "groups": {}, - "sub": {}, - "iss": {}, - "aud": {}, - "exp": {}, - "iat": {}, - "auth_time": {}, - "nonce": {}, - "acr": {}, - "amr": {}, - "azp": {}, - "nbf": {}, - "jti": {}, -} - type CustomClaimService struct { db *gorm.DB } @@ -39,8 +19,29 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService { // isReservedClaim checks if a claim key is reserved e.g. email, preferred_username func isReservedClaim(key string) bool { - _, ok := reservedClaims[key] - return ok + switch key { + case "given_name", + "family_name", + "name", + "email", + "preferred_username", + "groups", + "sub", + "iss", + "aud", + "exp", + "iat", + "auth_time", + "nonce", + "acr", + "amr", + "azp", + "nbf", + "jti": + return true + default: + return false + } } // idType is the type of the id used to identify the user or user group @@ -52,28 +53,38 @@ const ( ) // UpdateCustomClaimsForUser updates the custom claims for a user -func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { - return s.updateCustomClaims(UserID, userID, claims) +func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(ctx, UserID, userID, claims) } // UpdateCustomClaimsForUserGroup updates the custom claims for a user group -func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { - return s.updateCustomClaims(UserGroupID, userGroupID, claims) +func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { + return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims) } // updateCustomClaims updates the custom claims for a user or user group -func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { +func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { // Check for duplicate keys in the claims slice - seenKeys := make(map[string]bool) + seenKeys := make(map[string]struct{}) for _, claim := range claims { - if seenKeys[claim.Key] { + if _, ok := seenKeys[claim.Key]; ok { return nil, &common.DuplicateClaimError{Key: claim.Key} } - seenKeys[claim.Key] = true + seenKeys[claim.Key] = struct{}{} } + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var existingClaims []model.CustomClaim - err := s.db.Where(string(idType), value).Find(&existingClaims).Error + err := tx. + WithContext(ctx). + Where(string(idType), value). + Find(&existingClaims). + Error if err != nil { return nil, err } @@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla break } } + if !found { - err = s.db.Delete(&existingClaim).Error + err = tx. + WithContext(ctx). + Delete(&existingClaim). + Error if err != nil { return nil, err } @@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla } // Update the claim if it already exists or create a new one - err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error + err = tx. + WithContext(ctx). + Where(string(idType)+" = ? AND key = ?", value, claim.Key). + Assign(&customClaim). + FirstOrCreate(&model.CustomClaim{}). + Error if err != nil { return nil, err } @@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla // Get the updated claims var updatedClaims []model.CustomClaim - err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error + err = tx. + WithContext(ctx). + Where(string(idType)+" = ?", value). + Find(&updatedClaims). + Error + if err != nil { + return nil, err + } + + err = tx.Commit().Error if err != nil { return nil, err } @@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla return updatedClaims, nil } -func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) { +func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) { var customClaims []model.CustomClaim - err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error + err := tx. + WithContext(ctx). + Where("user_id = ?", userID). + Find(&customClaims). + Error return customClaims, err } -func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) { +func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) { var customClaims []model.CustomClaim - err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error + err := tx. + WithContext(ctx). + Where("user_group_id = ?", userGroupID). + Find(&customClaims). + Error return customClaims, err } // GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of, // prioritizing the user's claims over user group claims with the same key. -func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) { +func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) { // Get the custom claims of the user - customClaims, err := s.GetCustomClaimsForUser(userID) + customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx) if err != nil { return nil, err } @@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) // Get all user groups of the user var userGroupsOfUser []model.UserGroup - err = s.db.Preload("CustomClaims"). + err = tx. + WithContext(ctx). + Preload("CustomClaims"). Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id"). Where("user_groups_users.user_id = ?", userID). Find(&userGroupsOfUser).Error @@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) } // GetSuggestions returns a list of custom claim keys that have been used before -func (s *CustomClaimService) GetSuggestions() ([]string, error) { +func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) { var customClaimsKeys []string - err := s.db.Model(&model.CustomClaim{}). + err := s.db. + WithContext(ctx). + Model(&model.CustomClaim{}). Group("key"). Order("COUNT(*) DESC"). Pluck("key", &customClaimsKeys).Error diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index b596195c..fd7b1ee9 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -3,6 +3,7 @@ package service import ( + "context" "crypto/ecdsa" "crypto/x509" "encoding/base64" @@ -34,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService} } +//nolint:gocognit func (s *TestService) SeedDatabase() error { return s.db.Transaction(func(tx *gorm.DB) error { users := []model.User{ @@ -187,11 +189,8 @@ func (s *TestService) SeedDatabase() error { // openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \ // openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout) - publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==") - publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==") - if err != nil { - return err - } + publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==") + publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==") webauthnCredentials := []model.WebauthnCredential{ { Name: "Passkey 1", @@ -303,7 +302,7 @@ func (s *TestService) ResetApplicationImages() error { func (s *TestService) ResetAppConfig() error { // Reseed the config variables - if err := s.appConfigService.InitDbConfig(); err != nil { + if err := s.appConfigService.InitDbConfig(context.Background()); err != nil { return err } @@ -320,7 +319,7 @@ func (s *TestService) SetJWTKeys() { const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}` privateKey, _ := jwk.ParseKey([]byte(privateKeyString)) - s.jwtService.SetKey(privateKey) + _ = s.jwtService.SetKey(privateKey) } // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index dfb38777..231ccecb 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -2,10 +2,12 @@ package service import ( "bytes" + "context" "crypto/tls" "errors" "fmt" htemplate "html/template" + "io" "mime/multipart" "mime/quotedprintable" "net/textproto" @@ -17,10 +19,11 @@ import ( "github.com/emersion/go-sasl" "github.com/emersion/go-smtp" "github.com/google/uuid" + "gorm.io/gorm" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/utils/email" - "gorm.io/gorm" ) type EmailService struct { @@ -49,20 +52,24 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer }, nil } -func (srv *EmailService) SendTestEmail(recipientUserId string) error { +func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error { var user model.User - if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil { + err := srv.db. + WithContext(ctx). + First(&user, "id = ?", recipientUserId). + Error + if err != nil { return err } - return SendEmail(srv, + return SendEmail(ctx, srv, email.Address{ Email: user.Email, Name: user.FullName(), }, TestTemplate, nil) } -func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { +func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { data := &email.TemplateData[V]{ AppName: srv.appConfigService.DbConfig.AppName.Value, LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo", @@ -112,6 +119,15 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T c.Body(body) + // Check if the context is still valid before attemtping to connect + // We need to do this because the smtp library doesn't have context support + select { + case <-ctx.Done(): + return ctx.Err() + default: + // All good + } + // Connect to the SMTP server client, err := srv.getSmtpClient() if err != nil { @@ -119,6 +135,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T } defer client.Close() + // Check if the context is still valid before sending the email + select { + case <-ctx.Done(): + return ctx.Err() + default: + // All good + } + // Send the email if err := srv.sendEmailContent(client, toEmail, c); err != nil { return fmt.Errorf("send email content: %w", err) @@ -215,7 +239,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add } // Write the email content - _, err = w.Write([]byte(c.String())) + _, err = io.Copy(w, strings.NewReader(c.String())) if err != nil { return fmt.Errorf("failed to write email data: %w", err) } diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index 2c5361f1..6e839e91 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -42,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{ } // NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database. -func NewGeoLiteService() *GeoLiteService { +func NewGeoLiteService(ctx context.Context) *GeoLiteService { service := &GeoLiteService{} if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl { @@ -52,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService { } go func() { - if err := service.updateDatabase(); err != nil { - log.Printf("Failed to update GeoLite2 City database: %v\n", err) + err := service.updateDatabase(ctx) + if err != nil { + log.Printf("Failed to update GeoLite2 City database: %v", err) } }() @@ -111,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } // UpdateDatabase checks the age of the database and updates it if it's older than 14 days. -func (s *GeoLiteService) updateDatabase() error { +func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error { if s.disableUpdater { // Avoid updating the GeoLite2 City database. return nil @@ -125,7 +126,7 @@ func (s *GeoLiteService) updateDatabase() error { log.Println("Updating GeoLite2 City database...") downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil) diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index ce8282b4..d7e5df10 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) { // Setup LDAP connection ldapURL := s.appConfigService.DbConfig.LdapUrl.Value skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue() - client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec + client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: skipTLSVerify, //nolint:gosec + })) if err != nil { return nil, fmt.Errorf("failed to connect to LDAP: %w", err) } @@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) { return client, nil } -func (s *LdapService) SyncAll() error { - err := s.SyncUsers() +func (s *LdapService) SyncAll(ctx context.Context) error { + // Start a transaction + tx := s.db.Begin() + + err := s.SyncUsers(ctx, tx) if err != nil { return fmt.Errorf("failed to sync users: %w", err) } - err = s.SyncGroups() + err = s.SyncGroups(ctx, tx) if err != nil { return fmt.Errorf("failed to sync groups: %w", err) } + // Commit the changes + err = tx.Commit().Error + if err != nil { + return fmt.Errorf("failed to commit changes to database: %w", err) + } + return nil } //nolint:gocognit -func (s *LdapService) SyncGroups() error { +func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error { // Setup LDAP connection client, err := s.createClient() if err != nil { @@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error { // Try to find the group in the database var databaseGroup model.UserGroup - s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup) + tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup) // Get group members and add to the correct Group groupMembers := value.GetAttributeValues(groupMemberOfAttribute) @@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error { singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0] var databaseUser model.User - err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error + err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // The user collides with a non-LDAP user, so we skip it @@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error { } if databaseGroup.ID == "" { - newGroup, err := s.groupService.Create(syncGroup) + newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx) if err != nil { - log.Printf("Error syncing group %s: %s", syncGroup.Name, err) - } else { - if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil { - log.Printf("Error syncing group %s: %s", syncGroup.Name, err) - } + log.Printf("Error syncing group %s: %v", syncGroup.Name, err) + continue + } + + _, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx) + if err != nil { + log.Printf("Error syncing group %s: %v", syncGroup.Name, err) + continue } } else { - _, err = s.groupService.Update(databaseGroup.ID, syncGroup, true) + _, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx) if err != nil { - log.Printf("Error syncing group %s: %s", syncGroup.Name, err) - } - _, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId) - if err != nil { - log.Printf("Error syncing group %s: %s", syncGroup.Name, err) - return err + log.Printf("Error syncing group %s: %v", syncGroup.Name, err) + continue } + _, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx) + if err != nil { + log.Printf("Error syncing group %s: %v", syncGroup.Name, err) + continue + } } - } // Get all LDAP groups from the database var ldapGroupsInDb []model.UserGroup - if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { - fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err)) + err = tx. + WithContext(ctx). + Find(&ldapGroupsInDb, "ldap_id IS NOT NULL"). + Select("ldap_id"). + Error + if err != nil { + log.Printf("Failed to fetch groups from database: %v", err) } // Delete groups that no longer exist in LDAP for _, group := range ldapGroupsInDb { if _, exists := ldapGroupIDs[*group.LdapID]; !exists { - if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil { + err = tx. + WithContext(ctx). + Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID). + Error + if err != nil { log.Printf("Failed to delete group %s with: %v", group.Name, err) } else { log.Printf("Deleted group %s", group.Name) @@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error { } //nolint:gocognit -func (s *LdapService) SyncUsers() error { +func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error { // Setup LDAP connection client, err := s.createClient() if err != nil { @@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error { // Get the user from the database var databaseUser model.User - s.db.Where("ldap_id = ?", ldapId).First(&databaseUser) + tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser) // Check if user is admin by checking if they are in the admin group isAdmin := false @@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error { } if databaseUser.ID == "" { - _, err = s.userService.CreateUser(newUser) + _, err = s.userService.createUserInternal(ctx, newUser, tx) if err != nil { - log.Printf("Error syncing user %s: %s", newUser.Username, err) + log.Printf("Error syncing user %s: %v", newUser.Username, err) } } else { - _, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true) + _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx) if err != nil { - log.Printf("Error syncing user %s: %s", newUser.Username, err) + log.Printf("Error syncing user %s: %v", newUser.Username, err) } } // Save profile picture if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" { - if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil { - log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err) + if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil { + log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err) } } } // Get all LDAP users from the database var ldapUsersInDb []model.User - if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { - fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err)) + err = tx. + WithContext(ctx). + Find(&ldapUsersInDb, "ldap_id IS NOT NULL"). + Select("ldap_id"). + Error + if err != nil { + log.Printf("Failed to fetch users from database: %v", err) } // Delete users that no longer exist in LDAP for _, user := range ldapUsersInDb { if _, exists := ldapUserIDs[*user.LdapID]; !exists { - if err := s.userService.DeleteUser(user.ID, true); err != nil { + if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil { log.Printf("Failed to delete user %s with: %v", user.Username, err) } else { log.Printf("Deleted user %s", user.Username) } } } + return nil } -func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error { +func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error { var reader io.Reader - if _, err := url.ParseRequestURI(pictureString); err == nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _, err := url.ParseRequestURI(pictureString) + if err == nil { + ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil) + var req *http.Request + req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - response, err := http.DefaultClient.Do(req) + var res *http.Response + res, err = http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("failed to download profile picture: %w", err) } - defer response.Body.Close() - - reader = response.Body + defer res.Body.Close() + reader = res.Body } else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil { // If the photo is a base64 encoded string, decode it reader = bytes.NewReader(decodedPhoto) - } else { // If the photo is a string, we assume that it's a binary string reader = bytes.NewReader([]byte(pictureString)) diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index abc3e1d6..e5dd347d 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/sha256" "encoding/base64" "encoding/json" @@ -9,6 +10,7 @@ import ( "mime/multipart" "os" "regexp" + "slices" "strings" "time" @@ -39,9 +41,20 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo } } -func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { +func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("AllowedUserGroups"). + First(&client, "id = ?", input.ClientID). + Error + if err != nil { return "", "", err } @@ -58,7 +71,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, // Check if the user group is allowed to authorize the client var user model.User - if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil { + err = tx. + WithContext(ctx). + Preload("UserGroups"). + First(&user, "id = ?", userID). + Error + if err != nil { return "", "", err } @@ -67,7 +85,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, } // Check if the user has already authorized the client with the given scope - hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope) + hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, input.ClientID, userID, input.Scope, tx) if err != nil { return "", "", err } @@ -80,39 +98,55 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, Scope: input.Scope, } - if err := s.db.Create(&userAuthorizedClient).Error; err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - // The client has already been authorized but with a different scope so we need to update the scope - if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil { - return "", "", err - } - } else { + err = tx. + WithContext(ctx). + Create(&userAuthorizedClient). + Error + if errors.Is(err, gorm.ErrDuplicatedKey) { + // The client has already been authorized but with a different scope so we need to update the scope + if err := tx. + WithContext(ctx). + Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil { return "", "", err } + } else if err != nil { + return "", "", err } } // Create the authorization code - code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod) + code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx) if err != nil { return "", "", err } // Log the authorization event if hasAuthorizedClient { - s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}) + s.auditLogService.Create(ctx, model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx) } else { - s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}) + s.auditLogService.Create(ctx, model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx) + } + err = tx.Commit().Error + if err != nil { + return "", "", err } return code, callbackURL, nil } // HasAuthorizedClient checks if the user has already authorized the client with the given scope -func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) { +func (s *OidcService) HasAuthorizedClient(ctx context.Context, clientID, userID, scope string) (bool, error) { + return s.hasAuthorizedClientInternal(ctx, clientID, userID, scope, s.db) +} + +func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID, userID, scope string, tx *gorm.DB) (bool, error) { var userAuthorizedOidcClient model.UserAuthorizedOidcClient - if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil { + err := tx. + WithContext(ctx). + First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil } @@ -145,21 +179,31 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode return isAllowedToAuthorize } -func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) { +func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) { switch grantType { case "authorization_code": - return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier) + return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier) case "refresh_token": - accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret) + accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret) return "", accessToken, newRefreshToken, exp, err default: return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{} } } -func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) { +func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err = tx. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return "", "", "", 0, err } @@ -176,7 +220,11 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec } var authorizationCodeMetaData model.OidcAuthorizationCode - err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error + err = tx. + WithContext(ctx). + Preload("User"). + First(&authorizationCodeMetaData, "code = ?", code). + Error if err != nil { return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} } @@ -192,7 +240,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} } - userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) + userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx) if err != nil { return "", "", "", 0, err } @@ -203,7 +251,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec } // Generate a refresh token - refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope) + refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx) if err != nil { return "", "", "", 0, err } @@ -213,19 +261,40 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec return "", "", "", 0, err } - s.db.Delete(&authorizationCodeMetaData) + err = tx. + WithContext(ctx). + Delete(&authorizationCodeMetaData). + Error + if err != nil { + return "", "", "", 0, err + } + + err = tx.Commit().Error + if err != nil { + return "", "", "", 0, err + } return idToken, accessToken, refreshToken, 3600, nil } -func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) { +func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) { if refreshToken == "" { return "", "", 0, &common.OidcMissingRefreshTokenError{} } + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + // Get the client to check if it's public var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err = tx. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return "", "", 0, err } @@ -243,7 +312,9 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client // Verify refresh token var storedRefreshToken model.OidcRefreshToken - err = s.db.Preload("User"). + err = tx. + WithContext(ctx). + Preload("User"). Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). First(&storedRefreshToken). Error @@ -266,29 +337,53 @@ func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, client } // Generate a new refresh token and invalidate the old one - newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope) + newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) if err != nil { return "", "", 0, err } // Delete the used refresh token - s.db.Delete(&storedRefreshToken) + err = tx. + WithContext(ctx). + Delete(&storedRefreshToken). + Error + if err != nil { + return "", "", 0, err + } + + err = tx.Commit().Error + if err != nil { + return "", "", 0, err + } return accessToken, newRefreshToken, 3600, nil } -func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) { +func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) { + return s.getClientInternal(ctx, clientID, s.db) +} + +func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) { var client model.OidcClient - if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("CreatedBy"). + Preload("AllowedUserGroups"). + First(&client, "id = ?", clientID). + Error + if err != nil { return model.OidcClient{}, err } return client, nil } -func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) { +func (s *OidcService) ListClients(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) { var clients []model.OidcClient - query := s.db.Preload("CreatedBy").Model(&model.OidcClient{}) + query := s.db. + WithContext(ctx). + Preload("CreatedBy"). + Model(&model.OidcClient{}) if searchTerm != "" { searchPattern := "%" + searchTerm + "%" query = query.Where("name LIKE ?", searchPattern) @@ -302,7 +397,7 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti return clients, pagination, nil } -func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { +func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { client := model.OidcClient{ Name: input.Name, CallbackURLs: input.CallbackURLs, @@ -312,16 +407,31 @@ func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) PkceEnabled: input.IsPublic || input.PkceEnabled, } - if err := s.db.Create(&client).Error; err != nil { + err := s.db. + WithContext(ctx). + Create(&client). + Error + if err != nil { return model.OidcClient{}, err } return client, nil } -func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) { +func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("CreatedBy"). + First(&client, "id = ?", clientID). + Error + if err != nil { return model.OidcClient{}, err } @@ -331,29 +441,49 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt client.IsPublic = input.IsPublic client.PkceEnabled = input.IsPublic || input.PkceEnabled - if err := s.db.Save(&client).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return model.OidcClient{}, err + } + + err = tx.Commit().Error + if err != nil { return model.OidcClient{}, err } return client, nil } -func (s *OidcService) DeleteClient(clientID string) error { +func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error { var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return err - } - - if err := s.db.Delete(&client).Error; err != nil { + err := s.db. + WithContext(ctx). + Where("id = ?", clientID). + Delete(&client). + Error + if err != nil { return err } return nil } -func (s *OidcService) CreateClientSecret(clientID string) (string, error) { +func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (string, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err := tx. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return "", err } @@ -368,16 +498,29 @@ func (s *OidcService) CreateClientSecret(clientID string) (string, error) { } client.Secret = string(hashedSecret) - if err := s.db.Save(&client).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return "", err + } + + err = tx.Commit().Error + if err != nil { return "", err } return clientSecret, nil } -func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { +func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (string, string, error) { var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err := s.db. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return "", "", err } @@ -385,26 +528,36 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { return "", "", errors.New("image not found") } - imageType := *client.ImageType - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType) - mimeType := utils.GetImageMimeType(imageType) + imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType + mimeType := utils.GetImageMimeType(*client.ImageType) return imagePath, mimeType, nil } -func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error { +func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error { fileType := utils.GetFileExtension(file.Filename) if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { return &common.FileTypeNotSupportedError{} } - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType) - if err := utils.SaveFile(file, imagePath); err != nil { + imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + fileType + err := utils.SaveFile(file, imagePath) + if err != nil { return err } + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err = tx. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return err } @@ -416,16 +569,35 @@ func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHead } client.ImageType = &fileType - if err := s.db.Save(&client).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return err + } + + err = tx.Commit().Error + if err != nil { return err } return nil } -func (s *OidcService) DeleteClientLogo(clientID string) error { +func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + err := tx. + WithContext(ctx). + First(&client, "id = ?", clientID). + Error + if err != nil { return err } @@ -433,38 +605,72 @@ func (s *OidcService) DeleteClientLogo(clientID string) error { return errors.New("image not found") } - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) + client.ImageType = nil + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return err + } + + imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType if err := os.Remove(imagePath); err != nil { return err } - client.ImageType = nil - if err := s.db.Save(&client).Error; err != nil { + err = tx.Commit().Error + if err != nil { return err } return nil } -func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) { +func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db) + if err != nil { + return nil, err + } + + err = tx.Commit().Error + if err != nil { + return nil, err + } + + return claims, nil +} + +func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) { var authorizedOidcClient model.UserAuthorizedOidcClient - if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("User.UserGroups"). + First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID). + Error + if err != nil { return nil, err } user := authorizedOidcClient.User - scope := authorizedOidcClient.Scope + scopes := strings.Split(authorizedOidcClient.Scope, " ") claims := map[string]interface{}{ "sub": user.ID, } - if strings.Contains(scope, "email") { + if slices.Contains(scopes, "email") { claims["email"] = user.Email claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue() } - if strings.Contains(scope, "groups") { + if slices.Contains(scopes, "groups") { userGroups := make([]string, len(user.UserGroups)) for i, group := range user.UserGroups { userGroups[i] = group.Name @@ -477,17 +683,17 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma "family_name": user.LastName, "name": user.FullName(), "preferred_username": user.Username, - "picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID), + "picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png", } - if strings.Contains(scope, "profile") { + if slices.Contains(scopes, "profile") { // Add profile claims for k, v := range profileClaims { claims[k] = v } // Add custom claims - customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID) + customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx) if err != nil { return nil, err } @@ -505,15 +711,22 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma } } } - if strings.Contains(scope, "email") { + + if slices.Contains(scopes, "email") { claims["email"] = user.Email } return claims, nil } -func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) { - client, err = s.GetClient(id) +func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + client, err = s.getClientInternal(ctx, id, tx) if err != nil { return model.OidcClient{}, err } @@ -521,18 +734,37 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll // Fetch the user groups based on UserGroupIDs in input var groups []model.UserGroup if len(input.UserGroupIDs) > 0 { - if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil { + err = tx. + WithContext(ctx). + Where("id IN (?)", input.UserGroupIDs). + Find(&groups). + Error + if err != nil { return model.OidcClient{}, err } } // Replace the current user groups with the new set of user groups - if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil { + err = tx. + WithContext(ctx). + Model(&client). + Association("AllowedUserGroups"). + Replace(groups) + if err != nil { return model.OidcClient{}, err } // Save the updated client - if err := s.db.Save(&client).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return model.OidcClient{}, err + } + + err = tx.Commit().Error + if err != nil { return model.OidcClient{}, err } @@ -540,7 +772,7 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll } // ValidateEndSession returns the logout callback URL for the client if all the validations pass -func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) { +func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) { // If no ID token hint is provided, return an error if input.IdTokenHint == "" { return "", &common.TokenInvalidError{} @@ -564,7 +796,12 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) // Check if the user has authorized the client before var userAuthorizedOIDCClient model.UserAuthorizedOidcClient - if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil { + err = s.db. + WithContext(ctx). + Preload("Client"). + First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID). + Error + if err != nil { return "", &common.OidcMissingAuthorizationError{} } @@ -582,7 +819,7 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) } -func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) { +func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(32) if err != nil { return "", err @@ -601,7 +838,11 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc CodeChallengeMethodSha256: &codeChallengeMethodSha256, } - if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil { + err = tx. + WithContext(ctx). + Create(&oidcAuthorizationCode). + Error + if err != nil { return "", err } @@ -647,7 +888,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca return "", &common.OidcInvalidCallbackURLError{} } -func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) { +func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) { refreshToken, err := utils.GenerateRandomAlphanumericString(40) if err != nil { return "", err @@ -665,7 +906,11 @@ func (s *OidcService) createRefreshToken(clientID string, userID string, scope s Scope: scope, } - if err := s.db.Create(&m).Error; err != nil { + err = tx. + WithContext(ctx). + Create(&m). + Error + if err != nil { return "", err } diff --git a/backend/internal/service/user_group_service.go b/backend/internal/service/user_group_service.go index e9835274..fe7c432c 100644 --- a/backend/internal/service/user_group_service.go +++ b/backend/internal/service/user_group_service.go @@ -1,13 +1,15 @@ package service import ( + "context" "errors" + "gorm.io/gorm" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/utils" - "gorm.io/gorm" ) type UserGroupService struct { @@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG return &UserGroupService{db: db, appConfigService: appConfigService} } -func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) { - query := s.db.Preload("CustomClaims").Model(&model.UserGroup{}) +func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) { + query := s.db. + WithContext(ctx). + Preload("CustomClaims"). + Model(&model.UserGroup{}) if name != "" { query = query.Where("name LIKE ?", "%"+name+"%") @@ -42,26 +47,59 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte return groups, response, err } -func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) { - err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error +func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) { + return s.getInternal(ctx, id, s.db) +} + +func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) { + err = tx. + WithContext(ctx). + Where("id = ?", id). + Preload("CustomClaims"). + Preload("Users"). + First(&group). + Error return group, err } -func (s *UserGroupService) Delete(id string) error { +func (s *UserGroupService) Delete(ctx context.Context, id string) error { + tx := s.db.Begin() + var group model.UserGroup - if err := s.db.Where("id = ?", id).First(&group).Error; err != nil { + err := tx. + WithContext(ctx). + Where("id = ?", id). + First(&group). + Error + if err != nil { return err } // Disallow deleting the group if it is an LDAP group and LDAP is enabled if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { + err = tx.Rollback().Error + if err != nil { + return err + } return &common.LdapUserGroupUpdateError{} } - return s.db.Delete(&group).Error + err = tx. + WithContext(ctx). + Delete(&group). + Error + if err != nil { + return err + } + + return tx.Commit().Error } -func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) { +func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) { + return s.createInternal(ctx, input, s.db) +} + +func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) { group = model.UserGroup{ FriendlyName: input.FriendlyName, Name: input.Name, @@ -71,7 +109,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use group.LdapID = &input.LdapID } - if err := s.db.Preload("Users").Create(&group).Error; err != nil { + err = tx. + WithContext(ctx). + Preload("Users"). + Create(&group). + Error + if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } @@ -80,8 +123,26 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use return group, nil } -func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) { - group, err = s.Get(id) +func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) { + tx := s.db.Begin() + + group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx) + if err != nil { + tx.Rollback() + return model.UserGroup{}, err + } + + err = tx.Commit().Error + if err != nil { + tx.Rollback() + return model.UserGroup{}, err + } + + return group, nil +} + +func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) { + group, err = s.getInternal(ctx, id, tx) if err != nil { return model.UserGroup{}, err } @@ -94,7 +155,12 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow group.Name = input.Name group.FriendlyName = input.FriendlyName - if err := s.db.Preload("Users").Save(&group).Error; err != nil { + err = tx. + WithContext(ctx). + Preload("Users"). + Save(&group). + Error + if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} } @@ -103,8 +169,26 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow return group, nil } -func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) { - group, err = s.Get(id) +func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) { + tx := s.db.Begin() + + group, err = s.updateUsersInternal(ctx, id, userIds, tx) + if err != nil { + tx.Rollback() + return model.UserGroup{}, err + } + + err = tx.Commit().Error + if err != nil { + tx.Rollback() + return model.UserGroup{}, err + } + + return group, nil +} + +func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) { + group, err = s.getInternal(ctx, id, tx) if err != nil { return model.UserGroup{}, err } @@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model // Fetch the users based on the userIds var users []model.User if len(userIds) > 0 { - if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil { + err := tx. + WithContext(ctx). + Where("id IN (?)", userIds). + Find(&users). + Error + if err != nil { return model.UserGroup{}, err } } // Replace the current users with the new set of users - if err := s.db.Model(&group).Association("Users").Replace(users); err != nil { + err = tx. + WithContext(ctx). + Model(&group). + Association("Users"). + Replace(users) + if err != nil { return model.UserGroup{}, err } // Save the updated group - if err := s.db.Save(&group).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&group). + Error + if err != nil { return model.UserGroup{}, err } return group, nil } -func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) { +func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) { + // We only perform select queries here, so we can rollback in all cases + tx := s.db.Begin() + defer func() { + tx.Rollback() + }() + var group model.UserGroup - if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("Users"). + Where("id = ?", id). + First(&group). + Error + if err != nil { return 0, err } - return s.db.Model(&group).Association("Users").Count(), nil + count := tx. + WithContext(ctx). + Model(&group). + Association("Users"). + Count() + return count, nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3c0c0cb1..42a35c07 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "context" "errors" "fmt" "io" @@ -12,7 +13,7 @@ import ( "time" "github.com/google/uuid" - profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image" + "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" @@ -20,7 +21,7 @@ import ( datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils/email" - "gorm.io/gorm" + profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image" ) type UserService struct { @@ -35,9 +36,9 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService} } -func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) { +func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) { var users []model.User - query := s.db.Model(&model.User{}) + query := s.db.WithContext(ctx).Model(&model.User{}) if searchTerm != "" { searchPattern := "%" + searchTerm + "%" @@ -48,13 +49,23 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils return users, pagination, err } -func (s *UserService) GetUser(userID string) (model.User, error) { +func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) { + return s.getUserInternal(ctx, userID, s.db) +} + +func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) { var user model.User - err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error + err := tx. + WithContext(ctx). + Preload("UserGroups"). + Preload("CustomClaims"). + Where("id = ?", userID). + First(&user). + Error return user, err } -func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, error) { +func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) { // Validate the user ID to prevent directory traversal if err := uuid.Validate(userID); err != nil { return nil, 0, &common.InvalidUUIDError{} @@ -74,7 +85,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er } // If no custom picture exists, get the user's data for creating initials - user, err := s.GetUser(userID) + user, err := s.GetUser(ctx, userID) if err != nil { return nil, 0, err } @@ -115,9 +126,15 @@ func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, er return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil } -func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) { +func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) { var user model.User - if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil { + err := s.db. + WithContext(ctx). + Preload("UserGroups"). + Where("id = ?", userID). + First(&user). + Error + if err != nil { return nil, err } return user.UserGroups, nil @@ -152,9 +169,21 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error return nil } -func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error { +func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error { + return s.db.Transaction(func(tx *gorm.DB) error { + return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx) + }) +} + +func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error { var user model.User - if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { + + err := tx. + WithContext(ctx). + Where("id = ?", userID). + First(&user). + Error + if err != nil { return err } @@ -165,14 +194,35 @@ func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error { // Delete the profile picture profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" - if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) { + err = os.Remove(profilePicturePath) + if err != nil && !os.IsNotExist(err) { return err } - return s.db.Delete(&user).Error + return tx.WithContext(ctx).Delete(&user).Error } -func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) { +func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + user, err := s.createUserInternal(ctx, input, tx) + if err != nil { + return model.User{}, err + } + + err = tx.Commit().Error + if err != nil { + return model.User{}, err + } + + return user, nil +} + +func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) { user := model.User{ FirstName: input.FirstName, LastName: input.LastName, @@ -185,18 +235,47 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) { user.LdapID = &input.LdapID } - if err := s.db.Create(&user).Error; err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - return model.User{}, s.checkDuplicatedFields(user) - } + err := tx.WithContext(ctx).Create(&user).Error + if errors.Is(err, gorm.ErrDuplicatedKey) { + tx.Rollback() + + // If we are here, the transaction is already aborted due to an error, so we pass s.db + err = s.checkDuplicatedFields(ctx, user, s.db) + return model.User{}, err + } else if err != nil { return model.User{}, err } return user, nil } -func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) { +func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx) + if err != nil { + return model.User{}, err + } + + err = tx.Commit().Error + if err != nil { + return model.User{}, err + } + + return user, nil +} + +func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) { var user model.User - if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { + err := tx. + WithContext(ctx). + Where("id = ?", userID). + First(&user). + Error + if err != nil { return model.User{}, err } @@ -214,24 +293,42 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u user.IsAdmin = updatedUser.IsAdmin } - if err := s.db.Save(&user).Error; err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - return user, s.checkDuplicatedFields(user) - } + err = tx. + WithContext(ctx). + Save(&user). + Error + if errors.Is(err, gorm.ErrDuplicatedKey) { + tx.Rollback() + + // If we are here, the transaction is already aborted due to an error, so we pass s.db + err = s.checkDuplicatedFields(ctx, user, s.db) + return user, err + } else if err != nil { return user, err } return user, nil } -func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { +func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddress, redirectPath string) error { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue() if isDisabled { return &common.OneTimeAccessDisabledError{} } var user model.User - if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil { + err := tx. + WithContext(ctx). + Where("email = ?", emailAddress). + First(&user). + Error + if err != nil { // Do not return error if user not found to prevent email enumeration if errors.Is(err, gorm.ErrRecordNotFound) { return nil @@ -240,22 +337,31 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin } } - oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute)) + oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, time.Now().Add(15*time.Minute), tx) if err != nil { return err } - link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL) - linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken) - - // Add redirect path to the link - if strings.HasPrefix(redirectPath, "/") { - encodedRedirectPath := url.QueryEscape(redirectPath) - linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath) + err = tx.Commit().Error + if err != nil { + return err } + // We use a background context here as this is running in a goroutine + //nolint:contextcheck go func() { - err := SendEmail(s.emailService, email.Address{ + innerCtx := context.Background() + + link := common.EnvConfig.AppURL + "/lc" + linkWithCode := link + "/" + oneTimeAccessToken + + // Add redirect path to the link + if strings.HasPrefix(redirectPath, "/") { + encodedRedirectPath := url.QueryEscape(redirectPath) + linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath + } + + errInternal := SendEmail(innerCtx, s.emailService, email.Address{ Name: user.Username, Email: user.Email, }, OneTimeAccessTemplate, &OneTimeAccessTemplateData{ @@ -263,18 +369,21 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin LoginLink: link, LoginLinkWithCode: linkWithCode, }) - if err != nil { - log.Printf("Failed to send email to '%s': %v\n", user.Email, err) + if errInternal != nil { + log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal) } }() return nil } -func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) { - tokenLength := 16 +func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) { + return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db) +} +func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) { // If expires at is less than 15 minutes, use an 6 character token instead of 16 + tokenLength := 16 if time.Until(expiresAt) <= 15*time.Minute { tokenLength = 6 } @@ -290,16 +399,27 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim Token: randomString, } - if err := s.db.Create(&oneTimeAccessToken).Error; err != nil { + if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil { return "", err } return oneTimeAccessToken.Token, nil } -func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) { +func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var oneTimeAccessToken model.OneTimeAccessToken - if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { + err := tx. + WithContext(ctx). + Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User"). + First(&oneTimeAccessToken). + Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return model.User{}, "", &common.TokenInvalidOrExpiredError{} } @@ -310,19 +430,34 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg return model.User{}, "", err } - if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil { + err = tx. + WithContext(ctx). + Delete(&oneTimeAccessToken). + Error + if err != nil { return model.User{}, "", err } if ipAddress != "" && userAgent != "" { - s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}) + s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx) + } + + err = tx.Commit().Error + if err != nil { + return model.User{}, "", err } return oneTimeAccessToken.User, accessToken, nil } -func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) { - user, err = s.GetUser(id) +func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + + user, err = s.getUserInternal(ctx, id, tx) if err != nil { return model.User{}, err } @@ -330,27 +465,49 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m // Fetch the groups based on userGroupIds var groups []model.UserGroup if len(userGroupIds) > 0 { - if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil { + err = tx. + WithContext(ctx). + Where("id IN (?)", userGroupIds). + Find(&groups). + Error + if err != nil { return model.User{}, err } } // Replace the current groups with the new set of groups - if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil { + err = tx. + WithContext(ctx). + Model(&user). + Association("UserGroups"). + Replace(groups) + if err != nil { return model.User{}, err } // Save the updated user - if err := s.db.Save(&user).Error; err != nil { + err = tx.WithContext(ctx).Save(&user).Error + if err != nil { + return model.User{}, err + } + + err = tx.Commit().Error + if err != nil { return model.User{}, err } return user, nil } -func (s *UserService) SetupInitialAdmin() (model.User, string, error) { +func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var userCount int64 - if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil { + if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil { return model.User{}, "", err } if userCount > 1 { @@ -365,7 +522,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { IsAdmin: true, } - if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { + if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { return model.User{}, "", err } @@ -378,16 +535,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) { return model.User{}, "", err } + err = tx.Commit().Error + if err != nil { + return model.User{}, "", err + } + return user, token, nil } -func (s *UserService) checkDuplicatedFields(user model.User) error { - var existingUser model.User - if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { +func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error { + var result struct { + Found bool + } + err := tx. + WithContext(ctx). + Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email). + First(&result). + Error + if err != nil { + return err + } + if result.Found { return &common.AlreadyInUseError{Property: "email"} } - if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { + err = tx. + WithContext(ctx). + Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username). + First(&result). + Error + if err != nil { + return err + } + if result.Found { return &common.AlreadyInUseError{Property: "username"} } diff --git a/backend/internal/service/webauthn_service.go b/backend/internal/service/webauthn_service.go index 29f27de5..fd1cccd0 100644 --- a/backend/internal/service/webauthn_service.go +++ b/backend/internal/service/webauthn_service.go @@ -1,16 +1,19 @@ package service import ( + "context" + "fmt" "net/http" "time" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" + "gorm.io/gorm" + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" "github.com/pocket-id/pocket-id/backend/internal/utils" - "gorm.io/gorm" ) type WebAuthnService struct { @@ -43,15 +46,31 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService} } -func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) { +func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + s.updateWebAuthnConfig() var user model.User - if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil { + err := tx. + WithContext(ctx). + Preload("Credentials"). + Find(&user, "id = ?", userID). + Error + if err != nil { + tx.Rollback() return nil, err } - options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors())) + options, session, err := s.webAuthn.BeginRegistration( + &user, + webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), + webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), + ) if err != nil { return nil, err } @@ -62,7 +81,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred UserVerification: string(session.UserVerification), } - if err := s.db.Create(&sessionToStore).Error; err != nil { + err = tx. + WithContext(ctx). + Create(&sessionToStore). + Error + if err != nil { + return nil, err + } + + err = tx.Commit().Error + if err != nil { return nil, err } @@ -73,9 +101,19 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred }, nil } -func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) { +func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var storedSession model.WebauthnSession - if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { + err := tx. + WithContext(ctx). + First(&storedSession, "id = ?", sessionID). + Error + if err != nil { return model.WebauthnCredential{}, err } @@ -86,7 +124,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R } var user model.User - if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { + err = tx. + WithContext(ctx). + Find(&user, "id = ?", userID). + Error + if err != nil { return model.WebauthnCredential{}, err } @@ -108,7 +150,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R BackupEligible: credential.Flags.BackupEligible, BackupState: credential.Flags.BackupState, } - if err := s.db.Create(&credentialToStore).Error; err != nil { + err = tx. + WithContext(ctx). + Create(&credentialToStore). + Error + if err != nil { + return model.WebauthnCredential{}, err + } + + err = tx.Commit().Error + if err != nil { return model.WebauthnCredential{}, err } @@ -125,7 +176,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string { return "New Passkey" // Default fallback } -func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { +func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) { options, session, err := s.webAuthn.BeginDiscoverableLogin() if err != nil { return nil, err @@ -137,7 +188,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions UserVerification: string(session.UserVerification), } - if err := s.db.Create(&sessionToStore).Error; err != nil { + err = s.db. + WithContext(ctx). + Create(&sessionToStore). + Error + if err != nil { return nil, err } @@ -148,9 +203,19 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions }, nil } -func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) { +func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var storedSession model.WebauthnSession - if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { + err := tx. + WithContext(ctx). + First(&storedSession, "id = ?", sessionID). + Error + if err != nil { return model.User{}, "", err } @@ -160,9 +225,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData } var user *model.User - _, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { - if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil { - return nil, err + _, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { + innerErr := tx. + WithContext(ctx). + Preload("Credentials"). + First(&user, "id = ?", string(userHandle)). + Error + if innerErr != nil { + return nil, innerErr } return user, nil }, session, credentialAssertionData) @@ -176,41 +246,70 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData return model.User{}, "", err } - s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID) + s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx) + + err = tx.Commit().Error + if err != nil { + return model.User{}, "", err + } return *user, token, nil } -func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { +func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) { var credentials []model.WebauthnCredential - if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil { + err := s.db. + WithContext(ctx). + Find(&credentials, "user_id = ?", userID). + Error + if err != nil { return nil, err } return credentials, nil } -func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error { - var credential model.WebauthnCredential - if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil { - return err - } - - if err := s.db.Delete(&credential).Error; err != nil { - return err +func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error { + err := s.db. + WithContext(ctx). + Where("id = ? AND user_id = ?", credentialID, userID). + Delete(&model.WebauthnCredential{}). + Error + if err != nil { + return fmt.Errorf("failed to delete record: %w", err) } return nil } -func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) { +func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) { + tx := s.db.Begin() + defer func() { + // This is a no-op if the transaction has been committed already + tx.Rollback() + }() + var credential model.WebauthnCredential - if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { + err := tx. + WithContext(ctx). + Where("id = ? AND user_id = ?", credentialID, userID). + First(&credential). + Error + if err != nil { return credential, err } credential.Name = name - if err := s.db.Save(&credential).Error; err != nil { + err = tx. + WithContext(ctx). + Save(&credential). + Error + if err != nil { + return credential, err + } + + err = tx.Commit().Error + if err != nil { return credential, err } diff --git a/backend/internal/utils/aaguid_util.go b/backend/internal/utils/aaguid_util.go index e8eb1539..e2611a70 100644 --- a/backend/internal/utils/aaguid_util.go +++ b/backend/internal/utils/aaguid_util.go @@ -12,9 +12,13 @@ import ( var ( aaguidMap map[string]string - aaguidMapOnce sync.Once + aaguidMapOnce *sync.Once ) +func init() { + aaguidMapOnce = &sync.Once{} +} + // FormatAAGUID converts an AAGUID byte slice to UUID string format func FormatAAGUID(aaguid []byte) string { if len(aaguid) == 0 { diff --git a/backend/internal/utils/aaguid_util_test.go b/backend/internal/utils/aaguid_util_test.go index 8433ccf1..f4a6ee13 100644 --- a/backend/internal/utils/aaguid_util_test.go +++ b/backend/internal/utils/aaguid_util_test.go @@ -47,8 +47,10 @@ func TestFormatAAGUID(t *testing.T) { func TestGetAuthenticatorName(t *testing.T) { // Reset the aaguidMap for testing originalMap := aaguidMap + originalOnce := aaguidMapOnce defer func() { aaguidMap = originalMap + aaguidMapOnce = originalOnce }() // Inject a test AAGUID map @@ -56,7 +58,7 @@ func TestGetAuthenticatorName(t *testing.T) { "adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator", "00000000-0000-0000-0000-000000000000": "Zero Authenticator", } - aaguidMapOnce = sync.Once{} + aaguidMapOnce = &sync.Once{} aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file tests := []struct { @@ -99,7 +101,7 @@ func TestGetAuthenticatorName(t *testing.T) { func TestLoadAAGUIDsFromFile(t *testing.T) { // Reset the map and once flag for clean testing aaguidMap = nil - aaguidMapOnce = sync.Once{} + aaguidMapOnce = &sync.Once{} // Trigger loading of AAGUIDs by calling GetAuthenticatorName GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04}) diff --git a/backend/internal/utils/paging_util.go b/backend/internal/utils/paging_util.go index 7ef94318..af85115d 100644 --- a/backend/internal/utils/paging_util.go +++ b/backend/internal/utils/paging_util.go @@ -4,9 +4,8 @@ import ( "reflect" "strconv" - "gorm.io/gorm/clause" - "gorm.io/gorm" + "gorm.io/gorm/clause" ) type PaginationResponse struct { @@ -47,7 +46,6 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor } return Paginate(pagination.Page, pagination.Limit, query, result) - } func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) { diff --git a/backend/internal/utils/ptr_util.go b/backend/internal/utils/ptr_util.go new file mode 100644 index 00000000..947538cd --- /dev/null +++ b/backend/internal/utils/ptr_util.go @@ -0,0 +1,5 @@ +package utils + +func Ptr[T any](v T) *T { + return &v +} diff --git a/backend/main b/backend/main deleted file mode 100755 index 9e03a9c6..00000000 Binary files a/backend/main and /dev/null differ