diff --git a/util/stats/query_stats.go b/util/stats/query_stats.go index 7a6c0f18ea..fc31445808 100644 --- a/util/stats/query_stats.go +++ b/util/stats/query_stats.go @@ -110,15 +110,25 @@ type querySamples struct { TotalQueryableSamples int `json:"totalQueryableSamples"` } -// QueryStats currently only holding query timings. -type QueryStats struct { +// BuiltinStats holds the statistics that Prometheus's core gathers. +type BuiltinStats struct { Timings queryTimings `json:"timings,omitempty"` Samples *querySamples `json:"samples,omitempty"` } +// QueryStats holds BuiltinStats and any other stats the particular +// implementation wants to collect. +type QueryStats interface { + Builtin() BuiltinStats +} + +func (s *BuiltinStats) Builtin() BuiltinStats { + return *s +} + // NewQueryStats makes a QueryStats struct with all QueryTimings found in the // given TimerGroup. -func NewQueryStats(s *Statistics) *QueryStats { +func NewQueryStats(s *Statistics) QueryStats { var ( qt queryTimings samples *querySamples @@ -150,7 +160,7 @@ func NewQueryStats(s *Statistics) *QueryStats { samples.TotalQueryableSamplesPerStep = sp.totalSamplesPerStepPoints() } - qs := QueryStats{Timings: qt, Samples: samples} + qs := BuiltinStats{Timings: qt, Samples: samples} return &qs } diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 67e01a777d..2c50297f1d 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -104,6 +104,15 @@ type RulesRetriever interface { AlertingRules() []*rules.AlertingRule } +type StatsRenderer func(context.Context, *stats.Statistics, string) stats.QueryStats + +func defaultStatsRenderer(ctx context.Context, s *stats.Statistics, param string) stats.QueryStats { + if param != "" { + return stats.NewQueryStats(s) + } + return nil +} + // PrometheusVersion contains build information about Prometheus. type PrometheusVersion struct { Version string `json:"version"` @@ -177,15 +186,16 @@ type API struct { ready func(http.HandlerFunc) http.HandlerFunc globalURLOptions GlobalURLOptions - db TSDBAdminStats - dbDir string - enableAdmin bool - logger log.Logger - CORSOrigin *regexp.Regexp - buildInfo *PrometheusVersion - runtimeInfo func() (RuntimeInfo, error) - gatherer prometheus.Gatherer - isAgent bool + db TSDBAdminStats + dbDir string + enableAdmin bool + logger log.Logger + CORSOrigin *regexp.Regexp + buildInfo *PrometheusVersion + runtimeInfo func() (RuntimeInfo, error) + gatherer prometheus.Gatherer + isAgent bool + statsRenderer StatsRenderer remoteWriteHandler http.Handler remoteReadHandler http.Handler @@ -222,6 +232,7 @@ func NewAPI( buildInfo *PrometheusVersion, gatherer prometheus.Gatherer, registerer prometheus.Registerer, + statsRenderer StatsRenderer, ) *API { a := &API{ QueryEngine: qe, @@ -246,10 +257,15 @@ func NewAPI( buildInfo: buildInfo, gatherer: gatherer, isAgent: isAgent, + statsRenderer: defaultStatsRenderer, remoteReadHandler: remote.NewReadHandler(logger, registerer, q, configFunc, remoteReadSampleLimit, remoteReadConcurrencyLimit, remoteReadMaxBytesInFrame), } + if statsRenderer != nil { + a.statsRenderer = statsRenderer + } + if ap != nil { a.remoteWriteHandler = remote.NewWriteHandler(logger, ap) } @@ -344,9 +360,9 @@ func (api *API) Register(r *route.Router) { } type queryData struct { - ResultType parser.ValueType `json:"resultType"` - Result parser.Value `json:"result"` - Stats *stats.QueryStats `json:"stats,omitempty"` + ResultType parser.ValueType `json:"resultType"` + Result parser.Value `json:"result"` + Stats stats.QueryStats `json:"stats,omitempty"` } func invalidParamError(err error, parameter string) apiFuncResult { @@ -399,10 +415,11 @@ func (api *API) query(r *http.Request) (result apiFuncResult) { } // Optional stats field in response if parameter "stats" is not empty. - var qs *stats.QueryStats - if r.FormValue("stats") != "" { - qs = stats.NewQueryStats(qry.Stats()) + sr := api.statsRenderer + if sr == nil { + sr = defaultStatsRenderer } + qs := sr(ctx, qry.Stats(), r.FormValue("stats")) return apiFuncResult{&queryData{ ResultType: res.Value.Type(), @@ -480,10 +497,11 @@ func (api *API) queryRange(r *http.Request) (result apiFuncResult) { } // Optional stats field in response if parameter "stats" is not empty. - var qs *stats.QueryStats - if r.FormValue("stats") != "" { - qs = stats.NewQueryStats(qry.Stats()) + sr := api.statsRenderer + if sr == nil { + sr = defaultStatsRenderer } + qs := sr(ctx, qry.Stats(), r.FormValue("stats")) return apiFuncResult{&queryData{ ResultType: res.Value.Type(), diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index 6db0495929..55155f8bcf 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -30,6 +30,8 @@ import ( "testing" "time" + "github.com/prometheus/prometheus/util/stats" + "github.com/go-kit/log" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -522,6 +524,14 @@ func TestLabelNames(t *testing.T) { } } +type testStats struct { + Custom string `json:"custom"` +} + +func (testStats) Builtin() (_ stats.BuiltinStats) { + return +} + func TestStats(t *testing.T) { suite, err := promql.NewTest(t, ``) require.NoError(t, err) @@ -535,7 +545,7 @@ func TestStats(t *testing.T) { return time.Unix(123, 0) }, } - request := func(method string, param string) (*http.Request, error) { + request := func(method, param string) (*http.Request, error) { u, err := url.Parse("http://example.com") require.NoError(t, err) q := u.Query() @@ -555,6 +565,7 @@ func TestStats(t *testing.T) { for _, tc := range []struct { name string + renderer StatsRenderer param string expected func(*testing.T, interface{}) }{ @@ -574,7 +585,7 @@ func TestStats(t *testing.T) { require.IsType(t, i, &queryData{}) qd := i.(*queryData) require.NotNil(t, qd.Stats) - qs := qd.Stats + qs := qd.Stats.Builtin() require.NotNil(t, qs.Timings) require.Greater(t, qs.Timings.EvalTotalTime, float64(0)) require.NotNil(t, qs.Samples) @@ -589,7 +600,7 @@ func TestStats(t *testing.T) { require.IsType(t, i, &queryData{}) qd := i.(*queryData) require.NotNil(t, qd.Stats) - qs := qd.Stats + qs := qd.Stats.Builtin() require.NotNil(t, qs.Timings) require.Greater(t, qs.Timings.EvalTotalTime, float64(0)) require.NotNil(t, qs.Samples) @@ -597,8 +608,30 @@ func TestStats(t *testing.T) { require.NotNil(t, qs.Samples.TotalQueryableSamplesPerStep) }, }, + { + name: "custom handler with known value", + renderer: func(ctx context.Context, s *stats.Statistics, p string) stats.QueryStats { + if p == "known" { + return testStats{"Custom Value"} + } + return nil + }, + param: "known", + expected: func(t *testing.T, i interface{}) { + require.IsType(t, i, &queryData{}) + qd := i.(*queryData) + require.NotNil(t, qd.Stats) + j, err := json.Marshal(qd.Stats) + require.NoError(t, err) + require.JSONEq(t, string(j), `{"custom":"Custom Value"}`) + }, + }, } { t.Run(tc.name, func(t *testing.T) { + before := api.statsRenderer + defer func() { api.statsRenderer = before }() + api.statsRenderer = tc.renderer + for _, method := range []string{http.MethodGet, http.MethodPost} { ctx := context.Background() req, err := request(method, tc.param) diff --git a/web/web.go b/web/web.go index 47d76590ec..0574549e5a 100644 --- a/web/web.go +++ b/web/web.go @@ -340,6 +340,7 @@ func New(logger log.Logger, o *Options) *Handler { h.versionInfo, o.Gatherer, o.Registerer, + nil, ) if o.RoutePrefix != "/" {