115 lines
2.4 KiB
Go
115 lines
2.4 KiB
Go
package main
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
|
|
"shamilnunhuck/saml-oidc-bridge/internal/cli"
|
|
"shamilnunhuck/saml-oidc-bridge/internal/config"
|
|
"shamilnunhuck/saml-oidc-bridge/internal/crypto"
|
|
h "shamilnunhuck/saml-oidc-bridge/internal/http"
|
|
"shamilnunhuck/saml-oidc-bridge/internal/oidc"
|
|
"shamilnunhuck/saml-oidc-bridge/internal/saml"
|
|
)
|
|
|
|
type runtimeState struct {
|
|
mu sync.RWMutex
|
|
cfg *config.Config
|
|
ks *crypto.KeyStore
|
|
idp *saml.IdP
|
|
oidc *oidc.Client
|
|
}
|
|
|
|
func main() {
|
|
if len(os.Args) > 1 && os.Args[1] == "cert" {
|
|
if err := cli.RunCert(os.Args[2:]); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
cfgPath := os.Getenv("CONFIG_PATH")
|
|
if cfgPath == "" {
|
|
cfgPath = "example.config.yaml"
|
|
}
|
|
|
|
state := &runtimeState{}
|
|
load := func() {
|
|
cfg, err := config.Load(cfgPath)
|
|
if err != nil {
|
|
log.Fatalf("load config: %v", err)
|
|
}
|
|
if v := os.Getenv("OIDC_CLIENT_SECRET"); v != "" {
|
|
cfg.OIDC.ClientSecret = v
|
|
}
|
|
if err := cfg.Validate(); err != nil {
|
|
log.Fatalf("invalid config: %v", err)
|
|
}
|
|
|
|
ks, err := crypto.NewKeyStore(cfg.Crypto)
|
|
if err != nil {
|
|
log.Fatalf("keystore: %v", err)
|
|
}
|
|
idp := saml.NewIdP(cfg, ks)
|
|
oc, err := oidc.NewClient(cfg)
|
|
if err != nil {
|
|
log.Fatalf("oidc: %v", err)
|
|
}
|
|
|
|
state.mu.Lock()
|
|
state.cfg, state.ks, state.idp, state.oidc = cfg, ks, idp, oc
|
|
state.mu.Unlock()
|
|
log.Printf("loaded config; active signing key=%s", cfg.Crypto.ActiveKey)
|
|
}
|
|
load()
|
|
|
|
mux := http.NewServeMux()
|
|
h.Register(
|
|
mux,
|
|
func() *config.Config { state.mu.RLock(); defer state.mu.RUnlock(); return state.cfg },
|
|
func() *saml.IdP { state.mu.RLock(); defer state.mu.RUnlock(); return state.idp },
|
|
func() *oidc.Client { state.mu.RLock(); defer state.mu.RUnlock(); return state.oidc },
|
|
)
|
|
|
|
go func() {
|
|
log.Printf("listening on %s", state.cfg.Server.Listen)
|
|
log.Fatal(http.ListenAndServe(state.cfg.Server.Listen, mux))
|
|
}()
|
|
|
|
sigc := make(chan os.Signal, 1)
|
|
signal.Notify(sigc, syscall.SIGHUP)
|
|
go func() {
|
|
for range sigc {
|
|
load()
|
|
}
|
|
}()
|
|
|
|
w, err := fsnotify.NewWatcher()
|
|
if err == nil {
|
|
defer w.Close()
|
|
_ = w.Add(cfgPath)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case e := <-w.Events:
|
|
if e.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
|
|
load()
|
|
}
|
|
case err := <-w.Errors:
|
|
if err != nil {
|
|
log.Printf("watch error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
select {}
|
|
}
|