diff --git a/.gitignore b/.gitignore
index efb0a768..5978887f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ node_modules
/frontend/build
/backend/bin
pocket-id
+/tests/test-results/*.json
# OS
.DS_Store
diff --git a/backend/go.mod b/backend/go.mod
index 4dd4ef0f..0adeaa41 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -20,7 +20,8 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/go-uuid v1.0.3
github.com/joho/godotenv v1.5.1
- github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
+ github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
+ github.com/lestrrat-go/jwx/v3 v3.0.1
github.com/mileusna/useragent v1.3.5
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
github.com/stretchr/testify v1.10.0
@@ -32,7 +33,7 @@ require (
go.opentelemetry.io/otel/sdk v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
go.opentelemetry.io/otel/trace v1.35.0
- golang.org/x/crypto v0.36.0
+ golang.org/x/crypto v0.37.0
golang.org/x/image v0.24.0
golang.org/x/time v0.9.0
gorm.io/driver/postgres v1.5.11
@@ -77,9 +78,8 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
- github.com/lestrrat-go/blackmagic v1.0.2 // indirect
+ github.com/lestrrat-go/blackmagic v1.0.3 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
- github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@@ -123,7 +123,7 @@ require (
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
- golang.org/x/text v0.23.0 // indirect
+ golang.org/x/text v0.24.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/grpc v1.71.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 7042fe85..7e3a5370 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -164,14 +164,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
-github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
-github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
+github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
+github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
-github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
-github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
-github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
-github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
+github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
+github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
+github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
+github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -309,8 +309,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
-golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
-golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
+golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -377,8 +377,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
-golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
-golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
+golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
+golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/backend/internal/bootstrap/e2etest_router_bootstrap.go b/backend/internal/bootstrap/e2etest_router_bootstrap.go
index bda3360a..e16bdc9f 100644
--- a/backend/internal/bootstrap/e2etest_router_bootstrap.go
+++ b/backend/internal/bootstrap/e2etest_router_bootstrap.go
@@ -3,6 +3,8 @@
package bootstrap
import (
+ "log"
+
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -14,7 +16,12 @@ import (
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
- testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
+ testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
+ if err != nil {
+ log.Fatalf("failed to initialize test service: %v", err)
+ return
+ }
+
controller.NewTestController(apiGroup, testService)
},
}
diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go
index 1ba817cd..892f47d0 100644
--- a/backend/internal/bootstrap/services_bootstrap.go
+++ b/backend/internal/bootstrap/services_bootstrap.go
@@ -26,15 +26,14 @@ type services struct {
}
// Initializes all services
-// The context should be used by services only for initialization, and not for running
-func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
+func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
svc = &services{}
- svc.appConfigService = service.NewAppConfigService(initCtx, db)
+ svc.appConfigService = service.NewAppConfigService(ctx, db)
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
if err != nil {
- return nil, fmt.Errorf("unable to create email service: %w", err)
+ return nil, fmt.Errorf("failed to create email service: %w", err)
}
svc.geoLiteService = service.NewGeoLiteService(httpClient)
@@ -42,7 +41,12 @@ func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client)
svc.jwtService = service.NewJwtService(svc.appConfigService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
svc.customClaimService = service.NewCustomClaimService(db)
- svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
+
+ svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create OIDC service: %w", err)
+ }
+
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go
index d385d11c..86f90e93 100644
--- a/backend/internal/common/errors.go
+++ b/backend/internal/common/errors.go
@@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{}
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
+type OidcClientAssertionInvalidError struct{}
+
+func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
+func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
+
type OidcInvalidAuthorizationCodeError struct{}
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
diff --git a/backend/internal/controller/e2etest_controller.go b/backend/internal/controller/e2etest_controller.go
index 179285a3..d5ecc989 100644
--- a/backend/internal/controller/e2etest_controller.go
+++ b/backend/internal/controller/e2etest_controller.go
@@ -14,6 +14,9 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler)
+
+ group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
+ group.POST("/externalidp/sign", testController.externalIdPSignToken)
}
type TestController struct {
@@ -21,6 +24,15 @@ type TestController struct {
}
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
+ var baseURL string
+ if c.Request.TLS != nil {
+ baseURL = "https://" + c.Request.Host
+ } else {
+ baseURL = "http://" + c.Request.Host
+ }
+
+ skipLdap := c.Query("skip-ldap") == "true"
+
if err := tc.TestService.ResetDatabase(); err != nil {
_ = c.Error(err)
return
@@ -31,7 +43,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return
}
- if err := tc.TestService.SeedDatabase(); err != nil {
+ if err := tc.TestService.SeedDatabase(baseURL); err != nil {
_ = c.Error(err)
return
}
@@ -41,17 +53,50 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return
}
- if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
- _ = c.Error(err)
- return
- }
+ if !skipLdap {
+ if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
+ _ = c.Error(err)
+ return
+ }
- if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
- _ = c.Error(err)
- return
+ if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
+ _ = c.Error(err)
+ return
+ }
}
tc.TestService.SetJWTKeys()
c.Status(http.StatusNoContent)
}
+
+func (tc *TestController) externalIdPJWKS(c *gin.Context) {
+ jwks, err := tc.TestService.GetExternalIdPJWKS()
+ if err != nil {
+ _ = c.Error(err)
+ return
+ }
+
+ c.JSON(http.StatusOK, jwks)
+}
+
+func (tc *TestController) externalIdPSignToken(c *gin.Context) {
+ var input struct {
+ Aud string `json:"aud"`
+ Iss string `json:"iss"`
+ Sub string `json:"sub"`
+ }
+ err := c.ShouldBindJSON(&input)
+ if err != nil {
+ _ = c.Error(err)
+ return
+ }
+
+ token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud)
+ if err != nil {
+ _ = c.Error(err)
+ return
+ }
+
+ c.Writer.WriteString(token)
+}
diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go
index d6cb3fa2..68d7300d 100644
--- a/backend/internal/controller/oidc_controller.go
+++ b/backend/internal/controller/oidc_controller.go
@@ -7,14 +7,14 @@ import (
"net/url"
"strings"
- "github.com/pocket-id/pocket-id/backend/internal/common"
- "github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
-
"github.com/gin-gonic/gin"
+
+ "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
+ "github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
)
// NewOidcController creates a new controller for OIDC related endpoints
@@ -124,11 +124,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
// @Tags OIDC
// @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
-// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
+// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)"
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
+// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
+// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) {
@@ -363,12 +365,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
- if err == nil {
- c.JSON(http.StatusOK, clientDto)
+ if err != nil {
+ _ = c.Error(err)
return
}
- _ = c.Error(err)
+ c.JSON(http.StatusOK, clientDto)
}
// listClientsHandler godoc
diff --git a/backend/internal/dto/dto_mapper.go b/backend/internal/dto/dto_mapper.go
index 8c027d0b..f727dfc9 100644
--- a/backend/internal/dto/dto_mapper.go
+++ b/backend/internal/dto/dto_mapper.go
@@ -62,7 +62,60 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
return nil
}
+//nolint:gocognit
func mapField(sourceField reflect.Value, destField reflect.Value) error {
+ // Handle pointer to struct in source
+ if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() {
+ switch {
+ case sourceField.Elem().Kind() == reflect.Struct:
+ switch {
+ case destField.Kind() == reflect.Struct:
+ // Map from pointer to struct -> struct
+ return mapStructInternal(sourceField.Elem(), destField)
+ case destField.Kind() == reflect.Ptr && destField.CanSet():
+ // Map from pointer to struct -> pointer to struct
+ if destField.IsNil() {
+ destField.Set(reflect.New(destField.Type().Elem()))
+ }
+ return mapStructInternal(sourceField.Elem(), destField.Elem())
+ }
+ case destField.Kind() == reflect.Ptr &&
+ destField.CanSet() &&
+ sourceField.Elem().Type().AssignableTo(destField.Type().Elem()):
+ // Handle primitive pointer types (e.g., *string to *string)
+ if destField.IsNil() {
+ destField.Set(reflect.New(destField.Type().Elem()))
+ }
+ destField.Elem().Set(sourceField.Elem())
+ return nil
+ case destField.Kind() != reflect.Ptr &&
+ destField.CanSet() &&
+ sourceField.Elem().Type().AssignableTo(destField.Type()):
+ // Handle *T to T conversion for primitive types
+ destField.Set(sourceField.Elem())
+ return nil
+ }
+ }
+
+ // Handle pointer to struct in destination
+ if destField.Kind() == reflect.Ptr && destField.CanSet() {
+ switch {
+ case sourceField.Kind() == reflect.Struct:
+ // Map from struct -> pointer to struct
+ if destField.IsNil() {
+ destField.Set(reflect.New(destField.Type().Elem()))
+ }
+ return mapStructInternal(sourceField, destField.Elem())
+ case !sourceField.IsZero() && sourceField.Type().AssignableTo(destField.Type().Elem()):
+ // Handle T to *T conversion for primitive types
+ if destField.IsNil() {
+ destField.Set(reflect.New(destField.Type().Elem()))
+ }
+ destField.Elem().Set(sourceField)
+ return nil
+ }
+ }
+
switch {
case sourceField.Type() == destField.Type():
destField.Set(sourceField)
diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go
index 9c0b62c7..df317787 100644
--- a/backend/internal/dto/oidc_dto.go
+++ b/backend/internal/dto/oidc_dto.go
@@ -8,10 +8,11 @@ type OidcClientMetaDataDto struct {
type OidcClientDto struct {
OidcClientMetaDataDto
- CallbackURLs []string `json:"callbackURLs"`
- LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
- IsPublic bool `json:"isPublic"`
- PkceEnabled bool `json:"pkceEnabled"`
+ CallbackURLs []string `json:"callbackURLs"`
+ LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
+ IsPublic bool `json:"isPublic"`
+ PkceEnabled bool `json:"pkceEnabled"`
+ Credentials OidcClientCredentialsDto `json:"credentials"`
}
type OidcClientWithAllowedUserGroupsDto struct {
@@ -25,11 +26,23 @@ type OidcClientWithAllowedGroupsCountDto struct {
}
type OidcClientCreateDto struct {
- Name string `json:"name" binding:"required,max=50"`
- CallbackURLs []string `json:"callbackURLs"`
- LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
- IsPublic bool `json:"isPublic"`
- PkceEnabled bool `json:"pkceEnabled"`
+ Name string `json:"name" binding:"required,max=50"`
+ CallbackURLs []string `json:"callbackURLs"`
+ LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
+ IsPublic bool `json:"isPublic"`
+ PkceEnabled bool `json:"pkceEnabled"`
+ Credentials OidcClientCredentialsDto `json:"credentials"`
+}
+
+type OidcClientCredentialsDto struct {
+ FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"`
+}
+
+type OidcClientFederatedIdentityDto struct {
+ Issuer string `json:"issuer"`
+ Subject string `json:"subject,omitempty"`
+ Audience string `json:"audience,omitempty"`
+ JWKS string `json:"jwks,omitempty"`
}
type AuthorizeOidcClientRequestDto struct {
@@ -52,13 +65,15 @@ type AuthorizationRequiredDto struct {
}
type OidcCreateTokensDto struct {
- GrantType string `form:"grant_type" binding:"required"`
- Code string `form:"code"`
- DeviceCode string `form:"device_code"`
- ClientID string `form:"client_id"`
- ClientSecret string `form:"client_secret"`
- CodeVerifier string `form:"code_verifier"`
- RefreshToken string `form:"refresh_token"`
+ GrantType string `form:"grant_type" binding:"required"`
+ Code string `form:"code"`
+ DeviceCode string `form:"device_code"`
+ ClientID string `form:"client_id"`
+ ClientSecret string `form:"client_secret"`
+ CodeVerifier string `form:"code_verifier"`
+ RefreshToken string `form:"refresh_token"`
+ ClientAssertion string `form:"client_assertion"`
+ ClientAssertionType string `form:"client_assertion_type"`
}
type OidcIntrospectDto struct {
diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go
index 7958215c..490015bf 100644
--- a/backend/internal/model/oidc.go
+++ b/backend/internal/model/oidc.go
@@ -5,8 +5,9 @@ import (
"encoding/json"
"fmt"
- datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
+
+ datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type UserAuthorizedOidcClient struct {
@@ -45,6 +46,7 @@ type OidcClient struct {
HasLogo bool `gorm:"-"`
IsPublic bool
PkceEnabled bool
+ Credentials OidcClientCredentials
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
CreatedByID string
@@ -71,9 +73,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
return nil
}
+type OidcClientCredentials struct { //nolint:recvcheck
+ FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
+}
+
+type OidcClientFederatedIdentity struct {
+ Issuer string `json:"issuer"`
+ Subject string `json:"subject,omitempty"`
+ Audience string `json:"audience,omitempty"`
+ JWKS string `json:"jwks,omitempty"` // URL of the JWKS
+}
+
+func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) {
+ if issuer == "" {
+ return OidcClientFederatedIdentity{}, false
+ }
+
+ for _, fi := range occ.FederatedIdentities {
+ if fi.Issuer == issuer {
+ return fi, true
+ }
+ }
+
+ return OidcClientFederatedIdentity{}, false
+}
+
+func (occ *OidcClientCredentials) Scan(value any) error {
+ switch v := value.(type) {
+ case []byte:
+ return json.Unmarshal(v, occ)
+ case string:
+ return json.Unmarshal([]byte(v), occ)
+ default:
+ return fmt.Errorf("unsupported type: %T", value)
+ }
+}
+
+func (occ OidcClientCredentials) Value() (driver.Value, error) {
+ return json.Marshal(occ)
+}
+
type UrlList []string //nolint:recvcheck
-func (cu *UrlList) Scan(value interface{}) error {
+func (cu *UrlList) Scan(value any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu)
diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go
index 0abd4558..5213d059 100644
--- a/backend/internal/service/app_config_service.go
+++ b/backend/internal/service/app_config_service.go
@@ -29,17 +29,17 @@ type AppConfigService struct {
db *gorm.DB
}
-func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService {
+func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
db: db,
}
- err := service.LoadDbConfig(initCtx)
+ err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
- err = service.initInstanceID(initCtx)
+ err = service.initInstanceID(ctx)
if err != nil {
log.Fatalf("Failed to initialize instance ID: %v", err)
}
diff --git a/backend/internal/service/app_config_service_test.go b/backend/internal/service/app_config_service_test.go
index c986162b..5b538919 100644
--- a/backend/internal/service/app_config_service_test.go
+++ b/backend/internal/service/app_config_service_test.go
@@ -3,16 +3,10 @@ package service
import (
"sync/atomic"
"testing"
- "time"
-
- "github.com/glebarez/sqlite"
- "gorm.io/gorm"
- "gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
- "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
@@ -28,7 +22,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -42,7 +36,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loads value from config table", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
@@ -72,7 +66,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("ignores unknown config keys", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
@@ -93,7 +87,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loading config multiple times", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
@@ -135,7 +129,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -171,7 +165,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -195,7 +189,7 @@ func TestLoadDbConfig(t *testing.T) {
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -220,7 +214,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("update multiple values", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -264,7 +258,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("empty value resets to default", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -285,7 +279,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with odd number of arguments", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -301,7 +295,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with invalid key", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -319,7 +313,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -392,7 +386,7 @@ func TestUpdateAppConfig(t *testing.T) {
})
t.Run("empty values reset to defaults", func(t *testing.T) {
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
@@ -457,7 +451,7 @@ func TestUpdateAppConfig(t *testing.T) {
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
- db := newAppConfigTestDatabaseForTest(t)
+ db := newDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -475,49 +469,3 @@ func TestUpdateAppConfig(t *testing.T) {
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
-
-// Implements gorm's logger.Writer interface
-type testLoggerAdapter struct {
- t *testing.T
-}
-
-func (l testLoggerAdapter) Printf(format string, args ...any) {
- l.t.Logf(format, args...)
-}
-
-func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
- t.Helper()
-
- // Get a name for this in-memory database that is specific to the test
- dbName := utils.CreateSha256Hash(t.Name())
-
- // Connect to a new in-memory SQL database
- db, err := gorm.Open(
- sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
- &gorm.Config{
- TranslateError: true,
- Logger: logger.New(
- testLoggerAdapter{t: t},
- logger.Config{
- SlowThreshold: 200 * time.Millisecond,
- LogLevel: logger.Info,
- IgnoreRecordNotFoundError: false,
- ParameterizedQueries: false,
- Colorful: false,
- },
- ),
- })
- require.NoError(t, err, "Failed to connect to test database")
-
- // Create the app_config_variables table
- err = db.Exec(`
-CREATE TABLE app_config_variables
-(
- key VARCHAR(100) NOT NULL PRIMARY KEY,
- value TEXT NOT NULL
-)
-`).Error
- require.NoError(t, err, "Failed to create test config table")
-
- return db
-}
diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go
index dcee2aac..96234a99 100644
--- a/backend/internal/service/e2etest_service.go
+++ b/backend/internal/service/e2etest_service.go
@@ -5,6 +5,8 @@ package service
import (
"context"
"crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
"crypto/x509"
"encoding/base64"
"fmt"
@@ -16,6 +18,7 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwk"
+ "github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -30,14 +33,43 @@ type TestService struct {
jwtService *JwtService
appConfigService *AppConfigService
ldapService *LdapService
+ externalIdPKey jwk.Key
}
-func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService {
- return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService}
+func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
+ s := &TestService{
+ db: db,
+ appConfigService: appConfigService,
+ jwtService: jwtService,
+ ldapService: ldapService,
+ }
+ err := s.initExternalIdP()
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize external IdP: %w", err)
+ }
+ return s, nil
+}
+
+// Initializes the "external IdP"
+// This creates a new "issuing authority" containing a public JWKS
+// It also stores the private key internally that will be used to issue JWTs
+func (s *TestService) initExternalIdP() error {
+ // Generate a new ECDSA key
+ rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return fmt.Errorf("failed to generate private key: %w", err)
+ }
+
+ s.externalIdPKey, err = utils.ImportRawKey(rawKey)
+ if err != nil {
+ return fmt.Errorf("failed to import private key: %w", err)
+ }
+
+ return nil
}
//nolint:gocognit
-func (s *TestService) SeedDatabase() error {
+func (s *TestService) SeedDatabase(baseURL string) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
{
@@ -138,6 +170,26 @@ func (s *TestService) SeedDatabase() error {
userGroups[1],
},
},
+ {
+ Base: model.Base{
+ ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
+ },
+ Name: "Federated",
+ Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
+ CallbackURLs: model.UrlList{"http://federated/auth/callback"},
+ CreatedByID: users[1].ID,
+ AllowedUserGroups: []model.UserGroup{},
+ Credentials: model.OidcClientCredentials{
+ FederatedIdentities: []model.OidcClientFederatedIdentity{
+ {
+ Issuer: "https://external-idp.local",
+ Audience: "api://PocketID",
+ Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
+ JWKS: baseURL + "/api/externalidp/jwks.json",
+ },
+ },
+ },
+ },
}
for _, client := range oidcClients {
if err := tx.Create(&client).Error; err != nil {
@@ -145,16 +197,28 @@ func (s *TestService) SeedDatabase() error {
}
}
- authCode := model.OidcAuthorizationCode{
- Code: "auth-code",
- Scope: "openid profile",
- Nonce: "nonce",
- ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
- UserID: users[0].ID,
- ClientID: oidcClients[0].ID,
+ authCodes := []model.OidcAuthorizationCode{
+ {
+ Code: "auth-code",
+ Scope: "openid profile",
+ Nonce: "nonce",
+ ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
+ UserID: users[0].ID,
+ ClientID: oidcClients[0].ID,
+ },
+ {
+ Code: "federated",
+ Scope: "openid profile",
+ Nonce: "nonce",
+ ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
+ UserID: users[1].ID,
+ ClientID: oidcClients[2].ID,
+ },
}
- if err := tx.Create(&authCode).Error; err != nil {
- return err
+ for _, authCode := range authCodes {
+ if err := tx.Create(&authCode).Error; err != nil {
+ return err
+ }
}
refreshToken := model.OidcRefreshToken{
@@ -177,13 +241,22 @@ func (s *TestService) SeedDatabase() error {
return err
}
- userAuthorizedClient := model.UserAuthorizedOidcClient{
- Scope: "openid profile email",
- UserID: users[0].ID,
- ClientID: oidcClients[0].ID,
+ userAuthorizedClients := []model.UserAuthorizedOidcClient{
+ {
+ Scope: "openid profile email",
+ UserID: users[0].ID,
+ ClientID: oidcClients[0].ID,
+ },
+ {
+ Scope: "openid profile email",
+ UserID: users[1].ID,
+ ClientID: oidcClients[2].ID,
+ },
}
- if err := tx.Create(&userAuthorizedClient).Error; err != nil {
- return err
+ for _, userAuthorizedClient := range userAuthorizedClients {
+ if err := tx.Create(&userAuthorizedClient).Error; err != nil {
+ return err
+ }
}
// To generate a new key pair, run the following command:
@@ -405,3 +478,41 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
return nil
}
+
+// GetExternalIdPJWKS returns the JWKS for the "external IdP".
+func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
+ pubKey, err := s.externalIdPKey.PublicKey()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get public key: %w", err)
+ }
+
+ set := jwk.NewSet()
+ err = set.AddKey(pubKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to add public key to set: %w", err)
+ }
+
+ return set, nil
+}
+
+func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) {
+ now := time.Now()
+ token, err := jwt.NewBuilder().
+ Subject(sub).
+ Expiration(now.Add(time.Hour)).
+ IssuedAt(now).
+ Issuer(iss).
+ Audience([]string{aud}).
+ Build()
+ if err != nil {
+ return "", fmt.Errorf("failed to build token: %w", err)
+ }
+
+ alg, _ := s.externalIdPKey.Algorithm()
+ signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey))
+ if err != nil {
+ return "", fmt.Errorf("failed to sign token: %w", err)
+ }
+
+ return string(signed), nil
+}
diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go
index 35dfd848..389cd608 100644
--- a/backend/internal/service/jwt_service.go
+++ b/backend/internal/service/jwt_service.go
@@ -4,11 +4,9 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
- "encoding/base64"
"encoding/json"
"errors"
"fmt"
- "io"
"log"
"os"
"path/filepath"
@@ -372,7 +370,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
- EnsureAlgInKey(pubKey)
+ utils.EnsureAlgInKey(pubKey)
return pubKey, nil
}
@@ -415,27 +413,6 @@ func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
return key, nil
}
-// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
-func EnsureAlgInKey(key jwk.Key) {
- _, ok := key.Algorithm()
- if ok {
- // Algorithm is already set
- return
- }
-
- switch key.KeyType() {
- case jwa.RSA():
- // Default to RS256 for RSA keys
- _ = key.Set(jwk.AlgorithmKey, jwa.RS256())
- case jwa.EC():
- // Default to ES256 for ECDSA keys
- _ = key.Set(jwk.AlgorithmKey, jwa.ES256())
- case jwa.OKP():
- // Default to EdDSA for OKP keys
- _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
- }
-}
-
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
@@ -444,27 +421,7 @@ func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
}
// Import the raw key
- return importRawKey(rawKey)
-}
-
-func importRawKey(rawKey any) (jwk.Key, error) {
- key, err := jwk.Import(rawKey)
- if err != nil {
- return nil, fmt.Errorf("failed to import generated private key: %w", err)
- }
-
- // Generate the key ID
- kid, err := generateRandomKeyID()
- if err != nil {
- return nil, fmt.Errorf("failed to generate key ID: %w", err)
- }
- _ = key.Set(jwk.KeyIDKey, kid)
-
- // Set other required fields
- _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
- EnsureAlgInKey(key)
-
- return key, err
+ return utils.ImportRawKey(rawKey)
}
// SaveKeyJWK saves a JWK to a file
@@ -492,16 +449,6 @@ func SaveKeyJWK(key jwk.Key, path string) error {
return nil
}
-// generateRandomKeyID generates a random key ID.
-func generateRandomKeyID() (string, error) {
- buf := make([]byte, 8)
- _, err := io.ReadFull(rand.Reader, buf)
- if err != nil {
- return "", fmt.Errorf("failed to read random bytes: %w", err)
- }
- return base64.RawURLEncoding.EncodeToString(buf), nil
-}
-
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {
diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go
index e4f7babf..43c7520c 100644
--- a/backend/internal/service/jwt_service_test.go
+++ b/backend/internal/service/jwt_service_test.go
@@ -21,6 +21,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
+ "github.com/pocket-id/pocket-id/backend/internal/utils"
)
func TestJwtService_Init(t *testing.T) {
@@ -1218,7 +1219,7 @@ func TestTokenTypeValidator(t *testing.T) {
func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper()
- privateKey, err := importRawKey(privateKeyRaw)
+ privateKey, err := utils.ImportRawKey(privateKeyRaw)
require.NoError(t, err, "Failed to import private key")
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go
index ba7070c0..0a355917 100644
--- a/backend/internal/service/oidc_service.go
+++ b/backend/internal/service/oidc_service.go
@@ -3,18 +3,25 @@ package service
import (
"context"
"crypto/sha256"
+ "crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
+ "log/slog"
"mime/multipart"
+ "net/http"
"os"
"regexp"
"slices"
"strings"
"time"
+ "github.com/lestrrat-go/httprc/v3"
+ "github.com/lestrrat-go/httprc/v3/errsink"
+ "github.com/lestrrat-go/jwx/v3/jwk"
+ "github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
@@ -31,6 +38,8 @@ const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
+
+ ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
)
type OidcService struct {
@@ -39,16 +48,61 @@ type OidcService struct {
appConfigService *AppConfigService
auditLogService *AuditLogService
customClaimService *CustomClaimService
+
+ httpClient *http.Client
+ jwkCache *jwk.Cache
}
-func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
- return &OidcService{
+func NewOidcService(
+ ctx context.Context,
+ db *gorm.DB,
+ jwtService *JwtService,
+ appConfigService *AppConfigService,
+ auditLogService *AuditLogService,
+ customClaimService *CustomClaimService,
+) (s *OidcService, err error) {
+ s = &OidcService{
db: db,
jwtService: jwtService,
appConfigService: appConfigService,
auditLogService: auditLogService,
customClaimService: customClaimService,
}
+
+ // Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
+ s.jwkCache, err = s.getJWKCache(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return s, nil
+}
+
+func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
+ // We need to create a custom HTTP client to set a timeout.
+ client := s.httpClient
+ if client == nil {
+ client = &http.Client{
+ Timeout: 20 * time.Second,
+ }
+
+ defaultTransport, ok := http.DefaultTransport.(*http.Transport)
+ if !ok {
+ // Indicates a development-time error
+ panic("Default transport is not of type *http.Transport")
+ }
+ transport := defaultTransport.Clone()
+ transport.TLSClientConfig.MinVersion = tls.VersionTLS12
+ client.Transport = transport
+ }
+
+ // Create the JWKS cache
+ return jwk.NewCache(ctx,
+ httprc.NewClient(
+ httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
+ httprc.WithHTTPClient(client),
+ ),
+ )
}
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
@@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
- _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
+ _, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
- client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
+ client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
tx.Rollback()
}()
- _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
+ _, err := s.verifyClientCredentialsInternal(ctx, tx, input)
if err != nil {
return CreatedTokens{}, err
}
@@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
- _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
+ _, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ })
if err != nil {
return introspectDto, err
}
@@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
introspectDto.Active = true
introspectDto.TokenType = "access_token"
if token.Has("scope") {
- var asString string
- var asStrings []string
+ var (
+ asString string
+ asStrings []string
+ )
if err := token.Get("scope", &asString); err == nil {
introspectDto.Scope = asString
} else if err := token.Get("scope", &asStrings); err == nil {
introspectDto.Scope = strings.Join(asStrings, " ")
}
}
- if expiration, hasExpiration := token.Expiration(); hasExpiration {
+ if expiration, ok := token.Expiration(); ok {
introspectDto.Expiration = expiration.Unix()
}
- if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
+ if issuedAt, ok := token.IssuedAt(); ok {
introspectDto.IssuedAt = issuedAt.Unix()
}
- if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
+ if notBefore, ok := token.NotBefore(); ok {
introspectDto.NotBefore = notBefore.Unix()
}
- if subject, hasSubject := token.Subject(); hasSubject {
+ if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject
}
- if audience, hasAudience := token.Audience(); hasAudience {
+ if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
- if issuer, hasIssuer := token.Issuer(); hasIssuer {
+ if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer
}
- if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
+ if identifier, ok := token.JwtID(); ok {
introspectDto.Identifier = identifier
}
@@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
- Name: input.Name,
- CallbackURLs: input.CallbackURLs,
- LogoutCallbackURLs: input.LogoutCallbackURLs,
- CreatedByID: userID,
- IsPublic: input.IsPublic,
- PkceEnabled: input.PkceEnabled,
+ CreatedByID: userID,
}
+ updateOIDCClientModelFromDto(&client, &input)
err := s.db.
WithContext(ctx).
@@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return model.OidcClient{}, err
}
- client.Name = input.Name
- client.CallbackURLs = input.CallbackURLs
- client.LogoutCallbackURLs = input.LogoutCallbackURLs
- client.IsPublic = input.IsPublic
- client.PkceEnabled = input.IsPublic || input.PkceEnabled
+ updateOIDCClientModelFromDto(&client, &input)
err = tx.
WithContext(ctx).
@@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
return client, nil
}
+func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
+ // Base fields
+ client.Name = input.Name
+ client.CallbackURLs = input.CallbackURLs
+ client.LogoutCallbackURLs = input.LogoutCallbackURLs
+ client.IsPublic = input.IsPublic
+ // PKCE is required for public clients
+ client.PkceEnabled = input.IsPublic || input.PkceEnabled
+
+ // Credentials
+ if len(input.Credentials.FederatedIdentities) > 0 {
+ client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
+ for i, fi := range input.Credentials.FederatedIdentities {
+ client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
+ Issuer: fi.Issuer,
+ Audience: fi.Audience,
+ Subject: fi.Subject,
+ JWKS: fi.JWKS,
+ }
+ }
+ }
+}
+
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient
err := s.db.
@@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
}
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
- client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
+ client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
+ ClientID: input.ClientID,
+ ClientSecret: input.ClientSecret,
+ })
if err != nil {
return nil, err
}
@@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err
}
-func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
+func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
- if clientID == "" {
- return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
+ if input.ClientID == "" {
+ return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
var client model.OidcClient
err := tx.
WithContext(ctx).
- First(&client, "id = ?", clientID).
+ First(&client, "id = ?", input.ClientID).
Error
if err != nil {
- return model.OidcClient{}, err
+ return nil, err
}
- // If we have a client secret, we validate it
- // Otherwise, we require the client to be public
- if clientSecret != "" {
- err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
+ // We have 3 options
+ // If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
+ switch {
+ // First, if we have a client secret, we validate it
+ case input.ClientSecret != "":
+ err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
if err != nil {
- return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
+ return nil, &common.OidcClientSecretInvalidError{}
+ }
+ return &client, nil
+
+ // Next, check if we want to use client assertions from federated identities
+ case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
+ err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
+ if err != nil {
+ log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
+ return nil, &common.OidcClientAssertionInvalidError{}
+ }
+ return &client, nil
+
+ // There's no credentials
+ // This is allowed only if the client is public
+ case client.IsPublic:
+ return &client, nil
+
+ // If we're here, we have no credentials AND the client is not public, so credentials are required
+ default:
+ return nil, &common.OidcMissingClientCredentialsError{}
+ }
+}
+
+func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
+ // Check if we have already registered the URL
+ if !s.jwkCache.IsRegistered(ctx, url) {
+ // We set a timeout because otherwise Register will keep trying in case of errors
+ registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
+ defer registerCancel()
+ // We need to register the URL
+ err = s.jwkCache.Register(
+ registerCtx,
+ url,
+ jwk.WithMaxInterval(24*time.Hour),
+ jwk.WithMinInterval(15*time.Minute),
+ jwk.WithWaitReady(true),
+ )
+ // In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
+ if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
+ return nil, fmt.Errorf("failed to register JWK set: %w", err)
}
- return client, nil
- } else if !client.IsPublic {
- return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
- return client, nil
+ jwks, err := s.jwkCache.CachedSet(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
+ }
+
+ return jwks, nil
+}
+
+func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
+ // First, parse the assertion JWT, without validating it, to check the issuer
+ assertion := []byte(input.ClientAssertion)
+ insecureToken, err := jwt.ParseInsecure(assertion)
+ if err != nil {
+ return fmt.Errorf("failed to parse client assertion JWT: %w", err)
+ }
+
+ issuer, _ := insecureToken.Issuer()
+ if issuer == "" {
+ return errors.New("client assertion does not contain an issuer claim")
+ }
+
+ // Ensure that this client is federated with the one that issued the token
+ ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
+ if !ok {
+ return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
+ }
+
+ // Get the JWK set for the issuer
+ jwksURL := ocfi.JWKS
+ if jwksURL == "" {
+ // Default URL is from the issuer
+ if strings.HasSuffix(issuer, "/") {
+ jwksURL = issuer + ".well-known/jwks.json"
+ } else {
+ jwksURL = issuer + "/.well-known/jwks.json"
+ }
+ }
+ jwks, err := s.jwkSetForURL(ctx, jwksURL)
+ if err != nil {
+ return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
+ }
+
+ // Set default audience and subject if missing
+ audience := ocfi.Audience
+ if audience == "" {
+ // Default to the Pocket ID's URL
+ audience = common.EnvConfig.AppURL
+ }
+ subject := ocfi.Subject
+ if subject == "" {
+ // Default to the client ID, per RFC 7523
+ subject = client.ID
+ }
+
+ // Now re-parse the token with proper validation
+ // (Note: we don't use jwt.WithIssuer() because that would be redundant)
+ _, err = jwt.Parse(assertion,
+ jwt.WithValidate(true),
+ jwt.WithAcceptableSkew(clockSkew),
+ jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
+ jwt.WithAudience(audience),
+ jwt.WithSubject(subject),
+ )
+ if err != nil {
+ return fmt.Errorf("client assertion is not valid: %w", err)
+ }
+
+ // If we're here, the assertion is valid
+ return nil
}
diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go
new file mode 100644
index 00000000..3a230f49
--- /dev/null
+++ b/backend/internal/service/oidc_service_test.go
@@ -0,0 +1,365 @@
+package service
+
+import (
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "encoding/json"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/lestrrat-go/jwx/v3/jwa"
+ "github.com/lestrrat-go/jwx/v3/jwk"
+ "github.com/lestrrat-go/jwx/v3/jwt"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/pocket-id/pocket-id/backend/internal/common"
+ "github.com/pocket-id/pocket-id/backend/internal/dto"
+)
+
+// generateTestECDSAKey creates an ECDSA key for testing
+func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
+ t.Helper()
+
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ require.NoError(t, err)
+
+ privateJwk, err := jwk.Import(privateKey)
+ require.NoError(t, err)
+
+ err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
+ require.NoError(t, err)
+ err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
+ require.NoError(t, err)
+ err = privateJwk.Set("use", "sig")
+ require.NoError(t, err)
+
+ publicJwk, err := jwk.PublicKeyOf(privateJwk)
+ require.NoError(t, err)
+
+ // Create a JWK Set with the public key
+ jwkSet := jwk.NewSet()
+ err = jwkSet.AddKey(publicJwk)
+ require.NoError(t, err)
+ jwkSetJSON, err := json.Marshal(jwkSet)
+ require.NoError(t, err)
+
+ return privateJwk, jwkSetJSON
+}
+
+func TestOidcService_jwkSetForURL(t *testing.T) {
+ // Generate a test key for JWKS
+ _, jwkSetJSON1 := generateTestECDSAKey(t)
+ _, jwkSetJSON2 := generateTestECDSAKey(t)
+
+ // Create a mock HTTP client with responses for different URLs
+ const (
+ url1 = "https://example.com/.well-known/jwks.json"
+ url2 = "https://other-issuer.com/jwks"
+ )
+ mockResponses := map[string]*http.Response{
+ //nolint:bodyclose
+ url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
+ //nolint:bodyclose
+ url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
+ }
+ httpClient := &http.Client{
+ Transport: &MockRoundTripper{
+ Responses: mockResponses,
+ },
+ }
+
+ // Create the OidcService with our mock client
+ s := &OidcService{
+ httpClient: httpClient,
+ }
+
+ var err error
+ s.jwkCache, err = s.getJWKCache(t.Context())
+ require.NoError(t, err)
+
+ t.Run("Fetches and caches JWK set", func(t *testing.T) {
+ jwks, err := s.jwkSetForURL(t.Context(), url1)
+ require.NoError(t, err)
+ require.NotNil(t, jwks)
+
+ // Verify the JWK set contains our key
+ require.Equal(t, 1, jwks.Len())
+ })
+
+ t.Run("Fails with invalid URL", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
+ defer cancel()
+ _, err := s.jwkSetForURL(ctx, "https://bad-url.com")
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+ })
+
+ t.Run("Safe for concurrent use", func(t *testing.T) {
+ const concurrency = 20
+
+ // Channel to collect errors
+ errChan := make(chan error, concurrency)
+
+ // Start concurrent requests
+ for range concurrency {
+ go func() {
+ jwks, err := s.jwkSetForURL(t.Context(), url2)
+ if err != nil {
+ errChan <- err
+ return
+ }
+
+ // Verify the JWK set is valid
+ if jwks == nil || jwks.Len() != 1 {
+ errChan <- assert.AnError
+ return
+ }
+
+ errChan <- nil
+ }()
+ }
+
+ // Check for errors
+ for range concurrency {
+ assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
+ }
+ })
+}
+
+func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
+ const (
+ federatedClientIssuer = "https://external-idp.com"
+ federatedClientAudience = "https://pocket-id.com"
+ federatedClientSubject = "123456abcdef"
+ federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
+ )
+
+ var err error
+ // Create a test database
+ db := newDatabaseForTest(t)
+
+ // Create two JWKs for testing
+ privateJWK, jwkSetJSON := generateTestECDSAKey(t)
+ require.NoError(t, err)
+ privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
+ require.NoError(t, err)
+
+ // Create a mock HTTP client with custom transport to return the JWKS
+ httpClient := &http.Client{
+ Transport: &MockRoundTripper{
+ Responses: map[string]*http.Response{
+ //nolint:bodyclose
+ federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
+ //nolint:bodyclose
+ federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
+ },
+ },
+ }
+
+ // Init the OidcService
+ s := &OidcService{
+ db: db,
+ httpClient: httpClient,
+ }
+ s.jwkCache, err = s.getJWKCache(t.Context())
+ require.NoError(t, err)
+
+ // Create the test clients
+ // 1. Confidential client
+ confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
+ Name: "Confidential Client",
+ CallbackURLs: []string{"https://example.com/callback"},
+ }, "test-user-id")
+ require.NoError(t, err)
+
+ // Create a client secret for the confidential client
+ confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
+ require.NoError(t, err)
+
+ // 2. Public client
+ publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
+ Name: "Public Client",
+ CallbackURLs: []string{"https://example.com/callback"},
+ IsPublic: true,
+ }, "test-user-id")
+ require.NoError(t, err)
+
+ // 3. Confidential client with federated identity
+ federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
+ Name: "Federated Client",
+ CallbackURLs: []string{"https://example.com/callback"},
+ Credentials: dto.OidcClientCredentialsDto{
+ FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
+ {
+ Issuer: federatedClientIssuer,
+ Audience: federatedClientAudience,
+ Subject: federatedClientSubject,
+ JWKS: federatedClientIssuer + "/jwks.json",
+ },
+ {Issuer: federatedClientIssuerDefaults},
+ },
+ },
+ }, "test-user-id")
+ require.NoError(t, err)
+
+ // Test cases for confidential client (using client secret)
+ t.Run("Confidential client", func(t *testing.T) {
+ t.Run("Succeeds with valid secret", func(t *testing.T) {
+ // Test with valid client credentials
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: confidentialClient.ID,
+ ClientSecret: confidentialSecret,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ assert.Equal(t, confidentialClient.ID, client.ID)
+ })
+
+ t.Run("Fails with invalid secret", func(t *testing.T) {
+ // Test with invalid client secret
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: confidentialClient.ID,
+ ClientSecret: "invalid-secret",
+ })
+ require.Error(t, err)
+ require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
+ assert.Nil(t, client)
+ })
+
+ t.Run("Fails with missing secret", func(t *testing.T) {
+ // Test with missing client secret
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: confidentialClient.ID,
+ })
+ require.Error(t, err)
+ require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
+ assert.Nil(t, client)
+ })
+ })
+
+ // Test cases for public client
+ t.Run("Public client", func(t *testing.T) {
+ t.Run("Succeeds with no credentials", func(t *testing.T) {
+ // Public clients don't require client secret
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: publicClient.ID,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ assert.Equal(t, publicClient.ID, client.ID)
+ })
+ })
+
+ // Test cases for federated client using JWT assertion
+ t.Run("Federated client", func(t *testing.T) {
+ t.Run("Succeeds with valid JWT", func(t *testing.T) {
+ // Create JWT for federated identity
+ token, err := jwt.NewBuilder().
+ Issuer(federatedClientIssuer).
+ Audience([]string{federatedClientAudience}).
+ Subject(federatedClientSubject).
+ IssuedAt(time.Now()).
+ Expiration(time.Now().Add(10 * time.Minute)).
+ Build()
+ require.NoError(t, err)
+ signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
+ require.NoError(t, err)
+
+ // Test with valid JWT assertion
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: federatedClient.ID,
+ ClientAssertionType: ClientAssertionTypeJWTBearer,
+ ClientAssertion: string(signedToken),
+ })
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ assert.Equal(t, federatedClient.ID, client.ID)
+ })
+
+ t.Run("Fails with malformed JWT", func(t *testing.T) {
+ // Test with invalid JWT assertion (just a random string)
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: federatedClient.ID,
+ ClientAssertionType: ClientAssertionTypeJWTBearer,
+ ClientAssertion: "invalid.jwt.token",
+ })
+ require.Error(t, err)
+ require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
+ assert.Nil(t, client)
+ })
+
+ testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
+ return func(t *testing.T) {
+ // Populate all claims with valid values
+ builder := jwt.NewBuilder().
+ Issuer(federatedClientIssuer).
+ Audience([]string{federatedClientAudience}).
+ Subject(federatedClientSubject).
+ IssuedAt(time.Now()).
+ Expiration(time.Now().Add(10 * time.Minute))
+
+ // Call builderFn to override the claims
+ builderFn(builder)
+
+ token, err := builder.Build()
+ require.NoError(t, err)
+ signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
+ require.NoError(t, err)
+
+ // Test with invalid JWT assertion
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: federatedClient.ID,
+ ClientAssertionType: ClientAssertionTypeJWTBearer,
+ ClientAssertion: string(signedToken),
+ })
+ require.Error(t, err)
+ require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
+ require.Nil(t, client)
+ }
+ }
+
+ t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
+ builder.Expiration(time.Now().Add(-30 * time.Minute))
+ }))
+
+ t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
+ builder.Issuer("https://bad-issuer.com")
+ }))
+
+ t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
+ builder.Audience([]string{"bad-audience"})
+ }))
+
+ t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
+ builder.Subject("bad-subject")
+ }))
+
+ t.Run("Uses default values for audience and subject", func(t *testing.T) {
+ // Create JWT for federated identity
+ token, err := jwt.NewBuilder().
+ Issuer(federatedClientIssuerDefaults).
+ Audience([]string{common.EnvConfig.AppURL}).
+ Subject(federatedClient.ID).
+ IssuedAt(time.Now()).
+ Expiration(time.Now().Add(10 * time.Minute)).
+ Build()
+ require.NoError(t, err)
+ signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
+ require.NoError(t, err)
+
+ // Test with valid JWT assertion
+ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
+ ClientID: federatedClient.ID,
+ ClientAssertionType: ClientAssertionTypeJWTBearer,
+ ClientAssertion: string(signedToken),
+ })
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ assert.Equal(t, federatedClient.ID, client.ID)
+ })
+ })
+}
diff --git a/backend/internal/service/testutils_test.go b/backend/internal/service/testutils_test.go
new file mode 100644
index 00000000..59cdd9fe
--- /dev/null
+++ b/backend/internal/service/testutils_test.go
@@ -0,0 +1,97 @@
+package service
+
+import (
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ _ "github.com/golang-migrate/migrate/v4/source/file"
+
+ "github.com/glebarez/sqlite"
+ "github.com/golang-migrate/migrate/v4"
+ sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
+ "github.com/golang-migrate/migrate/v4/source/iofs"
+ "github.com/stretchr/testify/require"
+ "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+
+ "github.com/pocket-id/pocket-id/backend/internal/utils"
+ "github.com/pocket-id/pocket-id/backend/resources"
+)
+
+func newDatabaseForTest(t *testing.T) *gorm.DB {
+ t.Helper()
+
+ // Get a name for this in-memory database that is specific to the test
+ dbName := utils.CreateSha256Hash(t.Name())
+
+ // Connect to a new in-memory SQL database
+ db, err := gorm.Open(
+ sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
+ &gorm.Config{
+ TranslateError: true,
+ Logger: logger.New(
+ testLoggerAdapter{t: t},
+ logger.Config{
+ SlowThreshold: 200 * time.Millisecond,
+ LogLevel: logger.Info,
+ IgnoreRecordNotFoundError: false,
+ ParameterizedQueries: false,
+ Colorful: false,
+ },
+ ),
+ })
+ require.NoError(t, err, "Failed to connect to test database")
+
+ // Perform migrations with the embedded migrations
+ sqlDB, err := db.DB()
+ require.NoError(t, err, "Failed to get sql.DB")
+ driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
+ require.NoError(t, err, "Failed to create migration driver")
+ source, err := iofs.New(resources.FS, "migrations/sqlite")
+ require.NoError(t, err, "Failed to create embedded migration source")
+ m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
+ require.NoError(t, err, "Failed to create migration instance")
+ err = m.Up()
+ require.NoError(t, err, "Failed to perform migrations")
+
+ return db
+}
+
+// Implements gorm's logger.Writer interface
+type testLoggerAdapter struct {
+ t *testing.T
+}
+
+func (l testLoggerAdapter) Printf(format string, args ...any) {
+ l.t.Logf(format, args...)
+}
+
+// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
+type MockRoundTripper struct {
+ Err error
+ Responses map[string]*http.Response
+}
+
+// RoundTrip implements the http.RoundTripper interface
+func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Check if we have a specific response for this URL
+ for url, resp := range m.Responses {
+ if req.URL.String() == url {
+ return resp, nil
+ }
+ }
+
+ return NewMockResponse(http.StatusNotFound, ""), nil
+}
+
+// NewMockResponse creates an http.Response with the given status code and body
+func NewMockResponse(statusCode int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: statusCode,
+ Body: io.NopCloser(strings.NewReader(body)),
+ Header: make(http.Header),
+ }
+}
diff --git a/backend/internal/utils/jwk_util.go b/backend/internal/utils/jwk_util.go
new file mode 100644
index 00000000..2571b514
--- /dev/null
+++ b/backend/internal/utils/jwk_util.go
@@ -0,0 +1,69 @@
+package utils
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "io"
+
+ "github.com/lestrrat-go/jwx/v3/jwa"
+ "github.com/lestrrat-go/jwx/v3/jwk"
+)
+
+const (
+ // KeyUsageSigning is the usage for the private keys, for the "use" property
+ KeyUsageSigning = "sig"
+)
+
+// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
+// It also populates additional fields such as the key ID, usage, and alg.
+func ImportRawKey(rawKey any) (jwk.Key, error) {
+ key, err := jwk.Import(rawKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to import generated private key: %w", err)
+ }
+
+ // Generate the key ID
+ kid, err := generateRandomKeyID()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate key ID: %w", err)
+ }
+ _ = key.Set(jwk.KeyIDKey, kid)
+
+ // Set other required fields
+ _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
+ EnsureAlgInKey(key)
+
+ return key, nil
+}
+
+// generateRandomKeyID generates a random key ID.
+func generateRandomKeyID() (string, error) {
+ buf := make([]byte, 8)
+ _, err := io.ReadFull(rand.Reader, buf)
+ if err != nil {
+ return "", fmt.Errorf("failed to read random bytes: %w", err)
+ }
+ return base64.RawURLEncoding.EncodeToString(buf), nil
+}
+
+// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
+func EnsureAlgInKey(key jwk.Key) {
+ _, ok := key.Algorithm()
+ if ok {
+ // Algorithm is already set
+ return
+ }
+
+ switch key.KeyType() {
+ case jwa.RSA():
+ // Default to RS256 for RSA keys
+ _ = key.Set(jwk.AlgorithmKey, jwa.RS256())
+ case jwa.EC():
+ // Default to ES256 for ECDSA keys
+ _ = key.Set(jwk.AlgorithmKey, jwa.ES256())
+ case jwa.OKP():
+ // Default to EdDSA for OKP keys
+ _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
+ }
+}
diff --git a/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql b/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql
new file mode 100644
index 00000000..41aab182
--- /dev/null
+++ b/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql
@@ -0,0 +1 @@
+ALTER TABLE oidc_clients DROP COLUMN credentials;
diff --git a/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql b/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql
new file mode 100644
index 00000000..3aa11798
--- /dev/null
+++ b/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql
@@ -0,0 +1 @@
+ALTER TABLE oidc_clients ADD COLUMN credentials JSONB NULL;
diff --git a/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql b/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql
new file mode 100644
index 00000000..41aab182
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql
@@ -0,0 +1 @@
+ALTER TABLE oidc_clients DROP COLUMN credentials;
diff --git a/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql b/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql
new file mode 100644
index 00000000..26089724
--- /dev/null
+++ b/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql
@@ -0,0 +1 @@
+ALTER TABLE oidc_clients ADD COLUMN credentials TEXT NULL;
diff --git a/frontend/messages/en.json b/frontend/messages/en.json
index d798c524..627b058e 100644
--- a/frontend/messages/en.json
+++ b/frontend/messages/en.json
@@ -348,6 +348,12 @@
"the_device_has_been_authorized": "The device has been authorized.",
"enter_code_displayed_in_previous_step": "Enter the code that was displayed in the previous step.",
"authorize": "Authorize",
+ "federated_identities": "Federated Identities",
+ "federated_identities_description": "Using federated identities, you can authenticate OIDC clients using JWT tokens issued by third-party authorities.",
+ "add_federated_identity": "Add Federated Identity",
+ "add_another_federated_identity": "Add another federated identity",
"oidc_allowed_group_count": "Allowed Group Count",
- "unrestricted": "Unrestricted"
+ "unrestricted": "Unrestricted",
+ "show_advanced_options": "Show Advanced Options",
+ "hide_advanced_options": "Hide Advanced Options"
}
diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts
index dfffa89c..fdaebd65 100644
--- a/frontend/src/lib/types/oidc.type.ts
+++ b/frontend/src/lib/types/oidc.type.ts
@@ -6,11 +6,23 @@ export type OidcClientMetaData = {
hasLogo: boolean;
};
+export type OidcClientFederatedIdentity = {
+ issuer: string;
+ subject?: string;
+ audience?: string;
+ jwks?: string;
+};
+
+export type OidcClientCredentials = {
+ federatedIdentities: OidcClientFederatedIdentity[];
+};
+
export type OidcClient = OidcClientMetaData & {
callbackURLs: string[]; // No longer requires at least one URL
logoutCallbackURLs: string[];
isPublic: boolean;
pkceEnabled: boolean;
+ credentials?: OidcClientCredentials;
};
export type OidcClientWithAllowedUserGroups = OidcClient & {
diff --git a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte
index 1b451a50..e55be368 100644
--- a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte
+++ b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte
@@ -8,6 +8,7 @@
import * as Card from '$lib/components/ui/card';
import Label from '$lib/components/ui/label/label.svelte';
import UserGroupSelection from '$lib/components/user-group-selection.svelte';
+ import { m } from '$lib/paraglide/messages';
import OidcService from '$lib/services/oidc-service';
import clientSecretStore from '$lib/stores/client-secret-store';
import type { OidcClientCreateWithLogo } from '$lib/types/oidc.type';
@@ -16,7 +17,6 @@
import { toast } from 'svelte-sonner';
import { slide } from 'svelte/transition';
import OidcForm from '../oidc-client-form.svelte';
- import { m } from '$lib/paraglide/messages';
let { data } = $props();
let client = $state({
@@ -166,7 +166,7 @@
{error}
+ {/if} + + +