Add proper unclean shutdown handling with a cancellable context.

Signed-off-by: Tom Wilkie <tom.wilkie@gmail.com>
This commit is contained in:
Tom Wilkie 2018-05-29 09:51:29 +01:00
parent 8acad5f3cd
commit 3353bbd018
4 changed files with 52 additions and 8 deletions

View file

@ -69,7 +69,7 @@ type recoverableError struct {
} }
// Store sends a batch of samples to the HTTP endpoint. // Store sends a batch of samples to the HTTP endpoint.
func (c *Client) Store(req *prompb.WriteRequest) error { func (c *Client) Store(ctx context.Context, req *prompb.WriteRequest) error {
data, err := proto.Marshal(req) data, err := proto.Marshal(req)
if err != nil { if err != nil {
return err return err
@ -85,6 +85,7 @@ func (c *Client) Store(req *prompb.WriteRequest) error {
httpReq.Header.Add("Content-Encoding", "snappy") httpReq.Header.Add("Content-Encoding", "snappy")
httpReq.Header.Set("Content-Type", "application/x-protobuf") httpReq.Header.Set("Content-Type", "application/x-protobuf")
httpReq.Header.Set("X-Prometheus-Remote-Write-Version", "0.1.0") httpReq.Header.Set("X-Prometheus-Remote-Write-Version", "0.1.0")
httpReq = httpReq.WithContext(ctx)
ctx, cancel := context.WithTimeout(context.Background(), c.timeout) ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel() defer cancel()

View file

@ -14,6 +14,7 @@
package remote package remote
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -73,7 +74,7 @@ func TestStoreHTTPErrorHandling(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = c.Store(&prompb.WriteRequest{}) err = c.Store(context.TODO(), &prompb.WriteRequest{})
if !reflect.DeepEqual(err, test.err) { if !reflect.DeepEqual(err, test.err) {
t.Errorf("%d. Unexpected error; want %v, got %v", i, test.err, err) t.Errorf("%d. Unexpected error; want %v, got %v", i, test.err, err)
} }

View file

@ -14,6 +14,7 @@
package remote package remote
import ( import (
"context"
"math" "math"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -130,7 +131,7 @@ func init() {
// external timeseries database. // external timeseries database.
type StorageClient interface { type StorageClient interface {
// Store stores the given samples in the remote storage. // Store stores the given samples in the remote storage.
Store(*prompb.WriteRequest) error Store(context.Context, *prompb.WriteRequest) error
// Name identifies the remote storage implementation. // Name identifies the remote storage implementation.
Name() string Name() string
} }
@ -376,6 +377,8 @@ type shards struct {
queues []chan *model.Sample queues []chan *model.Sample
done chan struct{} done chan struct{}
running int32 running int32
ctx context.Context
cancel context.CancelFunc
} }
func (t *QueueManager) newShards(numShards int) *shards { func (t *QueueManager) newShards(numShards int) *shards {
@ -383,11 +386,14 @@ func (t *QueueManager) newShards(numShards int) *shards {
for i := 0; i < numShards; i++ { for i := 0; i < numShards; i++ {
queues[i] = make(chan *model.Sample, t.cfg.Capacity) queues[i] = make(chan *model.Sample, t.cfg.Capacity)
} }
ctx, cancel := context.WithCancel(context.Background())
s := &shards{ s := &shards{
qm: t, qm: t,
queues: queues, queues: queues,
done: make(chan struct{}), done: make(chan struct{}),
running: int32(numShards), running: int32(numShards),
ctx: ctx,
cancel: cancel,
} }
return s return s
} }
@ -403,15 +409,21 @@ func (s *shards) start() {
} }
func (s *shards) stop(deadline time.Duration) { func (s *shards) stop(deadline time.Duration) {
// Attempt a clean shutdown.
for _, shard := range s.queues { for _, shard := range s.queues {
close(shard) close(shard)
} }
select { select {
case <-s.done: case <-s.done:
return
case <-time.After(deadline): case <-time.After(deadline):
level.Error(s.qm.logger).Log("msg", "Failed to flush all samples on shutdown") level.Error(s.qm.logger).Log("msg", "Failed to flush all samples on shutdown")
} }
// Force a unclean shutdown.
s.cancel()
<-s.done
return
} }
func (s *shards) enqueue(sample *model.Sample) bool { func (s *shards) enqueue(sample *model.Sample) bool {
@ -455,6 +467,9 @@ func (s *shards) runShard(i int) {
for { for {
select { select {
case <-s.ctx.Done():
return
case sample, ok := <-queue: case sample, ok := <-queue:
if !ok { if !ok {
if len(pendingSamples) > 0 { if len(pendingSamples) > 0 {
@ -502,7 +517,7 @@ func (s *shards) sendSamplesWithBackoff(samples model.Samples) {
for retries := s.qm.cfg.MaxRetries; retries > 0; retries-- { for retries := s.qm.cfg.MaxRetries; retries > 0; retries-- {
begin := time.Now() begin := time.Now()
req := ToWriteRequest(samples) req := ToWriteRequest(samples)
err := s.qm.client.Store(req) err := s.qm.client.Store(s.ctx, req)
sentBatchDuration.WithLabelValues(s.qm.queueName).Observe(time.Since(begin).Seconds()) sentBatchDuration.WithLabelValues(s.qm.queueName).Observe(time.Since(begin).Seconds())
if err == nil { if err == nil {

View file

@ -14,6 +14,7 @@
package remote package remote
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"sync" "sync"
@ -71,7 +72,7 @@ func (c *TestStorageClient) waitForExpectedSamples(t *testing.T) {
} }
} }
func (c *TestStorageClient) Store(req *prompb.WriteRequest) error { func (c *TestStorageClient) Store(_ context.Context, req *prompb.WriteRequest) error {
c.mtx.Lock() c.mtx.Lock()
defer c.mtx.Unlock() defer c.mtx.Unlock()
count := 0 count := 0
@ -211,9 +212,12 @@ func NewTestBlockedStorageClient() *TestBlockingStorageClient {
} }
} }
func (c *TestBlockingStorageClient) Store(_ *prompb.WriteRequest) error { func (c *TestBlockingStorageClient) Store(ctx context.Context, _ *prompb.WriteRequest) error {
atomic.AddUint64(&c.numCalls, 1) atomic.AddUint64(&c.numCalls, 1)
<-c.block select {
case <-c.block:
case <-ctx.Done():
}
return nil return nil
} }
@ -301,3 +305,26 @@ func TestSpawnNotMoreThanMaxConcurrentSendsGoroutines(t *testing.T) {
t.Errorf("Saw %d concurrent sends, expected 1", numCalls) t.Errorf("Saw %d concurrent sends, expected 1", numCalls)
} }
} }
func TestShutdown(t *testing.T) {
deadline := 10 * time.Second
c := NewTestBlockedStorageClient()
m := NewQueueManager(nil, config.DefaultQueueConfig, nil, nil, c, deadline)
for i := 0; i < config.DefaultQueueConfig.MaxSamplesPerSend; i++ {
m.Append(&model.Sample{
Metric: model.Metric{
model.MetricNameLabel: model.LabelValue(fmt.Sprintf("test_metric_%d", i)),
},
Value: model.SampleValue(i),
Timestamp: model.Time(i),
})
}
m.Start()
start := time.Now()
m.Stop()
duration := time.Now().Sub(start)
if duration > deadline+(deadline/10) {
t.Errorf("Took too long to shutdown: %s > %s", duration, deadline)
}
}