mirror of
https://github.com/TwiN/gatus.git
synced 2026-02-15 02:10:06 +00:00
fix(client): Switch websocket library (#1423)
* fix(websocket): switch to gorilla/websocket * fix(client): add missing t.Parallel() in tests --------- Co-authored-by: TwiN <twin@linux.com>
This commit is contained in:
@@ -21,13 +21,13 @@ import (
|
||||
"github.com/TwiN/gocache/v2"
|
||||
"github.com/TwiN/logr"
|
||||
"github.com/TwiN/whois"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/ishidawataru/sctp"
|
||||
"github.com/miekg/dns"
|
||||
ping "github.com/prometheus-community/pro-bing"
|
||||
"github.com/registrobr/rdap"
|
||||
"github.com/registrobr/rdap/protocol"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -394,48 +394,53 @@ func ShouldRunPingerAsPrivileged() bool {
|
||||
// QueryWebSocket opens a websocket connection, write `body` and return a message from the server
|
||||
func QueryWebSocket(address, body string, headers map[string]string, config *Config) (bool, []byte, error) {
|
||||
const (
|
||||
Origin = "http://localhost/"
|
||||
MaximumMessageSize = 1024 // in bytes
|
||||
Origin = "http://localhost/"
|
||||
)
|
||||
wsConfig, err := websocket.NewConfig(address, Origin)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("error configuring websocket connection: %w", err)
|
||||
}
|
||||
if headers != nil {
|
||||
if wsConfig.Header == nil {
|
||||
wsConfig.Header = make(http.Header)
|
||||
}
|
||||
for name, value := range headers {
|
||||
wsConfig.Header.Set(name, value)
|
||||
var (
|
||||
dialer = websocket.Dialer{
|
||||
EnableCompression: true,
|
||||
}
|
||||
wsHeaders = make(http.Header)
|
||||
)
|
||||
|
||||
wsHeaders.Set("Origin", Origin)
|
||||
for name, value := range headers {
|
||||
wsHeaders.Set(name, value)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if config != nil {
|
||||
wsConfig.Dialer = &net.Dialer{Timeout: config.Timeout}
|
||||
wsConfig.TlsConfig = &tls.Config{
|
||||
if config.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
dialer.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: config.Insecure,
|
||||
}
|
||||
if config.HasTLSConfig() && config.TLS.isValid() == nil {
|
||||
wsConfig.TlsConfig = configureTLS(wsConfig.TlsConfig, *config.TLS)
|
||||
dialer.TLSClientConfig = configureTLS(dialer.TLSClientConfig, *config.TLS)
|
||||
}
|
||||
}
|
||||
// Dial URL
|
||||
ws, err := websocket.DialConfig(wsConfig)
|
||||
ws, _, err := dialer.DialContext(ctx, address, wsHeaders)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("error dialing websocket: %w", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
body = parseLocalAddressPlaceholder(body, ws.LocalAddr())
|
||||
// Write message
|
||||
if _, err := ws.Write([]byte(body)); err != nil {
|
||||
if err := ws.WriteMessage(websocket.TextMessage, []byte(body)); err != nil {
|
||||
return false, nil, fmt.Errorf("error writing websocket body: %w", err)
|
||||
}
|
||||
// Read message
|
||||
var n int
|
||||
msg := make([]byte, MaximumMessageSize)
|
||||
if n, err = ws.Read(msg); err != nil {
|
||||
msgType, msg, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("error reading websocket message: %w", err)
|
||||
} else if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
return false, nil, fmt.Errorf("unexpected websocket message type: %d, expected %d or %d", msgType, websocket.TextMessage, websocket.BinaryMessage)
|
||||
}
|
||||
return true, msg[:n], nil
|
||||
return true, msg, nil
|
||||
}
|
||||
|
||||
func QueryDNS(queryType, queryName, url string) (connected bool, dnsRcode string, body []byte, err error) {
|
||||
|
||||
Reference in New Issue
Block a user