mirror of
https://github.com/prometheus/prometheus.git
synced 2024-12-24 05:04:05 -08:00
Added CORS Origin flag (#5011)
Signed-off-by: Hrishikesh Barman <hrishikeshbman@gmail.com>
This commit is contained in:
parent
c44cd7e166
commit
a1f34bec2e
|
@ -26,6 +26,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -50,6 +51,7 @@ import (
|
|||
"github.com/prometheus/prometheus/discovery"
|
||||
sd_config "github.com/prometheus/prometheus/discovery/config"
|
||||
"github.com/prometheus/prometheus/notifier"
|
||||
"github.com/prometheus/prometheus/pkg/relabel"
|
||||
"github.com/prometheus/prometheus/promql"
|
||||
"github.com/prometheus/prometheus/rules"
|
||||
"github.com/prometheus/prometheus/scrape"
|
||||
|
@ -99,7 +101,8 @@ func main() {
|
|||
queryMaxSamples int
|
||||
RemoteFlushDeadline model.Duration
|
||||
|
||||
prometheusURL string
|
||||
prometheusURL string
|
||||
corsRegexString string
|
||||
|
||||
promlogConfig promlog.Config
|
||||
}{
|
||||
|
@ -209,6 +212,9 @@ func main() {
|
|||
a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return.").
|
||||
Default("50000000").IntVar(&cfg.queryMaxSamples)
|
||||
|
||||
a.Flag("web.cors.origin", `Regex for CORS origin. It is fully anchored. Eg. 'https?://(domain1|domain2)\.com'`).
|
||||
Default(".*").StringVar(&cfg.corsRegexString)
|
||||
|
||||
promlogflag.AddFlags(a, &cfg.promlogConfig)
|
||||
|
||||
_, err := a.Parse(os.Args[1:])
|
||||
|
@ -224,6 +230,12 @@ func main() {
|
|||
os.Exit(2)
|
||||
}
|
||||
|
||||
cfg.web.CORSOrigin, err = compileCORSRegexString(cfg.corsRegexString)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, errors.Wrapf(err, "could not compile CORS regex string %q", cfg.corsRegexString))
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
cfg.web.ReadTimeout = time.Duration(cfg.webTimeout)
|
||||
// Default -web.route-prefix to path of -web.external-url.
|
||||
if cfg.web.RoutePrefix == "" {
|
||||
|
@ -674,6 +686,15 @@ func startsOrEndsWithQuote(s string) bool {
|
|||
strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'")
|
||||
}
|
||||
|
||||
// compileCORSRegexString compiles given string and adds anchors
|
||||
func compileCORSRegexString(s string) (*regexp.Regexp, error) {
|
||||
r, err := relabel.NewRegexp(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Regexp, nil
|
||||
}
|
||||
|
||||
// computeExternalURL computes a sanitized external URL from a raw input. It infers unset
|
||||
// URL parts from the OS and the given listen address.
|
||||
func computeExternalURL(u, listenAddr string) (*url.URL, error) {
|
||||
|
|
47
util/httputil/cors.go
Normal file
47
util/httputil/cors.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2013 The Prometheus Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var corsHeaders = map[string]string{
|
||||
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Expose-Headers": "Date",
|
||||
"Vary": "Origin",
|
||||
}
|
||||
|
||||
// Enables cross-site script calls.
|
||||
func SetCORS(w http.ResponseWriter, o *regexp.Regexp, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range corsHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
if o.String() == ".*" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
return
|
||||
}
|
||||
|
||||
if o.MatchString(origin) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
}
|
80
util/httputil/cors_test.go
Normal file
80
util/httputil/cors_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2016 The Prometheus Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getCORSHandlerFunc() http.Handler {
|
||||
hf := func(w http.ResponseWriter, r *http.Request) {
|
||||
reg := regexp.MustCompile(`^https://foo\.com$`)
|
||||
SetCORS(w, reg, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return http.HandlerFunc(hf)
|
||||
}
|
||||
|
||||
func TestCORSHandler(t *testing.T) {
|
||||
tearDown := setup()
|
||||
defer tearDown()
|
||||
client := &http.Client{}
|
||||
|
||||
ch := getCORSHandlerFunc()
|
||||
mux.Handle("/any_path", ch)
|
||||
|
||||
dummyOrigin := "https://foo.com"
|
||||
|
||||
// OPTIONS with legit origin
|
||||
req, err := http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Error("could not create request")
|
||||
}
|
||||
|
||||
req.Header.Set("Origin", dummyOrigin)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
t.Error("client get failed with unexpected error")
|
||||
}
|
||||
|
||||
AccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin")
|
||||
|
||||
if AccessControlAllowOrigin != dummyOrigin {
|
||||
t.Fatalf("%q does not match %q", dummyOrigin, AccessControlAllowOrigin)
|
||||
}
|
||||
|
||||
// OPTIONS with bad origin
|
||||
req, err = http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Error("could not create request")
|
||||
}
|
||||
|
||||
req.Header.Set("Origin", "https://not-foo.com")
|
||||
resp, err = client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
t.Error("client get failed with unexpected error")
|
||||
}
|
||||
|
||||
AccessControlAllowOrigin = resp.Header.Get("Access-Control-Allow-Origin")
|
||||
|
||||
if AccessControlAllowOrigin != "" {
|
||||
t.Fatalf("Access-Control-Allow-Origin should not exist but it was set to: %q", AccessControlAllowOrigin)
|
||||
}
|
||||
}
|
|
@ -23,6 +23,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -76,13 +77,6 @@ const (
|
|||
errorNotFound errorType = "not_found"
|
||||
)
|
||||
|
||||
var corsHeaders = map[string]string{
|
||||
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Expose-Headers": "Date",
|
||||
}
|
||||
|
||||
var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
|
@ -129,13 +123,6 @@ type apiFuncResult struct {
|
|||
finalizer func()
|
||||
}
|
||||
|
||||
// Enables cross-site script calls.
|
||||
func setCORS(w http.ResponseWriter) {
|
||||
for h, v := range corsHeaders {
|
||||
w.Header().Set(h, v)
|
||||
}
|
||||
}
|
||||
|
||||
type apiFunc func(r *http.Request) apiFuncResult
|
||||
|
||||
// TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations.
|
||||
|
@ -165,6 +152,7 @@ type API struct {
|
|||
logger log.Logger
|
||||
remoteReadSampleLimit int
|
||||
remoteReadGate *gate.Gate
|
||||
CORSOrigin *regexp.Regexp
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -187,6 +175,7 @@ func NewAPI(
|
|||
rr rulesRetriever,
|
||||
remoteReadSampleLimit int,
|
||||
remoteReadConcurrencyLimit int,
|
||||
CORSOrigin *regexp.Regexp,
|
||||
) *API {
|
||||
return &API{
|
||||
QueryEngine: qe,
|
||||
|
@ -204,6 +193,7 @@ func NewAPI(
|
|||
remoteReadSampleLimit: remoteReadSampleLimit,
|
||||
remoteReadGate: gate.New(remoteReadConcurrencyLimit),
|
||||
logger: logger,
|
||||
CORSOrigin: CORSOrigin,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,7 +201,7 @@ func NewAPI(
|
|||
func (api *API) Register(r *route.Router) {
|
||||
wrap := func(f apiFunc) http.HandlerFunc {
|
||||
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
setCORS(w)
|
||||
httputil.SetCORS(w, api.CORSOrigin, r)
|
||||
result := f(r)
|
||||
if result.err != nil {
|
||||
api.respondError(w, result.err, result.data)
|
||||
|
|
|
@ -1416,12 +1416,6 @@ func TestOptionsMethod(t *testing.T) {
|
|||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode)
|
||||
}
|
||||
|
||||
for h, v := range corsHeaders {
|
||||
if resp.Header.Get(h) != v {
|
||||
t.Fatalf("Expected %q for header %q, got %q", v, h, resp.Header.Get(h))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespond(t *testing.T) {
|
||||
|
|
19
web/web.go
19
web/web.go
|
@ -28,6 +28,7 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -173,6 +174,7 @@ type Options struct {
|
|||
Flags map[string]string
|
||||
|
||||
ListenAddress string
|
||||
CORSOrigin *regexp.Regexp
|
||||
ReadTimeout time.Duration
|
||||
MaxConnections int
|
||||
ExternalURL *url.URL
|
||||
|
@ -259,6 +261,7 @@ func New(logger log.Logger, o *Options) *Handler {
|
|||
h.ruleManager,
|
||||
h.options.RemoteReadSampleLimit,
|
||||
h.options.RemoteReadConcurrencyLimit,
|
||||
h.options.CORSOrigin,
|
||||
)
|
||||
|
||||
if o.RoutePrefix != "/" {
|
||||
|
@ -340,20 +343,6 @@ func New(logger log.Logger, o *Options) *Handler {
|
|||
return h
|
||||
}
|
||||
|
||||
var corsHeaders = map[string]string{
|
||||
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Expose-Headers": "Date",
|
||||
}
|
||||
|
||||
// Enables cross-site script calls.
|
||||
func setCORS(w http.ResponseWriter) {
|
||||
for h, v := range corsHeaders {
|
||||
w.Header().Set(h, v)
|
||||
}
|
||||
}
|
||||
|
||||
func serveDebug(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
subpath := route.Param(ctx, "subpath")
|
||||
|
@ -474,7 +463,7 @@ func (h *Handler) Run(ctx context.Context) error {
|
|||
|
||||
mux.Handle(apiPath+"/", http.StripPrefix(apiPath,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
setCORS(w)
|
||||
httputil.SetCORS(w, h.options.CORSOrigin, r)
|
||||
hhFunc(w, r)
|
||||
}),
|
||||
))
|
||||
|
|
Loading…
Reference in a new issue