1
0
mirror of https://github.com/TwiN/gatus.git synced 2026-02-04 15:14:43 +00:00

fix(client): Switch websocket library (#1423)

* fix(websocket): switch to gorilla/websocket

* fix(client): add missing t.Parallel() in tests

---------

Co-authored-by: TwiN <twin@linux.com>
This commit is contained in:
Yaroslav
2025-12-18 23:44:44 +00:00
committed by GitHub
parent 13184232d1
commit 15a8055617
4 changed files with 39 additions and 23 deletions

View File

@@ -21,13 +21,13 @@ import (
"github.com/TwiN/gocache/v2"
"github.com/TwiN/logr"
"github.com/TwiN/whois"
"github.com/gorilla/websocket"
"github.com/ishidawataru/sctp"
"github.com/miekg/dns"
ping "github.com/prometheus-community/pro-bing"
"github.com/registrobr/rdap"
"github.com/registrobr/rdap/protocol"
"golang.org/x/crypto/ssh"
"golang.org/x/net/websocket"
)
const (
@@ -394,48 +394,53 @@ func ShouldRunPingerAsPrivileged() bool {
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
const (
Origin = "http://localhost/"
MaximumMessageSize = 1024 // in bytes
Origin = "http://localhost/"
)
wsConfig, err := websocket.NewConfig(address, Origin)
if err != nil {
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
}
if headers != nil {
if wsConfig.Header == nil {
wsConfig.Header = make(http.Header)
}
for name, value := range headers {
wsConfig.Header.Set(name, value)
var (
dialer = websocket.Dialer{
EnableCompression: true,
}
wsHeaders = make(http.Header)
)
wsHeaders.Set("Origin", Origin)
for name, value := range headers {
wsHeaders.Set(name, value)
}
ctx := context.Background()
if config != nil {
wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout}
wsConfig.TlsConfig = &tls.Config{
if config.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, config.Timeout)
defer cancel()
}
dialer.TLSClientConfig = &tls.Config{
InsecureSkipVerify: config.Insecure,
}
if config.HasTLSConfig() && config.TLS.isValid() == nil {
wsConfig.TlsConfig = configureTLS(wsConfig.TlsConfig, *config.TLS)
dialer.TLSClientConfig = configureTLS(dialer.TLSClientConfig, *config.TLS)
}
}
// Dial URL
ws, err := websocket.DialConfig(wsConfig)
ws, _, err := dialer.DialContext(ctx, address, wsHeaders)
if err != nil {
return false, nil, fmt.Errorf("error dialing websocket: %w", err)
}
defer ws.Close()
body = parseLocalAddressPlaceholder(body, ws.LocalAddr())
// Write message
if _, err := ws.Write([]byte(body)); err != nil {
if err := ws.WriteMessage(websocket.TextMessage, []byte(body)); err != nil {
return false, nil, fmt.Errorf("error writing websocket body: %w", err)
}
// Read message
var n int
msg := make([]byte, MaximumMessageSize)
if n, err = ws.Read(msg); err != nil {
msgType, msg, err := ws.ReadMessage()
if err != nil {
return false, nil, fmt.Errorf("error reading websocket message: %w", err)
} else if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
return false, nil, fmt.Errorf("unexpected websocket message type: %d, expected %d or %d", msgType, websocket.TextMessage, websocket.BinaryMessage)
}
return true, msg[:n], nil
return true, msg, nil
}
func QueryDNS(queryType, queryName, url string) (connected bool, dnsRcode string, body []byte, err error) {

View File

@@ -17,6 +17,7 @@ import (
)
func TestGetHTTPClient(t *testing.T) {
t.Parallel()
cfg := &Config{
Insecure: false,
IgnoreRedirect: false,
@@ -42,6 +43,7 @@ func TestGetHTTPClient(t *testing.T) {
}
func TestRdapQuery(t *testing.T) {
t.Parallel()
if _, err := rdapQuery("1.1.1.1"); err == nil {
t.Error("expected an error due to the invalid domain type")
}
@@ -288,6 +290,7 @@ func TestCanPerformTLS(t *testing.T) {
}
func TestCanCreateConnection(t *testing.T) {
t.Parallel()
connected, _ := CanCreateNetworkConnection("tcp", "127.0.0.1", "", &Config{Timeout: 5 * time.Second})
if connected {
t.Error("should've failed, because there's no port in the address")
@@ -302,6 +305,7 @@ func TestCanCreateConnection(t *testing.T) {
// performs a Client Credentials OAuth2 flow and adds the obtained token as a `Authorization`
// header to all outgoing HTTP calls.
func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
t.Parallel()
defer InjectHTTPClient(nil)
oAuth2Config := &OAuth2Config{
ClientID: "00000000-0000-0000-0000-000000000000",
@@ -357,6 +361,7 @@ func TestHttpClientProvidesOAuth2BearerToken(t *testing.T) {
}
func TestQueryWebSocket(t *testing.T) {
t.Parallel()
_, _, err := QueryWebSocket("", "body", nil, &Config{Timeout: 2 * time.Second})
if err == nil {
t.Error("expected an error due to the address being invalid")
@@ -368,6 +373,7 @@ func TestQueryWebSocket(t *testing.T) {
}
func TestTlsRenegotiation(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
cfg TLSConfig
@@ -411,6 +417,7 @@ func TestTlsRenegotiation(t *testing.T) {
}
func TestQueryDNS(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
inputDNS dns.Config
@@ -540,6 +547,7 @@ func TestQueryDNS(t *testing.T) {
}
func TestCheckSSHBanner(t *testing.T) {
t.Parallel()
cfg := &Config{Timeout: 3}
t.Run("no-auth-ssh", func(t *testing.T) {
connected, status, err := CheckSSHBanner("tty.sdf.org", cfg)