mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 11:36:46 +00:00
217 lines
5.8 KiB
Go
217 lines
5.8 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type UserSignUpService struct {
|
|
db *gorm.DB
|
|
userService *UserService
|
|
jwtService *JwtService
|
|
auditLogService *AuditLogService
|
|
appConfigService *AppConfigService
|
|
}
|
|
|
|
func NewUserSignupService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService, userService *UserService) *UserSignUpService {
|
|
return &UserSignUpService{
|
|
db: db,
|
|
jwtService: jwtService,
|
|
auditLogService: auditLogService,
|
|
appConfigService: appConfigService,
|
|
userService: userService,
|
|
}
|
|
}
|
|
|
|
func (s *UserSignUpService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAddress, userAgent string) (model.User, string, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
tokenProvided := signupData.Token != ""
|
|
|
|
config := s.appConfigService.GetDbConfig()
|
|
if config.AllowUserSignups.Value != "open" && !tokenProvided {
|
|
return model.User{}, "", &common.OpenSignupDisabledError{}
|
|
}
|
|
|
|
var signupToken model.SignupToken
|
|
var userGroupIDs []string
|
|
if tokenProvided {
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Preload("UserGroups").
|
|
Where("token = ?", signupData.Token).
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
First(&signupToken).
|
|
Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
|
}
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
if !signupToken.IsValid() {
|
|
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
|
}
|
|
|
|
for _, group := range signupToken.UserGroups {
|
|
userGroupIDs = append(userGroupIDs, group.ID)
|
|
}
|
|
}
|
|
|
|
userToCreate := dto.UserCreateDto{
|
|
Username: signupData.Username,
|
|
Email: signupData.Email,
|
|
FirstName: signupData.FirstName,
|
|
LastName: signupData.LastName,
|
|
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
|
|
UserGroupIds: userGroupIDs,
|
|
EmailVerified: s.appConfigService.GetDbConfig().EmailsVerified.IsTrue(),
|
|
}
|
|
|
|
user, err := s.userService.createUserInternal(ctx, userToCreate, false, tx)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
accessToken, err := s.jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
if tokenProvided {
|
|
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
|
"signupToken": signupToken.Token,
|
|
}, tx)
|
|
|
|
signupToken.UsageCount++
|
|
|
|
err = tx.WithContext(ctx).Save(&signupToken).Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
|
|
}
|
|
} else {
|
|
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
|
"method": "open_signup",
|
|
}, tx)
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
return user, accessToken, nil
|
|
}
|
|
|
|
func (s *UserSignUpService) SignUpInitialAdmin(ctx context.Context, signUpData dto.SignUpDto) (model.User, string, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
var userCount int64
|
|
if err := tx.WithContext(ctx).Model(&model.User{}).
|
|
Where("id != ?", staticApiKeyUserID).
|
|
Count(&userCount).Error; err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
if userCount != 0 {
|
|
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
|
}
|
|
|
|
userToCreate := dto.UserCreateDto{
|
|
FirstName: signUpData.FirstName,
|
|
LastName: signUpData.LastName,
|
|
DisplayName: strings.TrimSpace(signUpData.FirstName + " " + signUpData.LastName),
|
|
Username: signUpData.Username,
|
|
Email: signUpData.Email,
|
|
IsAdmin: true,
|
|
}
|
|
|
|
user, err := s.userService.createUserInternal(ctx, userToCreate, false, tx)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
token, err := s.jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
return user, token, nil
|
|
}
|
|
|
|
func (s *UserSignUpService) ListSignupTokens(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.SignupToken, utils.PaginationResponse, error) {
|
|
var tokens []model.SignupToken
|
|
query := s.db.WithContext(ctx).Preload("UserGroups").Model(&model.SignupToken{})
|
|
|
|
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &tokens)
|
|
return tokens, pagination, err
|
|
}
|
|
|
|
func (s *UserSignUpService) DeleteSignupToken(ctx context.Context, tokenID string) error {
|
|
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
|
|
}
|
|
|
|
func (s *UserSignUpService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int, userGroupIDs []string) (model.SignupToken, error) {
|
|
signupToken, err := NewSignupToken(ttl, usageLimit)
|
|
if err != nil {
|
|
return model.SignupToken{}, err
|
|
}
|
|
|
|
var userGroups []model.UserGroup
|
|
err = s.db.WithContext(ctx).
|
|
Where("id IN ?", userGroupIDs).
|
|
Find(&userGroups).
|
|
Error
|
|
if err != nil {
|
|
return model.SignupToken{}, err
|
|
}
|
|
signupToken.UserGroups = userGroups
|
|
|
|
err = s.db.WithContext(ctx).Create(signupToken).Error
|
|
if err != nil {
|
|
return model.SignupToken{}, err
|
|
}
|
|
|
|
return *signupToken, nil
|
|
}
|
|
|
|
func NewSignupToken(ttl time.Duration, usageLimit int) (*model.SignupToken, error) {
|
|
// Generate a random token
|
|
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().Round(time.Second)
|
|
token := &model.SignupToken{
|
|
Token: randomString,
|
|
ExpiresAt: datatype.DateTime(now.Add(ttl)),
|
|
UsageLimit: usageLimit,
|
|
UsageCount: 0,
|
|
}
|
|
|
|
return token, nil
|
|
}
|