diff --git a/cmd/frontend/internal/httpapi/BUILD.bazel b/cmd/frontend/internal/httpapi/BUILD.bazel index ea42517f932..683e1b2a3bd 100644 --- a/cmd/frontend/internal/httpapi/BUILD.bazel +++ b/cmd/frontend/internal/httpapi/BUILD.bazel @@ -69,8 +69,10 @@ go_library( "//internal/types", "//internal/updatecheck", "//lib/errors", + "//lib/limitedgzip", "//lib/pointers", "//schema", + "@com_github_alecthomas_units//:units", "@com_github_derision_test_glock//:glock", "@com_github_gorilla_mux//:mux", "@com_github_graph_gophers_graphql_go//:graphql-go", diff --git a/cmd/frontend/internal/httpapi/graphql.go b/cmd/frontend/internal/httpapi/graphql.go index 35ff08390b8..b2326f799e8 100644 --- a/cmd/frontend/internal/httpapi/graphql.go +++ b/cmd/frontend/internal/httpapi/graphql.go @@ -1,7 +1,6 @@ package httpapi import ( - "compress/gzip" "context" "encoding/json" "net/http" @@ -9,6 +8,7 @@ import ( "strings" "time" + "github.com/alecthomas/units" "github.com/graph-gophers/graphql-go" gqlerrors "github.com/graph-gophers/graphql-go/errors" "github.com/prometheus/client_golang/prometheus" @@ -22,12 +22,16 @@ import ( "github.com/sourcegraph/sourcegraph/internal/audit" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/cookie" + "github.com/sourcegraph/sourcegraph/internal/env" "github.com/sourcegraph/sourcegraph/internal/trace" "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/limitedgzip" ) const costEstimationMetricActorTypeLabel = "actor_type" +var gzipFileSizeLimit = env.MustGetInt("HTTAPI_GZIP_FILE_SIZE_LIMIT", 500*int(units.Megabyte), "Maximum size of gzipped request bodies to read") + var ( costHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ Name: "src_graphql_cost_distribution", @@ -85,14 +89,12 @@ func serveGraphQL(logger log.Logger, schema *graphql.Schema, rlw graphqlbackend. r = r.WithContext(trace.WithRequestSource(r.Context(), requestSource)) if r.Header.Get("Content-Encoding") == "gzip" { - gzipReader, err := gzip.NewReader(r.Body) + r.Body, err = limitedgzip.WithReader(r.Body, int64(gzipFileSizeLimit)) if err != nil { return errors.Wrap(err, "failed to decompress request body") } - r.Body = gzipReader - - defer gzipReader.Close() + defer r.Body.Close() } var params graphQLQueryParams diff --git a/cmd/frontend/internal/httpapi/opencodegraph.go b/cmd/frontend/internal/httpapi/opencodegraph.go index b232e3d02da..4582f7884b7 100644 --- a/cmd/frontend/internal/httpapi/opencodegraph.go +++ b/cmd/frontend/internal/httpapi/opencodegraph.go @@ -1,16 +1,17 @@ package httpapi import ( - "compress/gzip" "encoding/json" "net/http" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search" "github.com/sourcegraph/sourcegraph/internal/featureflag" "github.com/sourcegraph/sourcegraph/internal/opencodegraph" "github.com/sourcegraph/sourcegraph/internal/trace" "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/limitedgzip" "github.com/sourcegraph/sourcegraph/schema" ) @@ -39,12 +40,12 @@ func serveOpenCodeGraph(logger log.Logger) func(w http.ResponseWriter, r *http.R r = r.WithContext(trace.WithRequestSource(r.Context(), requestSource)) if r.Header.Get("Content-Encoding") == "gzip" { - gzipReader, err := gzip.NewReader(r.Body) + r.Body, err = limitedgzip.WithReader(r.Body, int64(gzipFileSizeLimit)) if err != nil { return errors.Wrap(err, "failed to decompress request body") } - r.Body = gzipReader - defer gzipReader.Close() + + defer r.Body.Close() } method, cap, ann, err := opencodegraph.DecodeRequestMessage(json.NewDecoder(r.Body)) diff --git a/lib/limitedgzip/BUILD.bazel b/lib/limitedgzip/BUILD.bazel new file mode 100644 index 00000000000..aceaffbc02d --- /dev/null +++ b/lib/limitedgzip/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "limitedgzip", + srcs = ["gzip.go"], + importpath = "github.com/sourcegraph/sourcegraph/lib/limitedgzip", + visibility = ["//visibility:public"], +) diff --git a/lib/limitedgzip/gzip.go b/lib/limitedgzip/gzip.go new file mode 100644 index 00000000000..cea46cdeedc --- /dev/null +++ b/lib/limitedgzip/gzip.go @@ -0,0 +1,25 @@ +package limitedgzip + +import ( + "compress/gzip" + "io" +) + +// WithReader returns a new io.ReadCloser that reads and decompresses the body +// it reads until io.EOF or the specified limit is reached. +func WithReader(body io.ReadCloser, limit int64) (io.ReadCloser, error) { + gzipReader, err := gzip.NewReader(body) + if err != nil { + return nil, err + } + + body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(gzipReader, limit), + Closer: gzipReader, + } + + return body, nil +}