commit 920a79b2e9dd6c886e529a18c9b5420d3d9215fc Author: Shamil Nunhuck Date: Sat Nov 8 10:18:19 2025 +0000 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cc38e83 --- /dev/null +++ b/.gitignore @@ -0,0 +1,41 @@ +# Binaries and build artifacts +bin/ +build/ +dist/ +out/ +tmp/ + +# Go compiler outputs +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out + +# Dependency/vendor directories +vendor/ + +# Coverage and profiling data +coverage.* +*.cov +*.coverprofile +*.pprof +*.prof +profile.out + +# Local configs and secrets +.env +.env.* +*.local +*.secret.yaml + +# IDE/editor state +.idea/ +.vscode/ +*.code-workspace + +# OS junk +.DS_Store +Thumbs.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2fc0409 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM golang:1.22 AS build +WORKDIR /src +COPY go.mod ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -o /out/broker ./cmd/broker + +FROM gcr.io/distroless/static:nonroot +ENV CONFIG_PATH=/config/config.yaml +WORKDIR / +USER nonroot:nonroot +COPY --from=build /out/broker /broker +ENTRYPOINT ["/broker"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..96fa64b --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +APP?=saml-oidc-broker +PKG?=shamilnunhuck/saml-oidc-bridge +CONFIG?=example.config.yaml +KEY_ID?=k-$(shell date +%Y-%m) + +build: + GO111MODULE=on CGO_ENABLED=0 go build -o bin/$(APP) ./cmd/broker + +run: + CONFIG_PATH=$(CONFIG) bin/$(APP) + +rotate-key: + bin/$(APP) cert -config $(CONFIG) -id $(KEY_ID) -algo rsa3072 -days 825 -cn id.example.com -org "YourOrg" -k8s-secret-out build/$(KEY_ID).secret.yaml + @echo "Wrote build/$(KEY_ID).secret.yaml" + +docker: + docker build --platform linux/amd64 -t shamilnunhuck/$(APP):dev . + +.PHONY: build run rotate-key docker diff --git a/charts/saml-broker/Chart.yaml b/charts/saml-broker/Chart.yaml new file mode 100644 index 0000000..f179fa2 --- /dev/null +++ b/charts/saml-broker/Chart.yaml @@ -0,0 +1,6 @@ +apiVersion: v2 +name: saml-broker +description: Minimal SAML IdP brokering to OIDC (Pocket ID) for Splunk +type: application +version: 0.1.0 +appVersion: "0.1.0" diff --git a/charts/saml-broker/templates/NOTES.txt b/charts/saml-broker/templates/NOTES.txt new file mode 100644 index 0000000..f9136da --- /dev/null +++ b/charts/saml-broker/templates/NOTES.txt @@ -0,0 +1,4 @@ +1. Get the service URL by running these commands: + export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "saml-broker.name" . }}" -o jsonpath="{.items[0].metadata.name}") + kubectl port-forward $POD_NAME 8080:8080 & + echo "Visit http://127.0.0.1:8080/saml/metadata" diff --git a/charts/saml-broker/templates/_helpers.tpl b/charts/saml-broker/templates/_helpers.tpl new file mode 100644 index 0000000..2a394b4 --- /dev/null +++ b/charts/saml-broker/templates/_helpers.tpl @@ -0,0 +1,10 @@ +{{- define "saml-broker.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}} +{{- end -}} +{{- define "saml-broker.fullname" -}} +{{- if .Values.fullnameOverride -}} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" -}} +{{- else -}} +{{- printf "%s" (include "saml-broker.name" .) | trunc 63 | trimSuffix "-" -}} +{{- end -}} +{{- end -}} diff --git a/charts/saml-broker/templates/configmap.yaml b/charts/saml-broker/templates/configmap.yaml new file mode 100644 index 0000000..d1c1b3a --- /dev/null +++ b/charts/saml-broker/templates/configmap.yaml @@ -0,0 +1,7 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "saml-broker.fullname" . }}-config +data: + config.yaml: | +{{ toYaml .Values.config | indent 4 }} diff --git a/charts/saml-broker/templates/deployment.yaml b/charts/saml-broker/templates/deployment.yaml new file mode 100644 index 0000000..9adbd5b --- /dev/null +++ b/charts/saml-broker/templates/deployment.yaml @@ -0,0 +1,49 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "saml-broker.fullname" . }} + labels: + app.kubernetes.io/name: {{ include "saml-broker.name" . }} +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: {{ include "saml-broker.name" . }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ include "saml-broker.name" . }} + spec: + containers: + - name: broker + image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + env: + - name: CONFIG_PATH + value: /config/config.yaml + - name: OIDC_CLIENT_SECRET + valueFrom: + secretKeyRef: + name: {{ .Values.env.OIDC_CLIENT_SECRET_SECRET_NAME }} + key: {{ .Values.env.OIDC_CLIENT_SECRET_KEY }} + ports: + - name: http + containerPort: 8080 + volumeMounts: + - name: cfg + mountPath: /config + readOnly: true + readinessProbe: + httpGet: + path: /healthz + port: http + livenessProbe: + httpGet: + path: /healthz + port: http + resources: +{{ toYaml .Values.resources | indent 12 }} + volumes: + - name: cfg + configMap: + name: {{ include "saml-broker.fullname" . }}-config diff --git a/charts/saml-broker/templates/ingress.yaml b/charts/saml-broker/templates/ingress.yaml new file mode 100644 index 0000000..f8a3d21 --- /dev/null +++ b/charts/saml-broker/templates/ingress.yaml @@ -0,0 +1,30 @@ +{{- if .Values.ingress.enabled }} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ include "saml-broker.fullname" . }} + {{- with .Values.ingress.className }} + annotations: + kubernetes.io/ingress.class: {{ . }} + {{- end }} +spec: + rules: + {{- range .Values.ingress.hosts }} + - host: {{ .host }} + http: + paths: + {{- range .paths }} + - path: {{ .path }} + pathType: {{ .pathType }} + backend: + service: + name: {{ include "saml-broker.fullname" $ }} + port: + number: {{ $.Values.service.port }} + {{- end }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: +{{ toYaml .Values.ingress.tls | indent 4 }} + {{- end }} +{{- end }} diff --git a/charts/saml-broker/templates/service.yaml b/charts/saml-broker/templates/service.yaml new file mode 100644 index 0000000..71e1dbe --- /dev/null +++ b/charts/saml-broker/templates/service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "saml-broker.fullname" . }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ .Values.service.port }} + targetPort: http + protocol: TCP + name: http + selector: + app.kubernetes.io/name: {{ include "saml-broker.name" . }} diff --git a/charts/saml-broker/values.yaml b/charts/saml-broker/values.yaml new file mode 100644 index 0000000..6ed9b0b --- /dev/null +++ b/charts/saml-broker/values.yaml @@ -0,0 +1,63 @@ +image: + repository: ghcr.io/your-org/broker + tag: dev + pullPolicy: IfNotPresent + +service: + type: ClusterIP + port: 80 + +ingress: + enabled: false + className: "" + hosts: + - host: id.example.com + paths: + - path: / + pathType: Prefix + tls: [] + +resources: {} + +env: + # OIDC client secret comes from a Secret + OIDC_CLIENT_SECRET_SECRET_NAME: oidc-secret + OIDC_CLIENT_SECRET_KEY: OIDC_CLIENT_SECRET + +config: + # Paste example.config.yaml here (without private key if you mount keys via secret) + server: + listen: ":8080" + external_url: "https://id.example.com" + crypto: + active_key: "k-2025-09" + keys: [] + oidc_upstream: + issuer: "https://pocket-id.example" + client_id: "your-client-id" + redirect_path: "/oidc/callback" + scopes: ["email","profile"] + sps: + - name: "splunk" + entity_id: "https://splunk.example" + acs_url: "https://splunk.example/saml/acs" + audience: "https://splunk.example" + nameid_format: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" + attribute_mapping: + mail: "email" + realName: "name" + role: "role" + role_mapping: + admins: "admin" + power: "power" + "*": "user" + security: + skew_seconds: 120 + assertion_ttl_seconds: 300 + require_signed_authn_request: false + metadata_valid_until_days: 7 + metadata_cache_duration_seconds: 86400 + session: + cookie_name: "_saml_broker" + cookie_secure: true + cookie_domain: "id.example.com" diff --git a/cmd/broker/main.go b/cmd/broker/main.go new file mode 100644 index 0000000..4641852 --- /dev/null +++ b/cmd/broker/main.go @@ -0,0 +1,114 @@ +package main + +import ( + "log" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/fsnotify/fsnotify" + + "shamilnunhuck/saml-oidc-bridge/internal/cli" + "shamilnunhuck/saml-oidc-bridge/internal/config" + "shamilnunhuck/saml-oidc-bridge/internal/crypto" + h "shamilnunhuck/saml-oidc-bridge/internal/http" + "shamilnunhuck/saml-oidc-bridge/internal/oidc" + "shamilnunhuck/saml-oidc-bridge/internal/saml" +) + +type runtimeState struct { + mu sync.RWMutex + cfg *config.Config + ks *crypto.KeyStore + idp *saml.IdP + oidc *oidc.Client +} + +func main() { + if len(os.Args) > 1 && os.Args[1] == "cert" { + if err := cli.RunCert(os.Args[2:]); err != nil { + log.Fatal(err) + } + return + } + + cfgPath := os.Getenv("CONFIG_PATH") + if cfgPath == "" { + cfgPath = "example.config.yaml" + } + + state := &runtimeState{} + load := func() { + cfg, err := config.Load(cfgPath) + if err != nil { + log.Fatalf("load config: %v", err) + } + if v := os.Getenv("OIDC_CLIENT_SECRET"); v != "" { + cfg.OIDC.ClientSecret = v + } + if err := cfg.Validate(); err != nil { + log.Fatalf("invalid config: %v", err) + } + + ks, err := crypto.NewKeyStore(cfg.Crypto) + if err != nil { + log.Fatalf("keystore: %v", err) + } + idp := saml.NewIdP(cfg, ks) + oc, err := oidc.NewClient(cfg) + if err != nil { + log.Fatalf("oidc: %v", err) + } + + state.mu.Lock() + state.cfg, state.ks, state.idp, state.oidc = cfg, ks, idp, oc + state.mu.Unlock() + log.Printf("loaded config; active signing key=%s", cfg.Crypto.ActiveKey) + } + load() + + mux := http.NewServeMux() + h.Register( + mux, + func() *config.Config { state.mu.RLock(); defer state.mu.RUnlock(); return state.cfg }, + func() *saml.IdP { state.mu.RLock(); defer state.mu.RUnlock(); return state.idp }, + func() *oidc.Client { state.mu.RLock(); defer state.mu.RUnlock(); return state.oidc }, + ) + + go func() { + log.Printf("listening on %s", state.cfg.Server.Listen) + log.Fatal(http.ListenAndServe(state.cfg.Server.Listen, mux)) + }() + + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, syscall.SIGHUP) + go func() { + for range sigc { + load() + } + }() + + w, err := fsnotify.NewWatcher() + if err == nil { + defer w.Close() + _ = w.Add(cfgPath) + go func() { + for { + select { + case e := <-w.Events: + if e.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 { + load() + } + case err := <-w.Errors: + if err != nil { + log.Printf("watch error: %v", err) + } + } + } + }() + } + + select {} +} diff --git a/example.config.yaml b/example.config.yaml new file mode 100644 index 0000000..af76b75 --- /dev/null +++ b/example.config.yaml @@ -0,0 +1,43 @@ +server: + listen: :8080 + external_url: https://saml-v.ttt.net +crypto: + active_key: k-2025-12 + keys: + - id: k-2025-12 + cert_pem: | + ... + key_pem: | + ... + not_after: 2028-01-06T12:27:11.670644Z +oidc_upstream: + issuer: https://id.tt.net + client_id: 1ec56384 + redirect_path: /oidc/callback + scopes: + - email + - profile +sps: + - name: splunk + entity_id: https://splunk.example + acs_url: https://splunk.example/saml/acs + audience: https://splunk.example + nameid_format: urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + attribute_mapping: + mail: email + realName: name + role: role + role_mapping: + '*': user + admins: admin + power: power +security: + skew_seconds: 120 + assertion_ttl_seconds: 300 + require_signed_authn_request: false + metadata_valid_until_days: 7 + metadata_cache_duration_seconds: 86400 +session: + cookie_name: _saml_broker + cookie_secure: true + cookie_domain: saml-v.ttt.net diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a37d4b6 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module shamilnunhuck/saml-oidc-bridge + +go 1.22 + +require ( + github.com/coreos/go-oidc/v3 v3.11.0 + github.com/crewjam/saml v0.5.1 + github.com/fsnotify/fsnotify v1.7.0 + golang.org/x/oauth2 v0.23.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/beevik/etree v1.5.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect + github.com/russellhaering/goxmldsig v1.4.0 // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/sys v0.30.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ff498f4 --- /dev/null +++ b/go.sum @@ -0,0 +1,54 @@ +github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= +github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs= +github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crewjam/saml v0.4.14 h1:g9FBNx62osKusnFzs3QTN5L9CVA/Egfgm+stJShzw/c= +github.com/crewjam/saml v0.4.14/go.mod h1:UVSZCf18jJkk6GpWNVqcyQJMD5HsRugBPf4I1nl2mME= +github.com/crewjam/saml v0.5.1 h1:g+mfp0CrLuLRZCK793PgJcZeg5dS/0CDwoeAX2zcwNI= +github.com/crewjam/saml v0.5.1/go.mod h1:r0fDkmFe5URDgPrmtH0IYokva6fac3AUdstiPhyEolQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= +github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= +github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys= +github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cli/cert.go b/internal/cli/cert.go new file mode 100644 index 0000000..d276338 --- /dev/null +++ b/internal/cli/cert.go @@ -0,0 +1,206 @@ +package cli + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "flag" + "fmt" + "math/big" + "os" + "strings" + "time" + + "gopkg.in/yaml.v3" + + "shamilnunhuck/saml-oidc-bridge/internal/config" +) + +type rotateOpts struct { + ConfigPath string + ID string + Algo string + Days int + CN string + Org string + OutK8s string + ActiveOnly bool +} + +func RunCert(args []string) error { + fs := flag.NewFlagSet("cert", flag.ContinueOnError) + var ro rotateOpts + fs.StringVar(&ro.ConfigPath, "config", "example.config.yaml", "path to config yaml") + fs.StringVar(&ro.ID, "id", "", "key id (e.g. k-2025-10)") + fs.StringVar(&ro.Algo, "algo", "rsa3072", "rsa2048|rsa3072|rsa4096|p256|p384") + fs.IntVar(&ro.Days, "days", 825, "validity in days") + fs.StringVar(&ro.CN, "cn", "id.example.com", "certificate CN") + fs.StringVar(&ro.Org, "org", "YourOrg", "certificate O") + fs.StringVar(&ro.OutK8s, "k8s-secret-out", "", "write a Kubernetes Secret manifest to this path") + fs.BoolVar(&ro.ActiveOnly, "active-only", false, "only set active_key to -id (no new cert)") + if err := fs.Parse(args); err != nil { + return err + } + if ro.ID == "" { + return errors.New("missing -id") + } + + cfg, raw, err := loadConfig(ro.ConfigPath) + if err != nil { + return err + } + + if ro.ActiveOnly { + cfg.Crypto.ActiveKey = ro.ID + return saveConfig(ro.ConfigPath, raw, cfg) + } + + certPEM, keyPEM, notAfter, err := genSelfSigned(ro) + if err != nil { + return err + } + + cfg.Crypto.Keys = append(cfg.Crypto.Keys, config.KeyPair{ + ID: ro.ID, + CertPEM: string(certPEM), + KeyPEM: string(keyPEM), + NotAfter: notAfter.UTC(), + }) + cfg.Crypto.ActiveKey = ro.ID + + if err := saveConfig(ro.ConfigPath, raw, cfg); err != nil { + return err + } + if ro.OutK8s != "" { + if err := os.WriteFile(ro.OutK8s, []byte(k8sSecretYAML(ro, certPEM, keyPEM)), 0o600); err != nil { + return fmt.Errorf("write secret: %w", err) + } + } + fmt.Printf("OK: generated %s (algo=%s, not_after=%s) and set active_key\n", ro.ID, ro.Algo, notAfter.UTC().Format(time.RFC3339)) + return nil +} + +func genSelfSigned(ro rotateOpts) (certPEM, keyPEM []byte, notAfter time.Time, err error) { + nb := time.Now().Add(-5 * time.Minute) + na := nb.Add(time.Duration(ro.Days) * 24 * time.Hour) + + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: ro.CN, + Organization: []string{ro.Org}, + OrganizationalUnit: []string{"SAML Signing"}, + }, + NotBefore: nb, + NotAfter: na, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + var der []byte + var keyPKCS8 []byte + + switch strings.ToLower(ro.Algo) { + case "rsa2048", "rsa3072", "rsa4096": + bits := 2048 + if ro.Algo == "rsa3072" { + bits = 3072 + } + if ro.Algo == "rsa4096" { + bits = 4096 + } + priv, e := rsa.GenerateKey(rand.Reader, bits) + if e != nil { + return nil, nil, time.Time{}, e + } + der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, time.Time{}, err + } + keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, nil, time.Time{}, err + } + case "p256", "p384": + curve := elliptic.P256() + if ro.Algo == "p384" { + curve = elliptic.P384() + } + priv, e := ecdsa.GenerateKey(curve, rand.Reader) + if e != nil { + return nil, nil, time.Time{}, e + } + der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, time.Time{}, err + } + keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, nil, time.Time{}, err + } + default: + return nil, nil, time.Time{}, fmt.Errorf("unknown -algo %q", ro.Algo) + } + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyPKCS8}) + return certPEM, keyPEM, na, nil +} + +func loadConfig(path string) (*config.Config, *yaml.Node, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, nil, err + } + var root yaml.Node + if err := yaml.Unmarshal(b, &root); err != nil { + return nil, nil, err + } + var c config.Config + if err := root.Decode(&c); err != nil { + return nil, nil, err + } + return &c, &root, nil +} + +func saveConfig(path string, _ *yaml.Node, c *config.Config) error { + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + enc.SetIndent(2) + if err := enc.Encode(c); err != nil { + return err + } + _ = enc.Close() + return os.WriteFile(path, buf.Bytes(), 0o644) +} + +func k8sSecretYAML(ro rotateOpts, certPEM, keyPEM []byte) string { + name := strings.ToLower(strings.ReplaceAll(ro.ID, "_", "-")) + return fmt.Sprintf(`apiVersion: v1 +kind: Secret +metadata: + name: saml-signing-%s +type: Opaque +stringData: + cert.pem: |- +%s + key.pem: |- +%s +`, name, indent(string(certPEM), 4), indent(string(keyPEM), 4)) +} + +func indent(s string, n int) string { + pad := strings.Repeat(" ", n) + lines := strings.Split(strings.TrimRight(s, "\n"), "\n") + for i := range lines { + lines[i] = pad + lines[i] + } + return strings.Join(lines, "\n") +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..8f989f7 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,36 @@ +package config + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +func Load(path string) (*Config, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var c Config + if err := yaml.Unmarshal(b, &c); err != nil { + return nil, err + } + return &c, nil +} + +func (c *Config) Validate() error { + if c.Server.ExternalURL == "" || c.Server.Listen == "" { + return fmt.Errorf("server.external_url and server.listen required") + } + if len(c.SPs) == 0 { + return fmt.Errorf("at least one SP required") + } + if c.OIDC.Issuer == "" || c.OIDC.ClientID == "" || c.OIDC.RedirectPath == "" { + return fmt.Errorf("oidc issuer/client_id/redirect_path required") + } + if c.Crypto.ActiveKey == "" || len(c.Crypto.Keys) == 0 { + return fmt.Errorf("crypto.active_key and at least one key required") + } + return nil +} diff --git a/internal/config/types.go b/internal/config/types.go new file mode 100644 index 0000000..48b4a46 --- /dev/null +++ b/internal/config/types.go @@ -0,0 +1,68 @@ +package config + +import "time" + +type Server struct { + Listen string `yaml:"listen"` + ExternalURL string `yaml:"external_url"` +} + +type KeyPair struct { + ID string `yaml:"id"` + CertPEM string `yaml:"cert_pem"` + KeyPEM string `yaml:"key_pem"` + NotAfter time.Time `yaml:"not_after"` +} +type Crypto struct { + ActiveKey string `yaml:"active_key"` + Keys []KeyPair `yaml:"keys"` +} + +type OIDC struct { + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"-"` + RedirectPath string `yaml:"redirect_path"` + Scopes []string `yaml:"scopes"` +} + +type SP struct { + Name string `yaml:"name"` + EntityID string `yaml:"entity_id"` + ACSURL string `yaml:"acs_url"` + Audience string `yaml:"audience"` + NameIDFormat string `yaml:"nameid_format"` + AttributeMapping map[string]string `yaml:"attribute_mapping"` + RoleMapping map[string]string `yaml:"role_mapping"` + AttributeRules []AttributeRule `yaml:"attribute_rules"` +} + +type Security struct { + SkewSeconds int `yaml:"skew_seconds"` + AssertionTTLSec int `yaml:"assertion_ttl_seconds"` + RequireSignedAuthnRequest bool `yaml:"require_signed_authn_request"` + MetadataValidUntilDays int `yaml:"metadata_valid_until_days"` + MetadataCacheDurationSeconds int `yaml:"metadata_cache_duration_seconds"` +} + +type Session struct { + CookieName string `yaml:"cookie_name"` + CookieSecure bool `yaml:"cookie_secure"` + CookieDomain string `yaml:"cookie_domain"` +} + +type Config struct { + Server Server `yaml:"server"` + Crypto Crypto `yaml:"crypto"` + OIDC OIDC `yaml:"oidc_upstream"` + SPs []SP `yaml:"sps"` + Security Security `yaml:"security"` + Session Session `yaml:"session"` +} + +type AttributeRule struct { + Name string `yaml:"name"` + Value string `yaml:"value"` + IfGroupsAny []string `yaml:"if_groups_any"` + EmitWhenFalse bool `yaml:"emit_when_false"` +} diff --git a/internal/crypto/keystore.go b/internal/crypto/keystore.go new file mode 100644 index 0000000..31a8620 --- /dev/null +++ b/internal/crypto/keystore.go @@ -0,0 +1,74 @@ +package crypto + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + + "shamilnunhuck/saml-oidc-bridge/internal/config" +) + +type KeyStore struct { + activeID string + signers map[string]tls.Certificate + certsDER map[string][]byte +} + +func NewKeyStore(c config.Crypto) (*KeyStore, error) { + ks := &KeyStore{ + activeID: c.ActiveKey, + signers: map[string]tls.Certificate{}, + certsDER: map[string][]byte{}, + } + for _, k := range c.Keys { + cert, priv, err := parseKeypair(k.CertPEM, k.KeyPEM) + if err != nil { + return nil, fmt.Errorf("key %s: %w", k.ID, err) + } + ks.certsDER[k.ID] = cert.Raw + if priv != nil { + ks.signers[k.ID] = tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: priv} + } + } + if _, ok := ks.signers[ks.activeID]; !ok { + return nil, errors.New("active signing key not available (missing or no private key)") + } + return ks, nil +} + +func (ks *KeyStore) Active() tls.Certificate { return ks.signers[ks.activeID] } +func (ks *KeyStore) AllCertsDER() [][]byte { + out := make([][]byte, 0, len(ks.certsDER)) + for _, der := range ks.certsDER { + out = append(out, der) + } + return out +} + +func parseKeypair(certPEM, keyPEM string) (*x509.Certificate, interface{}, error) { + cb, _ := pem.Decode([]byte(certPEM)) + if cb == nil { + return nil, nil, errors.New("invalid cert pem") + } + cert, err := x509.ParseCertificate(cb.Bytes) + if err != nil { + return nil, nil, err + } + var priv interface{} + if keyPEM != "" { + kb, _ := pem.Decode([]byte(keyPEM)) + if kb == nil { + return nil, nil, errors.New("invalid key pem") + } + priv, err = x509.ParsePKCS8PrivateKey(kb.Bytes) + if err != nil { + priv, err = x509.ParsePKCS1PrivateKey(kb.Bytes) + if err != nil { + return nil, nil, err + } + } + } + return cert, priv, nil +} diff --git a/internal/http/handlers.go b/internal/http/handlers.go new file mode 100644 index 0000000..0f40991 --- /dev/null +++ b/internal/http/handlers.go @@ -0,0 +1,261 @@ +package http + +import ( + "compress/flate" + "context" + "crypto/rand" + "encoding/base64" + "encoding/xml" + "fmt" + "html/template" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/crewjam/saml" + + "shamilnunhuck/saml-oidc-bridge/internal/config" + "shamilnunhuck/saml-oidc-bridge/internal/oidc" + idsaml "shamilnunhuck/saml-oidc-bridge/internal/saml" +) + +type IdP interface { + Metadata() *saml.EntityDescriptor + BuildResponse(sp config.SP, nameID string, attrs map[string][]string) (*saml.Response, error) +} + +func Register( + mux *http.ServeMux, + getCfg func() *config.Config, + getIdP func() *idsaml.IdP, + getOIDC func() *oidc.Client, +) { + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, _ = w.Write([]byte("ok")) + }) + + mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) { + meta := getIdP().Metadata() + buf, err := xml.MarshalIndent(meta, "", " ") + if err != nil { + http.Error(w, "marshal metadata: "+err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/samlmetadata+xml") + _, _ = w.Write([]byte(xml.Header)) + _, _ = w.Write(buf) + }) + + mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) { + req, spEntityID, relay, err := parseAuthnRequest(r) + log.Printf("AuthnRequest from SP=%s requestID=%s relay=%q", spEntityID, req.ID, relay) + if err != nil { + http.Error(w, "bad authn request: "+err.Error(), 400) + return + } + setStateCookie(w, getCfg().Session, spEntityID, relay, req.ID) + state := randomState() + http.Redirect(w, r, getOIDC().AuthCodeURL(state, url.Values{}), http.StatusFound) + }) + + mux.HandleFunc(getCfg().OIDC.RedirectPath, func(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "missing code", 400) + return + } + + claims, err := getOIDC().ExchangeAndVerify(ctx, code) + if err != nil { + http.Error(w, "oidc: "+err.Error(), 400) + return + } + + s, err := readStateCookie(r, getCfg().Session) + if err != nil { + http.Error(w, "state missing", 400) + return + } + + sp := lookupSP(getCfg(), s.SPEntityID) + if sp == nil { + http.Error(w, "unknown SP", 400) + return + } + + attrs := map[string][]string{} + for samlAttr, oidcClaim := range sp.AttributeMapping { + switch oidcClaim { + case "email": + attrs[samlAttr] = []string{claims.Email} + case "name": + attrs[samlAttr] = []string{claims.Name} + case "role": + attrs[samlAttr] = []string{mapRole(claims.Groups, sp)} + } + } + + nameID := claims.Email + groups := claims.Groups + + // Apply generic conditional attribute rules + userGroups := toSet(groups) // []string -> set + for _, rule := range sp.AttributeRules { + if hasAnyGroup(userGroups, rule.IfGroupsAny) { + if rule.Value == "" { // safe default + attrs[rule.Name] = []string{"true"} + } else { + attrs[rule.Name] = []string{rule.Value} + } + } else if rule.EmitWhenFalse { + attrs[rule.Name] = []string{"false"} + } + } + + resp, err := getIdP().BuildResponse(*sp, nameID, attrs, s.RequestID) + if err != nil { + http.Error(w, "saml: "+err.Error(), 500) + return + } + + xmlBytes, err := idsaml.MarshalSignedResponse(resp) + if err != nil { + http.Error(w, "sign: "+err.Error(), 500) + return + } + + postToACS(w, sp.ACSURL, base64.StdEncoding.EncodeToString(xmlBytes), s.RelayState) + }) +} + +/*** helpers ***/ + +type state struct { + SPEntityID string + RelayState string + RequestID string + Expiry time.Time +} + +func setStateCookie(w http.ResponseWriter, s config.Session, spEntityID, relay, reqID string) { + v := url.Values{} + v.Set("sp", spEntityID) + v.Set("rs", relay) + v.Set("rid", reqID) + c := &http.Cookie{ + Name: s.CookieName, + Value: base64.RawURLEncoding.EncodeToString([]byte(v.Encode())), + Path: "/", + Domain: s.CookieDomain, + HttpOnly: true, + Secure: s.CookieSecure, + MaxAge: 600, + } + http.SetCookie(w, c) +} + +func readStateCookie(r *http.Request, s config.Session) (*state, error) { + c, err := r.Cookie(s.CookieName) + if err != nil { + return nil, err + } + b, err := base64.RawURLEncoding.DecodeString(c.Value) + if err != nil { + return nil, err + } + v, err := url.ParseQuery(string(b)) + if err != nil { + return nil, err + } + return &state{ + SPEntityID: v.Get("sp"), + RelayState: v.Get("rs"), + RequestID: v.Get("rid"), + Expiry: time.Now().Add(10 * time.Minute), + }, nil +} + +func lookupSP(cfg *config.Config, entityID string) *config.SP { + for i := range cfg.SPs { + if cfg.SPs[i].EntityID == entityID { + return &cfg.SPs[i] + } + } + return nil +} + +func mapRole(groups []string, sp *config.SP) string { + for _, g := range groups { + if v, ok := sp.RoleMapping[g]; ok { + return v + } + } + if v, ok := sp.RoleMapping["*"]; ok { + return v + } + return "user" +} + +func postToACS(w http.ResponseWriter, acsURL string, samlResponseB64 string, relay string) { + const tpl = ` + +
+ + {{if .Relay}}{{end}} + +
` + t := template.Must(template.New("post").Parse(tpl)) + _ = t.Execute(w, map[string]string{"ACS": acsURL, "Resp": samlResponseB64, "Relay": relay}) +} + +func randomState() string { + var b [16]byte + _, _ = rand.Read(b[:]) + return base64.RawURLEncoding.EncodeToString(b[:]) +} + +func parseAuthnRequest(r *http.Request) (*saml.AuthnRequest, string, string, error) { + relay := r.FormValue("RelayState") + if sr := r.URL.Query().Get("SAMLRequest"); sr != "" { + xmlBytes, err := base64.StdEncoding.DecodeString(sr) + if err != nil { + return nil, "", "", fmt.Errorf("b64: %w", err) + } + reader := flate.NewReader(strings.NewReader(string(xmlBytes))) + defer reader.Close() + var sb strings.Builder + if _, err := io.Copy(&sb, reader); err != nil { + return nil, "", "", fmt.Errorf("inflate: %w", err) + } + var req saml.AuthnRequest + if err := xml.Unmarshal([]byte(sb.String()), &req); err != nil { + return nil, "", "", fmt.Errorf("xml: %w", err) + } + sp := "" + if req.Issuer != nil { + sp = req.Issuer.Value + } + return &req, sp, relay, nil + } + if sr := r.FormValue("SAMLRequest"); sr != "" { + xmlBytes, err := base64.StdEncoding.DecodeString(sr) + if err != nil { + return nil, "", "", fmt.Errorf("b64: %w", err) + } + var req saml.AuthnRequest + if err := xml.Unmarshal(xmlBytes, &req); err != nil { + return nil, "", "", fmt.Errorf("xml: %w", err) + } + sp := "" + if req.Issuer != nil { + sp = req.Issuer.Value + } + return &req, sp, relay, nil + } + return nil, "", "", fmt.Errorf("missing SAMLRequest") +} diff --git a/internal/http/util.go b/internal/http/util.go new file mode 100644 index 0000000..51589e0 --- /dev/null +++ b/internal/http/util.go @@ -0,0 +1,21 @@ +package http + +func hasAnyGroup(user map[string]struct{}, want []string) bool { + if len(want) == 0 { + return false + } + for _, g := range want { + if _, ok := user[g]; ok { + return true + } + } + return false +} + +func toSet(ss []string) map[string]struct{} { + m := make(map[string]struct{}, len(ss)) + for _, s := range ss { + m[s] = struct{}{} + } + return m +} diff --git a/internal/oidc/client.go b/internal/oidc/client.go new file mode 100644 index 0000000..223bb34 --- /dev/null +++ b/internal/oidc/client.go @@ -0,0 +1,70 @@ +package oidc + +import ( + "context" + "fmt" + "net/url" + + gooidc "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + "shamilnunhuck/saml-oidc-bridge/internal/config" +) + +type Client struct { + Verifier *gooidc.IDTokenVerifier + OAuth2 *oauth2.Config +} + +func NewClient(cfg *config.Config) (*Client, error) { + ctx := context.Background() + provider, err := gooidc.NewProvider(ctx, cfg.OIDC.Issuer) + if err != nil { + return nil, err + } + verifier := provider.Verifier(&gooidc.Config{ClientID: cfg.OIDC.ClientID}) + redirect := cfg.Server.ExternalURL + cfg.OIDC.RedirectPath + + scopes := []string{"openid"} + scopes = append(scopes, cfg.OIDC.Scopes...) + + oauth2cfg := &oauth2.Config{ + ClientID: cfg.OIDC.ClientID, + ClientSecret: cfg.OIDC.ClientSecret, + Endpoint: provider.Endpoint(), + Scopes: scopes, + RedirectURL: redirect, + } + return &Client{Verifier: verifier, OAuth2: oauth2cfg}, nil +} + +type Claims struct { + Subject string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + Groups []string `json:"groups"` +} + +func (c *Client) AuthCodeURL(state string, extra url.Values) string { + return c.OAuth2.AuthCodeURL(state) +} + +func (c *Client) ExchangeAndVerify(ctx context.Context, code string) (*Claims, error) { + token, err := c.OAuth2.Exchange(ctx, code) + if err != nil { + return nil, err + } + rawID, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("no id_token in token response") + } + idt, err := c.Verifier.Verify(ctx, rawID) + if err != nil { + return nil, err + } + var cl Claims + if err := idt.Claims(&cl); err != nil { + return nil, err + } + return &cl, nil +} diff --git a/internal/saml/idp.go b/internal/saml/idp.go new file mode 100644 index 0000000..7752f1c --- /dev/null +++ b/internal/saml/idp.go @@ -0,0 +1,70 @@ +package saml + +import ( + "encoding/base64" + "time" + + "shamilnunhuck/saml-oidc-bridge/internal/config" + "shamilnunhuck/saml-oidc-bridge/internal/crypto" + + "github.com/crewjam/saml" +) + +type IdP struct { + keys *crypto.KeyStore + sec config.Security + entityID string + ssoURL string +} + +func NewIdP(cfg *config.Config, ks *crypto.KeyStore) *IdP { + return &IdP{ + keys: ks, + sec: cfg.Security, + entityID: cfg.Server.ExternalURL, + ssoURL: cfg.Server.ExternalURL + "/saml/sso", + } +} + +func (i *IdP) Metadata() *saml.EntityDescriptor { + // we need to publish all certs, to allow safe rotation + keyDescriptors := []saml.KeyDescriptor{} + + for _, der := range i.keys.AllCertsDER() { + keyDescriptors = append(keyDescriptors, saml.KeyDescriptor{ + Use: "signing", + KeyInfo: saml.KeyInfo{ + X509Data: saml.X509Data{ + X509Certificates: []saml.X509Certificate{{ + Data: base64.StdEncoding.EncodeToString(der), + }}, + }, + }, + }) + } + + entityDescriptor := &saml.EntityDescriptor{ + EntityID: i.entityID, + // crewjam expects time.Time for ValidUntil and time.Duration for CacheDuration + ValidUntil: time.Now().UTC().Add(time.Duration(i.sec.MetadataValidUntilDays) * 24 * time.Hour), + CacheDuration: time.Duration(i.sec.MetadataCacheDurationSeconds) * time.Second, + IDPSSODescriptors: []saml.IDPSSODescriptor{ + { + SSODescriptor: saml.SSODescriptor{ + RoleDescriptor: saml.RoleDescriptor{ + ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol", + KeyDescriptors: keyDescriptors, + }, + }, + SingleSignOnServices: []saml.Endpoint{ + { + Binding: saml.HTTPPostBinding, + Location: i.ssoURL, + }, + }, + }, + }, + } + + return entityDescriptor +} diff --git a/internal/saml/sign.go b/internal/saml/sign.go new file mode 100644 index 0000000..29392d9 --- /dev/null +++ b/internal/saml/sign.go @@ -0,0 +1,204 @@ +package saml + +import ( + "crypto" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/beevik/etree" + "github.com/crewjam/saml" + dsig "github.com/russellhaering/goxmldsig" + + "shamilnunhuck/saml-oidc-bridge/internal/config" +) + +const ( + subjectConfirmationMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer" + nameIDFormatEntity = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity" + authnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport" + defaultAssertionTTL = 5 * time.Minute +) + +func (i *IdP) BuildResponse(sp config.SP, nameID string, attrs map[string][]string, inResponseTo string) (*saml.Response, error) { + now := saml.TimeNow() + ttl := time.Duration(i.sec.AssertionTTLSec) * time.Second + if ttl <= 0 { + ttl = defaultAssertionTTL + } + skew := time.Duration(i.sec.SkewSeconds) * time.Second + + assertionID, err := newSAMLID() + if err != nil { + return nil, fmt.Errorf("generate assertion id: %w", err) + } + responseID, err := newSAMLID() + if err != nil { + return nil, fmt.Errorf("generate response id: %w", err) + } + + audience := sp.Audience + if audience == "" { + audience = sp.EntityID + } + + nameIDFormat := sp.NameIDFormat + if nameIDFormat == "" { + nameIDFormat = string(saml.UnspecifiedNameIDFormat) + } + + assertion := &saml.Assertion{ + ID: assertionID, + IssueInstant: now, + Version: "2.0", + Issuer: saml.Issuer{ + Format: nameIDFormatEntity, + Value: i.entityID, + }, + Subject: &saml.Subject{ + NameID: &saml.NameID{ + Format: nameIDFormat, + Value: nameID, + }, + SubjectConfirmations: []saml.SubjectConfirmation{ + { + Method: subjectConfirmationMethodBearer, + SubjectConfirmationData: &saml.SubjectConfirmationData{ + InResponseTo: inResponseTo, + NotOnOrAfter: now.Add(ttl), + Recipient: sp.ACSURL, + }, + }, + }, + }, + Conditions: &saml.Conditions{ + NotBefore: now.Add(-skew), + NotOnOrAfter: now.Add(ttl), + AudienceRestrictions: []saml.AudienceRestriction{ + { + Audience: saml.Audience{Value: audience}, + }, + }, + }, + AuthnStatements: []saml.AuthnStatement{ + { + AuthnInstant: now, + SessionIndex: responseID, + AuthnContext: saml.AuthnContext{ + AuthnContextClassRef: &saml.AuthnContextClassRef{Value: authnContextPasswordProtectedTransport}, + }, + }, + }, + AttributeStatements: []saml.AttributeStatement{ + { + Attributes: toSAMLAttributes(attrs), + }, + }, + } + + resp := &saml.Response{ + ID: responseID, + Version: "2.0", + IssueInstant: now, + InResponseTo: inResponseTo, + Destination: sp.ACSURL, + Issuer: &saml.Issuer{ + Format: nameIDFormatEntity, + Value: i.entityID, + }, + Status: saml.Status{ + StatusCode: saml.StatusCode{Value: saml.StatusSuccess}, + }, + Assertion: assertion, + } + + if err := i.signResponse(resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (i *IdP) signingContext() (*dsig.SigningContext, error) { + keyPair := i.keys.Active() + if len(keyPair.Certificate) == 0 || keyPair.PrivateKey == nil { + return nil, errors.New("active key missing certificate or private key") + } + + ctx := dsig.NewDefaultSigningContext(dsig.TLSCertKeyStore(keyPair)) + ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("") + if err := ctx.SetSignatureMethod(dsig.RSASHA256SignatureMethod); err != nil { + return nil, err + } + ctx.Hash = crypto.SHA256 + return ctx, nil +} + +func (i *IdP) signResponse(resp *saml.Response) error { + if resp.Assertion == nil { + return errors.New("response missing assertion") + } + + assertionCtx, err := i.signingContext() + if err != nil { + return err + } + + assertionEl := resp.Assertion.Element() + signedAssertionEl, err := assertionCtx.SignEnveloped(assertionEl) + if err != nil { + return fmt.Errorf("sign assertion: %w", err) + } + sigEl, err := lastChildElement(signedAssertionEl) + if err != nil { + return fmt.Errorf("sign assertion: %w", err) + } + resp.Assertion.Signature = sigEl + + responseCtx, err := i.signingContext() + if err != nil { + return err + } + responseEl := resp.Element() + signedResponseEl, err := responseCtx.SignEnveloped(responseEl) + if err != nil { + return fmt.Errorf("sign response: %w", err) + } + sigEl, err = lastChildElement(signedResponseEl) + if err != nil { + return fmt.Errorf("sign response: %w", err) + } + resp.Signature = sigEl + return nil +} + +func MarshalSignedResponse(resp *saml.Response) ([]byte, error) { + if resp == nil { + return nil, errors.New("nil response") + } + if resp.Signature == nil { + return nil, errors.New("response not signed") + } + + doc := etree.NewDocument() + doc.SetRoot(resp.Element()) + return doc.WriteToBytes() +} + +func lastChildElement(parent *etree.Element) (*etree.Element, error) { + children := parent.ChildElements() + if len(children) == 0 { + return nil, errors.New("no child elements found") + } + return children[len(children)-1], nil +} + +func newSAMLID() (string, error) { + var b [20]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return "_" + hex.EncodeToString(b[:]), nil +} diff --git a/internal/saml/util.go b/internal/saml/util.go new file mode 100644 index 0000000..24e0fd3 --- /dev/null +++ b/internal/saml/util.go @@ -0,0 +1,26 @@ +package saml + +import "github.com/crewjam/saml" + +const attrNameFormatURI = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri" + +func toSAMLAttributes(attrs map[string][]string) []saml.Attribute { + out := make([]saml.Attribute, 0, len(attrs)) + for name, vals := range attrs { + vs := make([]saml.AttributeValue, 0, len(vals)) + for _, s := range vals { + vs = append(vs, saml.AttributeValue{ + // explicitly mark string type so we never emit xsi:type="" + Type: "xs:string", + Value: s, + }) + } + out = append(out, saml.Attribute{ + FriendlyName: name, + Name: name, + NameFormat: attrNameFormatURI, + Values: vs, + }) + } + return out +}