diff --git a/utility/uncertaintygroup.go b/utility/uncertaintygroup.go index b1c909744..2c00833a0 100644 --- a/utility/uncertaintygroup.go +++ b/utility/uncertaintygroup.go @@ -15,6 +15,7 @@ package utility import ( "fmt" + "sync" ) type state int @@ -62,18 +63,19 @@ type uncertaintyGroup struct { successes uint results chan error anomalies []error + sync.Mutex } -func (g uncertaintyGroup) Succeed() { - if g.state == finished { +func (g *uncertaintyGroup) Succeed() { + if g.isFinished() { panic("cannot remark when done") } g.results <- nil } -func (g uncertaintyGroup) Fail(err error) { - if g.state == finished { +func (g *uncertaintyGroup) Fail(err error) { + if g.isFinished() { panic("cannot remark when done") } @@ -84,22 +86,42 @@ func (g uncertaintyGroup) Fail(err error) { g.results <- err } -func (g uncertaintyGroup) MayFail(err error) { - if g.state == finished { +func (g *uncertaintyGroup) MayFail(err error) { + if g.isFinished() { panic("cannot remark when done") } g.results <- err } -func (g *uncertaintyGroup) Wait() bool { +func (g *uncertaintyGroup) isFinished() bool { + g.Lock() + defer g.Unlock() + + return g.state == finished +} + +func (g *uncertaintyGroup) finish() { + g.Lock() + defer g.Unlock() + + g.state = finished +} + +func (g *uncertaintyGroup) start() { + g.Lock() + defer g.Unlock() + if g.state != unstarted { panic("cannot restart") } - defer close(g.results) - g.state = started +} + +func (g *uncertaintyGroup) Wait() bool { + defer close(g.results) + g.start() for g.remaining > 0 { result := <-g.results @@ -113,12 +135,12 @@ func (g *uncertaintyGroup) Wait() bool { g.remaining-- } - g.state = finished + g.finish() return len(g.anomalies) == 0 } -func (g uncertaintyGroup) Errors() []error { +func (g *uncertaintyGroup) Errors() []error { if g.state != finished { panic("cannot provide errors until finished") } @@ -126,7 +148,7 @@ func (g uncertaintyGroup) Errors() []error { return g.anomalies } -func (g uncertaintyGroup) String() string { +func (g *uncertaintyGroup) String() string { return fmt.Sprintf("UncertaintyGroup %s with %s failures", g.state, g.anomalies) }