From d7f3e542cf523db5164c28e176f0b174d6edc31e Mon Sep 17 00:00:00 2001 From: Varun Gandhi Date: Tue, 4 Jun 2024 09:56:55 +0800 Subject: [PATCH] chore: Replace errors.As with generic As and AsInterface (#63047) Splits the signature of errors.As into two more specialized functions which catch more errors at compile-time using generics. --- internal/errcode/code.go | 26 +++++++-------- lib/errors/cockroach.go | 59 ++++++++++++++++++++++++----------- lib/errors/invariants_test.go | 26 ++++++--------- lib/errors/multi_error.go | 7 ++++- 4 files changed, 70 insertions(+), 48 deletions(-) diff --git a/internal/errcode/code.go b/internal/errcode/code.go index 800793551ef..8e1dcd7fc41 100644 --- a/internal/errcode/code.go +++ b/internal/errcode/code.go @@ -37,7 +37,7 @@ func HTTP(err error) int { } var e interface{ HTTPStatusCode() int } - if errors.As(err, &e) { + if errors.AsInterface(err, &e) { return e.HTTPStatusCode() } @@ -103,40 +103,40 @@ func (e *Mock) NotFound() bool { // HTTPStatusCode into not found. func IsNotFound(err error) bool { var e interface{ NotFound() bool } - return errors.As(err, &e) && e.NotFound() + return errors.AsInterface(err, &e) && e.NotFound() } // IsUnauthorized will check if err or one of its causes is an unauthorized // error. func IsUnauthorized(err error) bool { var e interface{ Unauthorized() bool } - return errors.As(err, &e) && e.Unauthorized() + return errors.AsInterface(err, &e) && e.Unauthorized() } // IsForbidden will check if err or one of its causes is a forbidden error. func IsForbidden(err error) bool { var e interface{ Forbidden() bool } - return errors.As(err, &e) && e.Forbidden() + return errors.AsInterface(err, &e) && e.Forbidden() } // IsAccountSuspended will check if err or one of its causes was due to the // account being suspended func IsAccountSuspended(err error) bool { var e interface{ AccountSuspended() bool } - return errors.As(err, &e) && e.AccountSuspended() + return errors.AsInterface(err, &e) && e.AccountSuspended() } // IsUnavailableForLegalReasons will check if err or one of its causes was due to // legal reasons. func IsUnavailableForLegalReasons(err error) bool { var e interface{ UnavailableForLegalReasons() bool } - return errors.As(err, &e) && e.UnavailableForLegalReasons() + return errors.AsInterface(err, &e) && e.UnavailableForLegalReasons() } // IsBadRequest will check if err or one of its causes is a bad request. func IsBadRequest(err error) bool { var e interface{ BadRequest() bool } - return errors.As(err, &e) && e.BadRequest() + return errors.AsInterface(err, &e) && e.BadRequest() } // IsTemporary will check if err or one of its causes is temporary. A @@ -144,12 +144,12 @@ func IsBadRequest(err error) bool { // temporary interface. func IsTemporary(err error) bool { var e interface{ Temporary() bool } - return errors.As(err, &e) && e.Temporary() + return errors.AsInterface(err, &e) && e.Temporary() } func IsRepoDenied(err error) bool { var e interface{ IsRepoDenied() bool } - return errors.As(err, &e) && e.IsRepoDenied() + return errors.AsInterface(err, &e) && e.IsRepoDenied() } // IsArchived will check if err or one of its causes is an archived error. @@ -157,26 +157,26 @@ func IsRepoDenied(err error) bool { // archived.) func IsArchived(err error) bool { var e interface{ Archived() bool } - return errors.As(err, &e) && e.Archived() + return errors.AsInterface(err, &e) && e.Archived() } // IsBlocked will check if err or one of its causes is a blocked error. func IsBlocked(err error) bool { var e interface{ Blocked() bool } - return errors.As(err, &e) && e.Blocked() + return errors.AsInterface(err, &e) && e.Blocked() } // IsTimeout will check if err or one of its causes is a timeout. Many errors // in the go stdlib implement the timeout interface. func IsTimeout(err error) bool { var e interface{ Timeout() bool } - return errors.As(err, &e) && e.Timeout() + return errors.AsInterface(err, &e) && e.Timeout() } // IsNonRetryable will check if err or one of its causes is a error that cannot be retried. func IsNonRetryable(err error) bool { var e interface{ NonRetryable() bool } - return errors.As(err, &e) && e.NonRetryable() + return errors.AsInterface(err, &e) && e.NonRetryable() } // MakeNonRetryable makes any error non-retryable. diff --git a/lib/errors/cockroach.go b/lib/errors/cockroach.go index 3b703613ab6..2b3c39568c1 100644 --- a/lib/errors/cockroach.go +++ b/lib/errors/cockroach.go @@ -60,24 +60,8 @@ var ( // returns true for a value other than the one returned by As, // since an error tree can contain multiple errors of the same // concrete type but with different data. - Is = errors.Is - IsAny = errors.IsAny - // As checks if the error tree err is of type target, and if so, - // sets target to the value of the error. This can be used in two ways: - // - // 1. If looking for an error of concrete type T, then the second - // argument must be a non-nil pointer of type *T. This implies that - // if the error interface is implemented with a pointer receiver, - // then target must be of type **MyConcreteType. - // 2. If looking for an error satisfying an interface I (with a value - // or pointer receiver), then the second argument must be of type I. - // - // For error types which do not contain any data, As is equivalent to Is. - // - // For error types which contain data, As will return an arbitrary - // error of the target type, in case there are multiple errors of the - // same concrete type in the error tree. - As = errors.As + Is = errors.Is + IsAny = errors.IsAny HasType = errors.HasType Cause = errors.Cause Unwrap = errors.Unwrap @@ -86,6 +70,45 @@ var ( BuildSentryReport = errors.BuildSentryReport ) +// As checks if the error tree err is of type target, and if so, +// sets target to the value of the error. +// +// If looking for an error of concrete type T, then the second +// argument must be a non-nil pointer of type *T. This implies that +// if the error interface is implemented with a pointer receiver, +// then target must be of type **MyConcreteType. +// +// For error types which do not contain any data, As is equivalent to Is. +// +// For error types which contain data, As will return an arbitrary +// error of the target type, in case there are multiple errors of the +// same concrete type in the error tree. +// +// Compared to errors.As, this method uses a generic argument to prevent +// a runtime panic when target is not a pointer to an error type. +// +// Use AsInterface over this function for interface targets. +func As[T error](err error, target *T) bool { + return errors.As(err, target) +} + +// AsInterface checks if the error tree err is of type target (which must be +// an interface type), and if so, sets target to the value of the error. +// +// In general, 'I' may be any interface, not just an error interface. +// See internal/errcode/code.go for some examples. +// +// Use As over this function for concrete types. +func AsInterface[I any](err error, target *I) bool { + if target == nil { + panic("Expected non-nil pointer to interface") + } + if typ := reflect.TypeOf(target); typ.Elem().Kind() != reflect.Interface { + panic("Expected pointer to interface") + } + return errors.As(err, target) +} + // Extend multiError to work with cockroachdb errors. Implement here to keep imports in // one place. diff --git a/lib/errors/invariants_test.go b/lib/errors/invariants_test.go index a91b25b647c..02e00bf3067 100644 --- a/lib/errors/invariants_test.go +++ b/lib/errors/invariants_test.go @@ -120,20 +120,12 @@ func TestInvariants(t *testing.T) { if Is(err, &payloadLessPtrError{}) { // This can be false, see Counter-example 1 //require.True(t, HasType(err, &payloadLessPtrError{})) - require.Panics(t, func() { - var check payloadLessPtrError - require.True(t, As(err, &check)) - }) var check *payloadLessPtrError require.True(t, As(err, &check)) } // HasType implies Is and As for errors without data if HasType(err, &payloadLessPtrError{}) { require.True(t, Is(err, &payloadLessPtrError{})) - require.Panics(t, func() { - var check payloadLessPtrError - require.True(t, As(err, &check)) - }) var check *payloadLessPtrError require.True(t, As(err, &check)) } @@ -166,10 +158,6 @@ func TestInvariants(t *testing.T) { //require.True(t, HasType(err, errorOfInterest)) //require.True(t, HasType(err, errorWithOtherData)) //require.True(t, HasType(err, withPayloadStructError{})) - require.Panics(t, func() { - var check withPayloadPtrError - _ = As(err, &check) - }) var check *withPayloadPtrError require.True(t, As(err, &check)) // This can be false, see Counter-example 6 @@ -182,10 +170,6 @@ func TestInvariants(t *testing.T) { require.True(t, HasType(err, &withPayloadPtrError{})) //This can be false, see Counter-example 3 //require.True(t, Is(err, errorOfInterest)) - require.Panics(t, func() { - var check withPayloadPtrError - _ = As(err, &check) - }) var check *withPayloadPtrError require.True(t, As(err, &check)) require.True(t, *check == *errorOfInterest || *check == *errorWithOtherData) @@ -295,3 +279,13 @@ var _ error = ¬TheErrorOfInterest{} func (p *notTheErrorOfInterest) Error() string { return "notTheErrorOfInterest{}" } + +func TestAsInterface(t *testing.T) { + require.Panics(t, func() { + p := &payloadLessPtrError{} + err := error(&payloadLessPtrError{}) + AsInterface(err, &p) + }) + var e error + require.True(t, AsInterface(error(&payloadLessPtrError{}), &e)) +} diff --git a/lib/errors/multi_error.go b/lib/errors/multi_error.go index 7725850a086..452061473e3 100644 --- a/lib/errors/multi_error.go +++ b/lib/errors/multi_error.go @@ -2,6 +2,8 @@ package errors import ( "fmt" + + "github.com/cockroachdb/errors" //nolint:depguard // needed for implementation of multiError.As ) // MultiError is a container for groups of errors. @@ -100,7 +102,10 @@ func (e *multiError) As(target any) bool { return true } for _, err := range e.errs { - if As(err, target) { + // To conform to the Typed interface, 'target' has to be of type + // any. This means we cannot use our custom As wrapper which has + // a generic argument, so use cockroachdb's As instead. + if errors.As(err, target) { return true } }