1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 07:35:11 +00:00

fix: restrict email one time sign in token to same browser (#1144)

This commit is contained in:
Elias Schneider
2025-12-12 14:51:07 +01:00
committed by GitHub
parent 0a6ff6f84b
commit 3eaf36aae7
11 changed files with 76 additions and 32 deletions

View File

@@ -432,28 +432,36 @@ func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, user
return &common.OneTimeAccessDisabledError{}
}
return s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl)
_, err := s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl, true)
return err
}
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error {
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) (string, error) {
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
return "", &common.OneTimeAccessDisabledError{}
}
var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// Do not return error if user not found to prevent email enumeration
return nil
return "", nil
} else if err != nil {
return err
return "", err
}
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
deviceToken, err := s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute, true)
if err != nil {
return "", err
} else if deviceToken == nil {
return "", errors.New("device token expected but not returned")
}
return *deviceToken, nil
}
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration) error {
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration, withDeviceToken bool) (*string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -461,21 +469,20 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
user, err := s.GetUser(ctx, userID)
if err != nil {
return err
return nil, err
}
if user.Email == nil {
return &common.UserEmailNotSetError{}
return nil, &common.UserEmailNotSetError{}
}
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx)
oneTimeAccessToken, deviceToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, withDeviceToken, tx)
if err != nil {
return err
return nil, err
}
err = tx.Commit().Error
if err != nil {
return err
return nil, err
}
// We use a background context here as this is running in a goroutine
@@ -508,28 +515,29 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
}
}()
return nil
return deviceToken, nil
}
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (string, error) {
return s.createOneTimeAccessTokenInternal(ctx, userID, ttl, s.db)
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (token string, err error) {
token, _, err = s.createOneTimeAccessTokenInternal(ctx, userID, ttl, false, s.db)
return token, err
}
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, tx *gorm.DB) (string, error) {
oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl)
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, withDeviceToken bool, tx *gorm.DB) (token string, deviceToken *string, err error) {
oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl, withDeviceToken)
if err != nil {
return "", err
return "", nil, err
}
err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error
if err != nil {
return "", err
return "", nil, err
}
return oneTimeAccessToken.Token, nil
return oneTimeAccessToken.Token, oneTimeAccessToken.DeviceToken, nil
}
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token, deviceToken, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -549,6 +557,10 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
}
return model.User{}, "", err
}
if oneTimeAccessToken.DeviceToken != nil && deviceToken != *oneTimeAccessToken.DeviceToken {
return model.User{}, "", &common.DeviceCodeInvalid{}
}
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
if err != nil {
return model.User{}, "", err
@@ -818,23 +830,33 @@ func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) err
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
}
func NewOneTimeAccessToken(userID string, ttl time.Duration) (*model.OneTimeAccessToken, error) {
func NewOneTimeAccessToken(userID string, ttl time.Duration, withDeviceToken bool) (*model.OneTimeAccessToken, error) {
// If expires at is less than 15 minutes, use a 6-character token instead of 16
tokenLength := 16
if ttl <= 15*time.Minute {
tokenLength = 6
}
randomString, err := utils.GenerateRandomAlphanumericString(tokenLength)
token, err := utils.GenerateRandomAlphanumericString(tokenLength)
if err != nil {
return nil, err
}
var deviceToken *string
if withDeviceToken {
dt, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
return nil, err
}
deviceToken = &dt
}
now := time.Now().Round(time.Second)
o := &model.OneTimeAccessToken{
UserID: userID,
ExpiresAt: datatype.DateTime(now.Add(ttl)),
Token: randomString,
UserID: userID,
ExpiresAt: datatype.DateTime(now.Add(ttl)),
Token: token,
DeviceToken: deviceToken,
}
return o, nil