From a1f34bec2e6584a2fee9aec901f3157e3e12cbaa Mon Sep 17 00:00:00 2001 From: Hrishikesh Barman Date: Thu, 17 Jan 2019 20:31:06 +0530 Subject: [PATCH] Added CORS Origin flag (#5011) Signed-off-by: Hrishikesh Barman --- cmd/prometheus/main.go | 23 ++++++++++- util/httputil/cors.go | 47 ++++++++++++++++++++++ util/httputil/cors_test.go | 80 ++++++++++++++++++++++++++++++++++++++ web/api/v1/api.go | 20 +++------- web/api/v1/api_test.go | 6 --- web/web.go | 19 ++------- 6 files changed, 158 insertions(+), 37 deletions(-) create mode 100644 util/httputil/cors.go create mode 100644 util/httputil/cors_test.go diff --git a/cmd/prometheus/main.go b/cmd/prometheus/main.go index 2cf065a62..75efeca80 100644 --- a/cmd/prometheus/main.go +++ b/cmd/prometheus/main.go @@ -26,6 +26,7 @@ import ( "os" "os/signal" "path/filepath" + "regexp" "runtime" "strings" "sync" @@ -50,6 +51,7 @@ import ( "github.com/prometheus/prometheus/discovery" sd_config "github.com/prometheus/prometheus/discovery/config" "github.com/prometheus/prometheus/notifier" + "github.com/prometheus/prometheus/pkg/relabel" "github.com/prometheus/prometheus/promql" "github.com/prometheus/prometheus/rules" "github.com/prometheus/prometheus/scrape" @@ -99,7 +101,8 @@ func main() { queryMaxSamples int RemoteFlushDeadline model.Duration - prometheusURL string + prometheusURL string + corsRegexString string promlogConfig promlog.Config }{ @@ -209,6 +212,9 @@ func main() { a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return."). Default("50000000").IntVar(&cfg.queryMaxSamples) + a.Flag("web.cors.origin", `Regex for CORS origin. It is fully anchored. Eg. 'https?://(domain1|domain2)\.com'`). + Default(".*").StringVar(&cfg.corsRegexString) + promlogflag.AddFlags(a, &cfg.promlogConfig) _, err := a.Parse(os.Args[1:]) @@ -224,6 +230,12 @@ func main() { os.Exit(2) } + cfg.web.CORSOrigin, err = compileCORSRegexString(cfg.corsRegexString) + if err != nil { + fmt.Fprintln(os.Stderr, errors.Wrapf(err, "could not compile CORS regex string %q", cfg.corsRegexString)) + os.Exit(2) + } + cfg.web.ReadTimeout = time.Duration(cfg.webTimeout) // Default -web.route-prefix to path of -web.external-url. if cfg.web.RoutePrefix == "" { @@ -674,6 +686,15 @@ func startsOrEndsWithQuote(s string) bool { strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'") } +// compileCORSRegexString compiles given string and adds anchors +func compileCORSRegexString(s string) (*regexp.Regexp, error) { + r, err := relabel.NewRegexp(s) + if err != nil { + return nil, err + } + return r.Regexp, nil +} + // computeExternalURL computes a sanitized external URL from a raw input. It infers unset // URL parts from the OS and the given listen address. func computeExternalURL(u, listenAddr string) (*url.URL, error) { diff --git a/util/httputil/cors.go b/util/httputil/cors.go new file mode 100644 index 000000000..6e0b5bbfc --- /dev/null +++ b/util/httputil/cors.go @@ -0,0 +1,47 @@ +// Copyright 2013 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +import ( + "net/http" + "regexp" +) + +var corsHeaders = map[string]string{ + "Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Expose-Headers": "Date", + "Vary": "Origin", +} + +// Enables cross-site script calls. +func SetCORS(w http.ResponseWriter, o *regexp.Regexp, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + return + } + + for k, v := range corsHeaders { + w.Header().Set(k, v) + } + + if o.String() == ".*" { + w.Header().Set("Access-Control-Allow-Origin", "*") + return + } + + if o.MatchString(origin) { + w.Header().Set("Access-Control-Allow-Origin", origin) + } +} diff --git a/util/httputil/cors_test.go b/util/httputil/cors_test.go new file mode 100644 index 000000000..33b88ffca --- /dev/null +++ b/util/httputil/cors_test.go @@ -0,0 +1,80 @@ +// Copyright 2016 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +import ( + "net/http" + "regexp" + "testing" +) + +func getCORSHandlerFunc() http.Handler { + hf := func(w http.ResponseWriter, r *http.Request) { + reg := regexp.MustCompile(`^https://foo\.com$`) + SetCORS(w, reg, r) + w.WriteHeader(http.StatusOK) + } + return http.HandlerFunc(hf) +} + +func TestCORSHandler(t *testing.T) { + tearDown := setup() + defer tearDown() + client := &http.Client{} + + ch := getCORSHandlerFunc() + mux.Handle("/any_path", ch) + + dummyOrigin := "https://foo.com" + + // OPTIONS with legit origin + req, err := http.NewRequest("OPTIONS", server.URL+"/any_path", nil) + + if err != nil { + t.Error("could not create request") + } + + req.Header.Set("Origin", dummyOrigin) + resp, err := client.Do(req) + + if err != nil { + t.Error("client get failed with unexpected error") + } + + AccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin") + + if AccessControlAllowOrigin != dummyOrigin { + t.Fatalf("%q does not match %q", dummyOrigin, AccessControlAllowOrigin) + } + + // OPTIONS with bad origin + req, err = http.NewRequest("OPTIONS", server.URL+"/any_path", nil) + + if err != nil { + t.Error("could not create request") + } + + req.Header.Set("Origin", "https://not-foo.com") + resp, err = client.Do(req) + + if err != nil { + t.Error("client get failed with unexpected error") + } + + AccessControlAllowOrigin = resp.Header.Get("Access-Control-Allow-Origin") + + if AccessControlAllowOrigin != "" { + t.Fatalf("Access-Control-Allow-Origin should not exist but it was set to: %q", AccessControlAllowOrigin) + } +} diff --git a/web/api/v1/api.go b/web/api/v1/api.go index 8c3ec6df4..2f1772692 100644 --- a/web/api/v1/api.go +++ b/web/api/v1/api.go @@ -23,6 +23,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "sort" "strconv" "time" @@ -76,13 +77,6 @@ const ( errorNotFound errorType = "not_found" ) -var corsHeaders = map[string]string{ - "Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin", - "Access-Control-Allow-Methods": "GET, OPTIONS", - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "Date", -} - var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, Subsystem: subsystem, @@ -129,13 +123,6 @@ type apiFuncResult struct { finalizer func() } -// Enables cross-site script calls. -func setCORS(w http.ResponseWriter) { - for h, v := range corsHeaders { - w.Header().Set(h, v) - } -} - type apiFunc func(r *http.Request) apiFuncResult // TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations. @@ -165,6 +152,7 @@ type API struct { logger log.Logger remoteReadSampleLimit int remoteReadGate *gate.Gate + CORSOrigin *regexp.Regexp } func init() { @@ -187,6 +175,7 @@ func NewAPI( rr rulesRetriever, remoteReadSampleLimit int, remoteReadConcurrencyLimit int, + CORSOrigin *regexp.Regexp, ) *API { return &API{ QueryEngine: qe, @@ -204,6 +193,7 @@ func NewAPI( remoteReadSampleLimit: remoteReadSampleLimit, remoteReadGate: gate.New(remoteReadConcurrencyLimit), logger: logger, + CORSOrigin: CORSOrigin, } } @@ -211,7 +201,7 @@ func NewAPI( func (api *API) Register(r *route.Router) { wrap := func(f apiFunc) http.HandlerFunc { hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - setCORS(w) + httputil.SetCORS(w, api.CORSOrigin, r) result := f(r) if result.err != nil { api.respondError(w, result.err, result.data) diff --git a/web/api/v1/api_test.go b/web/api/v1/api_test.go index 1b98f3afc..0efdd32d3 100644 --- a/web/api/v1/api_test.go +++ b/web/api/v1/api_test.go @@ -1416,12 +1416,6 @@ func TestOptionsMethod(t *testing.T) { if resp.StatusCode != http.StatusNoContent { t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode) } - - for h, v := range corsHeaders { - if resp.Header.Get(h) != v { - t.Fatalf("Expected %q for header %q, got %q", v, h, resp.Header.Get(h)) - } - } } func TestRespond(t *testing.T) { diff --git a/web/web.go b/web/web.go index fe83fb272..ecd09e687 100644 --- a/web/web.go +++ b/web/web.go @@ -28,6 +28,7 @@ import ( "os" "path" "path/filepath" + "regexp" "runtime" "sort" "strings" @@ -173,6 +174,7 @@ type Options struct { Flags map[string]string ListenAddress string + CORSOrigin *regexp.Regexp ReadTimeout time.Duration MaxConnections int ExternalURL *url.URL @@ -259,6 +261,7 @@ func New(logger log.Logger, o *Options) *Handler { h.ruleManager, h.options.RemoteReadSampleLimit, h.options.RemoteReadConcurrencyLimit, + h.options.CORSOrigin, ) if o.RoutePrefix != "/" { @@ -340,20 +343,6 @@ func New(logger log.Logger, o *Options) *Handler { return h } -var corsHeaders = map[string]string{ - "Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin", - "Access-Control-Allow-Methods": "GET, OPTIONS", - "Access-Control-Allow-Origin": "*", - "Access-Control-Expose-Headers": "Date", -} - -// Enables cross-site script calls. -func setCORS(w http.ResponseWriter) { - for h, v := range corsHeaders { - w.Header().Set(h, v) - } -} - func serveDebug(w http.ResponseWriter, req *http.Request) { ctx := req.Context() subpath := route.Param(ctx, "subpath") @@ -474,7 +463,7 @@ func (h *Handler) Run(ctx context.Context) error { mux.Handle(apiPath+"/", http.StripPrefix(apiPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - setCORS(w) + httputil.SetCORS(w, h.options.CORSOrigin, r) hhFunc(w, r) }), ))