Orchestrating signal and wait in Go

One of the common use case in Go is to start a few goroutines to do some work. These goroutines block listening in on a channel, waiting for more work to arrive. At some point, you want to signal these goroutines to stop accepting more work and exit, so you can cleanly shut down the program.

This is how the code might look:

func doWork(chanWork chan *work) {
  for w := range chanWork {
    do(w)
  }
}

func main() {
  chanWork := make(chan *work, 100)
  for i := 0; i < N; i++ {
    go doWork(chanWork)
  }

  // Push work to chanWork.
  chanWork <- w

  // All work is done, now stop.
  close(chanWork)
}

This code looks reasonable, with one caveat. Once you close the chanWork channel, the main exits immediately. Closing the channel only acts as a signal. You want the program to wait as well.

Using sync.WaitGroup allows the goroutines to cleanly exit before exiting the main program, like so:

func doWork(chanWork chan *work, wg *sync.WaitGroup) {
  defer wg.Done()
  ...
}

func main() {
  var wg sync.WaitGroup
  chanWork := make(chan *work, 100)
  for i := 0; i < N; i++ {
    wg.Add(1)
    go doWork(chanWork, &wg)
  }

  // Push work to chanWork.
  chanWork <- w

  // All work is done, now stop.
  close(chanWork)

  // Now wait.
  wg.Wait()
}

For a Go programmer, this is pretty basic so far. And that’s because we had access to the channel, which controls when the goroutine should exit.

What happens in case we don’t have access to this channel? For, e.g., if we want to run some execution periodically, we’ll have this situation:

func doWorkPeriodically(wg *sync.WaitGroup) {
  defer wg.Done()
  timeChan := time.Tick(time.Second)

  for _ := range timeChan {
    do()  // some work
  }
}

In this case, we need a way to signal the goroutine to stop doing the work. Say we use a signal channel like so:

func doWorkPeriodically(wg *sync.WaitGroup, signal chan struct{}) {
  defer wg.Done()
  timeChan := time.Tick(time.Second)

  for {
    select {
      case <- timeChan:
        do() // some work
      case <- signal:
        return
    }
  }
}

func main() {
  signal := make(chan struct{}, 1)

  ...

  signal <- struct{}{}  // To signal the goroutine to stop.
  wg.Wait() // Wait for goroutine to exit.
}

This code above would indicate to the goroutine to stop doing the work. And also wait for it so main can exit cleanly.

All nice and good so far. Now, what if we have multiple different such goroutines, and we need to signal and wait on them in some order. For, e.g., we might have a pipeline of sorts, with multiple stages, each dependent on the previous A -> B -> C. In this case, we need to signal and wait on A, before we do that for B and then C.

With the above code, this would be cumbersome. You’d need multiple signal channels, one for each stage; and similarly, multiple waits, one for each stage. Would be nice to encapsulate this in a class. That’s what we have done in Badger.

type LevelCloser struct {
    Name    string
    running int32
    nomore  int32
    closed  chan struct{}
    waiting sync.WaitGroup
}

func (lc *LevelCloser) Signal() {
    if !atomic.CompareAndSwapInt32(&lc.nomore, 0, 1) {
        // fmt.Printf("Level %q already got signal\n", lc.Name)
        return
    }
    running := int(atomic.LoadInt32(&lc.running))
    // fmt.Printf("Sending signal to %d registered with name %q\n",
                  running, lc.Name)
    for i := 0; i < running; i++ {
        lc.closed <- struct{}{}
    }
}

func (lc *LevelCloser) HasBeenClosed() <-chan struct{} {
	return lc.closed
}

func (lc *LevelCloser) Done() {
    if atomic.LoadInt32(&lc.running) <= 0 {
        return
    }

    running := atomic.AddInt32(&lc.running, -1)
    if running == 0 {
        lc.waiting.Done()
    }
}

func (lc *LevelCloser) Wait() {
    lc.waiting.Wait()
}

It’s a simple class with some basic APIs. The way you’d use it is like this:

func doWorkPeriodically(lc *LevelCloser) {
  defer lc.Done()
  timeChan := time.Tick(time.Second)

  for {
    select {
      case <- timeChan:
        do() // some work
      case <- lc.HasBeenClosed():
        return
    }
  }
}

func main() {
  lc := &LevelCloser{
    Name: name,
    closed: make(chan struct{}, 10),
    running: 1,
  }
  lc.waiting.Add(1)

  doWorkPeriodically(lc)

  lc.Signal()
  lc.Wait()
}

Finally, to make it work for multiple stages, dependent or not, we wrap it up into one Closer class.

type Closer struct {
    sync.RWMutex
    levels map[string]*LevelCloser
}

func NewCloser() *Closer {
    return &Closer{
        levels: make(map[string]*LevelCloser),
    }
}

func (c *Closer) Register(name string) *LevelCloser {
    c.Lock()
    defer c.Unlock()

    lc, has := c.levels[name]
    if !has {
        lc = &LevelCloser{Name: name, closed: make(chan struct{}, 10)}
        lc.waiting.Add(1)
        c.levels[name] = lc
    }

    AssertTruef(atomic.LoadInt32(&lc.nomore) == 0, "Can't register with closer after signal.")
    atomic.AddInt32(&lc.running, 1)
    return lc
}

func (c *Closer) Get(name string) *LevelCloser {
    c.RLock()
    defer c.RUnlock()

    lc, has := c.levels[name]
    if !has {
        log.Fatalf("%q not present in Closer", name)
        return nil
    }
    return lc
}

Using this wrapper class, you can just create one Closer object, and use that to create and maintain all LevelClosers. This way, you can retrieve, and signal all the LevelClosers individually in order, or just have Closer signal all of them, and then wait for all of them.

func (c *Closer) SignalAll() {
    c.RLock()
    defer c.RUnlock()

    for _, l := range c.levels {
        l.Signal()
    }
}

func (c *Closer) WaitForAll() {
    c.RLock()
    defer c.RUnlock()

    for _, l := range c.levels {
        l.Wait()
    }
}

This is how you’d use this class:

func stageA(lc *LevelCloser) {
  defer lc.Done()
  for {
    select {
      case <- someChan:
      case <- lc.HasBeenClosed():
        return
    }
  }
}

func stageB(lc *LevelCloser) {
  ...
}

func main() {
  closer := NewCloser()
  lc := closer.Register("stage-a")
  go stageA(lc)

  lc := closer.Register("stage-b")
  go stageB(lc)

  ...

  lc = closer.Get("stage-b")
  lc.SignalAndWait()

  closer.SignalAll()
  closer.WaitAll()
}

This class is being used by Badger, and it significantly simplifies the various asynchronous activities going on internally. We can ensure that our writes are all committed before in-memory tables are flushed, before we close value log, and so on.

You can see the entire code here. The code is under Apache 2.0 license, so feel free to copy the code and use it in your project. You can see the class in action here. Look for closer.Register and closer.Get to track how we create multiple such LevelClosers and use them to maintain a strict opening and closing order between the various stages.

Hope you found this useful! Check out other posts to see how Dgraph and Badger can add value to your projects.