1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-03-22 20:15:07 +00:00

feat: add ability define user groups for sign up tokens (#1155)

This commit is contained in:
Elias Schneider
2025-12-21 18:26:52 +01:00
committed by GitHub
parent f5da11b99b
commit 59ca6b26ac
23 changed files with 391 additions and 162 deletions

View File

@@ -545,7 +545,7 @@ func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
ttl = defaultSignupTokenDuration
}
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit)
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit, input.UserGroupIDs)
if err != nil {
_ = c.Error(err)
return

View File

@@ -6,8 +6,9 @@ import (
)
type SignupTokenCreateDto struct {
TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"`
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"`
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
UserGroupIDs []string `json:"userGroupIds"`
}
type SignupTokenDto struct {
@@ -16,5 +17,6 @@ type SignupTokenDto struct {
ExpiresAt datatype.DateTime `json:"expiresAt"`
UsageLimit int `json:"usageLimit"`
UsageCount int `json:"usageCount"`
UserGroups []UserGroupDto `json:"userGroups"`
CreatedAt datatype.DateTime `json:"createdAt"`
}

View File

@@ -23,15 +23,16 @@ type UserDto struct {
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
DisplayName string `json:"displayName" binding:"required,min=1,max=100" unorm:"nfc"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
LdapID string `json:"-"`
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
DisplayName string `json:"displayName" binding:"required,min=1,max=100" unorm:"nfc"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
UserGroupIds []string `json:"userGroupIds"`
LdapID string `json:"-"`
}
func (u UserCreateDto) Validate() error {

View File

@@ -13,6 +13,7 @@ type SignupToken struct {
ExpiresAt datatype.DateTime `json:"expiresAt" sortable:"true"`
UsageLimit int `json:"usageLimit" sortable:"true"`
UsageCount int `json:"usageCount" sortable:"true"`
UserGroups []UserGroup `gorm:"many2many:signup_tokens_user_groups;"`
}
func (st *SignupToken) IsExpired() bool {

View File

@@ -344,6 +344,9 @@ func (s *TestService) SeedDatabase(baseURL string) error {
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
UsageLimit: 1,
UsageCount: 0,
UserGroups: []model.UserGroup{
userGroups[0],
},
},
{
Base: model.Base{

View File

@@ -253,6 +253,18 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
return model.User{}, &common.UserEmailNotSetError{}
}
var userGroups []model.UserGroup
if len(input.UserGroupIds) > 0 {
err := tx.
WithContext(ctx).
Where("id IN ?", input.UserGroupIds).
Find(&userGroups).
Error
if err != nil {
return model.User{}, err
}
}
user := model.User{
FirstName: input.FirstName,
LastName: input.LastName,
@@ -262,6 +274,7 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
IsAdmin: input.IsAdmin,
Locale: input.Locale,
Disabled: input.Disabled,
UserGroups: userGroups,
}
if input.LdapID != "" {
user.LdapID = &input.LdapID
@@ -285,7 +298,13 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
// Apply default groups and claims for new non-LDAP users
if !isLdapSync {
if err := s.applySignupDefaults(ctx, &user, tx); err != nil {
if len(input.UserGroupIds) == 0 {
if err := s.applyDefaultGroups(ctx, &user, tx); err != nil {
return model.User{}, err
}
}
if err := s.applyDefaultCustomClaims(ctx, &user, tx); err != nil {
return model.User{}, err
}
}
@@ -293,10 +312,9 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
return user, nil
}
func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User, tx *gorm.DB) error {
func (s *UserService) applyDefaultGroups(ctx context.Context, user *model.User, tx *gorm.DB) error {
config := s.appConfigService.GetDbConfig()
// Apply default user groups
var groupIDs []string
v := config.SignupDefaultUserGroupIDs.Value
if v != "" && v != "[]" {
@@ -323,10 +341,14 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
}
}
}
return nil
}
func (s *UserService) applyDefaultCustomClaims(ctx context.Context, user *model.User, tx *gorm.DB) error {
config := s.appConfigService.GetDbConfig()
// Apply default custom claims
var claims []dto.CustomClaimCreateDto
v = config.SignupDefaultCustomClaims.Value
v := config.SignupDefaultCustomClaims.Value
if v != "" && v != "[]" {
err := json.Unmarshal([]byte(v), &claims)
if err != nil {
@@ -727,12 +749,22 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
Error
}
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int) (model.SignupToken, error) {
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int, userGroupIDs []string) (model.SignupToken, error) {
signupToken, err := NewSignupToken(ttl, usageLimit)
if err != nil {
return model.SignupToken{}, err
}
var userGroups []model.UserGroup
err = s.db.WithContext(ctx).
Where("id IN ?", userGroupIDs).
Find(&userGroups).
Error
if err != nil {
return model.SignupToken{}, err
}
signupToken.UserGroups = userGroups
err = s.db.WithContext(ctx).Create(signupToken).Error
if err != nil {
return model.SignupToken{}, err
@@ -755,9 +787,11 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
}
var signupToken model.SignupToken
var userGroupIDs []string
if tokenProvided {
err := tx.
WithContext(ctx).
Preload("UserGroups").
Where("token = ?", signupData.Token).
Clauses(clause.Locking{Strength: "UPDATE"}).
First(&signupToken).
@@ -772,14 +806,19 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
if !signupToken.IsValid() {
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
}
for _, group := range signupToken.UserGroups {
userGroupIDs = append(userGroupIDs, group.ID)
}
}
userToCreate := dto.UserCreateDto{
Username: signupData.Username,
Email: signupData.Email,
FirstName: signupData.FirstName,
LastName: signupData.LastName,
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
Username: signupData.Username,
Email: signupData.Email,
FirstName: signupData.FirstName,
LastName: signupData.LastName,
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
UserGroupIds: userGroupIDs,
}
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
@@ -820,7 +859,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
func (s *UserService) ListSignupTokens(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.SignupToken, utils.PaginationResponse, error) {
var tokens []model.SignupToken
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
query := s.db.WithContext(ctx).Preload("UserGroups").Model(&model.SignupToken{})
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &tokens)
return tokens, pagination, err

View File

@@ -0,0 +1 @@
DROP TABLE signup_tokens_user_groups;

View File

@@ -0,0 +1,8 @@
CREATE TABLE signup_tokens_user_groups
(
signup_token_id UUID NOT NULL,
user_group_id UUID NOT NULL,
PRIMARY KEY (signup_token_id, user_group_id),
FOREIGN KEY (signup_token_id) REFERENCES signup_tokens (id) ON DELETE CASCADE,
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
);

View File

@@ -1 +1,7 @@
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;
PRAGMA foreign_keys=OFF;
BEGIN;
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;
COMMIT;
PRAGMA foreign_keys=ON;

View File

@@ -1 +1,7 @@
ALTER TABLE one_time_access_tokens ADD COLUMN device_token TEXT;
PRAGMA foreign_keys=OFF;
BEGIN;
ALTER TABLE one_time_access_tokens ADD COLUMN device_token TEXT;
COMMIT;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,7 @@
PRAGMA foreign_keys=OFF;
BEGIN;
DROP TABLE signup_tokens_user_groups;
COMMIT;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,14 @@
PRAGMA foreign_keys=OFF;
BEGIN;
CREATE TABLE signup_tokens_user_groups
(
signup_token_id TEXT NOT NULL,
user_group_id TEXT NOT NULL,
PRIMARY KEY (signup_token_id, user_group_id),
FOREIGN KEY (signup_token_id) REFERENCES signup_tokens (id) ON DELETE CASCADE,
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
);
COMMIT;
PRAGMA foreign_keys=ON;