diff --git a/go.mod b/go.mod index 697b3fa3b0f..de9549c6d90 100644 --- a/go.mod +++ b/go.mod @@ -439,7 +439,7 @@ require ( golang.org/x/text v0.5.0 golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/grpc v1.51.0 // indirect + google.golang.org/grpc v1.51.0 gopkg.in/alexcesaro/statsd.v2 v2.0.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/internal/grpc/grpc.go b/internal/grpc/grpc.go new file mode 100644 index 00000000000..595f2c6d79a --- /dev/null +++ b/internal/grpc/grpc.go @@ -0,0 +1,29 @@ +// Package grpc is a set of shared code for implementing gRPC. +package grpc + +import ( + "net/http" + "strings" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "google.golang.org/grpc" +) + +// MultiplexHandlers takes a gRPC server and a plain HTTP handler and multiplexes the +// request handling. Any requests that declare themselves as gRPC requests are routed +// to the gRPC server, all others are routed to the httpHandler. +func MultiplexHandlers(grpcServer *grpc.Server, httpHandler http.Handler) http.Handler { + newHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + grpcServer.ServeHTTP(w, r) + } else { + httpHandler.ServeHTTP(w, r) + } + }) + + // Until we enable TLS, we need to fall back to the h2c protocol, which is + // basically HTTP2 without TLS. The standard library does not implement the + // h2s protocol, so this hijacks h2s requests and handles them correctly. + return h2c.NewHandler(newHandler, &http2.Server{}) +} diff --git a/internal/grpc/grpc_test.go b/internal/grpc/grpc_test.go new file mode 100644 index 00000000000..8367e6558b2 --- /dev/null +++ b/internal/grpc/grpc_test.go @@ -0,0 +1,39 @@ +package grpc + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestMultiplexHandlers(t *testing.T) { + grpcServer := grpc.NewServer() + called := false + httpHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + called = true + }) + multiplexedHandler := MultiplexHandlers(grpcServer, httpHandler) + + { // Basic HTTP request is routed to HTTP handler + req, err := http.NewRequest("GET", "", bytes.NewReader(nil)) + require.NoError(t, err) + called = false + multiplexedHandler.ServeHTTP(httptest.NewRecorder(), req) + require.True(t, called) + } + + { // Request with HTTP2 and application/grpc header is not routed to HTTP handler + req, err := http.NewRequest("GET", "", bytes.NewReader(nil)) + require.NoError(t, err) + req.Header.Add("content-type", "application/grpc") + req.ProtoMajor = 2 + + called = false + multiplexedHandler.ServeHTTP(httptest.NewRecorder(), req) + require.False(t, called) + } +}