1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 16:49:42 +00:00

fix: ensure user inputs are normalized (#724)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-07-13 18:15:57 +02:00
committed by GitHub
parent f145903eb0
commit 7b4ccd1f30
23 changed files with 351 additions and 59 deletions

View File

@@ -20,6 +20,7 @@ import (
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
"github.com/pocket-id/pocket-id/backend/resources"
)
@@ -88,6 +89,7 @@ func connectDatabase() (db *gorm.DB, err error) {
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
}
sqliteutil.RegisterSqliteFunctions()
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
if err != nil {
return nil, err

View File

@@ -82,7 +82,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
var input dto.ApiKeyCreateDto
if err := ctx.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(ctx, &input); err != nil {
_ = ctx.Error(err)
return
}

View File

@@ -109,7 +109,7 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
// @Router /api/application-configuration [put]
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -59,7 +59,7 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -93,7 +93,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -193,7 +193,7 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
// @Router /api/users [post]
func (uc *UserController) createUserHandler(c *gin.Context) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -378,7 +378,7 @@ func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
// @Router /api/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsUnauthenticatedUserDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -457,7 +457,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
// @Router /api/signup/setup [post]
func (uc *UserController) signUpInitialAdmin(c *gin.Context) {
var input dto.SignUpDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -606,7 +606,7 @@ func (uc *UserController) deleteSignupTokenHandler(c *gin.Context) {
// @Router /api/signup [post]
func (uc *UserController) signupHandler(c *gin.Context) {
var input dto.SignUpDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -635,7 +635,7 @@ func (uc *UserController) signupHandler(c *gin.Context) {
// updateUser is an internal helper method, not exposed as an API endpoint
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -120,7 +120,7 @@ func (ugc *UserGroupController) get(c *gin.Context) {
// @Router /api/user-groups [post]
func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -152,7 +152,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
// @Router /api/user-groups/{id} [put]
func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -5,8 +5,8 @@ import (
)
type ApiKeyCreateDto struct {
Name string `json:"name" binding:"required,min=3,max=50"`
Description string `json:"description"`
Name string `json:"name" binding:"required,min=3,max=50" unorm:"nfc"`
Description string `json:"description" unorm:"nfc"`
ExpiresAt datatype.DateTime `json:"expiresAt" binding:"required"`
}

View File

@@ -12,7 +12,7 @@ type AppConfigVariableDto struct {
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"`
AppName string `json:"appName" binding:"required,min=1,max=30" unorm:"nfc"`
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
DisableAnimations string `json:"disableAnimations" binding:"required"`

View File

@@ -6,6 +6,6 @@ type CustomClaimDto struct {
}
type CustomClaimCreateDto struct {
Key string `json:"key" binding:"required"`
Value string `json:"value" binding:"required"`
Key string `json:"key" binding:"required" unorm:"nfc"`
Value string `json:"value" binding:"required" unorm:"nfc"`
}

View File

@@ -0,0 +1,94 @@
package dto
import (
"net/http"
"reflect"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"golang.org/x/text/unicode/norm"
)
// Normalize iterates through an object and performs Unicode normalization on all string fields with the `unorm` tag.
func Normalize(obj any) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr || v.IsNil() {
return
}
v = v.Elem()
// Handle case where obj is a slice of models
if v.Kind() == reflect.Slice {
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.Kind() == reflect.Ptr && !elem.IsNil() && elem.Elem().Kind() == reflect.Struct {
Normalize(elem.Interface())
} else if elem.Kind() == reflect.Struct && elem.CanAddr() {
Normalize(elem.Addr().Interface())
}
}
return
}
if v.Kind() != reflect.Struct {
return
}
// Iterate through all fields looking for those with the "unorm" tag
t := v.Type()
loop:
for i := range t.NumField() {
field := t.Field(i)
unormTag := field.Tag.Get("unorm")
if unormTag == "" {
continue
}
fv := v.Field(i)
if !fv.CanSet() || fv.Kind() != reflect.String {
continue
}
var form norm.Form
switch unormTag {
case "nfc":
form = norm.NFC
case "nfkc":
form = norm.NFKC
case "nfd":
form = norm.NFD
case "nfkd":
form = norm.NFKD
default:
continue loop
}
val := fv.String()
val = form.String(val)
fv.SetString(val)
}
}
func ShouldBindWithNormalizedJSON(ctx *gin.Context, obj any) error {
return ctx.ShouldBindWith(obj, binding.JSON)
}
type NormalizerJSONBinding struct{}
func (NormalizerJSONBinding) Name() string {
return "json"
}
func (NormalizerJSONBinding) Bind(req *http.Request, obj any) error {
// Use the default JSON binder
err := binding.JSON.Bind(req, obj)
if err != nil {
return err
}
// Perform normalization
Normalize(obj)
return nil
}

View File

@@ -0,0 +1,84 @@
package dto
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/text/unicode/norm"
)
type testDto struct {
Name string `unorm:"nfc"`
Description string `unorm:"nfd"`
Other string
BadForm string `unorm:"bad"`
}
func TestNormalize(t *testing.T) {
input := testDto{
// Is in NFC form already
Name: norm.NFC.String("Café"),
// NFC form will be normalized to NFD
Description: norm.NFC.String("vërø"),
// Should be unchanged
Other: "NöTag",
// Should be unchanged
BadForm: "BåD",
}
Normalize(&input)
assert.Equal(t, norm.NFC.String("Café"), input.Name)
assert.Equal(t, norm.NFD.String("vërø"), input.Description)
assert.Equal(t, "NöTag", input.Other)
assert.Equal(t, "BåD", input.BadForm)
}
func TestNormalizeSlice(t *testing.T) {
obj1 := testDto{
Name: norm.NFC.String("Café1"),
Description: norm.NFC.String("vërø1"),
Other: "NöTag1",
BadForm: "BåD1",
}
obj2 := testDto{
Name: norm.NFD.String("Résumé2"),
Description: norm.NFD.String("accéléré2"),
Other: "NöTag2",
BadForm: "BåD2",
}
t.Run("slice of structs", func(t *testing.T) {
slice := []testDto{obj1, obj2}
Normalize(&slice)
// Verify first element
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
assert.Equal(t, "NöTag1", slice[0].Other)
assert.Equal(t, "BåD1", slice[0].BadForm)
// Verify second element
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
assert.Equal(t, "NöTag2", slice[1].Other)
assert.Equal(t, "BåD2", slice[1].BadForm)
})
t.Run("slice of pointers to structs", func(t *testing.T) {
slice := []*testDto{&obj1, &obj2}
Normalize(&slice)
// Verify first element
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
assert.Equal(t, "NöTag1", slice[0].Other)
assert.Equal(t, "BåD1", slice[0].BadForm)
// Verify second element
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
assert.Equal(t, "NöTag2", slice[1].Other)
assert.Equal(t, "BåD2", slice[1].BadForm)
})
}

View File

@@ -26,7 +26,7 @@ type OidcClientWithAllowedGroupsCountDto struct {
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`

View File

@@ -1,6 +1,8 @@
package dto
import "time"
import (
"time"
)
type UserDto struct {
ID string `json:"id"`
@@ -17,10 +19,10 @@ type UserDto struct {
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"max=50"`
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
@@ -33,7 +35,7 @@ type OneTimeAccessTokenCreateDto struct {
}
type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
Email string `json:"email" binding:"required,email"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
RedirectPath string `json:"redirectPath"`
}
@@ -46,9 +48,9 @@ type UserUpdateUserGroupDto struct {
}
type SignUpDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"max=50"`
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
Token string `json:"token"`
}

View File

@@ -34,8 +34,8 @@ type UserGroupDtoWithUserCount struct {
}
type UserGroupCreateDto struct {
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50"`
Name string `json:"name" binding:"required,min=2,max=255"`
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50" unorm:"nfc"`
Name string `json:"name" binding:"required,min=2,max=255" unorm:"nfc"`
LdapID string `json:"-"`
}

View File

@@ -15,13 +15,14 @@ import (
"time"
"unicode/utf8"
"github.com/google/uuid"
"github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
"golang.org/x/text/unicode/norm"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
)
type LdapService struct {
@@ -181,7 +182,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
var databaseUser model.User
err = tx.
WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", username).
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -199,6 +200,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: ldapId,
}
dto.Normalize(syncGroup)
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
@@ -309,7 +311,6 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
// If a user is found (even if disabled), enable them since they're now back in LDAP
if databaseUser.ID != "" && databaseUser.Disabled {
// Use the transaction instead of the direct context
err = tx.
WithContext(ctx).
Model(&model.User{}).
@@ -318,7 +319,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
Error
if err != nil {
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
}
}
@@ -344,6 +345,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
IsAdmin: isAdmin,
LdapID: ldapId,
}
dto.Normalize(newUser)
if databaseUser.ID == "" {
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
@@ -476,7 +478,7 @@ func getDNProperty(property string, str string) string {
// LDAP servers may return binary UUIDs (16 bytes) or other non-UTF-8 data.
func convertLdapIdToString(ldapId string) string {
if utf8.ValidString(ldapId) {
return ldapId
return norm.NFC.String(ldapId)
}
// Try to parse as binary UUID (16 bytes)

View File

@@ -0,0 +1,51 @@
package sqlite
import (
"database/sql/driver"
"errors"
"fmt"
"strings"
sqlitelib "github.com/glebarez/go-sqlite"
"golang.org/x/text/unicode/norm"
)
func RegisterSqliteFunctions() {
// Register the `normalize(text, form)` function, which performs Unicode normalization on the text
// This is currently only used in migration functions
sqlitelib.MustRegisterDeterministicScalarFunction("normalize", 2, func(ctx *sqlitelib.FunctionContext, args []driver.Value) (driver.Value, error) {
if len(args) != 2 {
return nil, errors.New("normalize requires 2 arguments")
}
arg0, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("first argument for normalize is not a string: %T", args[0])
}
arg1, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("second argument for normalize is not a string: %T", args[1])
}
var form norm.Form
switch strings.ToLower(arg1) {
case "nfc":
form = norm.NFC
case "nfd":
form = norm.NFD
case "nfkc":
form = norm.NFKC
case "nfkd":
form = norm.NFKD
default:
return nil, fmt.Errorf("unsupported form: %s", arg1)
}
if len(arg0) == 0 {
return arg0, nil
}
return form.String(arg0), nil
})
}

View File

@@ -17,9 +17,14 @@ import (
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/utils"
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
"github.com/pocket-id/pocket-id/backend/resources"
)
func init() {
sqliteutil.RegisterSqliteFunctions()
}
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
// Each database connection is unique for the test.
// All migrations are automatically performed.