diff --git a/storage/remote/client.go b/storage/remote/client.go index 19f409b11..9e6a31c66 100644 --- a/storage/remote/client.go +++ b/storage/remote/client.go @@ -215,7 +215,7 @@ func (c *Client) Store(ctx context.Context, req []byte, attempt int) error { httpReq.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion1HeaderValue) } else { // Set the right header if we're using v1.1 remote write protocol - httpReq.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion11HeaderValue) + httpReq.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion20HeaderValue) } if attempt > 0 { diff --git a/storage/remote/write_handler.go b/storage/remote/write_handler.go index b09fd8d0a..77ac2b5c4 100644 --- a/storage/remote/write_handler.go +++ b/storage/remote/write_handler.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/prometheus/prometheus/model/labels" writev2 "github.com/prometheus/prometheus/prompb/write/v2" @@ -36,9 +37,66 @@ import ( const ( RemoteWriteVersionHeader = "X-Prometheus-Remote-Write-Version" RemoteWriteVersion1HeaderValue = "0.1.0" - RemoteWriteVersion11HeaderValue = "1.1" // TODO-RW11: Final value? + RemoteWriteVersion20HeaderValue = "2.0" ) +func RemoteWriteHeaderNameValues(rwFormat RemoteWriteFormat) map[string]string { + // Return the correct remote write header name/values based on provided rwFormat + ret := make(map[string]string, 1) + + switch rwFormat { + case Version1: + ret[RemoteWriteVersionHeader] = RemoteWriteVersion1HeaderValue + case Version2: + // We need to add the supported protocol definitions in order: + tuples := make([]string, 0, 2) + // Add 2.0;snappy; + tuples = append(tuples, RemoteWriteVersion20HeaderValue+";snappy;") + // Add default 0.1.0 + tuples = append(tuples, RemoteWriteVersion1HeaderValue) + ret[RemoteWriteVersionHeader] = strings.Join(tuples, ",") + } + return ret +} + +type writeHeadHandler struct { + logger log.Logger + + remoteWrite20HeadRequests prometheus.Counter + + // Experimental feature, new remote write proto format + // The handler will accept the new format, but it can still accept the old one + rwFormat RemoteWriteFormat +} + +func NewWriteHeadHandler(logger log.Logger, reg prometheus.Registerer, rwFormat RemoteWriteFormat) http.Handler { + h := &writeHeadHandler{ + logger: logger, + rwFormat: rwFormat, + remoteWrite20HeadRequests: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "prometheus", + Subsystem: "api", + Name: "remote_write_20_head_requests", + Help: "The number of remote write 2.0 head requests.", + }), + } + if reg != nil { + reg.MustRegister(h.remoteWrite20HeadRequests) + } + return h +} + +func (h *writeHeadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Send a response to the HEAD request based on the format supported + + // Add appropriate header values for the specific rwFormat + for hName, hValue := range RemoteWriteHeaderNameValues(h.rwFormat) { + w.Header().Set(hName, hValue) + } + + w.WriteHeader(http.StatusOK) +} + type writeHandler struct { logger log.Logger appendable storage.Appendable @@ -47,7 +105,6 @@ type writeHandler struct { // Experimental feature, new remote write proto format // The handler will accept the new format, but it can still accept the old one - // TODO: this should eventually be via content negotiation rwFormat RemoteWriteFormat } @@ -76,25 +133,59 @@ func (h *writeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var req *prompb.WriteRequest var reqMinStr *writev2.WriteRequest - // TODO: this should eventually be done via content negotiation/looking at the header - switch h.rwFormat { - case Version1: - req, err = DecodeWriteRequest(r.Body) - case Version2: - reqMinStr, err = DecodeMinimizedWriteRequestStr(r.Body) + // Set the header(s) in the response based on the rwFormat the server supports + for hName, hValue := range RemoteWriteHeaderNameValues(h.rwFormat) { + w.Header().Set(hName, hValue) } - if err != nil { - level.Error(h.logger).Log("msg", "Error decoding remote write request", "err", err.Error()) - http.Error(w, err.Error(), http.StatusBadRequest) + // Parse the headers to work out how to handle this + contentEncoding := r.Header.Get("Content-Encoding") + protoVer := r.Header.Get(RemoteWriteVersionHeader) + + if protoVer == "" { + // No header provided, assume 0.1.0 as everything that relies on later + // features MUST supply the correct headers + protoVer = RemoteWriteVersion1HeaderValue + } else if protoVer == RemoteWriteVersion20HeaderValue { + // This is a 2.0 request, woo + } else { + // We have a version in the header but it is not one we recognise + // TODO - make a proper error for this + level.Error(h.logger).Log("msg", "Error decoding remote write request", "err", "Unknown remote write version in headers", "ver", protoVer) + http.Error(w, "Unknown remote write version in headers", http.StatusBadRequest) return } - // TODO: this should eventually be done detecting the format version above - switch h.rwFormat { - case Version1: + // At this point we are happy with the version but need to check the encoding + if protoVer == RemoteWriteVersion1HeaderValue { + // If the version is 0.1.0 then we automatically assume Snappy encoding + // so we check that it is "snappy" if specified or unspecified + if contentEncoding != "" && contentEncoding != "snappy" { + level.Error(h.logger).Log("msg", "Error determining remote write request encoding", "contentEncoding", contentEncoding) + http.Error(w, "Error determining remote write encoding", http.StatusBadRequest) + return + } + req, err = DecodeWriteRequest(r.Body) + if err != nil { + level.Error(h.logger).Log("msg", "Error decoding remote write request", "err", err.Error()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } err = h.write(r.Context(), req) - case Version2: + } else { + // 2.0 request + // MUST be snappy encoded + if contentEncoding != "snappy" { + level.Error(h.logger).Log("msg", "Error determining remote write request encoding", "contentEncoding", contentEncoding) + http.Error(w, "Error determining remote write encoding", http.StatusNotAcceptable) + return + } + reqMinStr, err = DecodeMinimizedWriteRequestStr(r.Body) + if err != nil { + level.Error(h.logger).Log("msg", "Error decoding remote write request", "err", err.Error()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } err = h.writeMinStr(r.Context(), reqMinStr) } diff --git a/storage/remote/write_handler_test.go b/storage/remote/write_handler_test.go index f30154e97..3e1be395e 100644 --- a/storage/remote/write_handler_test.go +++ b/storage/remote/write_handler_test.go @@ -37,6 +37,86 @@ import ( "github.com/prometheus/prometheus/tsdb" ) +func TestRemoteWriteHeadHandler(t *testing.T) { + handler := NewWriteHeadHandler(log.NewNopLogger(), nil, Version2) + + req, err := http.NewRequest(http.MethodHead, "", nil) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + resp := recorder.Result() + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Check header is expected value + protHeader := resp.Header.Get(RemoteWriteVersionHeader) + require.Equal(t, protHeader, "2.0;snappy;,0.1.0") +} + +func TestRemoteWriteHandlerMinimizedMissingContentEncoding(t *testing.T) { + // Send a v2 request without a "Content-Encoding:" header -> 406 + buf, _, err := buildMinimizedWriteRequestStr(writeRequestMinimizedFixture.Timeseries, writeRequestMinimizedFixture.Symbols, nil, nil) + require.NoError(t, err) + + req, err := http.NewRequest("", "", bytes.NewReader(buf)) + req.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion20HeaderValue) + // Do not provide "Content-Encoding: snappy" header + // req.Header.Set("Content-Encoding", "snappy") + require.NoError(t, err) + + appendable := &mockAppendable{} + handler := NewWriteHandler(log.NewNopLogger(), nil, appendable, Version2) + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + resp := recorder.Result() + // Should give us a 406 + require.Equal(t, http.StatusNotAcceptable, resp.StatusCode) +} + +func TestRemoteWriteHandlerInvalidCompression(t *testing.T) { + // Send a v2 request without an unhandled compression scheme -> 406 + buf, _, err := buildMinimizedWriteRequestStr(writeRequestMinimizedFixture.Timeseries, writeRequestMinimizedFixture.Symbols, nil, nil) + require.NoError(t, err) + + req, err := http.NewRequest("", "", bytes.NewReader(buf)) + req.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion20HeaderValue) + req.Header.Set("Content-Encoding", "zstd") + require.NoError(t, err) + + appendable := &mockAppendable{} + handler := NewWriteHandler(log.NewNopLogger(), nil, appendable, Version2) + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + resp := recorder.Result() + // Expect a 406 + require.Equal(t, http.StatusNotAcceptable, resp.StatusCode) +} + +func TestRemoteWriteHandlerInvalidVersion(t *testing.T) { + // Send a protocol version number that isn't recognised/supported -> 400 + buf, _, err := buildMinimizedWriteRequestStr(writeRequestMinimizedFixture.Timeseries, writeRequestMinimizedFixture.Symbols, nil, nil) + require.NoError(t, err) + + req, err := http.NewRequest("", "", bytes.NewReader(buf)) + req.Header.Set(RemoteWriteVersionHeader, "0.3.0") + require.NoError(t, err) + + appendable := &mockAppendable{} + handler := NewWriteHandler(log.NewNopLogger(), nil, appendable, Version2) + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + resp := recorder.Result() + // Expect a 400 BadRequest + require.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + func TestRemoteWriteHandler(t *testing.T) { buf, _, err := buildWriteRequest(writeRequestFixture.Timeseries, nil, nil, nil) require.NoError(t, err) @@ -45,7 +125,6 @@ func TestRemoteWriteHandler(t *testing.T) { require.NoError(t, err) appendable := &mockAppendable{} - // TODO: test with other proto format(s) handler := NewWriteHandler(log.NewNopLogger(), nil, appendable, Version1) recorder := httptest.NewRecorder() @@ -54,6 +133,10 @@ func TestRemoteWriteHandler(t *testing.T) { resp := recorder.Result() require.Equal(t, http.StatusNoContent, resp.StatusCode) + // Check header is expected value + protHeader := resp.Header.Get(RemoteWriteVersionHeader) + require.Equal(t, protHeader, "0.1.0") + i := 0 j := 0 k := 0 @@ -89,12 +172,13 @@ func TestRemoteWriteHandlerMinimizedFormat(t *testing.T) { require.NoError(t, err) req, err := http.NewRequest("", "", bytes.NewReader(buf)) - req.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion11HeaderValue) + req.Header.Set(RemoteWriteVersionHeader, RemoteWriteVersion20HeaderValue) + // Must provide "Content-Encoding: snappy" header + req.Header.Set("Content-Encoding", "snappy") require.NoError(t, err) appendable := &mockAppendable{} - // TODO: test with other proto format(s) - handler := NewWriteHandler(nil, nil, appendable, Version2) + handler := NewWriteHandler(log.NewNopLogger(), nil, appendable, Version2) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) @@ -102,6 +186,10 @@ func TestRemoteWriteHandlerMinimizedFormat(t *testing.T) { resp := recorder.Result() require.Equal(t, http.StatusNoContent, resp.StatusCode) + // Check header is expected value + protHeader := resp.Header.Get(RemoteWriteVersionHeader) + require.Equal(t, protHeader, "2.0;snappy;,0.1.0") + i := 0 j := 0 k := 0 diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 0288eda02..a1412296f 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -217,9 +217,10 @@ type API struct { isAgent bool statsRenderer StatsRenderer - remoteWriteHandler http.Handler - remoteReadHandler http.Handler - otlpWriteHandler http.Handler + remoteWriteHeadHandler http.Handler + remoteWriteHandler http.Handler + remoteReadHandler http.Handler + otlpWriteHandler http.Handler codecs []Codec } @@ -297,6 +298,7 @@ func NewAPI( if rwEnabled { a.remoteWriteHandler = remote.NewWriteHandler(logger, registerer, ap, rwFormat) + a.remoteWriteHeadHandler = remote.NewWriteHeadHandler(logger, registerer, rwFormat) } if otlpEnabled { a.otlpWriteHandler = remote.NewOTLPWriteHandler(logger, ap) @@ -393,6 +395,7 @@ func (api *API) Register(r *route.Router) { r.Get("/status/walreplay", api.serveWALReplayStatus) r.Post("/read", api.ready(api.remoteRead)) r.Post("/write", api.ready(api.remoteWrite)) + r.Head("/write", api.remoteWriteHead) r.Post("/otlp/v1/metrics", api.ready(api.otlpWrite)) r.Get("/alerts", wrapAgent(api.alerts)) @@ -1616,6 +1619,14 @@ func (api *API) remoteRead(w http.ResponseWriter, r *http.Request) { } } +func (api *API) remoteWriteHead(w http.ResponseWriter, r *http.Request) { + if api.remoteWriteHeadHandler != nil { + api.remoteWriteHeadHandler.ServeHTTP(w, r) + } else { + http.Error(w, "remote write receiver needs to be enabled with --web.enable-remote-write-receiver", http.StatusNotFound) + } +} + func (api *API) remoteWrite(w http.ResponseWriter, r *http.Request) { if api.remoteWriteHandler != nil { api.remoteWriteHandler.ServeHTTP(w, r)