lib/group: propagate panics from child goroutines (#42679)

propagate panics from child goroutines
This commit is contained in:
Camden Cheek 2022-10-07 13:58:44 -06:00 committed by GitHub
parent 09786101a1
commit b89db88a97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 13 deletions

View File

@ -5,10 +5,9 @@ package group
import (
"context"
"runtime/debug"
"sync"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -114,6 +113,9 @@ type Errorable[T any] interface {
type group struct {
wg sync.WaitGroup
limiter Limiter // nil limiter means unlimited (default)
recoverMux sync.Mutex
recoveredErr error
}
func (g *group) Go(f func()) {
@ -140,7 +142,7 @@ func (g *group) start(f func()) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
defer recoverPanic()
defer g.recoverPanic()
f()
}()
@ -148,6 +150,11 @@ func (g *group) start(f func()) {
func (g *group) Wait() {
g.wg.Wait()
// Propagate panic from child goroutine
if g.recoveredErr != nil {
panic(g.recoveredErr)
}
}
func (g *group) WithMaxConcurrency(limit int) Group {
@ -173,6 +180,22 @@ func (g *group) WithContext(ctx context.Context) ContextGroup {
}
}
func (g *group) recoverPanic() {
if val := recover(); val != nil {
g.recoverMux.Lock()
defer g.recoverMux.Unlock()
var err error
if valErr, ok := val.(error); ok {
err = valErr
} else {
err = errors.Errorf("%#v", val)
}
g.recoveredErr = errors.Wrapf(err, "recovered from panic in child goroutine with stacktrace:\n%s", debug.Stack())
}
}
// errorGroup wraps a *group with error collection
type errorGroup struct {
group *group
@ -302,13 +325,3 @@ func (g *contextGroup) WithFirstError() ContextGroup {
g.errorGroup.onlyFirst = true
return g
}
func recoverPanic() {
if val := recover(); val != nil {
if err, ok := val.(error); ok {
log.Scoped("internal", "group").Error("recovered from panic", log.Error(err))
} else {
log.Scoped("internal", "group").Error("recovered from panic", log.Error(errors.Errorf("%#v", val)))
}
}
}

View File

@ -52,6 +52,19 @@ func TestGroup(t *testing.T) {
})
}
})
t.Run("propagate panic", func(t *testing.T) {
g := New()
for i := 0; i < 10; i++ {
i := i
g.Go(func() {
if i == 5 {
panic(i)
}
})
}
require.Panics(t, func() { g.Wait() })
})
}
func TestErrorGroup(t *testing.T) {