initial commit

This commit is contained in:
Shamil Nunhuck
2025-11-08 10:18:19 +00:00
commit 920a79b2e9
25 changed files with 1523 additions and 0 deletions

206
internal/cli/cert.go Normal file
View File

@@ -0,0 +1,206 @@
package cli
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"flag"
"fmt"
"math/big"
"os"
"strings"
"time"
"gopkg.in/yaml.v3"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type rotateOpts struct {
ConfigPath string
ID string
Algo string
Days int
CN string
Org string
OutK8s string
ActiveOnly bool
}
func RunCert(args []string) error {
fs := flag.NewFlagSet("cert", flag.ContinueOnError)
var ro rotateOpts
fs.StringVar(&ro.ConfigPath, "config", "example.config.yaml", "path to config yaml")
fs.StringVar(&ro.ID, "id", "", "key id (e.g. k-2025-10)")
fs.StringVar(&ro.Algo, "algo", "rsa3072", "rsa2048|rsa3072|rsa4096|p256|p384")
fs.IntVar(&ro.Days, "days", 825, "validity in days")
fs.StringVar(&ro.CN, "cn", "id.example.com", "certificate CN")
fs.StringVar(&ro.Org, "org", "YourOrg", "certificate O")
fs.StringVar(&ro.OutK8s, "k8s-secret-out", "", "write a Kubernetes Secret manifest to this path")
fs.BoolVar(&ro.ActiveOnly, "active-only", false, "only set active_key to -id (no new cert)")
if err := fs.Parse(args); err != nil {
return err
}
if ro.ID == "" {
return errors.New("missing -id")
}
cfg, raw, err := loadConfig(ro.ConfigPath)
if err != nil {
return err
}
if ro.ActiveOnly {
cfg.Crypto.ActiveKey = ro.ID
return saveConfig(ro.ConfigPath, raw, cfg)
}
certPEM, keyPEM, notAfter, err := genSelfSigned(ro)
if err != nil {
return err
}
cfg.Crypto.Keys = append(cfg.Crypto.Keys, config.KeyPair{
ID: ro.ID,
CertPEM: string(certPEM),
KeyPEM: string(keyPEM),
NotAfter: notAfter.UTC(),
})
cfg.Crypto.ActiveKey = ro.ID
if err := saveConfig(ro.ConfigPath, raw, cfg); err != nil {
return err
}
if ro.OutK8s != "" {
if err := os.WriteFile(ro.OutK8s, []byte(k8sSecretYAML(ro, certPEM, keyPEM)), 0o600); err != nil {
return fmt.Errorf("write secret: %w", err)
}
}
fmt.Printf("OK: generated %s (algo=%s, not_after=%s) and set active_key\n", ro.ID, ro.Algo, notAfter.UTC().Format(time.RFC3339))
return nil
}
func genSelfSigned(ro rotateOpts) (certPEM, keyPEM []byte, notAfter time.Time, err error) {
nb := time.Now().Add(-5 * time.Minute)
na := nb.Add(time.Duration(ro.Days) * 24 * time.Hour)
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
template := x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: ro.CN,
Organization: []string{ro.Org},
OrganizationalUnit: []string{"SAML Signing"},
},
NotBefore: nb,
NotAfter: na,
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
var der []byte
var keyPKCS8 []byte
switch strings.ToLower(ro.Algo) {
case "rsa2048", "rsa3072", "rsa4096":
bits := 2048
if ro.Algo == "rsa3072" {
bits = 3072
}
if ro.Algo == "rsa4096" {
bits = 4096
}
priv, e := rsa.GenerateKey(rand.Reader, bits)
if e != nil {
return nil, nil, time.Time{}, e
}
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, time.Time{}, err
}
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, time.Time{}, err
}
case "p256", "p384":
curve := elliptic.P256()
if ro.Algo == "p384" {
curve = elliptic.P384()
}
priv, e := ecdsa.GenerateKey(curve, rand.Reader)
if e != nil {
return nil, nil, time.Time{}, e
}
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, time.Time{}, err
}
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, time.Time{}, err
}
default:
return nil, nil, time.Time{}, fmt.Errorf("unknown -algo %q", ro.Algo)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyPKCS8})
return certPEM, keyPEM, na, nil
}
func loadConfig(path string) (*config.Config, *yaml.Node, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, nil, err
}
var root yaml.Node
if err := yaml.Unmarshal(b, &root); err != nil {
return nil, nil, err
}
var c config.Config
if err := root.Decode(&c); err != nil {
return nil, nil, err
}
return &c, &root, nil
}
func saveConfig(path string, _ *yaml.Node, c *config.Config) error {
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err := enc.Encode(c); err != nil {
return err
}
_ = enc.Close()
return os.WriteFile(path, buf.Bytes(), 0o644)
}
func k8sSecretYAML(ro rotateOpts, certPEM, keyPEM []byte) string {
name := strings.ToLower(strings.ReplaceAll(ro.ID, "_", "-"))
return fmt.Sprintf(`apiVersion: v1
kind: Secret
metadata:
name: saml-signing-%s
type: Opaque
stringData:
cert.pem: |-
%s
key.pem: |-
%s
`, name, indent(string(certPEM), 4), indent(string(keyPEM), 4))
}
func indent(s string, n int) string {
pad := strings.Repeat(" ", n)
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
for i := range lines {
lines[i] = pad + lines[i]
}
return strings.Join(lines, "\n")
}

36
internal/config/config.go Normal file
View File

@@ -0,0 +1,36 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
func Load(path string) (*Config, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var c Config
if err := yaml.Unmarshal(b, &c); err != nil {
return nil, err
}
return &c, nil
}
func (c *Config) Validate() error {
if c.Server.ExternalURL == "" || c.Server.Listen == "" {
return fmt.Errorf("server.external_url and server.listen required")
}
if len(c.SPs) == 0 {
return fmt.Errorf("at least one SP required")
}
if c.OIDC.Issuer == "" || c.OIDC.ClientID == "" || c.OIDC.RedirectPath == "" {
return fmt.Errorf("oidc issuer/client_id/redirect_path required")
}
if c.Crypto.ActiveKey == "" || len(c.Crypto.Keys) == 0 {
return fmt.Errorf("crypto.active_key and at least one key required")
}
return nil
}

68
internal/config/types.go Normal file
View File

@@ -0,0 +1,68 @@
package config
import "time"
type Server struct {
Listen string `yaml:"listen"`
ExternalURL string `yaml:"external_url"`
}
type KeyPair struct {
ID string `yaml:"id"`
CertPEM string `yaml:"cert_pem"`
KeyPEM string `yaml:"key_pem"`
NotAfter time.Time `yaml:"not_after"`
}
type Crypto struct {
ActiveKey string `yaml:"active_key"`
Keys []KeyPair `yaml:"keys"`
}
type OIDC struct {
Issuer string `yaml:"issuer"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"-"`
RedirectPath string `yaml:"redirect_path"`
Scopes []string `yaml:"scopes"`
}
type SP struct {
Name string `yaml:"name"`
EntityID string `yaml:"entity_id"`
ACSURL string `yaml:"acs_url"`
Audience string `yaml:"audience"`
NameIDFormat string `yaml:"nameid_format"`
AttributeMapping map[string]string `yaml:"attribute_mapping"`
RoleMapping map[string]string `yaml:"role_mapping"`
AttributeRules []AttributeRule `yaml:"attribute_rules"`
}
type Security struct {
SkewSeconds int `yaml:"skew_seconds"`
AssertionTTLSec int `yaml:"assertion_ttl_seconds"`
RequireSignedAuthnRequest bool `yaml:"require_signed_authn_request"`
MetadataValidUntilDays int `yaml:"metadata_valid_until_days"`
MetadataCacheDurationSeconds int `yaml:"metadata_cache_duration_seconds"`
}
type Session struct {
CookieName string `yaml:"cookie_name"`
CookieSecure bool `yaml:"cookie_secure"`
CookieDomain string `yaml:"cookie_domain"`
}
type Config struct {
Server Server `yaml:"server"`
Crypto Crypto `yaml:"crypto"`
OIDC OIDC `yaml:"oidc_upstream"`
SPs []SP `yaml:"sps"`
Security Security `yaml:"security"`
Session Session `yaml:"session"`
}
type AttributeRule struct {
Name string `yaml:"name"`
Value string `yaml:"value"`
IfGroupsAny []string `yaml:"if_groups_any"`
EmitWhenFalse bool `yaml:"emit_when_false"`
}

View File

@@ -0,0 +1,74 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type KeyStore struct {
activeID string
signers map[string]tls.Certificate
certsDER map[string][]byte
}
func NewKeyStore(c config.Crypto) (*KeyStore, error) {
ks := &KeyStore{
activeID: c.ActiveKey,
signers: map[string]tls.Certificate{},
certsDER: map[string][]byte{},
}
for _, k := range c.Keys {
cert, priv, err := parseKeypair(k.CertPEM, k.KeyPEM)
if err != nil {
return nil, fmt.Errorf("key %s: %w", k.ID, err)
}
ks.certsDER[k.ID] = cert.Raw
if priv != nil {
ks.signers[k.ID] = tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: priv}
}
}
if _, ok := ks.signers[ks.activeID]; !ok {
return nil, errors.New("active signing key not available (missing or no private key)")
}
return ks, nil
}
func (ks *KeyStore) Active() tls.Certificate { return ks.signers[ks.activeID] }
func (ks *KeyStore) AllCertsDER() [][]byte {
out := make([][]byte, 0, len(ks.certsDER))
for _, der := range ks.certsDER {
out = append(out, der)
}
return out
}
func parseKeypair(certPEM, keyPEM string) (*x509.Certificate, interface{}, error) {
cb, _ := pem.Decode([]byte(certPEM))
if cb == nil {
return nil, nil, errors.New("invalid cert pem")
}
cert, err := x509.ParseCertificate(cb.Bytes)
if err != nil {
return nil, nil, err
}
var priv interface{}
if keyPEM != "" {
kb, _ := pem.Decode([]byte(keyPEM))
if kb == nil {
return nil, nil, errors.New("invalid key pem")
}
priv, err = x509.ParsePKCS8PrivateKey(kb.Bytes)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(kb.Bytes)
if err != nil {
return nil, nil, err
}
}
}
return cert, priv, nil
}

261
internal/http/handlers.go Normal file
View File

@@ -0,0 +1,261 @@
package http
import (
"compress/flate"
"context"
"crypto/rand"
"encoding/base64"
"encoding/xml"
"fmt"
"html/template"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/crewjam/saml"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/oidc"
idsaml "shamilnunhuck/saml-oidc-bridge/internal/saml"
)
type IdP interface {
Metadata() *saml.EntityDescriptor
BuildResponse(sp config.SP, nameID string, attrs map[string][]string) (*saml.Response, error)
}
func Register(
mux *http.ServeMux,
getCfg func() *config.Config,
getIdP func() *idsaml.IdP,
getOIDC func() *oidc.Client,
) {
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, _ = w.Write([]byte("ok"))
})
mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) {
meta := getIdP().Metadata()
buf, err := xml.MarshalIndent(meta, "", " ")
if err != nil {
http.Error(w, "marshal metadata: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, _ = w.Write([]byte(xml.Header))
_, _ = w.Write(buf)
})
mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) {
req, spEntityID, relay, err := parseAuthnRequest(r)
log.Printf("AuthnRequest from SP=%s requestID=%s relay=%q", spEntityID, req.ID, relay)
if err != nil {
http.Error(w, "bad authn request: "+err.Error(), 400)
return
}
setStateCookie(w, getCfg().Session, spEntityID, relay, req.ID)
state := randomState()
http.Redirect(w, r, getOIDC().AuthCodeURL(state, url.Values{}), http.StatusFound)
})
mux.HandleFunc(getCfg().OIDC.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
code := r.URL.Query().Get("code")
if code == "" {
http.Error(w, "missing code", 400)
return
}
claims, err := getOIDC().ExchangeAndVerify(ctx, code)
if err != nil {
http.Error(w, "oidc: "+err.Error(), 400)
return
}
s, err := readStateCookie(r, getCfg().Session)
if err != nil {
http.Error(w, "state missing", 400)
return
}
sp := lookupSP(getCfg(), s.SPEntityID)
if sp == nil {
http.Error(w, "unknown SP", 400)
return
}
attrs := map[string][]string{}
for samlAttr, oidcClaim := range sp.AttributeMapping {
switch oidcClaim {
case "email":
attrs[samlAttr] = []string{claims.Email}
case "name":
attrs[samlAttr] = []string{claims.Name}
case "role":
attrs[samlAttr] = []string{mapRole(claims.Groups, sp)}
}
}
nameID := claims.Email
groups := claims.Groups
// Apply generic conditional attribute rules
userGroups := toSet(groups) // []string -> set
for _, rule := range sp.AttributeRules {
if hasAnyGroup(userGroups, rule.IfGroupsAny) {
if rule.Value == "" { // safe default
attrs[rule.Name] = []string{"true"}
} else {
attrs[rule.Name] = []string{rule.Value}
}
} else if rule.EmitWhenFalse {
attrs[rule.Name] = []string{"false"}
}
}
resp, err := getIdP().BuildResponse(*sp, nameID, attrs, s.RequestID)
if err != nil {
http.Error(w, "saml: "+err.Error(), 500)
return
}
xmlBytes, err := idsaml.MarshalSignedResponse(resp)
if err != nil {
http.Error(w, "sign: "+err.Error(), 500)
return
}
postToACS(w, sp.ACSURL, base64.StdEncoding.EncodeToString(xmlBytes), s.RelayState)
})
}
/*** helpers ***/
type state struct {
SPEntityID string
RelayState string
RequestID string
Expiry time.Time
}
func setStateCookie(w http.ResponseWriter, s config.Session, spEntityID, relay, reqID string) {
v := url.Values{}
v.Set("sp", spEntityID)
v.Set("rs", relay)
v.Set("rid", reqID)
c := &http.Cookie{
Name: s.CookieName,
Value: base64.RawURLEncoding.EncodeToString([]byte(v.Encode())),
Path: "/",
Domain: s.CookieDomain,
HttpOnly: true,
Secure: s.CookieSecure,
MaxAge: 600,
}
http.SetCookie(w, c)
}
func readStateCookie(r *http.Request, s config.Session) (*state, error) {
c, err := r.Cookie(s.CookieName)
if err != nil {
return nil, err
}
b, err := base64.RawURLEncoding.DecodeString(c.Value)
if err != nil {
return nil, err
}
v, err := url.ParseQuery(string(b))
if err != nil {
return nil, err
}
return &state{
SPEntityID: v.Get("sp"),
RelayState: v.Get("rs"),
RequestID: v.Get("rid"),
Expiry: time.Now().Add(10 * time.Minute),
}, nil
}
func lookupSP(cfg *config.Config, entityID string) *config.SP {
for i := range cfg.SPs {
if cfg.SPs[i].EntityID == entityID {
return &cfg.SPs[i]
}
}
return nil
}
func mapRole(groups []string, sp *config.SP) string {
for _, g := range groups {
if v, ok := sp.RoleMapping[g]; ok {
return v
}
}
if v, ok := sp.RoleMapping["*"]; ok {
return v
}
return "user"
}
func postToACS(w http.ResponseWriter, acsURL string, samlResponseB64 string, relay string) {
const tpl = `<!doctype html>
<html><body onload="document.forms[0].submit()">
<form method="post" action="{{.ACS}}">
<input type="hidden" name="SAMLResponse" value="{{.Resp}}">
{{if .Relay}}<input type="hidden" name="RelayState" value="{{.Relay}}">{{end}}
<noscript><button type="submit">Continue</button></noscript>
</form></body></html>`
t := template.Must(template.New("post").Parse(tpl))
_ = t.Execute(w, map[string]string{"ACS": acsURL, "Resp": samlResponseB64, "Relay": relay})
}
func randomState() string {
var b [16]byte
_, _ = rand.Read(b[:])
return base64.RawURLEncoding.EncodeToString(b[:])
}
func parseAuthnRequest(r *http.Request) (*saml.AuthnRequest, string, string, error) {
relay := r.FormValue("RelayState")
if sr := r.URL.Query().Get("SAMLRequest"); sr != "" {
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
if err != nil {
return nil, "", "", fmt.Errorf("b64: %w", err)
}
reader := flate.NewReader(strings.NewReader(string(xmlBytes)))
defer reader.Close()
var sb strings.Builder
if _, err := io.Copy(&sb, reader); err != nil {
return nil, "", "", fmt.Errorf("inflate: %w", err)
}
var req saml.AuthnRequest
if err := xml.Unmarshal([]byte(sb.String()), &req); err != nil {
return nil, "", "", fmt.Errorf("xml: %w", err)
}
sp := ""
if req.Issuer != nil {
sp = req.Issuer.Value
}
return &req, sp, relay, nil
}
if sr := r.FormValue("SAMLRequest"); sr != "" {
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
if err != nil {
return nil, "", "", fmt.Errorf("b64: %w", err)
}
var req saml.AuthnRequest
if err := xml.Unmarshal(xmlBytes, &req); err != nil {
return nil, "", "", fmt.Errorf("xml: %w", err)
}
sp := ""
if req.Issuer != nil {
sp = req.Issuer.Value
}
return &req, sp, relay, nil
}
return nil, "", "", fmt.Errorf("missing SAMLRequest")
}

21
internal/http/util.go Normal file
View File

@@ -0,0 +1,21 @@
package http
func hasAnyGroup(user map[string]struct{}, want []string) bool {
if len(want) == 0 {
return false
}
for _, g := range want {
if _, ok := user[g]; ok {
return true
}
}
return false
}
func toSet(ss []string) map[string]struct{} {
m := make(map[string]struct{}, len(ss))
for _, s := range ss {
m[s] = struct{}{}
}
return m
}

70
internal/oidc/client.go Normal file
View File

@@ -0,0 +1,70 @@
package oidc
import (
"context"
"fmt"
"net/url"
gooidc "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
type Client struct {
Verifier *gooidc.IDTokenVerifier
OAuth2 *oauth2.Config
}
func NewClient(cfg *config.Config) (*Client, error) {
ctx := context.Background()
provider, err := gooidc.NewProvider(ctx, cfg.OIDC.Issuer)
if err != nil {
return nil, err
}
verifier := provider.Verifier(&gooidc.Config{ClientID: cfg.OIDC.ClientID})
redirect := cfg.Server.ExternalURL + cfg.OIDC.RedirectPath
scopes := []string{"openid"}
scopes = append(scopes, cfg.OIDC.Scopes...)
oauth2cfg := &oauth2.Config{
ClientID: cfg.OIDC.ClientID,
ClientSecret: cfg.OIDC.ClientSecret,
Endpoint: provider.Endpoint(),
Scopes: scopes,
RedirectURL: redirect,
}
return &Client{Verifier: verifier, OAuth2: oauth2cfg}, nil
}
type Claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
Groups []string `json:"groups"`
}
func (c *Client) AuthCodeURL(state string, extra url.Values) string {
return c.OAuth2.AuthCodeURL(state)
}
func (c *Client) ExchangeAndVerify(ctx context.Context, code string) (*Claims, error) {
token, err := c.OAuth2.Exchange(ctx, code)
if err != nil {
return nil, err
}
rawID, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("no id_token in token response")
}
idt, err := c.Verifier.Verify(ctx, rawID)
if err != nil {
return nil, err
}
var cl Claims
if err := idt.Claims(&cl); err != nil {
return nil, err
}
return &cl, nil
}

70
internal/saml/idp.go Normal file
View File

@@ -0,0 +1,70 @@
package saml
import (
"encoding/base64"
"time"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
"github.com/crewjam/saml"
)
type IdP struct {
keys *crypto.KeyStore
sec config.Security
entityID string
ssoURL string
}
func NewIdP(cfg *config.Config, ks *crypto.KeyStore) *IdP {
return &IdP{
keys: ks,
sec: cfg.Security,
entityID: cfg.Server.ExternalURL,
ssoURL: cfg.Server.ExternalURL + "/saml/sso",
}
}
func (i *IdP) Metadata() *saml.EntityDescriptor {
// we need to publish all certs, to allow safe rotation
keyDescriptors := []saml.KeyDescriptor{}
for _, der := range i.keys.AllCertsDER() {
keyDescriptors = append(keyDescriptors, saml.KeyDescriptor{
Use: "signing",
KeyInfo: saml.KeyInfo{
X509Data: saml.X509Data{
X509Certificates: []saml.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(der),
}},
},
},
})
}
entityDescriptor := &saml.EntityDescriptor{
EntityID: i.entityID,
// crewjam expects time.Time for ValidUntil and time.Duration for CacheDuration
ValidUntil: time.Now().UTC().Add(time.Duration(i.sec.MetadataValidUntilDays) * 24 * time.Hour),
CacheDuration: time.Duration(i.sec.MetadataCacheDurationSeconds) * time.Second,
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
SSODescriptor: saml.SSODescriptor{
RoleDescriptor: saml.RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
KeyDescriptors: keyDescriptors,
},
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPPostBinding,
Location: i.ssoURL,
},
},
},
},
}
return entityDescriptor
}

204
internal/saml/sign.go Normal file
View File

@@ -0,0 +1,204 @@
package saml
import (
"crypto"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
const (
subjectConfirmationMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
nameIDFormatEntity = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"
authnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
defaultAssertionTTL = 5 * time.Minute
)
func (i *IdP) BuildResponse(sp config.SP, nameID string, attrs map[string][]string, inResponseTo string) (*saml.Response, error) {
now := saml.TimeNow()
ttl := time.Duration(i.sec.AssertionTTLSec) * time.Second
if ttl <= 0 {
ttl = defaultAssertionTTL
}
skew := time.Duration(i.sec.SkewSeconds) * time.Second
assertionID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate assertion id: %w", err)
}
responseID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate response id: %w", err)
}
audience := sp.Audience
if audience == "" {
audience = sp.EntityID
}
nameIDFormat := sp.NameIDFormat
if nameIDFormat == "" {
nameIDFormat = string(saml.UnspecifiedNameIDFormat)
}
assertion := &saml.Assertion{
ID: assertionID,
IssueInstant: now,
Version: "2.0",
Issuer: saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Subject: &saml.Subject{
NameID: &saml.NameID{
Format: nameIDFormat,
Value: nameID,
},
SubjectConfirmations: []saml.SubjectConfirmation{
{
Method: subjectConfirmationMethodBearer,
SubjectConfirmationData: &saml.SubjectConfirmationData{
InResponseTo: inResponseTo,
NotOnOrAfter: now.Add(ttl),
Recipient: sp.ACSURL,
},
},
},
},
Conditions: &saml.Conditions{
NotBefore: now.Add(-skew),
NotOnOrAfter: now.Add(ttl),
AudienceRestrictions: []saml.AudienceRestriction{
{
Audience: saml.Audience{Value: audience},
},
},
},
AuthnStatements: []saml.AuthnStatement{
{
AuthnInstant: now,
SessionIndex: responseID,
AuthnContext: saml.AuthnContext{
AuthnContextClassRef: &saml.AuthnContextClassRef{Value: authnContextPasswordProtectedTransport},
},
},
},
AttributeStatements: []saml.AttributeStatement{
{
Attributes: toSAMLAttributes(attrs),
},
},
}
resp := &saml.Response{
ID: responseID,
Version: "2.0",
IssueInstant: now,
InResponseTo: inResponseTo,
Destination: sp.ACSURL,
Issuer: &saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Status: saml.Status{
StatusCode: saml.StatusCode{Value: saml.StatusSuccess},
},
Assertion: assertion,
}
if err := i.signResponse(resp); err != nil {
return nil, err
}
return resp, nil
}
func (i *IdP) signingContext() (*dsig.SigningContext, error) {
keyPair := i.keys.Active()
if len(keyPair.Certificate) == 0 || keyPair.PrivateKey == nil {
return nil, errors.New("active key missing certificate or private key")
}
ctx := dsig.NewDefaultSigningContext(dsig.TLSCertKeyStore(keyPair))
ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
if err := ctx.SetSignatureMethod(dsig.RSASHA256SignatureMethod); err != nil {
return nil, err
}
ctx.Hash = crypto.SHA256
return ctx, nil
}
func (i *IdP) signResponse(resp *saml.Response) error {
if resp.Assertion == nil {
return errors.New("response missing assertion")
}
assertionCtx, err := i.signingContext()
if err != nil {
return err
}
assertionEl := resp.Assertion.Element()
signedAssertionEl, err := assertionCtx.SignEnveloped(assertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
sigEl, err := lastChildElement(signedAssertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
resp.Assertion.Signature = sigEl
responseCtx, err := i.signingContext()
if err != nil {
return err
}
responseEl := resp.Element()
signedResponseEl, err := responseCtx.SignEnveloped(responseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
sigEl, err = lastChildElement(signedResponseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
resp.Signature = sigEl
return nil
}
func MarshalSignedResponse(resp *saml.Response) ([]byte, error) {
if resp == nil {
return nil, errors.New("nil response")
}
if resp.Signature == nil {
return nil, errors.New("response not signed")
}
doc := etree.NewDocument()
doc.SetRoot(resp.Element())
return doc.WriteToBytes()
}
func lastChildElement(parent *etree.Element) (*etree.Element, error) {
children := parent.ChildElements()
if len(children) == 0 {
return nil, errors.New("no child elements found")
}
return children[len(children)-1], nil
}
func newSAMLID() (string, error) {
var b [20]byte
if _, err := rand.Read(b[:]); err != nil {
return "", err
}
return "_" + hex.EncodeToString(b[:]), nil
}

26
internal/saml/util.go Normal file
View File

@@ -0,0 +1,26 @@
package saml
import "github.com/crewjam/saml"
const attrNameFormatURI = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
func toSAMLAttributes(attrs map[string][]string) []saml.Attribute {
out := make([]saml.Attribute, 0, len(attrs))
for name, vals := range attrs {
vs := make([]saml.AttributeValue, 0, len(vals))
for _, s := range vals {
vs = append(vs, saml.AttributeValue{
// explicitly mark string type so we never emit xsi:type=""
Type: "xs:string",
Value: s,
})
}
out = append(out, saml.Attribute{
FriendlyName: name,
Name: name,
NameFormat: attrNameFormatURI,
Values: vs,
})
}
return out
}