Added CORS Origin flag (#5011)

Signed-off-by: Hrishikesh Barman <hrishikeshbman@gmail.com>
This commit is contained in:
Hrishikesh Barman 2019-01-17 20:31:06 +05:30 committed by Brian Brazil
parent c44cd7e166
commit a1f34bec2e
6 changed files with 158 additions and 37 deletions

View file

@ -26,6 +26,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -50,6 +51,7 @@ import (
"github.com/prometheus/prometheus/discovery" "github.com/prometheus/prometheus/discovery"
sd_config "github.com/prometheus/prometheus/discovery/config" sd_config "github.com/prometheus/prometheus/discovery/config"
"github.com/prometheus/prometheus/notifier" "github.com/prometheus/prometheus/notifier"
"github.com/prometheus/prometheus/pkg/relabel"
"github.com/prometheus/prometheus/promql" "github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/rules" "github.com/prometheus/prometheus/rules"
"github.com/prometheus/prometheus/scrape" "github.com/prometheus/prometheus/scrape"
@ -99,7 +101,8 @@ func main() {
queryMaxSamples int queryMaxSamples int
RemoteFlushDeadline model.Duration RemoteFlushDeadline model.Duration
prometheusURL string prometheusURL string
corsRegexString string
promlogConfig promlog.Config promlogConfig promlog.Config
}{ }{
@ -209,6 +212,9 @@ func main() {
a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return."). a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return.").
Default("50000000").IntVar(&cfg.queryMaxSamples) Default("50000000").IntVar(&cfg.queryMaxSamples)
a.Flag("web.cors.origin", `Regex for CORS origin. It is fully anchored. Eg. 'https?://(domain1|domain2)\.com'`).
Default(".*").StringVar(&cfg.corsRegexString)
promlogflag.AddFlags(a, &cfg.promlogConfig) promlogflag.AddFlags(a, &cfg.promlogConfig)
_, err := a.Parse(os.Args[1:]) _, err := a.Parse(os.Args[1:])
@ -224,6 +230,12 @@ func main() {
os.Exit(2) os.Exit(2)
} }
cfg.web.CORSOrigin, err = compileCORSRegexString(cfg.corsRegexString)
if err != nil {
fmt.Fprintln(os.Stderr, errors.Wrapf(err, "could not compile CORS regex string %q", cfg.corsRegexString))
os.Exit(2)
}
cfg.web.ReadTimeout = time.Duration(cfg.webTimeout) cfg.web.ReadTimeout = time.Duration(cfg.webTimeout)
// Default -web.route-prefix to path of -web.external-url. // Default -web.route-prefix to path of -web.external-url.
if cfg.web.RoutePrefix == "" { if cfg.web.RoutePrefix == "" {
@ -674,6 +686,15 @@ func startsOrEndsWithQuote(s string) bool {
strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'") strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'")
} }
// compileCORSRegexString compiles given string and adds anchors
func compileCORSRegexString(s string) (*regexp.Regexp, error) {
r, err := relabel.NewRegexp(s)
if err != nil {
return nil, err
}
return r.Regexp, nil
}
// computeExternalURL computes a sanitized external URL from a raw input. It infers unset // computeExternalURL computes a sanitized external URL from a raw input. It infers unset
// URL parts from the OS and the given listen address. // URL parts from the OS and the given listen address.
func computeExternalURL(u, listenAddr string) (*url.URL, error) { func computeExternalURL(u, listenAddr string) (*url.URL, error) {

47
util/httputil/cors.go Normal file
View file

@ -0,0 +1,47 @@
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"net/http"
"regexp"
)
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Expose-Headers": "Date",
"Vary": "Origin",
}
// Enables cross-site script calls.
func SetCORS(w http.ResponseWriter, o *regexp.Regexp, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
return
}
for k, v := range corsHeaders {
w.Header().Set(k, v)
}
if o.String() == ".*" {
w.Header().Set("Access-Control-Allow-Origin", "*")
return
}
if o.MatchString(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}

View file

@ -0,0 +1,80 @@
// Copyright 2016 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"net/http"
"regexp"
"testing"
)
func getCORSHandlerFunc() http.Handler {
hf := func(w http.ResponseWriter, r *http.Request) {
reg := regexp.MustCompile(`^https://foo\.com$`)
SetCORS(w, reg, r)
w.WriteHeader(http.StatusOK)
}
return http.HandlerFunc(hf)
}
func TestCORSHandler(t *testing.T) {
tearDown := setup()
defer tearDown()
client := &http.Client{}
ch := getCORSHandlerFunc()
mux.Handle("/any_path", ch)
dummyOrigin := "https://foo.com"
// OPTIONS with legit origin
req, err := http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
if err != nil {
t.Error("could not create request")
}
req.Header.Set("Origin", dummyOrigin)
resp, err := client.Do(req)
if err != nil {
t.Error("client get failed with unexpected error")
}
AccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin")
if AccessControlAllowOrigin != dummyOrigin {
t.Fatalf("%q does not match %q", dummyOrigin, AccessControlAllowOrigin)
}
// OPTIONS with bad origin
req, err = http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
if err != nil {
t.Error("could not create request")
}
req.Header.Set("Origin", "https://not-foo.com")
resp, err = client.Do(req)
if err != nil {
t.Error("client get failed with unexpected error")
}
AccessControlAllowOrigin = resp.Header.Get("Access-Control-Allow-Origin")
if AccessControlAllowOrigin != "" {
t.Fatalf("Access-Control-Allow-Origin should not exist but it was set to: %q", AccessControlAllowOrigin)
}
}

View file

@ -23,6 +23,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strconv" "strconv"
"time" "time"
@ -76,13 +77,6 @@ const (
errorNotFound errorType = "not_found" errorNotFound errorType = "not_found"
) )
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Origin": "*",
"Access-Control-Expose-Headers": "Date",
}
var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{ var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
@ -129,13 +123,6 @@ type apiFuncResult struct {
finalizer func() finalizer func()
} }
// Enables cross-site script calls.
func setCORS(w http.ResponseWriter) {
for h, v := range corsHeaders {
w.Header().Set(h, v)
}
}
type apiFunc func(r *http.Request) apiFuncResult type apiFunc func(r *http.Request) apiFuncResult
// TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations. // TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations.
@ -165,6 +152,7 @@ type API struct {
logger log.Logger logger log.Logger
remoteReadSampleLimit int remoteReadSampleLimit int
remoteReadGate *gate.Gate remoteReadGate *gate.Gate
CORSOrigin *regexp.Regexp
} }
func init() { func init() {
@ -187,6 +175,7 @@ func NewAPI(
rr rulesRetriever, rr rulesRetriever,
remoteReadSampleLimit int, remoteReadSampleLimit int,
remoteReadConcurrencyLimit int, remoteReadConcurrencyLimit int,
CORSOrigin *regexp.Regexp,
) *API { ) *API {
return &API{ return &API{
QueryEngine: qe, QueryEngine: qe,
@ -204,6 +193,7 @@ func NewAPI(
remoteReadSampleLimit: remoteReadSampleLimit, remoteReadSampleLimit: remoteReadSampleLimit,
remoteReadGate: gate.New(remoteReadConcurrencyLimit), remoteReadGate: gate.New(remoteReadConcurrencyLimit),
logger: logger, logger: logger,
CORSOrigin: CORSOrigin,
} }
} }
@ -211,7 +201,7 @@ func NewAPI(
func (api *API) Register(r *route.Router) { func (api *API) Register(r *route.Router) {
wrap := func(f apiFunc) http.HandlerFunc { wrap := func(f apiFunc) http.HandlerFunc {
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
setCORS(w) httputil.SetCORS(w, api.CORSOrigin, r)
result := f(r) result := f(r)
if result.err != nil { if result.err != nil {
api.respondError(w, result.err, result.data) api.respondError(w, result.err, result.data)

View file

@ -1416,12 +1416,6 @@ func TestOptionsMethod(t *testing.T) {
if resp.StatusCode != http.StatusNoContent { if resp.StatusCode != http.StatusNoContent {
t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode) t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode)
} }
for h, v := range corsHeaders {
if resp.Header.Get(h) != v {
t.Fatalf("Expected %q for header %q, got %q", v, h, resp.Header.Get(h))
}
}
} }
func TestRespond(t *testing.T) { func TestRespond(t *testing.T) {

View file

@ -28,6 +28,7 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
@ -173,6 +174,7 @@ type Options struct {
Flags map[string]string Flags map[string]string
ListenAddress string ListenAddress string
CORSOrigin *regexp.Regexp
ReadTimeout time.Duration ReadTimeout time.Duration
MaxConnections int MaxConnections int
ExternalURL *url.URL ExternalURL *url.URL
@ -259,6 +261,7 @@ func New(logger log.Logger, o *Options) *Handler {
h.ruleManager, h.ruleManager,
h.options.RemoteReadSampleLimit, h.options.RemoteReadSampleLimit,
h.options.RemoteReadConcurrencyLimit, h.options.RemoteReadConcurrencyLimit,
h.options.CORSOrigin,
) )
if o.RoutePrefix != "/" { if o.RoutePrefix != "/" {
@ -340,20 +343,6 @@ func New(logger log.Logger, o *Options) *Handler {
return h return h
} }
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Origin": "*",
"Access-Control-Expose-Headers": "Date",
}
// Enables cross-site script calls.
func setCORS(w http.ResponseWriter) {
for h, v := range corsHeaders {
w.Header().Set(h, v)
}
}
func serveDebug(w http.ResponseWriter, req *http.Request) { func serveDebug(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
subpath := route.Param(ctx, "subpath") subpath := route.Param(ctx, "subpath")
@ -474,7 +463,7 @@ func (h *Handler) Run(ctx context.Context) error {
mux.Handle(apiPath+"/", http.StripPrefix(apiPath, mux.Handle(apiPath+"/", http.StripPrefix(apiPath,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
setCORS(w) httputil.SetCORS(w, h.options.CORSOrigin, r)
hhFunc(w, r) hhFunc(w, r)
}), }),
)) ))