diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 2da69ed67..4d990eb64 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -28,6 +28,8 @@ import ( "time" "unsafe" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" jsoniter "github.com/json-iterator/go" "github.com/prometheus/common/model" "github.com/prometheus/common/route" @@ -125,6 +127,7 @@ type API struct { db func() *tsdb.DB enableAdmin bool + logger log.Logger } // NewAPI returns an initialized API type. @@ -138,6 +141,7 @@ func NewAPI( readyFunc func(http.HandlerFunc) http.HandlerFunc, db func() *tsdb.DB, enableAdmin bool, + logger log.Logger, ) *API { return &API{ QueryEngine: qe, @@ -160,9 +164,9 @@ func (api *API) Register(r *route.Router) { setCORS(w) data, err, finalizer := f(r) if err != nil { - respondError(w, err, data) + api.respondError(w, err, data) } else if data != nil { - respond(w, data) + api.respond(w, data) } else { w.WriteHeader(http.StatusNoContent) } @@ -819,23 +823,38 @@ func mergeLabels(primary, secondary []*prompb.Label) []*prompb.Label { return result } -func respond(w http.ResponseWriter, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - +func (api *API) respond(w http.ResponseWriter, data interface{}) { json := jsoniter.ConfigCompatibleWithStandardLibrary b, err := json.Marshal(&response{ Status: statusSuccess, Data: data, }) if err != nil { + level.Error(api.logger).Log("msg", "error marshalling json response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) return } - w.Write(b) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if n, err := w.Write(b); err != nil { + level.Error(api.logger).Log("msg", "error writing response", "bytesWritten", n, "err", err) + } } -func respondError(w http.ResponseWriter, apiErr *apiError, data interface{}) { - w.Header().Set("Content-Type", "application/json") +func (api *API) respondError(w http.ResponseWriter, apiErr *apiError, data interface{}) { + json := jsoniter.ConfigCompatibleWithStandardLibrary + b, err := json.Marshal(&response{ + Status: statusError, + ErrorType: apiErr.typ, + Error: apiErr.err.Error(), + Data: data, + }) + if err != nil { + level.Error(api.logger).Log("msg", "error marshalling json response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } var code int switch apiErr.typ { @@ -852,19 +871,12 @@ func respondError(w http.ResponseWriter, apiErr *apiError, data interface{}) { default: code = http.StatusInternalServerError } - w.WriteHeader(code) - json := jsoniter.ConfigCompatibleWithStandardLibrary - b, err := json.Marshal(&response{ - Status: statusError, - ErrorType: apiErr.typ, - Error: apiErr.err.Error(), - Data: data, - }) - if err != nil { - return + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if n, err := w.Write(b); err != nil { + level.Error(api.logger).Log("msg", "error writing response", "bytesWritten", n, "err", err) } - w.Write(b) } func parseTime(s string) (time.Time, error) { diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index ac00cd577..a4807daf2 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -750,7 +750,8 @@ func TestReadEndpoint(t *testing.T) { func TestRespondSuccess(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - respond(w, "test") + api := API{} + api.respond(w, "test") })) defer s.Close() @@ -787,7 +788,8 @@ func TestRespondSuccess(t *testing.T) { func TestRespondError(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - respondError(w, &apiError{errorTimeout, errors.New("message")}, "test") + api := API{} + api.respondError(w, &apiError{errorTimeout, errors.New("message")}, "test") })) defer s.Close() @@ -1039,7 +1041,8 @@ func TestRespond(t *testing.T) { for _, c := range cases { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - respond(w, c.response) + api := API{} + api.respond(w, c.response) })) defer s.Close() @@ -1078,7 +1081,8 @@ func BenchmarkRespond(b *testing.B) { }, } b.ResetTimer() + api := API{} for n := 0; n < b.N; n++ { - respond(&testResponseWriter, response) + api.respond(&testResponseWriter, response) } } diff --git a/web/web.go b/web/web.go index 4a4c30078..af75a0827 100644 --- a/web/web.go +++ b/web/web.go @@ -227,6 +227,7 @@ func New(logger log.Logger, o *Options) *Handler { h.testReady, h.options.TSDB, h.options.EnableAdminAPI, + logger, ) if o.RoutePrefix != "/" {