diff --git a/.gitignore b/.gitignore index efb0a768..5978887f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ node_modules /frontend/build /backend/bin pocket-id +/tests/test-results/*.json # OS .DS_Store diff --git a/backend/go.mod b/backend/go.mod index 4dd4ef0f..0adeaa41 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -20,7 +20,8 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/go-uuid v1.0.3 github.com/joho/godotenv v1.5.1 - github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 + github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 + github.com/lestrrat-go/jwx/v3 v3.0.1 github.com/mileusna/useragent v1.3.5 github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2 github.com/stretchr/testify v1.10.0 @@ -32,7 +33,7 @@ require ( go.opentelemetry.io/otel/sdk v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0 go.opentelemetry.io/otel/trace v1.35.0 - golang.org/x/crypto v0.36.0 + golang.org/x/crypto v0.37.0 golang.org/x/image v0.24.0 golang.org/x/time v0.9.0 gorm.io/driver/postgres v1.5.11 @@ -77,9 +78,8 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/blackmagic v1.0.3 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect - github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -123,7 +123,7 @@ require ( golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.24.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect google.golang.org/grpc v1.71.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 7042fe85..7e3a5370 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -164,14 +164,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= -github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs= +github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= -github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA= -github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms= -github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I= -github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc= +github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY= +github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE= +github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY= +github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -309,8 +309,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -377,8 +377,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/backend/internal/bootstrap/e2etest_router_bootstrap.go b/backend/internal/bootstrap/e2etest_router_bootstrap.go index bda3360a..e16bdc9f 100644 --- a/backend/internal/bootstrap/e2etest_router_bootstrap.go +++ b/backend/internal/bootstrap/e2etest_router_bootstrap.go @@ -3,6 +3,8 @@ package bootstrap import ( + "log" + "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -14,7 +16,12 @@ import ( func init() { registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){ func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) { - testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService) + testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService) + if err != nil { + log.Fatalf("failed to initialize test service: %v", err) + return + } + controller.NewTestController(apiGroup, testService) }, } diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go index 1ba817cd..892f47d0 100644 --- a/backend/internal/bootstrap/services_bootstrap.go +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -26,15 +26,14 @@ type services struct { } // Initializes all services -// The context should be used by services only for initialization, and not for running -func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) { +func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) { svc = &services{} - svc.appConfigService = service.NewAppConfigService(initCtx, db) + svc.appConfigService = service.NewAppConfigService(ctx, db) svc.emailService, err = service.NewEmailService(db, svc.appConfigService) if err != nil { - return nil, fmt.Errorf("unable to create email service: %w", err) + return nil, fmt.Errorf("failed to create email service: %w", err) } svc.geoLiteService = service.NewGeoLiteService(httpClient) @@ -42,7 +41,12 @@ func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) svc.jwtService = service.NewJwtService(svc.appConfigService) svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService) svc.customClaimService = service.NewCustomClaimService(db) - svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService) + + svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC service: %w", err) + } + svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService) svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService) svc.apiKeyService = service.NewApiKeyService(db, svc.emailService) diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index d385d11c..86f90e93 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{} func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" } func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 } +type OidcClientAssertionInvalidError struct{} + +func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" } +func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 } + type OidcInvalidAuthorizationCodeError struct{} func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" } diff --git a/backend/internal/controller/e2etest_controller.go b/backend/internal/controller/e2etest_controller.go index 179285a3..d5ecc989 100644 --- a/backend/internal/controller/e2etest_controller.go +++ b/backend/internal/controller/e2etest_controller.go @@ -14,6 +14,9 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService) testController := &TestController{TestService: testService} group.POST("/test/reset", testController.resetAndSeedHandler) + + group.GET("/externalidp/jwks.json", testController.externalIdPJWKS) + group.POST("/externalidp/sign", testController.externalIdPSignToken) } type TestController struct { @@ -21,6 +24,15 @@ type TestController struct { } func (tc *TestController) resetAndSeedHandler(c *gin.Context) { + var baseURL string + if c.Request.TLS != nil { + baseURL = "https://" + c.Request.Host + } else { + baseURL = "http://" + c.Request.Host + } + + skipLdap := c.Query("skip-ldap") == "true" + if err := tc.TestService.ResetDatabase(); err != nil { _ = c.Error(err) return @@ -31,7 +43,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) { return } - if err := tc.TestService.SeedDatabase(); err != nil { + if err := tc.TestService.SeedDatabase(baseURL); err != nil { _ = c.Error(err) return } @@ -41,17 +53,50 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) { return } - if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil { - _ = c.Error(err) - return - } + if !skipLdap { + if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil { + _ = c.Error(err) + return + } - if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil { - _ = c.Error(err) - return + if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil { + _ = c.Error(err) + return + } } tc.TestService.SetJWTKeys() c.Status(http.StatusNoContent) } + +func (tc *TestController) externalIdPJWKS(c *gin.Context) { + jwks, err := tc.TestService.GetExternalIdPJWKS() + if err != nil { + _ = c.Error(err) + return + } + + c.JSON(http.StatusOK, jwks) +} + +func (tc *TestController) externalIdPSignToken(c *gin.Context) { + var input struct { + Aud string `json:"aud"` + Iss string `json:"iss"` + Sub string `json:"sub"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + _ = c.Error(err) + return + } + + token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud) + if err != nil { + _ = c.Error(err) + return + } + + c.Writer.WriteString(token) +} diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index d6cb3fa2..68d7300d 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -7,14 +7,14 @@ import ( "net/url" "strings" - "github.com/pocket-id/pocket-id/backend/internal/common" - "github.com/pocket-id/pocket-id/backend/internal/utils/cookie" - "github.com/gin-gonic/gin" + + "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/middleware" "github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/utils" + "github.com/pocket-id/pocket-id/backend/internal/utils/cookie" ) // NewOidcController creates a new controller for OIDC related endpoints @@ -124,11 +124,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex // @Tags OIDC // @Produce json // @Param client_id formData string false "Client ID (if not using Basic Auth)" -// @Param client_secret formData string false "Client secret (if not using Basic Auth)" +// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)" // @Param code formData string false "Authorization code (required for 'authorization_code' grant)" // @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')" // @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)" // @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)" +// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)" +// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)" // @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token" // @Router /api/oidc/token [post] func (oc *OidcController) createTokensHandler(c *gin.Context) { @@ -363,12 +365,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { clientDto := dto.OidcClientWithAllowedUserGroupsDto{} err = dto.MapStruct(client, &clientDto) - if err == nil { - c.JSON(http.StatusOK, clientDto) + if err != nil { + _ = c.Error(err) return } - _ = c.Error(err) + c.JSON(http.StatusOK, clientDto) } // listClientsHandler godoc diff --git a/backend/internal/dto/dto_mapper.go b/backend/internal/dto/dto_mapper.go index 8c027d0b..f727dfc9 100644 --- a/backend/internal/dto/dto_mapper.go +++ b/backend/internal/dto/dto_mapper.go @@ -62,7 +62,60 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error { return nil } +//nolint:gocognit func mapField(sourceField reflect.Value, destField reflect.Value) error { + // Handle pointer to struct in source + if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() { + switch { + case sourceField.Elem().Kind() == reflect.Struct: + switch { + case destField.Kind() == reflect.Struct: + // Map from pointer to struct -> struct + return mapStructInternal(sourceField.Elem(), destField) + case destField.Kind() == reflect.Ptr && destField.CanSet(): + // Map from pointer to struct -> pointer to struct + if destField.IsNil() { + destField.Set(reflect.New(destField.Type().Elem())) + } + return mapStructInternal(sourceField.Elem(), destField.Elem()) + } + case destField.Kind() == reflect.Ptr && + destField.CanSet() && + sourceField.Elem().Type().AssignableTo(destField.Type().Elem()): + // Handle primitive pointer types (e.g., *string to *string) + if destField.IsNil() { + destField.Set(reflect.New(destField.Type().Elem())) + } + destField.Elem().Set(sourceField.Elem()) + return nil + case destField.Kind() != reflect.Ptr && + destField.CanSet() && + sourceField.Elem().Type().AssignableTo(destField.Type()): + // Handle *T to T conversion for primitive types + destField.Set(sourceField.Elem()) + return nil + } + } + + // Handle pointer to struct in destination + if destField.Kind() == reflect.Ptr && destField.CanSet() { + switch { + case sourceField.Kind() == reflect.Struct: + // Map from struct -> pointer to struct + if destField.IsNil() { + destField.Set(reflect.New(destField.Type().Elem())) + } + return mapStructInternal(sourceField, destField.Elem()) + case !sourceField.IsZero() && sourceField.Type().AssignableTo(destField.Type().Elem()): + // Handle T to *T conversion for primitive types + if destField.IsNil() { + destField.Set(reflect.New(destField.Type().Elem())) + } + destField.Elem().Set(sourceField) + return nil + } + } + switch { case sourceField.Type() == destField.Type(): destField.Set(sourceField) diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 9c0b62c7..df317787 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -8,10 +8,11 @@ type OidcClientMetaDataDto struct { type OidcClientDto struct { OidcClientMetaDataDto - CallbackURLs []string `json:"callbackURLs"` - LogoutCallbackURLs []string `json:"logoutCallbackURLs"` - IsPublic bool `json:"isPublic"` - PkceEnabled bool `json:"pkceEnabled"` + CallbackURLs []string `json:"callbackURLs"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs"` + IsPublic bool `json:"isPublic"` + PkceEnabled bool `json:"pkceEnabled"` + Credentials OidcClientCredentialsDto `json:"credentials"` } type OidcClientWithAllowedUserGroupsDto struct { @@ -25,11 +26,23 @@ type OidcClientWithAllowedGroupsCountDto struct { } type OidcClientCreateDto struct { - Name string `json:"name" binding:"required,max=50"` - CallbackURLs []string `json:"callbackURLs"` - LogoutCallbackURLs []string `json:"logoutCallbackURLs"` - IsPublic bool `json:"isPublic"` - PkceEnabled bool `json:"pkceEnabled"` + Name string `json:"name" binding:"required,max=50"` + CallbackURLs []string `json:"callbackURLs"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs"` + IsPublic bool `json:"isPublic"` + PkceEnabled bool `json:"pkceEnabled"` + Credentials OidcClientCredentialsDto `json:"credentials"` +} + +type OidcClientCredentialsDto struct { + FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"` +} + +type OidcClientFederatedIdentityDto struct { + Issuer string `json:"issuer"` + Subject string `json:"subject,omitempty"` + Audience string `json:"audience,omitempty"` + JWKS string `json:"jwks,omitempty"` } type AuthorizeOidcClientRequestDto struct { @@ -52,13 +65,15 @@ type AuthorizationRequiredDto struct { } type OidcCreateTokensDto struct { - GrantType string `form:"grant_type" binding:"required"` - Code string `form:"code"` - DeviceCode string `form:"device_code"` - ClientID string `form:"client_id"` - ClientSecret string `form:"client_secret"` - CodeVerifier string `form:"code_verifier"` - RefreshToken string `form:"refresh_token"` + GrantType string `form:"grant_type" binding:"required"` + Code string `form:"code"` + DeviceCode string `form:"device_code"` + ClientID string `form:"client_id"` + ClientSecret string `form:"client_secret"` + CodeVerifier string `form:"code_verifier"` + RefreshToken string `form:"refresh_token"` + ClientAssertion string `form:"client_assertion"` + ClientAssertionType string `form:"client_assertion_type"` } type OidcIntrospectDto struct { diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 7958215c..490015bf 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -5,8 +5,9 @@ import ( "encoding/json" "fmt" - datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" "gorm.io/gorm" + + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" ) type UserAuthorizedOidcClient struct { @@ -45,6 +46,7 @@ type OidcClient struct { HasLogo bool `gorm:"-"` IsPublic bool PkceEnabled bool + Credentials OidcClientCredentials AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"` CreatedByID string @@ -71,9 +73,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { return nil } +type OidcClientCredentials struct { //nolint:recvcheck + FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"` +} + +type OidcClientFederatedIdentity struct { + Issuer string `json:"issuer"` + Subject string `json:"subject,omitempty"` + Audience string `json:"audience,omitempty"` + JWKS string `json:"jwks,omitempty"` // URL of the JWKS +} + +func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) { + if issuer == "" { + return OidcClientFederatedIdentity{}, false + } + + for _, fi := range occ.FederatedIdentities { + if fi.Issuer == issuer { + return fi, true + } + } + + return OidcClientFederatedIdentity{}, false +} + +func (occ *OidcClientCredentials) Scan(value any) error { + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, occ) + case string: + return json.Unmarshal([]byte(v), occ) + default: + return fmt.Errorf("unsupported type: %T", value) + } +} + +func (occ OidcClientCredentials) Value() (driver.Value, error) { + return json.Marshal(occ) +} + type UrlList []string //nolint:recvcheck -func (cu *UrlList) Scan(value interface{}) error { +func (cu *UrlList) Scan(value any) error { switch v := value.(type) { case []byte: return json.Unmarshal(v, cu) diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index 0abd4558..5213d059 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -29,17 +29,17 @@ type AppConfigService struct { db *gorm.DB } -func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService { +func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService { service := &AppConfigService{ db: db, } - err := service.LoadDbConfig(initCtx) + err := service.LoadDbConfig(ctx) if err != nil { log.Fatalf("Failed to initialize app config service: %v", err) } - err = service.initInstanceID(initCtx) + err = service.initInstanceID(ctx) if err != nil { log.Fatalf("Failed to initialize instance ID: %v", err) } diff --git a/backend/internal/service/app_config_service_test.go b/backend/internal/service/app_config_service_test.go index c986162b..5b538919 100644 --- a/backend/internal/service/app_config_service_test.go +++ b/backend/internal/service/app_config_service_test.go @@ -3,16 +3,10 @@ package service import ( "sync/atomic" "testing" - "time" - - "github.com/glebarez/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" - "github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/stretchr/testify/require" ) @@ -28,7 +22,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService { func TestLoadDbConfig(t *testing.T) { t.Run("empty config table", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) service := &AppConfigService{ db: db, } @@ -42,7 +36,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("loads value from config table", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Populate the config table with some initial values err := db. @@ -72,7 +66,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("ignores unknown config keys", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Add an entry with a key that doesn't exist in the config struct err := db.Create([]model.AppConfigVariable{ @@ -93,7 +87,7 @@ func TestLoadDbConfig(t *testing.T) { }) t.Run("loading config multiple times", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Initial state err := db.Create([]model.AppConfigVariable{ @@ -135,7 +129,7 @@ func TestLoadDbConfig(t *testing.T) { common.EnvConfig.UiConfigDisabled = true // Create database with config that should be ignored - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) err := db.Create([]model.AppConfigVariable{ {Key: "appName", Value: "DB App"}, {Key: "sessionDuration", Value: "120"}, @@ -171,7 +165,7 @@ func TestLoadDbConfig(t *testing.T) { common.EnvConfig.UiConfigDisabled = false // Create database with config values that should take precedence - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) err := db.Create([]model.AppConfigVariable{ {Key: "appName", Value: "DB App"}, {Key: "sessionDuration", Value: "120"}, @@ -195,7 +189,7 @@ func TestLoadDbConfig(t *testing.T) { func TestUpdateAppConfigValues(t *testing.T) { t.Run("update single value", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -220,7 +214,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("update multiple values", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -264,7 +258,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("empty value resets to default", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -285,7 +279,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("error with odd number of arguments", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -301,7 +295,7 @@ func TestUpdateAppConfigValues(t *testing.T) { }) t.Run("error with invalid key", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -319,7 +313,7 @@ func TestUpdateAppConfigValues(t *testing.T) { func TestUpdateAppConfig(t *testing.T) { t.Run("updates configuration values from DTO", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config service := &AppConfigService{ @@ -392,7 +386,7 @@ func TestUpdateAppConfig(t *testing.T) { }) t.Run("empty values reset to defaults", func(t *testing.T) { - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) // Create a service with default config and modify some values service := &AppConfigService{ @@ -457,7 +451,7 @@ func TestUpdateAppConfig(t *testing.T) { // Disable UI config common.EnvConfig.UiConfigDisabled = true - db := newAppConfigTestDatabaseForTest(t) + db := newDatabaseForTest(t) service := &AppConfigService{ db: db, } @@ -475,49 +469,3 @@ func TestUpdateAppConfig(t *testing.T) { require.ErrorAs(t, err, &uiConfigDisabledErr) }) } - -// Implements gorm's logger.Writer interface -type testLoggerAdapter struct { - t *testing.T -} - -func (l testLoggerAdapter) Printf(format string, args ...any) { - l.t.Logf(format, args...) -} - -func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB { - t.Helper() - - // Get a name for this in-memory database that is specific to the test - dbName := utils.CreateSha256Hash(t.Name()) - - // Connect to a new in-memory SQL database - db, err := gorm.Open( - sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"), - &gorm.Config{ - TranslateError: true, - Logger: logger.New( - testLoggerAdapter{t: t}, - logger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Info, - IgnoreRecordNotFoundError: false, - ParameterizedQueries: false, - Colorful: false, - }, - ), - }) - require.NoError(t, err, "Failed to connect to test database") - - // Create the app_config_variables table - err = db.Exec(` -CREATE TABLE app_config_variables -( - key VARCHAR(100) NOT NULL PRIMARY KEY, - value TEXT NOT NULL -) -`).Error - require.NoError(t, err, "Failed to create test config table") - - return db -} diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index dcee2aac..96234a99 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -5,6 +5,8 @@ package service import ( "context" "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" "encoding/base64" "fmt" @@ -16,6 +18,7 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/go-webauthn/webauthn/protocol" "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" @@ -30,14 +33,43 @@ type TestService struct { jwtService *JwtService appConfigService *AppConfigService ldapService *LdapService + externalIdPKey jwk.Key } -func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService { - return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService} +func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) { + s := &TestService{ + db: db, + appConfigService: appConfigService, + jwtService: jwtService, + ldapService: ldapService, + } + err := s.initExternalIdP() + if err != nil { + return nil, fmt.Errorf("failed to initialize external IdP: %w", err) + } + return s, nil +} + +// Initializes the "external IdP" +// This creates a new "issuing authority" containing a public JWKS +// It also stores the private key internally that will be used to issue JWTs +func (s *TestService) initExternalIdP() error { + // Generate a new ECDSA key + rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("failed to generate private key: %w", err) + } + + s.externalIdPKey, err = utils.ImportRawKey(rawKey) + if err != nil { + return fmt.Errorf("failed to import private key: %w", err) + } + + return nil } //nolint:gocognit -func (s *TestService) SeedDatabase() error { +func (s *TestService) SeedDatabase(baseURL string) error { err := s.db.Transaction(func(tx *gorm.DB) error { users := []model.User{ { @@ -138,6 +170,26 @@ func (s *TestService) SeedDatabase() error { userGroups[1], }, }, + { + Base: model.Base{ + ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b", + }, + Name: "Federated", + Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x + CallbackURLs: model.UrlList{"http://federated/auth/callback"}, + CreatedByID: users[1].ID, + AllowedUserGroups: []model.UserGroup{}, + Credentials: model.OidcClientCredentials{ + FederatedIdentities: []model.OidcClientFederatedIdentity{ + { + Issuer: "https://external-idp.local", + Audience: "api://PocketID", + Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b", + JWKS: baseURL + "/api/externalidp/jwks.json", + }, + }, + }, + }, } for _, client := range oidcClients { if err := tx.Create(&client).Error; err != nil { @@ -145,16 +197,28 @@ func (s *TestService) SeedDatabase() error { } } - authCode := model.OidcAuthorizationCode{ - Code: "auth-code", - Scope: "openid profile", - Nonce: "nonce", - ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), - UserID: users[0].ID, - ClientID: oidcClients[0].ID, + authCodes := []model.OidcAuthorizationCode{ + { + Code: "auth-code", + Scope: "openid profile", + Nonce: "nonce", + ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), + UserID: users[0].ID, + ClientID: oidcClients[0].ID, + }, + { + Code: "federated", + Scope: "openid profile", + Nonce: "nonce", + ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), + UserID: users[1].ID, + ClientID: oidcClients[2].ID, + }, } - if err := tx.Create(&authCode).Error; err != nil { - return err + for _, authCode := range authCodes { + if err := tx.Create(&authCode).Error; err != nil { + return err + } } refreshToken := model.OidcRefreshToken{ @@ -177,13 +241,22 @@ func (s *TestService) SeedDatabase() error { return err } - userAuthorizedClient := model.UserAuthorizedOidcClient{ - Scope: "openid profile email", - UserID: users[0].ID, - ClientID: oidcClients[0].ID, + userAuthorizedClients := []model.UserAuthorizedOidcClient{ + { + Scope: "openid profile email", + UserID: users[0].ID, + ClientID: oidcClients[0].ID, + }, + { + Scope: "openid profile email", + UserID: users[1].ID, + ClientID: oidcClients[2].ID, + }, } - if err := tx.Create(&userAuthorizedClient).Error; err != nil { - return err + for _, userAuthorizedClient := range userAuthorizedClients { + if err := tx.Create(&userAuthorizedClient).Error; err != nil { + return err + } } // To generate a new key pair, run the following command: @@ -405,3 +478,41 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error { return nil } + +// GetExternalIdPJWKS returns the JWKS for the "external IdP". +func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) { + pubKey, err := s.externalIdPKey.PublicKey() + if err != nil { + return nil, fmt.Errorf("failed to get public key: %w", err) + } + + set := jwk.NewSet() + err = set.AddKey(pubKey) + if err != nil { + return nil, fmt.Errorf("failed to add public key to set: %w", err) + } + + return set, nil +} + +func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) { + now := time.Now() + token, err := jwt.NewBuilder(). + Subject(sub). + Expiration(now.Add(time.Hour)). + IssuedAt(now). + Issuer(iss). + Audience([]string{aud}). + Build() + if err != nil { + return "", fmt.Errorf("failed to build token: %w", err) + } + + alg, _ := s.externalIdPKey.Algorithm() + signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return string(signed), nil +} diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index 35dfd848..389cd608 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rand" "crypto/rsa" - "encoding/base64" "encoding/json" "errors" "fmt" - "io" "log" "os" "path/filepath" @@ -372,7 +370,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) { return nil, fmt.Errorf("failed to get public key: %w", err) } - EnsureAlgInKey(pubKey) + utils.EnsureAlgInKey(pubKey) return pubKey, nil } @@ -415,27 +413,6 @@ func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) { return key, nil } -// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type -func EnsureAlgInKey(key jwk.Key) { - _, ok := key.Algorithm() - if ok { - // Algorithm is already set - return - } - - switch key.KeyType() { - case jwa.RSA(): - // Default to RS256 for RSA keys - _ = key.Set(jwk.AlgorithmKey, jwa.RS256()) - case jwa.EC(): - // Default to ES256 for ECDSA keys - _ = key.Set(jwk.AlgorithmKey, jwa.ES256()) - case jwa.OKP(): - // Default to EdDSA for OKP keys - _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA()) - } -} - func (s *JwtService) generateNewRSAKey() (jwk.Key, error) { // We generate RSA keys only rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize) @@ -444,27 +421,7 @@ func (s *JwtService) generateNewRSAKey() (jwk.Key, error) { } // Import the raw key - return importRawKey(rawKey) -} - -func importRawKey(rawKey any) (jwk.Key, error) { - key, err := jwk.Import(rawKey) - if err != nil { - return nil, fmt.Errorf("failed to import generated private key: %w", err) - } - - // Generate the key ID - kid, err := generateRandomKeyID() - if err != nil { - return nil, fmt.Errorf("failed to generate key ID: %w", err) - } - _ = key.Set(jwk.KeyIDKey, kid) - - // Set other required fields - _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning) - EnsureAlgInKey(key) - - return key, err + return utils.ImportRawKey(rawKey) } // SaveKeyJWK saves a JWK to a file @@ -492,16 +449,6 @@ func SaveKeyJWK(key jwk.Key, path string) error { return nil } -// generateRandomKeyID generates a random key ID. -func generateRandomKeyID() (string, error) { - buf := make([]byte, 8) - _, err := io.ReadFull(rand.Reader, buf) - if err != nil { - return "", fmt.Errorf("failed to read random bytes: %w", err) - } - return base64.RawURLEncoding.EncodeToString(buf), nil -} - // GetIsAdmin returns the value of the "isAdmin" claim in the token func GetIsAdmin(token jwt.Token) (bool, error) { if !token.Has(IsAdminClaim) { diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index e4f7babf..43c7520c 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -21,6 +21,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/pocket-id/pocket-id/backend/internal/utils" ) func TestJwtService_Init(t *testing.T) { @@ -1218,7 +1219,7 @@ func TestTokenTypeValidator(t *testing.T) { func importKey(t *testing.T, privateKeyRaw any, path string) string { t.Helper() - privateKey, err := importRawKey(privateKeyRaw) + privateKey, err := utils.ImportRawKey(privateKeyRaw) require.NoError(t, err, "Failed to import private key") err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile)) diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index ba7070c0..0a355917 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -3,18 +3,25 @@ package service import ( "context" "crypto/sha256" + "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" "log" + "log/slog" "mime/multipart" + "net/http" "os" "regexp" "slices" "strings" "time" + "github.com/lestrrat-go/httprc/v3" + "github.com/lestrrat-go/httprc/v3/errsink" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jws" "github.com/lestrrat-go/jwx/v3/jwt" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" @@ -31,6 +38,8 @@ const ( GrantTypeAuthorizationCode = "authorization_code" GrantTypeRefreshToken = "refresh_token" GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" + + ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec ) type OidcService struct { @@ -39,16 +48,61 @@ type OidcService struct { appConfigService *AppConfigService auditLogService *AuditLogService customClaimService *CustomClaimService + + httpClient *http.Client + jwkCache *jwk.Cache } -func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService { - return &OidcService{ +func NewOidcService( + ctx context.Context, + db *gorm.DB, + jwtService *JwtService, + appConfigService *AppConfigService, + auditLogService *AuditLogService, + customClaimService *CustomClaimService, +) (s *OidcService, err error) { + s = &OidcService{ db: db, jwtService: jwtService, appConfigService: appConfigService, auditLogService: auditLogService, customClaimService: customClaimService, } + + // Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace + s.jwkCache, err = s.getJWKCache(ctx) + if err != nil { + return nil, err + } + + return s, nil +} + +func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) { + // We need to create a custom HTTP client to set a timeout. + client := s.httpClient + if client == nil { + client = &http.Client{ + Timeout: 20 * time.Second, + } + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + // Indicates a development-time error + panic("Default transport is not of type *http.Transport") + } + transport := defaultTransport.Clone() + transport.TLSClientConfig.MinVersion = tls.VersionTLS12 + client.Transport = transport + } + + // Create the JWKS cache + return jwk.NewCache(ctx, + httprc.NewClient( + httprc.WithErrorSink(errsink.NewSlog(slog.Default())), + httprc.WithHTTPClient(client), + ), + ) } func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { @@ -198,7 +252,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O tx.Rollback() }() - _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) + _, err := s.verifyClientCredentialsInternal(ctx, tx, input) if err != nil { return CreatedTokens{}, err } @@ -279,7 +333,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu tx.Rollback() }() - client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) + client, err := s.verifyClientCredentialsInternal(ctx, tx, input) if err != nil { return CreatedTokens{}, err } @@ -357,7 +411,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto tx.Rollback() }() - _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx) + _, err := s.verifyClientCredentialsInternal(ctx, tx, input) if err != nil { return CreatedTokens{}, err } @@ -420,7 +474,10 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre return introspectDto, &common.OidcMissingClientCredentialsError{} } - _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db) + _, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ + ClientID: clientID, + ClientSecret: clientSecret, + }) if err != nil { return introspectDto, err } @@ -440,33 +497,35 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre introspectDto.Active = true introspectDto.TokenType = "access_token" if token.Has("scope") { - var asString string - var asStrings []string + var ( + asString string + asStrings []string + ) if err := token.Get("scope", &asString); err == nil { introspectDto.Scope = asString } else if err := token.Get("scope", &asStrings); err == nil { introspectDto.Scope = strings.Join(asStrings, " ") } } - if expiration, hasExpiration := token.Expiration(); hasExpiration { + if expiration, ok := token.Expiration(); ok { introspectDto.Expiration = expiration.Unix() } - if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt { + if issuedAt, ok := token.IssuedAt(); ok { introspectDto.IssuedAt = issuedAt.Unix() } - if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore { + if notBefore, ok := token.NotBefore(); ok { introspectDto.NotBefore = notBefore.Unix() } - if subject, hasSubject := token.Subject(); hasSubject { + if subject, ok := token.Subject(); ok { introspectDto.Subject = subject } - if audience, hasAudience := token.Audience(); hasAudience { + if audience, ok := token.Audience(); ok { introspectDto.Audience = audience } - if issuer, hasIssuer := token.Issuer(); hasIssuer { + if issuer, ok := token.Issuer(); ok { introspectDto.Issuer = issuer } - if identifier, hasIdentifier := token.JwtID(); hasIdentifier { + if identifier, ok := token.JwtID(); ok { introspectDto.Identifier = identifier } @@ -542,13 +601,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { client := model.OidcClient{ - Name: input.Name, - CallbackURLs: input.CallbackURLs, - LogoutCallbackURLs: input.LogoutCallbackURLs, - CreatedByID: userID, - IsPublic: input.IsPublic, - PkceEnabled: input.PkceEnabled, + CreatedByID: userID, } + updateOIDCClientModelFromDto(&client, &input) err := s.db. WithContext(ctx). @@ -577,11 +632,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d return model.OidcClient{}, err } - client.Name = input.Name - client.CallbackURLs = input.CallbackURLs - client.LogoutCallbackURLs = input.LogoutCallbackURLs - client.IsPublic = input.IsPublic - client.PkceEnabled = input.IsPublic || input.PkceEnabled + updateOIDCClientModelFromDto(&client, &input) err = tx. WithContext(ctx). @@ -599,6 +650,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d return client, nil } +func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) { + // Base fields + client.Name = input.Name + client.CallbackURLs = input.CallbackURLs + client.LogoutCallbackURLs = input.LogoutCallbackURLs + client.IsPublic = input.IsPublic + // PKCE is required for public clients + client.PkceEnabled = input.IsPublic || input.PkceEnabled + + // Credentials + if len(input.Credentials.FederatedIdentities) > 0 { + client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities)) + for i, fi := range input.Credentials.FederatedIdentities { + client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{ + Issuer: fi.Issuer, + Audience: fi.Audience, + Subject: fi.Subject, + JWKS: fi.JWKS, + } + } + } +} + func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error { var client model.OidcClient err := s.db. @@ -1079,7 +1153,10 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model. } func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { - client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db) + client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ + ClientID: input.ClientID, + ClientSecret: input.ClientSecret, + }) if err != nil { return nil, err } @@ -1305,33 +1382,140 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID return err } -func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { +func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) { // First, ensure we have a valid client ID - if clientID == "" { - return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} + if input.ClientID == "" { + return nil, &common.OidcMissingClientCredentialsError{} } // Load the OIDC client's configuration var client model.OidcClient err := tx. WithContext(ctx). - First(&client, "id = ?", clientID). + First(&client, "id = ?", input.ClientID). Error if err != nil { - return model.OidcClient{}, err + return nil, err } - // If we have a client secret, we validate it - // Otherwise, we require the client to be public - if clientSecret != "" { - err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) + // We have 3 options + // If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only + switch { + // First, if we have a client secret, we validate it + case input.ClientSecret != "": + err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret)) if err != nil { - return model.OidcClient{}, &common.OidcClientSecretInvalidError{} + return nil, &common.OidcClientSecretInvalidError{} + } + return &client, nil + + // Next, check if we want to use client assertions from federated identities + case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "": + err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input) + if err != nil { + log.Printf("Invalid assertion for client '%s': %v", client.ID, err) + return nil, &common.OidcClientAssertionInvalidError{} + } + return &client, nil + + // There's no credentials + // This is allowed only if the client is public + case client.IsPublic: + return &client, nil + + // If we're here, we have no credentials AND the client is not public, so credentials are required + default: + return nil, &common.OidcMissingClientCredentialsError{} + } +} + +func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) { + // Check if we have already registered the URL + if !s.jwkCache.IsRegistered(ctx, url) { + // We set a timeout because otherwise Register will keep trying in case of errors + registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second) + defer registerCancel() + // We need to register the URL + err = s.jwkCache.Register( + registerCtx, + url, + jwk.WithMaxInterval(24*time.Hour), + jwk.WithMinInterval(15*time.Minute), + jwk.WithWaitReady(true), + ) + // In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error + if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) { + return nil, fmt.Errorf("failed to register JWK set: %w", err) } - return client, nil - } else if !client.IsPublic { - return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} } - return client, nil + jwks, err := s.jwkCache.CachedSet(url) + if err != nil { + return nil, fmt.Errorf("failed to get cached JWK set: %w", err) + } + + return jwks, nil +} + +func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error { + // First, parse the assertion JWT, without validating it, to check the issuer + assertion := []byte(input.ClientAssertion) + insecureToken, err := jwt.ParseInsecure(assertion) + if err != nil { + return fmt.Errorf("failed to parse client assertion JWT: %w", err) + } + + issuer, _ := insecureToken.Issuer() + if issuer == "" { + return errors.New("client assertion does not contain an issuer claim") + } + + // Ensure that this client is federated with the one that issued the token + ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer) + if !ok { + return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer) + } + + // Get the JWK set for the issuer + jwksURL := ocfi.JWKS + if jwksURL == "" { + // Default URL is from the issuer + if strings.HasSuffix(issuer, "/") { + jwksURL = issuer + ".well-known/jwks.json" + } else { + jwksURL = issuer + "/.well-known/jwks.json" + } + } + jwks, err := s.jwkSetForURL(ctx, jwksURL) + if err != nil { + return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err) + } + + // Set default audience and subject if missing + audience := ocfi.Audience + if audience == "" { + // Default to the Pocket ID's URL + audience = common.EnvConfig.AppURL + } + subject := ocfi.Subject + if subject == "" { + // Default to the client ID, per RFC 7523 + subject = client.ID + } + + // Now re-parse the token with proper validation + // (Note: we don't use jwt.WithIssuer() because that would be redundant) + _, err = jwt.Parse(assertion, + jwt.WithValidate(true), + jwt.WithAcceptableSkew(clockSkew), + jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)), + jwt.WithAudience(audience), + jwt.WithSubject(subject), + ) + if err != nil { + return fmt.Errorf("client assertion is not valid: %w", err) + } + + // If we're here, the assertion is valid + return nil } diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go new file mode 100644 index 00000000..3a230f49 --- /dev/null +++ b/backend/internal/service/oidc_service_test.go @@ -0,0 +1,365 @@ +package service + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/dto" +) + +// generateTestECDSAKey creates an ECDSA key for testing +func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) { + t.Helper() + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + privateJwk, err := jwk.Import(privateKey) + require.NoError(t, err) + + err = privateJwk.Set(jwk.KeyIDKey, "test-key-1") + require.NoError(t, err) + err = privateJwk.Set(jwk.AlgorithmKey, "ES256") + require.NoError(t, err) + err = privateJwk.Set("use", "sig") + require.NoError(t, err) + + publicJwk, err := jwk.PublicKeyOf(privateJwk) + require.NoError(t, err) + + // Create a JWK Set with the public key + jwkSet := jwk.NewSet() + err = jwkSet.AddKey(publicJwk) + require.NoError(t, err) + jwkSetJSON, err := json.Marshal(jwkSet) + require.NoError(t, err) + + return privateJwk, jwkSetJSON +} + +func TestOidcService_jwkSetForURL(t *testing.T) { + // Generate a test key for JWKS + _, jwkSetJSON1 := generateTestECDSAKey(t) + _, jwkSetJSON2 := generateTestECDSAKey(t) + + // Create a mock HTTP client with responses for different URLs + const ( + url1 = "https://example.com/.well-known/jwks.json" + url2 = "https://other-issuer.com/jwks" + ) + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)), + //nolint:bodyclose + url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)), + } + httpClient := &http.Client{ + Transport: &MockRoundTripper{ + Responses: mockResponses, + }, + } + + // Create the OidcService with our mock client + s := &OidcService{ + httpClient: httpClient, + } + + var err error + s.jwkCache, err = s.getJWKCache(t.Context()) + require.NoError(t, err) + + t.Run("Fetches and caches JWK set", func(t *testing.T) { + jwks, err := s.jwkSetForURL(t.Context(), url1) + require.NoError(t, err) + require.NotNil(t, jwks) + + // Verify the JWK set contains our key + require.Equal(t, 1, jwks.Len()) + }) + + t.Run("Fails with invalid URL", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + _, err := s.jwkSetForURL(ctx, "https://bad-url.com") + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("Safe for concurrent use", func(t *testing.T) { + const concurrency = 20 + + // Channel to collect errors + errChan := make(chan error, concurrency) + + // Start concurrent requests + for range concurrency { + go func() { + jwks, err := s.jwkSetForURL(t.Context(), url2) + if err != nil { + errChan <- err + return + } + + // Verify the JWK set is valid + if jwks == nil || jwks.Len() != 1 { + errChan <- assert.AnError + return + } + + errChan <- nil + }() + } + + // Check for errors + for range concurrency { + assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors") + } + }) +} + +func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { + const ( + federatedClientIssuer = "https://external-idp.com" + federatedClientAudience = "https://pocket-id.com" + federatedClientSubject = "123456abcdef" + federatedClientIssuerDefaults = "https://external-idp-defaults.com/" + ) + + var err error + // Create a test database + db := newDatabaseForTest(t) + + // Create two JWKs for testing + privateJWK, jwkSetJSON := generateTestECDSAKey(t) + require.NoError(t, err) + privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t) + require.NoError(t, err) + + // Create a mock HTTP client with custom transport to return the JWKS + httpClient := &http.Client{ + Transport: &MockRoundTripper{ + Responses: map[string]*http.Response{ + //nolint:bodyclose + federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)), + //nolint:bodyclose + federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)), + }, + }, + } + + // Init the OidcService + s := &OidcService{ + db: db, + httpClient: httpClient, + } + s.jwkCache, err = s.getJWKCache(t.Context()) + require.NoError(t, err) + + // Create the test clients + // 1. Confidential client + confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{ + Name: "Confidential Client", + CallbackURLs: []string{"https://example.com/callback"}, + }, "test-user-id") + require.NoError(t, err) + + // Create a client secret for the confidential client + confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID) + require.NoError(t, err) + + // 2. Public client + publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{ + Name: "Public Client", + CallbackURLs: []string{"https://example.com/callback"}, + IsPublic: true, + }, "test-user-id") + require.NoError(t, err) + + // 3. Confidential client with federated identity + federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{ + Name: "Federated Client", + CallbackURLs: []string{"https://example.com/callback"}, + Credentials: dto.OidcClientCredentialsDto{ + FederatedIdentities: []dto.OidcClientFederatedIdentityDto{ + { + Issuer: federatedClientIssuer, + Audience: federatedClientAudience, + Subject: federatedClientSubject, + JWKS: federatedClientIssuer + "/jwks.json", + }, + {Issuer: federatedClientIssuerDefaults}, + }, + }, + }, "test-user-id") + require.NoError(t, err) + + // Test cases for confidential client (using client secret) + t.Run("Confidential client", func(t *testing.T) { + t.Run("Succeeds with valid secret", func(t *testing.T) { + // Test with valid client credentials + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: confidentialClient.ID, + ClientSecret: confidentialSecret, + }) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, confidentialClient.ID, client.ID) + }) + + t.Run("Fails with invalid secret", func(t *testing.T) { + // Test with invalid client secret + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: confidentialClient.ID, + ClientSecret: "invalid-secret", + }) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{}) + assert.Nil(t, client) + }) + + t.Run("Fails with missing secret", func(t *testing.T) { + // Test with missing client secret + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: confidentialClient.ID, + }) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{}) + assert.Nil(t, client) + }) + }) + + // Test cases for public client + t.Run("Public client", func(t *testing.T) { + t.Run("Succeeds with no credentials", func(t *testing.T) { + // Public clients don't require client secret + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: publicClient.ID, + }) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, publicClient.ID, client.ID) + }) + }) + + // Test cases for federated client using JWT assertion + t.Run("Federated client", func(t *testing.T) { + t.Run("Succeeds with valid JWT", func(t *testing.T) { + // Create JWT for federated identity + token, err := jwt.NewBuilder(). + Issuer(federatedClientIssuer). + Audience([]string{federatedClientAudience}). + Subject(federatedClientSubject). + IssuedAt(time.Now()). + Expiration(time.Now().Add(10 * time.Minute)). + Build() + require.NoError(t, err) + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK)) + require.NoError(t, err) + + // Test with valid JWT assertion + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: federatedClient.ID, + ClientAssertionType: ClientAssertionTypeJWTBearer, + ClientAssertion: string(signedToken), + }) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, federatedClient.ID, client.ID) + }) + + t.Run("Fails with malformed JWT", func(t *testing.T) { + // Test with invalid JWT assertion (just a random string) + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: federatedClient.ID, + ClientAssertionType: ClientAssertionTypeJWTBearer, + ClientAssertion: "invalid.jwt.token", + }) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) + assert.Nil(t, client) + }) + + testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) { + return func(t *testing.T) { + // Populate all claims with valid values + builder := jwt.NewBuilder(). + Issuer(federatedClientIssuer). + Audience([]string{federatedClientAudience}). + Subject(federatedClientSubject). + IssuedAt(time.Now()). + Expiration(time.Now().Add(10 * time.Minute)) + + // Call builderFn to override the claims + builderFn(builder) + + token, err := builder.Build() + require.NoError(t, err) + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK)) + require.NoError(t, err) + + // Test with invalid JWT assertion + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: federatedClient.ID, + ClientAssertionType: ClientAssertionTypeJWTBearer, + ClientAssertion: string(signedToken), + }) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) + require.Nil(t, client) + } + } + + t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) { + builder.Expiration(time.Now().Add(-30 * time.Minute)) + })) + + t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) { + builder.Issuer("https://bad-issuer.com") + })) + + t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) { + builder.Audience([]string{"bad-audience"}) + })) + + t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) { + builder.Subject("bad-subject") + })) + + t.Run("Uses default values for audience and subject", func(t *testing.T) { + // Create JWT for federated identity + token, err := jwt.NewBuilder(). + Issuer(federatedClientIssuerDefaults). + Audience([]string{common.EnvConfig.AppURL}). + Subject(federatedClient.ID). + IssuedAt(time.Now()). + Expiration(time.Now().Add(10 * time.Minute)). + Build() + require.NoError(t, err) + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults)) + require.NoError(t, err) + + // Test with valid JWT assertion + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + ClientID: federatedClient.ID, + ClientAssertionType: ClientAssertionTypeJWTBearer, + ClientAssertion: string(signedToken), + }) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, federatedClient.ID, client.ID) + }) + }) +} diff --git a/backend/internal/service/testutils_test.go b/backend/internal/service/testutils_test.go new file mode 100644 index 00000000..59cdd9fe --- /dev/null +++ b/backend/internal/service/testutils_test.go @@ -0,0 +1,97 @@ +package service + +import ( + "io" + "net/http" + "strings" + "testing" + "time" + + _ "github.com/golang-migrate/migrate/v4/source/file" + + "github.com/glebarez/sqlite" + "github.com/golang-migrate/migrate/v4" + sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/pocket-id/pocket-id/backend/internal/utils" + "github.com/pocket-id/pocket-id/backend/resources" +) + +func newDatabaseForTest(t *testing.T) *gorm.DB { + t.Helper() + + // Get a name for this in-memory database that is specific to the test + dbName := utils.CreateSha256Hash(t.Name()) + + // Connect to a new in-memory SQL database + db, err := gorm.Open( + sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"), + &gorm.Config{ + TranslateError: true, + Logger: logger.New( + testLoggerAdapter{t: t}, + logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + ParameterizedQueries: false, + Colorful: false, + }, + ), + }) + require.NoError(t, err, "Failed to connect to test database") + + // Perform migrations with the embedded migrations + sqlDB, err := db.DB() + require.NoError(t, err, "Failed to get sql.DB") + driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{}) + require.NoError(t, err, "Failed to create migration driver") + source, err := iofs.New(resources.FS, "migrations/sqlite") + require.NoError(t, err, "Failed to create embedded migration source") + m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver) + require.NoError(t, err, "Failed to create migration instance") + err = m.Up() + require.NoError(t, err, "Failed to perform migrations") + + return db +} + +// Implements gorm's logger.Writer interface +type testLoggerAdapter struct { + t *testing.T +} + +func (l testLoggerAdapter) Printf(format string, args ...any) { + l.t.Logf(format, args...) +} + +// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL +type MockRoundTripper struct { + Err error + Responses map[string]*http.Response +} + +// RoundTrip implements the http.RoundTripper interface +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Check if we have a specific response for this URL + for url, resp := range m.Responses { + if req.URL.String() == url { + return resp, nil + } + } + + return NewMockResponse(http.StatusNotFound, ""), nil +} + +// NewMockResponse creates an http.Response with the given status code and body +func NewMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} diff --git a/backend/internal/utils/jwk_util.go b/backend/internal/utils/jwk_util.go new file mode 100644 index 00000000..2571b514 --- /dev/null +++ b/backend/internal/utils/jwk_util.go @@ -0,0 +1,69 @@ +package utils + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" +) + +const ( + // KeyUsageSigning is the usage for the private keys, for the "use" property + KeyUsageSigning = "sig" +) + +// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key. +// It also populates additional fields such as the key ID, usage, and alg. +func ImportRawKey(rawKey any) (jwk.Key, error) { + key, err := jwk.Import(rawKey) + if err != nil { + return nil, fmt.Errorf("failed to import generated private key: %w", err) + } + + // Generate the key ID + kid, err := generateRandomKeyID() + if err != nil { + return nil, fmt.Errorf("failed to generate key ID: %w", err) + } + _ = key.Set(jwk.KeyIDKey, kid) + + // Set other required fields + _ = key.Set(jwk.KeyUsageKey, KeyUsageSigning) + EnsureAlgInKey(key) + + return key, nil +} + +// generateRandomKeyID generates a random key ID. +func generateRandomKeyID() (string, error) { + buf := make([]byte, 8) + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + return "", fmt.Errorf("failed to read random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type +func EnsureAlgInKey(key jwk.Key) { + _, ok := key.Algorithm() + if ok { + // Algorithm is already set + return + } + + switch key.KeyType() { + case jwa.RSA(): + // Default to RS256 for RSA keys + _ = key.Set(jwk.AlgorithmKey, jwa.RS256()) + case jwa.EC(): + // Default to ES256 for ECDSA keys + _ = key.Set(jwk.AlgorithmKey, jwa.ES256()) + case jwa.OKP(): + // Default to EdDSA for OKP keys + _ = key.Set(jwk.AlgorithmKey, jwa.EdDSA()) + } +} diff --git a/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql b/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql new file mode 100644 index 00000000..41aab182 --- /dev/null +++ b/backend/resources/migrations/postgres/20250524202611_client_credentials.down.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients DROP COLUMN credentials; diff --git a/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql b/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql new file mode 100644 index 00000000..3aa11798 --- /dev/null +++ b/backend/resources/migrations/postgres/20250524202611_client_credentials.up.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients ADD COLUMN credentials JSONB NULL; diff --git a/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql b/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql new file mode 100644 index 00000000..41aab182 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250524202611_client_credentials.down.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients DROP COLUMN credentials; diff --git a/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql b/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql new file mode 100644 index 00000000..26089724 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250524202611_client_credentials.up.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients ADD COLUMN credentials TEXT NULL; diff --git a/frontend/messages/en.json b/frontend/messages/en.json index d798c524..627b058e 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -348,6 +348,12 @@ "the_device_has_been_authorized": "The device has been authorized.", "enter_code_displayed_in_previous_step": "Enter the code that was displayed in the previous step.", "authorize": "Authorize", + "federated_identities": "Federated Identities", + "federated_identities_description": "Using federated identities, you can authenticate OIDC clients using JWT tokens issued by third-party authorities.", + "add_federated_identity": "Add Federated Identity", + "add_another_federated_identity": "Add another federated identity", "oidc_allowed_group_count": "Allowed Group Count", - "unrestricted": "Unrestricted" + "unrestricted": "Unrestricted", + "show_advanced_options": "Show Advanced Options", + "hide_advanced_options": "Hide Advanced Options" } diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index dfffa89c..fdaebd65 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -6,11 +6,23 @@ export type OidcClientMetaData = { hasLogo: boolean; }; +export type OidcClientFederatedIdentity = { + issuer: string; + subject?: string; + audience?: string; + jwks?: string; +}; + +export type OidcClientCredentials = { + federatedIdentities: OidcClientFederatedIdentity[]; +}; + export type OidcClient = OidcClientMetaData & { callbackURLs: string[]; // No longer requires at least one URL logoutCallbackURLs: string[]; isPublic: boolean; pkceEnabled: boolean; + credentials?: OidcClientCredentials; }; export type OidcClientWithAllowedUserGroups = OidcClient & { diff --git a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte index 1b451a50..e55be368 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte @@ -8,6 +8,7 @@ import * as Card from '$lib/components/ui/card'; import Label from '$lib/components/ui/label/label.svelte'; import UserGroupSelection from '$lib/components/user-group-selection.svelte'; + import { m } from '$lib/paraglide/messages'; import OidcService from '$lib/services/oidc-service'; import clientSecretStore from '$lib/stores/client-secret-store'; import type { OidcClientCreateWithLogo } from '$lib/types/oidc.type'; @@ -16,7 +17,6 @@ import { toast } from 'svelte-sonner'; import { slide } from 'svelte/transition'; import OidcForm from '../oidc-client-form.svelte'; - import { m } from '$lib/paraglide/messages'; let { data } = $props(); let client = $state({ @@ -166,7 +166,7 @@ - + diff --git a/frontend/src/routes/settings/admin/oidc-clients/federated-identities-input.svelte b/frontend/src/routes/settings/admin/oidc-clients/federated-identities-input.svelte new file mode 100644 index 00000000..4bb770f5 --- /dev/null +++ b/frontend/src/routes/settings/admin/oidc-clients/federated-identities-input.svelte @@ -0,0 +1,128 @@ + + +
+ +
+ {#each federatedIdentities as identity, i} +
+
+ + {#if federatedIdentities.length > 0} + + {/if} +
+ +
+
+ + updateFederatedIdentity(i, 'issuer', e.currentTarget.value)} + required + /> +
+ +
+ + updateFederatedIdentity(i, 'subject', e.currentTarget.value)} + /> +
+ +
+ + updateFederatedIdentity(i, 'audience', e.currentTarget.value)} + /> +
+ +
+ + updateFederatedIdentity(i, 'jwks', e.currentTarget.value)} + /> +
+
+
+ {/each} +
+
+ + {#if error} +

{error}

+ {/if} + + +
diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte index aeb37a7b..a811adaf 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte @@ -5,14 +5,14 @@ import { Button } from '$lib/components/ui/button'; import Label from '$lib/components/ui/label/label.svelte'; import { m } from '$lib/paraglide/messages'; - import type { - OidcClient, - OidcClientCreate, - OidcClientCreateWithLogo - } from '$lib/types/oidc.type'; + import type { OidcClient, OidcClientCreateWithLogo } from '$lib/types/oidc.type'; import { preventDefault } from '$lib/utils/event-util'; import { createForm } from '$lib/utils/form-util'; + import { cn } from '$lib/utils/style'; + import { LucideChevronDown } from '@lucide/svelte'; + import { slide } from 'svelte/transition'; import { z } from 'zod'; + import FederatedIdentitiesInput from './federated-identities-input.svelte'; import OidcCallbackUrlInput from './oidc-callback-url-input.svelte'; let { @@ -24,17 +24,21 @@ } = $props(); let isLoading = $state(false); + let showAdvancedOptions = $state(false); let logo = $state(); let logoDataURL: string | null = $state( existingClient?.hasLogo ? `/api/oidc/clients/${existingClient!.id}/logo` : null ); - const client: OidcClientCreate = { + const client = { name: existingClient?.name || '', callbackURLs: existingClient?.callbackURLs || [], logoutCallbackURLs: existingClient?.logoutCallbackURLs || [], isPublic: existingClient?.isPublic || false, - pkceEnabled: existingClient?.pkceEnabled || false + pkceEnabled: existingClient?.pkceEnabled || false, + credentials: { + federatedIdentities: existingClient?.credentials?.federatedIdentities || [] + } }; const formSchema = z.object({ @@ -42,7 +46,17 @@ callbackURLs: z.array(z.string().nonempty()).default([]), logoutCallbackURLs: z.array(z.string().nonempty()), isPublic: z.boolean(), - pkceEnabled: z.boolean() + pkceEnabled: z.boolean(), + credentials: z.object({ + federatedIdentities: z.array( + z.object({ + issuer: z.string().url(), + subject: z.string().optional(), + audience: z.string().optional(), + jwks: z.string().url().optional().or(z.literal('')) + }) + ) + }) }); type FormSchema = typeof formSchema; @@ -139,8 +153,31 @@ -
-
- + + {#if showAdvancedOptions} +
+ +
+ {/if} + +
+ +
diff --git a/tests/data.ts b/tests/data.ts index 8ffa4f36..51e57d13 100644 --- a/tests/data.ts +++ b/tests/data.ts @@ -35,6 +35,17 @@ export const oidcClients = { callbackUrl: 'http://immich/auth/callback', secret: 'PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x' }, + federated: { + id: "c48232ff-ff65-45ed-ae96-7afa8a9b443b", + name: 'Federated', + callbackUrl: 'http://federated/auth/callback', + federatedJWT: { + issuer: 'https://external-idp.local', + audience: 'api://PocketID', + subject: 'c48232ff-ff65-45ed-ae96-7afa8a9b443b', + }, + accessCodes: ['federated'] + }, pingvinShare: { name: 'Pingvin Share', callbackUrl: 'http://pingvin.share/auth/callback', diff --git a/tests/package-lock.json b/tests/package-lock.json index 6c30b21c..1ddd7698 100644 --- a/tests/package-lock.json +++ b/tests/package-lock.json @@ -7,6 +7,7 @@ "devDependencies": { "@playwright/test": "^1.52.0", "@types/node": "^22.15.21", + "dotenv": "^16.5.0", "jose": "^6.0.11" } }, @@ -36,6 +37,19 @@ "undici-types": "~6.21.0" } }, + "node_modules/dotenv": { + "version": "16.5.0", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.5.0.tgz", + "integrity": "sha512-m/C+AwOAr9/W1UOIZUo232ejMNnJAJtYQjUbHoNTBNTJSvqzzDh7vnrei3o3r3m9blf6ZoDkvcw0VmozNRFJxg==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } + }, "node_modules/fsevents": { "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", diff --git a/tests/package.json b/tests/package.json index 356998ff..88485685 100644 --- a/tests/package.json +++ b/tests/package.json @@ -3,6 +3,7 @@ "devDependencies": { "@playwright/test": "^1.52.0", "@types/node": "^22.15.21", - "jose": "^6.0.11" + "jose": "^6.0.11", + "dotenv": "^16.5.0" } } diff --git a/tests/playwright.config.ts b/tests/playwright.config.ts index 055a9be3..48ddf2ef 100644 --- a/tests/playwright.config.ts +++ b/tests/playwright.config.ts @@ -1,30 +1,31 @@ -import { defineConfig, devices } from '@playwright/test'; +import { defineConfig, devices } from "@playwright/test"; +import "dotenv/config"; /** * See https://playwright.dev/docs/test-configuration. */ export default defineConfig({ - outputDir: './.output', - timeout: 10000, - testDir: './specs', - fullyParallel: false, - forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 0, - workers: 1, - reporter: process.env.CI - ? [['html', { outputFolder: '.report' }], ['github']] - : [['line'], ['html', { open: 'never', outputFolder: '.report' }]], - use: { - baseURL: process.env.APP_URL ?? 'http://localhost:1411', - video: 'retain-on-failure', - trace: 'on-first-retry' - }, - projects: [ - { name: 'setup', testMatch: /.*\.setup\.ts/ }, - { - name: 'chromium', - use: { ...devices['Desktop Chrome'], storageState: '.auth/user.json' }, - dependencies: ['setup'] - } - ] + outputDir: "./.output", + timeout: 10000, + testDir: "./specs", + fullyParallel: false, + forbidOnly: !!process.env.CI, + retries: process.env.CI ? 1 : 0, + workers: 1, + reporter: process.env.CI + ? [["html", { outputFolder: ".report" }], ["github"]] + : [["line"], ["html", { open: "never", outputFolder: ".report" }]], + use: { + baseURL: process.env.APP_URL ?? "http://localhost:1411", + video: "retain-on-failure", + trace: "on-first-retry", + }, + projects: [ + { name: "setup", testMatch: /.*\.setup\.ts/ }, + { + name: "chromium", + use: { ...devices["Desktop Chrome"], storageState: ".auth/user.json" }, + dependencies: ["setup"], + }, + ], }); diff --git a/tests/specs/ldap.spec.ts b/tests/specs/ldap.spec.ts index 49015e01..fefee131 100644 --- a/tests/specs/ldap.spec.ts +++ b/tests/specs/ldap.spec.ts @@ -4,6 +4,8 @@ import { cleanupBackend } from '../utils/cleanup.util'; test.beforeEach(cleanupBackend); test.describe('LDAP Integration', () => { + test.skip(process.env.SKIP_LDAP_TESTS === "true", 'Skipping LDAP tests due to SKIP_LDAP_TESTS environment variable'); + test('LDAP configuration is working properly', async ({ page }) => { await page.goto('/settings/admin/application-configuration'); diff --git a/tests/specs/oidc.spec.ts b/tests/specs/oidc.spec.ts index b0dcff07..52f5e35d 100644 --- a/tests/specs/oidc.spec.ts +++ b/tests/specs/oidc.spec.ts @@ -2,7 +2,7 @@ import test, { expect } from "@playwright/test"; import { oidcClients, refreshTokens, users } from "../data"; import { cleanupBackend } from "../utils/cleanup.util"; import { generateIdToken, generateOauthAccessToken } from "../utils/jwt.util"; -import oidcUtil from "../utils/oidc.util"; +import * as oidcUtil from "../utils/oidc.util"; import passkeyUtil from "../utils/passkey.util"; test.beforeEach(cleanupBackend); @@ -449,3 +449,40 @@ test("Authorize new client with device authorization with user group not allowed .filter({ hasText: "You're not allowed to access this service." }) ).toBeVisible(); }); + +test("Federated identity fails with invalid client assertion", async ({ + page, +}) => { + const client = oidcClients.federated; + + const res = await oidcUtil.exchangeCode(page, { + client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', + grant_type: 'authorization_code', + redirect_uri: client.callbackUrl, + code: client.accessCodes[0], + client_id: client.id, + client_assertion:'not-an-assertion', + }); + + expect(res?.error).toBe('Invalid client assertion'); +}); + +test("Authorize existing client with federated identity", async ({ + page, +}) => { + const client = oidcClients.federated; + const clientAssertion = await oidcUtil.getClientAssertion(page, client.federatedJWT); + + const res = await oidcUtil.exchangeCode(page, { + client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', + grant_type: 'authorization_code', + redirect_uri: client.callbackUrl, + code: client.accessCodes[0], + client_id: client.id, + client_assertion: clientAssertion, + }); + + expect(res.access_token).not.toBeNull; + expect(res.expires_in).not.toBeNull; + expect(res.token_type).toBe('Bearer'); +}); diff --git a/tests/tsconfig.json b/tests/tsconfig.json index 36aa1a4d..2adb063f 100644 --- a/tests/tsconfig.json +++ b/tests/tsconfig.json @@ -1,5 +1,6 @@ { "compilerOptions": { - "baseUrl": "." + "baseUrl": ".", + "lib": ["ES2022"] } } diff --git a/tests/utils/cleanup.util.ts b/tests/utils/cleanup.util.ts index 14785b0f..a2317f95 100644 --- a/tests/utils/cleanup.util.ts +++ b/tests/utils/cleanup.util.ts @@ -1,12 +1,15 @@ import playwrightConfig from "../playwright.config"; export async function cleanupBackend() { - const response = await fetch( - playwrightConfig.use!.baseURL + "/api/test/reset", - { - method: "POST", - } - ); + const url = new URL("/api/test/reset", playwrightConfig.use!.baseURL); + + if (process.env.SKIP_LDAP_TESTS === "true") { + url.searchParams.append("skip-ldap", "true"); + } + + const response = await fetch(url, { + method: "POST", + }); if (!response.ok) { throw new Error( diff --git a/tests/utils/oidc.util.ts b/tests/utils/oidc.util.ts index 6f3763d1..2600c063 100644 --- a/tests/utils/oidc.util.ts +++ b/tests/utils/oidc.util.ts @@ -1,7 +1,7 @@ import type { Page } from '@playwright/test'; -async function getUserCode(page: Page, clientId: string, clientSecret: string) { - const response = await page.request +export async function getUserCode(page: Page, clientId: string, clientSecret: string): Promise { + return page.request .post('/api/oidc/device/authorize', { headers: { 'Content-Type': 'application/x-www-form-urlencoded' @@ -12,11 +12,29 @@ async function getUserCode(page: Page, clientId: string, clientSecret: string) { scope: 'openid profile email' } }) - .then((r) => r.json()); - - return response.user_code; + .then((r) => r.json()) + .then((r) => r.user_code); } -export default { - getUserCode -}; +export async function exchangeCode(page: Page, params: Record): Promise<{access_token?: string, token_type?: string, expires_in?: number, error?: string}> { + return page.request + .post('/api/oidc/token', { + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + form: params, + }) + .then((r) => r.json()); +} + +export async function getClientAssertion(page: Page, data: {issuer: string, audience: string, subject: string}): Promise { + return page.request + .post('/api/externalidp/sign', { + data: { + iss: data.issuer, + aud: data.audience, + sub: data.subject, + }, + }) + .then((r) => r.text()); +} \ No newline at end of file