diff --git a/config/endpoint/endpoint.go b/config/endpoint/endpoint.go index b049052e..91fc712e 100644 --- a/config/endpoint/endpoint.go +++ b/config/endpoint/endpoint.go @@ -221,12 +221,12 @@ func (e *Endpoint) ValidateAndSetDefaults() error { e.Headers = make(map[string]string) } // Automatically add user agent header if there isn't one specified in the endpoint configuration - if _, userAgentHeaderExists := e.Headers[UserAgentHeader]; !userAgentHeaderExists { + if !hasHeader(e.Headers, UserAgentHeader) { e.Headers[UserAgentHeader] = GatusUserAgent } // Automatically add "Content-Type: application/json" header if there's no Content-Type set // and endpoint.GraphQL is set to true - if _, contentTypeHeaderExists := e.Headers[ContentTypeHeader]; !contentTypeHeaderExists && e.GraphQL { + if !hasHeader(e.Headers, ContentTypeHeader) && e.GraphQL { e.Headers[ContentTypeHeader] = "application/json" } if len(e.Conditions) == 0 { @@ -493,8 +493,8 @@ func (e *Endpoint) call(result *Result) { wsHeaders[k] = v } } - if _, exists := wsHeaders["User-Agent"]; !exists { - wsHeaders["User-Agent"] = GatusUserAgent + if !hasHeader(wsHeaders, UserAgentHeader) { + wsHeaders[UserAgentHeader] = GatusUserAgent } result.Connected, result.Body, err = client.QueryWebSocket(e.URL, e.getParsedBody(), wsHeaders, e.ClientConfig) if err != nil { @@ -582,7 +582,7 @@ func (e *Endpoint) buildHTTPRequest() *http.Request { request, _ := http.NewRequest(e.Method, e.URL, bodyBuffer) for k, v := range e.Headers { request.Header.Set(k, v) - if k == HostHeader { + if strings.EqualFold(k, HostHeader) { request.Host = v } } @@ -626,3 +626,13 @@ func (e *Endpoint) needsToRetrieveIP() bool { } return false } + +// hasHeader checks if a header exists in the map using a case-insensitive lookup +func hasHeader(headers map[string]string, name string) bool { + for k := range headers { + if strings.EqualFold(k, name) { + return true + } + } + return false +} diff --git a/config/endpoint/endpoint_test.go b/config/endpoint/endpoint_test.go index c6aee630..d9b60f75 100644 --- a/config/endpoint/endpoint_test.go +++ b/config/endpoint/endpoint_test.go @@ -21,6 +21,30 @@ import ( "github.com/TwiN/gatus/v5/test" ) +func TestHasHeader(t *testing.T) { + scenarios := []struct { + name string + headers map[string]string + lookup string + expected bool + }{ + {name: "exact-match", headers: map[string]string{"User-Agent": "test"}, lookup: "User-Agent", expected: true}, + {name: "lowercase-lookup", headers: map[string]string{"User-Agent": "test"}, lookup: "user-agent", expected: true}, + {name: "uppercase-lookup", headers: map[string]string{"user-agent": "test"}, lookup: "USER-AGENT", expected: true}, + {name: "mixed-case", headers: map[string]string{"UsEr-AgEnT": "test"}, lookup: "uSeR-aGeNt", expected: true}, + {name: "not-found", headers: map[string]string{"Content-Type": "test"}, lookup: "User-Agent", expected: false}, + {name: "empty-headers", headers: map[string]string{}, lookup: "User-Agent", expected: false}, + {name: "nil-headers", headers: nil, lookup: "User-Agent", expected: false}, + } + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + if result := hasHeader(scenario.headers, scenario.lookup); result != scenario.expected { + t.Errorf("expected %v, got %v", scenario.expected, result) + } + }) + } +} + func TestEndpoint(t *testing.T) { defer client.InjectHTTPClient(nil) scenarios := []struct { @@ -722,6 +746,76 @@ func TestEndpoint_buildHTTPRequestWithHostHeader(t *testing.T) { } } +func TestEndpoint_buildHTTPRequestWithLowercaseUserAgent(t *testing.T) { + condition := Condition("[STATUS] == 200") + endpoint := Endpoint{ + Name: "website-health", + URL: "https://twin.sh/health", + Conditions: []Condition{condition}, + Headers: map[string]string{ + "user-agent": "CustomAgent/1.0", + }, + } + err := endpoint.ValidateAndSetDefaults() + if err != nil { + t.Fatal("did not expect an error, got", err) + } + if _, exists := endpoint.Headers[UserAgentHeader]; exists { + t.Error("User-Agent header should not have been added since user-agent was already specified") + } + request := endpoint.buildHTTPRequest() + if userAgent := request.Header.Get("User-Agent"); userAgent != "CustomAgent/1.0" { + t.Errorf("request.Header.Get(User-Agent) should've been CustomAgent/1.0, but was %s", userAgent) + } +} + +func TestEndpoint_buildHTTPRequestWithLowercaseContentType(t *testing.T) { + condition := Condition("[STATUS] == 200") + endpoint := Endpoint{ + Name: "website-graphql", + URL: "https://twin.sh/graphql", + Method: "POST", + Conditions: []Condition{condition}, + GraphQL: true, + Headers: map[string]string{ + "content-type": "application/graphql", + }, + Body: `{ users { id } }`, + } + err := endpoint.ValidateAndSetDefaults() + if err != nil { + t.Fatal("did not expect an error, got", err) + } + if _, exists := endpoint.Headers[ContentTypeHeader]; exists { + t.Error("Content-Type header should not have been added since content-type was already specified") + } + request := endpoint.buildHTTPRequest() + if contentType := request.Header.Get("Content-Type"); contentType != "application/graphql" { + t.Errorf("request.Header.Get(Content-Type) should've been application/graphql, but was %s", contentType) + } +} + +func TestEndpoint_buildHTTPRequestWithLowercaseHostHeader(t *testing.T) { + condition := Condition("[STATUS] == 200") + endpoint := Endpoint{ + Name: "website-health", + URL: "https://twin.sh/health", + Method: "POST", + Conditions: []Condition{condition}, + Headers: map[string]string{ + "host": "example.com", + }, + } + err := endpoint.ValidateAndSetDefaults() + if err != nil { + t.Fatal("did not expect an error, got", err) + } + request := endpoint.buildHTTPRequest() + if request.Host != "example.com" { + t.Error("request.Host should've been example.com, but was", request.Host) + } +} + func TestEndpoint_buildHTTPRequestWithGraphQLEnabled(t *testing.T) { condition := Condition("[STATUS] == 200") endpoint := Endpoint{