mirror of
https://github.com/TwiN/gatus.git
synced 2026-02-16 12:41:12 +00:00
feat(client): Add support for SSH tunneling (#1298)
* feat(client): Add support for SSH tunneling * Fix test
This commit is contained in:
157
config/tunneling/sshtunnel/sshtunnel.go
Normal file
157
config/tunneling/sshtunnel/sshtunnel.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package sshtunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Config represents the configuration for an SSH tunnel
|
||||
type Config struct {
|
||||
Type string `yaml:"type"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port,omitempty"`
|
||||
Username string `yaml:"username"`
|
||||
PrivateKey string `yaml:"private-key,omitempty"`
|
||||
Password string `yaml:"password,omitempty"`
|
||||
}
|
||||
|
||||
// ValidateAndSetDefaults validates the SSH tunnel configuration and sets defaults
|
||||
func (c *Config) ValidateAndSetDefaults() error {
|
||||
if c.Type != "SSH" {
|
||||
return fmt.Errorf("unsupported tunnel type: %s", c.Type)
|
||||
}
|
||||
if c.Host == "" {
|
||||
return fmt.Errorf("host is required")
|
||||
}
|
||||
if c.Username == "" {
|
||||
return fmt.Errorf("username is required")
|
||||
}
|
||||
if c.PrivateKey == "" && c.Password == "" {
|
||||
return fmt.Errorf("either private-key or password is required")
|
||||
}
|
||||
if c.Port == 0 {
|
||||
c.Port = 22
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHTunnel represents an SSH tunnel connection
|
||||
type SSHTunnel struct {
|
||||
config *Config
|
||||
mu sync.RWMutex
|
||||
client *ssh.Client
|
||||
|
||||
// Cached authentication methods to avoid reparsing private keys
|
||||
authMethods []ssh.AuthMethod
|
||||
}
|
||||
|
||||
// New creates a new SSH tunnel with the given configuration
|
||||
func New(config *Config) *SSHTunnel {
|
||||
tunnel := &SSHTunnel{
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Parse authentication methods once during initialization to avoid
|
||||
// expensive cryptographic operations on every connection attempt
|
||||
if config.PrivateKey != "" {
|
||||
if signer, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)); err == nil {
|
||||
tunnel.authMethods = []ssh.AuthMethod{ssh.PublicKeys(signer)}
|
||||
}
|
||||
// Note: We don't return error here to maintain backward compatibility.
|
||||
// Invalid keys will be caught during first connection attempt.
|
||||
} else if config.Password != "" {
|
||||
tunnel.authMethods = []ssh.AuthMethod{ssh.Password(config.Password)}
|
||||
}
|
||||
|
||||
return tunnel
|
||||
}
|
||||
|
||||
// Connect establishes the SSH connection
|
||||
func (t *SSHTunnel) Connect() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.connectUnsafe()
|
||||
}
|
||||
|
||||
// connectUnsafe establishes the SSH connection without acquiring locks
|
||||
// Must be called with t.mu.Lock() already held
|
||||
func (t *SSHTunnel) connectUnsafe() error {
|
||||
// Use cached authentication methods to avoid expensive crypto operations
|
||||
if len(t.authMethods) == 0 {
|
||||
return fmt.Errorf("no authentication method available")
|
||||
}
|
||||
config := &ssh.ClientConfig{
|
||||
User: t.config.Username,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Skip host key verification
|
||||
Auth: t.authMethods, // Use pre-parsed authentication
|
||||
}
|
||||
// Connect to SSH server
|
||||
addr := fmt.Sprintf("%s:%d", t.config.Host, t.config.Port)
|
||||
client, err := ssh.Dial("tcp", addr, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SSH connection failed: %w", err)
|
||||
}
|
||||
t.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the SSH connection
|
||||
func (t *SSHTunnel) Close() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.client != nil {
|
||||
err := t.client.Close()
|
||||
t.client = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dial creates a connection through the SSH tunnel
|
||||
func (t *SSHTunnel) Dial(network, addr string) (net.Conn, error) {
|
||||
t.mu.RLock()
|
||||
client := t.client
|
||||
t.mu.RUnlock()
|
||||
// Ensure we have an SSH connection
|
||||
if client == nil {
|
||||
// Use write lock to prevent race condition during connection
|
||||
t.mu.Lock()
|
||||
// Double-check client after acquiring lock
|
||||
if t.client == nil {
|
||||
if err := t.connectUnsafe(); err != nil {
|
||||
t.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
client = t.client
|
||||
t.mu.Unlock()
|
||||
}
|
||||
// Create connection through SSH tunnel
|
||||
conn, err := client.Dial(network, addr)
|
||||
if err != nil {
|
||||
// Close stale connection before retry to prevent leak
|
||||
t.mu.Lock()
|
||||
if t.client != nil {
|
||||
t.client.Close()
|
||||
t.client = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
// Retry once - connection might be stale
|
||||
if connErr := t.Connect(); connErr != nil {
|
||||
return nil, fmt.Errorf("SSH tunnel dial failed: %w (retry failed: %v)", err, connErr)
|
||||
}
|
||||
t.mu.RLock()
|
||||
client = t.client
|
||||
t.mu.RUnlock()
|
||||
conn, err = client.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH tunnel dial failed after retry: %w", err)
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
158
config/tunneling/sshtunnel/sshtunnel_test.go
Normal file
158
config/tunneling/sshtunnel/sshtunnel_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package sshtunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfig_ValidateAndSetDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid SSH config with private key",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid SSH config with password",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid SSH config with custom port",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Port: 2222,
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sets default port 22",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
config: &Config{
|
||||
Type: "INVALID",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "unsupported tunnel type: INVALID",
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Password: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "username is required",
|
||||
},
|
||||
{
|
||||
name: "missing authentication",
|
||||
config: &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "either private-key or password is required",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalPort := tt.config.Port
|
||||
err := tt.config.ValidateAndSetDefaults()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("ValidateAndSetDefaults() expected error but got none")
|
||||
return
|
||||
}
|
||||
if err.Error() != tt.errMsg {
|
||||
t.Errorf("ValidateAndSetDefaults() error = %v, want %v", err.Error(), tt.errMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("ValidateAndSetDefaults() unexpected error = %v", err)
|
||||
return
|
||||
}
|
||||
// Check that default port is set
|
||||
if originalPort == 0 && tt.config.Port != 22 {
|
||||
t.Errorf("ValidateAndSetDefaults() expected default port 22, got %d", tt.config.Port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
config := &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
}
|
||||
tunnel := New(config)
|
||||
if tunnel == nil {
|
||||
t.Error("New() returned nil")
|
||||
return
|
||||
}
|
||||
if tunnel.config != config {
|
||||
t.Error("New() did not set config correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHTunnel_Close(t *testing.T) {
|
||||
config := &Config{
|
||||
Type: "SSH",
|
||||
Host: "example.com",
|
||||
Username: "test",
|
||||
Password: "secret",
|
||||
}
|
||||
tunnel := New(config)
|
||||
// Test closing when no client is set
|
||||
err := tunnel.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() with no client returned error: %v", err)
|
||||
}
|
||||
// Test closing multiple times
|
||||
err = tunnel.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() called twice returned error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user