1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 20:25:05 +00:00

fix: restrict email one time sign in token to same browser (#1144)

This commit is contained in:
Elias Schneider
2025-12-12 14:51:07 +01:00
committed by GitHub
parent 0a6ff6f84b
commit 3eaf36aae7
11 changed files with 76 additions and 32 deletions

View File

@@ -51,7 +51,7 @@ var oneTimeAccessTokenCmd = &cobra.Command{
} }
// Create a new access token that expires in 1 hour // Create a new access token that expires in 1 hour
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Hour) oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Hour, false)
if txErr != nil { if txErr != nil {
return fmt.Errorf("failed to generate access token: %w", txErr) return fmt.Errorf("failed to generate access token: %w", txErr)
} }

View File

@@ -38,6 +38,13 @@ type TokenInvalidOrExpiredError struct{}
func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" } func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" }
func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 } func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 }
type DeviceCodeInvalid struct{}
func (e *DeviceCodeInvalid) Error() string {
return "one time access code must be used on the device it was generated for"
}
func (e *DeviceCodeInvalid) HttpStatusCode() int { return 400 }
type TokenInvalidError struct{} type TokenInvalidError struct{}
func (e *TokenInvalidError) Error() string { func (e *TokenInvalidError) Error() string {

View File

@@ -391,12 +391,13 @@ func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(
return return
} }
err := uc.userService.RequestOneTimeAccessEmailAsUnauthenticatedUser(c.Request.Context(), input.Email, input.RedirectPath) deviceToken, err := uc.userService.RequestOneTimeAccessEmailAsUnauthenticatedUser(c.Request.Context(), input.Email, input.RedirectPath)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
cookie.AddDeviceTokenCookie(c, deviceToken)
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
} }
@@ -440,7 +441,8 @@ func (uc *UserController) RequestOneTimeAccessEmailAsAdminHandler(c *gin.Context
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/{token} [post] // @Router /api/one-time-access-token/{token} [post]
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent()) deviceToken, _ := c.Cookie(cookie.DeviceTokenCookieName)
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), deviceToken, c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -88,6 +88,7 @@ func (u User) Initials() string {
type OneTimeAccessToken struct { type OneTimeAccessToken struct {
Base Base
Token string Token string
DeviceToken *string
ExpiresAt datatype.DateTime ExpiresAt datatype.DateTime
UserID string UserID string

View File

@@ -432,28 +432,36 @@ func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, user
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }
return s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl) _, err := s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl, true)
return err
} }
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error { func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) (string, error) {
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue() isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return "", &common.OneTimeAccessDisabledError{}
} }
var userId string var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// Do not return error if user not found to prevent email enumeration // Do not return error if user not found to prevent email enumeration
return nil return "", nil
} else if err != nil { } else if err != nil {
return err return "", err
} }
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute) deviceToken, err := s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute, true)
if err != nil {
return "", err
} else if deviceToken == nil {
return "", errors.New("device token expected but not returned")
} }
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration) error { return *deviceToken, nil
}
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration, withDeviceToken bool) (*string, error) {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
@@ -461,21 +469,20 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
user, err := s.GetUser(ctx, userID) user, err := s.GetUser(ctx, userID)
if err != nil { if err != nil {
return err return nil, err
} }
if user.Email == nil { if user.Email == nil {
return &common.UserEmailNotSetError{} return nil, &common.UserEmailNotSetError{}
} }
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx) oneTimeAccessToken, deviceToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, withDeviceToken, tx)
if err != nil { if err != nil {
return err return nil, err
} }
err = tx.Commit().Error err = tx.Commit().Error
if err != nil { if err != nil {
return err return nil, err
} }
// We use a background context here as this is running in a goroutine // We use a background context here as this is running in a goroutine
@@ -508,28 +515,29 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
} }
}() }()
return nil return deviceToken, nil
} }
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (string, error) { func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (token string, err error) {
return s.createOneTimeAccessTokenInternal(ctx, userID, ttl, s.db) token, _, err = s.createOneTimeAccessTokenInternal(ctx, userID, ttl, false, s.db)
return token, err
} }
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, tx *gorm.DB) (string, error) { func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, withDeviceToken bool, tx *gorm.DB) (token string, deviceToken *string, err error) {
oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl) oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl, withDeviceToken)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error
if err != nil { if err != nil {
return "", err return "", nil, err
} }
return oneTimeAccessToken.Token, nil return oneTimeAccessToken.Token, oneTimeAccessToken.DeviceToken, nil
} }
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) { func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token, deviceToken, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
@@ -549,6 +557,10 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
} }
return model.User{}, "", err return model.User{}, "", err
} }
if oneTimeAccessToken.DeviceToken != nil && deviceToken != *oneTimeAccessToken.DeviceToken {
return model.User{}, "", &common.DeviceCodeInvalid{}
}
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User) accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
if err != nil { if err != nil {
return model.User{}, "", err return model.User{}, "", err
@@ -818,23 +830,33 @@ func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) err
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
} }
func NewOneTimeAccessToken(userID string, ttl time.Duration) (*model.OneTimeAccessToken, error) { func NewOneTimeAccessToken(userID string, ttl time.Duration, withDeviceToken bool) (*model.OneTimeAccessToken, error) {
// If expires at is less than 15 minutes, use a 6-character token instead of 16 // If expires at is less than 15 minutes, use a 6-character token instead of 16
tokenLength := 16 tokenLength := 16
if ttl <= 15*time.Minute { if ttl <= 15*time.Minute {
tokenLength = 6 tokenLength = 6
} }
randomString, err := utils.GenerateRandomAlphanumericString(tokenLength) token, err := utils.GenerateRandomAlphanumericString(tokenLength)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var deviceToken *string
if withDeviceToken {
dt, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
return nil, err
}
deviceToken = &dt
}
now := time.Now().Round(time.Second) now := time.Now().Round(time.Second)
o := &model.OneTimeAccessToken{ o := &model.OneTimeAccessToken{
UserID: userID, UserID: userID,
ExpiresAt: datatype.DateTime(now.Add(ttl)), ExpiresAt: datatype.DateTime(now.Add(ttl)),
Token: randomString, Token: token,
DeviceToken: deviceToken,
} }
return o, nil return o, nil

View File

@@ -1,6 +1,8 @@
package cookie package cookie
import ( import (
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -11,3 +13,7 @@ func AddAccessTokenCookie(c *gin.Context, maxAgeInSeconds int, token string) {
func AddSessionIdCookie(c *gin.Context, maxAgeInSeconds int, sessionID string) { func AddSessionIdCookie(c *gin.Context, maxAgeInSeconds int, sessionID string) {
c.SetCookie(SessionIdCookieName, sessionID, maxAgeInSeconds, "/", "", true, true) c.SetCookie(SessionIdCookieName, sessionID, maxAgeInSeconds, "/", "", true, true)
} }
func AddDeviceTokenCookie(c *gin.Context, deviceToken string) {
c.SetCookie(DeviceTokenCookieName, deviceToken, int(15*time.Minute.Seconds()), "/api/one-time-access-token", "", true, true)
}

View File

@@ -8,10 +8,12 @@ import (
var AccessTokenCookieName = "__Host-access_token" var AccessTokenCookieName = "__Host-access_token"
var SessionIdCookieName = "__Host-session" var SessionIdCookieName = "__Host-session"
var DeviceTokenCookieName = "__Host-device_token" //nolint:gosec
func init() { func init() {
if strings.HasPrefix(common.EnvConfig.AppURL, "http://") { if strings.HasPrefix(common.EnvConfig.AppURL, "http://") {
AccessTokenCookieName = "access_token" AccessTokenCookieName = "access_token"
SessionIdCookieName = "session" SessionIdCookieName = "session"
DeviceTokenCookieName = "device_token"
} }
} }

View File

@@ -0,0 +1 @@
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;

View File

@@ -0,0 +1 @@
ALTER TABLE one_time_access_tokens ADD COLUMN device_token VARCHAR(16);

View File

@@ -0,0 +1 @@
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;

View File

@@ -0,0 +1 @@
ALTER TABLE one_time_access_tokens ADD COLUMN device_token TEXT;