From 457e4bb58e0fd3c10e16ecaabfeb04fb2c41bac5 Mon Sep 17 00:00:00 2001 From: Tom Wilkie Date: Wed, 5 Sep 2018 15:50:50 +0200 Subject: [PATCH] Limit the number of samples remote read can return. (#4532) * Limit the number of samples remote read can return. - Return 413 entity too large. - Limit can be set be a flag. Allow 0 to mean no limit. - Include limit in error message. - Set default limit to 50M (* 16 bytes = 800MB). Signed-off-by: Tom Wilkie --- cmd/prometheus/main.go | 3 +++ storage/remote/codec.go | 23 ++++++++++++++++++++++- storage/remote/read_test.go | 2 +- web/api/v1/api.go | 30 +++++++++++++++++++----------- web/api/v1/api_test.go | 7 ++++++- web/web.go | 2 ++ 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/cmd/prometheus/main.go b/cmd/prometheus/main.go index 5b3048cb14..bc50e678d6 100644 --- a/cmd/prometheus/main.go +++ b/cmd/prometheus/main.go @@ -168,6 +168,9 @@ func main() { a.Flag("storage.remote.flush-deadline", "How long to wait flushing sample on shutdown or config reload."). Default("1m").PlaceHolder("").SetValue(&cfg.RemoteFlushDeadline) + a.Flag("storage.remote.read-sample-limit", "Maximum overall number of samples to return via the remote read interface, in a single query. 0 means no limit."). + Default("5e7").IntVar(&cfg.web.RemoteReadLimit) + a.Flag("rules.alert.for-outage-tolerance", "Max time to tolerate prometheus outage for restoring 'for' state of alert."). Default("1h").SetValue(&cfg.outageTolerance) diff --git a/storage/remote/codec.go b/storage/remote/codec.go index 7a2a9c5bc1..6cde53cf2e 100644 --- a/storage/remote/codec.go +++ b/storage/remote/codec.go @@ -32,6 +32,19 @@ import ( // decodeReadLimit is the maximum size of a read request body in bytes. const decodeReadLimit = 32 * 1024 * 1024 +type HTTPError struct { + msg string + status int +} + +func (e HTTPError) Error() string { + return e.msg +} + +func (e HTTPError) Status() int { + return e.status +} + // DecodeReadRequest reads a remote.Request from a http.Request. func DecodeReadRequest(r *http.Request) (*prompb.ReadRequest, error) { compressed, err := ioutil.ReadAll(io.LimitReader(r.Body, decodeReadLimit)) @@ -134,7 +147,8 @@ func FromQuery(req *prompb.Query) (int64, int64, []*labels.Matcher, *storage.Sel } // ToQueryResult builds a QueryResult proto. -func ToQueryResult(ss storage.SeriesSet) (*prompb.QueryResult, error) { +func ToQueryResult(ss storage.SeriesSet, sampleLimit int) (*prompb.QueryResult, error) { + numSamples := 0 resp := &prompb.QueryResult{} for ss.Next() { series := ss.At() @@ -142,6 +156,13 @@ func ToQueryResult(ss storage.SeriesSet) (*prompb.QueryResult, error) { samples := []*prompb.Sample{} for iter.Next() { + numSamples++ + if sampleLimit > 0 && numSamples > sampleLimit { + return nil, HTTPError{ + msg: fmt.Sprintf("exceeded sample limit (%d)", sampleLimit), + status: http.StatusBadRequest, + } + } ts, val := iter.At() samples = append(samples, &prompb.Sample{ Timestamp: ts, diff --git a/storage/remote/read_test.go b/storage/remote/read_test.go index 537f605a8c..7ada03912b 100644 --- a/storage/remote/read_test.go +++ b/storage/remote/read_test.go @@ -135,7 +135,7 @@ func TestSeriesSetFilter(t *testing.T) { for i, tc := range tests { filtered := newSeriesSetFilter(FromQueryResult(tc.in), tc.toRemove) - have, err := ToQueryResult(filtered) + have, err := ToQueryResult(filtered, 1e6) if err != nil { t.Fatal(err) } diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 12893b0289..14ec736523 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -131,9 +131,10 @@ type API struct { flagsMap map[string]string ready func(http.HandlerFunc) http.HandlerFunc - db func() *tsdb.DB - enableAdmin bool - logger log.Logger + db func() *tsdb.DB + enableAdmin bool + logger log.Logger + remoteReadLimit int } // NewAPI returns an initialized API type. @@ -149,19 +150,22 @@ func NewAPI( enableAdmin bool, logger log.Logger, rr rulesRetriever, + remoteReadLimit int, ) *API { return &API{ QueryEngine: qe, Queryable: q, targetRetriever: tr, alertmanagerRetriever: ar, - now: time.Now, - config: configFunc, - flagsMap: flagsMap, - ready: readyFunc, - db: db, - enableAdmin: enableAdmin, - rulesRetriever: rr, + + now: time.Now, + config: configFunc, + flagsMap: flagsMap, + ready: readyFunc, + db: db, + enableAdmin: enableAdmin, + rulesRetriever: rr, + remoteReadLimit: remoteReadLimit, } } @@ -793,8 +797,12 @@ func (api *API) remoteRead(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - resp.Results[i], err = remote.ToQueryResult(set) + resp.Results[i], err = remote.ToQueryResult(set, api.remoteReadLimit) if err != nil { + if httpErr, ok := err.(remote.HTTPError); ok { + http.Error(w, httpErr.Error(), httpErr.Status()) + return + } http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index b9253a475b..cf174b6b4c 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -313,7 +313,7 @@ func setupRemote(s storage.Storage) *httptest.Server { http.Error(w, err.Error(), http.StatusInternalServerError) return } - resp.Results[i], err = remote.ToQueryResult(set) + resp.Results[i], err = remote.ToQueryResult(set, 1e6) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -833,6 +833,7 @@ func TestReadEndpoint(t *testing.T) { }, } }, + remoteReadLimit: 1e6, } // Encode the request. @@ -861,6 +862,10 @@ func TestReadEndpoint(t *testing.T) { recorder := httptest.NewRecorder() api.remoteRead(recorder, request) + if recorder.Code/100 != 2 { + t.Fatal(recorder.Code) + } + // Decode the response. compressed, err = ioutil.ReadAll(recorder.Result().Body) if err != nil { diff --git a/web/web.go b/web/web.go index a278092a4c..d41f62b61f 100644 --- a/web/web.go +++ b/web/web.go @@ -167,6 +167,7 @@ type Options struct { ConsoleLibrariesPath string EnableLifecycle bool EnableAdminAPI bool + RemoteReadLimit int } func instrumentHandler(handlerName string, handler http.HandlerFunc) http.HandlerFunc { @@ -227,6 +228,7 @@ func New(logger log.Logger, o *Options) *Handler { h.options.EnableAdminAPI, logger, h.ruleManager, + h.options.RemoteReadLimit, ) if o.RoutePrefix != "/" {