mirror of
https://github.com/prometheus/prometheus.git
synced 2024-12-24 21:24:05 -08:00
Added CORS Origin flag (#5011)
Signed-off-by: Hrishikesh Barman <hrishikeshbman@gmail.com>
This commit is contained in:
parent
c44cd7e166
commit
a1f34bec2e
|
@ -26,6 +26,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -50,6 +51,7 @@ import (
|
||||||
"github.com/prometheus/prometheus/discovery"
|
"github.com/prometheus/prometheus/discovery"
|
||||||
sd_config "github.com/prometheus/prometheus/discovery/config"
|
sd_config "github.com/prometheus/prometheus/discovery/config"
|
||||||
"github.com/prometheus/prometheus/notifier"
|
"github.com/prometheus/prometheus/notifier"
|
||||||
|
"github.com/prometheus/prometheus/pkg/relabel"
|
||||||
"github.com/prometheus/prometheus/promql"
|
"github.com/prometheus/prometheus/promql"
|
||||||
"github.com/prometheus/prometheus/rules"
|
"github.com/prometheus/prometheus/rules"
|
||||||
"github.com/prometheus/prometheus/scrape"
|
"github.com/prometheus/prometheus/scrape"
|
||||||
|
@ -100,6 +102,7 @@ func main() {
|
||||||
RemoteFlushDeadline model.Duration
|
RemoteFlushDeadline model.Duration
|
||||||
|
|
||||||
prometheusURL string
|
prometheusURL string
|
||||||
|
corsRegexString string
|
||||||
|
|
||||||
promlogConfig promlog.Config
|
promlogConfig promlog.Config
|
||||||
}{
|
}{
|
||||||
|
@ -209,6 +212,9 @@ func main() {
|
||||||
a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return.").
|
a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return.").
|
||||||
Default("50000000").IntVar(&cfg.queryMaxSamples)
|
Default("50000000").IntVar(&cfg.queryMaxSamples)
|
||||||
|
|
||||||
|
a.Flag("web.cors.origin", `Regex for CORS origin. It is fully anchored. Eg. 'https?://(domain1|domain2)\.com'`).
|
||||||
|
Default(".*").StringVar(&cfg.corsRegexString)
|
||||||
|
|
||||||
promlogflag.AddFlags(a, &cfg.promlogConfig)
|
promlogflag.AddFlags(a, &cfg.promlogConfig)
|
||||||
|
|
||||||
_, err := a.Parse(os.Args[1:])
|
_, err := a.Parse(os.Args[1:])
|
||||||
|
@ -224,6 +230,12 @@ func main() {
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cfg.web.CORSOrigin, err = compileCORSRegexString(cfg.corsRegexString)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, errors.Wrapf(err, "could not compile CORS regex string %q", cfg.corsRegexString))
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
cfg.web.ReadTimeout = time.Duration(cfg.webTimeout)
|
cfg.web.ReadTimeout = time.Duration(cfg.webTimeout)
|
||||||
// Default -web.route-prefix to path of -web.external-url.
|
// Default -web.route-prefix to path of -web.external-url.
|
||||||
if cfg.web.RoutePrefix == "" {
|
if cfg.web.RoutePrefix == "" {
|
||||||
|
@ -674,6 +686,15 @@ func startsOrEndsWithQuote(s string) bool {
|
||||||
strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'")
|
strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compileCORSRegexString compiles given string and adds anchors
|
||||||
|
func compileCORSRegexString(s string) (*regexp.Regexp, error) {
|
||||||
|
r, err := relabel.NewRegexp(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return r.Regexp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// computeExternalURL computes a sanitized external URL from a raw input. It infers unset
|
// computeExternalURL computes a sanitized external URL from a raw input. It infers unset
|
||||||
// URL parts from the OS and the given listen address.
|
// URL parts from the OS and the given listen address.
|
||||||
func computeExternalURL(u, listenAddr string) (*url.URL, error) {
|
func computeExternalURL(u, listenAddr string) (*url.URL, error) {
|
||||||
|
|
47
util/httputil/cors.go
Normal file
47
util/httputil/cors.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
// Copyright 2013 The Prometheus Authors
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var corsHeaders = map[string]string{
|
||||||
|
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||||
|
"Access-Control-Expose-Headers": "Date",
|
||||||
|
"Vary": "Origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enables cross-site script calls.
|
||||||
|
func SetCORS(w http.ResponseWriter, o *regexp.Regexp, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range corsHeaders {
|
||||||
|
w.Header().Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.String() == ".*" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.MatchString(origin) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
}
|
80
util/httputil/cors_test.go
Normal file
80
util/httputil/cors_test.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
// Copyright 2016 The Prometheus Authors
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getCORSHandlerFunc() http.Handler {
|
||||||
|
hf := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reg := regexp.MustCompile(`^https://foo\.com$`)
|
||||||
|
SetCORS(w, reg, r)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(hf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSHandler(t *testing.T) {
|
||||||
|
tearDown := setup()
|
||||||
|
defer tearDown()
|
||||||
|
client := &http.Client{}
|
||||||
|
|
||||||
|
ch := getCORSHandlerFunc()
|
||||||
|
mux.Handle("/any_path", ch)
|
||||||
|
|
||||||
|
dummyOrigin := "https://foo.com"
|
||||||
|
|
||||||
|
// OPTIONS with legit origin
|
||||||
|
req, err := http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not create request")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Origin", dummyOrigin)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("client get failed with unexpected error")
|
||||||
|
}
|
||||||
|
|
||||||
|
AccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin")
|
||||||
|
|
||||||
|
if AccessControlAllowOrigin != dummyOrigin {
|
||||||
|
t.Fatalf("%q does not match %q", dummyOrigin, AccessControlAllowOrigin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPTIONS with bad origin
|
||||||
|
req, err = http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not create request")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Origin", "https://not-foo.com")
|
||||||
|
resp, err = client.Do(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("client get failed with unexpected error")
|
||||||
|
}
|
||||||
|
|
||||||
|
AccessControlAllowOrigin = resp.Header.Get("Access-Control-Allow-Origin")
|
||||||
|
|
||||||
|
if AccessControlAllowOrigin != "" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Origin should not exist but it was set to: %q", AccessControlAllowOrigin)
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -76,13 +77,6 @@ const (
|
||||||
errorNotFound errorType = "not_found"
|
errorNotFound errorType = "not_found"
|
||||||
)
|
)
|
||||||
|
|
||||||
var corsHeaders = map[string]string{
|
|
||||||
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Expose-Headers": "Date",
|
|
||||||
}
|
|
||||||
|
|
||||||
var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{
|
var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
Subsystem: subsystem,
|
Subsystem: subsystem,
|
||||||
|
@ -129,13 +123,6 @@ type apiFuncResult struct {
|
||||||
finalizer func()
|
finalizer func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enables cross-site script calls.
|
|
||||||
func setCORS(w http.ResponseWriter) {
|
|
||||||
for h, v := range corsHeaders {
|
|
||||||
w.Header().Set(h, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type apiFunc func(r *http.Request) apiFuncResult
|
type apiFunc func(r *http.Request) apiFuncResult
|
||||||
|
|
||||||
// TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations.
|
// TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations.
|
||||||
|
@ -165,6 +152,7 @@ type API struct {
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
remoteReadSampleLimit int
|
remoteReadSampleLimit int
|
||||||
remoteReadGate *gate.Gate
|
remoteReadGate *gate.Gate
|
||||||
|
CORSOrigin *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -187,6 +175,7 @@ func NewAPI(
|
||||||
rr rulesRetriever,
|
rr rulesRetriever,
|
||||||
remoteReadSampleLimit int,
|
remoteReadSampleLimit int,
|
||||||
remoteReadConcurrencyLimit int,
|
remoteReadConcurrencyLimit int,
|
||||||
|
CORSOrigin *regexp.Regexp,
|
||||||
) *API {
|
) *API {
|
||||||
return &API{
|
return &API{
|
||||||
QueryEngine: qe,
|
QueryEngine: qe,
|
||||||
|
@ -204,6 +193,7 @@ func NewAPI(
|
||||||
remoteReadSampleLimit: remoteReadSampleLimit,
|
remoteReadSampleLimit: remoteReadSampleLimit,
|
||||||
remoteReadGate: gate.New(remoteReadConcurrencyLimit),
|
remoteReadGate: gate.New(remoteReadConcurrencyLimit),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
CORSOrigin: CORSOrigin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,7 +201,7 @@ func NewAPI(
|
||||||
func (api *API) Register(r *route.Router) {
|
func (api *API) Register(r *route.Router) {
|
||||||
wrap := func(f apiFunc) http.HandlerFunc {
|
wrap := func(f apiFunc) http.HandlerFunc {
|
||||||
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
setCORS(w)
|
httputil.SetCORS(w, api.CORSOrigin, r)
|
||||||
result := f(r)
|
result := f(r)
|
||||||
if result.err != nil {
|
if result.err != nil {
|
||||||
api.respondError(w, result.err, result.data)
|
api.respondError(w, result.err, result.data)
|
||||||
|
|
|
@ -1416,12 +1416,6 @@ func TestOptionsMethod(t *testing.T) {
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode)
|
t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
for h, v := range corsHeaders {
|
|
||||||
if resp.Header.Get(h) != v {
|
|
||||||
t.Fatalf("Expected %q for header %q, got %q", v, h, resp.Header.Get(h))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRespond(t *testing.T) {
|
func TestRespond(t *testing.T) {
|
||||||
|
|
19
web/web.go
19
web/web.go
|
@ -28,6 +28,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -173,6 +174,7 @@ type Options struct {
|
||||||
Flags map[string]string
|
Flags map[string]string
|
||||||
|
|
||||||
ListenAddress string
|
ListenAddress string
|
||||||
|
CORSOrigin *regexp.Regexp
|
||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
MaxConnections int
|
MaxConnections int
|
||||||
ExternalURL *url.URL
|
ExternalURL *url.URL
|
||||||
|
@ -259,6 +261,7 @@ func New(logger log.Logger, o *Options) *Handler {
|
||||||
h.ruleManager,
|
h.ruleManager,
|
||||||
h.options.RemoteReadSampleLimit,
|
h.options.RemoteReadSampleLimit,
|
||||||
h.options.RemoteReadConcurrencyLimit,
|
h.options.RemoteReadConcurrencyLimit,
|
||||||
|
h.options.CORSOrigin,
|
||||||
)
|
)
|
||||||
|
|
||||||
if o.RoutePrefix != "/" {
|
if o.RoutePrefix != "/" {
|
||||||
|
@ -340,20 +343,6 @@ func New(logger log.Logger, o *Options) *Handler {
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
var corsHeaders = map[string]string{
|
|
||||||
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
|
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Expose-Headers": "Date",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enables cross-site script calls.
|
|
||||||
func setCORS(w http.ResponseWriter) {
|
|
||||||
for h, v := range corsHeaders {
|
|
||||||
w.Header().Set(h, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func serveDebug(w http.ResponseWriter, req *http.Request) {
|
func serveDebug(w http.ResponseWriter, req *http.Request) {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
subpath := route.Param(ctx, "subpath")
|
subpath := route.Param(ctx, "subpath")
|
||||||
|
@ -474,7 +463,7 @@ func (h *Handler) Run(ctx context.Context) error {
|
||||||
|
|
||||||
mux.Handle(apiPath+"/", http.StripPrefix(apiPath,
|
mux.Handle(apiPath+"/", http.StripPrefix(apiPath,
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
setCORS(w)
|
httputil.SetCORS(w, h.options.CORSOrigin, r)
|
||||||
hhFunc(w, r)
|
hhFunc(w, r)
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
|
|
Loading…
Reference in a new issue