diff --git a/discovery/marathon/marathon_test.go b/discovery/marathon/marathon_test.go index 5624858fb4..b545da1725 100644 --- a/discovery/marathon/marathon_test.go +++ b/discovery/marathon/marathon_test.go @@ -125,27 +125,19 @@ func TestMarathonSDSendGroup(t *testing.T) { } func TestMarathonSDRemoveApp(t *testing.T) { - var ch = make(chan []*config.TargetGroup) + var ch = make(chan []*config.TargetGroup, 1) md, err := NewDiscovery(&conf) if err != nil { t.Fatalf("%s", err) } + md.appsClient = func(client *http.Client, url, token string) (*AppList, error) { return marathonTestAppList(marathonValidLabel, 1), nil } - go func() { - up1 := (<-ch)[0] - up2 := (<-ch)[0] - if up2.Source != up1.Source { - t.Fatalf("Source is different: %s", up2) - if len(up2.Targets) > 0 { - t.Fatalf("Got a non-empty target set: %s", up2.Targets) - } - } - }() if err := md.updateServices(context.Background(), ch); err != nil { t.Fatalf("Got error on first update: %s", err) } + up1 := (<-ch)[0] md.appsClient = func(client *http.Client, url, token string) (*AppList, error) { return marathonTestAppList(marathonValidLabel, 0), nil @@ -153,6 +145,14 @@ func TestMarathonSDRemoveApp(t *testing.T) { if err := md.updateServices(context.Background(), ch); err != nil { t.Fatalf("Got error on second update: %s", err) } + up2 := (<-ch)[0] + + if up2.Source != up1.Source { + t.Fatalf("Source is different: %s", up2) + if len(up2.Targets) > 0 { + t.Fatalf("Got a non-empty target set: %s", up2.Targets) + } + } } func TestMarathonSDRunAndStop(t *testing.T) { @@ -160,6 +160,7 @@ func TestMarathonSDRunAndStop(t *testing.T) { refreshInterval = model.Duration(time.Millisecond * 10) conf = config.MarathonSDConfig{Servers: testServers, RefreshInterval: refreshInterval} ch = make(chan []*config.TargetGroup) + doneCh = make(chan error) ) md, err := NewDiscovery(&conf) if err != nil { @@ -171,21 +172,21 @@ func TestMarathonSDRunAndStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - for { - select { - case _, ok := <-ch: - if !ok { - return - } - cancel() - case <-time.After(md.refreshInterval * 3): - cancel() - t.Fatalf("Update took too long.") - } - } + md.Run(ctx, ch) + close(doneCh) }() - md.Run(ctx, ch) + timeout := time.After(md.refreshInterval * 3) + for { + select { + case <-ch: + cancel() + case <-doneCh: + return + case <-timeout: + t.Fatalf("Update took too long.") + } + } } func marathonTestZeroTaskPortAppList(labels map[string]string, runningTasks int) *AppList {