diff --git a/storage/remote/queue_manager_test.go b/storage/remote/queue_manager_test.go index 9e2522d0b1..afe23178f2 100644 --- a/storage/remote/queue_manager_test.go +++ b/storage/remote/queue_manager_test.go @@ -15,7 +15,9 @@ package remote import ( "sync" + "sync/atomic" "testing" + "time" "github.com/prometheus/common/model" ) @@ -50,33 +52,6 @@ func (c *TestStorageClient) Name() string { return "teststorageclient" } -type TestBlockingStorageClient struct { - block chan bool - getData chan bool -} - -func NewTestBlockedStorageClient() *TestBlockingStorageClient { - return &TestBlockingStorageClient{ - block: make(chan bool), - getData: make(chan bool), - } -} - -func (c *TestBlockingStorageClient) Store(s model.Samples) error { - <-c.getData - <-c.block - return nil -} - -func (c *TestBlockingStorageClient) unlock() { - close(c.getData) - close(c.block) -} - -func (c *TestBlockingStorageClient) Name() string { - return "testblockingstorageclient" -} - func TestSampleDelivery(t *testing.T) { // Let's create an even number of send batches so we don't run into the // batch timeout case. @@ -110,10 +85,47 @@ func TestSampleDelivery(t *testing.T) { c.waitForExpectedSamples(t) } +// TestBlockingStorageClient is a queue_manager StorageClient which will block +// on any calls to Store(), until the `block` channel is closed, at which point +// the `numCalls` property will contain a count of how many times Store() was +// called. +type TestBlockingStorageClient struct { + block chan bool + numCalls uint64 +} + +func NewTestBlockedStorageClient() *TestBlockingStorageClient { + return &TestBlockingStorageClient{ + block: make(chan bool), + numCalls: 0, + } +} + +func (c *TestBlockingStorageClient) Store(s model.Samples) error { + atomic.AddUint64(&c.numCalls, 1) + <-c.block + return nil +} + +func (c *TestBlockingStorageClient) NumCalls() uint64 { + return atomic.LoadUint64(&c.numCalls) +} + +func (c *TestBlockingStorageClient) unlock() { + close(c.block) +} + +func (c *TestBlockingStorageClient) Name() string { + return "testblockingstorageclient" +} + func TestSpawnNotMoreThanMaxConcurrentSendsGoroutines(t *testing.T) { - // `maxSamplesPerSend*maxConcurrentSends + 1` samples should be consumed by - // goroutines, `maxSamplesPerSend` should be still in the queue. - n := maxSamplesPerSend*maxConcurrentSends + maxSamplesPerSend*2 + // Our goal is to fully empty the queue: + // `maxSamplesPerSend*maxConcurrentSends` samples should be consumed by the + // semaphore-controlled goroutines, and then another `maxSamplesPerSend` + // should be consumed by the Run() loop calling sendSample and immediately + // blocking. + n := maxSamplesPerSend*maxConcurrentSends + maxSamplesPerSend samples := make(model.Samples, 0, n) for i := 0; i < n; i++ { @@ -130,19 +142,40 @@ func TestSpawnNotMoreThanMaxConcurrentSendsGoroutines(t *testing.T) { go m.Run() + defer func() { + c.unlock() + m.Stop() + }() + for _, s := range samples { m.Append(s) } - for i := 0; i < maxConcurrentSends; i++ { - c.getData <- true // Wait while all goroutines are spawned. + // Wait until the Run() loop drains the queue. If things went right, it + // should then immediately block in sendSamples(), but, in case of error, + // it would spawn too many goroutines, and thus we'd see more calls to + // client.Store() + // + // The timed wait is maybe non-ideal, but, in order to verify that we're + // not spawning too many concurrent goroutines, we have to wait on the + // Run() loop to consume a specific number of elements from the + // queue... and it doesn't signal that in any obvious way, except by + // draining the queue. We cap the waiting at 1 second -- that should give + // plenty of time, and keeps the failure fairly quick if we're not draining + // the queue properly. + for i := 0; i < 100 && len(m.queue) > 0; i++ { + time.Sleep(10 * time.Millisecond) } - if len(m.queue) != maxSamplesPerSend { - t.Errorf("Queue should contain %d samples, it contains 0.", maxSamplesPerSend) + if len(m.queue) > 0 { + t.Fatalf("Failed to drain StorageQueueManager queue, %d elements left", + len(m.queue), + ) } - c.unlock() + numCalls := c.NumCalls() + if numCalls != maxConcurrentSends { + t.Errorf("Saw %d concurrent sends, expected %d", numCalls, maxConcurrentSends) + } - defer m.Stop() }