Merge branch 'dev-2.0' into grobie/reduce-noisy-append-errors

This commit is contained in:
Fabian Reinartz 2017-05-24 15:29:30 +02:00 committed by GitHub
commit d3f662f15e
928 changed files with 160019 additions and 201162 deletions

2
.gitignore vendored
View file

@ -25,3 +25,5 @@ benchmark.txt
!/circle.yml
!/.travis.yml
!/.promu.yml
/documentation/examples/remote_storage/remote_storage_adapter/remote_storage_adapter
/documentation/examples/remote_storage/example_write_adapter/example_writer_adapter

View file

@ -1,3 +1,7 @@
## 1.6.2 / 2017-05-11
* [BUGFIX] Fix potential memory leak in Kubernetes service discovery
## 1.6.1 / 2017-04-19
* [BUGFIX] Don't panic if storage has no FPs even after initial wait
@ -67,6 +71,10 @@
* [BUGFIX] Fix deadlock in Zookeeper SD.
* [BUGFIX] Fix fuzzy search problems in the web-UI auto-completion.
## 1.5.3 / 2017-05-11
* [BUGFIX] Fix potential memory leak in Kubernetes service discovery
## 1.5.2 / 2017-02-10
* [BUGFIX] Fix series corruption in a special case of series maintenance where

View file

@ -1 +1 @@
2.0.0-alpha.0
2.0.0-alpha.1

View file

@ -125,17 +125,17 @@ func init() {
&cfg.localStoragePath, "storage.local.path", "data",
"Base path for metrics storage.",
)
cfg.fs.BoolVar(
&cfg.tsdb.NoLockfile, "storage.tsdb.no-lockfile", false,
"Disable lock file usage.",
)
cfg.fs.DurationVar(
&cfg.tsdb.MinBlockDuration, "storage.tsdb.min-block-duration", 2*time.Hour,
"Minimum duration of a data block before being persisted.",
)
cfg.fs.DurationVar(
&cfg.tsdb.MaxBlockDuration, "storage.tsdb.max-block-duration", 36*time.Hour,
"Maximum duration compacted blocks may span.",
)
cfg.fs.IntVar(
&cfg.tsdb.AppendableBlocks, "storage.tsdb.appendable-blocks", 2,
"Number of head blocks that can be appended to.",
&cfg.tsdb.MaxBlockDuration, "storage.tsdb.max-block-duration", 0,
"Maximum duration compacted blocks may span. (Defaults to 10% of the retention period)",
)
cfg.fs.DurationVar(
&cfg.tsdb.Retention, "storage.tsdb.retention", 15*24*time.Hour,
@ -206,6 +206,10 @@ func parse(args []string) error {
}
}
if cfg.tsdb.MaxBlockDuration == 0 {
cfg.tsdb.MaxBlockDuration = cfg.tsdb.Retention / 10
}
return nil
}

View file

@ -72,12 +72,18 @@ func Main() int {
log.Infoln("Starting prometheus", version.Info())
log.Infoln("Build context", version.BuildContext())
log.Infoln("Host details", Uname())
var (
// sampleAppender = storage.Fanout{}
reloadables []Reloadable
)
// Make sure that sighup handler is registered with a redirect to the channel before the potentially
// long and synchronous tsdb init.
hup := make(chan os.Signal)
hupReady := make(chan bool)
signal.Notify(hup, syscall.SIGHUP)
localStorage, err := tsdb.Open(cfg.localStoragePath, prometheus.DefaultRegisterer, &cfg.tsdb)
if err != nil {
log.Errorf("Opening storage failed: %s", err)
@ -136,9 +142,6 @@ func Main() int {
// Wait for reload or termination signals. Start the handler for SIGHUP as
// early as possible, but ignore it until we are ready to handle reloading
// our config.
hup := make(chan os.Signal)
hupReady := make(chan bool)
signal.Notify(hup, syscall.SIGHUP)
go func() {
<-hupReady
for {

View file

@ -0,0 +1,23 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !linux
package main
import "runtime"
// Uname for any platform other than linux.
func Uname() string {
return "(" + runtime.GOOS + ")"
}

View file

@ -0,0 +1,35 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"log"
"syscall"
)
// Uname returns the uname of the host machine.
func Uname() string {
buf := syscall.Utsname{}
err := syscall.Uname(&buf)
if err != nil {
log.Fatal("Error!")
}
str := "(" + charsToString(buf.Sysname[:])
str += " " + charsToString(buf.Release[:])
str += " " + charsToString(buf.Version[:])
str += " " + charsToString(buf.Machine[:])
str += " " + charsToString(buf.Nodename[:])
str += " " + charsToString(buf.Domainname[:]) + ")"
return str
}

View file

@ -0,0 +1,25 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build 386 amd64 arm64 mips64 mips64le mips mipsle
// +build linux
package main
func charsToString(ca []int8) string {
s := make([]byte, len(ca))
for i, c := range ca {
s[i] = byte(c)
}
return string(s[0:len(ca)])
}

View file

@ -0,0 +1,25 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build arm ppc64 ppc64le s390x
// +build linux
package main
func charsToString(ca []uint8) string {
s := make([]byte, len(ca))
for i, c := range ca {
s[i] = byte(c)
}
return string(s[0:len(ca)])
}

View file

@ -987,12 +987,13 @@ func (c *KubernetesRole) UnmarshalYAML(unmarshal func(interface{}) error) error
// KubernetesSDConfig is the configuration for Kubernetes service discovery.
type KubernetesSDConfig struct {
APIServer URL `yaml:"api_server"`
Role KubernetesRole `yaml:"role"`
BasicAuth *BasicAuth `yaml:"basic_auth,omitempty"`
BearerToken string `yaml:"bearer_token,omitempty"`
BearerTokenFile string `yaml:"bearer_token_file,omitempty"`
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
APIServer URL `yaml:"api_server"`
Role KubernetesRole `yaml:"role"`
BasicAuth *BasicAuth `yaml:"basic_auth,omitempty"`
BearerToken string `yaml:"bearer_token,omitempty"`
BearerTokenFile string `yaml:"bearer_token_file,omitempty"`
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
NamespaceDiscovery KubernetesNamespaceDiscovery `yaml:"namespaces"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
@ -1026,6 +1027,28 @@ func (c *KubernetesSDConfig) UnmarshalYAML(unmarshal func(interface{}) error) er
return nil
}
// KubernetesNamespaceDiscovery is the configuration for discovering
// Kubernetes namespaces.
type KubernetesNamespaceDiscovery struct {
Names []string `yaml:"names"`
// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *KubernetesNamespaceDiscovery) UnmarshalYAML(unmarshal func(interface{}) error) error {
*c = KubernetesNamespaceDiscovery{}
type plain KubernetesNamespaceDiscovery
err := unmarshal((*plain)(c))
if err != nil {
return err
}
if err := checkOverflow(c.XXX, "namespaces"); err != nil {
return err
}
return nil
}
// GCESDConfig is the configuration for GCE based service discovery.
type GCESDConfig struct {
// Project: The Google Cloud Project ID

View file

@ -305,6 +305,30 @@ var expectedConf = &Config{
Username: "myusername",
Password: "mypassword",
},
NamespaceDiscovery: KubernetesNamespaceDiscovery{},
},
},
},
},
{
JobName: "service-kubernetes-namespaces",
ScrapeInterval: model.Duration(15 * time.Second),
ScrapeTimeout: DefaultGlobalConfig.ScrapeTimeout,
MetricsPath: DefaultScrapeConfig.MetricsPath,
Scheme: DefaultScrapeConfig.Scheme,
ServiceDiscoveryConfig: ServiceDiscoveryConfig{
KubernetesSDConfigs: []*KubernetesSDConfig{
{
APIServer: kubernetesSDHostURL(),
Role: KubernetesRoleEndpoint,
NamespaceDiscovery: KubernetesNamespaceDiscovery{
Names: []string{
"default",
},
},
},
},
},
@ -592,6 +616,9 @@ var expectedErrors = []struct {
}, {
filename: "kubernetes_role.bad.yml",
errMsg: "role",
}, {
filename: "kubernetes_namespace_discovery.bad.yml",
errMsg: "unknown fields in namespaces",
}, {
filename: "kubernetes_bearertoken_basicauth.bad.yml",
errMsg: "at most one of basic_auth, bearer_token & bearer_token_file must be configured",

View file

@ -146,6 +146,15 @@ scrape_configs:
username: 'myusername'
password: 'mypassword'
- job_name: service-kubernetes-namespaces
kubernetes_sd_configs:
- role: endpoints
api_server: 'https://localhost:1234'
namespaces:
names:
- default
- job_name: service-marathon
marathon_sd_configs:
- servers:

View file

@ -0,0 +1,6 @@
scrape_configs:
- kubernetes_sd_configs:
- api_server: kubernetes:443
role: endpoints
namespaces:
foo: bar

View file

@ -23,8 +23,8 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"golang.org/x/net/context"
apiv1 "k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
apiv1 "k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
// Endpoints discovers new endpoint targets.

View file

@ -19,8 +19,9 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/config"
"k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
func endpointsStoreKeyFunc(obj interface{}) (string, error) {
@ -40,7 +41,7 @@ func makeTestEndpointsDiscovery() (*Endpoints, *fakeInformer, *fakeInformer, *fa
func makeEndpoints() *v1.Endpoints {
return &v1.Endpoints{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testendpoints",
Namespace: "default",
},
@ -123,7 +124,7 @@ func TestEndpointsDiscoveryInitial(t *testing.T) {
func TestEndpointsDiscoveryAdd(t *testing.T) {
n, _, eps, pods := makeTestEndpointsDiscovery()
pods.GetStore().Add(&v1.Pod{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testpod",
Namespace: "default",
},
@ -164,7 +165,7 @@ func TestEndpointsDiscoveryAdd(t *testing.T) {
go func() {
eps.Add(
&v1.Endpoints{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testendpoints",
Namespace: "default",
},
@ -273,7 +274,7 @@ func TestEndpointsDiscoveryUpdate(t *testing.T) {
afterStart: func() {
go func() {
eps.Update(&v1.Endpoints{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testendpoints",
Namespace: "default",
},

View file

@ -15,6 +15,7 @@ package kubernetes
import (
"io/ioutil"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -23,12 +24,12 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"golang.org/x/net/context"
"k8s.io/client-go/1.5/kubernetes"
"k8s.io/client-go/1.5/pkg/api"
apiv1 "k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/pkg/util/runtime"
"k8s.io/client-go/1.5/rest"
"k8s.io/client-go/1.5/tools/cache"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/pkg/api"
apiv1 "k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
)
const (
@ -62,9 +63,10 @@ func init() {
// Discovery implements the TargetProvider interface for discovering
// targets from Kubernetes.
type Discovery struct {
client kubernetes.Interface
role config.KubernetesRole
logger log.Logger
client kubernetes.Interface
role config.KubernetesRole
logger log.Logger
namespaceDiscovery *config.KubernetesNamespaceDiscovery
}
func init() {
@ -75,6 +77,14 @@ func init() {
}
}
func (d *Discovery) getNamespaces() []string {
namespaces := d.namespaceDiscovery.Names
if len(namespaces) == 0 {
namespaces = []string{api.NamespaceAll}
}
return namespaces
}
// New creates a new Kubernetes discovery for the given role.
func New(l log.Logger, conf *config.KubernetesSDConfig) (*Discovery, error) {
var (
@ -111,8 +121,8 @@ func New(l log.Logger, conf *config.KubernetesSDConfig) (*Discovery, error) {
CAFile: conf.TLSConfig.CAFile,
CertFile: conf.TLSConfig.CertFile,
KeyFile: conf.TLSConfig.KeyFile,
Insecure: conf.TLSConfig.InsecureSkipVerify,
},
Insecure: conf.TLSConfig.InsecureSkipVerify,
}
token := conf.BearerToken
if conf.BearerTokenFile != "" {
@ -137,9 +147,10 @@ func New(l log.Logger, conf *config.KubernetesSDConfig) (*Discovery, error) {
return nil, err
}
return &Discovery{
client: c,
logger: l,
role: conf.Role,
client: c,
logger: l,
role: conf.Role,
namespaceDiscovery: &conf.NamespaceDiscovery,
}, nil
}
@ -147,60 +158,84 @@ const resyncPeriod = 10 * time.Minute
// Run implements the TargetProvider interface.
func (d *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) {
rclient := d.client.Core().GetRESTClient()
rclient := d.client.Core().RESTClient()
namespaces := d.getNamespaces()
switch d.role {
case "endpoints":
elw := cache.NewListWatchFromClient(rclient, "endpoints", api.NamespaceAll, nil)
slw := cache.NewListWatchFromClient(rclient, "services", api.NamespaceAll, nil)
plw := cache.NewListWatchFromClient(rclient, "pods", api.NamespaceAll, nil)
eps := NewEndpoints(
d.logger.With("kubernetes_sd", "endpoint"),
cache.NewSharedInformer(slw, &apiv1.Service{}, resyncPeriod),
cache.NewSharedInformer(elw, &apiv1.Endpoints{}, resyncPeriod),
cache.NewSharedInformer(plw, &apiv1.Pod{}, resyncPeriod),
)
go eps.endpointsInf.Run(ctx.Done())
go eps.serviceInf.Run(ctx.Done())
go eps.podInf.Run(ctx.Done())
var wg sync.WaitGroup
for !eps.serviceInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
for !eps.endpointsInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
for !eps.podInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
eps.Run(ctx, ch)
for _, namespace := range namespaces {
elw := cache.NewListWatchFromClient(rclient, "endpoints", namespace, nil)
slw := cache.NewListWatchFromClient(rclient, "services", namespace, nil)
plw := cache.NewListWatchFromClient(rclient, "pods", namespace, nil)
eps := NewEndpoints(
d.logger.With("kubernetes_sd", "endpoint"),
cache.NewSharedInformer(slw, &apiv1.Service{}, resyncPeriod),
cache.NewSharedInformer(elw, &apiv1.Endpoints{}, resyncPeriod),
cache.NewSharedInformer(plw, &apiv1.Pod{}, resyncPeriod),
)
go eps.endpointsInf.Run(ctx.Done())
go eps.serviceInf.Run(ctx.Done())
go eps.podInf.Run(ctx.Done())
for !eps.serviceInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
for !eps.endpointsInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
for !eps.podInf.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
wg.Add(1)
go func() {
defer wg.Done()
eps.Run(ctx, ch)
}()
}
wg.Wait()
case "pod":
plw := cache.NewListWatchFromClient(rclient, "pods", api.NamespaceAll, nil)
pod := NewPod(
d.logger.With("kubernetes_sd", "pod"),
cache.NewSharedInformer(plw, &apiv1.Pod{}, resyncPeriod),
)
go pod.informer.Run(ctx.Done())
var wg sync.WaitGroup
for _, namespace := range namespaces {
plw := cache.NewListWatchFromClient(rclient, "pods", namespace, nil)
pod := NewPod(
d.logger.With("kubernetes_sd", "pod"),
cache.NewSharedInformer(plw, &apiv1.Pod{}, resyncPeriod),
)
go pod.informer.Run(ctx.Done())
for !pod.informer.HasSynced() {
time.Sleep(100 * time.Millisecond)
for !pod.informer.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
wg.Add(1)
go func() {
defer wg.Done()
pod.Run(ctx, ch)
}()
}
pod.Run(ctx, ch)
wg.Wait()
case "service":
slw := cache.NewListWatchFromClient(rclient, "services", api.NamespaceAll, nil)
svc := NewService(
d.logger.With("kubernetes_sd", "service"),
cache.NewSharedInformer(slw, &apiv1.Service{}, resyncPeriod),
)
go svc.informer.Run(ctx.Done())
var wg sync.WaitGroup
for _, namespace := range namespaces {
slw := cache.NewListWatchFromClient(rclient, "services", namespace, nil)
svc := NewService(
d.logger.With("kubernetes_sd", "service"),
cache.NewSharedInformer(slw, &apiv1.Service{}, resyncPeriod),
)
go svc.informer.Run(ctx.Done())
for !svc.informer.HasSynced() {
time.Sleep(100 * time.Millisecond)
for !svc.informer.HasSynced() {
time.Sleep(100 * time.Millisecond)
}
wg.Add(1)
go func() {
defer wg.Done()
svc.Run(ctx, ch)
}()
}
svc.Run(ctx, ch)
wg.Wait()
case "node":
nlw := cache.NewListWatchFromClient(rclient, "nodes", api.NamespaceAll, nil)
node := NewNode(

View file

@ -23,9 +23,9 @@ import (
"github.com/prometheus/prometheus/config"
"github.com/prometheus/prometheus/util/strutil"
"golang.org/x/net/context"
"k8s.io/client-go/1.5/pkg/api"
apiv1 "k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
"k8s.io/client-go/pkg/api"
apiv1 "k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
// Node discovers Kubernetes nodes.

View file

@ -25,8 +25,9 @@ import (
"github.com/prometheus/prometheus/config"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
type fakeInformer struct {
@ -46,18 +47,21 @@ func newFakeInformer(f func(obj interface{}) (string, error)) *fakeInformer {
return i
}
func (i *fakeInformer) AddEventHandler(handler cache.ResourceEventHandler) error {
i.handlers = append(i.handlers, handler)
func (i *fakeInformer) AddEventHandler(h cache.ResourceEventHandler) {
i.handlers = append(i.handlers, h)
// Only now that there is a registered handler, we are able to handle deltas.
i.blockDeltas.Unlock()
return nil
}
func (i *fakeInformer) AddEventHandlerWithResyncPeriod(h cache.ResourceEventHandler, _ time.Duration) {
i.AddEventHandler(h)
}
func (i *fakeInformer) GetStore() cache.Store {
return i.store
}
func (i *fakeInformer) GetController() cache.ControllerInterface {
func (i *fakeInformer) GetController() cache.Controller {
return nil
}
@ -160,7 +164,7 @@ func makeTestNodeDiscovery() (*Node, *fakeInformer) {
func makeNode(name, address string, labels map[string]string, annotations map[string]string) *v1.Node {
return &v1.Node{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Labels: labels,
Annotations: annotations,

View file

@ -24,9 +24,9 @@ import (
"github.com/prometheus/prometheus/config"
"github.com/prometheus/prometheus/util/strutil"
"golang.org/x/net/context"
"k8s.io/client-go/1.5/pkg/api"
apiv1 "k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
"k8s.io/client-go/pkg/api"
apiv1 "k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
// Pod discovers new pod targets.

View file

@ -19,8 +19,9 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/config"
"k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
func podStoreKeyFunc(obj interface{}) (string, error) {
@ -38,7 +39,7 @@ func makeTestPodDiscovery() (*Pod, *fakeInformer) {
func makeMultiPortPod() *v1.Pod {
return &v1.Pod{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testpod",
Namespace: "default",
Labels: map[string]string{"testlabel": "testvalue"},
@ -82,7 +83,7 @@ func makeMultiPortPod() *v1.Pod {
func makePod() *v1.Pod {
return &v1.Pod{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testpod",
Namespace: "default",
},
@ -266,7 +267,7 @@ func TestPodDiscoveryDeleteUnknownCacheState(t *testing.T) {
func TestPodDiscoveryUpdate(t *testing.T) {
n, i := makeTestPodDiscovery()
i.GetStore().Add(&v1.Pod{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testpod",
Namespace: "default",
},

View file

@ -23,8 +23,8 @@ import (
"github.com/prometheus/prometheus/config"
"github.com/prometheus/prometheus/util/strutil"
"golang.org/x/net/context"
apiv1 "k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
apiv1 "k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
// Service implements discovery of Kubernetes services.

View file

@ -20,8 +20,9 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/config"
"k8s.io/client-go/1.5/pkg/api/v1"
"k8s.io/client-go/1.5/tools/cache"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/api/v1"
"k8s.io/client-go/tools/cache"
)
func serviceStoreKeyFunc(obj interface{}) (string, error) {
@ -39,7 +40,7 @@ func makeTestServiceDiscovery() (*Service, *fakeInformer) {
func makeMultiPortService() *v1.Service {
return &v1.Service{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: "testservice",
Namespace: "default",
Labels: map[string]string{"testlabel": "testvalue"},
@ -64,7 +65,7 @@ func makeMultiPortService() *v1.Service {
func makeSuffixedService(suffix string) *v1.Service {
return &v1.Service{
ObjectMeta: v1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("testservice%s", suffix),
Namespace: "default",
},

View file

@ -93,6 +93,7 @@ func (d *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) {
for {
select {
case <-ctx.Done():
return
case event := <-d.updates:
tg := &config.TargetGroup{
Source: event.Path,

View file

@ -76,6 +76,12 @@ scrape_configs:
relabel_configs:
- action: labelmap
regex: __meta_kubernetes_node_label_(.+)
- target_label: __address__
replacement: kubernetes.default.svc:443
- source_labels: [__meta_kubernetes_node_name]
regex: (.+)
target_label: __metrics_path__
replacement: /api/v1/nodes/${1}/proxy/metrics
# Scrape config for service endpoints.
#
@ -158,7 +164,8 @@ scrape_configs:
#
# * `prometheus.io/scrape`: Only scrape pods that have a value of `true`
# * `prometheus.io/path`: If the metrics path is not `/metrics` override this.
# * `prometheus.io/port`: Scrape the pod on the indicated port instead of the default of `9102`.
# * `prometheus.io/port`: Scrape the pod on the indicated port instead of the
# pod's declared ports (default is a port-free target if none are declared).
- job_name: 'kubernetes-pods'
kubernetes_sd_configs:

View file

@ -0,0 +1,34 @@
apiVersion: rbac.authorization.k8s.io/v1beta1
kind: ClusterRole
metadata:
name: prometheus
rules:
- apiGroups: [""]
resources:
- nodes
- nodes/proxy
- services
- endpoints
- pods
verbs: ["get", "list", "watch"]
- nonResourceURLs: ["/metrics"]
verbs: ["get"]
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: prometheus
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1beta1
kind: ClusterRoleBinding
metadata:
name: prometheus
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: prometheus
subjects:
- kind: ServiceAccount
name: prometheus
namespace: default

View file

@ -27,7 +27,13 @@ import (
func main() {
http.HandleFunc("/receive", func(w http.ResponseWriter, r *http.Request) {
reqBuf, err := ioutil.ReadAll(snappy.NewReader(r.Body))
compressed, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
reqBuf, err := snappy.Decode(nil, compressed)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return

View file

@ -184,7 +184,13 @@ func buildClients(cfg *config) ([]writer, []reader) {
func serve(addr string, writers []writer, readers []reader) error {
http.HandleFunc("/write", func(w http.ResponseWriter, r *http.Request) {
reqBuf, err := ioutil.ReadAll(snappy.NewReader(r.Body))
compressed, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
reqBuf, err := snappy.Decode(nil, compressed)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -211,7 +217,13 @@ func serve(addr string, writers []writer, readers []reader) error {
})
http.HandleFunc("/read", func(w http.ResponseWriter, r *http.Request) {
reqBuf, err := ioutil.ReadAll(snappy.NewReader(r.Body))
compressed, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
reqBuf, err := snappy.Decode(nil, compressed)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -245,7 +257,10 @@ func serve(addr string, writers []writer, readers []reader) error {
}
w.Header().Set("Content-Type", "application/x-protobuf")
if _, err := snappy.NewWriter(w).Write(data); err != nil {
w.Header().Set("Content-Encoding", "snappy")
compressed = snappy.Encode(nil, data)
if _, err := w.Write(compressed); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

View file

@ -127,15 +127,16 @@ type Options struct {
}
type alertMetrics struct {
latency *prometheus.SummaryVec
errors *prometheus.CounterVec
sent *prometheus.CounterVec
dropped prometheus.Counter
queueLength prometheus.GaugeFunc
queueCapacity prometheus.Gauge
latency *prometheus.SummaryVec
errors *prometheus.CounterVec
sent *prometheus.CounterVec
dropped prometheus.Counter
queueLength prometheus.GaugeFunc
queueCapacity prometheus.Gauge
alertmanagersDiscovered prometheus.GaugeFunc
}
func newAlertMetrics(r prometheus.Registerer, queueCap int, queueLen func() float64) *alertMetrics {
func newAlertMetrics(r prometheus.Registerer, queueCap int, queueLen, alertmanagersDiscovered func() float64) *alertMetrics {
m := &alertMetrics{
latency: prometheus.NewSummaryVec(prometheus.SummaryOpts{
Namespace: namespace,
@ -179,6 +180,10 @@ func newAlertMetrics(r prometheus.Registerer, queueCap int, queueLen func() floa
Name: "queue_capacity",
Help: "The capacity of the alert notifications queue.",
}),
alertmanagersDiscovered: prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "prometheus_notifications_alertmanagers_discovered",
Help: "The number of alertmanagers discovered and active.",
}, alertmanagersDiscovered),
}
m.queueCapacity.Set(float64(queueCap))
@ -191,6 +196,7 @@ func newAlertMetrics(r prometheus.Registerer, queueCap int, queueLen func() floa
m.dropped,
m.queueLength,
m.queueCapacity,
m.alertmanagersDiscovered,
)
}
@ -214,7 +220,8 @@ func New(o *Options) *Notifier {
}
queueLenFunc := func() float64 { return float64(n.queueLen()) }
n.metrics = newAlertMetrics(o.Registerer, o.QueueCapacity, queueLenFunc)
alertmanagersDiscoveredFunc := func() float64 { return float64(len(n.Alertmanagers())) }
n.metrics = newAlertMetrics(o.Registerer, o.QueueCapacity, queueLenFunc, alertmanagersDiscoveredFunc)
return n
}

View file

@ -230,7 +230,7 @@ func (b *Builder) Set(n, v string) *Builder {
}
// Labels returns the labels from the builder. If no modifications
// were made, the originl labels are returned.
// were made, the original labels are returned.
func (b *Builder) Labels() Labels {
if len(b.del) == 0 && len(b.add) == 0 {
return b.base

View file

@ -18,6 +18,8 @@ import (
"fmt"
"math"
"strconv"
"github.com/prometheus/prometheus/pkg/value"
)
@ -46,6 +48,7 @@ func (l *lexer) Lex() int {
%}
D [0-9]
S [a-zA-Z]
L [a-zA-Z_]
M [a-zA-Z_:]
@ -63,29 +66,30 @@ M [a-zA-Z_:]
#[^\r\n]*\n l.mstart = l.i
[\r\n \t]+ l.mstart = l.i
{L}({L}|{D})*\{ s = lstateLabels
{S}({M}|{D})*\{ s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
{L}({L}|{D})* s = lstateValue
{S}({M}|{D})* s = lstateValue
l.mend = l.i
l.offsets = append(l.offsets, l.i)
<lstateLabels>[ \t]+
<lstateLabels>\} s = lstateValue
<lstateLabels>,?\} s = lstateValue
l.mend = l.i
<lstateLabels>,? s = lstateLName
l.offsets = append(l.offsets, l.i)
<lstateLName>{M}({M}|{D})*= s = lstateLValue
<lstateLName>{S}({L}|{D})*= s = lstateLValue
l.offsets = append(l.offsets, l.i-1)
<lstateLValue>\"(\\.|[^\\"])*\" s = lstateLabels
<lstateLValue>\"(\\.|[^\\"]|\0)*\" s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
<lstateLValue>\'(\\.|[^\\'])*\' s = lstateLabels
<lstateLValue>\'(\\.|[^\\']|\0)*\' s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
<lstateValue>[ \t]+ l.vstart = l.i
<lstateValue>(NaN) l.val = math.NaN()
<lstateValue>(NaN) l.val = math.Float64frombits(value.NormalNaN)
s = lstateTimestamp
<lstateValue>[^\n \t\r]+ // We don't parse strictly correct floats as the conversion
// repeats the effort anyway.
l.val, l.err = strconv.ParseFloat(yoloString(l.b[l.vstart:l.i]), 64)

View file

@ -19,6 +19,8 @@ import (
"fmt"
"math"
"strconv"
"github.com/prometheus/prometheus/pkg/value"
)
// Lex is called by the parser generated by "go tool yacc" to obtain each
@ -77,7 +79,7 @@ yystart1:
goto yystate3
case c == '\x00':
goto yystate2
case c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z':
goto yystate6
}
@ -116,7 +118,7 @@ yystate6:
goto yyrule5
case c == '{':
goto yystate7
case c >= '0' && c <= '9' || c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
case c >= '0' && c <= ':' || c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
goto yystate6
}
@ -262,7 +264,12 @@ yystate20:
yystate21:
c = l.next()
goto yyrule8
switch {
default:
goto yyrule8
case c == '}':
goto yystate22
}
yystate22:
c = l.next()
@ -275,7 +282,7 @@ yystart23:
switch {
default:
goto yyabort
case c == ':' || c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z':
goto yystate24
}
@ -286,7 +293,7 @@ yystate24:
goto yyabort
case c == '=':
goto yystate25
case c >= '0' && c <= ':' || c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
case c >= '0' && c <= '9' || c >= 'A' && c <= 'Z' || c == '_' || c >= 'a' && c <= 'z':
goto yystate24
}
@ -311,13 +318,11 @@ yystate27:
c = l.next()
switch {
default:
goto yyabort
goto yystate27 // c >= '\x00' && c <= '!' || c >= '#' && c <= '[' || c >= ']' && c <= 'ÿ'
case c == '"':
goto yystate28
case c == '\\':
goto yystate29
case c >= '\x01' && c <= '!' || c >= '#' && c <= '[' || c >= ']' && c <= 'ÿ':
goto yystate27
}
yystate28:
@ -337,13 +342,11 @@ yystate30:
c = l.next()
switch {
default:
goto yyabort
goto yystate30 // c >= '\x00' && c <= '&' || c >= '(' && c <= '[' || c >= ']' && c <= 'ÿ'
case c == '\'':
goto yystate31
case c == '\\':
goto yystate32
case c >= '\x01' && c <= '&' || c >= '(' && c <= '[' || c >= ']' && c <= 'ÿ':
goto yystate30
}
yystate31:
@ -373,13 +376,13 @@ yyrule3: // [\r\n \t]+
l.mstart = l.i
goto yystate0
}
yyrule4: // {L}({L}|{D})*\{
yyrule4: // {S}({M}|{D})*\{
{
s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
goto yystate0
}
yyrule5: // {L}({L}|{D})*
yyrule5: // {S}({M}|{D})*
{
s = lstateValue
l.mend = l.i
@ -389,7 +392,7 @@ yyrule5: // {L}({L}|{D})*
yyrule6: // [ \t]+
goto yystate0
yyrule7: // \}
yyrule7: // ,?\}
{
s = lstateValue
l.mend = l.i
@ -401,19 +404,19 @@ yyrule8: // ,?
l.offsets = append(l.offsets, l.i)
goto yystate0
}
yyrule9: // {M}({M}|{D})*=
yyrule9: // {S}({L}|{D})*=
{
s = lstateLValue
l.offsets = append(l.offsets, l.i-1)
goto yystate0
}
yyrule10: // \"(\\.|[^\\"])*\"
yyrule10: // \"(\\.|[^\\"]|\0)*\"
{
s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
goto yystate0
}
yyrule11: // \'(\\.|[^\\'])*\'
yyrule11: // \'(\\.|[^\\']|\0)*\'
{
s = lstateLabels
l.offsets = append(l.offsets, l.i-1)
@ -426,7 +429,7 @@ yyrule12: // [ \t]+
}
yyrule13: // (NaN)
{
l.val = math.NaN()
l.val = math.Float64frombits(value.NormalNaN)
s = lstateTimestamp
goto yystate0
}

View file

@ -31,12 +31,14 @@ func TestParse(t *testing.T) {
input := `# HELP go_gc_duration_seconds A summary of the GC invocation durations.
# TYPE go_gc_duration_seconds summary
go_gc_duration_seconds{quantile="0"} 4.9351e-05
go_gc_duration_seconds{quantile="0.25"} 7.424100000000001e-05
go_gc_duration_seconds{quantile="0.25",} 7.424100000000001e-05
go_gc_duration_seconds{quantile="0.5",a="b"} 8.3835e-05
go_gc_duration_seconds_count 99
some:aggregate:rate5m{a_b="c"} 1
# HELP go_goroutines Number of goroutines that currently exist.
# TYPE go_goroutines gauge
go_goroutines 33 123123`
input += "\nnull_byte_metric{a=\"abc\x00\"} 1"
int64p := func(x int64) *int64 { return &x }
@ -51,7 +53,7 @@ go_goroutines 33 123123`
v: 4.9351e-05,
lset: labels.FromStrings("__name__", "go_gc_duration_seconds", "quantile", "0"),
}, {
m: `go_gc_duration_seconds{quantile="0.25"}`,
m: `go_gc_duration_seconds{quantile="0.25",}`,
v: 7.424100000000001e-05,
lset: labels.FromStrings("__name__", "go_gc_duration_seconds", "quantile", "0.25"),
}, {
@ -62,11 +64,19 @@ go_goroutines 33 123123`
m: `go_gc_duration_seconds_count`,
v: 99,
lset: labels.FromStrings("__name__", "go_gc_duration_seconds_count"),
}, {
m: `some:aggregate:rate5m{a_b="c"}`,
v: 1,
lset: labels.FromStrings("__name__", "some:aggregate:rate5m", "a_b", "c"),
}, {
m: `go_goroutines`,
v: 33,
t: int64p(123123),
lset: labels.FromStrings("__name__", "go_goroutines"),
}, {
m: "null_byte_metric{a=\"abc\x00\"}",
v: 1,
lset: labels.FromStrings("__name__", "null_byte_metric", "a", "abc\x00"),
},
}

33
pkg/value/value.go Normal file
View file

@ -0,0 +1,33 @@
// Copyright 2016 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package value
import (
"math"
)
const (
// A quiet NaN. This is also math.NaN().
NormalNaN uint64 = 0x7ff8000000000001
// A signalling NaN, due to the MSB of the mantissa being 0.
// This value is chosen with many leading 0s, so we have scope to store more
// complicated values in the future. It is 2 rather than 1 to make
// it easier to distinguish from the NormalNaN by a human when debugging.
StaleNaN uint64 = 0x7ff0000000000002
)
func IsStaleNaN(v float64) bool {
return math.Float64bits(v) == StaleNaN
}

View file

@ -23,10 +23,12 @@ import (
"sync"
"time"
opentracing "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/log"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/timestamp"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/storage"
"golang.org/x/net/context"
@ -36,6 +38,7 @@ import (
const (
namespace = "prometheus"
subsystem = "engine"
queryTag = "query"
// The largest SampleValue that can be converted to an int64 without overflow.
maxInt64 = 9223372036854774784
@ -168,6 +171,10 @@ func (q *query) Cancel() {
// Exec implements the Query interface.
func (q *query) Exec(ctx context.Context) *Result {
if span := opentracing.SpanFromContext(ctx); span != nil {
span.SetTag(queryTag, q.stmt.String())
}
res, err := q.ng.exec(ctx, q)
return &Result{Err: err, Value: res}
}
@ -363,8 +370,9 @@ func (ng *Engine) execEvalStmt(ctx context.Context, query *query, s *EvalStmt) (
evalTimer := query.stats.GetTimer(stats.InnerEvalTime).Start()
// Instant evaluation.
if s.Start == s.End && s.Interval == 0 {
start := timeMilliseconds(s.Start)
evaluator := &evaluator{
Timestamp: timeMilliseconds(s.Start),
Timestamp: start,
ctx: ctx,
}
val, err := evaluator.Eval(s.Expr)
@ -374,6 +382,16 @@ func (ng *Engine) execEvalStmt(ctx context.Context, query *query, s *EvalStmt) (
evalTimer.Stop()
queryInnerEval.Observe(evalTimer.ElapsedTime().Seconds())
// Point might have a different timestamp, force it to the evaluation
// timestamp as that is when we ran the evaluation.
switch v := val.(type) {
case Scalar:
v.T = start
case Vector:
for i := range v {
v[i].Point.T = start
}
}
return val, nil
}
@ -387,8 +405,9 @@ func (ng *Engine) execEvalStmt(ctx context.Context, query *query, s *EvalStmt) (
return nil, err
}
t := timeMilliseconds(ts)
evaluator := &evaluator{
Timestamp: timeMilliseconds(ts),
Timestamp: t,
ctx: ctx,
}
val, err := evaluator.Eval(s.Expr)
@ -405,7 +424,7 @@ func (ng *Engine) execEvalStmt(ctx context.Context, query *query, s *EvalStmt) (
ss = Series{Points: make([]Point, 0, numSteps)}
Seriess[0] = ss
}
ss.Points = append(ss.Points, Point(v))
ss.Points = append(ss.Points, Point{V: v.V, T: t})
Seriess[0] = ss
case Vector:
for _, sample := range v {
@ -418,6 +437,7 @@ func (ng *Engine) execEvalStmt(ctx context.Context, query *query, s *EvalStmt) (
}
Seriess[h] = ss
}
sample.Point.T = t
ss.Points = append(ss.Points, sample.Point)
Seriess[h] = ss
}
@ -731,15 +751,35 @@ func (ev *evaluator) vectorSelector(node *VectorSelector) Vector {
}
t, v := it.Values()
peek := 1
if !ok || t > refTime {
t, v, ok = it.PeekBack()
t, v, ok = it.PeekBack(peek)
peek += 1
if !ok || t < refTime-durationMilliseconds(StalenessDelta) {
continue
}
}
if value.IsStaleNaN(v) {
continue
}
// Find timestamp before this point, within the staleness delta.
prevT, _, ok := it.PeekBack(peek)
if ok && prevT >= refTime-durationMilliseconds(StalenessDelta) {
interval := t - prevT
if interval*4+interval/10 < refTime-t {
// It is more than 4 (+10% for safety) intervals
// since the last data point, skip as stale.
//
// We need 4 to allow for federation, as with a 10s einterval an eval
// started at t=10 could be ingested at t=20, scraped for federation at
// t=30 and only ingested by federation at t=40.
continue
}
}
vec = append(vec, Sample{
Metric: node.series[i].Labels(),
Point: Point{V: v, T: ev.Timestamp},
Point: Point{V: v, T: t},
})
}
return vec
@ -807,6 +847,9 @@ func (ev *evaluator) matrixSelector(node *MatrixSelector) Matrix {
buf := it.Buffer()
for buf.Next() {
t, v := buf.At()
if value.IsStaleNaN(v) {
continue
}
// Values in the buffer are guaranteed to be smaller than maxt.
if t >= mint {
allPoints = append(allPoints, Point{T: t, V: v})
@ -814,7 +857,7 @@ func (ev *evaluator) matrixSelector(node *MatrixSelector) Matrix {
}
// The seeked sample might also be in the range.
t, v = it.Values()
if t == maxt {
if t == maxt && !value.IsStaleNaN(v) {
allPoints = append(allPoints, Point{T: t, V: v})
}

View file

@ -15,10 +15,13 @@ package promql
import (
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/net/context"
"github.com/prometheus/prometheus/pkg/labels"
)
func TestQueryConcurrency(t *testing.T) {
@ -194,6 +197,93 @@ func TestEngineShutdown(t *testing.T) {
}
}
func TestEngineEvalStmtTimestamps(t *testing.T) {
test, err := NewTest(t, `
load 10s
metric 1 2
`)
if err != nil {
t.Fatalf("unexpected error creating test: %q", err)
}
err = test.Run()
if err != nil {
t.Fatalf("unexpected error initializing test: %q", err)
}
cases := []struct {
Query string
Result Value
Start time.Time
End time.Time
Interval time.Duration
}{
// Instant queries.
{
Query: "1",
Result: Scalar{V: 1, T: 1000},
Start: time.Unix(1, 0),
},
{
Query: "metric",
Result: Vector{
Sample{Point: Point{V: 1, T: 1000},
Metric: labels.FromStrings("__name__", "metric")},
},
Start: time.Unix(1, 0),
},
{
Query: "metric[20s]",
Result: Matrix{Series{
Points: []Point{{V: 1, T: 0}, {V: 2, T: 10000}},
Metric: labels.FromStrings("__name__", "metric")},
},
Start: time.Unix(10, 0),
},
// Range queries.
{
Query: "1",
Result: Matrix{Series{
Points: []Point{{V: 1, T: 0}, {V: 1, T: 1000}, {V: 1, T: 2000}},
Metric: labels.FromStrings()},
},
Start: time.Unix(0, 0),
End: time.Unix(2, 0),
Interval: time.Second,
},
{
Query: "metric",
Result: Matrix{Series{
Points: []Point{{V: 1, T: 0}, {V: 1, T: 1000}, {V: 2, T: 2000}},
Metric: labels.FromStrings("__name__", "metric")},
},
Start: time.Unix(0, 0),
End: time.Unix(2, 0),
Interval: time.Second,
},
}
for _, c := range cases {
var err error
var qry Query
if c.Interval == 0 {
qry, err = test.QueryEngine().NewInstantQuery(c.Query, c.Start)
} else {
qry, err = test.QueryEngine().NewRangeQuery(c.Query, c.Start, c.End, c.Interval)
}
if err != nil {
t.Fatalf("unexpected error creating query: %q", err)
}
res := qry.Exec(test.Context())
if res.Err != nil {
t.Fatalf("unexpected error running query: %q", res.Err)
}
if !reflect.DeepEqual(res.Value, c.Result) {
t.Fatalf("unexpected result for query %q: got %q wanted %q", c.Query, res.Value.String(), c.Result.String())
}
}
}
func TestRecoverEvaluatorRuntime(t *testing.T) {
var ev *evaluator
var err error

View file

@ -650,6 +650,18 @@ func funcLog10(ev *evaluator, args Expressions) Value {
return vec
}
// === timestamp(Vector ValueTypeVector) Vector ===
func funcTimestamp(ev *evaluator, args Expressions) Value {
vec := ev.evalVector(args[0])
for i := range vec {
el := &vec[i]
el.Metric = dropMetricName(el.Metric)
el.V = float64(el.T) / 1000.0
}
return vec
}
// linearRegression performs a least-square linear regression analysis on the
// provided SamplePairs. It returns the slope, and the intercept value at the
// provided time.
@ -1208,6 +1220,12 @@ var functions = map[string]*Function{
ReturnType: ValueTypeScalar,
Call: funcTime,
},
"timestamp": {
Name: "timestamp",
ArgTypes: []ValueType{ValueTypeVector},
ReturnType: ValueTypeVector,
Call: funcTimestamp,
},
"vector": {
Name: "vector",
ArgTypes: []ValueType{ValueTypeScalar},

View file

@ -15,6 +15,7 @@ package promql
import (
"fmt"
"math"
"runtime"
"sort"
"strconv"
@ -24,6 +25,7 @@ import (
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/util/strutil"
)
@ -201,17 +203,25 @@ func (p *parser) parseSeriesDesc() (m labels.Labels, vals []sequenceValue, err e
sign = -1
}
}
k := sign * p.number(p.expect(itemNumber, ctx).val)
var k float64
if t := p.peek().typ; t == itemNumber {
k = sign * p.number(p.expect(itemNumber, ctx).val)
} else if t == itemIdentifier && p.peek().val == "stale" {
p.next()
k = math.Float64frombits(value.StaleNaN)
} else {
p.errorf("expected number or 'stale' in %s but got %s", ctx, t.desc())
}
vals = append(vals, sequenceValue{
value: k,
})
// If there are no offset repetitions specified, proceed with the next value.
if t := p.peek().typ; t == itemNumber || t == itemBlank {
if t := p.peek(); t.typ == itemNumber || t.typ == itemBlank || t.typ == itemIdentifier && t.val == "stale" {
continue
} else if t == itemEOF {
} else if t.typ == itemEOF {
break
} else if t != itemADD && t != itemSUB {
} else if t.typ != itemADD && t.typ != itemSUB {
p.errorf("expected next value or relative expansion in %s but got %s", ctx, t.desc())
}

View file

@ -223,14 +223,31 @@ eval_fail instant at 0m label_replace(testmetric, "src", "", "", "")
clear
# Tests for vector.
# Tests for vector, time and timestamp.
load 10s
metric 1 1
eval instant at 0s timestamp(metric)
{} 0
eval instant at 5s timestamp(metric)
{} 0
eval instant at 10s timestamp(metric)
{} 10
eval instant at 0m vector(1)
{} 1
eval instant at 0s vector(time())
{} 0
eval instant at 5s vector(time())
{} 5
eval instant at 60m vector(time())
{} 3600
clear
# Tests for clamp_max and clamp_min().
load 5m
@ -436,3 +453,4 @@ eval instant at 0m days_in_month(vector(1454284800))
# Febuary 1st 2017 not in leap year.
eval instant at 0m days_in_month(vector(1485907200))
{} 28

51
promql/testdata/staleness.test vendored Normal file
View file

@ -0,0 +1,51 @@
load 10s
metric 0 1 stale 2
# Instant vector doesn't return series when stale.
eval instant at 10s metric
{__name__="metric"} 1
eval instant at 20s metric
eval instant at 30s metric
{__name__="metric"} 2
eval instant at 40s metric
{__name__="metric"} 2
# It goes stale 4 intervals + 10% after the last sample.
eval instant at 71s metric
{__name__="metric"} 2
eval instant at 72s metric
# Range vector ignores stale sample.
eval instant at 30s count_over_time(metric[1m])
{} 3
eval instant at 10s count_over_time(metric[1s])
{} 1
eval instant at 20s count_over_time(metric[1s])
eval instant at 20s count_over_time(metric[10s])
{} 1
clear
load 10s
metric 0
# Series with single point goes stale after 5 minutes.
eval instant at 0s metric
{__name__="metric"} 0
eval instant at 150s metric
{__name__="metric"} 0
eval instant at 300s metric
{__name__="metric"} 0
eval instant at 301s metric

View file

@ -27,37 +27,27 @@ func (a nopAppendable) Appender() (storage.Appender, error) {
type nopAppender struct{}
func (a nopAppender) Add(labels.Labels, int64, float64) (uint64, error) { return 0, nil }
func (a nopAppender) AddFast(uint64, int64, float64) error { return nil }
func (a nopAppender) Add(labels.Labels, int64, float64) (string, error) { return "", nil }
func (a nopAppender) AddFast(string, int64, float64) error { return nil }
func (a nopAppender) Commit() error { return nil }
func (a nopAppender) Rollback() error { return nil }
type collectResultAppender struct {
refs map[uint64]labels.Labels
result []sample
}
func (a *collectResultAppender) SetSeries(l labels.Labels) (uint64, error) {
if a.refs == nil {
a.refs = map[uint64]labels.Labels{}
}
ref := uint64(len(a.refs))
a.refs[ref] = l
return ref, nil
func (a *collectResultAppender) AddFast(ref string, t int64, v float64) error {
// Not implemented.
return storage.ErrNotFound
}
func (a *collectResultAppender) Add(ref uint64, t int64, v float64) error {
// for ln, lv := range s.Metric {
// if len(lv) == 0 {
// delete(s.Metric, ln)
// }
// }
func (a *collectResultAppender) Add(m labels.Labels, t int64, v float64) (string, error) {
a.result = append(a.result, sample{
metric: a.refs[ref],
metric: m,
t: t,
v: v,
})
return nil
return "", nil
}
func (a *collectResultAppender) Commit() error { return nil }

View file

@ -19,6 +19,7 @@ import (
"compress/gzip"
"fmt"
"io"
"math"
"net/http"
"sync"
"time"
@ -35,6 +36,7 @@ import (
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/textparse"
"github.com/prometheus/prometheus/pkg/timestamp"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/util/httputil"
)
@ -55,12 +57,6 @@ var (
},
[]string{"interval"},
)
targetSkippedScrapes = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "prometheus_target_skipped_scrapes_total",
Help: "Total number of scrapes that were skipped because the metric storage was throttled.",
},
)
targetReloadIntervalLength = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Name: "prometheus_target_reload_length_seconds",
@ -94,7 +90,6 @@ var (
func init() {
prometheus.MustRegister(targetIntervalLength)
prometheus.MustRegister(targetSkippedScrapes)
prometheus.MustRegister(targetReloadIntervalLength)
prometheus.MustRegister(targetSyncIntervalLength)
prometheus.MustRegister(targetScrapePoolSyncsCounter)
@ -116,7 +111,7 @@ type scrapePool struct {
loops map[uint64]loop
// Constructor for new scrape loops. This is settable for testing convenience.
newLoop func(context.Context, scraper, func() storage.Appender, func() storage.Appender) loop
newLoop func(context.Context, scraper, func() storage.Appender, func() storage.Appender, log.Logger) loop
}
func newScrapePool(ctx context.Context, cfg *config.ScrapeConfig, app Appendable) *scrapePool {
@ -160,7 +155,7 @@ func (sp *scrapePool) stop() {
// reload the scrape pool with the given scrape configuration. The target state is preserved
// but all scrape loops are restarted with the new scrape configuration.
// This method returns after all scrape loops that were stopped have fully terminated.
// This method returns after all scrape loops that were stopped have stopped scraping.
func (sp *scrapePool) reload(cfg *config.ScrapeConfig) {
start := time.Now()
@ -192,6 +187,7 @@ func (sp *scrapePool) reload(cfg *config.ScrapeConfig) {
func() storage.Appender {
return sp.reportAppender(t)
},
log.With("target", t.labels.String()),
)
)
wg.Add(1)
@ -261,6 +257,7 @@ func (sp *scrapePool) sync(targets []*Target) {
func() storage.Appender {
return sp.reportAppender(t)
},
log.With("target", t.labels.String()),
)
sp.targets[hash] = t
@ -417,39 +414,54 @@ type loop interface {
stop()
}
type lsetCacheEntry struct {
lset labels.Labels
str string
}
type scrapeLoop struct {
scraper scraper
l log.Logger
appender func() storage.Appender
reportAppender func() storage.Appender
cache map[string]uint64
// TODO: Keep only the values from the last scrape to avoid a memory leak.
refCache map[string]string // Parsed string to ref.
lsetCache map[string]lsetCacheEntry // Ref to labelset and string
seriesInPreviousScrape map[string]labels.Labels
done chan struct{}
ctx context.Context
cancel func()
ctx context.Context
scrapeCtx context.Context
cancel func()
stopped chan struct{}
}
func newScrapeLoop(ctx context.Context, sc scraper, app, reportApp func() storage.Appender) loop {
func newScrapeLoop(ctx context.Context, sc scraper, app, reportApp func() storage.Appender, l log.Logger) loop {
if l == nil {
l = log.Base()
}
sl := &scrapeLoop{
scraper: sc,
appender: app,
reportAppender: reportApp,
cache: map[string]uint64{},
done: make(chan struct{}),
refCache: map[string]string{},
lsetCache: map[string]lsetCacheEntry{},
stopped: make(chan struct{}),
ctx: ctx,
l: l,
}
sl.ctx, sl.cancel = context.WithCancel(ctx)
sl.scrapeCtx, sl.cancel = context.WithCancel(ctx)
return sl
}
func (sl *scrapeLoop) run(interval, timeout time.Duration, errc chan<- error) {
defer close(sl.done)
select {
case <-time.After(sl.scraper.offset(interval)):
// Continue after a scraping offset.
case <-sl.ctx.Done():
case <-sl.scrapeCtx.Done():
close(sl.stopped)
return
}
@ -460,18 +472,22 @@ func (sl *scrapeLoop) run(interval, timeout time.Duration, errc chan<- error) {
buf := bytes.NewBuffer(make([]byte, 0, 16000))
mainLoop:
for {
buf.Reset()
select {
case <-sl.ctx.Done():
close(sl.stopped)
return
case <-sl.scrapeCtx.Done():
break mainLoop
default:
}
var (
total, added int
start = time.Now()
scrapeCtx, _ = context.WithTimeout(sl.ctx, timeout)
total, added int
start = time.Now()
scrapeCtx, cancel = context.WithTimeout(sl.ctx, timeout)
)
// Only record after the first scrape.
@ -482,30 +498,95 @@ func (sl *scrapeLoop) run(interval, timeout time.Duration, errc chan<- error) {
}
err := sl.scraper.scrape(scrapeCtx, buf)
cancel()
var b []byte
if err == nil {
b := buf.Bytes()
if total, added, err = sl.append(b, start); err != nil {
log.With("err", err).Error("append failed")
}
b = buf.Bytes()
} else if errc != nil {
errc <- err
}
// A failed scrape is the same as an empty scrape,
// we still call sl.append to trigger stale markers.
if total, added, err = sl.append(b, start); err != nil {
sl.l.With("err", err).Error("append failed")
// The append failed, probably due to a parse error.
// Call sl.append again with an empty scrape to trigger stale markers.
if _, _, err = sl.append([]byte{}, start); err != nil {
sl.l.With("err", err).Error("append failed")
}
}
sl.report(start, time.Since(start), total, added, err)
last = start
select {
case <-sl.ctx.Done():
close(sl.stopped)
return
case <-sl.scrapeCtx.Done():
break mainLoop
case <-ticker.C:
}
}
close(sl.stopped)
sl.endOfRunStaleness(last, ticker, interval)
}
func (sl *scrapeLoop) endOfRunStaleness(last time.Time, ticker *time.Ticker, interval time.Duration) {
// Scraping has stopped. We want to write stale markers but
// the target may be recreated, so we wait just over 2 scrape intervals
// before creating them.
// If the context is cancelled, we presume the server is shutting down
// and will restart where is was. We do not attempt to write stale markers
// in this case.
if last.IsZero() {
// There never was a scrape, so there will be no stale markers.
return
}
// Wait for when the next scrape would have been, record its timestamp.
var staleTime time.Time
select {
case <-sl.ctx.Done():
return
case <-ticker.C:
staleTime = time.Now()
}
// Wait for when the next scrape would have been, if the target was recreated
// samples should have been ingested by now.
select {
case <-sl.ctx.Done():
return
case <-ticker.C:
}
// Wait for an extra 10% of the interval, just to be safe.
select {
case <-sl.ctx.Done():
return
case <-time.After(interval / 10):
}
// Call sl.append again with an empty scrape to trigger stale markers.
// If the target has since been recreated and scraped, the
// stale markers will be out of order and ignored.
if _, _, err := sl.append([]byte{}, staleTime); err != nil {
sl.l.With("err", err).Error("stale append failed")
}
if err := sl.reportStale(staleTime); err != nil {
sl.l.With("err", err).Error("stale report failed")
}
}
// Stop the scraping. May still write data and stale markers after it has
// returned. Cancel the context to stop all writes.
func (sl *scrapeLoop) stop() {
sl.cancel()
<-sl.done
<-sl.stopped
}
type sample struct {
@ -531,9 +612,12 @@ func (s samples) Less(i, j int) bool {
func (sl *scrapeLoop) append(b []byte, ts time.Time) (total, added int, err error) {
var (
app = sl.appender()
p = textparse.New(b)
defTime = timestamp.FromTime(ts)
app = sl.appender()
p = textparse.New(b)
defTime = timestamp.FromTime(ts)
seriesScraped = make(map[string]labels.Labels, len(sl.seriesInPreviousScrape))
numOutOfOrder = 0
numDuplicates = 0
)
loop:
@ -547,15 +631,24 @@ loop:
}
mets := yoloString(met)
ref, ok := sl.cache[mets]
ref, ok := sl.refCache[mets]
if ok {
switch err = app.AddFast(ref, t, v); err {
case nil:
seriesScraped[sl.lsetCache[ref].str] = sl.lsetCache[ref].lset
case storage.ErrNotFound:
ok = false
case errSeriesDropped:
err = nil
continue
case storage.ErrOutOfOrderSample:
sl.l.With("timeseries", string(met)).Debug("Out of order sample")
numOutOfOrder += 1
continue
case storage.ErrDuplicateSampleForTimestamp:
numDuplicates += 1
sl.l.With("timeseries", string(met)).Debug("Duplicate sample for timestamp")
continue
default:
break loop
}
@ -571,18 +664,61 @@ loop:
case errSeriesDropped:
err = nil
continue
case storage.ErrOutOfOrderSample:
err = nil
sl.l.With("timeseries", string(met)).Debug("Out of order sample")
numOutOfOrder += 1
continue
case storage.ErrDuplicateSampleForTimestamp:
err = nil
numDuplicates += 1
sl.l.With("timeseries", string(met)).Debug("Duplicate sample for timestamp")
continue
default:
break loop
}
// Allocate a real string.
mets = string(met)
sl.cache[mets] = ref
sl.refCache[mets] = ref
str := lset.String()
sl.lsetCache[ref] = lsetCacheEntry{lset: lset, str: str}
if tp == nil {
// Bypass staleness logic if there is an explicit timestamp.
seriesScraped[str] = lset
}
}
added++
}
if err == nil {
err = p.Err()
}
if numOutOfOrder > 0 {
sl.l.With("numDropped", numOutOfOrder).Warn("Error on ingesting out-of-order samples")
}
if numDuplicates > 0 {
sl.l.With("numDropped", numDuplicates).Warn("Error on ingesting samples with different value but same timestamp")
}
if err == nil {
for metric, lset := range sl.seriesInPreviousScrape {
if _, ok := seriesScraped[metric]; !ok {
// Series no longer exposed, mark it stale.
_, err = app.Add(lset, defTime, math.Float64frombits(value.StaleNaN))
switch err {
case nil:
case errSeriesDropped:
err = nil
continue
case storage.ErrOutOfOrderSample, storage.ErrDuplicateSampleForTimestamp:
// Do not count these in logging, as this is expected if a target
// goes away and comes back again with a new scrape loop.
err = nil
continue
default:
break
}
}
}
}
if err != nil {
app.Rollback()
return total, 0, err
@ -590,6 +726,7 @@ loop:
if err := app.Commit(); err != nil {
return total, 0, err
}
sl.seriesInPreviousScrape = seriesScraped
return total, added, nil
}
@ -628,13 +765,45 @@ func (sl *scrapeLoop) report(start time.Time, duration time.Duration, scraped, a
return app.Commit()
}
func (sl *scrapeLoop) reportStale(start time.Time) error {
ts := timestamp.FromTime(start)
app := sl.reportAppender()
stale := math.Float64frombits(value.StaleNaN)
if err := sl.addReportSample(app, scrapeHealthMetricName, ts, stale); err != nil {
app.Rollback()
return err
}
if err := sl.addReportSample(app, scrapeDurationMetricName, ts, stale); err != nil {
app.Rollback()
return err
}
if err := sl.addReportSample(app, scrapeSamplesMetricName, ts, stale); err != nil {
app.Rollback()
return err
}
if err := sl.addReportSample(app, samplesPostRelabelMetricName, ts, stale); err != nil {
app.Rollback()
return err
}
return app.Commit()
}
func (sl *scrapeLoop) addReportSample(app storage.Appender, s string, t int64, v float64) error {
ref, ok := sl.cache[s]
ref, ok := sl.refCache[s]
if ok {
if err := app.AddFast(ref, t, v); err == nil {
err := app.AddFast(ref, t, v)
switch err {
case nil:
return nil
} else if err != storage.ErrNotFound {
case storage.ErrNotFound:
// Try an Add.
case storage.ErrOutOfOrderSample, storage.ErrDuplicateSampleForTimestamp:
// Do not log here, as this is expected if a target goes away and comes back
// again with a new scrape loop.
return nil
default:
return err
}
}
@ -642,10 +811,13 @@ func (sl *scrapeLoop) addReportSample(app storage.Appender, s string, t int64, v
labels.Label{Name: labels.MetricName, Value: s},
}
ref, err := app.Add(met, t, v)
if err != nil {
switch err {
case nil:
sl.refCache[s] = ref
return nil
case storage.ErrOutOfOrderSample, storage.ErrDuplicateSampleForTimestamp:
return nil
default:
return err
}
sl.cache[s] = ref
return nil
}

View file

@ -18,6 +18,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
"net/http/httptest"
"net/url"
@ -27,12 +28,15 @@ import (
"testing"
"time"
"github.com/prometheus/common/log"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"github.com/prometheus/prometheus/config"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/timestamp"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/storage"
)
@ -141,7 +145,7 @@ func TestScrapePoolReload(t *testing.T) {
}
// On starting to run, new loops created on reload check whether their preceding
// equivalents have been stopped.
newLoop := func(ctx context.Context, s scraper, app, reportApp func() storage.Appender) loop {
newLoop := func(ctx context.Context, s scraper, app, reportApp func() storage.Appender, _ log.Logger) loop {
l := &testLoop{}
l.startFunc = func(interval, timeout time.Duration, errc chan<- error) {
if interval != 3*time.Second {
@ -305,9 +309,9 @@ func TestScrapePoolSampleAppender(t *testing.T) {
}
}
func TestScrapeLoopStop(t *testing.T) {
func TestScrapeLoopStopBeforeRun(t *testing.T) {
scraper := &testScraper{}
sl := newScrapeLoop(context.Background(), scraper, nil, nil)
sl := newScrapeLoop(context.Background(), scraper, nil, nil, nil)
// The scrape pool synchronizes on stopping scrape loops. However, new scrape
// loops are started asynchronously. Thus it's possible, that a loop is stopped
@ -352,6 +356,66 @@ func TestScrapeLoopStop(t *testing.T) {
}
}
func TestScrapeLoopStop(t *testing.T) {
appender := &collectResultAppender{}
reportAppender := &collectResultAppender{}
var (
signal = make(chan struct{})
scraper = &testScraper{}
app = func() storage.Appender { return appender }
reportApp = func() storage.Appender { return reportAppender }
numScrapes = 0
)
defer close(signal)
sl := newScrapeLoop(context.Background(), scraper, app, reportApp, nil)
// Succeed once, several failures, then stop.
scraper.scrapeFunc = func(ctx context.Context, w io.Writer) error {
numScrapes += 1
if numScrapes == 2 {
go func() {
sl.stop()
}()
}
w.Write([]byte("metric_a 42\n"))
return nil
}
go func() {
sl.run(10*time.Millisecond, time.Hour, nil)
signal <- struct{}{}
}()
select {
case <-signal:
case <-time.After(5 * time.Second):
t.Fatalf("Scrape wasn't stopped.")
}
if len(appender.result) < 2 {
t.Fatalf("Appended samples not as expected. Wanted: at least %d samples Got: %d", 2, len(appender.result))
}
if !value.IsStaleNaN(appender.result[len(appender.result)-1].v) {
t.Fatalf("Appended last sample not as expected. Wanted: stale NaN Got: %x", math.Float64bits(appender.result[len(appender.result)].v))
}
if len(reportAppender.result) < 8 {
t.Fatalf("Appended samples not as expected. Wanted: at least %d samples Got: %d", 8, len(reportAppender.result))
}
if len(reportAppender.result)%4 != 0 {
t.Fatalf("Appended samples not as expected. Wanted: samples mod 4 == 0 Got: %d samples", len(reportAppender.result))
}
if !value.IsStaleNaN(reportAppender.result[len(reportAppender.result)-1].v) {
t.Fatalf("Appended last sample not as expected. Wanted: stale NaN Got: %x", math.Float64bits(reportAppender.result[len(reportAppender.result)].v))
}
if reportAppender.result[len(reportAppender.result)-1].t != appender.result[len(appender.result)-1].t {
t.Fatalf("Expected last append and report sample to have same timestamp. Append: stale NaN Report: %x", appender.result[len(appender.result)-1].t, reportAppender.result[len(reportAppender.result)-1].t)
}
}
func TestScrapeLoopRun(t *testing.T) {
var (
signal = make(chan struct{})
@ -364,7 +428,7 @@ func TestScrapeLoopRun(t *testing.T) {
defer close(signal)
ctx, cancel := context.WithCancel(context.Background())
sl := newScrapeLoop(ctx, scraper, app, reportApp)
sl := newScrapeLoop(ctx, scraper, app, reportApp, nil)
// The loop must terminate during the initial offset if the context
// is canceled.
@ -402,7 +466,7 @@ func TestScrapeLoopRun(t *testing.T) {
}
ctx, cancel = context.WithCancel(context.Background())
sl = newScrapeLoop(ctx, scraper, app, reportApp)
sl = newScrapeLoop(ctx, scraper, app, reportApp, nil)
go func() {
sl.run(time.Second, 100*time.Millisecond, errc)
@ -434,6 +498,267 @@ func TestScrapeLoopRun(t *testing.T) {
}
}
func TestScrapeLoopRunCreatesStaleMarkersOnFailedScrape(t *testing.T) {
appender := &collectResultAppender{}
var (
signal = make(chan struct{})
scraper = &testScraper{}
app = func() storage.Appender { return appender }
reportApp = func() storage.Appender { return &nopAppender{} }
numScrapes = 0
)
defer close(signal)
ctx, cancel := context.WithCancel(context.Background())
sl := newScrapeLoop(ctx, scraper, app, reportApp, nil)
// Succeed once, several failures, then stop.
scraper.scrapeFunc = func(ctx context.Context, w io.Writer) error {
numScrapes += 1
if numScrapes == 1 {
w.Write([]byte("metric_a 42\n"))
return nil
} else if numScrapes == 5 {
cancel()
}
return fmt.Errorf("Scrape failed.")
}
go func() {
sl.run(10*time.Millisecond, time.Hour, nil)
signal <- struct{}{}
}()
select {
case <-signal:
case <-time.After(5 * time.Second):
t.Fatalf("Scrape wasn't stopped.")
}
if len(appender.result) != 2 {
t.Fatalf("Appended samples not as expected. Wanted: %d samples Got: %d", 2, len(appender.result))
}
if appender.result[0].v != 42.0 {
t.Fatalf("Appended first sample not as expected. Wanted: %f Got: %f", appender.result[0], 42)
}
if !value.IsStaleNaN(appender.result[1].v) {
t.Fatalf("Appended second sample not as expected. Wanted: stale NaN Got: %x", math.Float64bits(appender.result[1].v))
}
}
func TestScrapeLoopRunCreatesStaleMarkersOnParseFailure(t *testing.T) {
appender := &collectResultAppender{}
var (
signal = make(chan struct{})
scraper = &testScraper{}
app = func() storage.Appender { return appender }
reportApp = func() storage.Appender { return &nopAppender{} }
numScrapes = 0
)
defer close(signal)
ctx, cancel := context.WithCancel(context.Background())
sl := newScrapeLoop(ctx, scraper, app, reportApp, nil)
// Succeed once, several failures, then stop.
scraper.scrapeFunc = func(ctx context.Context, w io.Writer) error {
numScrapes += 1
if numScrapes == 1 {
w.Write([]byte("metric_a 42\n"))
return nil
} else if numScrapes == 2 {
w.Write([]byte("7&-\n"))
return nil
} else if numScrapes == 3 {
cancel()
}
return fmt.Errorf("Scrape failed.")
}
go func() {
sl.run(10*time.Millisecond, time.Hour, nil)
signal <- struct{}{}
}()
select {
case <-signal:
case <-time.After(5 * time.Second):
t.Fatalf("Scrape wasn't stopped.")
}
if len(appender.result) != 2 {
t.Fatalf("Appended samples not as expected. Wanted: %d samples Got: %d", 2, len(appender.result))
}
if appender.result[0].v != 42.0 {
t.Fatalf("Appended first sample not as expected. Wanted: %f Got: %f", appender.result[0], 42)
}
if !value.IsStaleNaN(appender.result[1].v) {
t.Fatalf("Appended second sample not as expected. Wanted: stale NaN Got: %x", math.Float64bits(appender.result[1].v))
}
}
func TestScrapeLoopAppend(t *testing.T) {
app := &collectResultAppender{}
sl := &scrapeLoop{
appender: func() storage.Appender { return app },
reportAppender: func() storage.Appender { return nopAppender{} },
refCache: map[string]string{},
lsetCache: map[string]lsetCacheEntry{},
}
now := time.Now()
_, _, err := sl.append([]byte("metric_a 1\nmetric_b NaN\n"), now)
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
ingestedNaN := math.Float64bits(app.result[1].v)
if ingestedNaN != value.NormalNaN {
t.Fatalf("Appended NaN samples wasn't as expected. Wanted: %x Got: %x", value.NormalNaN, ingestedNaN)
}
// DeepEqual will report NaNs as being different, so replace with a different value.
app.result[1].v = 42
want := []sample{
{
metric: labels.FromStrings(model.MetricNameLabel, "metric_a"),
t: timestamp.FromTime(now),
v: 1,
},
{
metric: labels.FromStrings(model.MetricNameLabel, "metric_b"),
t: timestamp.FromTime(now),
v: 42,
},
}
if !reflect.DeepEqual(want, app.result) {
t.Fatalf("Appended samples not as expected. Wanted: %+v Got: %+v", want, app.result)
}
}
func TestScrapeLoopAppendStaleness(t *testing.T) {
app := &collectResultAppender{}
sl := &scrapeLoop{
appender: func() storage.Appender { return app },
reportAppender: func() storage.Appender { return nopAppender{} },
refCache: map[string]string{},
lsetCache: map[string]lsetCacheEntry{},
}
now := time.Now()
_, _, err := sl.append([]byte("metric_a 1\n"), now)
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
_, _, err = sl.append([]byte(""), now.Add(time.Second))
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
ingestedNaN := math.Float64bits(app.result[1].v)
if ingestedNaN != value.StaleNaN {
t.Fatalf("Appended stale sample wasn't as expected. Wanted: %x Got: %x", value.StaleNaN, ingestedNaN)
}
// DeepEqual will report NaNs as being different, so replace with a different value.
app.result[1].v = 42
want := []sample{
{
metric: labels.FromStrings(model.MetricNameLabel, "metric_a"),
t: timestamp.FromTime(now),
v: 1,
},
{
metric: labels.FromStrings(model.MetricNameLabel, "metric_a"),
t: timestamp.FromTime(now.Add(time.Second)),
v: 42,
},
}
if !reflect.DeepEqual(want, app.result) {
t.Fatalf("Appended samples not as expected. Wanted: %+v Got: %+v", want, app.result)
}
}
func TestScrapeLoopAppendNoStalenessIfTimestamp(t *testing.T) {
app := &collectResultAppender{}
sl := &scrapeLoop{
appender: func() storage.Appender { return app },
reportAppender: func() storage.Appender { return nopAppender{} },
refCache: map[string]string{},
lsetCache: map[string]lsetCacheEntry{},
}
now := time.Now()
_, _, err := sl.append([]byte("metric_a 1 1000\n"), now)
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
_, _, err = sl.append([]byte(""), now.Add(time.Second))
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
want := []sample{
{
metric: labels.FromStrings(model.MetricNameLabel, "metric_a"),
t: 1000,
v: 1,
},
}
if !reflect.DeepEqual(want, app.result) {
t.Fatalf("Appended samples not as expected. Wanted: %+v Got: %+v", want, app.result)
}
}
type errorAppender struct {
collectResultAppender
}
func (app *errorAppender) Add(lset labels.Labels, t int64, v float64) (string, error) {
if lset.Get(model.MetricNameLabel) == "out_of_order" {
return "", storage.ErrOutOfOrderSample
} else if lset.Get(model.MetricNameLabel) == "amend" {
return "", storage.ErrDuplicateSampleForTimestamp
}
return app.collectResultAppender.Add(lset, t, v)
}
func (app *errorAppender) AddFast(ref string, t int64, v float64) error {
return app.collectResultAppender.AddFast(ref, t, v)
}
func TestScrapeLoopAppendGracefullyIfAmendOrOutOfOrder(t *testing.T) {
app := &errorAppender{}
sl := &scrapeLoop{
appender: func() storage.Appender { return app },
reportAppender: func() storage.Appender { return nopAppender{} },
refCache: map[string]string{},
lsetCache: map[string]lsetCacheEntry{},
l: log.Base(),
}
now := time.Unix(1, 0)
_, _, err := sl.append([]byte("out_of_order 1\namend 1\nnormal 1\n"), now)
if err != nil {
t.Fatalf("Unexpected append error: %s", err)
}
want := []sample{
{
metric: labels.FromStrings(model.MetricNameLabel, "normal"),
t: timestamp.FromTime(now),
v: 1,
},
}
if !reflect.DeepEqual(want, app.result) {
t.Fatalf("Appended samples not as expected. Wanted: %+v Got: %+v", want, app.result)
}
}
func TestTargetScraperScrapeOK(t *testing.T) {
const (
configTimeout = 1500 * time.Millisecond
@ -568,7 +893,6 @@ type testScraper struct {
lastDuration time.Duration
lastError error
samples samples
scrapeErr error
scrapeFunc func(context.Context, io.Writer) error
}

View file

@ -236,19 +236,19 @@ type limitAppender struct {
i int
}
func (app *limitAppender) Add(lset labels.Labels, t int64, v float64) (uint64, error) {
func (app *limitAppender) Add(lset labels.Labels, t int64, v float64) (string, error) {
if app.i+1 > app.limit {
return 0, errors.New("sample limit exceeded")
return "", errors.New("sample limit exceeded")
}
ref, err := app.Appender.Add(lset, t, v)
if err != nil {
return 0, fmt.Errorf("sample limit of %d exceeded", app.limit)
return "", fmt.Errorf("sample limit of %d exceeded", app.limit)
}
app.i++
return ref, nil
}
func (app *limitAppender) AddFast(ref uint64, t int64, v float64) error {
func (app *limitAppender) AddFast(ref string, t int64, v float64) error {
if app.i+1 > app.limit {
return errors.New("sample limit exceeded")
}
@ -267,7 +267,7 @@ type ruleLabelsAppender struct {
labels labels.Labels
}
func (app ruleLabelsAppender) Add(lset labels.Labels, t int64, v float64) (uint64, error) {
func (app ruleLabelsAppender) Add(lset labels.Labels, t int64, v float64) (string, error) {
lb := labels.NewBuilder(lset)
for _, l := range app.labels {
@ -289,7 +289,7 @@ type honorLabelsAppender struct {
// Merges the sample's metric with the given labels if the label is not
// already present in the metric.
// This also considers labels explicitly set to the empty string.
func (app honorLabelsAppender) Add(lset labels.Labels, t int64, v float64) (uint64, error) {
func (app honorLabelsAppender) Add(lset labels.Labels, t int64, v float64) (string, error) {
lb := labels.NewBuilder(lset)
for _, l := range app.labels {
@ -309,10 +309,10 @@ type relabelAppender struct {
var errSeriesDropped = errors.New("series dropped")
func (app relabelAppender) Add(lset labels.Labels, t int64, v float64) (uint64, error) {
func (app relabelAppender) Add(lset labels.Labels, t int64, v float64) (string, error) {
lset = relabel.Process(lset, app.relabelings...)
if lset == nil {
return 0, errSeriesDropped
return "", errSeriesDropped
}
return app.Appender.Add(lset, t, v)
}

View file

@ -15,6 +15,7 @@ package rules
import (
"fmt"
"net/url"
"sync"
"time"
@ -124,7 +125,7 @@ func (r *AlertingRule) equal(o *AlertingRule) bool {
return r.name == o.name && labels.Equal(r.labels, o.labels)
}
func (r *AlertingRule) sample(alert *Alert, ts time.Time, set bool) promql.Sample {
func (r *AlertingRule) sample(alert *Alert, ts time.Time) promql.Sample {
lb := labels.NewBuilder(r.labels)
for _, l := range alert.Labels {
@ -137,10 +138,7 @@ func (r *AlertingRule) sample(alert *Alert, ts time.Time, set bool) promql.Sampl
s := promql.Sample{
Metric: lb.Labels(),
Point: promql.Point{T: timestamp.FromTime(ts), V: 0},
}
if set {
s.V = 1
Point: promql.Point{T: timestamp.FromTime(ts), V: 1},
}
return s
}
@ -151,7 +149,7 @@ const resolvedRetention = 15 * time.Minute
// Eval evaluates the rule expression and then creates pending alerts and fires
// or removes previously pending alerts accordingly.
func (r *AlertingRule) Eval(ctx context.Context, ts time.Time, engine *promql.Engine, externalURLPath string) (promql.Vector, error) {
func (r *AlertingRule) Eval(ctx context.Context, ts time.Time, engine *promql.Engine, externalURL *url.URL) (promql.Vector, error) {
query, err := engine.NewInstantQuery(r.vector.String(), ts)
if err != nil {
return nil, err
@ -194,7 +192,7 @@ func (r *AlertingRule) Eval(ctx context.Context, ts time.Time, engine *promql.En
tmplData,
model.Time(timestamp.FromTime(ts)),
engine,
externalURLPath,
externalURL,
)
result, err := tmpl.Expand()
if err != nil {
@ -240,9 +238,6 @@ func (r *AlertingRule) Eval(ctx context.Context, ts time.Time, engine *promql.En
// Check if any pending alerts should be removed or fire now. Write out alert timeseries.
for fp, a := range r.active {
if _, ok := resultFPs[fp]; !ok {
if a.State != StateInactive {
vec = append(vec, r.sample(a, ts, false))
}
// If the alert was previously firing, keep it around for a given
// retention time so it is reported as resolved to the AlertManager.
if a.State == StatePending || (!a.ResolvedAt.IsZero() && ts.Sub(a.ResolvedAt) > resolvedRetention) {
@ -256,11 +251,10 @@ func (r *AlertingRule) Eval(ctx context.Context, ts time.Time, engine *promql.En
}
if a.State == StatePending && ts.Sub(a.ActiveAt) >= r.holdDuration {
vec = append(vec, r.sample(a, ts, false))
a.State = StateFiring
}
vec = append(vec, r.sample(a, ts, true))
vec = append(vec, r.sample(a, ts))
}
return vec, nil

View file

@ -16,6 +16,7 @@ package rules
import (
"fmt"
"io/ioutil"
"math"
"net/url"
"path/filepath"
"sync"
@ -30,6 +31,9 @@ import (
"github.com/prometheus/prometheus/config"
"github.com/prometheus/prometheus/notifier"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/timestamp"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/util/strutil"
@ -112,7 +116,7 @@ const (
type Rule interface {
Name() string
// eval evaluates the rule, including any associated recording or alerting actions.
Eval(context.Context, time.Time, *promql.Engine, string) (promql.Vector, error)
Eval(context.Context, time.Time, *promql.Engine, *url.URL) (promql.Vector, error)
// String returns a human-readable string representation of the rule.
String() string
// HTMLSnippet returns a human-readable string representation of the rule,
@ -122,10 +126,11 @@ type Rule interface {
// Group is a set of rules that have a logical relation.
type Group struct {
name string
interval time.Duration
rules []Rule
opts *ManagerOptions
name string
interval time.Duration
rules []Rule
seriesInPreviousEval []map[string]labels.Labels // One per Rule.
opts *ManagerOptions
done chan struct{}
terminated chan struct{}
@ -134,12 +139,13 @@ type Group struct {
// NewGroup makes a new Group with the given name, options, and rules.
func NewGroup(name string, interval time.Duration, rules []Rule, opts *ManagerOptions) *Group {
return &Group{
name: name,
interval: interval,
rules: rules,
opts: opts,
done: make(chan struct{}),
terminated: make(chan struct{}),
name: name,
interval: interval,
rules: rules,
opts: opts,
seriesInPreviousEval: make([]map[string]labels.Labels, len(rules)),
done: make(chan struct{}),
terminated: make(chan struct{}),
}
}
@ -157,7 +163,7 @@ func (g *Group) run() {
iterationsScheduled.Inc()
start := time.Now()
g.Eval()
g.Eval(start)
iterationDuration.Observe(time.Since(start).Seconds())
}
@ -214,25 +220,38 @@ func (g *Group) offset() time.Duration {
return time.Duration(next - now)
}
// copyState copies the alerting rule state from the given group.
// copyState copies the alerting rule and staleness related state from the given group.
//
// Rules are matched based on their name. If there are duplicates, the
// first is matched with the first, second with the second etc.
func (g *Group) copyState(from *Group) {
for _, fromRule := range from.rules {
far, ok := fromRule.(*AlertingRule)
ruleMap := make(map[string][]int, len(from.rules))
for fi, fromRule := range from.rules {
l, _ := ruleMap[fromRule.Name()]
ruleMap[fromRule.Name()] = append(l, fi)
}
for i, rule := range g.rules {
indexes, ok := ruleMap[rule.Name()]
if len(indexes) == 0 {
continue
}
fi := indexes[0]
g.seriesInPreviousEval[i] = from.seriesInPreviousEval[fi]
ruleMap[rule.Name()] = indexes[1:]
ar, ok := rule.(*AlertingRule)
if !ok {
continue
}
for _, rule := range g.rules {
ar, ok := rule.(*AlertingRule)
if !ok {
continue
}
// TODO(fabxc): forbid same alert definitions that are not unique by
// at least on static label or alertname?
if far.equal(ar) {
for fp, a := range far.active {
ar.active[fp] = a
}
}
far, ok := from.rules[fi].(*AlertingRule)
if !ok {
continue
}
for fp, a := range far.active {
ar.active[fp] = a
}
}
}
@ -250,18 +269,17 @@ func typeForRule(r Rule) ruleType {
// Eval runs a single evaluation cycle in which all rules are evaluated in parallel.
// In the future a single group will be evaluated sequentially to properly handle
// rule dependency.
func (g *Group) Eval() {
func (g *Group) Eval(ts time.Time) {
var (
now = time.Now()
wg sync.WaitGroup
wg sync.WaitGroup
)
for _, rule := range g.rules {
for i, rule := range g.rules {
rtyp := string(typeForRule(rule))
wg.Add(1)
// BUG(julius): Look at fixing thundering herd.
go func(rule Rule) {
go func(i int, rule Rule) {
defer wg.Done()
defer func(t time.Time) {
@ -270,7 +288,7 @@ func (g *Group) Eval() {
evalTotal.WithLabelValues(rtyp).Inc()
vector, err := rule.Eval(g.opts.Context, now, g.opts.QueryEngine, g.opts.ExternalURL.Path)
vector, err := rule.Eval(g.opts.Context, ts, g.opts.QueryEngine, g.opts.ExternalURL)
if err != nil {
// Canceled queries are intentional termination of queries. This normally
// happens on shutdown and thus we skip logging of any errors here.
@ -295,6 +313,7 @@ func (g *Group) Eval() {
return
}
seriesReturned := make(map[string]labels.Labels, len(g.seriesInPreviousEval[i]))
for _, s := range vector {
if _, err := app.Add(s.Metric, s.T, s.V); err != nil {
switch err {
@ -307,6 +326,8 @@ func (g *Group) Eval() {
default:
log.With("sample", s).With("err", err).Warn("Rule evaluation result discarded")
}
} else {
seriesReturned[s.Metric.String()] = s.Metric
}
}
if numOutOfOrder > 0 {
@ -315,10 +336,27 @@ func (g *Group) Eval() {
if numDuplicates > 0 {
log.With("numDropped", numDuplicates).Warn("Error on ingesting results from rule evaluation with different value but same timestamp")
}
for metric, lset := range g.seriesInPreviousEval[i] {
if _, ok := seriesReturned[metric]; !ok {
// Series no longer exposed, mark it stale.
_, err = app.Add(lset, timestamp.FromTime(ts), math.Float64frombits(value.StaleNaN))
switch err {
case nil:
case storage.ErrOutOfOrderSample, storage.ErrDuplicateSampleForTimestamp:
// Do not count these in logging, as this is expected if series
// is exposed from a different rule.
default:
log.With("sample", metric).With("err", err).Warn("adding stale sample failed")
}
}
}
if err := app.Commit(); err != nil {
log.With("err", err).Warn("rule sample appending failed")
} else {
g.seriesInPreviousEval[i] = seriesReturned
}
}(rule)
}(i, rule)
}
wg.Wait()
}
@ -433,7 +471,7 @@ func (m *Manager) ApplyConfig(conf *config.Config) error {
wg.Add(1)
// If there is an old group with the same identifier, stop it and wait for
// it to finish the current iteration. Then copy its into the new group.
// it to finish the current iteration. Then copy it into the new group.
oldg, ok := m.groups[newg.name]
delete(m.groups, newg.name)

View file

@ -14,7 +14,10 @@
package rules
import (
"context"
"fmt"
"math"
"reflect"
"strings"
"testing"
"time"
@ -23,7 +26,10 @@ import (
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/pkg/timestamp"
"github.com/prometheus/prometheus/pkg/value"
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/storage"
"github.com/prometheus/prometheus/util/testutil"
)
func TestAlertingRule(t *testing.T) {
@ -69,23 +75,18 @@ func TestAlertingRule(t *testing.T) {
}, {
time: 5 * time.Minute,
result: []string{
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="pending", group="canary", instance="0", job="app-server", severity="critical"} => 0 @[%v]`,
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="0", job="app-server", severity="critical"} => 1 @[%v]`,
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="pending", group="canary", instance="1", job="app-server", severity="critical"} => 0 @[%v]`,
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="1", job="app-server", severity="critical"} => 1 @[%v]`,
},
}, {
time: 10 * time.Minute,
result: []string{
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="0", job="app-server", severity="critical"} => 1 @[%v]`,
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="1", job="app-server", severity="critical"} => 0 @[%v]`,
},
},
{
time: 15 * time.Minute,
result: []string{
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="0", job="app-server", severity="critical"} => 0 @[%v]`,
},
time: 15 * time.Minute,
result: []string{},
},
{
time: 20 * time.Minute,
@ -100,7 +101,6 @@ func TestAlertingRule(t *testing.T) {
{
time: 30 * time.Minute,
result: []string{
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="pending", group="canary", instance="0", job="app-server", severity="critical"} => 0 @[%v]`,
`{__name__="ALERTS", alertname="HTTPRequestRateLow", alertstate="firing", group="canary", instance="0", job="app-server", severity="critical"} => 1 @[%v]`,
},
},
@ -109,7 +109,7 @@ func TestAlertingRule(t *testing.T) {
for i, test := range tests {
evalTime := baseTime.Add(test.time)
res, err := rule.Eval(suite.Context(), evalTime, suite.QueryEngine(), "")
res, err := rule.Eval(suite.Context(), evalTime, suite.QueryEngine(), nil)
if err != nil {
t.Fatalf("Error during alerting rule evaluation: %s", err)
}
@ -156,3 +156,131 @@ func annotateWithTime(lines []string, ts time.Time) []string {
}
return annotatedLines
}
func TestStaleness(t *testing.T) {
storage := testutil.NewStorage(t)
defer storage.Close()
engine := promql.NewEngine(storage, nil)
opts := &ManagerOptions{
QueryEngine: engine,
Appendable: storage,
Context: context.Background(),
}
expr, err := promql.ParseExpr("a + 1")
if err != nil {
t.Fatal(err)
}
rule := NewRecordingRule("a_plus_one", expr, labels.Labels{})
group := NewGroup("default", time.Second, []Rule{rule}, opts)
// A time series that has two samples and then goes stale.
app, _ := storage.Appender()
app.Add(labels.FromStrings(model.MetricNameLabel, "a"), 0, 1)
app.Add(labels.FromStrings(model.MetricNameLabel, "a"), 1000, 2)
app.Add(labels.FromStrings(model.MetricNameLabel, "a"), 2000, math.Float64frombits(value.StaleNaN))
if err = app.Commit(); err != nil {
t.Fatal(err)
}
// Execute 3 times, 1 second apart.
group.Eval(time.Unix(0, 0))
group.Eval(time.Unix(1, 0))
group.Eval(time.Unix(2, 0))
querier, err := storage.Querier(0, 2000)
defer querier.Close()
if err != nil {
t.Fatal(err)
}
matcher, _ := labels.NewMatcher(labels.MatchEqual, model.MetricNameLabel, "a_plus_one")
samples, err := readSeriesSet(querier.Select(matcher))
if err != nil {
t.Fatal(err)
}
metric := labels.FromStrings(model.MetricNameLabel, "a_plus_one").String()
metricSample, ok := samples[metric]
if !ok {
t.Fatalf("Series %s not returned.", metric)
}
if !value.IsStaleNaN(metricSample[2].V) {
t.Fatalf("Appended second sample not as expected. Wanted: stale NaN Got: %x", math.Float64bits(metricSample[2].V))
}
metricSample[2].V = 42 // reflect.DeepEqual cannot handle NaN.
want := map[string][]promql.Point{
metric: []promql.Point{{0, 2}, {1000, 3}, {2000, 42}},
}
if !reflect.DeepEqual(want, samples) {
t.Fatalf("Returned samples not as expected. Wanted: %+v Got: %+v", want, samples)
}
}
// Convert a SeriesSet into a form useable with reflect.DeepEqual.
func readSeriesSet(ss storage.SeriesSet) (map[string][]promql.Point, error) {
result := map[string][]promql.Point{}
for ss.Next() {
series := ss.At()
points := []promql.Point{}
it := series.Iterator()
for it.Next() {
t, v := it.At()
points = append(points, promql.Point{T: t, V: v})
}
name := series.Labels().String()
result[name] = points
}
return result, ss.Err()
}
func TestCopyState(t *testing.T) {
oldGroup := &Group{
rules: []Rule{
NewAlertingRule("alert", nil, 0, nil, nil),
NewRecordingRule("rule1", nil, nil),
NewRecordingRule("rule2", nil, nil),
NewRecordingRule("rule3", nil, nil),
NewRecordingRule("rule3", nil, nil),
},
seriesInPreviousEval: []map[string]labels.Labels{
map[string]labels.Labels{"a": nil},
map[string]labels.Labels{"r1": nil},
map[string]labels.Labels{"r2": nil},
map[string]labels.Labels{"r3a": nil},
map[string]labels.Labels{"r3b": nil},
},
}
oldGroup.rules[0].(*AlertingRule).active[42] = nil
newGroup := &Group{
rules: []Rule{
NewRecordingRule("rule3", nil, nil),
NewRecordingRule("rule3", nil, nil),
NewRecordingRule("rule3", nil, nil),
NewAlertingRule("alert", nil, 0, nil, nil),
NewRecordingRule("rule1", nil, nil),
NewRecordingRule("rule4", nil, nil),
},
seriesInPreviousEval: make([]map[string]labels.Labels, 6),
}
newGroup.copyState(oldGroup)
want := []map[string]labels.Labels{
map[string]labels.Labels{"r3a": nil},
map[string]labels.Labels{"r3b": nil},
nil,
map[string]labels.Labels{"a": nil},
map[string]labels.Labels{"r1": nil},
nil,
}
if !reflect.DeepEqual(want, newGroup.seriesInPreviousEval) {
t.Fatalf("seriesInPreviousEval not as expected. Wanted: %+v Got: %+v", want, newGroup.seriesInPreviousEval)
}
if !reflect.DeepEqual(oldGroup.rules[0], newGroup.rules[3]) {
t.Fatalf("Active alerts not as expected. Wanted: %+v Got: %+v", oldGroup.rules[0], oldGroup.rules[3])
}
}

View file

@ -16,6 +16,7 @@ package rules
import (
"fmt"
"html/template"
"net/url"
"time"
"golang.org/x/net/context"
@ -47,7 +48,7 @@ func (rule RecordingRule) Name() string {
}
// Eval evaluates the rule and then overrides the metric names and labels accordingly.
func (rule RecordingRule) Eval(ctx context.Context, ts time.Time, engine *promql.Engine, _ string) (promql.Vector, error) {
func (rule RecordingRule) Eval(ctx context.Context, ts time.Time, engine *promql.Engine, _ *url.URL) (promql.Vector, error) {
query, err := engine.NewInstantQuery(rule.vector.String(), ts)
if err != nil {
return nil, err

View file

@ -64,7 +64,7 @@ func TestRuleEval(t *testing.T) {
for _, test := range suite {
rule := NewRecordingRule(test.name, test.expr, test.labels)
result, err := rule.Eval(ctx, now, engine, "")
result, err := rule.Eval(ctx, now, engine, nil)
if err != nil {
t.Fatalf("Error evaluating %s", test.name)
}

View file

@ -36,10 +36,10 @@ func NewBuffer(it SeriesIterator, delta int64) *BufferedSeriesIterator {
return bit
}
// PeekBack returns the previous element of the iterator. If there is none buffered,
// PeekBack returns the nth previous element of the iterator. If there is none buffered,
// ok is false.
func (b *BufferedSeriesIterator) PeekBack() (t int64, v float64, ok bool) {
return b.buf.last()
func (b *BufferedSeriesIterator) PeekBack(n int) (t int64, v float64, ok bool) {
return b.buf.nthLast(n)
}
// Buffer returns an iterator over the buffered data.
@ -189,13 +189,13 @@ func (r *sampleRing) add(t int64, v float64) {
}
}
// last returns the most recent element added to the ring.
func (r *sampleRing) last() (int64, float64, bool) {
if r.l == 0 {
// nthLast returns the nth most recent element added to the ring.
func (r *sampleRing) nthLast(n int) (int64, float64, bool) {
if n > r.l {
return 0, 0, false
}
s := r.buf[r.i]
return s.t, s.v, true
t, v := r.at(r.l - n)
return t, v, true
}
func (r *sampleRing) samples() []sample {

View file

@ -52,9 +52,9 @@ type Querier interface {
// Appender provides batched appends against a storage.
type Appender interface {
Add(l labels.Labels, t int64, v float64) (uint64, error)
Add(l labels.Labels, t int64, v float64) (string, error)
AddFast(ref uint64, t int64, v float64) error
AddFast(ref string, t int64, v float64) error
// Commit submits the collected samples and purges the batch.
Commit() error

View file

@ -17,6 +17,7 @@ import (
"time"
"unsafe"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/storage"
@ -41,15 +42,11 @@ type Options struct {
// The maximum timestamp range of compacted blocks.
MaxBlockDuration time.Duration
// Number of head blocks that can be appended to.
// Should be two or higher to prevent write errors in general scenarios.
//
// After a new block is started for timestamp t0 or higher, appends with
// timestamps as early as t0 - (n-1) * MinBlockDuration are valid.
AppendableBlocks int
// Duration for how long to retain data.
Retention time.Duration
// Disable creation and consideration of lockfile.
NoLockfile bool
}
// Open returns a new storage backed by a tsdb database.
@ -58,8 +55,8 @@ func Open(path string, r prometheus.Registerer, opts *Options) (storage.Storage,
WALFlushInterval: 10 * time.Second,
MinBlockDuration: uint64(opts.MinBlockDuration.Seconds() * 1000),
MaxBlockDuration: uint64(opts.MaxBlockDuration.Seconds() * 1000),
AppendableBlocks: opts.AppendableBlocks,
RetentionDuration: uint64(opts.Retention.Seconds() * 1000),
NoLockfile: opts.NoLockfile,
})
if err != nil {
return nil, err
@ -117,24 +114,24 @@ type appender struct {
a tsdb.Appender
}
func (a appender) Add(lset labels.Labels, t int64, v float64) (uint64, error) {
func (a appender) Add(lset labels.Labels, t int64, v float64) (string, error) {
ref, err := a.a.Add(toTSDBLabels(lset), t, v)
switch err {
switch errors.Cause(err) {
case tsdb.ErrNotFound:
return 0, storage.ErrNotFound
return "", storage.ErrNotFound
case tsdb.ErrOutOfOrderSample:
return 0, storage.ErrOutOfOrderSample
return "", storage.ErrOutOfOrderSample
case tsdb.ErrAmendSample:
return 0, storage.ErrDuplicateSampleForTimestamp
return "", storage.ErrDuplicateSampleForTimestamp
}
return ref, err
}
func (a appender) AddFast(ref uint64, t int64, v float64) error {
func (a appender) AddFast(ref string, t int64, v float64) error {
err := a.a.AddFast(ref, t, v)
switch err {
switch errors.Cause(err) {
case tsdb.ErrNotFound:
return storage.ErrNotFound
case tsdb.ErrOutOfOrderSample:

View file

@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"math"
"net/url"
"regexp"
"sort"
"strings"
@ -110,7 +111,7 @@ type Expander struct {
}
// NewTemplateExpander returns a template expander ready to use.
func NewTemplateExpander(ctx context.Context, text string, name string, data interface{}, timestamp model.Time, queryEngine *promql.Engine, pathPrefix string) *Expander {
func NewTemplateExpander(ctx context.Context, text string, name string, data interface{}, timestamp model.Time, queryEngine *promql.Engine, externalURL *url.URL) *Expander {
return &Expander{
text: text,
name: name,
@ -246,7 +247,10 @@ func NewTemplateExpander(ctx context.Context, text string, name string, data int
return fmt.Sprint(t)
},
"pathPrefix": func() string {
return pathPrefix
return externalURL.Path
},
"externalURL": func() string {
return externalURL.String()
},
},
}

View file

@ -15,6 +15,7 @@ package template
import (
"math"
"net/url"
"testing"
"github.com/prometheus/common/model"
@ -198,6 +199,16 @@ func TestTemplateExpansion(t *testing.T) {
output: "x",
html: true,
},
{
// pathPrefix.
text: "{{ pathPrefix }}",
output: "/path/prefix",
},
{
// externalURL.
text: "{{ externalURL }}",
output: "http://testhost:9090/path/prefix",
},
}
time := model.Time(0)
@ -221,10 +232,15 @@ func TestTemplateExpansion(t *testing.T) {
engine := promql.NewEngine(storage, nil)
extURL, err := url.Parse("http://testhost:9090/path/prefix")
if err != nil {
panic(err)
}
for i, s := range scenarios {
var result string
var err error
expander := NewTemplateExpander(context.Background(), s.text, "test", s.input, time, engine, "")
expander := NewTemplateExpander(context.Background(), s.text, "test", s.input, time, engine, extURL)
if s.html {
result, err = expander.ExpandHTML(nil)
} else {

View file

@ -33,10 +33,11 @@ func NewStorage(t T) storage.Storage {
log.With("dir", dir).Debugln("opening test storage")
// Tests just load data for a series sequentially. Thus we
// need a long appendable window.
db, err := tsdb.Open(dir, nil, &tsdb.Options{
MinBlockDuration: 2 * time.Hour,
MinBlockDuration: 24 * time.Hour,
MaxBlockDuration: 24 * time.Hour,
AppendableBlocks: 10,
})
if err != nil {
t.Fatalf("Opening test storage failed: %s", err)

View file

@ -168,7 +168,7 @@ func (tc *ZookeeperTreeCache) loop(path string) {
failureMode = false
}
case <-tc.stop:
close(tc.events)
tc.recursiveStop(tc.head)
return
}
}
@ -264,3 +264,13 @@ func (tc *ZookeeperTreeCache) recursiveDelete(path string, node *zookeeperTreeCa
tc.recursiveDelete(path+"/"+name, childNode)
}
}
func (tc *ZookeeperTreeCache) recursiveStop(node *zookeeperTreeCacheNode) {
if !node.stopped {
node.done <- struct{}{}
node.stopped = true
}
for _, childNode := range node.children {
tc.recursiveStop(childNode)
}
}

View file

@ -1,7 +1,7 @@
govalidator
===========
[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/asaskevich/govalidator?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![GoDoc](https://godoc.org/github.com/asaskevich/govalidator?status.png)](https://godoc.org/github.com/asaskevich/govalidator) [![Coverage Status](https://img.shields.io/coveralls/asaskevich/govalidator.svg)](https://coveralls.io/r/asaskevich/govalidator?branch=master) [![wercker status](https://app.wercker.com/status/1ec990b09ea86c910d5f08b0e02c6043/s "wercker status")](https://app.wercker.com/project/bykey/1ec990b09ea86c910d5f08b0e02c6043)
[![Build Status](https://travis-ci.org/asaskevich/govalidator.svg?branch=master)](https://travis-ci.org/asaskevich/govalidator)
[![Build Status](https://travis-ci.org/asaskevich/govalidator.svg?branch=master)](https://travis-ci.org/asaskevich/govalidator) [![Go Report Card](https://goreportcard.com/badge/github.com/asaskevich/govalidator)](https://goreportcard.com/report/github.com/asaskevich/govalidator) [![GoSearch](http://go-search.org/badge?id=github.com%2Fasaskevich%2Fgovalidator)](http://go-search.org/view?id=github.com%2Fasaskevich%2Fgovalidator)
A package of validators and sanitizers for strings, structs and collections. Based on [validator.js](https://github.com/chriso/validator.js).
@ -96,28 +96,27 @@ govalidator.CustomTypeTagMap.Set("customByteArrayValidator", CustomTypeValidator
func Abs(value float64) float64
func BlackList(str, chars string) string
func ByteLength(str string, params ...string) bool
func StringLength(str string, params ...string) bool
func StringMatches(s string, params ...string) bool
func CamelCaseToUnderscore(str string) string
func Contains(str, substring string) bool
func Count(array []interface{}, iterator ConditionIterator) int
func Each(array []interface{}, iterator Iterator)
func ErrorByField(e error, field string) string
func ErrorsByField(e error) map[string]string
func Filter(array []interface{}, iterator ConditionIterator) []interface{}
func Find(array []interface{}, iterator ConditionIterator) interface{}
func GetLine(s string, index int) (string, error)
func GetLines(s string) []string
func IsHost(s string) bool
func InRange(value, left, right float64) bool
func IsASCII(str string) bool
func IsAlpha(str string) bool
func IsAlphanumeric(str string) bool
func IsBase64(str string) bool
func IsByteLength(str string, min, max int) bool
func IsCIDR(str string) bool
func IsCreditCard(str string) bool
func IsDNSName(str string) bool
func IsDataURI(str string) bool
func IsDialString(str string) bool
func IsDNSName(str string) bool
func IsDivisibleBy(str, num string) bool
func IsEmail(str string) bool
func IsFilePath(str string) (bool, int)
@ -126,6 +125,7 @@ func IsFullWidth(str string) bool
func IsHalfWidth(str string) bool
func IsHexadecimal(str string) bool
func IsHexcolor(str string) bool
func IsHost(str string) bool
func IsIP(str string) bool
func IsIPv4(str string) bool
func IsIPv6(str string) bool
@ -134,6 +134,8 @@ func IsISBN10(str string) bool
func IsISBN13(str string) bool
func IsISO3166Alpha2(str string) bool
func IsISO3166Alpha3(str string) bool
func IsISO4217(str string) bool
func IsIn(str string, params ...string) bool
func IsInt(str string) bool
func IsJSON(str string) bool
func IsLatitude(str string) bool
@ -151,11 +153,13 @@ func IsNumeric(str string) bool
func IsPort(str string) bool
func IsPositive(value float64) bool
func IsPrintableASCII(str string) bool
func IsRFC3339(str string) bool
func IsRGBcolor(str string) bool
func IsRequestURI(rawurl string) bool
func IsRequestURL(rawurl string) bool
func IsSSN(str string) bool
func IsSemver(str string) bool
func IsTime(str string, format string) bool
func IsURL(str string) bool
func IsUTFDigit(str string) bool
func IsUTFLetter(str string) bool
@ -172,12 +176,20 @@ func LeftTrim(str, chars string) string
func Map(array []interface{}, iterator ResultIterator) []interface{}
func Matches(str, pattern string) bool
func NormalizeEmail(str string) (string, error)
func PadBoth(str string, padStr string, padLen int) string
func PadLeft(str string, padStr string, padLen int) string
func PadRight(str string, padStr string, padLen int) string
func Range(str string, params ...string) bool
func RemoveTags(s string) string
func ReplacePattern(str, pattern, replace string) string
func Reverse(s string) string
func RightTrim(str, chars string) string
func RuneLength(str string, params ...string) bool
func SafeFileName(str string) string
func SetFieldsRequiredByDefault(value bool)
func Sign(value float64) float64
func StringLength(str string, params ...string) bool
func StringMatches(s string, params ...string) bool
func StripLow(str string, keepNewLines bool) string
func ToBoolean(str string) (bool, error)
func ToFloat(str string) (float64, error)
@ -190,10 +202,12 @@ func UnderscoreToCamelCase(s string) string
func ValidateStruct(s interface{}) (bool, error)
func WhiteList(str, chars string) string
type ConditionIterator
type CustomTypeValidator
type Error
func (e Error) Error() string
type Errors
func (es Errors) Error() string
func (es Errors) Errors() []error
type ISO3166Entry
type Iterator
type ParamValidator
@ -253,59 +267,65 @@ For completely custom validators (interface-based), see below.
Here is a list of available validators for struct fields (validator - used function):
```go
"alpha": IsAlpha,
"alphanum": IsAlphanumeric,
"ascii": IsASCII,
"base64": IsBase64,
"creditcard": IsCreditCard,
"datauri": IsDataURI,
"dialstring": IsDialString,
"dns": IsDNSName,
"email": IsEmail,
"float": IsFloat,
"fullwidth": IsFullWidth,
"halfwidth": IsHalfWidth,
"url": IsURL,
"dialstring": IsDialString,
"requrl": IsRequestURL,
"requri": IsRequestURI,
"alpha": IsAlpha,
"utfletter": IsUTFLetter,
"alphanum": IsAlphanumeric,
"utfletternum": IsUTFLetterNumeric,
"numeric": IsNumeric,
"utfnumeric": IsUTFNumeric,
"utfdigit": IsUTFDigit,
"hexadecimal": IsHexadecimal,
"hexcolor": IsHexcolor,
"host": IsHost,
"int": IsInt,
"ip": IsIP,
"ipv4": IsIPv4,
"ipv6": IsIPv6,
"isbn10": IsISBN10,
"isbn13": IsISBN13,
"json": IsJSON,
"latitude": IsLatitude,
"longitude": IsLongitude,
"lowercase": IsLowerCase,
"mac": IsMAC,
"multibyte": IsMultibyte,
"null": IsNull,
"numeric": IsNumeric,
"port": IsPort,
"printableascii": IsPrintableASCII,
"requri": IsRequestURI,
"requrl": IsRequestURL,
"rgbcolor": IsRGBcolor,
"ssn": IsSSN,
"semver": IsSemver,
"lowercase": IsLowerCase,
"uppercase": IsUpperCase,
"url": IsURL,
"utfdigit": IsUTFDigit,
"utfletter": IsUTFLetter,
"utfletternum": IsUTFLetterNumeric,
"utfnumeric": IsUTFNumeric,
"int": IsInt,
"float": IsFloat,
"null": IsNull,
"uuid": IsUUID,
"uuidv3": IsUUIDv3,
"uuidv4": IsUUIDv4,
"uuidv5": IsUUIDv5,
"creditcard": IsCreditCard,
"isbn10": IsISBN10,
"isbn13": IsISBN13,
"json": IsJSON,
"multibyte": IsMultibyte,
"ascii": IsASCII,
"printableascii": IsPrintableASCII,
"fullwidth": IsFullWidth,
"halfwidth": IsHalfWidth,
"variablewidth": IsVariableWidth,
"base64": IsBase64,
"datauri": IsDataURI,
"ip": IsIP,
"port": IsPort,
"ipv4": IsIPv4,
"ipv6": IsIPv6,
"dns": IsDNSName,
"host": IsHost,
"mac": IsMAC,
"latitude": IsLatitude,
"longitude": IsLongitude,
"ssn": IsSSN,
"semver": IsSemver,
"rfc3339": IsRFC3339,
"ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3,
```
Validators with parameters
```go
"range(min|max)": Range,
"length(min|max)": ByteLength,
"runelength(min|max)": RuneLength,
"matches(pattern)": StringMatches,
"in(string1|string2|...|stringN)": IsIn,
```
And here is small example of usage:

View file

@ -4,7 +4,7 @@ import "regexp"
// Basic regular expressions for validating strings
const (
Email string = "^(((([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|((\\x22)((((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(([\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(\\([\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(\\x22)))@((([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|\\.|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|\\.|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$"
Email string = "^(((([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|((\\x22)((((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(([\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(\\([\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(\\x22)))@((([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|\\.|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$"
CreditCard string = "^(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|6(?:011|5[0-9][0-9])[0-9]{12}|3[47][0-9]{13}|3(?:0[0-5]|[68][0-9])[0-9]{11}|(?:2131|1800|35\\d{3})\\d{11})$"
ISBN10 string = "^(?:[0-9]{9}X|[0-9]{10})$"
ISBN13 string = "^(?:[0-9]{13})$"
@ -14,7 +14,7 @@ const (
UUID string = "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
Alpha string = "^[a-zA-Z]+$"
Alphanumeric string = "^[a-zA-Z0-9]+$"
Numeric string = "^[-+]?[0-9]+$"
Numeric string = "^[0-9]+$"
Int string = "^(?:[-+]?(?:0|[1-9][0-9]*))$"
Float string = "^(?:[-+]?(?:[0-9]+))?(?:\\.[0-9]*)?(?:[eE][\\+\\-]?(?:[0-9]+))?$"
Hexadecimal string = "^[0-9a-fA-F]+$"
@ -29,7 +29,7 @@ const (
DataURI string = "^data:.+\\/(.+);base64$"
Latitude string = "^[-+]?([1-8]?\\d(\\.\\d+)?|90(\\.0+)?)$"
Longitude string = "^[-+]?(180(\\.0+)?|((1[0-7]\\d)|([1-9]?\\d))(\\.\\d+)?)$"
DNSName string = `^([a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62})*$`
DNSName string = `^([a-zA-Z0-9]{1}[a-zA-Z0-9_-]{0,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9_-]{0,62})*$`
IP string = `(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))`
URLSchema string = `((ftp|tcp|udp|wss?|https?):\/\/)`
URLUsername string = `(\S+(:\S*)?@)`
@ -37,11 +37,11 @@ const (
URLPath string = `((\/|\?|#)[^\s]*)`
URLPort string = `(:(\d{1,5}))`
URLIP string = `([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))`
URLSubdomain string = `((www\.)|([a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*))`
URL string = `^` + URLSchema + `?` + URLUsername + `?` + `((` + URLIP + `|(\[` + IP + `\])|(([a-zA-Z0-9]([a-zA-Z0-9-]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|(` + URLSubdomain + `?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))` + URLPort + `?` + URLPath + `?$`
URLSubdomain string = `((www\.)|([a-zA-Z0-9]([-\.][-\._a-zA-Z0-9]+)*))`
URL string = `^` + URLSchema + `?` + URLUsername + `?` + `((` + URLIP + `|(\[` + IP + `\])|(([a-zA-Z0-9]([a-zA-Z0-9-_]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|(` + URLSubdomain + `?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))\.?` + URLPort + `?` + URLPath + `?$`
SSN string = `^\d{3}[- ]?\d{2}[- ]?\d{4}$`
WinPath string = `^[a-zA-Z]:\\(?:[^\\/:*?"<>|\r\n]+\\)*[^\\/:*?"<>|\r\n]*$`
UnixPath string = `^((?:\/[a-zA-Z0-9\.\:]+(?:_[a-zA-Z0-9\:\.]+)*(?:\-[\:a-zA-Z0-9\.]+)*)+\/?)$`
UnixPath string = `^(/[^/\x00]*)+/?$`
Semver string = "^v?(?:0|[1-9]\\d*)\\.(?:0|[1-9]\\d*)\\.(?:0|[1-9]\\d*)(-(0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(\\.(0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*)?(\\+[0-9a-zA-Z-]+(\\.[0-9a-zA-Z-]+)*)?$"
tagName string = "valid"
)

View file

@ -29,15 +29,21 @@ type stringValues []reflect.Value
// ParamTagMap is a map of functions accept variants parameters
var ParamTagMap = map[string]ParamValidator{
"length": ByteLength,
"range": Range,
"runelength": RuneLength,
"stringlength": StringLength,
"matches": StringMatches,
"in": isInRaw,
}
// ParamTagRegexMap maps param tags to their respective regexes.
var ParamTagRegexMap = map[string]*regexp.Regexp{
"range": regexp.MustCompile("^range\\((\\d+)\\|(\\d+)\\)$"),
"length": regexp.MustCompile("^length\\((\\d+)\\|(\\d+)\\)$"),
"runelength": regexp.MustCompile("^runelength\\((\\d+)\\|(\\d+)\\)$"),
"stringlength": regexp.MustCompile("^stringlength\\((\\d+)\\|(\\d+)\\)$"),
"matches": regexp.MustCompile(`matches\(([^)]+)\)`),
"in": regexp.MustCompile(`^in\((.*)\)`),
"matches": regexp.MustCompile(`^matches\((.+)\)$`),
}
type customTypeTagMap struct {
@ -113,6 +119,10 @@ var TagMap = map[string]Validator{
"longitude": IsLongitude,
"ssn": IsSSN,
"semver": IsSemver,
"rfc3339": IsRFC3339,
"ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3,
"ISO4217": IsISO4217,
}
// ISO3166Entry stores country codes
@ -376,3 +386,33 @@ var ISO3166List = []ISO3166Entry{
{"Yemen", "Yémen (le)", "YE", "YEM", "887"},
{"Zambia", "Zambie (la)", "ZM", "ZMB", "894"},
}
// ISO4217List is the list of ISO currency codes
var ISO4217List = []string{
"AED", "AFN", "ALL", "AMD", "ANG", "AOA", "ARS", "AUD", "AWG", "AZN",
"BAM", "BBD", "BDT", "BGN", "BHD", "BIF", "BMD", "BND", "BOB", "BOV", "BRL", "BSD", "BTN", "BWP", "BYN", "BZD",
"CAD", "CDF", "CHE", "CHF", "CHW", "CLF", "CLP", "CNY", "COP", "COU", "CRC", "CUC", "CUP", "CVE", "CZK",
"DJF", "DKK", "DOP", "DZD",
"EGP", "ERN", "ETB", "EUR",
"FJD", "FKP",
"GBP", "GEL", "GHS", "GIP", "GMD", "GNF", "GTQ", "GYD",
"HKD", "HNL", "HRK", "HTG", "HUF",
"IDR", "ILS", "INR", "IQD", "IRR", "ISK",
"JMD", "JOD", "JPY",
"KES", "KGS", "KHR", "KMF", "KPW", "KRW", "KWD", "KYD", "KZT",
"LAK", "LBP", "LKR", "LRD", "LSL", "LYD",
"MAD", "MDL", "MGA", "MKD", "MMK", "MNT", "MOP", "MRO", "MUR", "MVR", "MWK", "MXN", "MXV", "MYR", "MZN",
"NAD", "NGN", "NIO", "NOK", "NPR", "NZD",
"OMR",
"PAB", "PEN", "PGK", "PHP", "PKR", "PLN", "PYG",
"QAR",
"RON", "RSD", "RUB", "RWF",
"SAR", "SBD", "SCR", "SDG", "SEK", "SGD", "SHP", "SLL", "SOS", "SRD", "SSP", "STD", "SVC", "SYP", "SZL",
"THB", "TJS", "TMT", "TND", "TOP", "TRY", "TTD", "TWD", "TZS",
"UAH", "UGX", "USD", "USN", "UYI", "UYU", "UZS",
"VEF", "VND", "VUV",
"WST",
"XAF", "XAG", "XAU", "XBA", "XBB", "XBC", "XBD", "XCD", "XDR", "XOF", "XPD", "XPF", "XPT", "XSU", "XTS", "XUA", "XXX",
"YER",
"ZAR", "ZMW", "ZWL",
}

View file

@ -4,10 +4,12 @@ import (
"errors"
"fmt"
"html"
"math"
"path"
"regexp"
"strings"
"unicode"
"unicode/utf8"
)
// Contains check if the string contains the substring.
@ -211,3 +213,56 @@ func Truncate(str string, length int, ending string) string {
return str
}
// PadLeft pad left side of string if size of string is less then indicated pad length
func PadLeft(str string, padStr string, padLen int) string {
return buildPadStr(str, padStr, padLen, true, false)
}
// PadRight pad right side of string if size of string is less then indicated pad length
func PadRight(str string, padStr string, padLen int) string {
return buildPadStr(str, padStr, padLen, false, true)
}
// PadBoth pad sides of string if size of string is less then indicated pad length
func PadBoth(str string, padStr string, padLen int) string {
return buildPadStr(str, padStr, padLen, true, true)
}
// PadString either left, right or both sides, not the padding string can be unicode and more then one
// character
func buildPadStr(str string, padStr string, padLen int, padLeft bool, padRight bool) string {
// When padded length is less then the current string size
if padLen < utf8.RuneCountInString(str) {
return str
}
padLen -= utf8.RuneCountInString(str)
targetLen := padLen
targetLenLeft := targetLen
targetLenRight := targetLen
if padLeft && padRight {
targetLenLeft = padLen / 2
targetLenRight = padLen - targetLenLeft
}
strToRepeatLen := utf8.RuneCountInString(padStr)
repeatTimes := int(math.Ceil(float64(targetLen) / float64(strToRepeatLen)))
repeatedString := strings.Repeat(padStr, repeatTimes)
leftSide := ""
if padLeft {
leftSide = repeatedString[0:targetLenLeft]
}
rightSide := ""
if padRight {
rightSide = repeatedString[0:targetLenRight]
}
return leftSide + str + rightSide
}

View file

@ -11,13 +11,16 @@ import (
"sort"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
var fieldsRequiredByDefault bool
const maxURLRuneCount = 2083
const minURLRuneCount = 3
// SetFieldsRequiredByDefault causes validation to fail when struct fields
// do not include validations or are not explicitly marked as exempt (using `valid:"-"` or `valid:"email,optional"`).
// This struct definition will fail govalidator.ValidateStruct() (and the field values do not matter):
@ -44,7 +47,7 @@ func IsEmail(str string) bool {
// IsURL check if the string is an URL.
func IsURL(str string) bool {
if str == "" || len(str) >= 2083 || len(str) <= 3 || strings.HasPrefix(str, ".") {
if str == "" || utf8.RuneCountInString(str) >= maxURLRuneCount || len(str) <= minURLRuneCount || strings.HasPrefix(str, ".") {
return false
}
u, err := url.Parse(str)
@ -62,7 +65,7 @@ func IsURL(str string) bool {
}
// IsRequestURL check if the string rawurl, assuming
// it was recieved in an HTTP request, is a valid
// it was received in an HTTP request, is a valid
// URL confirm to RFC 3986
func IsRequestURL(rawurl string) bool {
url, err := url.ParseRequestURI(rawurl)
@ -76,7 +79,7 @@ func IsRequestURL(rawurl string) bool {
}
// IsRequestURI check if the string rawurl, assuming
// it was recieved in an HTTP request, is an
// it was received in an HTTP request, is an
// absolute URI or an absolute path.
func IsRequestURI(rawurl string) bool {
_, err := url.ParseRequestURI(rawurl)
@ -458,7 +461,7 @@ func IsDNSName(str string) bool {
// constraints already violated
return false
}
return rxDNSName.MatchString(str)
return !IsIP(str) && rxDNSName.MatchString(str)
}
// IsDialString validates the given string for usage with the various Dial() functions
@ -535,6 +538,17 @@ func IsLongitude(str string) bool {
return rxLongitude.MatchString(str)
}
func toJSONName(tag string) string {
if tag == "" {
return ""
}
// JSON name always comes first. If there's no options then split[0] is
// JSON name, if JSON name is not set, then split[0] is an empty string.
split := strings.SplitN(tag, ",", 2)
return split[0]
}
// ValidateStruct use tags for fields.
// result will be equal to `false` if there are any errors.
func ValidateStruct(s interface{}) (bool, error) {
@ -558,11 +572,39 @@ func ValidateStruct(s interface{}) (bool, error) {
if typeField.PkgPath != "" {
continue // Private field
}
resultField, err2 := typeCheck(valueField, typeField, val)
structResult := true
if valueField.Kind() == reflect.Struct {
var err error
structResult, err = ValidateStruct(valueField.Interface())
if err != nil {
errs = append(errs, err)
}
}
resultField, err2 := typeCheck(valueField, typeField, val, nil)
if err2 != nil {
// Replace structure name with JSON name if there is a tag on the variable
jsonTag := toJSONName(typeField.Tag.Get("json"))
if jsonTag != "" {
switch jsonError := err2.(type) {
case Error:
jsonError.Name = jsonTag
err2 = jsonError
case Errors:
for _, e := range jsonError.Errors() {
switch tempErr := e.(type) {
case Error:
tempErr.Name = jsonTag
_ = tempErr
}
}
err2 = jsonError
}
}
errs = append(errs, err2)
}
result = result && resultField
result = result && resultField && structResult
}
if len(errs) > 0 {
err = errs
@ -594,7 +636,7 @@ func isValidTag(s string) bool {
}
for _, c := range s {
switch {
case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
case strings.ContainsRune("\\'\"!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
// Backslash and quote chars are reserved, but
// otherwise any punctuation chars are allowed
// in a tag name.
@ -620,6 +662,28 @@ func IsSemver(str string) bool {
return rxSemver.MatchString(str)
}
// IsTime check if string is valid according to given format
func IsTime(str string, format string) bool {
_, err := time.Parse(format, str)
return err == nil
}
// IsRFC3339 check if string is valid timestamp value according to RFC3339
func IsRFC3339(str string) bool {
return IsTime(str, time.RFC3339)
}
// IsISO4217 check if string is valid ISO currency code
func IsISO4217(str string) bool {
for _, currency := range ISO4217List {
if str == currency {
return true
}
}
return false
}
// ByteLength check string's length
func ByteLength(str string, params ...string) bool {
if len(params) == 2 {
@ -631,6 +695,12 @@ func ByteLength(str string, params ...string) bool {
return false
}
// RuneLength check string's length
// Alias for StringLength
func RuneLength(str string, params ...string) bool {
return StringLength(str, params...)
}
// StringMatches checks if a string matches a given pattern.
func StringMatches(s string, params ...string) bool {
if len(params) == 1 {
@ -653,6 +723,41 @@ func StringLength(str string, params ...string) bool {
return false
}
// Range check string's length
func Range(str string, params ...string) bool {
if len(params) == 2 {
value, _ := ToFloat(str)
min, _ := ToFloat(params[0])
max, _ := ToFloat(params[1])
return InRange(value, min, max)
}
return false
}
func isInRaw(str string, params ...string) bool {
if len(params) == 1 {
rawParams := params[0]
parsedParams := strings.Split(rawParams, "|")
return IsIn(str, parsedParams...)
}
return false
}
// IsIn check if string str is a member of the set of strings params
func IsIn(str string, params ...string) bool {
for _, param := range params {
if str == param {
return true
}
}
return false
}
func checkRequired(v reflect.Value, t reflect.StructField, options tagOptionsMap) (bool, error) {
if requiredOption, isRequired := options["required"]; isRequired {
if len(requiredOption) > 0 {
@ -666,7 +771,7 @@ func checkRequired(v reflect.Value, t reflect.StructField, options tagOptionsMap
return true, nil
}
func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, error) {
func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value, options tagOptionsMap) (isValid bool, resultErr error) {
if !v.IsValid() {
return false, nil
}
@ -684,12 +789,22 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
return true, nil
}
options := parseTagIntoMap(tag)
isRootType := false
if options == nil {
isRootType = true
options = parseTagIntoMap(tag)
}
if isEmptyValue(v) {
// an empty value is not validated, check only required
return checkRequired(v, t, options)
}
var customTypeErrors Errors
var customTypeValidatorsExist bool
for validatorName, customErrorMessage := range options {
if validatefunc, ok := CustomTypeTagMap.Get(validatorName); ok {
customTypeValidatorsExist = true
delete(options, validatorName)
if result := validatefunc(v.Interface(), o.Interface()); !result {
if len(customErrorMessage) > 0 {
customTypeErrors = append(customTypeErrors, Error{Name: t.Name, Err: fmt.Errorf(customErrorMessage), CustomErrorMessageExists: true})
@ -699,16 +814,26 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
}
}
}
if customTypeValidatorsExist {
if len(customTypeErrors.Errors()) > 0 {
return false, customTypeErrors
}
return true, nil
if len(customTypeErrors.Errors()) > 0 {
return false, customTypeErrors
}
if isEmptyValue(v) {
// an empty value is not validated, check only required
return checkRequired(v, t, options)
if isRootType {
// Ensure that we've checked the value by all specified validators before report that the value is valid
defer func() {
delete(options, "optional")
delete(options, "required")
if isValid && resultErr == nil && len(options) != 0 {
for validator := range options {
isValid = false
resultErr = Error{t.Name, fmt.Errorf(
"The following validator is invalid or can't be applied to the field: %q", validator), false}
return
}
}
}()
}
switch v.Kind() {
@ -718,10 +843,12 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
reflect.Float32, reflect.Float64,
reflect.String:
// for each tag option check the map of validator functions
for validator, customErrorMessage := range options {
for validatorSpec, customErrorMessage := range options {
var negate bool
validator := validatorSpec
customMsgExists := (len(customErrorMessage) > 0)
// Check wether the tag looks like '!something' or 'something'
// Check whether the tag looks like '!something' or 'something'
if validator[0] == '!' {
validator = string(validator[1:])
negate = true
@ -730,38 +857,47 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
// Check for param validators
for key, value := range ParamTagRegexMap {
ps := value.FindStringSubmatch(validator)
if len(ps) > 0 {
if validatefunc, ok := ParamTagMap[key]; ok {
switch v.Kind() {
case reflect.String:
field := fmt.Sprint(v) // make value into string, then validate with regex
if result := validatefunc(field, ps[1:]...); (!result && !negate) || (result && negate) {
var err error
if !negate {
if customMsgExists {
err = fmt.Errorf(customErrorMessage)
} else {
err = fmt.Errorf("%s does not validate as %s", field, validator)
}
if len(ps) == 0 {
continue
}
} else {
if customMsgExists {
err = fmt.Errorf(customErrorMessage)
} else {
err = fmt.Errorf("%s does validate as %s", field, validator)
}
}
return false, Error{t.Name, err, customMsgExists}
validatefunc, ok := ParamTagMap[key]
if !ok {
continue
}
delete(options, validatorSpec)
switch v.Kind() {
case reflect.String:
field := fmt.Sprint(v) // make value into string, then validate with regex
if result := validatefunc(field, ps[1:]...); (!result && !negate) || (result && negate) {
var err error
if !negate {
if customMsgExists {
err = fmt.Errorf(customErrorMessage)
} else {
err = fmt.Errorf("%s does not validate as %s", field, validator)
}
} else {
if customMsgExists {
err = fmt.Errorf(customErrorMessage)
} else {
err = fmt.Errorf("%s does validate as %s", field, validator)
}
default:
// type not yet supported, fail
return false, Error{t.Name, fmt.Errorf("Validator %s doesn't support kind %s", validator, v.Kind()), false}
}
return false, Error{t.Name, err, customMsgExists}
}
default:
// type not yet supported, fail
return false, Error{t.Name, fmt.Errorf("Validator %s doesn't support kind %s", validator, v.Kind()), false}
}
}
if validatefunc, ok := TagMap[validator]; ok {
delete(options, validatorSpec)
switch v.Kind() {
case reflect.String:
field := fmt.Sprint(v) // make value into string, then validate with regex
@ -813,7 +949,7 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
var resultItem bool
var err error
if v.Index(i).Kind() != reflect.Struct {
resultItem, err = typeCheck(v.Index(i), t, o)
resultItem, err = typeCheck(v.Index(i), t, o, options)
if err != nil {
return false, err
}
@ -832,7 +968,7 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
var resultItem bool
var err error
if v.Index(i).Kind() != reflect.Struct {
resultItem, err = typeCheck(v.Index(i), t, o)
resultItem, err = typeCheck(v.Index(i), t, o, options)
if err != nil {
return false, err
}
@ -856,7 +992,7 @@ func typeCheck(v reflect.Value, t reflect.StructField, o reflect.Value) (bool, e
if v.IsNil() {
return true, nil
}
return typeCheck(v.Elem(), t, o)
return typeCheck(v.Elem(), t, o, options)
case reflect.Struct:
return ValidateStruct(v.Interface())
default:

View file

@ -1,4 +1,4 @@
box: wercker/golang
box: golang
build:
steps:
- setup-go-workspace

View file

@ -1,23 +0,0 @@
package semver
import (
"encoding/json"
)
// MarshalJSON implements the encoding/json.Marshaler interface.
func (v Version) MarshalJSON() ([]byte, error) {
return json.Marshal(v.String())
}
// UnmarshalJSON implements the encoding/json.Unmarshaler interface.
func (v *Version) UnmarshalJSON(data []byte) (err error) {
var versionString string
if err = json.Unmarshal(data, &versionString); err != nil {
return
}
*v, err = Parse(versionString)
return
}

View file

@ -1,395 +0,0 @@
package semver
import (
"errors"
"fmt"
"strconv"
"strings"
)
const (
numbers string = "0123456789"
alphas = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-"
alphanum = alphas + numbers
)
// SpecVersion is the latest fully supported spec version of semver
var SpecVersion = Version{
Major: 2,
Minor: 0,
Patch: 0,
}
// Version represents a semver compatible version
type Version struct {
Major uint64
Minor uint64
Patch uint64
Pre []PRVersion
Build []string //No Precendence
}
// Version to string
func (v Version) String() string {
b := make([]byte, 0, 5)
b = strconv.AppendUint(b, v.Major, 10)
b = append(b, '.')
b = strconv.AppendUint(b, v.Minor, 10)
b = append(b, '.')
b = strconv.AppendUint(b, v.Patch, 10)
if len(v.Pre) > 0 {
b = append(b, '-')
b = append(b, v.Pre[0].String()...)
for _, pre := range v.Pre[1:] {
b = append(b, '.')
b = append(b, pre.String()...)
}
}
if len(v.Build) > 0 {
b = append(b, '+')
b = append(b, v.Build[0]...)
for _, build := range v.Build[1:] {
b = append(b, '.')
b = append(b, build...)
}
}
return string(b)
}
// Equals checks if v is equal to o.
func (v Version) Equals(o Version) bool {
return (v.Compare(o) == 0)
}
// EQ checks if v is equal to o.
func (v Version) EQ(o Version) bool {
return (v.Compare(o) == 0)
}
// NE checks if v is not equal to o.
func (v Version) NE(o Version) bool {
return (v.Compare(o) != 0)
}
// GT checks if v is greater than o.
func (v Version) GT(o Version) bool {
return (v.Compare(o) == 1)
}
// GTE checks if v is greater than or equal to o.
func (v Version) GTE(o Version) bool {
return (v.Compare(o) >= 0)
}
// GE checks if v is greater than or equal to o.
func (v Version) GE(o Version) bool {
return (v.Compare(o) >= 0)
}
// LT checks if v is less than o.
func (v Version) LT(o Version) bool {
return (v.Compare(o) == -1)
}
// LTE checks if v is less than or equal to o.
func (v Version) LTE(o Version) bool {
return (v.Compare(o) <= 0)
}
// LE checks if v is less than or equal to o.
func (v Version) LE(o Version) bool {
return (v.Compare(o) <= 0)
}
// Compare compares Versions v to o:
// -1 == v is less than o
// 0 == v is equal to o
// 1 == v is greater than o
func (v Version) Compare(o Version) int {
if v.Major != o.Major {
if v.Major > o.Major {
return 1
}
return -1
}
if v.Minor != o.Minor {
if v.Minor > o.Minor {
return 1
}
return -1
}
if v.Patch != o.Patch {
if v.Patch > o.Patch {
return 1
}
return -1
}
// Quick comparison if a version has no prerelease versions
if len(v.Pre) == 0 && len(o.Pre) == 0 {
return 0
} else if len(v.Pre) == 0 && len(o.Pre) > 0 {
return 1
} else if len(v.Pre) > 0 && len(o.Pre) == 0 {
return -1
}
i := 0
for ; i < len(v.Pre) && i < len(o.Pre); i++ {
if comp := v.Pre[i].Compare(o.Pre[i]); comp == 0 {
continue
} else if comp == 1 {
return 1
} else {
return -1
}
}
// If all pr versions are the equal but one has further prversion, this one greater
if i == len(v.Pre) && i == len(o.Pre) {
return 0
} else if i == len(v.Pre) && i < len(o.Pre) {
return -1
} else {
return 1
}
}
// Validate validates v and returns error in case
func (v Version) Validate() error {
// Major, Minor, Patch already validated using uint64
for _, pre := range v.Pre {
if !pre.IsNum { //Numeric prerelease versions already uint64
if len(pre.VersionStr) == 0 {
return fmt.Errorf("Prerelease can not be empty %q", pre.VersionStr)
}
if !containsOnly(pre.VersionStr, alphanum) {
return fmt.Errorf("Invalid character(s) found in prerelease %q", pre.VersionStr)
}
}
}
for _, build := range v.Build {
if len(build) == 0 {
return fmt.Errorf("Build meta data can not be empty %q", build)
}
if !containsOnly(build, alphanum) {
return fmt.Errorf("Invalid character(s) found in build meta data %q", build)
}
}
return nil
}
// New is an alias for Parse and returns a pointer, parses version string and returns a validated Version or error
func New(s string) (vp *Version, err error) {
v, err := Parse(s)
vp = &v
return
}
// Make is an alias for Parse, parses version string and returns a validated Version or error
func Make(s string) (Version, error) {
return Parse(s)
}
// Parse parses version string and returns a validated Version or error
func Parse(s string) (Version, error) {
if len(s) == 0 {
return Version{}, errors.New("Version string empty")
}
// Split into major.minor.(patch+pr+meta)
parts := strings.SplitN(s, ".", 3)
if len(parts) != 3 {
return Version{}, errors.New("No Major.Minor.Patch elements found")
}
// Major
if !containsOnly(parts[0], numbers) {
return Version{}, fmt.Errorf("Invalid character(s) found in major number %q", parts[0])
}
if hasLeadingZeroes(parts[0]) {
return Version{}, fmt.Errorf("Major number must not contain leading zeroes %q", parts[0])
}
major, err := strconv.ParseUint(parts[0], 10, 64)
if err != nil {
return Version{}, err
}
// Minor
if !containsOnly(parts[1], numbers) {
return Version{}, fmt.Errorf("Invalid character(s) found in minor number %q", parts[1])
}
if hasLeadingZeroes(parts[1]) {
return Version{}, fmt.Errorf("Minor number must not contain leading zeroes %q", parts[1])
}
minor, err := strconv.ParseUint(parts[1], 10, 64)
if err != nil {
return Version{}, err
}
v := Version{}
v.Major = major
v.Minor = minor
var build, prerelease []string
patchStr := parts[2]
if buildIndex := strings.IndexRune(patchStr, '+'); buildIndex != -1 {
build = strings.Split(patchStr[buildIndex+1:], ".")
patchStr = patchStr[:buildIndex]
}
if preIndex := strings.IndexRune(patchStr, '-'); preIndex != -1 {
prerelease = strings.Split(patchStr[preIndex+1:], ".")
patchStr = patchStr[:preIndex]
}
if !containsOnly(patchStr, numbers) {
return Version{}, fmt.Errorf("Invalid character(s) found in patch number %q", patchStr)
}
if hasLeadingZeroes(patchStr) {
return Version{}, fmt.Errorf("Patch number must not contain leading zeroes %q", patchStr)
}
patch, err := strconv.ParseUint(patchStr, 10, 64)
if err != nil {
return Version{}, err
}
v.Patch = patch
// Prerelease
for _, prstr := range prerelease {
parsedPR, err := NewPRVersion(prstr)
if err != nil {
return Version{}, err
}
v.Pre = append(v.Pre, parsedPR)
}
// Build meta data
for _, str := range build {
if len(str) == 0 {
return Version{}, errors.New("Build meta data is empty")
}
if !containsOnly(str, alphanum) {
return Version{}, fmt.Errorf("Invalid character(s) found in build meta data %q", str)
}
v.Build = append(v.Build, str)
}
return v, nil
}
// MustParse is like Parse but panics if the version cannot be parsed.
func MustParse(s string) Version {
v, err := Parse(s)
if err != nil {
panic(`semver: Parse(` + s + `): ` + err.Error())
}
return v
}
// PRVersion represents a PreRelease Version
type PRVersion struct {
VersionStr string
VersionNum uint64
IsNum bool
}
// NewPRVersion creates a new valid prerelease version
func NewPRVersion(s string) (PRVersion, error) {
if len(s) == 0 {
return PRVersion{}, errors.New("Prerelease is empty")
}
v := PRVersion{}
if containsOnly(s, numbers) {
if hasLeadingZeroes(s) {
return PRVersion{}, fmt.Errorf("Numeric PreRelease version must not contain leading zeroes %q", s)
}
num, err := strconv.ParseUint(s, 10, 64)
// Might never be hit, but just in case
if err != nil {
return PRVersion{}, err
}
v.VersionNum = num
v.IsNum = true
} else if containsOnly(s, alphanum) {
v.VersionStr = s
v.IsNum = false
} else {
return PRVersion{}, fmt.Errorf("Invalid character(s) found in prerelease %q", s)
}
return v, nil
}
// IsNumeric checks if prerelease-version is numeric
func (v PRVersion) IsNumeric() bool {
return v.IsNum
}
// Compare compares two PreRelease Versions v and o:
// -1 == v is less than o
// 0 == v is equal to o
// 1 == v is greater than o
func (v PRVersion) Compare(o PRVersion) int {
if v.IsNum && !o.IsNum {
return -1
} else if !v.IsNum && o.IsNum {
return 1
} else if v.IsNum && o.IsNum {
if v.VersionNum == o.VersionNum {
return 0
} else if v.VersionNum > o.VersionNum {
return 1
} else {
return -1
}
} else { // both are Alphas
if v.VersionStr == o.VersionStr {
return 0
} else if v.VersionStr > o.VersionStr {
return 1
} else {
return -1
}
}
}
// PreRelease version to string
func (v PRVersion) String() string {
if v.IsNum {
return strconv.FormatUint(v.VersionNum, 10)
}
return v.VersionStr
}
func containsOnly(s string, set string) bool {
return strings.IndexFunc(s, func(r rune) bool {
return !strings.ContainsRune(set, r)
}) == -1
}
func hasLeadingZeroes(s string) bool {
return len(s) > 1 && s[0] == '0'
}
// NewBuildVersion creates a new valid build version
func NewBuildVersion(s string) (string, error) {
if len(s) == 0 {
return "", errors.New("Buildversion is empty")
}
if !containsOnly(s, alphanum) {
return "", fmt.Errorf("Invalid character(s) found in build meta data %q", s)
}
return s, nil
}

View file

@ -1,28 +0,0 @@
package semver
import (
"sort"
)
// Versions represents multiple versions.
type Versions []Version
// Len returns length of version collection
func (s Versions) Len() int {
return len(s)
}
// Swap swaps two versions inside the collection by its indices
func (s Versions) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// Less checks if version at index i is less than version at index j
func (s Versions) Less(i, j int) bool {
return s[i].LT(s[j])
}
// Sort sorts a slice of versions
func Sort(versions []Version) {
sort.Sort(Versions(versions))
}

View file

@ -1,30 +0,0 @@
package semver
import (
"database/sql/driver"
"fmt"
)
// Scan implements the database/sql.Scanner interface.
func (v *Version) Scan(src interface{}) (err error) {
var str string
switch src := src.(type) {
case string:
str = src
case []byte:
str = string(src)
default:
return fmt.Errorf("Version.Scan: cannot convert %T to string.", src)
}
if t, err := Parse(str); err == nil {
*v = t
}
return
}
// Value implements the database/sql/driver.Valuer interface.
func (v Version) Value() (driver.Value, error) {
return v.String(), nil
}

View file

@ -1,5 +0,0 @@
CoreOS Project
Copyright 2014 CoreOS, Inc
This product includes software developed at CoreOS, Inc.
(http://www.coreos.com/).

View file

@ -1,7 +0,0 @@
package http
import "net/http"
type Client interface {
Do(*http.Request) (*http.Response, error)
}

View file

@ -1,156 +0,0 @@
package http
import (
"encoding/base64"
"encoding/json"
"errors"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
)
func WriteError(w http.ResponseWriter, code int, msg string) {
e := struct {
Error string `json:"error"`
}{
Error: msg,
}
b, err := json.Marshal(e)
if err != nil {
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
code = http.StatusInternalServerError
b = []byte(`{"error":"server_error"}`)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
w.Write(b)
}
// BasicAuth parses a username and password from the request's
// Authorization header. This was pulled from golang master:
// https://codereview.appspot.com/76540043
func BasicAuth(r *http.Request) (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
return
}
if !strings.HasPrefix(auth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func cacheControlMaxAge(hdr string) (time.Duration, bool, error) {
for _, field := range strings.Split(hdr, ",") {
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
k := strings.ToLower(strings.TrimSpace(parts[0]))
if k != "max-age" {
continue
}
if len(parts) == 1 {
return 0, false, errors.New("max-age has no value")
}
v := strings.TrimSpace(parts[1])
if v == "" {
return 0, false, errors.New("max-age has empty value")
}
age, err := strconv.Atoi(v)
if err != nil {
return 0, false, err
}
if age <= 0 {
return 0, false, nil
}
return time.Duration(age) * time.Second, true, nil
}
return 0, false, nil
}
func expires(date, expires string) (time.Duration, bool, error) {
if date == "" || expires == "" {
return 0, false, nil
}
te, err := time.Parse(time.RFC1123, expires)
if err != nil {
return 0, false, err
}
td, err := time.Parse(time.RFC1123, date)
if err != nil {
return 0, false, err
}
ttl := te.Sub(td)
// headers indicate data already expired, caller should not
// have to care about this case
if ttl <= 0 {
return 0, false, nil
}
return ttl, true, nil
}
func Cacheable(hdr http.Header) (time.Duration, bool, error) {
ttl, ok, err := cacheControlMaxAge(hdr.Get("Cache-Control"))
if err != nil || ok {
return ttl, ok, err
}
return expires(hdr.Get("Date"), hdr.Get("Expires"))
}
// MergeQuery appends additional query values to an existing URL.
func MergeQuery(u url.URL, q url.Values) url.URL {
uv := u.Query()
for k, vs := range q {
for _, v := range vs {
uv.Add(k, v)
}
}
u.RawQuery = uv.Encode()
return u
}
// NewResourceLocation appends a resource id to the end of the requested URL path.
func NewResourceLocation(reqURL *url.URL, id string) string {
var u url.URL
u = *reqURL
u.Path = path.Join(u.Path, id)
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
// CopyRequest returns a clone of the provided *http.Request.
// The returned object is a shallow copy of the struct and a
// deep copy of its Header field.
func CopyRequest(r *http.Request) *http.Request {
r2 := *r
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return &r2
}

View file

@ -1,29 +0,0 @@
package http
import (
"errors"
"net/url"
)
// ParseNonEmptyURL checks that a string is a parsable URL which is also not empty
// since `url.Parse("")` does not return an error. Must contian a scheme and a host.
func ParseNonEmptyURL(u string) (*url.URL, error) {
if u == "" {
return nil, errors.New("url is empty")
}
ur, err := url.Parse(u)
if err != nil {
return nil, err
}
if ur.Scheme == "" {
return nil, errors.New("url scheme is empty")
}
if ur.Host == "" {
return nil, errors.New("url host is empty")
}
return ur, nil
}

View file

@ -1,126 +0,0 @@
package jose
import (
"encoding/json"
"fmt"
"math"
"time"
)
type Claims map[string]interface{}
func (c Claims) Add(name string, value interface{}) {
c[name] = value
}
func (c Claims) StringClaim(name string) (string, bool, error) {
cl, ok := c[name]
if !ok {
return "", false, nil
}
v, ok := cl.(string)
if !ok {
return "", false, fmt.Errorf("unable to parse claim as string: %v", name)
}
return v, true, nil
}
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
cl, ok := c[name]
if !ok {
return nil, false, nil
}
if v, ok := cl.([]string); ok {
return v, true, nil
}
// When unmarshaled, []string will become []interface{}.
if v, ok := cl.([]interface{}); ok {
var ret []string
for _, vv := range v {
str, ok := vv.(string)
if !ok {
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
ret = append(ret, str)
}
return ret, true, nil
}
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
func (c Claims) Int64Claim(name string) (int64, bool, error) {
cl, ok := c[name]
if !ok {
return 0, false, nil
}
v, ok := cl.(int64)
if !ok {
vf, ok := cl.(float64)
if !ok {
return 0, false, fmt.Errorf("unable to parse claim as int64: %v", name)
}
v = int64(vf)
}
return v, true, nil
}
func (c Claims) Float64Claim(name string) (float64, bool, error) {
cl, ok := c[name]
if !ok {
return 0, false, nil
}
v, ok := cl.(float64)
if !ok {
vi, ok := cl.(int64)
if !ok {
return 0, false, fmt.Errorf("unable to parse claim as float64: %v", name)
}
v = float64(vi)
}
return v, true, nil
}
func (c Claims) TimeClaim(name string) (time.Time, bool, error) {
v, ok, err := c.Float64Claim(name)
if !ok || err != nil {
return time.Time{}, ok, err
}
s := math.Trunc(v)
ns := (v - s) * math.Pow(10, 9)
return time.Unix(int64(s), int64(ns)).UTC(), true, nil
}
func decodeClaims(payload []byte) (Claims, error) {
var c Claims
if err := json.Unmarshal(payload, &c); err != nil {
return nil, fmt.Errorf("malformed JWT claims, unable to decode: %v", err)
}
return c, nil
}
func marshalClaims(c Claims) ([]byte, error) {
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
return b, nil
}
func encodeClaims(c Claims) (string, error) {
b, err := marshalClaims(c)
if err != nil {
return "", err
}
return encodeSegment(b), nil
}

View file

@ -1,112 +0,0 @@
package jose
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
const (
HeaderMediaType = "typ"
HeaderKeyAlgorithm = "alg"
HeaderKeyID = "kid"
)
const (
// Encryption Algorithm Header Parameter Values for JWS
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-6
AlgHS256 = "HS256"
AlgHS384 = "HS384"
AlgHS512 = "HS512"
AlgRS256 = "RS256"
AlgRS384 = "RS384"
AlgRS512 = "RS512"
AlgES256 = "ES256"
AlgES384 = "ES384"
AlgES512 = "ES512"
AlgPS256 = "PS256"
AlgPS384 = "PS384"
AlgPS512 = "PS512"
AlgNone = "none"
)
const (
// Algorithm Header Parameter Values for JWE
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#section-4.1
AlgRSA15 = "RSA1_5"
AlgRSAOAEP = "RSA-OAEP"
AlgRSAOAEP256 = "RSA-OAEP-256"
AlgA128KW = "A128KW"
AlgA192KW = "A192KW"
AlgA256KW = "A256KW"
AlgDir = "dir"
AlgECDHES = "ECDH-ES"
AlgECDHESA128KW = "ECDH-ES+A128KW"
AlgECDHESA192KW = "ECDH-ES+A192KW"
AlgECDHESA256KW = "ECDH-ES+A256KW"
AlgA128GCMKW = "A128GCMKW"
AlgA192GCMKW = "A192GCMKW"
AlgA256GCMKW = "A256GCMKW"
AlgPBES2HS256A128KW = "PBES2-HS256+A128KW"
AlgPBES2HS384A192KW = "PBES2-HS384+A192KW"
AlgPBES2HS512A256KW = "PBES2-HS512+A256KW"
)
const (
// Encryption Algorithm Header Parameter Values for JWE
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-22
EncA128CBCHS256 = "A128CBC-HS256"
EncA128CBCHS384 = "A128CBC-HS384"
EncA256CBCHS512 = "A256CBC-HS512"
EncA128GCM = "A128GCM"
EncA192GCM = "A192GCM"
EncA256GCM = "A256GCM"
)
type JOSEHeader map[string]string
func (j JOSEHeader) Validate() error {
if _, exists := j[HeaderKeyAlgorithm]; !exists {
return fmt.Errorf("header missing %q parameter", HeaderKeyAlgorithm)
}
return nil
}
func decodeHeader(seg string) (JOSEHeader, error) {
b, err := decodeSegment(seg)
if err != nil {
return nil, err
}
var h JOSEHeader
err = json.Unmarshal(b, &h)
if err != nil {
return nil, err
}
return h, nil
}
func encodeHeader(h JOSEHeader) (string, error) {
b, err := json.Marshal(h)
if err != nil {
return "", err
}
return encodeSegment(b), nil
}
// Decode JWT specific base64url encoding with padding stripped
func decodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l != 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}
// Encode JWT specific base64url encoding with padding stripped
func encodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
}

View file

@ -1,135 +0,0 @@
package jose
import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"math/big"
"strings"
)
// JSON Web Key
// https://tools.ietf.org/html/draft-ietf-jose-json-web-key-36#page-5
type JWK struct {
ID string
Type string
Alg string
Use string
Exponent int
Modulus *big.Int
Secret []byte
}
type jwkJSON struct {
ID string `json:"kid"`
Type string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
Exponent string `json:"e"`
Modulus string `json:"n"`
}
func (j *JWK) MarshalJSON() ([]byte, error) {
t := jwkJSON{
ID: j.ID,
Type: j.Type,
Alg: j.Alg,
Use: j.Use,
Exponent: encodeExponent(j.Exponent),
Modulus: encodeModulus(j.Modulus),
}
return json.Marshal(&t)
}
func (j *JWK) UnmarshalJSON(data []byte) error {
var t jwkJSON
err := json.Unmarshal(data, &t)
if err != nil {
return err
}
e, err := decodeExponent(t.Exponent)
if err != nil {
return err
}
n, err := decodeModulus(t.Modulus)
if err != nil {
return err
}
j.ID = t.ID
j.Type = t.Type
j.Alg = t.Alg
j.Use = t.Use
j.Exponent = e
j.Modulus = n
return nil
}
type JWKSet struct {
Keys []JWK `json:"keys"`
}
func decodeExponent(e string) (int, error) {
decE, err := decodeBase64URLPaddingOptional(e)
if err != nil {
return 0, err
}
var eBytes []byte
if len(decE) < 8 {
eBytes = make([]byte, 8-len(decE), 8)
eBytes = append(eBytes, decE...)
} else {
eBytes = decE
}
eReader := bytes.NewReader(eBytes)
var E uint64
err = binary.Read(eReader, binary.BigEndian, &E)
if err != nil {
return 0, err
}
return int(E), nil
}
func encodeExponent(e int) string {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(e))
var idx int
for ; idx < 8; idx++ {
if b[idx] != 0x0 {
break
}
}
return base64.URLEncoding.EncodeToString(b[idx:])
}
// Turns a URL encoded modulus of a key into a big int.
func decodeModulus(n string) (*big.Int, error) {
decN, err := decodeBase64URLPaddingOptional(n)
if err != nil {
return nil, err
}
N := big.NewInt(0)
N.SetBytes(decN)
return N, nil
}
func encodeModulus(n *big.Int) string {
return base64.URLEncoding.EncodeToString(n.Bytes())
}
// decodeBase64URLPaddingOptional decodes Base64 whether there is padding or not.
// The stdlib version currently doesn't handle this.
// We can get rid of this is if this bug:
// https://github.com/golang/go/issues/4237
// ever closes.
func decodeBase64URLPaddingOptional(e string) ([]byte, error) {
if m := len(e) % 4; m != 0 {
e += strings.Repeat("=", 4-m)
}
return base64.URLEncoding.DecodeString(e)
}

View file

@ -1,51 +0,0 @@
package jose
import (
"fmt"
"strings"
)
type JWS struct {
RawHeader string
Header JOSEHeader
RawPayload string
Payload []byte
Signature []byte
}
// Given a raw encoded JWS token parses it and verifies the structure.
func ParseJWS(raw string) (JWS, error) {
parts := strings.Split(raw, ".")
if len(parts) != 3 {
return JWS{}, fmt.Errorf("malformed JWS, only %d segments", len(parts))
}
rawSig := parts[2]
jws := JWS{
RawHeader: parts[0],
RawPayload: parts[1],
}
header, err := decodeHeader(jws.RawHeader)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode header, %s", err)
}
if err = header.Validate(); err != nil {
return JWS{}, fmt.Errorf("malformed JWS, %s", err)
}
jws.Header = header
payload, err := decodeSegment(jws.RawPayload)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode payload: %s", err)
}
jws.Payload = payload
sig, err := decodeSegment(rawSig)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode signature: %s", err)
}
jws.Signature = sig
return jws, nil
}

View file

@ -1,82 +0,0 @@
package jose
import "strings"
type JWT JWS
func ParseJWT(token string) (jwt JWT, err error) {
jws, err := ParseJWS(token)
if err != nil {
return
}
return JWT(jws), nil
}
func NewJWT(header JOSEHeader, claims Claims) (jwt JWT, err error) {
jwt = JWT{}
jwt.Header = header
jwt.Header[HeaderMediaType] = "JWT"
claimBytes, err := marshalClaims(claims)
if err != nil {
return
}
jwt.Payload = claimBytes
eh, err := encodeHeader(header)
if err != nil {
return
}
jwt.RawHeader = eh
ec, err := encodeClaims(claims)
if err != nil {
return
}
jwt.RawPayload = ec
return
}
func (j *JWT) KeyID() (string, bool) {
kID, ok := j.Header[HeaderKeyID]
return kID, ok
}
func (j *JWT) Claims() (Claims, error) {
return decodeClaims(j.Payload)
}
// Encoded data part of the token which may be signed.
func (j *JWT) Data() string {
return strings.Join([]string{j.RawHeader, j.RawPayload}, ".")
}
// Full encoded JWT token string in format: header.claims.signature
func (j *JWT) Encode() string {
d := j.Data()
s := encodeSegment(j.Signature)
return strings.Join([]string{d, s}, ".")
}
func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
header := JOSEHeader{
HeaderKeyAlgorithm: s.Alg(),
HeaderKeyID: s.ID(),
}
jwt, err := NewJWT(header, claims)
if err != nil {
return nil, err
}
sig, err := s.Sign([]byte(jwt.Data()))
if err != nil {
return nil, err
}
jwt.Signature = sig
return &jwt, nil
}

View file

@ -1,24 +0,0 @@
package jose
import (
"fmt"
)
type Verifier interface {
ID() string
Alg() string
Verify(sig []byte, data []byte) error
}
type Signer interface {
Verifier
Sign(data []byte) (sig []byte, err error)
}
func NewVerifier(jwk JWK) (Verifier, error) {
if jwk.Type != "RSA" {
return nil, fmt.Errorf("unsupported key type %q", jwk.Type)
}
return NewVerifierRSA(jwk)
}

View file

@ -1,67 +0,0 @@
package jose
import (
"bytes"
"crypto"
"crypto/hmac"
_ "crypto/sha256"
"errors"
"fmt"
)
type VerifierHMAC struct {
KeyID string
Hash crypto.Hash
Secret []byte
}
type SignerHMAC struct {
VerifierHMAC
}
func NewVerifierHMAC(jwk JWK) (*VerifierHMAC, error) {
if jwk.Alg != "" && jwk.Alg != "HS256" {
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
}
v := VerifierHMAC{
KeyID: jwk.ID,
Secret: jwk.Secret,
Hash: crypto.SHA256,
}
return &v, nil
}
func (v *VerifierHMAC) ID() string {
return v.KeyID
}
func (v *VerifierHMAC) Alg() string {
return "HS256"
}
func (v *VerifierHMAC) Verify(sig []byte, data []byte) error {
h := hmac.New(v.Hash.New, v.Secret)
h.Write(data)
if !bytes.Equal(sig, h.Sum(nil)) {
return errors.New("invalid hmac signature")
}
return nil
}
func NewSignerHMAC(kid string, secret []byte) *SignerHMAC {
return &SignerHMAC{
VerifierHMAC: VerifierHMAC{
KeyID: kid,
Secret: secret,
Hash: crypto.SHA256,
},
}
}
func (s *SignerHMAC) Sign(data []byte) ([]byte, error) {
h := hmac.New(s.Hash.New, s.Secret)
h.Write(data)
return h.Sum(nil), nil
}

View file

@ -1,67 +0,0 @@
package jose
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"fmt"
)
type VerifierRSA struct {
KeyID string
Hash crypto.Hash
PublicKey rsa.PublicKey
}
type SignerRSA struct {
PrivateKey rsa.PrivateKey
VerifierRSA
}
func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) {
if jwk.Alg != "" && jwk.Alg != "RS256" {
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
}
v := VerifierRSA{
KeyID: jwk.ID,
PublicKey: rsa.PublicKey{
N: jwk.Modulus,
E: jwk.Exponent,
},
Hash: crypto.SHA256,
}
return &v, nil
}
func NewSignerRSA(kid string, key rsa.PrivateKey) *SignerRSA {
return &SignerRSA{
PrivateKey: key,
VerifierRSA: VerifierRSA{
KeyID: kid,
PublicKey: key.PublicKey,
Hash: crypto.SHA256,
},
}
}
func (v *VerifierRSA) ID() string {
return v.KeyID
}
func (v *VerifierRSA) Alg() string {
return "RS256"
}
func (v *VerifierRSA) Verify(sig []byte, data []byte) error {
h := v.Hash.New()
h.Write(data)
return rsa.VerifyPKCS1v15(&v.PublicKey, v.Hash, h.Sum(nil), sig)
}
func (s *SignerRSA) Sign(data []byte) ([]byte, error) {
h := s.Hash.New()
h.Write(data)
return rsa.SignPKCS1v15(rand.Reader, &s.PrivateKey, s.Hash, h.Sum(nil))
}

View file

@ -1,153 +0,0 @@
package key
import (
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"encoding/json"
"io"
"time"
"github.com/coreos/go-oidc/jose"
)
func NewPublicKey(jwk jose.JWK) *PublicKey {
return &PublicKey{jwk: jwk}
}
type PublicKey struct {
jwk jose.JWK
}
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(&k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string {
return k.jwk.ID
}
func (k *PublicKey) Verifier() (jose.Verifier, error) {
return jose.NewVerifierRSA(k.jwk)
}
type PrivateKey struct {
KeyID string
PrivateKey *rsa.PrivateKey
}
func (k *PrivateKey) ID() string {
return k.KeyID
}
func (k *PrivateKey) Signer() jose.Signer {
return jose.NewSignerRSA(k.ID(), *k.PrivateKey)
}
func (k *PrivateKey) JWK() jose.JWK {
return jose.JWK{
ID: k.KeyID,
Type: "RSA",
Alg: "RS256",
Use: "sig",
Exponent: k.PrivateKey.PublicKey.E,
Modulus: k.PrivateKey.PublicKey.N,
}
}
type KeySet interface {
ExpiresAt() time.Time
}
type PublicKeySet struct {
keys []PublicKey
index map[string]*PublicKey
expiresAt time.Time
}
func NewPublicKeySet(jwks []jose.JWK, exp time.Time) *PublicKeySet {
keys := make([]PublicKey, len(jwks))
index := make(map[string]*PublicKey)
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
index[keys[i].ID()] = &keys[i]
}
return &PublicKeySet{
keys: keys,
index: index,
expiresAt: exp,
}
}
func (s *PublicKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PublicKeySet) Keys() []PublicKey {
return s.keys
}
func (s *PublicKeySet) Key(id string) *PublicKey {
return s.index[id]
}
type PrivateKeySet struct {
keys []*PrivateKey
ActiveKeyID string
expiresAt time.Time
}
func NewPrivateKeySet(keys []*PrivateKey, exp time.Time) *PrivateKeySet {
return &PrivateKeySet{
keys: keys,
ActiveKeyID: keys[0].ID(),
expiresAt: exp.UTC(),
}
}
func (s *PrivateKeySet) Keys() []*PrivateKey {
return s.keys
}
func (s *PrivateKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PrivateKeySet) Active() *PrivateKey {
for i, k := range s.keys {
if k.ID() == s.ActiveKeyID {
return s.keys[i]
}
}
return nil
}
type GeneratePrivateKeyFunc func() (*PrivateKey, error)
func GeneratePrivateKey() (*PrivateKey, error) {
pk, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
keyID := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
return nil, err
}
k := PrivateKey{
KeyID: hex.EncodeToString(keyID),
PrivateKey: pk,
}
return &k, nil
}

View file

@ -1,99 +0,0 @@
package key
import (
"errors"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/pkg/health"
)
type PrivateKeyManager interface {
ExpiresAt() time.Time
Signer() (jose.Signer, error)
JWKs() ([]jose.JWK, error)
PublicKeys() ([]PublicKey, error)
WritableKeySetRepo
health.Checkable
}
func NewPrivateKeyManager() PrivateKeyManager {
return &privateKeyManager{
clock: clockwork.NewRealClock(),
}
}
type privateKeyManager struct {
keySet *PrivateKeySet
clock clockwork.Clock
}
func (m *privateKeyManager) ExpiresAt() time.Time {
if m.keySet == nil {
return m.clock.Now().UTC()
}
return m.keySet.ExpiresAt()
}
func (m *privateKeyManager) Signer() (jose.Signer, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
return m.keySet.Active().Signer(), nil
}
func (m *privateKeyManager) JWKs() ([]jose.JWK, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
keys := m.keySet.Keys()
jwks := make([]jose.JWK, len(keys))
for i, k := range keys {
jwks[i] = k.JWK()
}
return jwks, nil
}
func (m *privateKeyManager) PublicKeys() ([]PublicKey, error) {
jwks, err := m.JWKs()
if err != nil {
return nil, err
}
keys := make([]PublicKey, len(jwks))
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
}
return keys, nil
}
func (m *privateKeyManager) Healthy() error {
if m.keySet == nil {
return errors.New("private key manager uninitialized")
}
if len(m.keySet.Keys()) == 0 {
return errors.New("private key manager zero keys")
}
if m.keySet.ExpiresAt().Before(m.clock.Now().UTC()) {
return errors.New("private key manager keys expired")
}
return nil
}
func (m *privateKeyManager) Set(keySet KeySet) error {
privKeySet, ok := keySet.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
m.keySet = privKeySet
return nil
}

View file

@ -1,55 +0,0 @@
package key
import (
"errors"
"sync"
)
var ErrorNoKeys = errors.New("no keys found")
type WritableKeySetRepo interface {
Set(KeySet) error
}
type ReadableKeySetRepo interface {
Get() (KeySet, error)
}
type PrivateKeySetRepo interface {
WritableKeySetRepo
ReadableKeySetRepo
}
func NewPrivateKeySetRepo() PrivateKeySetRepo {
return &memPrivateKeySetRepo{}
}
type memPrivateKeySetRepo struct {
mu sync.RWMutex
pks PrivateKeySet
}
func (r *memPrivateKeySetRepo) Set(ks KeySet) error {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
} else if pks == nil {
return errors.New("nil KeySet")
}
r.mu.Lock()
defer r.mu.Unlock()
r.pks = *pks
return nil
}
func (r *memPrivateKeySetRepo) Get() (KeySet, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if r.pks.keys == nil {
return nil, ErrorNoKeys
}
return KeySet(&r.pks), nil
}

View file

@ -1,159 +0,0 @@
package key
import (
"errors"
"log"
"time"
ptime "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
)
var (
ErrorPrivateKeysExpired = errors.New("private keys have expired")
)
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
return &PrivateKeyRotator{
repo: repo,
ttl: ttl,
keep: 2,
generateKey: GeneratePrivateKey,
clock: clockwork.NewRealClock(),
}
}
type PrivateKeyRotator struct {
repo PrivateKeySetRepo
generateKey GeneratePrivateKeyFunc
clock clockwork.Clock
keep int
ttl time.Duration
}
func (r *PrivateKeyRotator) expiresAt() time.Time {
return r.clock.Now().UTC().Add(r.ttl)
}
func (r *PrivateKeyRotator) Healthy() error {
pks, err := r.privateKeySet()
if err != nil {
return err
}
if r.clock.Now().After(pks.ExpiresAt()) {
return ErrorPrivateKeysExpired
}
return nil
}
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
ks, err := r.repo.Get()
if err != nil {
return nil, err
}
pks, ok := ks.(*PrivateKeySet)
if !ok {
return nil, errors.New("unable to cast to PrivateKeySet")
}
return pks, nil
}
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
pks, err := r.privateKeySet()
if err == ErrorNoKeys {
return 0, nil
}
if err != nil {
return 0, err
}
now := r.clock.Now()
// Ideally, we want to rotate after half the TTL has elapsed.
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
// If we are past the ideal rotation time, rotate immediatly.
return max(0, idealRotationTime.Sub(now)), nil
}
func max(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}
func (r *PrivateKeyRotator) Run() chan struct{} {
attempt := func() {
k, err := r.generateKey()
if err != nil {
log.Printf("go-oidc: failed generating signing key: %v", err)
return
}
exp := r.expiresAt()
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
log.Printf("go-oidc: key rotation failed: %v", err)
return
}
}
stop := make(chan struct{})
go func() {
for {
var nextRotation time.Duration
var sleep time.Duration
var err error
for {
if nextRotation, err = r.nextRotation(); err == nil {
break
}
sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
time.Sleep(sleep)
}
select {
case <-r.clock.After(nextRotation):
attempt()
case <-stop:
return
}
}
}()
return stop
}
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
ks, err := repo.Get()
if err != nil && err != ErrorNoKeys {
return err
}
var keys []*PrivateKey
if ks != nil {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
keys = pks.Keys()
}
keys = append([]*PrivateKey{k}, keys...)
if l := len(keys); l > keep {
keys = keys[0:keep]
}
nks := PrivateKeySet{
keys: keys,
ActiveKeyID: k.ID(),
expiresAt: exp,
}
return repo.Set(KeySet(&nks))
}

View file

@ -1,91 +0,0 @@
package key
import (
"errors"
"log"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/pkg/timeutil"
)
func NewKeySetSyncer(r ReadableKeySetRepo, w WritableKeySetRepo) *KeySetSyncer {
return &KeySetSyncer{
readable: r,
writable: w,
clock: clockwork.NewRealClock(),
}
}
type KeySetSyncer struct {
readable ReadableKeySetRepo
writable WritableKeySetRepo
clock clockwork.Clock
}
func (s *KeySetSyncer) Run() chan struct{} {
stop := make(chan struct{})
go func() {
var failing bool
var next time.Duration
for {
exp, err := syncKeySet(s.readable, s.writable, s.clock)
if err != nil || exp == 0 {
if !failing {
failing = true
next = time.Second
} else {
next = timeutil.ExpBackoff(next, time.Minute)
}
if exp == 0 {
log.Printf("Synced to already expired key set, retrying in %v: %v", next, err)
} else {
log.Printf("Failed syncing key set, retrying in %v: %v", next, err)
}
} else {
failing = false
next = exp / 2
}
select {
case <-s.clock.After(next):
continue
case <-stop:
return
}
}
}()
return stop
}
func Sync(r ReadableKeySetRepo, w WritableKeySetRepo) (time.Duration, error) {
return syncKeySet(r, w, clockwork.NewRealClock())
}
// syncKeySet copies the keyset from r to the KeySet at w and returns the duration in which the KeySet will expire.
// If keyset has already expired, returns a zero duration.
func syncKeySet(r ReadableKeySetRepo, w WritableKeySetRepo, clock clockwork.Clock) (exp time.Duration, err error) {
var ks KeySet
ks, err = r.Get()
if err != nil {
return
}
if ks == nil {
err = errors.New("no source KeySet")
return
}
if err = w.Set(ks); err != nil {
return
}
now := clock.Now()
if ks.ExpiresAt().After(now) {
exp = ks.ExpiresAt().Sub(now)
}
return
}

View file

@ -1,29 +0,0 @@
package oauth2
const (
ErrorAccessDenied = "access_denied"
ErrorInvalidClient = "invalid_client"
ErrorInvalidGrant = "invalid_grant"
ErrorInvalidRequest = "invalid_request"
ErrorServerError = "server_error"
ErrorUnauthorizedClient = "unauthorized_client"
ErrorUnsupportedGrantType = "unsupported_grant_type"
ErrorUnsupportedResponseType = "unsupported_response_type"
)
type Error struct {
Type string `json:"error"`
Description string `json:"error_description,omitempty"`
State string `json:"state,omitempty"`
}
func (e *Error) Error() string {
if e.Description != "" {
return e.Type + ": " + e.Description
}
return e.Type
}
func NewError(typ string) *Error {
return &Error{Type: typ}
}

View file

@ -1,416 +0,0 @@
package oauth2
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"mime"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
phttp "github.com/coreos/go-oidc/http"
)
// ResponseTypesEqual compares two response_type values. If either
// contains a space, it is treated as an unordered list. For example,
// comparing "code id_token" and "id_token code" would evaluate to true.
func ResponseTypesEqual(r1, r2 string) bool {
if !strings.Contains(r1, " ") || !strings.Contains(r2, " ") {
// fast route, no split needed
return r1 == r2
}
// split, sort, and compare
r1Fields := strings.Fields(r1)
r2Fields := strings.Fields(r2)
if len(r1Fields) != len(r2Fields) {
return false
}
sort.Strings(r1Fields)
sort.Strings(r2Fields)
for i, r1Field := range r1Fields {
if r1Field != r2Fields[i] {
return false
}
}
return true
}
const (
// OAuth2.0 response types registered by OIDC.
//
// See: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#RegistryContents
ResponseTypeCode = "code"
ResponseTypeCodeIDToken = "code id_token"
ResponseTypeCodeIDTokenToken = "code id_token token"
ResponseTypeIDToken = "id_token"
ResponseTypeIDTokenToken = "id_token token"
ResponseTypeToken = "token"
ResponseTypeNone = "none"
)
const (
GrantTypeAuthCode = "authorization_code"
GrantTypeClientCreds = "client_credentials"
GrantTypeUserCreds = "password"
GrantTypeImplicit = "implicit"
GrantTypeRefreshToken = "refresh_token"
AuthMethodClientSecretPost = "client_secret_post"
AuthMethodClientSecretBasic = "client_secret_basic"
AuthMethodClientSecretJWT = "client_secret_jwt"
AuthMethodPrivateKeyJWT = "private_key_jwt"
)
type Config struct {
Credentials ClientCredentials
Scope []string
RedirectURL string
AuthURL string
TokenURL string
// Must be one of the AuthMethodXXX methods above. Right now, only
// AuthMethodClientSecretPost and AuthMethodClientSecretBasic are supported.
AuthMethod string
}
type Client struct {
hc phttp.Client
creds ClientCredentials
scope []string
authURL *url.URL
redirectURL *url.URL
tokenURL *url.URL
authMethod string
}
type ClientCredentials struct {
ID string
Secret string
}
func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) {
if len(cfg.Credentials.ID) == 0 {
err = errors.New("missing client id")
return
}
if len(cfg.Credentials.Secret) == 0 {
err = errors.New("missing client secret")
return
}
if cfg.AuthMethod == "" {
cfg.AuthMethod = AuthMethodClientSecretBasic
} else if cfg.AuthMethod != AuthMethodClientSecretPost && cfg.AuthMethod != AuthMethodClientSecretBasic {
err = fmt.Errorf("auth method %q is not supported", cfg.AuthMethod)
return
}
au, err := phttp.ParseNonEmptyURL(cfg.AuthURL)
if err != nil {
return
}
tu, err := phttp.ParseNonEmptyURL(cfg.TokenURL)
if err != nil {
return
}
// Allow empty redirect URL in the case where the client
// only needs to verify a given token.
ru, err := url.Parse(cfg.RedirectURL)
if err != nil {
return
}
c = &Client{
creds: cfg.Credentials,
scope: cfg.Scope,
redirectURL: ru,
authURL: au,
tokenURL: tu,
hc: hc,
authMethod: cfg.AuthMethod,
}
return
}
// Return the embedded HTTP client
func (c *Client) HttpClient() phttp.Client {
return c.hc
}
// Generate the url for initial redirect to oauth provider.
func (c *Client) AuthCodeURL(state, accessType, prompt string) string {
v := c.commonURLValues()
v.Set("state", state)
if strings.ToLower(accessType) == "offline" {
v.Set("access_type", "offline")
}
if prompt != "" {
v.Set("prompt", prompt)
}
v.Set("response_type", "code")
q := v.Encode()
u := *c.authURL
if u.RawQuery == "" {
u.RawQuery = q
} else {
u.RawQuery += "&" + q
}
return u.String()
}
func (c *Client) commonURLValues() url.Values {
return url.Values{
"redirect_uri": {c.redirectURL.String()},
"scope": {strings.Join(c.scope, " ")},
"client_id": {c.creds.ID},
}
}
func (c *Client) newAuthenticatedRequest(urlToken string, values url.Values) (*http.Request, error) {
var req *http.Request
var err error
switch c.authMethod {
case AuthMethodClientSecretPost:
values.Set("client_secret", c.creds.Secret)
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
case AuthMethodClientSecretBasic:
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
encodedID := url.QueryEscape(c.creds.ID)
encodedSecret := url.QueryEscape(c.creds.Secret)
req.SetBasicAuth(encodedID, encodedSecret)
default:
panic("misconfigured client: auth method not supported")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil
}
// ClientCredsToken posts the client id and secret to obtain a token scoped to the OAuth2 client via the "client_credentials" grant type.
// May not be supported by all OAuth2 servers.
func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err error) {
v := url.Values{
"scope": {strings.Join(scope, " ")},
"grant_type": {GrantTypeClientCreds},
}
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
if err != nil {
return
}
resp, err := c.hc.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
return parseTokenResponse(resp)
}
// UserCredsToken posts the username and password to obtain a token scoped to the OAuth2 client via the "password" grant_type
// May not be supported by all OAuth2 servers.
func (c *Client) UserCredsToken(username, password string) (result TokenResponse, err error) {
v := url.Values{
"scope": {strings.Join(c.scope, " ")},
"grant_type": {GrantTypeUserCreds},
"username": {username},
"password": {password},
}
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
if err != nil {
return
}
resp, err := c.hc.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
return parseTokenResponse(resp)
}
// RequestToken requests a token from the Token Endpoint with the specified grantType.
// If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code.
// If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token.
func (c *Client) RequestToken(grantType, value string) (result TokenResponse, err error) {
v := c.commonURLValues()
v.Set("grant_type", grantType)
v.Set("client_secret", c.creds.Secret)
switch grantType {
case GrantTypeAuthCode:
v.Set("code", value)
case GrantTypeRefreshToken:
v.Set("refresh_token", value)
default:
err = fmt.Errorf("unsupported grant_type: %v", grantType)
return
}
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
if err != nil {
return
}
resp, err := c.hc.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
return parseTokenResponse(resp)
}
func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return
}
result = TokenResponse{
RawBody: body,
}
newError := func(typ, desc, state string) error {
if typ == "" {
return fmt.Errorf("unrecognized error %s", body)
}
return &Error{typ, desc, state}
}
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
var vals url.Values
vals, err = url.ParseQuery(string(body))
if err != nil {
return
}
if error := vals.Get("error"); error != "" || badStatusCode {
err = newError(error, vals.Get("error_description"), vals.Get("state"))
return
}
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
if e != "" {
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
}
result.AccessToken = vals.Get("access_token")
result.TokenType = vals.Get("token_type")
result.IDToken = vals.Get("id_token")
result.RefreshToken = vals.Get("refresh_token")
result.Scope = vals.Get("scope")
} else {
var r struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn json.Number `json:"expires_in"` // Azure AD returns string
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
}
if err = json.Unmarshal(body, &r); err != nil {
return
}
if r.Error != "" || badStatusCode {
err = newError(r.Error, r.Desc, r.State)
return
}
result.AccessToken = r.AccessToken
result.TokenType = r.TokenType
result.IDToken = r.IDToken
result.RefreshToken = r.RefreshToken
result.Scope = r.Scope
if expiresIn, err := r.ExpiresIn.Int64(); err != nil {
result.Expires = r.Expires
} else {
result.Expires = int(expiresIn)
}
}
return
}
type TokenResponse struct {
AccessToken string
TokenType string
Expires int
IDToken string
RefreshToken string // OPTIONAL.
Scope string // OPTIONAL, if identical to the scope requested by the client, otherwise, REQUIRED.
RawBody []byte // In case callers need some other non-standard info from the token response
}
type AuthCodeRequest struct {
ResponseType string
ClientID string
RedirectURL *url.URL
Scope []string
State string
}
func ParseAuthCodeRequest(q url.Values) (AuthCodeRequest, error) {
acr := AuthCodeRequest{
ResponseType: q.Get("response_type"),
ClientID: q.Get("client_id"),
State: q.Get("state"),
Scope: make([]string, 0),
}
qs := strings.TrimSpace(q.Get("scope"))
if qs != "" {
acr.Scope = strings.Split(qs, " ")
}
err := func() error {
if acr.ClientID == "" {
return NewError(ErrorInvalidRequest)
}
redirectURL := q.Get("redirect_uri")
if redirectURL != "" {
ru, err := url.Parse(redirectURL)
if err != nil {
return NewError(ErrorInvalidRequest)
}
acr.RedirectURL = ru
}
return nil
}()
return acr, err
}

View file

@ -1,846 +0,0 @@
package oidc
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/mail"
"net/url"
"sync"
"time"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
)
const (
// amount of time that must pass after the last key sync
// completes before another attempt may begin
keySyncWindow = 5 * time.Second
)
var (
DefaultScope = []string{"openid", "email", "profile"}
supportedAuthMethods = map[string]struct{}{
oauth2.AuthMethodClientSecretBasic: struct{}{},
oauth2.AuthMethodClientSecretPost: struct{}{},
}
)
type ClientCredentials oauth2.ClientCredentials
type ClientIdentity struct {
Credentials ClientCredentials
Metadata ClientMetadata
}
type JWAOptions struct {
// SigningAlg specifies an JWA alg for signing JWTs.
//
// Specifying this field implies different actions depending on the context. It may
// require objects be serialized and signed as a JWT instead of plain JSON, or
// require an existing JWT object use the specified alg.
//
// See: http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
SigningAlg string
// EncryptionAlg, if provided, specifies that the returned or sent object be stored
// (or nested) within a JWT object and encrypted with the provided JWA alg.
EncryptionAlg string
// EncryptionEnc specifies the JWA enc algorithm to use with EncryptionAlg. If
// EncryptionAlg is provided and EncryptionEnc is omitted, this field defaults
// to A128CBC-HS256.
//
// If EncryptionEnc is provided EncryptionAlg must also be specified.
EncryptionEnc string
}
func (opt JWAOptions) valid() error {
if opt.EncryptionEnc != "" && opt.EncryptionAlg == "" {
return errors.New("encryption encoding provided with no encryption algorithm")
}
return nil
}
func (opt JWAOptions) defaults() JWAOptions {
if opt.EncryptionAlg != "" && opt.EncryptionEnc == "" {
opt.EncryptionEnc = jose.EncA128CBCHS256
}
return opt
}
var (
// Ensure ClientMetadata satisfies these interfaces.
_ json.Marshaler = &ClientMetadata{}
_ json.Unmarshaler = &ClientMetadata{}
)
// ClientMetadata holds metadata that the authorization server associates
// with a client identifier. The fields range from human-facing display
// strings such as client name, to items that impact the security of the
// protocol, such as the list of valid redirect URIs.
//
// See http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
//
// TODO: support language specific claim representations
// http://openid.net/specs/openid-connect-registration-1_0.html#LanguagesAndScripts
type ClientMetadata struct {
RedirectURIs []url.URL // Required
// A list of OAuth 2.0 "response_type" values that the client wishes to restrict
// itself to. Either "code", "token", or another registered extension.
//
// If omitted, only "code" will be used.
ResponseTypes []string
// A list of OAuth 2.0 grant types the client wishes to restrict itself to.
// The grant type values used by OIDC are "authorization_code", "implicit",
// and "refresh_token".
//
// If ommitted, only "authorization_code" will be used.
GrantTypes []string
// "native" or "web". If omitted, "web".
ApplicationType string
// List of email addresses.
Contacts []mail.Address
// Name of client to be presented to the end-user.
ClientName string
// URL that references a logo for the Client application.
LogoURI *url.URL
// URL of the home page of the Client.
ClientURI *url.URL
// Profile data policies and terms of use to be provided to the end user.
PolicyURI *url.URL
TermsOfServiceURI *url.URL
// URL to or the value of the client's JSON Web Key Set document.
JWKSURI *url.URL
JWKS *jose.JWKSet
// URL referencing a flie with a single JSON array of redirect URIs.
SectorIdentifierURI *url.URL
SubjectType string
// Options to restrict the JWS alg and enc values used for server responses and requests.
IDTokenResponseOptions JWAOptions
UserInfoResponseOptions JWAOptions
RequestObjectOptions JWAOptions
// Client requested authorization method and signing options for the token endpoint.
//
// Defaults to "client_secret_basic"
TokenEndpointAuthMethod string
TokenEndpointAuthSigningAlg string
// DefaultMaxAge specifies the maximum amount of time in seconds before an authorized
// user must reauthroize.
//
// If 0, no limitation is placed on the maximum.
DefaultMaxAge int64
// RequireAuthTime specifies if the auth_time claim in the ID token is required.
RequireAuthTime bool
// Default Authentication Context Class Reference values for authentication requests.
DefaultACRValues []string
// URI that a third party can use to initiate a login by the relaying party.
//
// See: http://openid.net/specs/openid-connect-core-1_0.html#ThirdPartyInitiatedLogin
InitiateLoginURI *url.URL
// Pre-registered request_uri values that may be cached by the server.
RequestURIs []url.URL
}
// Defaults returns a shallow copy of ClientMetadata with default
// values replacing omitted fields.
func (m ClientMetadata) Defaults() ClientMetadata {
if len(m.ResponseTypes) == 0 {
m.ResponseTypes = []string{oauth2.ResponseTypeCode}
}
if len(m.GrantTypes) == 0 {
m.GrantTypes = []string{oauth2.GrantTypeAuthCode}
}
if m.ApplicationType == "" {
m.ApplicationType = "web"
}
if m.TokenEndpointAuthMethod == "" {
m.TokenEndpointAuthMethod = oauth2.AuthMethodClientSecretBasic
}
m.IDTokenResponseOptions = m.IDTokenResponseOptions.defaults()
m.UserInfoResponseOptions = m.UserInfoResponseOptions.defaults()
m.RequestObjectOptions = m.RequestObjectOptions.defaults()
return m
}
func (m *ClientMetadata) MarshalJSON() ([]byte, error) {
e := m.toEncodableStruct()
return json.Marshal(&e)
}
func (m *ClientMetadata) UnmarshalJSON(data []byte) error {
var e encodableClientMetadata
if err := json.Unmarshal(data, &e); err != nil {
return err
}
meta, err := e.toStruct()
if err != nil {
return err
}
if err := meta.Valid(); err != nil {
return err
}
*m = meta
return nil
}
type encodableClientMetadata struct {
RedirectURIs []string `json:"redirect_uris"` // Required
ResponseTypes []string `json:"response_types,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
Contacts []string `json:"contacts,omitempty"`
ClientName string `json:"client_name,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
TermsOfServiceURI string `json:"tos_uri,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
JWKS *jose.JWKSet `json:"jwks,omitempty"`
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
SubjectType string `json:"subject_type,omitempty"`
IDTokenSignedResponseAlg string `json:"id_token_signed_response_alg,omitempty"`
IDTokenEncryptedResponseAlg string `json:"id_token_encrypted_response_alg,omitempty"`
IDTokenEncryptedResponseEnc string `json:"id_token_encrypted_response_enc,omitempty"`
UserInfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty"`
UserInfoEncryptedResponseAlg string `json:"userinfo_encrypted_response_alg,omitempty"`
UserInfoEncryptedResponseEnc string `json:"userinfo_encrypted_response_enc,omitempty"`
RequestObjectSigningAlg string `json:"request_object_signing_alg,omitempty"`
RequestObjectEncryptionAlg string `json:"request_object_encryption_alg,omitempty"`
RequestObjectEncryptionEnc string `json:"request_object_encryption_enc,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty"`
DefaultMaxAge int64 `json:"default_max_age,omitempty"`
RequireAuthTime bool `json:"require_auth_time,omitempty"`
DefaultACRValues []string `json:"default_acr_values,omitempty"`
InitiateLoginURI string `json:"initiate_login_uri,omitempty"`
RequestURIs []string `json:"request_uris,omitempty"`
}
func (c *encodableClientMetadata) toStruct() (ClientMetadata, error) {
p := stickyErrParser{}
m := ClientMetadata{
RedirectURIs: p.parseURIs(c.RedirectURIs, "redirect_uris"),
ResponseTypes: c.ResponseTypes,
GrantTypes: c.GrantTypes,
ApplicationType: c.ApplicationType,
Contacts: p.parseEmails(c.Contacts, "contacts"),
ClientName: c.ClientName,
LogoURI: p.parseURI(c.LogoURI, "logo_uri"),
ClientURI: p.parseURI(c.ClientURI, "client_uri"),
PolicyURI: p.parseURI(c.PolicyURI, "policy_uri"),
TermsOfServiceURI: p.parseURI(c.TermsOfServiceURI, "tos_uri"),
JWKSURI: p.parseURI(c.JWKSURI, "jwks_uri"),
JWKS: c.JWKS,
SectorIdentifierURI: p.parseURI(c.SectorIdentifierURI, "sector_identifier_uri"),
SubjectType: c.SubjectType,
TokenEndpointAuthMethod: c.TokenEndpointAuthMethod,
TokenEndpointAuthSigningAlg: c.TokenEndpointAuthSigningAlg,
DefaultMaxAge: c.DefaultMaxAge,
RequireAuthTime: c.RequireAuthTime,
DefaultACRValues: c.DefaultACRValues,
InitiateLoginURI: p.parseURI(c.InitiateLoginURI, "initiate_login_uri"),
RequestURIs: p.parseURIs(c.RequestURIs, "request_uris"),
IDTokenResponseOptions: JWAOptions{
c.IDTokenSignedResponseAlg,
c.IDTokenEncryptedResponseAlg,
c.IDTokenEncryptedResponseEnc,
},
UserInfoResponseOptions: JWAOptions{
c.UserInfoSignedResponseAlg,
c.UserInfoEncryptedResponseAlg,
c.UserInfoEncryptedResponseEnc,
},
RequestObjectOptions: JWAOptions{
c.RequestObjectSigningAlg,
c.RequestObjectEncryptionAlg,
c.RequestObjectEncryptionEnc,
},
}
if p.firstErr != nil {
return ClientMetadata{}, p.firstErr
}
return m, nil
}
// stickyErrParser parses URIs and email addresses. Once it encounters
// a parse error, subsequent calls become no-op.
type stickyErrParser struct {
firstErr error
}
func (p *stickyErrParser) parseURI(s, field string) *url.URL {
if p.firstErr != nil || s == "" {
return nil
}
u, err := url.Parse(s)
if err == nil {
if u.Host == "" {
err = errors.New("no host in URI")
} else if u.Scheme != "http" && u.Scheme != "https" {
err = errors.New("invalid URI scheme")
}
}
if err != nil {
p.firstErr = fmt.Errorf("failed to parse %s: %v", field, err)
return nil
}
return u
}
func (p *stickyErrParser) parseURIs(s []string, field string) []url.URL {
if p.firstErr != nil || len(s) == 0 {
return nil
}
uris := make([]url.URL, len(s))
for i, val := range s {
if val == "" {
p.firstErr = fmt.Errorf("invalid URI in field %s", field)
return nil
}
if u := p.parseURI(val, field); u != nil {
uris[i] = *u
}
}
return uris
}
func (p *stickyErrParser) parseEmails(s []string, field string) []mail.Address {
if p.firstErr != nil || len(s) == 0 {
return nil
}
addrs := make([]mail.Address, len(s))
for i, addr := range s {
if addr == "" {
p.firstErr = fmt.Errorf("invalid email in field %s", field)
return nil
}
a, err := mail.ParseAddress(addr)
if err != nil {
p.firstErr = fmt.Errorf("invalid email in field %s: %v", field, err)
return nil
}
addrs[i] = *a
}
return addrs
}
func (m *ClientMetadata) toEncodableStruct() encodableClientMetadata {
return encodableClientMetadata{
RedirectURIs: urisToStrings(m.RedirectURIs),
ResponseTypes: m.ResponseTypes,
GrantTypes: m.GrantTypes,
ApplicationType: m.ApplicationType,
Contacts: emailsToStrings(m.Contacts),
ClientName: m.ClientName,
LogoURI: uriToString(m.LogoURI),
ClientURI: uriToString(m.ClientURI),
PolicyURI: uriToString(m.PolicyURI),
TermsOfServiceURI: uriToString(m.TermsOfServiceURI),
JWKSURI: uriToString(m.JWKSURI),
JWKS: m.JWKS,
SectorIdentifierURI: uriToString(m.SectorIdentifierURI),
SubjectType: m.SubjectType,
IDTokenSignedResponseAlg: m.IDTokenResponseOptions.SigningAlg,
IDTokenEncryptedResponseAlg: m.IDTokenResponseOptions.EncryptionAlg,
IDTokenEncryptedResponseEnc: m.IDTokenResponseOptions.EncryptionEnc,
UserInfoSignedResponseAlg: m.UserInfoResponseOptions.SigningAlg,
UserInfoEncryptedResponseAlg: m.UserInfoResponseOptions.EncryptionAlg,
UserInfoEncryptedResponseEnc: m.UserInfoResponseOptions.EncryptionEnc,
RequestObjectSigningAlg: m.RequestObjectOptions.SigningAlg,
RequestObjectEncryptionAlg: m.RequestObjectOptions.EncryptionAlg,
RequestObjectEncryptionEnc: m.RequestObjectOptions.EncryptionEnc,
TokenEndpointAuthMethod: m.TokenEndpointAuthMethod,
TokenEndpointAuthSigningAlg: m.TokenEndpointAuthSigningAlg,
DefaultMaxAge: m.DefaultMaxAge,
RequireAuthTime: m.RequireAuthTime,
DefaultACRValues: m.DefaultACRValues,
InitiateLoginURI: uriToString(m.InitiateLoginURI),
RequestURIs: urisToStrings(m.RequestURIs),
}
}
func uriToString(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
func urisToStrings(urls []url.URL) []string {
if len(urls) == 0 {
return nil
}
sli := make([]string, len(urls))
for i, u := range urls {
sli[i] = u.String()
}
return sli
}
func emailsToStrings(addrs []mail.Address) []string {
if len(addrs) == 0 {
return nil
}
sli := make([]string, len(addrs))
for i, addr := range addrs {
sli[i] = addr.String()
}
return sli
}
// Valid determines if a ClientMetadata conforms with the OIDC specification.
//
// Valid is called by UnmarshalJSON.
//
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
// URLs fields where the OIDC spec requires it. This may change in future releases
// of this package. See: https://github.com/coreos/go-oidc/issues/34
func (m *ClientMetadata) Valid() error {
if len(m.RedirectURIs) == 0 {
return errors.New("zero redirect URLs")
}
validURI := func(u *url.URL, fieldName string) error {
if u.Host == "" {
return fmt.Errorf("no host for uri field %s", fieldName)
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("uri field %s scheme is not http or https", fieldName)
}
return nil
}
uris := []struct {
val *url.URL
name string
}{
{m.LogoURI, "logo_uri"},
{m.ClientURI, "client_uri"},
{m.PolicyURI, "policy_uri"},
{m.TermsOfServiceURI, "tos_uri"},
{m.JWKSURI, "jwks_uri"},
{m.SectorIdentifierURI, "sector_identifier_uri"},
{m.InitiateLoginURI, "initiate_login_uri"},
}
for _, uri := range uris {
if uri.val == nil {
continue
}
if err := validURI(uri.val, uri.name); err != nil {
return err
}
}
uriLists := []struct {
vals []url.URL
name string
}{
{m.RedirectURIs, "redirect_uris"},
{m.RequestURIs, "request_uris"},
}
for _, list := range uriLists {
for _, uri := range list.vals {
if err := validURI(&uri, list.name); err != nil {
return err
}
}
}
options := []struct {
option JWAOptions
name string
}{
{m.IDTokenResponseOptions, "id_token response"},
{m.UserInfoResponseOptions, "userinfo response"},
{m.RequestObjectOptions, "request_object"},
}
for _, option := range options {
if err := option.option.valid(); err != nil {
return fmt.Errorf("invalid JWA values for %s: %v", option.name, err)
}
}
return nil
}
type ClientRegistrationResponse struct {
ClientID string // Required
ClientSecret string
RegistrationAccessToken string
RegistrationClientURI string
// If IsZero is true, unspecified.
ClientIDIssuedAt time.Time
// Time at which the client_secret will expire.
// If IsZero is true, it will not expire.
ClientSecretExpiresAt time.Time
ClientMetadata
}
type encodableClientRegistrationResponse struct {
ClientID string `json:"client_id"` // Required
ClientSecret string `json:"client_secret,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
// Time at which the client_secret will expire, in seconds since the epoch.
// If 0 it will not expire.
ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` // Required
encodableClientMetadata
}
func unixToSec(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func (c *ClientRegistrationResponse) MarshalJSON() ([]byte, error) {
e := encodableClientRegistrationResponse{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
RegistrationAccessToken: c.RegistrationAccessToken,
RegistrationClientURI: c.RegistrationClientURI,
ClientIDIssuedAt: unixToSec(c.ClientIDIssuedAt),
ClientSecretExpiresAt: unixToSec(c.ClientSecretExpiresAt),
encodableClientMetadata: c.ClientMetadata.toEncodableStruct(),
}
return json.Marshal(&e)
}
func secToUnix(sec int64) time.Time {
if sec == 0 {
return time.Time{}
}
return time.Unix(sec, 0)
}
func (c *ClientRegistrationResponse) UnmarshalJSON(data []byte) error {
var e encodableClientRegistrationResponse
if err := json.Unmarshal(data, &e); err != nil {
return err
}
if e.ClientID == "" {
return errors.New("no client_id in client registration response")
}
metadata, err := e.encodableClientMetadata.toStruct()
if err != nil {
return err
}
*c = ClientRegistrationResponse{
ClientID: e.ClientID,
ClientSecret: e.ClientSecret,
RegistrationAccessToken: e.RegistrationAccessToken,
RegistrationClientURI: e.RegistrationClientURI,
ClientIDIssuedAt: secToUnix(e.ClientIDIssuedAt),
ClientSecretExpiresAt: secToUnix(e.ClientSecretExpiresAt),
ClientMetadata: metadata,
}
return nil
}
type ClientConfig struct {
HTTPClient phttp.Client
Credentials ClientCredentials
Scope []string
RedirectURL string
ProviderConfig ProviderConfig
KeySet key.PublicKeySet
}
func NewClient(cfg ClientConfig) (*Client, error) {
// Allow empty redirect URL in the case where the client
// only needs to verify a given token.
ru, err := url.Parse(cfg.RedirectURL)
if err != nil {
return nil, fmt.Errorf("invalid redirect URL: %v", err)
}
c := Client{
credentials: cfg.Credentials,
httpClient: cfg.HTTPClient,
scope: cfg.Scope,
redirectURL: ru.String(),
providerConfig: newProviderConfigRepo(cfg.ProviderConfig),
keySet: cfg.KeySet,
}
if c.httpClient == nil {
c.httpClient = http.DefaultClient
}
if c.scope == nil {
c.scope = make([]string, len(DefaultScope))
copy(c.scope, DefaultScope)
}
return &c, nil
}
type Client struct {
httpClient phttp.Client
providerConfig *providerConfigRepo
credentials ClientCredentials
redirectURL string
scope []string
keySet key.PublicKeySet
providerSyncer *ProviderConfigSyncer
keySetSyncMutex sync.RWMutex
lastKeySetSync time.Time
}
func (c *Client) Healthy() error {
now := time.Now().UTC()
cfg := c.providerConfig.Get()
if cfg.Empty() {
return errors.New("oidc client provider config empty")
}
if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) {
return errors.New("oidc client provider config expired")
}
return nil
}
func (c *Client) OAuthClient() (*oauth2.Client, error) {
cfg := c.providerConfig.Get()
authMethod, err := chooseAuthMethod(cfg)
if err != nil {
return nil, err
}
ocfg := oauth2.Config{
Credentials: oauth2.ClientCredentials(c.credentials),
RedirectURL: c.redirectURL,
AuthURL: cfg.AuthEndpoint.String(),
TokenURL: cfg.TokenEndpoint.String(),
Scope: c.scope,
AuthMethod: authMethod,
}
return oauth2.NewClient(c.httpClient, ocfg)
}
func chooseAuthMethod(cfg ProviderConfig) (string, error) {
if len(cfg.TokenEndpointAuthMethodsSupported) == 0 {
return oauth2.AuthMethodClientSecretBasic, nil
}
for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported {
if _, ok := supportedAuthMethods[authMethod]; ok {
return authMethod, nil
}
}
return "", errors.New("no supported auth methods")
}
// SyncProviderConfig starts the provider config syncer
func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} {
r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL)
s := NewProviderConfigSyncer(r, c.providerConfig)
stop := s.Run()
s.WaitUntilInitialSync()
return stop
}
func (c *Client) maybeSyncKeys() error {
tooSoon := func() bool {
return time.Now().UTC().Before(c.lastKeySetSync.Add(keySyncWindow))
}
// ignore request to sync keys if a sync operation has been
// attempted too recently
if tooSoon() {
return nil
}
c.keySetSyncMutex.Lock()
defer c.keySetSyncMutex.Unlock()
// check again, as another goroutine may have been holding
// the lock while updating the keys
if tooSoon() {
return nil
}
cfg := c.providerConfig.Get()
r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint.String())
w := &clientKeyRepo{client: c}
_, err := key.Sync(r, w)
c.lastKeySetSync = time.Now().UTC()
return err
}
type clientKeyRepo struct {
client *Client
}
func (r *clientKeyRepo) Set(ks key.KeySet) error {
pks, ok := ks.(*key.PublicKeySet)
if !ok {
return errors.New("unable to cast to PublicKey")
}
r.client.keySet = *pks
return nil
}
func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) {
cfg := c.providerConfig.Get()
if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) {
return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds)
}
oac, err := c.OAuthClient()
if err != nil {
return jose.JWT{}, err
}
t, err := oac.ClientCredsToken(scope)
if err != nil {
return jose.JWT{}, err
}
jwt, err := jose.ParseJWT(t.IDToken)
if err != nil {
return jose.JWT{}, err
}
return jwt, c.VerifyJWT(jwt)
}
// ExchangeAuthCode exchanges an OAuth2 auth code for an OIDC JWT ID token.
func (c *Client) ExchangeAuthCode(code string) (jose.JWT, error) {
oac, err := c.OAuthClient()
if err != nil {
return jose.JWT{}, err
}
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
if err != nil {
return jose.JWT{}, err
}
jwt, err := jose.ParseJWT(t.IDToken)
if err != nil {
return jose.JWT{}, err
}
return jwt, c.VerifyJWT(jwt)
}
// RefreshToken uses a refresh token to exchange for a new OIDC JWT ID Token.
func (c *Client) RefreshToken(refreshToken string) (jose.JWT, error) {
oac, err := c.OAuthClient()
if err != nil {
return jose.JWT{}, err
}
t, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken)
if err != nil {
return jose.JWT{}, err
}
jwt, err := jose.ParseJWT(t.IDToken)
if err != nil {
return jose.JWT{}, err
}
return jwt, c.VerifyJWT(jwt)
}
func (c *Client) VerifyJWT(jwt jose.JWT) error {
var keysFunc func() []key.PublicKey
if kID, ok := jwt.KeyID(); ok {
keysFunc = c.keysFuncWithID(kID)
} else {
keysFunc = c.keysFuncAll()
}
v := NewJWTVerifier(
c.providerConfig.Get().Issuer.String(),
c.credentials.ID,
c.maybeSyncKeys, keysFunc)
return v.Verify(jwt)
}
// keysFuncWithID returns a function that retrieves at most unexpired
// public key from the Client that matches the provided ID
func (c *Client) keysFuncWithID(kID string) func() []key.PublicKey {
return func() []key.PublicKey {
c.keySetSyncMutex.RLock()
defer c.keySetSyncMutex.RUnlock()
if c.keySet.ExpiresAt().Before(time.Now()) {
return []key.PublicKey{}
}
k := c.keySet.Key(kID)
if k == nil {
return []key.PublicKey{}
}
return []key.PublicKey{*k}
}
}
// keysFuncAll returns a function that retrieves all unexpired public
// keys from the Client
func (c *Client) keysFuncAll() func() []key.PublicKey {
return func() []key.PublicKey {
c.keySetSyncMutex.RLock()
defer c.keySetSyncMutex.RUnlock()
if c.keySet.ExpiresAt().Before(time.Now()) {
return []key.PublicKey{}
}
return c.keySet.Keys()
}
}
type providerConfigRepo struct {
mu sync.RWMutex
config ProviderConfig // do not access directly, use Get()
}
func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo {
return &providerConfigRepo{sync.RWMutex{}, pc}
}
// returns an error to implement ProviderConfigSetter
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
r.mu.Lock()
defer r.mu.Unlock()
r.config = cfg
return nil
}
func (r *providerConfigRepo) Get() ProviderConfig {
r.mu.RLock()
defer r.mu.RUnlock()
return r.config
}

View file

@ -1,44 +0,0 @@
package oidc
import (
"errors"
"time"
"github.com/coreos/go-oidc/jose"
)
type Identity struct {
ID string
Name string
Email string
ExpiresAt time.Time
}
func IdentityFromClaims(claims jose.Claims) (*Identity, error) {
if claims == nil {
return nil, errors.New("nil claim set")
}
var ident Identity
var err error
var ok bool
if ident.ID, ok, err = claims.StringClaim("sub"); err != nil {
return nil, err
} else if !ok {
return nil, errors.New("missing required claim: sub")
}
if ident.Email, _, err = claims.StringClaim("email"); err != nil {
return nil, err
}
exp, ok, err := claims.TimeClaim("exp")
if err != nil {
return nil, err
} else if ok {
ident.ExpiresAt = exp
}
return &ident, nil
}

View file

@ -1,3 +0,0 @@
package oidc
type LoginFunc func(ident Identity, sessionKey string) (redirectURL string, err error)

View file

@ -1,67 +0,0 @@
package oidc
import (
"encoding/json"
"errors"
"net/http"
"time"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
)
// DefaultPublicKeySetTTL is the default TTL set on the PublicKeySet if no
// Cache-Control header is provided by the JWK Set document endpoint.
const DefaultPublicKeySetTTL = 24 * time.Hour
// NewRemotePublicKeyRepo is responsible for fetching the JWK Set document.
func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo {
return &remotePublicKeyRepo{hc: hc, ep: ep}
}
type remotePublicKeyRepo struct {
hc phttp.Client
ep string
}
// Get returns a PublicKeySet fetched from the JWK Set document endpoint. A TTL
// is set on the Key Set to avoid it having to be re-retrieved for every
// encryption event. This TTL is typically controlled by the endpoint returning
// a Cache-Control header, but defaults to 24 hours if no Cache-Control header
// is found.
func (r *remotePublicKeyRepo) Get() (key.KeySet, error) {
req, err := http.NewRequest("GET", r.ep, nil)
if err != nil {
return nil, err
}
resp, err := r.hc.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var d struct {
Keys []jose.JWK `json:"keys"`
}
if err := json.NewDecoder(resp.Body).Decode(&d); err != nil {
return nil, err
}
if len(d.Keys) == 0 {
return nil, errors.New("zero keys in response")
}
ttl, ok, err := phttp.Cacheable(resp.Header)
if err != nil {
return nil, err
}
if !ok {
ttl = DefaultPublicKeySetTTL
}
exp := time.Now().UTC().Add(ttl)
ks := key.NewPublicKeySet(d.Keys, exp)
return ks, nil
}

View file

@ -1,690 +0,0 @@
package oidc
import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/oauth2"
)
const (
// Subject Identifier types defined by the OIDC spec. Specifies if the provider
// should provide the same sub claim value to all clients (public) or a unique
// value for each client (pairwise).
//
// See: http://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
SubjectTypePublic = "public"
SubjectTypePairwise = "pairwise"
)
var (
// Default values for omitted provider config fields.
//
// Use ProviderConfig's Defaults method to fill a provider config with these values.
DefaultGrantTypesSupported = []string{oauth2.GrantTypeAuthCode, oauth2.GrantTypeImplicit}
DefaultResponseModesSupported = []string{"query", "fragment"}
DefaultTokenEndpointAuthMethodsSupported = []string{oauth2.AuthMethodClientSecretBasic}
DefaultClaimTypesSupported = []string{"normal"}
)
const (
MaximumProviderConfigSyncInterval = 24 * time.Hour
MinimumProviderConfigSyncInterval = time.Minute
discoveryConfigPath = "/.well-known/openid-configuration"
)
// internally configurable for tests
var minimumProviderConfigSyncInterval = MinimumProviderConfigSyncInterval
var (
// Ensure ProviderConfig satisfies these interfaces.
_ json.Marshaler = &ProviderConfig{}
_ json.Unmarshaler = &ProviderConfig{}
)
// ProviderConfig represents the OpenID Provider Metadata specifying what
// configurations a provider supports.
//
// See: http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
type ProviderConfig struct {
Issuer *url.URL // Required
AuthEndpoint *url.URL // Required
TokenEndpoint *url.URL // Required if grant types other than "implicit" are supported
UserInfoEndpoint *url.URL
KeysEndpoint *url.URL // Required
RegistrationEndpoint *url.URL
EndSessionEndpoint *url.URL
CheckSessionIFrame *url.URL
// Servers MAY choose not to advertise some supported scope values even when this
// parameter is used, although those defined in OpenID Core SHOULD be listed, if supported.
ScopesSupported []string
// OAuth2.0 response types supported.
ResponseTypesSupported []string // Required
// OAuth2.0 response modes supported.
//
// If omitted, defaults to DefaultResponseModesSupported.
ResponseModesSupported []string
// OAuth2.0 grant types supported.
//
// If omitted, defaults to DefaultGrantTypesSupported.
GrantTypesSupported []string
ACRValuesSupported []string
// SubjectTypesSupported specifies strategies for providing values for the sub claim.
SubjectTypesSupported []string // Required
// JWA signing and encryption algorith values supported for ID tokens.
IDTokenSigningAlgValues []string // Required
IDTokenEncryptionAlgValues []string
IDTokenEncryptionEncValues []string
// JWA signing and encryption algorith values supported for user info responses.
UserInfoSigningAlgValues []string
UserInfoEncryptionAlgValues []string
UserInfoEncryptionEncValues []string
// JWA signing and encryption algorith values supported for request objects.
ReqObjSigningAlgValues []string
ReqObjEncryptionAlgValues []string
ReqObjEncryptionEncValues []string
TokenEndpointAuthMethodsSupported []string
TokenEndpointAuthSigningAlgValuesSupported []string
DisplayValuesSupported []string
ClaimTypesSupported []string
ClaimsSupported []string
ServiceDocs *url.URL
ClaimsLocalsSupported []string
UILocalsSupported []string
ClaimsParameterSupported bool
RequestParameterSupported bool
RequestURIParamaterSupported bool
RequireRequestURIRegistration bool
Policy *url.URL
TermsOfService *url.URL
// Not part of the OpenID Provider Metadata
ExpiresAt time.Time
}
// Defaults returns a shallow copy of ProviderConfig with default
// values replacing omitted fields.
//
// var cfg oidc.ProviderConfig
// // Fill provider config with default values for omitted fields.
// cfg = cfg.Defaults()
//
func (p ProviderConfig) Defaults() ProviderConfig {
setDefault := func(val *[]string, defaultVal []string) {
if len(*val) == 0 {
*val = defaultVal
}
}
setDefault(&p.GrantTypesSupported, DefaultGrantTypesSupported)
setDefault(&p.ResponseModesSupported, DefaultResponseModesSupported)
setDefault(&p.TokenEndpointAuthMethodsSupported, DefaultTokenEndpointAuthMethodsSupported)
setDefault(&p.ClaimTypesSupported, DefaultClaimTypesSupported)
return p
}
func (p *ProviderConfig) MarshalJSON() ([]byte, error) {
e := p.toEncodableStruct()
return json.Marshal(&e)
}
func (p *ProviderConfig) UnmarshalJSON(data []byte) error {
var e encodableProviderConfig
if err := json.Unmarshal(data, &e); err != nil {
return err
}
conf, err := e.toStruct()
if err != nil {
return err
}
if err := conf.Valid(); err != nil {
return err
}
*p = conf
return nil
}
type encodableProviderConfig struct {
Issuer string `json:"issuer"`
AuthEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
KeysEndpoint string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
CheckSessionIFrame string `json:"check_session_iframe,omitempty"`
// Use 'omitempty' for all slices as per OIDC spec:
// "Claims that return multiple values are represented as JSON arrays.
// Claims with zero elements MUST be omitted from the response."
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
ScopesSupported []string `json:"scopes_supported,omitempty"`
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
ACRValuesSupported []string `json:"acr_values_supported,omitempty"`
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
IDTokenSigningAlgValues []string `json:"id_token_signing_alg_values_supported,omitempty"`
IDTokenEncryptionAlgValues []string `json:"id_token_encryption_alg_values_supported,omitempty"`
IDTokenEncryptionEncValues []string `json:"id_token_encryption_enc_values_supported,omitempty"`
UserInfoSigningAlgValues []string `json:"userinfo_signing_alg_values_supported,omitempty"`
UserInfoEncryptionAlgValues []string `json:"userinfo_encryption_alg_values_supported,omitempty"`
UserInfoEncryptionEncValues []string `json:"userinfo_encryption_enc_values_supported,omitempty"`
ReqObjSigningAlgValues []string `json:"request_object_signing_alg_values_supported,omitempty"`
ReqObjEncryptionAlgValues []string `json:"request_object_encryption_alg_values_supported,omitempty"`
ReqObjEncryptionEncValues []string `json:"request_object_encryption_enc_values_supported,omitempty"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"`
DisplayValuesSupported []string `json:"display_values_supported,omitempty"`
ClaimTypesSupported []string `json:"claim_types_supported,omitempty"`
ClaimsSupported []string `json:"claims_supported,omitempty"`
ServiceDocs string `json:"service_documentation,omitempty"`
ClaimsLocalsSupported []string `json:"claims_locales_supported,omitempty"`
UILocalsSupported []string `json:"ui_locales_supported,omitempty"`
ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"`
RequestParameterSupported bool `json:"request_parameter_supported,omitempty"`
RequestURIParamaterSupported bool `json:"request_uri_parameter_supported,omitempty"`
RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"`
Policy string `json:"op_policy_uri,omitempty"`
TermsOfService string `json:"op_tos_uri,omitempty"`
}
func (cfg ProviderConfig) toEncodableStruct() encodableProviderConfig {
return encodableProviderConfig{
Issuer: uriToString(cfg.Issuer),
AuthEndpoint: uriToString(cfg.AuthEndpoint),
TokenEndpoint: uriToString(cfg.TokenEndpoint),
UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint),
KeysEndpoint: uriToString(cfg.KeysEndpoint),
RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint),
EndSessionEndpoint: uriToString(cfg.EndSessionEndpoint),
CheckSessionIFrame: uriToString(cfg.CheckSessionIFrame),
ScopesSupported: cfg.ScopesSupported,
ResponseTypesSupported: cfg.ResponseTypesSupported,
ResponseModesSupported: cfg.ResponseModesSupported,
GrantTypesSupported: cfg.GrantTypesSupported,
ACRValuesSupported: cfg.ACRValuesSupported,
SubjectTypesSupported: cfg.SubjectTypesSupported,
IDTokenSigningAlgValues: cfg.IDTokenSigningAlgValues,
IDTokenEncryptionAlgValues: cfg.IDTokenEncryptionAlgValues,
IDTokenEncryptionEncValues: cfg.IDTokenEncryptionEncValues,
UserInfoSigningAlgValues: cfg.UserInfoSigningAlgValues,
UserInfoEncryptionAlgValues: cfg.UserInfoEncryptionAlgValues,
UserInfoEncryptionEncValues: cfg.UserInfoEncryptionEncValues,
ReqObjSigningAlgValues: cfg.ReqObjSigningAlgValues,
ReqObjEncryptionAlgValues: cfg.ReqObjEncryptionAlgValues,
ReqObjEncryptionEncValues: cfg.ReqObjEncryptionEncValues,
TokenEndpointAuthMethodsSupported: cfg.TokenEndpointAuthMethodsSupported,
TokenEndpointAuthSigningAlgValuesSupported: cfg.TokenEndpointAuthSigningAlgValuesSupported,
DisplayValuesSupported: cfg.DisplayValuesSupported,
ClaimTypesSupported: cfg.ClaimTypesSupported,
ClaimsSupported: cfg.ClaimsSupported,
ServiceDocs: uriToString(cfg.ServiceDocs),
ClaimsLocalsSupported: cfg.ClaimsLocalsSupported,
UILocalsSupported: cfg.UILocalsSupported,
ClaimsParameterSupported: cfg.ClaimsParameterSupported,
RequestParameterSupported: cfg.RequestParameterSupported,
RequestURIParamaterSupported: cfg.RequestURIParamaterSupported,
RequireRequestURIRegistration: cfg.RequireRequestURIRegistration,
Policy: uriToString(cfg.Policy),
TermsOfService: uriToString(cfg.TermsOfService),
}
}
func (e encodableProviderConfig) toStruct() (ProviderConfig, error) {
p := stickyErrParser{}
conf := ProviderConfig{
Issuer: p.parseURI(e.Issuer, "issuer"),
AuthEndpoint: p.parseURI(e.AuthEndpoint, "authorization_endpoint"),
TokenEndpoint: p.parseURI(e.TokenEndpoint, "token_endpoint"),
UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"),
KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"),
RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"),
EndSessionEndpoint: p.parseURI(e.EndSessionEndpoint, "end_session_endpoint"),
CheckSessionIFrame: p.parseURI(e.CheckSessionIFrame, "check_session_iframe"),
ScopesSupported: e.ScopesSupported,
ResponseTypesSupported: e.ResponseTypesSupported,
ResponseModesSupported: e.ResponseModesSupported,
GrantTypesSupported: e.GrantTypesSupported,
ACRValuesSupported: e.ACRValuesSupported,
SubjectTypesSupported: e.SubjectTypesSupported,
IDTokenSigningAlgValues: e.IDTokenSigningAlgValues,
IDTokenEncryptionAlgValues: e.IDTokenEncryptionAlgValues,
IDTokenEncryptionEncValues: e.IDTokenEncryptionEncValues,
UserInfoSigningAlgValues: e.UserInfoSigningAlgValues,
UserInfoEncryptionAlgValues: e.UserInfoEncryptionAlgValues,
UserInfoEncryptionEncValues: e.UserInfoEncryptionEncValues,
ReqObjSigningAlgValues: e.ReqObjSigningAlgValues,
ReqObjEncryptionAlgValues: e.ReqObjEncryptionAlgValues,
ReqObjEncryptionEncValues: e.ReqObjEncryptionEncValues,
TokenEndpointAuthMethodsSupported: e.TokenEndpointAuthMethodsSupported,
TokenEndpointAuthSigningAlgValuesSupported: e.TokenEndpointAuthSigningAlgValuesSupported,
DisplayValuesSupported: e.DisplayValuesSupported,
ClaimTypesSupported: e.ClaimTypesSupported,
ClaimsSupported: e.ClaimsSupported,
ServiceDocs: p.parseURI(e.ServiceDocs, "service_documentation"),
ClaimsLocalsSupported: e.ClaimsLocalsSupported,
UILocalsSupported: e.UILocalsSupported,
ClaimsParameterSupported: e.ClaimsParameterSupported,
RequestParameterSupported: e.RequestParameterSupported,
RequestURIParamaterSupported: e.RequestURIParamaterSupported,
RequireRequestURIRegistration: e.RequireRequestURIRegistration,
Policy: p.parseURI(e.Policy, "op_policy-uri"),
TermsOfService: p.parseURI(e.TermsOfService, "op_tos_uri"),
}
if p.firstErr != nil {
return ProviderConfig{}, p.firstErr
}
return conf, nil
}
// Empty returns if a ProviderConfig holds no information.
//
// This case generally indicates a ProviderConfigGetter has experienced an error
// and has nothing to report.
func (p ProviderConfig) Empty() bool {
return p.Issuer == nil
}
func contains(sli []string, ele string) bool {
for _, s := range sli {
if s == ele {
return true
}
}
return false
}
// Valid determines if a ProviderConfig conforms with the OIDC specification.
// If Valid returns successfully it guarantees required field are non-nil and
// URLs are well formed.
//
// Valid is called by UnmarshalJSON.
//
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
// URLs fields where the OIDC spec requires it. This may change in future releases
// of this package. See: https://github.com/coreos/go-oidc/issues/34
func (p ProviderConfig) Valid() error {
grantTypes := p.GrantTypesSupported
if len(grantTypes) == 0 {
grantTypes = DefaultGrantTypesSupported
}
implicitOnly := true
for _, grantType := range grantTypes {
if grantType != oauth2.GrantTypeImplicit {
implicitOnly = false
break
}
}
if len(p.SubjectTypesSupported) == 0 {
return errors.New("missing required field subject_types_supported")
}
if len(p.IDTokenSigningAlgValues) == 0 {
return errors.New("missing required field id_token_signing_alg_values_supported")
}
if len(p.ScopesSupported) != 0 && !contains(p.ScopesSupported, "openid") {
return errors.New("scoped_supported must be unspecified or include 'openid'")
}
if !contains(p.IDTokenSigningAlgValues, "RS256") {
return errors.New("id_token_signing_alg_values_supported must include 'RS256'")
}
if contains(p.TokenEndpointAuthMethodsSupported, "none") {
return errors.New("token_endpoint_auth_signing_alg_values_supported cannot include 'none'")
}
uris := []struct {
val *url.URL
name string
required bool
}{
{p.Issuer, "issuer", true},
{p.AuthEndpoint, "authorization_endpoint", true},
{p.TokenEndpoint, "token_endpoint", !implicitOnly},
{p.UserInfoEndpoint, "userinfo_endpoint", false},
{p.KeysEndpoint, "jwks_uri", true},
{p.RegistrationEndpoint, "registration_endpoint", false},
{p.EndSessionEndpoint, "end_session_endpoint", false},
{p.CheckSessionIFrame, "check_session_iframe", false},
{p.ServiceDocs, "service_documentation", false},
{p.Policy, "op_policy_uri", false},
{p.TermsOfService, "op_tos_uri", false},
}
for _, uri := range uris {
if uri.val == nil {
if !uri.required {
continue
}
return fmt.Errorf("empty value for required uri field %s", uri.name)
}
if uri.val.Host == "" {
return fmt.Errorf("no host for uri field %s", uri.name)
}
if uri.val.Scheme != "http" && uri.val.Scheme != "https" {
return fmt.Errorf("uri field %s schemeis not http or https", uri.name)
}
}
return nil
}
// Supports determines if provider supports a client given their respective metadata.
func (p ProviderConfig) Supports(c ClientMetadata) error {
if err := p.Valid(); err != nil {
return fmt.Errorf("invalid provider config: %v", err)
}
if err := c.Valid(); err != nil {
return fmt.Errorf("invalid client config: %v", err)
}
// Fill default values for omitted fields
c = c.Defaults()
p = p.Defaults()
// Do the supported values list the requested one?
supports := []struct {
supported []string
requested string
name string
}{
{p.IDTokenSigningAlgValues, c.IDTokenResponseOptions.SigningAlg, "id_token_signed_response_alg"},
{p.IDTokenEncryptionAlgValues, c.IDTokenResponseOptions.EncryptionAlg, "id_token_encryption_response_alg"},
{p.IDTokenEncryptionEncValues, c.IDTokenResponseOptions.EncryptionEnc, "id_token_encryption_response_enc"},
{p.UserInfoSigningAlgValues, c.UserInfoResponseOptions.SigningAlg, "userinfo_signed_response_alg"},
{p.UserInfoEncryptionAlgValues, c.UserInfoResponseOptions.EncryptionAlg, "userinfo_encryption_response_alg"},
{p.UserInfoEncryptionEncValues, c.UserInfoResponseOptions.EncryptionEnc, "userinfo_encryption_response_enc"},
{p.ReqObjSigningAlgValues, c.RequestObjectOptions.SigningAlg, "request_object_signing_alg"},
{p.ReqObjEncryptionAlgValues, c.RequestObjectOptions.EncryptionAlg, "request_object_encryption_alg"},
{p.ReqObjEncryptionEncValues, c.RequestObjectOptions.EncryptionEnc, "request_object_encryption_enc"},
}
for _, field := range supports {
if field.requested == "" {
continue
}
if !contains(field.supported, field.requested) {
return fmt.Errorf("provider does not support requested value for field %s", field.name)
}
}
stringsEqual := func(s1, s2 string) bool { return s1 == s2 }
// For lists, are the list of requested values a subset of the supported ones?
supportsAll := []struct {
supported []string
requested []string
name string
// OAuth2.0 response_type can be space separated lists where order doesn't matter.
// For example "id_token token" is the same as "token id_token"
// Support a custom compare method.
comp func(s1, s2 string) bool
}{
{p.GrantTypesSupported, c.GrantTypes, "grant_types", stringsEqual},
{p.ResponseTypesSupported, c.ResponseTypes, "response_type", oauth2.ResponseTypesEqual},
}
for _, field := range supportsAll {
requestLoop:
for _, req := range field.requested {
for _, sup := range field.supported {
if field.comp(req, sup) {
continue requestLoop
}
}
return fmt.Errorf("provider does not support requested value for field %s", field.name)
}
}
// TODO(ericchiang): Are there more checks we feel comfortable with begin strict about?
return nil
}
func (p ProviderConfig) SupportsGrantType(grantType string) bool {
var supported []string
if len(p.GrantTypesSupported) == 0 {
supported = DefaultGrantTypesSupported
} else {
supported = p.GrantTypesSupported
}
for _, t := range supported {
if t == grantType {
return true
}
}
return false
}
type ProviderConfigGetter interface {
Get() (ProviderConfig, error)
}
type ProviderConfigSetter interface {
Set(ProviderConfig) error
}
type ProviderConfigSyncer struct {
from ProviderConfigGetter
to ProviderConfigSetter
clock clockwork.Clock
initialSyncDone bool
initialSyncWait sync.WaitGroup
}
func NewProviderConfigSyncer(from ProviderConfigGetter, to ProviderConfigSetter) *ProviderConfigSyncer {
return &ProviderConfigSyncer{
from: from,
to: to,
clock: clockwork.NewRealClock(),
}
}
func (s *ProviderConfigSyncer) Run() chan struct{} {
stop := make(chan struct{})
var next pcsStepper
next = &pcsStepNext{aft: time.Duration(0)}
s.initialSyncWait.Add(1)
go func() {
for {
select {
case <-s.clock.After(next.after()):
next = next.step(s.sync)
case <-stop:
return
}
}
}()
return stop
}
func (s *ProviderConfigSyncer) WaitUntilInitialSync() {
s.initialSyncWait.Wait()
}
func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
cfg, err := s.from.Get()
if err != nil {
return 0, err
}
if err = s.to.Set(cfg); err != nil {
return 0, fmt.Errorf("error setting provider config: %v", err)
}
if !s.initialSyncDone {
s.initialSyncWait.Done()
s.initialSyncDone = true
}
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
}
type pcsStepFunc func() (time.Duration, error)
type pcsStepper interface {
after() time.Duration
step(pcsStepFunc) pcsStepper
}
type pcsStepNext struct {
aft time.Duration
}
func (n *pcsStepNext) after() time.Duration {
return n.aft
}
func (n *pcsStepNext) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn()
if err == nil {
next = &pcsStepNext{aft: ttl}
} else {
next = &pcsStepRetry{aft: time.Second}
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
}
return
}
type pcsStepRetry struct {
aft time.Duration
}
func (r *pcsStepRetry) after() time.Duration {
return r.aft
}
func (r *pcsStepRetry) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn()
if err == nil {
next = &pcsStepNext{aft: ttl}
} else {
next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)}
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
}
return
}
func nextSyncAfter(exp time.Time, clock clockwork.Clock) time.Duration {
if exp.IsZero() {
return MaximumProviderConfigSyncInterval
}
t := exp.Sub(clock.Now()) / 2
if t > MaximumProviderConfigSyncInterval {
t = MaximumProviderConfigSyncInterval
} else if t < minimumProviderConfigSyncInterval {
t = minimumProviderConfigSyncInterval
}
return t
}
type httpProviderConfigGetter struct {
hc phttp.Client
issuerURL string
clock clockwork.Clock
}
func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProviderConfigGetter {
return &httpProviderConfigGetter{
hc: hc,
issuerURL: issuerURL,
clock: clockwork.NewRealClock(),
}
}
func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) {
// If the Issuer value contains a path component, any terminating / MUST be removed before
// appending /.well-known/openid-configuration.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest
discoveryURL := strings.TrimSuffix(r.issuerURL, "/") + discoveryConfigPath
req, err := http.NewRequest("GET", discoveryURL, nil)
if err != nil {
return
}
resp, err := r.hc.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if err = json.NewDecoder(resp.Body).Decode(&cfg); err != nil {
return
}
var ttl time.Duration
var ok bool
ttl, ok, err = phttp.Cacheable(resp.Header)
if err != nil {
return
} else if ok {
cfg.ExpiresAt = r.clock.Now().UTC().Add(ttl)
}
// The issuer value returned MUST be identical to the Issuer URL that was directly used to retrieve the configuration information.
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationValidation
if !urlEqual(cfg.Issuer.String(), r.issuerURL) {
err = fmt.Errorf(`"issuer" in config (%v) does not match provided issuer URL (%v)`, cfg.Issuer, r.issuerURL)
return
}
return
}
func FetchProviderConfig(hc phttp.Client, issuerURL string) (ProviderConfig, error) {
if hc == nil {
hc = http.DefaultClient
}
g := NewHTTPProviderConfigGetter(hc, issuerURL)
return g.Get()
}
func WaitForProviderConfig(hc phttp.Client, issuerURL string) (pcfg ProviderConfig) {
return waitForProviderConfig(hc, issuerURL, clockwork.NewRealClock())
}
func waitForProviderConfig(hc phttp.Client, issuerURL string, clock clockwork.Clock) (pcfg ProviderConfig) {
var sleep time.Duration
var err error
for {
pcfg, err = FetchProviderConfig(hc, issuerURL)
if err == nil {
break
}
sleep = timeutil.ExpBackoff(sleep, time.Minute)
fmt.Printf("Failed fetching provider config, trying again in %v: %v\n", sleep, err)
time.Sleep(sleep)
}
return
}

View file

@ -1,88 +0,0 @@
package oidc
import (
"fmt"
"net/http"
"sync"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose"
)
type TokenRefresher interface {
// Verify checks if the provided token is currently valid or not.
Verify(jose.JWT) error
// Refresh attempts to authenticate and retrieve a new token.
Refresh() (jose.JWT, error)
}
type ClientCredsTokenRefresher struct {
Issuer string
OIDCClient *Client
}
func (c *ClientCredsTokenRefresher) Verify(jwt jose.JWT) (err error) {
_, err = VerifyClientClaims(jwt, c.Issuer)
return
}
func (c *ClientCredsTokenRefresher) Refresh() (jwt jose.JWT, err error) {
if err = c.OIDCClient.Healthy(); err != nil {
err = fmt.Errorf("unable to authenticate, unhealthy OIDC client: %v", err)
return
}
jwt, err = c.OIDCClient.ClientCredsToken([]string{"openid"})
if err != nil {
err = fmt.Errorf("unable to verify auth code with issuer: %v", err)
return
}
return
}
type AuthenticatedTransport struct {
TokenRefresher
http.RoundTripper
mu sync.Mutex
jwt jose.JWT
}
func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.TokenRefresher.Verify(t.jwt) == nil {
return t.jwt, nil
}
jwt, err := t.TokenRefresher.Refresh()
if err != nil {
return jose.JWT{}, fmt.Errorf("unable to acquire valid JWT: %v", err)
}
t.jwt = jwt
return t.jwt, nil
}
// SetJWT sets the JWT held by the Transport.
// This is useful for cases in which you want to set an initial JWT.
func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) {
t.mu.Lock()
defer t.mu.Unlock()
t.jwt = jwt
}
func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) {
jwt, err := t.verifiedJWT()
if err != nil {
return nil, err
}
req := phttp.CopyRequest(r)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt.Encode()))
return t.RoundTripper.RoundTrip(req)
}

View file

@ -1,109 +0,0 @@
package oidc
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc/jose"
)
// RequestTokenExtractor funcs extract a raw encoded token from a request.
type RequestTokenExtractor func(r *http.Request) (string, error)
// ExtractBearerToken is a RequestTokenExtractor which extracts a bearer token from a request's
// Authorization header.
func ExtractBearerToken(r *http.Request) (string, error) {
ah := r.Header.Get("Authorization")
if ah == "" {
return "", errors.New("missing Authorization header")
}
if len(ah) <= 6 || strings.ToUpper(ah[0:6]) != "BEARER" {
return "", errors.New("should be a bearer token")
}
val := ah[7:]
if len(val) == 0 {
return "", errors.New("bearer token is empty")
}
return val, nil
}
// CookieTokenExtractor returns a RequestTokenExtractor which extracts a token from the named cookie in a request.
func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
return func(r *http.Request) (string, error) {
ck, err := r.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("token cookie not found in request: %v", err)
}
if ck.Value == "" {
return "", errors.New("token cookie found but is empty")
}
return ck.Value, nil
}
}
func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
return jose.Claims{
// required
"iss": iss,
"sub": sub,
"aud": aud,
"iat": iat.Unix(),
"exp": exp.Unix(),
}
}
func GenClientID(hostport string) (string, error) {
b, err := randBytes(32)
if err != nil {
return "", err
}
var host string
if strings.Contains(hostport, ":") {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
return "", err
}
} else {
host = hostport
}
return fmt.Sprintf("%s@%s", base64.URLEncoding.EncodeToString(b), host), nil
}
func randBytes(n int) ([]byte, error) {
b := make([]byte, n)
got, err := rand.Read(b)
if err != nil {
return nil, err
} else if n != got {
return nil, errors.New("unable to generate enough random data")
}
return b, nil
}
// urlEqual checks two urls for equality using only the host and path portions.
func urlEqual(url1, url2 string) bool {
u1, err := url.Parse(url1)
if err != nil {
return false
}
u2, err := url.Parse(url2)
if err != nil {
return false
}
return strings.ToLower(u1.Host+u1.Path) == strings.ToLower(u2.Host+u2.Path)
}

View file

@ -1,190 +0,0 @@
package oidc
import (
"errors"
"fmt"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
)
func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
jwtBytes := []byte(jwt.Data())
for _, k := range keys {
v, err := k.Verifier()
if err != nil {
return false, err
}
if v.Verify(jwt.Signature, jwtBytes) == nil {
return true, nil
}
}
return false, nil
}
// containsString returns true if the given string(needle) is found
// in the string array(haystack).
func containsString(needle string, haystack []string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}
// Verify claims in accordance with OIDC spec
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
now := time.Now().UTC()
claims, err := jwt.Claims()
if err != nil {
return err
}
ident, err := IdentityFromClaims(claims)
if err != nil {
return err
}
if ident.ExpiresAt.Before(now) {
return errors.New("token is expired")
}
// iss REQUIRED. Issuer Identifier for the Issuer of the response.
// The iss value is a case sensitive URL using the https scheme that contains scheme,
// host, and optionally, port number and path components and no query or fragment components.
if iss, exists := claims["iss"].(string); exists {
if !urlEqual(iss, issuer) {
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
}
} else {
return errors.New("missing claim: 'iss'")
}
// iat REQUIRED. Time at which the JWT was issued.
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
// as measured in UTC until the date/time.
if _, exists := claims["iat"].(float64); !exists {
return errors.New("missing claim: 'iat'")
}
// aud REQUIRED. Audience(s) that this ID Token is intended for.
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
// It MAY also contain identifiers for other audiences. In the general case, the aud
// value is an array of case sensitive strings. In the common special case when there
// is one audience, the aud value MAY be a single case sensitive string.
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if aud != clientID {
return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
}
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(clientID, aud) {
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
}
} else {
return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
}
return nil
}
// VerifyClientClaims verifies all the required claims are valid for a "client credentials" JWT.
// Returns the client ID if valid, or an error if invalid.
func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
claims, err := jwt.Claims()
if err != nil {
return "", fmt.Errorf("failed to parse JWT claims: %v", err)
}
iss, ok, err := claims.StringClaim("iss")
if err != nil {
return "", fmt.Errorf("failed to parse 'iss' claim: %v", err)
} else if !ok {
return "", errors.New("missing required 'iss' claim")
} else if !urlEqual(iss, issuer) {
return "", fmt.Errorf("'iss' claim does not match expected issuer, iss=%s", iss)
}
sub, ok, err := claims.StringClaim("sub")
if err != nil {
return "", fmt.Errorf("failed to parse 'sub' claim: %v", err)
} else if !ok {
return "", errors.New("missing required 'sub' claim")
}
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if aud != sub {
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
}
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(sub, aud) {
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
}
} else {
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
}
now := time.Now().UTC()
exp, ok, err := claims.TimeClaim("exp")
if err != nil {
return "", fmt.Errorf("failed to parse 'exp' claim: %v", err)
} else if !ok {
return "", errors.New("missing required 'exp' claim")
} else if exp.Before(now) {
return "", fmt.Errorf("token already expired at: %v", exp)
}
return sub, nil
}
type JWTVerifier struct {
issuer string
clientID string
syncFunc func() error
keysFunc func() []key.PublicKey
clock clockwork.Clock
}
func NewJWTVerifier(issuer, clientID string, syncFunc func() error, keysFunc func() []key.PublicKey) JWTVerifier {
return JWTVerifier{
issuer: issuer,
clientID: clientID,
syncFunc: syncFunc,
keysFunc: keysFunc,
clock: clockwork.NewRealClock(),
}
}
func (v *JWTVerifier) Verify(jwt jose.JWT) error {
// Verify claims before verifying the signature. This is an optimization to throw out
// tokens we know are invalid without undergoing an expensive signature check and
// possibly a re-sync event.
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
}
ok, err := VerifySignature(jwt, v.keysFunc())
if err != nil {
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
} else if ok {
return nil
}
if err = v.syncFunc(); err != nil {
return fmt.Errorf("oidc: failed syncing KeySet: %v", err)
}
ok, err = VerifySignature(jwt, v.keysFunc())
if err != nil {
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
} else if !ok {
return errors.New("oidc: unable to verify JWT signature: no matching keys")
}
return nil
}

202
vendor/github.com/coreos/pkg/LICENSE generated vendored
View file

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -1,5 +0,0 @@
CoreOS Project
Copyright 2014 CoreOS, Inc
This product includes software developed at CoreOS, Inc.
(http://www.coreos.com/).

View file

@ -1,127 +0,0 @@
package health
import (
"expvar"
"fmt"
"log"
"net/http"
"github.com/coreos/pkg/httputil"
)
// Checkables should return nil when the thing they are checking is healthy, and an error otherwise.
type Checkable interface {
Healthy() error
}
// Checker provides a way to make an endpoint which can be probed for system health.
type Checker struct {
// Checks are the Checkables to be checked when probing.
Checks []Checkable
// Unhealthyhandler is called when one or more of the checks are unhealthy.
// If not provided DefaultUnhealthyHandler is called.
UnhealthyHandler UnhealthyHandler
// HealthyHandler is called when all checks are healthy.
// If not provided, DefaultHealthyHandler is called.
HealthyHandler http.HandlerFunc
}
func (c Checker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
unhealthyHandler := c.UnhealthyHandler
if unhealthyHandler == nil {
unhealthyHandler = DefaultUnhealthyHandler
}
successHandler := c.HealthyHandler
if successHandler == nil {
successHandler = DefaultHealthyHandler
}
if r.Method != "GET" {
w.Header().Set("Allow", "GET")
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := Check(c.Checks); err != nil {
unhealthyHandler(w, r, err)
return
}
successHandler(w, r)
}
type UnhealthyHandler func(w http.ResponseWriter, r *http.Request, err error)
type StatusResponse struct {
Status string `json:"status"`
Details *StatusResponseDetails `json:"details,omitempty"`
}
type StatusResponseDetails struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
func Check(checks []Checkable) (err error) {
errs := []error{}
for _, c := range checks {
if e := c.Healthy(); e != nil {
errs = append(errs, e)
}
}
switch len(errs) {
case 0:
err = nil
case 1:
err = errs[0]
default:
err = fmt.Errorf("multiple health check failure: %v", errs)
}
return
}
func DefaultHealthyHandler(w http.ResponseWriter, r *http.Request) {
err := httputil.WriteJSONResponse(w, http.StatusOK, StatusResponse{
Status: "ok",
})
if err != nil {
// TODO(bobbyrullo): replace with logging from new logging pkg,
// once it lands.
log.Printf("Failed to write JSON response: %v", err)
}
}
func DefaultUnhealthyHandler(w http.ResponseWriter, r *http.Request, err error) {
writeErr := httputil.WriteJSONResponse(w, http.StatusInternalServerError, StatusResponse{
Status: "error",
Details: &StatusResponseDetails{
Code: http.StatusInternalServerError,
Message: err.Error(),
},
})
if writeErr != nil {
// TODO(bobbyrullo): replace with logging from new logging pkg,
// once it lands.
log.Printf("Failed to write JSON response: %v", err)
}
}
// ExpvarHandler is copied from https://golang.org/src/expvar/expvar.go, where it's sadly unexported.
func ExpvarHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n")
first := true
expvar.Do(func(kv expvar.KeyValue) {
if !first {
fmt.Fprintf(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprintf(w, "\n}\n")
}

View file

@ -1,21 +0,0 @@
package httputil
import (
"net/http"
"time"
)
// DeleteCookies effectively deletes all named cookies
// by wiping all data and setting to expire immediately.
func DeleteCookies(w http.ResponseWriter, cookieNames ...string) {
for _, n := range cookieNames {
c := &http.Cookie{
Name: n,
Value: "",
Path: "/",
MaxAge: -1,
Expires: time.Time{},
}
http.SetCookie(w, c)
}
}

View file

@ -1,27 +0,0 @@
package httputil
import (
"encoding/json"
"net/http"
)
const (
JSONContentType = "application/json"
)
func WriteJSONResponse(w http.ResponseWriter, code int, resp interface{}) error {
enc, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return err
}
w.Header().Set("Content-Type", JSONContentType)
w.WriteHeader(code)
_, err = w.Write(enc)
if err != nil {
return err
}
return nil
}

View file

@ -1,15 +0,0 @@
package timeutil
import (
"time"
)
func ExpBackoff(prev, max time.Duration) time.Duration {
if prev == 0 {
return time.Second
}
if prev > max/2 {
return max
}
return 2 * prev
}

View file

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs:
make install
make -C testdata
protoc-min-version --version="3.0.0" --proto_path=.:../../../../ --gogo_out=. proto3_proto/proto3.proto
protoc-min-version --version="3.0.0" --proto_path=.:../../../../:../protobuf --gogo_out=Mtestdata/test.proto=github.com/gogo/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. proto3_proto/proto3.proto
make

View file

@ -84,14 +84,20 @@ func mergeStruct(out, in reflect.Value) {
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
}
if emIn, ok := in.Addr().Interface().(extensionsMap); ok {
emOut := out.Addr().Interface().(extensionsMap)
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
} else if emIn, ok := in.Addr().Interface().(extensionsBytes); ok {
if emIn, ok := in.Addr().Interface().(extensionsBytes); ok {
emOut := out.Addr().Interface().(extensionsBytes)
bIn := emIn.GetExtensions()
bOut := emOut.GetExtensions()
*bOut = append(*bOut, *bIn...)
} else if emIn, ok := extendable(in.Addr().Interface()); ok {
emOut, _ := extendable(out.Addr().Interface())
mIn, muIn := emIn.extensionsRead()
if mIn != nil {
mOut := emOut.extensionsWrite()
muIn.Lock()
mergeExtension(mOut, mIn)
muIn.Unlock()
}
}
uf := in.FieldByName("XXX_unrecognized")

Some files were not shown because too many files have changed in this diff Show more