initial commit
This commit is contained in:
206
internal/cli/cert.go
Normal file
206
internal/cli/cert.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
)
|
||||
|
||||
type rotateOpts struct {
|
||||
ConfigPath string
|
||||
ID string
|
||||
Algo string
|
||||
Days int
|
||||
CN string
|
||||
Org string
|
||||
OutK8s string
|
||||
ActiveOnly bool
|
||||
}
|
||||
|
||||
func RunCert(args []string) error {
|
||||
fs := flag.NewFlagSet("cert", flag.ContinueOnError)
|
||||
var ro rotateOpts
|
||||
fs.StringVar(&ro.ConfigPath, "config", "example.config.yaml", "path to config yaml")
|
||||
fs.StringVar(&ro.ID, "id", "", "key id (e.g. k-2025-10)")
|
||||
fs.StringVar(&ro.Algo, "algo", "rsa3072", "rsa2048|rsa3072|rsa4096|p256|p384")
|
||||
fs.IntVar(&ro.Days, "days", 825, "validity in days")
|
||||
fs.StringVar(&ro.CN, "cn", "id.example.com", "certificate CN")
|
||||
fs.StringVar(&ro.Org, "org", "YourOrg", "certificate O")
|
||||
fs.StringVar(&ro.OutK8s, "k8s-secret-out", "", "write a Kubernetes Secret manifest to this path")
|
||||
fs.BoolVar(&ro.ActiveOnly, "active-only", false, "only set active_key to -id (no new cert)")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
if ro.ID == "" {
|
||||
return errors.New("missing -id")
|
||||
}
|
||||
|
||||
cfg, raw, err := loadConfig(ro.ConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ro.ActiveOnly {
|
||||
cfg.Crypto.ActiveKey = ro.ID
|
||||
return saveConfig(ro.ConfigPath, raw, cfg)
|
||||
}
|
||||
|
||||
certPEM, keyPEM, notAfter, err := genSelfSigned(ro)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Crypto.Keys = append(cfg.Crypto.Keys, config.KeyPair{
|
||||
ID: ro.ID,
|
||||
CertPEM: string(certPEM),
|
||||
KeyPEM: string(keyPEM),
|
||||
NotAfter: notAfter.UTC(),
|
||||
})
|
||||
cfg.Crypto.ActiveKey = ro.ID
|
||||
|
||||
if err := saveConfig(ro.ConfigPath, raw, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if ro.OutK8s != "" {
|
||||
if err := os.WriteFile(ro.OutK8s, []byte(k8sSecretYAML(ro, certPEM, keyPEM)), 0o600); err != nil {
|
||||
return fmt.Errorf("write secret: %w", err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("OK: generated %s (algo=%s, not_after=%s) and set active_key\n", ro.ID, ro.Algo, notAfter.UTC().Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
func genSelfSigned(ro rotateOpts) (certPEM, keyPEM []byte, notAfter time.Time, err error) {
|
||||
nb := time.Now().Add(-5 * time.Minute)
|
||||
na := nb.Add(time.Duration(ro.Days) * 24 * time.Hour)
|
||||
|
||||
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
CommonName: ro.CN,
|
||||
Organization: []string{ro.Org},
|
||||
OrganizationalUnit: []string{"SAML Signing"},
|
||||
},
|
||||
NotBefore: nb,
|
||||
NotAfter: na,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
var der []byte
|
||||
var keyPKCS8 []byte
|
||||
|
||||
switch strings.ToLower(ro.Algo) {
|
||||
case "rsa2048", "rsa3072", "rsa4096":
|
||||
bits := 2048
|
||||
if ro.Algo == "rsa3072" {
|
||||
bits = 3072
|
||||
}
|
||||
if ro.Algo == "rsa4096" {
|
||||
bits = 4096
|
||||
}
|
||||
priv, e := rsa.GenerateKey(rand.Reader, bits)
|
||||
if e != nil {
|
||||
return nil, nil, time.Time{}, e
|
||||
}
|
||||
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, time.Time{}, err
|
||||
}
|
||||
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, nil, time.Time{}, err
|
||||
}
|
||||
case "p256", "p384":
|
||||
curve := elliptic.P256()
|
||||
if ro.Algo == "p384" {
|
||||
curve = elliptic.P384()
|
||||
}
|
||||
priv, e := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if e != nil {
|
||||
return nil, nil, time.Time{}, e
|
||||
}
|
||||
der, err = x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, time.Time{}, err
|
||||
}
|
||||
keyPKCS8, err = x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, nil, time.Time{}, err
|
||||
}
|
||||
default:
|
||||
return nil, nil, time.Time{}, fmt.Errorf("unknown -algo %q", ro.Algo)
|
||||
}
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyPKCS8})
|
||||
return certPEM, keyPEM, na, nil
|
||||
}
|
||||
|
||||
func loadConfig(path string) (*config.Config, *yaml.Node, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(b, &root); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var c config.Config
|
||||
if err := root.Decode(&c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &c, &root, nil
|
||||
}
|
||||
|
||||
func saveConfig(path string, _ *yaml.Node, c *config.Config) error {
|
||||
var buf bytes.Buffer
|
||||
enc := yaml.NewEncoder(&buf)
|
||||
enc.SetIndent(2)
|
||||
if err := enc.Encode(c); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = enc.Close()
|
||||
return os.WriteFile(path, buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
func k8sSecretYAML(ro rotateOpts, certPEM, keyPEM []byte) string {
|
||||
name := strings.ToLower(strings.ReplaceAll(ro.ID, "_", "-"))
|
||||
return fmt.Sprintf(`apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: saml-signing-%s
|
||||
type: Opaque
|
||||
stringData:
|
||||
cert.pem: |-
|
||||
%s
|
||||
key.pem: |-
|
||||
%s
|
||||
`, name, indent(string(certPEM), 4), indent(string(keyPEM), 4))
|
||||
}
|
||||
|
||||
func indent(s string, n int) string {
|
||||
pad := strings.Repeat(" ", n)
|
||||
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
|
||||
for i := range lines {
|
||||
lines[i] = pad + lines[i]
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
36
internal/config/config.go
Normal file
36
internal/config/config.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var c Config
|
||||
if err := yaml.Unmarshal(b, &c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.ExternalURL == "" || c.Server.Listen == "" {
|
||||
return fmt.Errorf("server.external_url and server.listen required")
|
||||
}
|
||||
if len(c.SPs) == 0 {
|
||||
return fmt.Errorf("at least one SP required")
|
||||
}
|
||||
if c.OIDC.Issuer == "" || c.OIDC.ClientID == "" || c.OIDC.RedirectPath == "" {
|
||||
return fmt.Errorf("oidc issuer/client_id/redirect_path required")
|
||||
}
|
||||
if c.Crypto.ActiveKey == "" || len(c.Crypto.Keys) == 0 {
|
||||
return fmt.Errorf("crypto.active_key and at least one key required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
68
internal/config/types.go
Normal file
68
internal/config/types.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
type Server struct {
|
||||
Listen string `yaml:"listen"`
|
||||
ExternalURL string `yaml:"external_url"`
|
||||
}
|
||||
|
||||
type KeyPair struct {
|
||||
ID string `yaml:"id"`
|
||||
CertPEM string `yaml:"cert_pem"`
|
||||
KeyPEM string `yaml:"key_pem"`
|
||||
NotAfter time.Time `yaml:"not_after"`
|
||||
}
|
||||
type Crypto struct {
|
||||
ActiveKey string `yaml:"active_key"`
|
||||
Keys []KeyPair `yaml:"keys"`
|
||||
}
|
||||
|
||||
type OIDC struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"-"`
|
||||
RedirectPath string `yaml:"redirect_path"`
|
||||
Scopes []string `yaml:"scopes"`
|
||||
}
|
||||
|
||||
type SP struct {
|
||||
Name string `yaml:"name"`
|
||||
EntityID string `yaml:"entity_id"`
|
||||
ACSURL string `yaml:"acs_url"`
|
||||
Audience string `yaml:"audience"`
|
||||
NameIDFormat string `yaml:"nameid_format"`
|
||||
AttributeMapping map[string]string `yaml:"attribute_mapping"`
|
||||
RoleMapping map[string]string `yaml:"role_mapping"`
|
||||
AttributeRules []AttributeRule `yaml:"attribute_rules"`
|
||||
}
|
||||
|
||||
type Security struct {
|
||||
SkewSeconds int `yaml:"skew_seconds"`
|
||||
AssertionTTLSec int `yaml:"assertion_ttl_seconds"`
|
||||
RequireSignedAuthnRequest bool `yaml:"require_signed_authn_request"`
|
||||
MetadataValidUntilDays int `yaml:"metadata_valid_until_days"`
|
||||
MetadataCacheDurationSeconds int `yaml:"metadata_cache_duration_seconds"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
CookieName string `yaml:"cookie_name"`
|
||||
CookieSecure bool `yaml:"cookie_secure"`
|
||||
CookieDomain string `yaml:"cookie_domain"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Server Server `yaml:"server"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
OIDC OIDC `yaml:"oidc_upstream"`
|
||||
SPs []SP `yaml:"sps"`
|
||||
Security Security `yaml:"security"`
|
||||
Session Session `yaml:"session"`
|
||||
}
|
||||
|
||||
type AttributeRule struct {
|
||||
Name string `yaml:"name"`
|
||||
Value string `yaml:"value"`
|
||||
IfGroupsAny []string `yaml:"if_groups_any"`
|
||||
EmitWhenFalse bool `yaml:"emit_when_false"`
|
||||
}
|
||||
74
internal/crypto/keystore.go
Normal file
74
internal/crypto/keystore.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
)
|
||||
|
||||
type KeyStore struct {
|
||||
activeID string
|
||||
signers map[string]tls.Certificate
|
||||
certsDER map[string][]byte
|
||||
}
|
||||
|
||||
func NewKeyStore(c config.Crypto) (*KeyStore, error) {
|
||||
ks := &KeyStore{
|
||||
activeID: c.ActiveKey,
|
||||
signers: map[string]tls.Certificate{},
|
||||
certsDER: map[string][]byte{},
|
||||
}
|
||||
for _, k := range c.Keys {
|
||||
cert, priv, err := parseKeypair(k.CertPEM, k.KeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("key %s: %w", k.ID, err)
|
||||
}
|
||||
ks.certsDER[k.ID] = cert.Raw
|
||||
if priv != nil {
|
||||
ks.signers[k.ID] = tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: priv}
|
||||
}
|
||||
}
|
||||
if _, ok := ks.signers[ks.activeID]; !ok {
|
||||
return nil, errors.New("active signing key not available (missing or no private key)")
|
||||
}
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (ks *KeyStore) Active() tls.Certificate { return ks.signers[ks.activeID] }
|
||||
func (ks *KeyStore) AllCertsDER() [][]byte {
|
||||
out := make([][]byte, 0, len(ks.certsDER))
|
||||
for _, der := range ks.certsDER {
|
||||
out = append(out, der)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseKeypair(certPEM, keyPEM string) (*x509.Certificate, interface{}, error) {
|
||||
cb, _ := pem.Decode([]byte(certPEM))
|
||||
if cb == nil {
|
||||
return nil, nil, errors.New("invalid cert pem")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(cb.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var priv interface{}
|
||||
if keyPEM != "" {
|
||||
kb, _ := pem.Decode([]byte(keyPEM))
|
||||
if kb == nil {
|
||||
return nil, nil, errors.New("invalid key pem")
|
||||
}
|
||||
priv, err = x509.ParsePKCS8PrivateKey(kb.Bytes)
|
||||
if err != nil {
|
||||
priv, err = x509.ParsePKCS1PrivateKey(kb.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return cert, priv, nil
|
||||
}
|
||||
261
internal/http/handlers.go
Normal file
261
internal/http/handlers.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/oidc"
|
||||
idsaml "shamilnunhuck/saml-oidc-bridge/internal/saml"
|
||||
)
|
||||
|
||||
type IdP interface {
|
||||
Metadata() *saml.EntityDescriptor
|
||||
BuildResponse(sp config.SP, nameID string, attrs map[string][]string) (*saml.Response, error)
|
||||
}
|
||||
|
||||
func Register(
|
||||
mux *http.ServeMux,
|
||||
getCfg func() *config.Config,
|
||||
getIdP func() *idsaml.IdP,
|
||||
getOIDC func() *oidc.Client,
|
||||
) {
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) {
|
||||
meta := getIdP().Metadata()
|
||||
buf, err := xml.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, "marshal metadata: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/samlmetadata+xml")
|
||||
_, _ = w.Write([]byte(xml.Header))
|
||||
_, _ = w.Write(buf)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) {
|
||||
req, spEntityID, relay, err := parseAuthnRequest(r)
|
||||
log.Printf("AuthnRequest from SP=%s requestID=%s relay=%q", spEntityID, req.ID, relay)
|
||||
if err != nil {
|
||||
http.Error(w, "bad authn request: "+err.Error(), 400)
|
||||
return
|
||||
}
|
||||
setStateCookie(w, getCfg().Session, spEntityID, relay, req.ID)
|
||||
state := randomState()
|
||||
http.Redirect(w, r, getOIDC().AuthCodeURL(state, url.Values{}), http.StatusFound)
|
||||
})
|
||||
|
||||
mux.HandleFunc(getCfg().OIDC.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", 400)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := getOIDC().ExchangeAndVerify(ctx, code)
|
||||
if err != nil {
|
||||
http.Error(w, "oidc: "+err.Error(), 400)
|
||||
return
|
||||
}
|
||||
|
||||
s, err := readStateCookie(r, getCfg().Session)
|
||||
if err != nil {
|
||||
http.Error(w, "state missing", 400)
|
||||
return
|
||||
}
|
||||
|
||||
sp := lookupSP(getCfg(), s.SPEntityID)
|
||||
if sp == nil {
|
||||
http.Error(w, "unknown SP", 400)
|
||||
return
|
||||
}
|
||||
|
||||
attrs := map[string][]string{}
|
||||
for samlAttr, oidcClaim := range sp.AttributeMapping {
|
||||
switch oidcClaim {
|
||||
case "email":
|
||||
attrs[samlAttr] = []string{claims.Email}
|
||||
case "name":
|
||||
attrs[samlAttr] = []string{claims.Name}
|
||||
case "role":
|
||||
attrs[samlAttr] = []string{mapRole(claims.Groups, sp)}
|
||||
}
|
||||
}
|
||||
|
||||
nameID := claims.Email
|
||||
groups := claims.Groups
|
||||
|
||||
// Apply generic conditional attribute rules
|
||||
userGroups := toSet(groups) // []string -> set
|
||||
for _, rule := range sp.AttributeRules {
|
||||
if hasAnyGroup(userGroups, rule.IfGroupsAny) {
|
||||
if rule.Value == "" { // safe default
|
||||
attrs[rule.Name] = []string{"true"}
|
||||
} else {
|
||||
attrs[rule.Name] = []string{rule.Value}
|
||||
}
|
||||
} else if rule.EmitWhenFalse {
|
||||
attrs[rule.Name] = []string{"false"}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := getIdP().BuildResponse(*sp, nameID, attrs, s.RequestID)
|
||||
if err != nil {
|
||||
http.Error(w, "saml: "+err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
xmlBytes, err := idsaml.MarshalSignedResponse(resp)
|
||||
if err != nil {
|
||||
http.Error(w, "sign: "+err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
postToACS(w, sp.ACSURL, base64.StdEncoding.EncodeToString(xmlBytes), s.RelayState)
|
||||
})
|
||||
}
|
||||
|
||||
/*** helpers ***/
|
||||
|
||||
type state struct {
|
||||
SPEntityID string
|
||||
RelayState string
|
||||
RequestID string
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
func setStateCookie(w http.ResponseWriter, s config.Session, spEntityID, relay, reqID string) {
|
||||
v := url.Values{}
|
||||
v.Set("sp", spEntityID)
|
||||
v.Set("rs", relay)
|
||||
v.Set("rid", reqID)
|
||||
c := &http.Cookie{
|
||||
Name: s.CookieName,
|
||||
Value: base64.RawURLEncoding.EncodeToString([]byte(v.Encode())),
|
||||
Path: "/",
|
||||
Domain: s.CookieDomain,
|
||||
HttpOnly: true,
|
||||
Secure: s.CookieSecure,
|
||||
MaxAge: 600,
|
||||
}
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
func readStateCookie(r *http.Request, s config.Session) (*state, error) {
|
||||
c, err := r.Cookie(s.CookieName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, err := base64.RawURLEncoding.DecodeString(c.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v, err := url.ParseQuery(string(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &state{
|
||||
SPEntityID: v.Get("sp"),
|
||||
RelayState: v.Get("rs"),
|
||||
RequestID: v.Get("rid"),
|
||||
Expiry: time.Now().Add(10 * time.Minute),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func lookupSP(cfg *config.Config, entityID string) *config.SP {
|
||||
for i := range cfg.SPs {
|
||||
if cfg.SPs[i].EntityID == entityID {
|
||||
return &cfg.SPs[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapRole(groups []string, sp *config.SP) string {
|
||||
for _, g := range groups {
|
||||
if v, ok := sp.RoleMapping[g]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if v, ok := sp.RoleMapping["*"]; ok {
|
||||
return v
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func postToACS(w http.ResponseWriter, acsURL string, samlResponseB64 string, relay string) {
|
||||
const tpl = `<!doctype html>
|
||||
<html><body onload="document.forms[0].submit()">
|
||||
<form method="post" action="{{.ACS}}">
|
||||
<input type="hidden" name="SAMLResponse" value="{{.Resp}}">
|
||||
{{if .Relay}}<input type="hidden" name="RelayState" value="{{.Relay}}">{{end}}
|
||||
<noscript><button type="submit">Continue</button></noscript>
|
||||
</form></body></html>`
|
||||
t := template.Must(template.New("post").Parse(tpl))
|
||||
_ = t.Execute(w, map[string]string{"ACS": acsURL, "Resp": samlResponseB64, "Relay": relay})
|
||||
}
|
||||
|
||||
func randomState() string {
|
||||
var b [16]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return base64.RawURLEncoding.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func parseAuthnRequest(r *http.Request) (*saml.AuthnRequest, string, string, error) {
|
||||
relay := r.FormValue("RelayState")
|
||||
if sr := r.URL.Query().Get("SAMLRequest"); sr != "" {
|
||||
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("b64: %w", err)
|
||||
}
|
||||
reader := flate.NewReader(strings.NewReader(string(xmlBytes)))
|
||||
defer reader.Close()
|
||||
var sb strings.Builder
|
||||
if _, err := io.Copy(&sb, reader); err != nil {
|
||||
return nil, "", "", fmt.Errorf("inflate: %w", err)
|
||||
}
|
||||
var req saml.AuthnRequest
|
||||
if err := xml.Unmarshal([]byte(sb.String()), &req); err != nil {
|
||||
return nil, "", "", fmt.Errorf("xml: %w", err)
|
||||
}
|
||||
sp := ""
|
||||
if req.Issuer != nil {
|
||||
sp = req.Issuer.Value
|
||||
}
|
||||
return &req, sp, relay, nil
|
||||
}
|
||||
if sr := r.FormValue("SAMLRequest"); sr != "" {
|
||||
xmlBytes, err := base64.StdEncoding.DecodeString(sr)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("b64: %w", err)
|
||||
}
|
||||
var req saml.AuthnRequest
|
||||
if err := xml.Unmarshal(xmlBytes, &req); err != nil {
|
||||
return nil, "", "", fmt.Errorf("xml: %w", err)
|
||||
}
|
||||
sp := ""
|
||||
if req.Issuer != nil {
|
||||
sp = req.Issuer.Value
|
||||
}
|
||||
return &req, sp, relay, nil
|
||||
}
|
||||
return nil, "", "", fmt.Errorf("missing SAMLRequest")
|
||||
}
|
||||
21
internal/http/util.go
Normal file
21
internal/http/util.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package http
|
||||
|
||||
func hasAnyGroup(user map[string]struct{}, want []string) bool {
|
||||
if len(want) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, g := range want {
|
||||
if _, ok := user[g]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toSet(ss []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(ss))
|
||||
for _, s := range ss {
|
||||
m[s] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
70
internal/oidc/client.go
Normal file
70
internal/oidc/client.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Verifier *gooidc.IDTokenVerifier
|
||||
OAuth2 *oauth2.Config
|
||||
}
|
||||
|
||||
func NewClient(cfg *config.Config) (*Client, error) {
|
||||
ctx := context.Background()
|
||||
provider, err := gooidc.NewProvider(ctx, cfg.OIDC.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verifier := provider.Verifier(&gooidc.Config{ClientID: cfg.OIDC.ClientID})
|
||||
redirect := cfg.Server.ExternalURL + cfg.OIDC.RedirectPath
|
||||
|
||||
scopes := []string{"openid"}
|
||||
scopes = append(scopes, cfg.OIDC.Scopes...)
|
||||
|
||||
oauth2cfg := &oauth2.Config{
|
||||
ClientID: cfg.OIDC.ClientID,
|
||||
ClientSecret: cfg.OIDC.ClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
RedirectURL: redirect,
|
||||
}
|
||||
return &Client{Verifier: verifier, OAuth2: oauth2cfg}, nil
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
func (c *Client) AuthCodeURL(state string, extra url.Values) string {
|
||||
return c.OAuth2.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
func (c *Client) ExchangeAndVerify(ctx context.Context, code string) (*Claims, error) {
|
||||
token, err := c.OAuth2.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawID, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no id_token in token response")
|
||||
}
|
||||
idt, err := c.Verifier.Verify(ctx, rawID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var cl Claims
|
||||
if err := idt.Claims(&cl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cl, nil
|
||||
}
|
||||
70
internal/saml/idp.go
Normal file
70
internal/saml/idp.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
)
|
||||
|
||||
type IdP struct {
|
||||
keys *crypto.KeyStore
|
||||
sec config.Security
|
||||
entityID string
|
||||
ssoURL string
|
||||
}
|
||||
|
||||
func NewIdP(cfg *config.Config, ks *crypto.KeyStore) *IdP {
|
||||
return &IdP{
|
||||
keys: ks,
|
||||
sec: cfg.Security,
|
||||
entityID: cfg.Server.ExternalURL,
|
||||
ssoURL: cfg.Server.ExternalURL + "/saml/sso",
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IdP) Metadata() *saml.EntityDescriptor {
|
||||
// we need to publish all certs, to allow safe rotation
|
||||
keyDescriptors := []saml.KeyDescriptor{}
|
||||
|
||||
for _, der := range i.keys.AllCertsDER() {
|
||||
keyDescriptors = append(keyDescriptors, saml.KeyDescriptor{
|
||||
Use: "signing",
|
||||
KeyInfo: saml.KeyInfo{
|
||||
X509Data: saml.X509Data{
|
||||
X509Certificates: []saml.X509Certificate{{
|
||||
Data: base64.StdEncoding.EncodeToString(der),
|
||||
}},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
entityDescriptor := &saml.EntityDescriptor{
|
||||
EntityID: i.entityID,
|
||||
// crewjam expects time.Time for ValidUntil and time.Duration for CacheDuration
|
||||
ValidUntil: time.Now().UTC().Add(time.Duration(i.sec.MetadataValidUntilDays) * 24 * time.Hour),
|
||||
CacheDuration: time.Duration(i.sec.MetadataCacheDurationSeconds) * time.Second,
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
SSODescriptor: saml.SSODescriptor{
|
||||
RoleDescriptor: saml.RoleDescriptor{
|
||||
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
|
||||
KeyDescriptors: keyDescriptors,
|
||||
},
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPPostBinding,
|
||||
Location: i.ssoURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return entityDescriptor
|
||||
}
|
||||
204
internal/saml/sign.go
Normal file
204
internal/saml/sign.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/crewjam/saml"
|
||||
dsig "github.com/russellhaering/goxmldsig"
|
||||
|
||||
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
subjectConfirmationMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
|
||||
nameIDFormatEntity = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"
|
||||
authnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
|
||||
defaultAssertionTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
func (i *IdP) BuildResponse(sp config.SP, nameID string, attrs map[string][]string, inResponseTo string) (*saml.Response, error) {
|
||||
now := saml.TimeNow()
|
||||
ttl := time.Duration(i.sec.AssertionTTLSec) * time.Second
|
||||
if ttl <= 0 {
|
||||
ttl = defaultAssertionTTL
|
||||
}
|
||||
skew := time.Duration(i.sec.SkewSeconds) * time.Second
|
||||
|
||||
assertionID, err := newSAMLID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate assertion id: %w", err)
|
||||
}
|
||||
responseID, err := newSAMLID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate response id: %w", err)
|
||||
}
|
||||
|
||||
audience := sp.Audience
|
||||
if audience == "" {
|
||||
audience = sp.EntityID
|
||||
}
|
||||
|
||||
nameIDFormat := sp.NameIDFormat
|
||||
if nameIDFormat == "" {
|
||||
nameIDFormat = string(saml.UnspecifiedNameIDFormat)
|
||||
}
|
||||
|
||||
assertion := &saml.Assertion{
|
||||
ID: assertionID,
|
||||
IssueInstant: now,
|
||||
Version: "2.0",
|
||||
Issuer: saml.Issuer{
|
||||
Format: nameIDFormatEntity,
|
||||
Value: i.entityID,
|
||||
},
|
||||
Subject: &saml.Subject{
|
||||
NameID: &saml.NameID{
|
||||
Format: nameIDFormat,
|
||||
Value: nameID,
|
||||
},
|
||||
SubjectConfirmations: []saml.SubjectConfirmation{
|
||||
{
|
||||
Method: subjectConfirmationMethodBearer,
|
||||
SubjectConfirmationData: &saml.SubjectConfirmationData{
|
||||
InResponseTo: inResponseTo,
|
||||
NotOnOrAfter: now.Add(ttl),
|
||||
Recipient: sp.ACSURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Conditions: &saml.Conditions{
|
||||
NotBefore: now.Add(-skew),
|
||||
NotOnOrAfter: now.Add(ttl),
|
||||
AudienceRestrictions: []saml.AudienceRestriction{
|
||||
{
|
||||
Audience: saml.Audience{Value: audience},
|
||||
},
|
||||
},
|
||||
},
|
||||
AuthnStatements: []saml.AuthnStatement{
|
||||
{
|
||||
AuthnInstant: now,
|
||||
SessionIndex: responseID,
|
||||
AuthnContext: saml.AuthnContext{
|
||||
AuthnContextClassRef: &saml.AuthnContextClassRef{Value: authnContextPasswordProtectedTransport},
|
||||
},
|
||||
},
|
||||
},
|
||||
AttributeStatements: []saml.AttributeStatement{
|
||||
{
|
||||
Attributes: toSAMLAttributes(attrs),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := &saml.Response{
|
||||
ID: responseID,
|
||||
Version: "2.0",
|
||||
IssueInstant: now,
|
||||
InResponseTo: inResponseTo,
|
||||
Destination: sp.ACSURL,
|
||||
Issuer: &saml.Issuer{
|
||||
Format: nameIDFormatEntity,
|
||||
Value: i.entityID,
|
||||
},
|
||||
Status: saml.Status{
|
||||
StatusCode: saml.StatusCode{Value: saml.StatusSuccess},
|
||||
},
|
||||
Assertion: assertion,
|
||||
}
|
||||
|
||||
if err := i.signResponse(resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (i *IdP) signingContext() (*dsig.SigningContext, error) {
|
||||
keyPair := i.keys.Active()
|
||||
if len(keyPair.Certificate) == 0 || keyPair.PrivateKey == nil {
|
||||
return nil, errors.New("active key missing certificate or private key")
|
||||
}
|
||||
|
||||
ctx := dsig.NewDefaultSigningContext(dsig.TLSCertKeyStore(keyPair))
|
||||
ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
|
||||
if err := ctx.SetSignatureMethod(dsig.RSASHA256SignatureMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx.Hash = crypto.SHA256
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (i *IdP) signResponse(resp *saml.Response) error {
|
||||
if resp.Assertion == nil {
|
||||
return errors.New("response missing assertion")
|
||||
}
|
||||
|
||||
assertionCtx, err := i.signingContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
assertionEl := resp.Assertion.Element()
|
||||
signedAssertionEl, err := assertionCtx.SignEnveloped(assertionEl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign assertion: %w", err)
|
||||
}
|
||||
sigEl, err := lastChildElement(signedAssertionEl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign assertion: %w", err)
|
||||
}
|
||||
resp.Assertion.Signature = sigEl
|
||||
|
||||
responseCtx, err := i.signingContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseEl := resp.Element()
|
||||
signedResponseEl, err := responseCtx.SignEnveloped(responseEl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign response: %w", err)
|
||||
}
|
||||
sigEl, err = lastChildElement(signedResponseEl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign response: %w", err)
|
||||
}
|
||||
resp.Signature = sigEl
|
||||
return nil
|
||||
}
|
||||
|
||||
func MarshalSignedResponse(resp *saml.Response) ([]byte, error) {
|
||||
if resp == nil {
|
||||
return nil, errors.New("nil response")
|
||||
}
|
||||
if resp.Signature == nil {
|
||||
return nil, errors.New("response not signed")
|
||||
}
|
||||
|
||||
doc := etree.NewDocument()
|
||||
doc.SetRoot(resp.Element())
|
||||
return doc.WriteToBytes()
|
||||
}
|
||||
|
||||
func lastChildElement(parent *etree.Element) (*etree.Element, error) {
|
||||
children := parent.ChildElements()
|
||||
if len(children) == 0 {
|
||||
return nil, errors.New("no child elements found")
|
||||
}
|
||||
return children[len(children)-1], nil
|
||||
}
|
||||
|
||||
func newSAMLID() (string, error) {
|
||||
var b [20]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "_" + hex.EncodeToString(b[:]), nil
|
||||
}
|
||||
26
internal/saml/util.go
Normal file
26
internal/saml/util.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package saml
|
||||
|
||||
import "github.com/crewjam/saml"
|
||||
|
||||
const attrNameFormatURI = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
|
||||
|
||||
func toSAMLAttributes(attrs map[string][]string) []saml.Attribute {
|
||||
out := make([]saml.Attribute, 0, len(attrs))
|
||||
for name, vals := range attrs {
|
||||
vs := make([]saml.AttributeValue, 0, len(vals))
|
||||
for _, s := range vals {
|
||||
vs = append(vs, saml.AttributeValue{
|
||||
// explicitly mark string type so we never emit xsi:type=""
|
||||
Type: "xs:string",
|
||||
Value: s,
|
||||
})
|
||||
}
|
||||
out = append(out, saml.Attribute{
|
||||
FriendlyName: name,
|
||||
Name: name,
|
||||
NameFormat: attrNameFormatURI,
|
||||
Values: vs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user