initial commit

This commit is contained in:
Shamil Nunhuck
2025-11-08 10:18:19 +00:00
commit 920a79b2e9
25 changed files with 1523 additions and 0 deletions

41
.gitignore vendored Normal file
View File

@@ -0,0 +1,41 @@
# Binaries and build artifacts
bin/
build/
dist/
out/
tmp/
# Go compiler outputs
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
# Dependency/vendor directories
vendor/
# Coverage and profiling data
coverage.*
*.cov
*.coverprofile
*.pprof
*.prof
profile.out
# Local configs and secrets
.env
.env.*
*.local
*.secret.yaml
# IDE/editor state
.idea/
.vscode/
*.code-workspace
# OS junk
.DS_Store
Thumbs.db

13
Dockerfile Normal file
View File

@@ -0,0 +1,13 @@
FROM golang:1.22 AS build
WORKDIR /src
COPY go.mod ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 go build -o /out/broker ./cmd/broker
FROM gcr.io/distroless/static:nonroot
ENV CONFIG_PATH=/config/config.yaml
WORKDIR /
USER nonroot:nonroot
COPY --from=build /out/broker /broker
ENTRYPOINT ["/broker"]

19
Makefile Normal file
View File

@@ -0,0 +1,19 @@
APP?=saml-oidc-broker
PKG?=shamilnunhuck/saml-oidc-bridge
CONFIG?=example.config.yaml
KEY_ID?=k-$(shell date +%Y-%m)
build:
GO111MODULE=on CGO_ENABLED=0 go build -o bin/$(APP) ./cmd/broker
run:
CONFIG_PATH=$(CONFIG) bin/$(APP)
rotate-key:
bin/$(APP) cert -config $(CONFIG) -id $(KEY_ID) -algo rsa3072 -days 825 -cn id.example.com -org "YourOrg" -k8s-secret-out build/$(KEY_ID).secret.yaml
@echo "Wrote build/$(KEY_ID).secret.yaml"
docker:
docker build --platform linux/amd64 -t shamilnunhuck/$(APP):dev .
.PHONY: build run rotate-key docker

View File

@@ -0,0 +1,6 @@
apiVersion: v2
name: saml-broker
description: Minimal SAML IdP brokering to OIDC (Pocket ID) for Splunk
type: application
version: 0.1.0
appVersion: "0.1.0"

View File

@@ -0,0 +1,4 @@
1. Get the service URL by running these commands:
export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "saml-broker.name" . }}" -o jsonpath="{.items[0].metadata.name}")
kubectl port-forward $POD_NAME 8080:8080 &
echo "Visit http://127.0.0.1:8080/saml/metadata"

View File

@@ -0,0 +1,10 @@
{{- define "saml-broker.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}}
{{- end -}}
{{- define "saml-broker.fullname" -}}
{{- if .Values.fullnameOverride -}}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" -}}
{{- else -}}
{{- printf "%s" (include "saml-broker.name" .) | trunc 63 | trimSuffix "-" -}}
{{- end -}}
{{- end -}}

View File

@@ -0,0 +1,7 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ include "saml-broker.fullname" . }}-config
data:
config.yaml: |
{{ toYaml .Values.config | indent 4 }}

View File

@@ -0,0 +1,49 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "saml-broker.fullname" . }}
labels:
app.kubernetes.io/name: {{ include "saml-broker.name" . }}
spec:
replicas: 1
selector:
matchLabels:
app.kubernetes.io/name: {{ include "saml-broker.name" . }}
template:
metadata:
labels:
app.kubernetes.io/name: {{ include "saml-broker.name" . }}
spec:
containers:
- name: broker
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
env:
- name: CONFIG_PATH
value: /config/config.yaml
- name: OIDC_CLIENT_SECRET
valueFrom:
secretKeyRef:
name: {{ .Values.env.OIDC_CLIENT_SECRET_SECRET_NAME }}
key: {{ .Values.env.OIDC_CLIENT_SECRET_KEY }}
ports:
- name: http
containerPort: 8080
volumeMounts:
- name: cfg
mountPath: /config
readOnly: true
readinessProbe:
httpGet:
path: /healthz
port: http
livenessProbe:
httpGet:
path: /healthz
port: http
resources:
{{ toYaml .Values.resources | indent 12 }}
volumes:
- name: cfg
configMap:
name: {{ include "saml-broker.fullname" . }}-config

View File

@@ -0,0 +1,30 @@
{{- if .Values.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "saml-broker.fullname" . }}
{{- with .Values.ingress.className }}
annotations:
kubernetes.io/ingress.class: {{ . }}
{{- end }}
spec:
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
pathType: {{ .pathType }}
backend:
service:
name: {{ include "saml-broker.fullname" $ }}
port:
number: {{ $.Values.service.port }}
{{- end }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{ toYaml .Values.ingress.tls | indent 4 }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,13 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "saml-broker.fullname" . }}
spec:
type: {{ .Values.service.type }}
ports:
- port: {{ .Values.service.port }}
targetPort: http
protocol: TCP
name: http
selector:
app.kubernetes.io/name: {{ include "saml-broker.name" . }}

View File

@@ -0,0 +1,63 @@
image:
repository: ghcr.io/your-org/broker
tag: dev
pullPolicy: IfNotPresent
service:
type: ClusterIP
port: 80
ingress:
enabled: false
className: ""
hosts:
- host: id.example.com
paths:
- path: /
pathType: Prefix
tls: []
resources: {}
env:
# OIDC client secret comes from a Secret
OIDC_CLIENT_SECRET_SECRET_NAME: oidc-secret
OIDC_CLIENT_SECRET_KEY: OIDC_CLIENT_SECRET
config:
# Paste example.config.yaml here (without private key if you mount keys via secret)
server:
listen: ":8080"
external_url: "https://id.example.com"
crypto:
active_key: "k-2025-09"
keys: []
oidc_upstream:
issuer: "https://pocket-id.example"
client_id: "your-client-id"
redirect_path: "/oidc/callback"
scopes: ["email","profile"]
sps:
- name: "splunk"
entity_id: "https://splunk.example"
acs_url: "https://splunk.example/saml/acs"
audience: "https://splunk.example"
nameid_format: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
attribute_mapping:
mail: "email"
realName: "name"
role: "role"
role_mapping:
admins: "admin"
power: "power"
"*": "user"
security:
skew_seconds: 120
assertion_ttl_seconds: 300
require_signed_authn_request: false
metadata_valid_until_days: 7
metadata_cache_duration_seconds: 86400
session:
cookie_name: "_saml_broker"
cookie_secure: true
cookie_domain: "id.example.com"

114
cmd/broker/main.go Normal file
View File

@@ -0,0 +1,114 @@
package main
import (
"log"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"github.com/fsnotify/fsnotify"
"shamilnunhuck/saml-oidc-bridge/internal/cli"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
h "shamilnunhuck/saml-oidc-bridge/internal/http"
"shamilnunhuck/saml-oidc-bridge/internal/oidc"
"shamilnunhuck/saml-oidc-bridge/internal/saml"
)
type runtimeState struct {
mu sync.RWMutex
cfg *config.Config
ks *crypto.KeyStore
idp *saml.IdP
oidc *oidc.Client
}
func main() {
if len(os.Args) > 1 && os.Args[1] == "cert" {
if err := cli.RunCert(os.Args[2:]); err != nil {
log.Fatal(err)
}
return
}
cfgPath := os.Getenv("CONFIG_PATH")
if cfgPath == "" {
cfgPath = "example.config.yaml"
}
state := &runtimeState{}
load := func() {
cfg, err := config.Load(cfgPath)
if err != nil {
log.Fatalf("load config: %v", err)
}
if v := os.Getenv("OIDC_CLIENT_SECRET"); v != "" {
cfg.OIDC.ClientSecret = v
}
if err := cfg.Validate(); err != nil {
log.Fatalf("invalid config: %v", err)
}
ks, err := crypto.NewKeyStore(cfg.Crypto)
if err != nil {
log.Fatalf("keystore: %v", err)
}
idp := saml.NewIdP(cfg, ks)
oc, err := oidc.NewClient(cfg)
if err != nil {
log.Fatalf("oidc: %v", err)
}
state.mu.Lock()
state.cfg, state.ks, state.idp, state.oidc = cfg, ks, idp, oc
state.mu.Unlock()
log.Printf("loaded config; active signing key=%s", cfg.Crypto.ActiveKey)
}
load()
mux := http.NewServeMux()
h.Register(
mux,
func() *config.Config { state.mu.RLock(); defer state.mu.RUnlock(); return state.cfg },
func() *saml.IdP { state.mu.RLock(); defer state.mu.RUnlock(); return state.idp },
func() *oidc.Client { state.mu.RLock(); defer state.mu.RUnlock(); return state.oidc },
)
go func() {
log.Printf("listening on %s", state.cfg.Server.Listen)
log.Fatal(http.ListenAndServe(state.cfg.Server.Listen, mux))
}()
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGHUP)
go func() {
for range sigc {
load()
}
}()
w, err := fsnotify.NewWatcher()
if err == nil {
defer w.Close()
_ = w.Add(cfgPath)
go func() {
for {
select {
case e := <-w.Events:
if e.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
load()
}
case err := <-w.Errors:
if err != nil {
log.Printf("watch error: %v", err)
}
}
}
}()
}
select {}
}

43
example.config.yaml Normal file
View File

@@ -0,0 +1,43 @@
server:
listen: :8080
external_url: https://saml-v.ttt.net
crypto:
active_key: k-2025-12
keys:
- id: k-2025-12
cert_pem: |
...
key_pem: |
...
not_after: 2028-01-06T12:27:11.670644Z
oidc_upstream:
issuer: https://id.tt.net
client_id: 1ec56384
redirect_path: /oidc/callback
scopes:
- email
- profile
sps:
- name: splunk
entity_id: https://splunk.example
acs_url: https://splunk.example/saml/acs
audience: https://splunk.example
nameid_format: urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress
attribute_mapping:
mail: email
realName: name
role: role
role_mapping:
'*': user
admins: admin
power: power
security:
skew_seconds: 120
assertion_ttl_seconds: 300
require_signed_authn_request: false
metadata_valid_until_days: 7
metadata_cache_duration_seconds: 86400
session:
cookie_name: _saml_broker
cookie_secure: true
cookie_domain: saml-v.ttt.net

21
go.mod Normal file
View File

@@ -0,0 +1,21 @@
module shamilnunhuck/saml-oidc-bridge
go 1.22
require (
github.com/coreos/go-oidc/v3 v3.11.0
github.com/crewjam/saml v0.5.1
github.com/fsnotify/fsnotify v1.7.0
golang.org/x/oauth2 v0.23.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/beevik/etree v1.5.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/russellhaering/goxmldsig v1.4.0 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/sys v0.30.0 // indirect
)

54
go.sum Normal file
View File

@@ -0,0 +1,54 @@
github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs=
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=
github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs=
github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs=
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/crewjam/saml v0.4.14 h1:g9FBNx62osKusnFzs3QTN5L9CVA/Egfgm+stJShzw/c=
github.com/crewjam/saml v0.4.14/go.mod h1:UVSZCf18jJkk6GpWNVqcyQJMD5HsRugBPf4I1nl2mME=
github.com/crewjam/saml v0.5.1 h1:g+mfp0CrLuLRZCK793PgJcZeg5dS/0CDwoeAX2zcwNI=
github.com/crewjam/saml v0.5.1/go.mod h1:r0fDkmFe5URDgPrmtH0IYokva6fac3AUdstiPhyEolQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ=
github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM=
github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw=
github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys=
github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

206
internal/cli/cert.go Normal file
View File

@@ -0,0 +1,206 @@
package cli
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"flag"
"fmt"
"math/big"
"os"
"strings"
"time"
"gopkg.in/yaml.v3"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type rotateOpts struct {
ConfigPath string
ID string
Algo string
Days int
CN string
Org string
OutK8s string
ActiveOnly bool
}
func RunCert(args []string) error {
fs := flag.NewFlagSet("cert", flag.ContinueOnError)
var ro rotateOpts
fs.StringVar(&ro.ConfigPath, "config", "example.config.yaml", "path to config yaml")
fs.StringVar(&ro.ID, "id", "", "key id (e.g. k-2025-10)")
fs.StringVar(&ro.Algo, "algo", "rsa3072", "rsa2048|rsa3072|rsa4096|p256|p384")
fs.IntVar(&ro.Days, "days", 825, "validity in days")
fs.StringVar(&ro.CN, "cn", "id.example.com", "certificate CN")
fs.StringVar(&ro.Org, "org", "YourOrg", "certificate O")
fs.StringVar(&ro.OutK8s, "k8s-secret-out", "", "write a Kubernetes Secret manifest to this path")
fs.BoolVar(&ro.ActiveOnly, "active-only", false, "only set active_key to -id (no new cert)")
if err := fs.Parse(args); err != nil {
return err
}
if ro.ID == "" {
return errors.New("missing -id")
}
cfg, raw, err := loadConfig(ro.ConfigPath)
if err != nil {
return err
}
if ro.ActiveOnly {
cfg.Crypto.ActiveKey = ro.ID
return saveConfig(ro.ConfigPath, raw, cfg)
}
certPEM, keyPEM, notAfter, err := genSelfSigned(ro)
if err != nil {
return err
}
cfg.Crypto.Keys = append(cfg.Crypto.Keys, config.KeyPair{
ID: ro.ID,
CertPEM: string(certPEM),
KeyPEM: string(keyPEM),
NotAfter: notAfter.UTC(),
})
cfg.Crypto.ActiveKey = ro.ID
if err := saveConfig(ro.ConfigPath, raw, cfg); err != nil {
return err
}
if ro.OutK8s != "" {
if err := os.WriteFile(ro.OutK8s, []byte(k8sSecretYAML(ro, certPEM, keyPEM)), 0o600); err != nil {
return fmt.Errorf("write secret: %w", err)
}
}
fmt.Printf("OK: generated %s (algo=%s, not_after=%s) and set active_key\n", ro.ID, ro.Algo, notAfter.UTC().Format(time.RFC3339))
return nil
}
func genSelfSigned(ro rotateOpts) (certPEM, keyPEM []byte, notAfter time.Time, err error) {
nb := time.Now().Add(-5 * time.Minute)
na := nb.Add(time.Duration(ro.Days) * 24 * time.Hour)
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
template := x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: ro.CN,
Organization: []string{ro.Org},
OrganizationalUnit: []string{"SAML Signing"},
},
NotBefore: nb,
NotAfter: na,
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
var der []byte
var keyPKCS8 []byte
switch strings.ToLower(ro.Algo) {
case "rsa2048", "rsa3072", "rsa4096":
bits := 2048
if ro.Algo == "rsa3072" {
bits = 3072
}
if ro.Algo == "rsa4096" {
bits = 4096
}
priv, e := rsa.GenerateKey(rand.Reader, bits)
if e != nil {
return nil, nil, time.Time{}, e
}
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, time.Time{}, err
}
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, time.Time{}, err
}
case "p256", "p384":
curve := elliptic.P256()
if ro.Algo == "p384" {
curve = elliptic.P384()
}
priv, e := ecdsa.GenerateKey(curve, rand.Reader)
if e != nil {
return nil, nil, time.Time{}, e
}
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, time.Time{}, err
}
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, time.Time{}, err
}
default:
return nil, nil, time.Time{}, fmt.Errorf("unknown -algo %q", ro.Algo)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyPKCS8})
return certPEM, keyPEM, na, nil
}
func loadConfig(path string) (*config.Config, *yaml.Node, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, nil, err
}
var root yaml.Node
if err := yaml.Unmarshal(b, &root); err != nil {
return nil, nil, err
}
var c config.Config
if err := root.Decode(&c); err != nil {
return nil, nil, err
}
return &c, &root, nil
}
func saveConfig(path string, _ *yaml.Node, c *config.Config) error {
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err := enc.Encode(c); err != nil {
return err
}
_ = enc.Close()
return os.WriteFile(path, buf.Bytes(), 0o644)
}
func k8sSecretYAML(ro rotateOpts, certPEM, keyPEM []byte) string {
name := strings.ToLower(strings.ReplaceAll(ro.ID, "_", "-"))
return fmt.Sprintf(`apiVersion: v1
kind: Secret
metadata:
name: saml-signing-%s
type: Opaque
stringData:
cert.pem: |-
%s
key.pem: |-
%s
`, name, indent(string(certPEM), 4), indent(string(keyPEM), 4))
}
func indent(s string, n int) string {
pad := strings.Repeat(" ", n)
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
for i := range lines {
lines[i] = pad + lines[i]
}
return strings.Join(lines, "\n")
}

36
internal/config/config.go Normal file
View File

@@ -0,0 +1,36 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
func Load(path string) (*Config, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var c Config
if err := yaml.Unmarshal(b, &c); err != nil {
return nil, err
}
return &c, nil
}
func (c *Config) Validate() error {
if c.Server.ExternalURL == "" || c.Server.Listen == "" {
return fmt.Errorf("server.external_url and server.listen required")
}
if len(c.SPs) == 0 {
return fmt.Errorf("at least one SP required")
}
if c.OIDC.Issuer == "" || c.OIDC.ClientID == "" || c.OIDC.RedirectPath == "" {
return fmt.Errorf("oidc issuer/client_id/redirect_path required")
}
if c.Crypto.ActiveKey == "" || len(c.Crypto.Keys) == 0 {
return fmt.Errorf("crypto.active_key and at least one key required")
}
return nil
}

68
internal/config/types.go Normal file
View File

@@ -0,0 +1,68 @@
package config
import "time"
type Server struct {
Listen string `yaml:"listen"`
ExternalURL string `yaml:"external_url"`
}
type KeyPair struct {
ID string `yaml:"id"`
CertPEM string `yaml:"cert_pem"`
KeyPEM string `yaml:"key_pem"`
NotAfter time.Time `yaml:"not_after"`
}
type Crypto struct {
ActiveKey string `yaml:"active_key"`
Keys []KeyPair `yaml:"keys"`
}
type OIDC struct {
Issuer string `yaml:"issuer"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"-"`
RedirectPath string `yaml:"redirect_path"`
Scopes []string `yaml:"scopes"`
}
type SP struct {
Name string `yaml:"name"`
EntityID string `yaml:"entity_id"`
ACSURL string `yaml:"acs_url"`
Audience string `yaml:"audience"`
NameIDFormat string `yaml:"nameid_format"`
AttributeMapping map[string]string `yaml:"attribute_mapping"`
RoleMapping map[string]string `yaml:"role_mapping"`
AttributeRules []AttributeRule `yaml:"attribute_rules"`
}
type Security struct {
SkewSeconds int `yaml:"skew_seconds"`
AssertionTTLSec int `yaml:"assertion_ttl_seconds"`
RequireSignedAuthnRequest bool `yaml:"require_signed_authn_request"`
MetadataValidUntilDays int `yaml:"metadata_valid_until_days"`
MetadataCacheDurationSeconds int `yaml:"metadata_cache_duration_seconds"`
}
type Session struct {
CookieName string `yaml:"cookie_name"`
CookieSecure bool `yaml:"cookie_secure"`
CookieDomain string `yaml:"cookie_domain"`
}
type Config struct {
Server Server `yaml:"server"`
Crypto Crypto `yaml:"crypto"`
OIDC OIDC `yaml:"oidc_upstream"`
SPs []SP `yaml:"sps"`
Security Security `yaml:"security"`
Session Session `yaml:"session"`
}
type AttributeRule struct {
Name string `yaml:"name"`
Value string `yaml:"value"`
IfGroupsAny []string `yaml:"if_groups_any"`
EmitWhenFalse bool `yaml:"emit_when_false"`
}

View File

@@ -0,0 +1,74 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type KeyStore struct {
activeID string
signers map[string]tls.Certificate
certsDER map[string][]byte
}
func NewKeyStore(c config.Crypto) (*KeyStore, error) {
ks := &KeyStore{
activeID: c.ActiveKey,
signers: map[string]tls.Certificate{},
certsDER: map[string][]byte{},
}
for _, k := range c.Keys {
cert, priv, err := parseKeypair(k.CertPEM, k.KeyPEM)
if err != nil {
return nil, fmt.Errorf("key %s: %w", k.ID, err)
}
ks.certsDER[k.ID] = cert.Raw
if priv != nil {
ks.signers[k.ID] = tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: priv}
}
}
if _, ok := ks.signers[ks.activeID]; !ok {
return nil, errors.New("active signing key not available (missing or no private key)")
}
return ks, nil
}
func (ks *KeyStore) Active() tls.Certificate { return ks.signers[ks.activeID] }
func (ks *KeyStore) AllCertsDER() [][]byte {
out := make([][]byte, 0, len(ks.certsDER))
for _, der := range ks.certsDER {
out = append(out, der)
}
return out
}
func parseKeypair(certPEM, keyPEM string) (*x509.Certificate, interface{}, error) {
cb, _ := pem.Decode([]byte(certPEM))
if cb == nil {
return nil, nil, errors.New("invalid cert pem")
}
cert, err := x509.ParseCertificate(cb.Bytes)
if err != nil {
return nil, nil, err
}
var priv interface{}
if keyPEM != "" {
kb, _ := pem.Decode([]byte(keyPEM))
if kb == nil {
return nil, nil, errors.New("invalid key pem")
}
priv, err = x509.ParsePKCS8PrivateKey(kb.Bytes)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(kb.Bytes)
if err != nil {
return nil, nil, err
}
}
}
return cert, priv, nil
}

261
internal/http/handlers.go Normal file
View File

@@ -0,0 +1,261 @@
package http
import (
"compress/flate"
"context"
"crypto/rand"
"encoding/base64"
"encoding/xml"
"fmt"
"html/template"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/crewjam/saml"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/oidc"
idsaml "shamilnunhuck/saml-oidc-bridge/internal/saml"
)
type IdP interface {
Metadata() *saml.EntityDescriptor
BuildResponse(sp config.SP, nameID string, attrs map[string][]string) (*saml.Response, error)
}
func Register(
mux *http.ServeMux,
getCfg func() *config.Config,
getIdP func() *idsaml.IdP,
getOIDC func() *oidc.Client,
) {
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, _ = w.Write([]byte("ok"))
})
mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) {
meta := getIdP().Metadata()
buf, err := xml.MarshalIndent(meta, "", " ")
if err != nil {
http.Error(w, "marshal metadata: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, _ = w.Write([]byte(xml.Header))
_, _ = w.Write(buf)
})
mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) {
req, spEntityID, relay, err := parseAuthnRequest(r)
log.Printf("AuthnRequest from SP=%s requestID=%s relay=%q", spEntityID, req.ID, relay)
if err != nil {
http.Error(w, "bad authn request: "+err.Error(), 400)
return
}
setStateCookie(w, getCfg().Session, spEntityID, relay, req.ID)
state := randomState()
http.Redirect(w, r, getOIDC().AuthCodeURL(state, url.Values{}), http.StatusFound)
})
mux.HandleFunc(getCfg().OIDC.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
code := r.URL.Query().Get("code")
if code == "" {
http.Error(w, "missing code", 400)
return
}
claims, err := getOIDC().ExchangeAndVerify(ctx, code)
if err != nil {
http.Error(w, "oidc: "+err.Error(), 400)
return
}
s, err := readStateCookie(r, getCfg().Session)
if err != nil {
http.Error(w, "state missing", 400)
return
}
sp := lookupSP(getCfg(), s.SPEntityID)
if sp == nil {
http.Error(w, "unknown SP", 400)
return
}
attrs := map[string][]string{}
for samlAttr, oidcClaim := range sp.AttributeMapping {
switch oidcClaim {
case "email":
attrs[samlAttr] = []string{claims.Email}
case "name":
attrs[samlAttr] = []string{claims.Name}
case "role":
attrs[samlAttr] = []string{mapRole(claims.Groups, sp)}
}
}
nameID := claims.Email
groups := claims.Groups
// Apply generic conditional attribute rules
userGroups := toSet(groups) // []string -> set
for _, rule := range sp.AttributeRules {
if hasAnyGroup(userGroups, rule.IfGroupsAny) {
if rule.Value == "" { // safe default
attrs[rule.Name] = []string{"true"}
} else {
attrs[rule.Name] = []string{rule.Value}
}
} else if rule.EmitWhenFalse {
attrs[rule.Name] = []string{"false"}
}
}
resp, err := getIdP().BuildResponse(*sp, nameID, attrs, s.RequestID)
if err != nil {
http.Error(w, "saml: "+err.Error(), 500)
return
}
xmlBytes, err := idsaml.MarshalSignedResponse(resp)
if err != nil {
http.Error(w, "sign: "+err.Error(), 500)
return
}
postToACS(w, sp.ACSURL, base64.StdEncoding.EncodeToString(xmlBytes), s.RelayState)
})
}
/*** helpers ***/
type state struct {
SPEntityID string
RelayState string
RequestID string
Expiry time.Time
}
func setStateCookie(w http.ResponseWriter, s config.Session, spEntityID, relay, reqID string) {
v := url.Values{}
v.Set("sp", spEntityID)
v.Set("rs", relay)
v.Set("rid", reqID)
c := &http.Cookie{
Name: s.CookieName,
Value: base64.RawURLEncoding.EncodeToString([]byte(v.Encode())),
Path: "/",
Domain: s.CookieDomain,
HttpOnly: true,
Secure: s.CookieSecure,
MaxAge: 600,
}
http.SetCookie(w, c)
}
func readStateCookie(r *http.Request, s config.Session) (*state, error) {
c, err := r.Cookie(s.CookieName)
if err != nil {
return nil, err
}
b, err := base64.RawURLEncoding.DecodeString(c.Value)
if err != nil {
return nil, err
}
v, err := url.ParseQuery(string(b))
if err != nil {
return nil, err
}
return &state{
SPEntityID: v.Get("sp"),
RelayState: v.Get("rs"),
RequestID: v.Get("rid"),
Expiry: time.Now().Add(10 * time.Minute),
}, nil
}
func lookupSP(cfg *config.Config, entityID string) *config.SP {
for i := range cfg.SPs {
if cfg.SPs[i].EntityID == entityID {
return &cfg.SPs[i]
}
}
return nil
}
func mapRole(groups []string, sp *config.SP) string {
for _, g := range groups {
if v, ok := sp.RoleMapping[g]; ok {
return v
}
}
if v, ok := sp.RoleMapping["*"]; ok {
return v
}
return "user"
}
func postToACS(w http.ResponseWriter, acsURL string, samlResponseB64 string, relay string) {
const tpl = `<!doctype html>
<html><body onload="document.forms[0].submit()">
<form method="post" action="{{.ACS}}">
<input type="hidden" name="SAMLResponse" value="{{.Resp}}">
{{if .Relay}}<input type="hidden" name="RelayState" value="{{.Relay}}">{{end}}
<noscript><button type="submit">Continue</button></noscript>
</form></body></html>`
t := template.Must(template.New("post").Parse(tpl))
_ = t.Execute(w, map[string]string{"ACS": acsURL, "Resp": samlResponseB64, "Relay": relay})
}
func randomState() string {
var b [16]byte
_, _ = rand.Read(b[:])
return base64.RawURLEncoding.EncodeToString(b[:])
}
func parseAuthnRequest(r *http.Request) (*saml.AuthnRequest, string, string, error) {
relay := r.FormValue("RelayState")
if sr := r.URL.Query().Get("SAMLRequest"); sr != "" {
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
if err != nil {
return nil, "", "", fmt.Errorf("b64: %w", err)
}
reader := flate.NewReader(strings.NewReader(string(xmlBytes)))
defer reader.Close()
var sb strings.Builder
if _, err := io.Copy(&sb, reader); err != nil {
return nil, "", "", fmt.Errorf("inflate: %w", err)
}
var req saml.AuthnRequest
if err := xml.Unmarshal([]byte(sb.String()), &req); err != nil {
return nil, "", "", fmt.Errorf("xml: %w", err)
}
sp := ""
if req.Issuer != nil {
sp = req.Issuer.Value
}
return &req, sp, relay, nil
}
if sr := r.FormValue("SAMLRequest"); sr != "" {
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
if err != nil {
return nil, "", "", fmt.Errorf("b64: %w", err)
}
var req saml.AuthnRequest
if err := xml.Unmarshal(xmlBytes, &req); err != nil {
return nil, "", "", fmt.Errorf("xml: %w", err)
}
sp := ""
if req.Issuer != nil {
sp = req.Issuer.Value
}
return &req, sp, relay, nil
}
return nil, "", "", fmt.Errorf("missing SAMLRequest")
}

21
internal/http/util.go Normal file
View File

@@ -0,0 +1,21 @@
package http
func hasAnyGroup(user map[string]struct{}, want []string) bool {
if len(want) == 0 {
return false
}
for _, g := range want {
if _, ok := user[g]; ok {
return true
}
}
return false
}
func toSet(ss []string) map[string]struct{} {
m := make(map[string]struct{}, len(ss))
for _, s := range ss {
m[s] = struct{}{}
}
return m
}

70
internal/oidc/client.go Normal file
View File

@@ -0,0 +1,70 @@
package oidc
import (
"context"
"fmt"
"net/url"
gooidc "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type Client struct {
Verifier *gooidc.IDTokenVerifier
OAuth2 *oauth2.Config
}
func NewClient(cfg *config.Config) (*Client, error) {
ctx := context.Background()
provider, err := gooidc.NewProvider(ctx, cfg.OIDC.Issuer)
if err != nil {
return nil, err
}
verifier := provider.Verifier(&gooidc.Config{ClientID: cfg.OIDC.ClientID})
redirect := cfg.Server.ExternalURL + cfg.OIDC.RedirectPath
scopes := []string{"openid"}
scopes = append(scopes, cfg.OIDC.Scopes...)
oauth2cfg := &oauth2.Config{
ClientID: cfg.OIDC.ClientID,
ClientSecret: cfg.OIDC.ClientSecret,
Endpoint: provider.Endpoint(),
Scopes: scopes,
RedirectURL: redirect,
}
return &Client{Verifier: verifier, OAuth2: oauth2cfg}, nil
}
type Claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
Groups []string `json:"groups"`
}
func (c *Client) AuthCodeURL(state string, extra url.Values) string {
return c.OAuth2.AuthCodeURL(state)
}
func (c *Client) ExchangeAndVerify(ctx context.Context, code string) (*Claims, error) {
token, err := c.OAuth2.Exchange(ctx, code)
if err != nil {
return nil, err
}
rawID, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("no id_token in token response")
}
idt, err := c.Verifier.Verify(ctx, rawID)
if err != nil {
return nil, err
}
var cl Claims
if err := idt.Claims(&cl); err != nil {
return nil, err
}
return &cl, nil
}

70
internal/saml/idp.go Normal file
View File

@@ -0,0 +1,70 @@
package saml
import (
"encoding/base64"
"time"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
"github.com/crewjam/saml"
)
type IdP struct {
keys *crypto.KeyStore
sec config.Security
entityID string
ssoURL string
}
func NewIdP(cfg *config.Config, ks *crypto.KeyStore) *IdP {
return &IdP{
keys: ks,
sec: cfg.Security,
entityID: cfg.Server.ExternalURL,
ssoURL: cfg.Server.ExternalURL + "/saml/sso",
}
}
func (i *IdP) Metadata() *saml.EntityDescriptor {
// we need to publish all certs, to allow safe rotation
keyDescriptors := []saml.KeyDescriptor{}
for _, der := range i.keys.AllCertsDER() {
keyDescriptors = append(keyDescriptors, saml.KeyDescriptor{
Use: "signing",
KeyInfo: saml.KeyInfo{
X509Data: saml.X509Data{
X509Certificates: []saml.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(der),
}},
},
},
})
}
entityDescriptor := &saml.EntityDescriptor{
EntityID: i.entityID,
// crewjam expects time.Time for ValidUntil and time.Duration for CacheDuration
ValidUntil: time.Now().UTC().Add(time.Duration(i.sec.MetadataValidUntilDays) * 24 * time.Hour),
CacheDuration: time.Duration(i.sec.MetadataCacheDurationSeconds) * time.Second,
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
SSODescriptor: saml.SSODescriptor{
RoleDescriptor: saml.RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
KeyDescriptors: keyDescriptors,
},
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPPostBinding,
Location: i.ssoURL,
},
},
},
},
}
return entityDescriptor
}

204
internal/saml/sign.go Normal file
View File

@@ -0,0 +1,204 @@
package saml
import (
"crypto"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
const (
subjectConfirmationMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
nameIDFormatEntity = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"
authnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
defaultAssertionTTL = 5 * time.Minute
)
func (i *IdP) BuildResponse(sp config.SP, nameID string, attrs map[string][]string, inResponseTo string) (*saml.Response, error) {
now := saml.TimeNow()
ttl := time.Duration(i.sec.AssertionTTLSec) * time.Second
if ttl <= 0 {
ttl = defaultAssertionTTL
}
skew := time.Duration(i.sec.SkewSeconds) * time.Second
assertionID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate assertion id: %w", err)
}
responseID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate response id: %w", err)
}
audience := sp.Audience
if audience == "" {
audience = sp.EntityID
}
nameIDFormat := sp.NameIDFormat
if nameIDFormat == "" {
nameIDFormat = string(saml.UnspecifiedNameIDFormat)
}
assertion := &saml.Assertion{
ID: assertionID,
IssueInstant: now,
Version: "2.0",
Issuer: saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Subject: &saml.Subject{
NameID: &saml.NameID{
Format: nameIDFormat,
Value: nameID,
},
SubjectConfirmations: []saml.SubjectConfirmation{
{
Method: subjectConfirmationMethodBearer,
SubjectConfirmationData: &saml.SubjectConfirmationData{
InResponseTo: inResponseTo,
NotOnOrAfter: now.Add(ttl),
Recipient: sp.ACSURL,
},
},
},
},
Conditions: &saml.Conditions{
NotBefore: now.Add(-skew),
NotOnOrAfter: now.Add(ttl),
AudienceRestrictions: []saml.AudienceRestriction{
{
Audience: saml.Audience{Value: audience},
},
},
},
AuthnStatements: []saml.AuthnStatement{
{
AuthnInstant: now,
SessionIndex: responseID,
AuthnContext: saml.AuthnContext{
AuthnContextClassRef: &saml.AuthnContextClassRef{Value: authnContextPasswordProtectedTransport},
},
},
},
AttributeStatements: []saml.AttributeStatement{
{
Attributes: toSAMLAttributes(attrs),
},
},
}
resp := &saml.Response{
ID: responseID,
Version: "2.0",
IssueInstant: now,
InResponseTo: inResponseTo,
Destination: sp.ACSURL,
Issuer: &saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Status: saml.Status{
StatusCode: saml.StatusCode{Value: saml.StatusSuccess},
},
Assertion: assertion,
}
if err := i.signResponse(resp); err != nil {
return nil, err
}
return resp, nil
}
func (i *IdP) signingContext() (*dsig.SigningContext, error) {
keyPair := i.keys.Active()
if len(keyPair.Certificate) == 0 || keyPair.PrivateKey == nil {
return nil, errors.New("active key missing certificate or private key")
}
ctx := dsig.NewDefaultSigningContext(dsig.TLSCertKeyStore(keyPair))
ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
if err := ctx.SetSignatureMethod(dsig.RSASHA256SignatureMethod); err != nil {
return nil, err
}
ctx.Hash = crypto.SHA256
return ctx, nil
}
func (i *IdP) signResponse(resp *saml.Response) error {
if resp.Assertion == nil {
return errors.New("response missing assertion")
}
assertionCtx, err := i.signingContext()
if err != nil {
return err
}
assertionEl := resp.Assertion.Element()
signedAssertionEl, err := assertionCtx.SignEnveloped(assertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
sigEl, err := lastChildElement(signedAssertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
resp.Assertion.Signature = sigEl
responseCtx, err := i.signingContext()
if err != nil {
return err
}
responseEl := resp.Element()
signedResponseEl, err := responseCtx.SignEnveloped(responseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
sigEl, err = lastChildElement(signedResponseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
resp.Signature = sigEl
return nil
}
func MarshalSignedResponse(resp *saml.Response) ([]byte, error) {
if resp == nil {
return nil, errors.New("nil response")
}
if resp.Signature == nil {
return nil, errors.New("response not signed")
}
doc := etree.NewDocument()
doc.SetRoot(resp.Element())
return doc.WriteToBytes()
}
func lastChildElement(parent *etree.Element) (*etree.Element, error) {
children := parent.ChildElements()
if len(children) == 0 {
return nil, errors.New("no child elements found")
}
return children[len(children)-1], nil
}
func newSAMLID() (string, error) {
var b [20]byte
if _, err := rand.Read(b[:]); err != nil {
return "", err
}
return "_" + hex.EncodeToString(b[:]), nil
}

26
internal/saml/util.go Normal file
View File

@@ -0,0 +1,26 @@
package saml
import "github.com/crewjam/saml"
const attrNameFormatURI = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
func toSAMLAttributes(attrs map[string][]string) []saml.Attribute {
out := make([]saml.Attribute, 0, len(attrs))
for name, vals := range attrs {
vs := make([]saml.AttributeValue, 0, len(vals))
for _, s := range vals {
vs = append(vs, saml.AttributeValue{
// explicitly mark string type so we never emit xsi:type=""
Type: "xs:string",
Value: s,
})
}
out = append(out, saml.Attribute{
FriendlyName: name,
Name: name,
NameFormat: attrNameFormatURI,
Values: vs,
})
}
return out
}