gRPC: add support for accesslog interceptor in gitserver package (#51437)

Co-authored-by: Jean-Hadrien Chabran <jh@chabran.fr>
Co-authored-by: Alex Ostrikov <alex.ostrikov@sourcegraph.com>
This commit is contained in:
Geoffrey Gilmore 2023-05-15 09:16:42 -07:00 committed by GitHub
parent 1206baaea9
commit aeef1a7d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1211 additions and 312 deletions

2
BUILD.bazel generated
View File

@ -85,6 +85,8 @@ gazelle(
# Because the current implementation of rules_go uses the old protoc grpc compiler, we have to declare our own, and declare it manually in the build files.
# See https://github.com/bazelbuild/rules_go/issues/3022
# gazelle:go_grpc_compilers //:gen-go-grpc,@io_bazel_rules_go//proto:go_proto
go_proto_compiler(
name = "gen-go-grpc",
plugin = "@org_golang_google_grpc_cmd_protoc_gen_go_grpc//:protoc-gen-go-grpc",

View File

@ -202,6 +202,26 @@ go_repository(
version = "v1.14.1",
)
# Overrides the default provided protobuf dep from rules_go by a more
# recent one.
go_repository(
name = "org_golang_google_protobuf",
build_file_proto_mode = "disable_global",
importpath = "google.golang.org/protobuf",
sum = "h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM=",
version = "v1.29.1",
) # keep
# Pin protoc-gen-go-grpc to 1.3.0
# See also //:gen-go-grpc
go_repository(
name = "org_golang_google_grpc_cmd_protoc_gen_go_grpc",
build_file_proto_mode = "disable_global",
importpath = "google.golang.org/grpc/cmd/protoc-gen-go-grpc",
sum = "h1:rNBFJjBCOgVr9pWD7rs/knKL4FRTKgpZmsRfV214zcA=",
version = "v1.3.0",
) # keep
# gazelle:repository_macro deps.bzl%go_dependencies
go_dependencies()

View File

@ -13,6 +13,10 @@ js_library(
srcs = glob([
"*.graphql",
]),
visibility = [
"//client/backstage-backend/node_modules/@sourcegraph/shared/dev:__pkg__",
"//client/shared/dev:__pkg__",
],
)
linter_bin.graphql_schema_linter_test(

View File

@ -34,7 +34,7 @@ go_library(
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server",
visibility = ["//visibility:public"],
deps = [
"//cmd/gitserver/server/internal/accesslog",
"//cmd/gitserver/server/accesslog",
"//cmd/gitserver/server/internal/cacert",
"//internal/actor",
"//internal/api",

View File

@ -3,12 +3,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "accesslog",
srcs = ["accesslog.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog",
visibility = ["//cmd/gitserver/server:__subpackages__"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog",
visibility = ["//cmd/gitserver:__subpackages__"],
deps = [
"//internal/audit",
"//internal/conf/conftypes",
"@com_github_sourcegraph_log//:log",
"@org_golang_google_grpc//:go_default_library",
"@org_uber_go_atomic//:atomic",
],
)
@ -26,5 +27,6 @@ go_test(
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_grpc//:go_default_library",
],
)

View File

@ -0,0 +1,210 @@
// accesslog provides instrumentation to record logs of access made by a given actor to a repo at
// the http handler level.
// access logs may optionally (as per site configuration) be included in the audit log.
package accesslog
import (
"context"
"net/http"
"sync"
"github.com/sourcegraph/log"
"go.uber.org/atomic"
"google.golang.org/grpc"
"github.com/sourcegraph/sourcegraph/internal/audit"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
)
type contextKey struct{}
type paramsContext struct {
mu sync.Mutex
repo string
metadata []log.Field
}
func (pc *paramsContext) Set(repo string, metadata ...log.Field) {
pc.mu.Lock()
defer pc.mu.Unlock()
pc.repo = repo
pc.metadata = metadata
}
func (pc *paramsContext) Get() (repo string, metadata []log.Field) {
pc.mu.Lock()
defer pc.mu.Unlock()
return pc.repo, pc.metadata
}
// Record updates a mutable unexported field stored in the context,
// making it available for Middleware to log at the end of the middleware
// chain.
func Record(ctx context.Context, repo string, meta ...log.Field) {
pc := fromContext(ctx)
if pc == nil {
return
}
pc.Set(repo, meta...)
}
func withContext(ctx context.Context, pc *paramsContext) context.Context {
return context.WithValue(ctx, contextKey{}, pc)
}
func fromContext(ctx context.Context) *paramsContext {
pc, ok := ctx.Value(contextKey{}).(*paramsContext)
if !ok || pc == nil {
return nil
}
return pc
}
// accessLogger watches the site configuration and logs accesses (if enabled).
type accessLogger struct {
logger log.Logger
logEnabled *atomic.Bool
watcher conftypes.WatchableSiteConfig
watchEnabledOnce sync.Once
}
func newAccessLogger(logger log.Logger, watcher conftypes.WatchableSiteConfig) *accessLogger {
return &accessLogger{
logger: logger,
logEnabled: atomic.NewBool(false),
watcher: watcher,
}
}
// messages are defined here to make assertions in testing.
const (
accessEventMessage = "access"
accessLoggingEnabledMessage = "access logging enabled"
)
func (a *accessLogger) maybeLog(ctx context.Context) {
// If access logging is not enabled, we are done
if !a.isEnabled() {
return
}
// Otherwise, log this access
// Now we've gone through the handler, we can get the params that the handler
// got from the request body.
paramsCtx := fromContext(ctx)
if paramsCtx == nil {
return
}
repository, metadata := paramsCtx.Get()
if repository == "" {
return
}
var fields []log.Field
if paramsCtx != nil {
params := append([]log.Field{log.String("repo", repository)}, metadata...)
fields = append(fields, log.Object("params", params...))
} else {
fields = append(fields, log.String("params", "nil"))
}
audit.Log(ctx, a.logger, audit.Record{
Entity: "gitserver",
Action: "access",
Fields: fields,
})
}
func (a *accessLogger) isEnabled() bool {
a.watchEnabledOnce.Do(func() {
// Initialize the logEnabled field with the current value
logEnabled := audit.IsEnabled(a.watcher.SiteConfig(), audit.GitserverAccess)
if logEnabled {
a.logger.Info(accessLoggingEnabledMessage)
}
a.logEnabled.Store(logEnabled)
// Watch for changes to the site config
a.watcher.Watch(func() {
newShouldLog := audit.IsEnabled(a.watcher.SiteConfig(), audit.GitserverAccess)
changed := a.logEnabled.Swap(newShouldLog) != newShouldLog
if changed {
if newShouldLog {
a.logger.Info(accessLoggingEnabledMessage)
} else {
a.logger.Info("access logging disabled")
}
}
})
})
return a.logEnabled.Load()
}
// HTTPMiddleware will extract actor information and params collected by Record that has
// been stored in the context, in order to log a trace of the access.
func HTTPMiddleware(logger log.Logger, watcher conftypes.WatchableSiteConfig, next http.HandlerFunc) http.HandlerFunc {
a := newAccessLogger(logger, watcher)
return func(w http.ResponseWriter, r *http.Request) {
// Prepare the context to hold the params which the handler is going to set.
ctx := withContext(r.Context(), &paramsContext{})
r = r.WithContext(ctx)
// Call the next handler in the chain.
next(w, r)
// Log the access
a.maybeLog(ctx)
}
}
// UnaryServerInterceptor returns a grpc.UnaryServerInterceptor that will extract actor information and params collected by Record that has
// been stored in the context in order to log a trace of the access.
func UnaryServerInterceptor(logger log.Logger, watcher conftypes.WatchableSiteConfig) grpc.UnaryServerInterceptor {
a := newAccessLogger(logger, watcher)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
ctx = withContext(ctx, &paramsContext{})
resp, err = handler(ctx, req)
a.maybeLog(ctx)
return resp, err
}
}
// StreamServerInterceptor returns a grpc.StreamServerInterceptor that will extract actor information and params collected by Record that has
// been stored in the context in order to log a trace of the access.
func StreamServerInterceptor(logger log.Logger, watcher conftypes.WatchableSiteConfig) grpc.StreamServerInterceptor {
a := newAccessLogger(logger, watcher)
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := withContext(ss.Context(), &paramsContext{})
ss = &wrappedServerStream{ServerStream: ss, ctx: ctx}
err := handler(srv, ss)
a.maybeLog(ctx)
return err
}
}
// wrappedServerStream wraps grpc.ServerStream to override the Context method.
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

View File

@ -0,0 +1,489 @@
package accesslog
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/requestclient"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestRecord(t *testing.T) {
t.Run("OK", func(t *testing.T) {
ctx := context.Background()
ctx = withContext(ctx, &paramsContext{})
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
Record(ctx, "github.com/foo/bar", meta...)
pc := fromContext(ctx)
require.NotNil(t, pc)
assert.Equal(t, "github.com/foo/bar", pc.repo)
assert.Equal(t, meta, pc.metadata)
})
t.Run("OK not initialized context", func(t *testing.T) {
ctx := context.Background()
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
Record(ctx, "github.com/foo/bar", meta...)
pc := fromContext(ctx)
assert.Nil(t, pc)
})
}
type accessLogConf struct {
disabled bool
callback func()
}
var _ conftypes.WatchableSiteConfig = &accessLogConf{}
func (a *accessLogConf) Watch(cb func()) { a.callback = cb }
func (a *accessLogConf) SiteConfig() schema.SiteConfiguration {
return schema.SiteConfiguration{
Log: &schema.Log{
AuditLog: &schema.AuditLog{
GitserverAccess: !a.disabled,
GraphQL: false,
InternalTraffic: false,
},
},
}
}
func TestHTTPMiddleware(t *testing.T) {
t.Run("OK for access log setting", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
ctx := req.Context()
ctx = requestclient.WithClient(ctx, &requestclient.Client{IP: "192.168.1.1"})
req = req.WithContext(ctx)
h.ServeHTTP(rec, req)
logs := exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
assert.Equal(t, "github.com/foo/bar", logs[1].Fields["params"].(map[string]any)["repo"])
auditFields := logs[1].Fields["audit"].(map[string]interface{})
assert.Equal(t, "gitserver", auditFields["entity"])
assert.NotEmpty(t, auditFields["auditId"])
actorFields := auditFields["actor"].(map[string]interface{})
assert.Equal(t, "unknown", actorFields["actorUID"])
assert.Equal(t, "192.168.1.1", actorFields["ip"])
assert.Equal(t, "", actorFields["X-Forwarded-For"])
})
t.Run("handle, no recording", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
var handled bool
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
handled = true
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(rec, req)
// Should have handled but not logged
assert.True(t, handled)
logs := exportLogs()
require.Len(t, logs, 1)
assert.NotEqual(t, accessEventMessage, logs[0].Message)
})
t.Run("disabled, then enabled", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
cfg := &accessLogConf{disabled: true}
var handled bool
h := HTTPMiddleware(logger, cfg, func(w http.ResponseWriter, r *http.Request) {
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
handled = true
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Request with access logging disabled
h.ServeHTTP(rec, req)
// Disabled, should have been handled but without a log message
assert.True(t, handled)
logs := exportLogs()
require.Len(t, logs, 0)
// Now we re-enable
handled = false
cfg.disabled = false
cfg.callback()
h.ServeHTTP(rec, req)
// Enabled, should have handled AND generated a log message
assert.True(t, handled)
logs = exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
})
}
func TestAccessLogGRPC(t *testing.T) {
var (
fakeIP = "192.168.1.1"
fakeRepositoryName = "github.com/foo/bar"
)
t.Run("basic recording and audit fields", func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{}
client := &requestclient.Client{IP: fakeIP}
interceptor := chainUnaryInterceptors(
mockClientUnaryInterceptor(client),
UnaryServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(ctx context.Context, req any) (any, error) {
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
handlerCalled = true
return req, nil
}
req := struct{}{}
info := &grpc.UnaryServerInfo{}
_, err := interceptor(context.Background(), req, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
logs := exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
assert.Equal(t, fakeRepositoryName, logs[1].Fields["params"].(map[string]any)["repo"])
auditFields := logs[1].Fields["audit"].(map[string]interface{})
assert.Equal(t, "gitserver", auditFields["entity"])
assert.NotEmpty(t, auditFields["auditId"])
actorFields := auditFields["actor"].(map[string]interface{})
assert.Equal(t, "unknown", actorFields["actorUID"])
assert.Equal(t, fakeIP, actorFields["ip"])
assert.Equal(t, "", actorFields["X-Forwarded-For"])
})
t.Run("stream", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{}
client := &requestclient.Client{IP: fakeIP}
streamInterceptor := chainStreamInterceptors(
mockClientStreamInterceptor(client),
StreamServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(srv interface{}, stream grpc.ServerStream) error {
ctx := stream.Context()
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
handlerCalled = true
return nil
}
srv := struct{}{}
ss := &testServerStream{ctx: context.Background()}
info := &grpc.StreamServerInfo{}
err := streamInterceptor(srv, ss, info, handler)
if err != nil {
t.Fatal(err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
logs := exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
assert.Equal(t, fakeRepositoryName, logs[1].Fields["params"].(map[string]any)["repo"])
auditFields := logs[1].Fields["audit"].(map[string]interface{})
assert.Equal(t, "gitserver", auditFields["entity"])
assert.NotEmpty(t, auditFields["auditId"])
actorFields := auditFields["actor"].(map[string]interface{})
assert.Equal(t, "unknown", actorFields["actorUID"])
assert.Equal(t, fakeIP, actorFields["ip"])
assert.Equal(t, "", actorFields["X-Forwarded-For"])
})
})
t.Run("handler, no recording", func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{}
client := &requestclient.Client{IP: fakeIP}
interceptor := chainUnaryInterceptors(
mockClientUnaryInterceptor(client),
UnaryServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(ctx context.Context, req any) (any, error) {
handlerCalled = true
return req, nil
}
req := struct{}{}
info := &grpc.UnaryServerInfo{}
_, err := interceptor(context.Background(), req, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
logs := exportLogs()
// Should have handled but not logged
require.Len(t, logs, 1)
assert.NotEqual(t, accessEventMessage, logs[0].Message)
})
})
t.Run("stream", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{}
client := &requestclient.Client{IP: fakeIP}
streamInterceptor := chainStreamInterceptors(
mockClientStreamInterceptor(client),
StreamServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(srv interface{}, stream grpc.ServerStream) error {
handlerCalled = true
return nil
}
srv := struct{}{}
ss := &testServerStream{ctx: context.Background()}
info := &grpc.StreamServerInfo{}
err := streamInterceptor(srv, ss, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
logs := exportLogs()
// Should have handled but not logged
require.Len(t, logs, 1)
assert.NotEqual(t, accessEventMessage, logs[0].Message)
})
t.Run("disabled, then enabled", func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{disabled: true}
client := &requestclient.Client{IP: fakeIP}
interceptor := chainUnaryInterceptors(
mockClientUnaryInterceptor(client),
UnaryServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(ctx context.Context, req any) (any, error) {
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
handlerCalled = true
return req, nil
}
req := struct{}{}
info := &grpc.UnaryServerInfo{}
_, err := interceptor(context.Background(), req, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
// Disabled, should have been handled but without a log message
logs := exportLogs()
require.Len(t, logs, 0)
// Now we re-enable
handlerCalled = false
configuration.disabled = false
configuration.callback()
_, err = interceptor(context.Background(), req, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
// Enabled, should have handled AND generated a log message
logs = exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
})
t.Run("stream", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
configuration := &accessLogConf{disabled: true}
client := &requestclient.Client{IP: fakeIP}
interceptor := chainStreamInterceptors(
mockClientStreamInterceptor(client),
StreamServerInterceptor(logger, configuration),
)
handlerCalled := false
handler := func(srv interface{}, stream grpc.ServerStream) error {
ctx := stream.Context()
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
handlerCalled = true
return nil
}
srv := struct{}{}
ss := &testServerStream{ctx: context.Background()}
info := &grpc.StreamServerInfo{}
err := interceptor(srv, ss, info, handler)
if err != nil {
t.Fatal(err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
// Disabled, should have been handled but without a log message
logs := exportLogs()
require.Len(t, logs, 0)
// Now we re-enable
handlerCalled = false
configuration.disabled = false
configuration.callback()
err = interceptor(srv, ss, info, handler)
if err != nil {
t.Fatalf("failed to call interceptor: %v", err)
}
if !handlerCalled {
t.Fatal("handler not called")
}
// Enabled, should have handled AND generated a log message
logs = exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
})
})
}
func mockClientUnaryInterceptor(client *requestclient.Client) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
ctx = requestclient.WithClient(ctx, client)
return handler(ctx, req)
}
}
func mockClientStreamInterceptor(client *requestclient.Client) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := requestclient.WithClient(ss.Context(), client)
return handler(srv, &wrappedServerStream{ss, ctx})
}
}
func chainUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if len(interceptors) == 0 {
return handler(ctx, req)
}
return interceptors[0](ctx, req, info, func(ctx context.Context, req any) (any, error) {
return chainUnaryInterceptors(interceptors[1:]...)(ctx, req, info, handler)
})
}
}
func chainStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if len(interceptors) == 0 {
return handler(srv, ss)
}
return interceptors[0](srv, ss, info, func(srv any, ss grpc.ServerStream) error {
return chainStreamInterceptors(interceptors[1:]...)(srv, ss, info, handler)
})
}
}
// testServerStream is a mock implementation of grpc.ServerStream for testing.
type testServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (m *testServerStream) Context() context.Context {
return m.ctx
}
var _ grpc.ServerStream = &testServerStream{}

View File

@ -6,7 +6,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
"github.com/sourcegraph/sourcegraph/internal/gitserver/gitdomain"
"github.com/sourcegraph/sourcegraph/internal/gitserver/protocol"
)

View File

@ -14,7 +14,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/env"
"github.com/sourcegraph/sourcegraph/lib/gitservice"

View File

@ -1,128 +0,0 @@
// accesslog provides instrumentation to record logs of access made by a given actor to a repo at
// the http handler level.
// access logs may optionally (as per site configuration) be included in the audit log.
package accesslog
import (
"context"
"net/http"
"github.com/sourcegraph/log"
"go.uber.org/atomic"
"github.com/sourcegraph/sourcegraph/internal/audit"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
)
type contextKey struct{}
type paramsContext struct {
repo string
metadata []log.Field
}
// Record updates a mutable unexported field stored in the context,
// making it available for Middleware to log at the end of the middleware
// chain.
func Record(ctx context.Context, repo string, meta ...log.Field) {
pc := fromContext(ctx)
if pc == nil {
return
}
pc.repo = repo
pc.metadata = meta
}
func withContext(ctx context.Context, pc *paramsContext) context.Context {
return context.WithValue(ctx, contextKey{}, pc)
}
func fromContext(ctx context.Context) *paramsContext {
pc, ok := ctx.Value(contextKey{}).(*paramsContext)
if !ok || pc == nil {
return nil
}
return pc
}
// accessLogger handles HTTP requests and, if logEnabled, logs accesses.
type accessLogger struct {
logger log.Logger
next http.HandlerFunc
logEnabled *atomic.Bool
}
var _ http.Handler = &accessLogger{}
// messages are defined here to make assertions in testing.
const (
accessEventMessage = "access"
accessLoggingEnabledMessage = "access logging enabled"
)
func (a *accessLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Prepare the context to hold the params which the handler is going to set.
ctx := r.Context()
r = r.WithContext(withContext(ctx, &paramsContext{}))
a.next(w, r)
// If access logging is not enabled, we are done
if !a.logEnabled.Load() {
return
}
// Otherwise, log this access
var fields []log.Field
// Now we've gone through the handler, we can get the params that the handler
// got from the request body.
paramsCtx := fromContext(r.Context())
if paramsCtx == nil {
return
}
if paramsCtx.repo == "" {
return
}
if paramsCtx != nil {
params := append([]log.Field{log.String("repo", paramsCtx.repo)}, paramsCtx.metadata...)
fields = append(fields, log.Object("params", params...))
} else {
fields = append(fields, log.String("params", "nil"))
}
audit.Log(ctx, a.logger, audit.Record{
Entity: "gitserver",
Action: "access",
Fields: fields,
})
}
// HTTPMiddleware will extract actor information and params collected by Record that has
// been stored in the context, in order to log a trace of the access.
func HTTPMiddleware(logger log.Logger, watcher conftypes.WatchableSiteConfig, next http.HandlerFunc) http.HandlerFunc {
handler := &accessLogger{
logger: logger,
next: next,
logEnabled: atomic.NewBool(audit.IsEnabled(watcher.SiteConfig(), audit.GitserverAccess)),
}
if handler.logEnabled.Load() {
logger.Info(accessLoggingEnabledMessage)
}
// Allow live toggling of access logging
watcher.Watch(func() {
newShouldLog := audit.IsEnabled(watcher.SiteConfig(), audit.GitserverAccess)
changed := handler.logEnabled.Swap(newShouldLog) != newShouldLog
if changed {
if newShouldLog {
logger.Info(accessLoggingEnabledMessage)
} else {
logger.Info("access logging disabled")
}
}
})
return handler.ServeHTTP
}

View File

@ -1,145 +0,0 @@
package accesslog
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/sourcegraph/log"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/requestclient"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestRecord(t *testing.T) {
t.Run("OK", func(t *testing.T) {
ctx := context.Background()
ctx = withContext(ctx, &paramsContext{})
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
Record(ctx, "github.com/foo/bar", meta...)
pc := fromContext(ctx)
require.NotNil(t, pc)
assert.Equal(t, "github.com/foo/bar", pc.repo)
assert.Equal(t, meta, pc.metadata)
})
t.Run("OK not initialized context", func(t *testing.T) {
ctx := context.Background()
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
Record(ctx, "github.com/foo/bar", meta...)
pc := fromContext(ctx)
assert.Nil(t, pc)
})
}
type accessLogConf struct {
disabled bool
callback func()
}
var _ conftypes.WatchableSiteConfig = &accessLogConf{}
func (a *accessLogConf) Watch(cb func()) { a.callback = cb }
func (a *accessLogConf) SiteConfig() schema.SiteConfiguration {
return schema.SiteConfiguration{
Log: &schema.Log{
AuditLog: &schema.AuditLog{
GitserverAccess: !a.disabled,
GraphQL: false,
InternalTraffic: false,
},
},
}
}
func TestHTTPMiddleware(t *testing.T) {
t.Run("OK for access log setting", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
ctx := req.Context()
ctx = requestclient.WithClient(ctx, &requestclient.Client{IP: "192.168.1.1"})
req = req.WithContext(ctx)
h.ServeHTTP(rec, req)
logs := exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
assert.Equal(t, "github.com/foo/bar", logs[1].Fields["params"].(map[string]any)["repo"])
auditFields := logs[1].Fields["audit"].(map[string]interface{})
assert.Equal(t, "gitserver", auditFields["entity"])
assert.NotEmpty(t, auditFields["auditId"])
actorFields := auditFields["actor"].(map[string]interface{})
assert.Equal(t, "unknown", actorFields["actorUID"])
assert.Equal(t, "192.168.1.1", actorFields["ip"])
assert.Equal(t, "", actorFields["X-Forwarded-For"])
})
t.Run("handle, no recording", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
var handled bool
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
handled = true
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(rec, req)
// Should have handled but not logged
assert.True(t, handled)
logs := exportLogs()
require.Len(t, logs, 1)
assert.NotEqual(t, accessEventMessage, logs[0].Message)
})
t.Run("disabled, then enabled", func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
cfg := &accessLogConf{disabled: true}
var handled bool
h := HTTPMiddleware(logger, cfg, func(w http.ResponseWriter, r *http.Request) {
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
handled = true
})
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Request with access logging disabled
h.ServeHTTP(rec, req)
// Disabled, should have been handled but without a log message
assert.True(t, handled)
logs := exportLogs()
require.Len(t, logs, 0)
// Now we re-enable
handled = false
cfg.disabled = false
cfg.callback()
h.ServeHTTP(rec, req)
// Enabled, should have handled AND generated a log message
assert.True(t, handled)
logs = exportLogs()
require.Len(t, logs, 2)
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
assert.Contains(t, logs[1].Message, accessEventMessage)
})
}

View File

@ -37,7 +37,7 @@ import (
"github.com/sourcegraph/conc"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/conf"
@ -1643,7 +1643,7 @@ func (s *Server) handleExec(w http.ResponseWriter, r *http.Request) {
return
}
// Log which which actor is accessing the repo.
// Log which actor is accessing the repo.
args := req.Args
cmd := ""
if len(req.Args) > 0 {

View File

@ -9,6 +9,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/gitserver/gitdomain"
@ -38,18 +39,30 @@ func (gs *GRPCServer) Exec(req *proto.ExecRequest, ss proto.GitserverService_Exe
})
})
// Log which actor is accessing the repo.
args := req.GetArgs()
cmd := ""
if len(args) > 0 {
cmd = args[0]
args = args[1:]
}
accesslog.Record(ss.Context(), req.GetRepo(),
log.String("cmd", cmd),
log.Strings("args", args),
)
// TODO(mucles): set user agent from all grpc clients
return gs.doExec(ss.Context(), gs.Server.Logger, &internalReq, "unknown-grpc-client", w)
}
func (gs *GRPCServer) Archive(req *proto.ArchiveRequest, ss proto.GitserverService_ArchiveServer) error {
//TODO(mucles): re-enable access logging (see server.go handleArchive)
// Log which which actor is accessing the repo.
// accesslog.Record(ctx, req.Repo,
// log.String("treeish", req.Treeish),
// log.String("format", req.Format),
// log.Strings("path", req.Pathspecs),
// )
accesslog.Record(ss.Context(), req.Repo,
log.String("treeish", req.Treeish),
log.String("format", req.Format),
log.Strings("path", req.Pathspecs),
)
if err := checkSpecArgSafety(req.GetTreeish()); err != nil {
return status.Error(codes.InvalidArgument, err.Error())

View File

@ -11,6 +11,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//cmd/gitserver/server",
"//cmd/gitserver/server/accesslog",
"//internal/actor",
"//internal/api",
"//internal/authz",
@ -46,6 +47,7 @@ go_library(
"//schema",
"@com_github_json_iterator_go//:go",
"@com_github_sourcegraph_log//:log",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_x_sync//semaphore",
"@org_golang_x_time//rate",
],
@ -66,7 +68,9 @@ go_test(
"//internal/database",
"//internal/extsvc",
"//internal/types",
"@com_github_google_go_cmp//cmp",
"@com_github_sourcegraph_log//:log",
"@com_github_sourcegraph_log//logtest",
"@org_golang_google_grpc//:go_default_library",
],
)

View File

@ -17,8 +17,10 @@ import (
"github.com/sourcegraph/log"
"golang.org/x/sync/semaphore"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/authz"
@ -148,7 +150,25 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
GlobalBatchLogSemaphore: semaphore.NewWeighted(int64(batchLogGlobalConcurrencyLimit)),
}
grpcServer := defaults.NewServer(logger)
configurationWatcher := conf.DefaultClient()
var additionalServerOptions []grpc.ServerOption
for method, scopedLogger := range map[string]log.Logger{
proto.GitserverService_Exec_FullMethodName: logger.Scoped("exec.accesslog", "exec endpoint access log"),
proto.GitserverService_Archive_FullMethodName: logger.Scoped("archive.accesslog", "archive endpoint access log"),
} {
streamInterceptor := accesslog.StreamServerInterceptor(scopedLogger, configurationWatcher)
unaryInterceptor := accesslog.UnaryServerInterceptor(scopedLogger, configurationWatcher)
additionalServerOptions = append(additionalServerOptions,
grpc.ChainStreamInterceptor(methodSpecificStreamInterceptor(method, streamInterceptor)),
grpc.ChainUnaryInterceptor(methodSpecificUnaryInterceptor(method, unaryInterceptor)),
)
}
grpcServer := defaults.NewServer(logger, additionalServerOptions...)
proto.RegisterGitserverServiceServer(grpcServer, &server.GRPCServer{
Server: &gitserver,
})
@ -538,3 +558,29 @@ func getAddr() string {
}
return addr
}
// methodSpecificStreamInterceptor returns a gRPC stream server interceptor that only calls the next interceptor if the method matches.
//
// The returned interceptor will call next if the invoked gRPC method matches the method parameter. Otherwise, it will call handler directly.
func methodSpecificStreamInterceptor(method string, next grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if method != info.FullMethod {
return handler(srv, ss)
}
return next(srv, ss, info, handler)
}
}
// methodSpecificUnaryInterceptor returns a gRPC unary server interceptor that only calls the next interceptor if the method matches.
//
// The returned interceptor will call next if the invoked gRPC method matches the method parameter. Otherwise, it will call handler directly.
func methodSpecificUnaryInterceptor(method string, next grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if method != info.FullMethod {
return handler(ctx, req)
}
return next(ctx, req, info, handler)
}
}

View File

@ -8,8 +8,10 @@ import (
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/log"
"github.com/sourcegraph/log/logtest"
"google.golang.org/grpc"
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server"
"github.com/sourcegraph/sourcegraph/internal/api"
@ -112,3 +114,118 @@ func TestGetVCSSyncer(t *testing.T) {
t.Fatalf("Want *server.PerforceDepotSyncer, got %T", s)
}
}
func TestMethodSpecificStreamInterceptor(t *testing.T) {
tests := []struct {
name string
matchedMethod string
testMethod string
expectedInterceptorCalled bool
}{
{
name: "allowed method",
matchedMethod: "allowedMethod",
testMethod: "allowedMethod",
expectedInterceptorCalled: true,
},
{
name: "not allowed method",
matchedMethod: "allowedMethod",
testMethod: "otherMethod",
expectedInterceptorCalled: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
interceptorCalled := false
interceptor := methodSpecificStreamInterceptor(test.matchedMethod, func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
interceptorCalled = true
return handler(srv, ss)
})
handlerCalled := false
noopHandler := func(srv any, ss grpc.ServerStream) error {
handlerCalled = true
return nil
}
err := interceptor(nil, nil, &grpc.StreamServerInfo{FullMethod: test.testMethod}, noopHandler)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !handlerCalled {
t.Error("expected handler to be called")
}
if diff := cmp.Diff(test.expectedInterceptorCalled, interceptorCalled); diff != "" {
t.Fatalf("unexpected interceptor called value (-want +got):\n%s", diff)
}
})
}
}
func TestMethodSpecificUnaryInterceptor(t *testing.T) {
tests := []struct {
name string
matchedMethod string
testMethod string
expectedInterceptorCalled bool
}{
{
name: "allowed method",
matchedMethod: "allowedMethod",
testMethod: "allowedMethod",
expectedInterceptorCalled: true,
},
{
name: "not allowed method",
matchedMethod: "allowedMethod",
testMethod: "otherMethod",
expectedInterceptorCalled: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
interceptorCalled := false
interceptor := methodSpecificUnaryInterceptor(test.matchedMethod, func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
interceptorCalled = true
return handler(ctx, req)
})
handlerCalled := false
noopHandler := func(ctx context.Context, req any) (any, error) {
handlerCalled = true
return nil, nil
}
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: test.testMethod}, noopHandler)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !handlerCalled {
t.Error("expected handler to be called")
}
if diff := cmp.Diff(test.expectedInterceptorCalled, interceptorCalled); diff != "" {
t.Fatalf("unexpected interceptor called value (-want +got):\n%s", diff)
}
})
}
}

View File

@ -8978,14 +8978,6 @@ def go_dependencies():
sum = "h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc=",
version = "v1.53.0",
)
go_repository(
name = "org_golang_google_grpc_cmd_protoc_gen_go_grpc",
build_file_proto_mode = "disable_global",
importpath = "google.golang.org/grpc/cmd/protoc-gen-go-grpc",
sum = "h1:TLkBREm4nIsEcexnCjgQd5GQWaHcqMzwQV0TX9pq8S0=",
version = "v1.2.0",
) # keep
go_repository(
name = "org_golang_google_grpc_examples",
build_file_proto_mode = "disable_global",

View File

@ -14,7 +14,10 @@ proto_library(
go_proto_library(
name = "v1_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
compilers = [
"//:gen-go-grpc",
"@io_bazel_rules_go//proto:go_proto",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/gitserver/v1",
proto = ":v1_proto",
visibility = ["//visibility:private"],

View File

@ -5,7 +5,6 @@ go_library(
srcs = [
"grpc.go",
"panics.go",
"propagator.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc",
visibility = ["//:__subpackages__"],
@ -14,7 +13,6 @@ go_library(
"@com_github_sourcegraph_log//:log",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_x_net//http2",
"@org_golang_x_net//http2/h2c",

View File

@ -12,6 +12,8 @@ go_library(
"//internal/actor",
"//internal/env",
"//internal/grpc",
"//internal/grpc/propagator",
"//internal/requestclient",
"//internal/trace/policy",
"//internal/ttlcache",
"//lib/errors",

View File

@ -20,6 +20,8 @@ import (
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/env"
internalgrpc "github.com/sourcegraph/sourcegraph/internal/grpc"
"github.com/sourcegraph/sourcegraph/internal/grpc/propagator"
"github.com/sourcegraph/sourcegraph/internal/requestclient"
"github.com/sourcegraph/sourcegraph/internal/trace/policy"
)
@ -47,14 +49,16 @@ func DialOptions() []grpc.DialOption {
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithChainStreamInterceptor(
grpc_prometheus.StreamClientInterceptor(metrics),
internalgrpc.StreamClientPropagator(actor.ActorPropagator{}),
internalgrpc.StreamClientPropagator(policy.ShouldTracePropagator{}),
propagator.StreamClientPropagator(actor.ActorPropagator{}),
propagator.StreamClientPropagator(policy.ShouldTracePropagator{}),
propagator.StreamClientPropagator(requestclient.Propagator{}),
otelStreamInterceptor,
),
grpc.WithChainUnaryInterceptor(
grpc_prometheus.UnaryClientInterceptor(metrics),
internalgrpc.UnaryClientPropagator(actor.ActorPropagator{}),
internalgrpc.UnaryClientPropagator(policy.ShouldTracePropagator{}),
propagator.UnaryClientPropagator(actor.ActorPropagator{}),
propagator.UnaryClientPropagator(policy.ShouldTracePropagator{}),
propagator.UnaryClientPropagator(requestclient.Propagator{}),
otelUnaryInterceptor,
),
}
@ -87,15 +91,17 @@ func ServerOptions(logger log.Logger) []grpc.ServerOption {
grpc.ChainStreamInterceptor(
internalgrpc.NewStreamPanicCatcher(logger),
grpc_prometheus.StreamServerInterceptor(metrics),
internalgrpc.StreamServerPropagator(actor.ActorPropagator{}),
internalgrpc.StreamServerPropagator(policy.ShouldTracePropagator{}),
propagator.StreamServerPropagator(requestclient.Propagator{}),
propagator.StreamServerPropagator(actor.ActorPropagator{}),
propagator.StreamServerPropagator(policy.ShouldTracePropagator{}),
otelgrpc.StreamServerInterceptor(),
),
grpc.ChainUnaryInterceptor(
internalgrpc.NewUnaryPanicCatcher(logger),
grpc_prometheus.UnaryServerInterceptor(metrics),
internalgrpc.UnaryServerPropagator(actor.ActorPropagator{}),
internalgrpc.UnaryServerPropagator(policy.ShouldTracePropagator{}),
propagator.UnaryServerPropagator(requestclient.Propagator{}),
propagator.UnaryServerPropagator(actor.ActorPropagator{}),
propagator.UnaryServerPropagator(policy.ShouldTracePropagator{}),
otelgrpc.UnaryServerInterceptor(),
),
}

12
internal/grpc/propagator/BUILD.bazel generated Normal file
View File

@ -0,0 +1,12 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "propagator",
srcs = ["propagator.go"],
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/propagator",
visibility = ["//:__subpackages__"],
deps = [
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//metadata",
],
)

View File

@ -1,4 +1,4 @@
package grpc
package propagator
import (
"context"
@ -86,7 +86,7 @@ func StreamServerPropagator(prop Propagator) grpc.StreamServerInterceptor {
}
}
// StreamServerPropagator returns an interceptor that will use the given propagator
// UnaryServerPropagator returns an interceptor that will use the given propagator
// to translate some metadata back into the context for the RPC handler. The client
// should be configured with an interceptor that uses the same propagator.
func UnaryServerPropagator(prop Propagator) grpc.UnaryServerInterceptor {

View File

@ -14,7 +14,10 @@ proto_library(
go_proto_library(
name = "v1_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
compilers = [
"//:gen-go-grpc",
"@io_bazel_rules_go//proto:go_proto",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/repoupdater/v1",
proto = ":v1_proto",
visibility = ["//visibility:private"],

View File

@ -1,11 +1,27 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "requestclient",
srcs = [
"client.go",
"grpc.go",
"http.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/requestclient",
visibility = ["//:__subpackages__"],
deps = [
"//internal/grpc/propagator",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//peer",
],
)
go_test(
name = "requestclient_test",
srcs = ["grpc_test.go"],
embed = [":requestclient"],
deps = [
"@com_github_google_go_cmp//cmp",
"@org_golang_google_grpc//peer",
],
)

View File

@ -0,0 +1,70 @@
package requestclient
import (
"context"
"net"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
internalgrpc "github.com/sourcegraph/sourcegraph/internal/grpc/propagator"
)
// Propagator is a github.com/sourcegraph/sourcegraph/internal/grpc/propagator.Propagator that Propagates
// the Client in the context across the gRPC client / server request boundary.
//
// If the context does not contain a Client, the server will backfill the Client's IP with the IP of the address
// that the request came from. (see https://pkg.go.dev/google.golang.org/grpc/peer for more information)
type Propagator struct{}
func (Propagator) FromContext(ctx context.Context) metadata.MD {
client := FromContext(ctx)
if client == nil {
return metadata.New(nil)
}
return metadata.Pairs(
headerKeyClientIP, client.IP,
headerKeyForwardedFor, client.ForwardedFor,
)
}
func (Propagator) InjectContext(ctx context.Context, md metadata.MD) context.Context {
var ip string
var forwardedFor string
if vals := md.Get(headerKeyClientIP); len(vals) > 0 {
ip = vals[0]
}
if vals := md.Get(headerKeyForwardedFor); len(vals) > 0 {
forwardedFor = vals[0]
}
if ip == "" {
p, ok := peer.FromContext(ctx)
if ok && p != nil {
ip = baseIP(p.Addr)
}
}
c := Client{
IP: ip,
ForwardedFor: forwardedFor,
}
return WithClient(ctx, &c)
}
var _ internalgrpc.Propagator = Propagator{}
// baseIP returns the base IP address of the given net.Addr
func baseIP(addr net.Addr) string {
switch a := addr.(type) {
case *net.TCPAddr:
return a.IP.String()
case *net.UDPAddr:
return a.IP.String()
default:
return addr.String()
}
}

View File

@ -0,0 +1,157 @@
package requestclient
import (
"context"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/peer"
)
func TestPropagator(t *testing.T) {
tests := []struct {
name string
requestClient *Client
requestPeer *peer.Peer
wantClient *Client
}{
{
name: "no client or peer",
wantClient: &Client{},
},
{
name: "client with no peer",
requestClient: &Client{
IP: "192.168.1.1",
ForwardedFor: "192.168.1.2",
},
wantClient: &Client{
IP: "192.168.1.1",
ForwardedFor: "192.168.1.2",
},
},
{
name: "peer only (nil client)",
requestPeer: &peer.Peer{
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.1")},
},
wantClient: &Client{
IP: "192.168.1.1",
},
},
{
name: "peer only (non-nil empty client)",
requestClient: &Client{},
requestPeer: &peer.Peer{
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.1")},
},
wantClient: &Client{
IP: "192.168.1.1",
},
},
{
name: "client should override peer",
requestClient: &Client{
IP: "192.168.1.1",
ForwardedFor: "192.168.1.2",
},
requestPeer: &peer.Peer{
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.3")},
},
wantClient: &Client{
IP: "192.168.1.1",
ForwardedFor: "192.168.1.2",
},
},
{
name: "client for ForwardedFor, peer for IP",
requestClient: &Client{
ForwardedFor: "192.168.1.2",
},
requestPeer: &peer.Peer{
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.3")},
},
wantClient: &Client{
IP: "192.168.1.3",
ForwardedFor: "192.168.1.2",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
requestCtx := context.Background()
if test.requestClient != nil {
requestCtx = WithClient(requestCtx, test.requestClient)
}
if test.requestPeer != nil {
requestCtx = peer.NewContext(requestCtx, test.requestPeer)
}
propagator := &Propagator{}
md := propagator.FromContext(requestCtx)
resultCtx := propagator.InjectContext(requestCtx, md)
if diff := cmp.Diff(test.wantClient, FromContext(resultCtx)); diff != "" {
t.Errorf("Client mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestBaseIP(t *testing.T) {
tests := []struct {
name string
addr net.Addr
want string
}{
{
name: "TCP address",
addr: &net.TCPAddr{
IP: net.ParseIP("127.0.127.2"),
Port: 448,
},
want: "127.0.127.2",
},
{
name: "UDP address",
addr: &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 448,
},
want: "127.0.0.1",
},
{
name: "Other address",
addr: &net.UnixAddr{
Name: "foobar",
},
want: "foobar",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := baseIP(tt.addr); got != tt.want {
t.Errorf("baseIP() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -16,7 +16,10 @@ proto_library(
go_proto_library(
name = "v1_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
compilers = [
"//:gen-go-grpc",
"@io_bazel_rules_go//proto:go_proto",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/searcher/v1",
proto = ":v1_proto",
visibility = ["//visibility:private"],

View File

@ -16,7 +16,10 @@ proto_library(
go_proto_library(
name = "v1_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
compilers = [
"//:gen-go-grpc",
"@io_bazel_rules_go//proto:go_proto",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/symbols/v1",
proto = ":v1_proto",
visibility = ["//visibility:private"],