discovery: properly check context on chan send

This commit is contained in:
Fabian Reinartz 2016-04-25 16:32:04 +02:00
parent 9f8feb9ff6
commit b5bfb502df
2 changed files with 20 additions and 12 deletions

View file

@ -48,7 +48,6 @@ const appListPath string = "/v2/apps/?embed=apps.tasks"
type Discovery struct { type Discovery struct {
Servers []string Servers []string
RefreshInterval time.Duration RefreshInterval time.Duration
Done chan struct{}
lastRefresh map[string]*config.TargetGroup lastRefresh map[string]*config.TargetGroup
Client AppListClient Client AppListClient
} }
@ -62,7 +61,7 @@ func (md *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-time.After(md.RefreshInterval): case <-time.After(md.RefreshInterval):
err := md.updateServices(ch) err := md.updateServices(ctx, ch)
if err != nil { if err != nil {
log.Errorf("Error while updating services: %s", err) log.Errorf("Error while updating services: %s", err)
} }
@ -70,7 +69,7 @@ func (md *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) {
} }
} }
func (md *Discovery) updateServices(ch chan<- []*config.TargetGroup) error { func (md *Discovery) updateServices(ctx context.Context, ch chan<- []*config.TargetGroup) error {
targetMap, err := md.fetchTargetGroups() targetMap, err := md.fetchTargetGroups()
if err != nil { if err != nil {
return err return err
@ -80,14 +79,23 @@ func (md *Discovery) updateServices(ch chan<- []*config.TargetGroup) error {
for _, tg := range targetMap { for _, tg := range targetMap {
all = append(all, tg) all = append(all, tg)
} }
ch <- all
// Remove services which did disappear select {
case <-ctx.Done():
return ctx.Err()
case ch <- all:
}
// Remove services which did disappear.
for source := range md.lastRefresh { for source := range md.lastRefresh {
_, ok := targetMap[source] _, ok := targetMap[source]
if !ok { if !ok {
select {
case <-ctx.Done():
return ctx.Err()
case ch <- []*config.TargetGroup{{Source: source}}:
log.Debugf("Removing group for %s", source) log.Debugf("Removing group for %s", source)
ch <- []*config.TargetGroup{{Source: source}} }
} }
} }

View file

@ -47,7 +47,7 @@ func TestMarathonSDHandleError(t *testing.T) {
default: default:
} }
}() }()
err := md.updateServices(ch) err := md.updateServices(context.Background(), ch)
if err != errTesting { if err != errTesting {
t.Fatalf("Expected error: %s", err) t.Fatalf("Expected error: %s", err)
} }
@ -66,7 +66,7 @@ func TestMarathonSDEmptyList(t *testing.T) {
default: default:
} }
}() }()
err := md.updateServices(ch) err := md.updateServices(context.Background(), ch)
if err != nil { if err != nil {
t.Fatalf("Got error: %s", err) t.Fatalf("Got error: %s", err)
} }
@ -115,7 +115,7 @@ func TestMarathonSDSendGroup(t *testing.T) {
t.Fatal("Did not get a target group.") t.Fatal("Did not get a target group.")
} }
}() }()
err := md.updateServices(ch) err := md.updateServices(context.Background(), ch)
if err != nil { if err != nil {
t.Fatalf("Got error: %s", err) t.Fatalf("Got error: %s", err)
} }
@ -136,7 +136,7 @@ func TestMarathonSDRemoveApp(t *testing.T) {
} }
} }
}() }()
err := md.updateServices(ch) err := md.updateServices(context.Background(), ch)
if err != nil { if err != nil {
t.Fatalf("Got error on first update: %s", err) t.Fatalf("Got error on first update: %s", err)
} }
@ -144,7 +144,7 @@ func TestMarathonSDRemoveApp(t *testing.T) {
md.Client = func(url string) (*AppList, error) { md.Client = func(url string) (*AppList, error) {
return marathonTestAppList(marathonValidLabel, 0), nil return marathonTestAppList(marathonValidLabel, 0), nil
} }
err = md.updateServices(ch) err = md.updateServices(context.Background(), ch)
if err != nil { if err != nil {
t.Fatalf("Got error on second update: %s", err) t.Fatalf("Got error on second update: %s", err)
} }