initial commit

This commit is contained in:
Shamil Nunhuck
2025-11-08 10:18:19 +00:00
commit 920a79b2e9
25 changed files with 1523 additions and 0 deletions

70
internal/saml/idp.go Normal file
View File

@@ -0,0 +1,70 @@
package saml
import (
"encoding/base64"
"time"
"shamilnunhuck/saml-oidc-bridge/internal/config"
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
"github.com/crewjam/saml"
)
type IdP struct {
keys *crypto.KeyStore
sec config.Security
entityID string
ssoURL string
}
func NewIdP(cfg *config.Config, ks *crypto.KeyStore) *IdP {
return &IdP{
keys: ks,
sec: cfg.Security,
entityID: cfg.Server.ExternalURL,
ssoURL: cfg.Server.ExternalURL + "/saml/sso",
}
}
func (i *IdP) Metadata() *saml.EntityDescriptor {
// we need to publish all certs, to allow safe rotation
keyDescriptors := []saml.KeyDescriptor{}
for _, der := range i.keys.AllCertsDER() {
keyDescriptors = append(keyDescriptors, saml.KeyDescriptor{
Use: "signing",
KeyInfo: saml.KeyInfo{
X509Data: saml.X509Data{
X509Certificates: []saml.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(der),
}},
},
},
})
}
entityDescriptor := &saml.EntityDescriptor{
EntityID: i.entityID,
// crewjam expects time.Time for ValidUntil and time.Duration for CacheDuration
ValidUntil: time.Now().UTC().Add(time.Duration(i.sec.MetadataValidUntilDays) * 24 * time.Hour),
CacheDuration: time.Duration(i.sec.MetadataCacheDurationSeconds) * time.Second,
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
SSODescriptor: saml.SSODescriptor{
RoleDescriptor: saml.RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
KeyDescriptors: keyDescriptors,
},
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPPostBinding,
Location: i.ssoURL,
},
},
},
},
}
return entityDescriptor
}

204
internal/saml/sign.go Normal file
View File

@@ -0,0 +1,204 @@
package saml
import (
"crypto"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"shamilnunhuck/saml-oidc-bridge/internal/config"
)
const (
subjectConfirmationMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
nameIDFormatEntity = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"
authnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
defaultAssertionTTL = 5 * time.Minute
)
func (i *IdP) BuildResponse(sp config.SP, nameID string, attrs map[string][]string, inResponseTo string) (*saml.Response, error) {
now := saml.TimeNow()
ttl := time.Duration(i.sec.AssertionTTLSec) * time.Second
if ttl <= 0 {
ttl = defaultAssertionTTL
}
skew := time.Duration(i.sec.SkewSeconds) * time.Second
assertionID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate assertion id: %w", err)
}
responseID, err := newSAMLID()
if err != nil {
return nil, fmt.Errorf("generate response id: %w", err)
}
audience := sp.Audience
if audience == "" {
audience = sp.EntityID
}
nameIDFormat := sp.NameIDFormat
if nameIDFormat == "" {
nameIDFormat = string(saml.UnspecifiedNameIDFormat)
}
assertion := &saml.Assertion{
ID: assertionID,
IssueInstant: now,
Version: "2.0",
Issuer: saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Subject: &saml.Subject{
NameID: &saml.NameID{
Format: nameIDFormat,
Value: nameID,
},
SubjectConfirmations: []saml.SubjectConfirmation{
{
Method: subjectConfirmationMethodBearer,
SubjectConfirmationData: &saml.SubjectConfirmationData{
InResponseTo: inResponseTo,
NotOnOrAfter: now.Add(ttl),
Recipient: sp.ACSURL,
},
},
},
},
Conditions: &saml.Conditions{
NotBefore: now.Add(-skew),
NotOnOrAfter: now.Add(ttl),
AudienceRestrictions: []saml.AudienceRestriction{
{
Audience: saml.Audience{Value: audience},
},
},
},
AuthnStatements: []saml.AuthnStatement{
{
AuthnInstant: now,
SessionIndex: responseID,
AuthnContext: saml.AuthnContext{
AuthnContextClassRef: &saml.AuthnContextClassRef{Value: authnContextPasswordProtectedTransport},
},
},
},
AttributeStatements: []saml.AttributeStatement{
{
Attributes: toSAMLAttributes(attrs),
},
},
}
resp := &saml.Response{
ID: responseID,
Version: "2.0",
IssueInstant: now,
InResponseTo: inResponseTo,
Destination: sp.ACSURL,
Issuer: &saml.Issuer{
Format: nameIDFormatEntity,
Value: i.entityID,
},
Status: saml.Status{
StatusCode: saml.StatusCode{Value: saml.StatusSuccess},
},
Assertion: assertion,
}
if err := i.signResponse(resp); err != nil {
return nil, err
}
return resp, nil
}
func (i *IdP) signingContext() (*dsig.SigningContext, error) {
keyPair := i.keys.Active()
if len(keyPair.Certificate) == 0 || keyPair.PrivateKey == nil {
return nil, errors.New("active key missing certificate or private key")
}
ctx := dsig.NewDefaultSigningContext(dsig.TLSCertKeyStore(keyPair))
ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
if err := ctx.SetSignatureMethod(dsig.RSASHA256SignatureMethod); err != nil {
return nil, err
}
ctx.Hash = crypto.SHA256
return ctx, nil
}
func (i *IdP) signResponse(resp *saml.Response) error {
if resp.Assertion == nil {
return errors.New("response missing assertion")
}
assertionCtx, err := i.signingContext()
if err != nil {
return err
}
assertionEl := resp.Assertion.Element()
signedAssertionEl, err := assertionCtx.SignEnveloped(assertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
sigEl, err := lastChildElement(signedAssertionEl)
if err != nil {
return fmt.Errorf("sign assertion: %w", err)
}
resp.Assertion.Signature = sigEl
responseCtx, err := i.signingContext()
if err != nil {
return err
}
responseEl := resp.Element()
signedResponseEl, err := responseCtx.SignEnveloped(responseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
sigEl, err = lastChildElement(signedResponseEl)
if err != nil {
return fmt.Errorf("sign response: %w", err)
}
resp.Signature = sigEl
return nil
}
func MarshalSignedResponse(resp *saml.Response) ([]byte, error) {
if resp == nil {
return nil, errors.New("nil response")
}
if resp.Signature == nil {
return nil, errors.New("response not signed")
}
doc := etree.NewDocument()
doc.SetRoot(resp.Element())
return doc.WriteToBytes()
}
func lastChildElement(parent *etree.Element) (*etree.Element, error) {
children := parent.ChildElements()
if len(children) == 0 {
return nil, errors.New("no child elements found")
}
return children[len(children)-1], nil
}
func newSAMLID() (string, error) {
var b [20]byte
if _, err := rand.Read(b[:]); err != nil {
return "", err
}
return "_" + hex.EncodeToString(b[:]), nil
}

26
internal/saml/util.go Normal file
View File

@@ -0,0 +1,26 @@
package saml
import "github.com/crewjam/saml"
const attrNameFormatURI = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
func toSAMLAttributes(attrs map[string][]string) []saml.Attribute {
out := make([]saml.Attribute, 0, len(attrs))
for name, vals := range attrs {
vs := make([]saml.AttributeValue, 0, len(vals))
for _, s := range vals {
vs = append(vs, saml.AttributeValue{
// explicitly mark string type so we never emit xsi:type=""
Type: "xs:string",
Value: s,
})
}
out = append(out, saml.Attribute{
FriendlyName: name,
Name: name,
NameFormat: attrNameFormatURI,
Values: vs,
})
}
return out
}