From 38d4e83e59ba8cad0daea47aabcc2ae072337426 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Mon, 29 Jul 2024 14:17:25 -0700 Subject: [PATCH] feat/requestclient: propagate original User-Agent as X-Forwarded-For-User-Agent (#64113) Propagates a for-reference-only record of the first `User-Agent` seen when a request gets into Sourcegraph across services and contexts. This allows telemetry to try and indicate where a request originates from (https://github.com/sourcegraph/sourcegraph/pull/64112), rather than only having the most recent user-agent. A new header and `requestclient.Client` property `X-Forwarded-For-User-Agent` and `ForwardedForUserAgent` is used to explicitly forward this. Strictly speaking I think we're supposed to just forward `User-Agent` but it looks like in multiple places we add/clobber the `User-Agent` ourselves. The gRPC propagator currently sets user-agent on outgoing requests, this change also makes that consistent with the HTTP transport, such that both only explicitly propagate `X-Forwarded-For-User-Agent` ## Test plan Unit tests --- internal/audit/audit.go | 8 +++ internal/audit/audit_test.go | 45 ++++++++------- internal/requestclient/BUILD.bazel | 1 + internal/requestclient/client.go | 12 +++- internal/requestclient/grpc.go | 26 +++++++-- internal/requestclient/grpc_test.go | 24 ++++++++ internal/requestclient/http.go | 20 ++++++- internal/requestclient/http_test.go | 87 +++++++++++++++++++++++++++++ 8 files changed, 194 insertions(+), 29 deletions(-) create mode 100644 internal/requestclient/http_test.go diff --git a/internal/audit/audit.go b/internal/audit/audit.go index f638050483b..975bec06b62 100644 --- a/internal/audit/audit.go +++ b/internal/audit/audit.go @@ -49,6 +49,7 @@ func Log(ctx context.Context, logger log.Logger, record Record) { log.String("actorUID", actorId(act)), log.String("ip", ip(client)), log.String("userAgent", userAgent(client)), + log.String("forwardedForUserAgent", forwardedForUserAgent(client)), log.String("X-Forwarded-For", forwardedFor(client))))) fields = append(fields, record.Fields...) @@ -81,6 +82,13 @@ func userAgent(client *requestclient.Client) string { return client.UserAgent } +func forwardedForUserAgent(client *requestclient.Client) string { + if client == nil { + return "unknown" + } + return client.ForwardedForUserAgent +} + func forwardedFor(client *requestclient.Client) string { if client == nil { return "unknown" diff --git a/internal/audit/audit_test.go b/internal/audit/audit_test.go index b85ba6734c1..41325c5a6c9 100644 --- a/internal/audit/audit_test.go +++ b/internal/audit/audit_test.go @@ -36,10 +36,11 @@ func TestLog(t *testing.T) { expectedEntry: autogold.Expect(map[string]interface{}{"additional": "stuff", "audit": map[string]interface{}{ "action": "test audit action", "actor": map[string]interface{}{ - "X-Forwarded-For": "192.168.0.1", - "actorUID": "1", - "ip": "192.168.0.1", - "userAgent": "Foobar", + "X-Forwarded-For": "192.168.0.1", + "actorUID": "1", + "forwardedForUserAgent": "", + "ip": "192.168.0.1", + "userAgent": "Foobar", }, "auditId": "test-audit-id-1234", "entity": "test entity", @@ -57,10 +58,11 @@ func TestLog(t *testing.T) { expectedEntry: autogold.Expect(map[string]interface{}{"additional": "stuff", "audit": map[string]interface{}{ "action": "test audit action", "actor": map[string]interface{}{ - "X-Forwarded-For": "192.168.0.1", - "actorUID": "anonymous", - "ip": "192.168.0.1", - "userAgent": "Foobar", + "X-Forwarded-For": "192.168.0.1", + "actorUID": "anonymous", + "forwardedForUserAgent": "", + "ip": "192.168.0.1", + "userAgent": "Foobar", }, "auditId": "test-audit-id-1234", "entity": "test entity", @@ -78,10 +80,11 @@ func TestLog(t *testing.T) { expectedEntry: autogold.Expect(map[string]interface{}{"additional": "stuff", "audit": map[string]interface{}{ "action": "test audit action", "actor": map[string]interface{}{ - "X-Forwarded-For": "192.168.0.1", - "actorUID": "unknown", - "ip": "192.168.0.1", - "userAgent": "Foobar", + "X-Forwarded-For": "192.168.0.1", + "actorUID": "unknown", + "forwardedForUserAgent": "", + "ip": "192.168.0.1", + "userAgent": "Foobar", }, "auditId": "test-audit-id-1234", "entity": "test entity", @@ -95,10 +98,11 @@ func TestLog(t *testing.T) { expectedEntry: autogold.Expect(map[string]interface{}{"additional": "stuff", "audit": map[string]interface{}{ "action": "test audit action", "actor": map[string]interface{}{ - "X-Forwarded-For": "unknown", - "actorUID": "1", - "ip": "unknown", - "userAgent": "unknown", + "X-Forwarded-For": "unknown", + "actorUID": "1", + "forwardedForUserAgent": "unknown", + "ip": "unknown", + "userAgent": "unknown", }, "auditId": "test-audit-id-1234", "entity": "test entity", @@ -115,10 +119,11 @@ func TestLog(t *testing.T) { additionalContext: nil, expectedEntry: autogold.Expect(map[string]interface{}{"audit": map[string]interface{}{ "action": "test audit action", "actor": map[string]interface{}{ - "X-Forwarded-For": "192.168.0.1", - "actorUID": "1", - "ip": "192.168.0.1", - "userAgent": "Foobar", + "X-Forwarded-For": "192.168.0.1", + "actorUID": "1", + "forwardedForUserAgent": "", + "ip": "192.168.0.1", + "userAgent": "Foobar", }, "auditId": "test-audit-id-1234", "entity": "test entity", diff --git a/internal/requestclient/BUILD.bazel b/internal/requestclient/BUILD.bazel index 74bd9c92fe2..2258859ef26 100644 --- a/internal/requestclient/BUILD.bazel +++ b/internal/requestclient/BUILD.bazel @@ -26,6 +26,7 @@ go_test( srcs = [ "client_test.go", "grpc_test.go", + "http_test.go", ], embed = [":requestclient"], deps = [ diff --git a/internal/requestclient/client.go b/internal/requestclient/client.go index 60bce37230c..ea52943aa36 100644 --- a/internal/requestclient/client.go +++ b/internal/requestclient/client.go @@ -23,9 +23,18 @@ type Client struct { // Note: This header can be spoofed and relies on trusted clients/proxies. // For sourcegraph.com we use cloudflare headers to avoid spoofing. ForwardedFor string - // UserAgent is value of the User-Agent header: + // UserAgent is current value of the User-Agent header: // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent UserAgent string + // ForwardedForUserAgent is first known value of the User-Agent header + // from the original request: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent + // + // It may be equal to UserAgent if no previous ForwardedForUserAgent was + // provided. + // + // Note: This header can be spoofed, and should only be used for reference. + ForwardedForUserAgent string // wafIPCountryCode is a ISO 3166-1 alpha-2 country code for the // request client as provided by a WAF (typically Cloudlfare) behind which @@ -69,6 +78,7 @@ func (c *Client) LogFields() []log.Field { log.String("requestClient.ip", c.IP), log.String("requestClient.forwardedFor", c.ForwardedFor), log.String("requestClient.userAgent", c.UserAgent), + log.String("requestClient.forwardedForUserAgent", c.ForwardedForUserAgent), ccField, } } diff --git a/internal/requestclient/grpc.go b/internal/requestclient/grpc.go index 5364a0d75f5..eaf6238c9d4 100644 --- a/internal/requestclient/grpc.go +++ b/internal/requestclient/grpc.go @@ -23,16 +23,20 @@ func (Propagator) FromContext(ctx context.Context) metadata.MD { return metadata.New(nil) } + forwardedForUserAgent := client.ForwardedForUserAgent + if forwardedForUserAgent == "" { + forwardedForUserAgent = client.UserAgent + } + return metadata.Pairs( headerKeyClientIP, client.IP, headerKeyForwardedFor, client.ForwardedFor, - headerKeyUserAgent, client.UserAgent, + headerKeyForwardedForUserAgent, forwardedForUserAgent, ) } func (Propagator) InjectContext(ctx context.Context, md metadata.MD) context.Context { - var ip string - var forwardedFor string + var ip, forwardedFor, forwardedForUserAgent, currentUserAgent string if vals := md.Get(headerKeyClientIP); len(vals) > 0 { ip = vals[0] @@ -42,6 +46,16 @@ func (Propagator) InjectContext(ctx context.Context, md metadata.MD) context.Con forwardedFor = vals[0] } + if vals := md.Get(headerKeyUserAgent); len(vals) > 0 { + currentUserAgent = vals[0] + } + + if vals := md.Get(headerKeyForwardedForUserAgent); len(vals) > 0 { + forwardedForUserAgent = vals[0] + } else { + forwardedForUserAgent = currentUserAgent + } + if ip == "" { p, ok := peer.FromContext(ctx) if ok && p != nil { @@ -50,8 +64,10 @@ func (Propagator) InjectContext(ctx context.Context, md metadata.MD) context.Con } c := Client{ - IP: ip, - ForwardedFor: forwardedFor, + IP: ip, + ForwardedFor: forwardedFor, + UserAgent: currentUserAgent, + ForwardedForUserAgent: forwardedForUserAgent, } return WithClient(ctx, &c) } diff --git a/internal/requestclient/grpc_test.go b/internal/requestclient/grpc_test.go index 8b93bb1367b..3c7cf6834ad 100644 --- a/internal/requestclient/grpc_test.go +++ b/internal/requestclient/grpc_test.go @@ -93,6 +93,29 @@ func TestPropagator(t *testing.T) { ForwardedFor: "192.168.1.2", }, }, + { + name: "client with user-agent sets forwarded-for-user-agent", + + requestClient: &Client{ + UserAgent: "Sourcegraph-Bot", + }, + + wantClient: &Client{ + ForwardedForUserAgent: "Sourcegraph-Bot", + }, + }, + { + name: "client with forwarded-for-user-agent drops the current user-agent", + + requestClient: &Client{ + UserAgent: "Not-Sourcegraph-Bot", + ForwardedForUserAgent: "Sourcegraph-Bot", + }, + + wantClient: &Client{ + ForwardedForUserAgent: "Sourcegraph-Bot", + }, + }, } for _, test := range tests { @@ -118,6 +141,7 @@ func TestPropagator(t *testing.T) { assert.Equal(t, test.wantClient.IP, rc.IP) assert.Equal(t, test.wantClient.ForwardedFor, rc.ForwardedFor) assert.Equal(t, test.wantClient.UserAgent, rc.UserAgent) + assert.Equal(t, test.wantClient.ForwardedForUserAgent, rc.ForwardedForUserAgent) }) } } diff --git a/internal/requestclient/http.go b/internal/requestclient/http.go index 361b265d412..8293c8f252c 100644 --- a/internal/requestclient/http.go +++ b/internal/requestclient/http.go @@ -15,6 +15,9 @@ const ( // De-facto standard for identifying original IP address of a client: // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For headerKeyForwardedFor = "X-Forwarded-For" + // headerKeyForwardedForUserAgent propagates the first headerKeyUserAgent + // seen. + headerKeyForwardedForUserAgent = "X-Forwarded-For-User-Agent" // Standard for identifyying the application, operating system, vendor, // and/or version of the requesting user agent. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent @@ -39,9 +42,14 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { client := FromContext(req.Context()) if client != nil { + forwardedForUserAgent := client.ForwardedForUserAgent + if forwardedForUserAgent == "" { + forwardedForUserAgent = client.UserAgent + } req = req.Clone(req.Context()) // RoundTripper should not modify original request req.Header.Set(headerKeyClientIP, client.IP) req.Header.Set(headerKeyForwardedFor, client.ForwardedFor) + req.Header.Set(headerKeyForwardedForUserAgent, forwardedForUserAgent) } return t.RoundTripper.RoundTrip(req) @@ -106,10 +114,16 @@ func httpMiddleware(next http.Handler, external bool) http.Handler { } } + currentUserAgent := req.Header.Get(headerKeyUserAgent) + forwardedForUserAgent := currentUserAgent + if agent := req.Header.Get(headerKeyForwardedForUserAgent); agent != "" { + forwardedForUserAgent = agent + } ctxWithClient := WithClient(req.Context(), &Client{ - IP: strings.Split(req.RemoteAddr, ":")[0], - ForwardedFor: req.Header.Get(headerKeyForwardedFor), - UserAgent: req.Header.Get(headerKeyUserAgent), + IP: strings.Split(req.RemoteAddr, ":")[0], + ForwardedFor: req.Header.Get(headerKeyForwardedFor), + UserAgent: currentUserAgent, + ForwardedForUserAgent: forwardedForUserAgent, wafIPCountryCode: wafIPCountryCode, }) diff --git a/internal/requestclient/http_test.go b/internal/requestclient/http_test.go new file mode 100644 index 00000000000..8c95e7e0049 --- /dev/null +++ b/internal/requestclient/http_test.go @@ -0,0 +1,87 @@ +package requestclient + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/hexops/autogold/v2" + "github.com/stretchr/testify/require" +) + +type noopRoundTripper struct{ gotRequest *http.Request } + +func (n *noopRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + n.gotRequest = req + return nil, nil +} + +func TestHTTP(t *testing.T) { + tests := []struct { + name string + + requestClient *Client + + wantClient autogold.Value + }{ + { + name: "nil client", + wantClient: autogold.Expect(&Client{IP: "192.0.2.1"}), + }, + { + name: "non-nil empty client", + requestClient: &Client{}, + wantClient: autogold.Expect(&Client{IP: "192.0.2.1"}), + }, + { + name: "forwarded-for", + requestClient: &Client{ + ForwardedFor: "192.168.1.2", + }, + wantClient: autogold.Expect(&Client{IP: "192.0.2.1", ForwardedFor: "192.168.1.2"}), + }, + { + name: "client with user-agent sets forwarded-for-user-agent", + requestClient: &Client{ + UserAgent: "Sourcegraph-Bot", + }, + wantClient: autogold.Expect(&Client{IP: "192.0.2.1", ForwardedForUserAgent: "Sourcegraph-Bot"}), + }, + { + name: "client with forwarded-for-user-agent drops the current user-agent", + requestClient: &Client{ + UserAgent: "Not-Sourcegraph-Bot", + ForwardedForUserAgent: "Sourcegraph-Bot", + }, + wantClient: autogold.Expect(&Client{IP: "192.0.2.1", ForwardedForUserAgent: "Sourcegraph-Bot"}), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + requestCtx := context.Background() + if test.requestClient != nil { + requestCtx = WithClient(requestCtx, test.requestClient) + } + + rt := &noopRoundTripper{} + _, err := (&HTTPTransport{RoundTripper: rt}). + RoundTrip( + httptest.NewRequest(http.MethodGet, "/", nil). + WithContext(requestCtx), + ) + require.NoError(t, err) + + var rc *Client + httpMiddleware( + http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + rc = FromContext(r.Context()) + }), + false, + ).ServeHTTP(httptest.NewRecorder(), rt.gotRequest) + + require.NotNil(t, rc) + test.wantClient.Equal(t, rc, autogold.ExportedOnly()) + }) + } +}