mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:51:55 +00:00
gRPC: add support for accesslog interceptor in gitserver package (#51437)
Co-authored-by: Jean-Hadrien Chabran <jh@chabran.fr> Co-authored-by: Alex Ostrikov <alex.ostrikov@sourcegraph.com>
This commit is contained in:
parent
1206baaea9
commit
aeef1a7d1f
2
BUILD.bazel
generated
2
BUILD.bazel
generated
@ -85,6 +85,8 @@ gazelle(
|
||||
|
||||
# Because the current implementation of rules_go uses the old protoc grpc compiler, we have to declare our own, and declare it manually in the build files.
|
||||
# See https://github.com/bazelbuild/rules_go/issues/3022
|
||||
|
||||
# gazelle:go_grpc_compilers //:gen-go-grpc,@io_bazel_rules_go//proto:go_proto
|
||||
go_proto_compiler(
|
||||
name = "gen-go-grpc",
|
||||
plugin = "@org_golang_google_grpc_cmd_protoc_gen_go_grpc//:protoc-gen-go-grpc",
|
||||
|
||||
20
WORKSPACE
20
WORKSPACE
@ -202,6 +202,26 @@ go_repository(
|
||||
version = "v1.14.1",
|
||||
)
|
||||
|
||||
# Overrides the default provided protobuf dep from rules_go by a more
|
||||
# recent one.
|
||||
go_repository(
|
||||
name = "org_golang_google_protobuf",
|
||||
build_file_proto_mode = "disable_global",
|
||||
importpath = "google.golang.org/protobuf",
|
||||
sum = "h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM=",
|
||||
version = "v1.29.1",
|
||||
) # keep
|
||||
|
||||
# Pin protoc-gen-go-grpc to 1.3.0
|
||||
# See also //:gen-go-grpc
|
||||
go_repository(
|
||||
name = "org_golang_google_grpc_cmd_protoc_gen_go_grpc",
|
||||
build_file_proto_mode = "disable_global",
|
||||
importpath = "google.golang.org/grpc/cmd/protoc-gen-go-grpc",
|
||||
sum = "h1:rNBFJjBCOgVr9pWD7rs/knKL4FRTKgpZmsRfV214zcA=",
|
||||
version = "v1.3.0",
|
||||
) # keep
|
||||
|
||||
# gazelle:repository_macro deps.bzl%go_dependencies
|
||||
go_dependencies()
|
||||
|
||||
|
||||
4
cmd/frontend/graphqlbackend/BUILD.bazel
generated
4
cmd/frontend/graphqlbackend/BUILD.bazel
generated
@ -13,6 +13,10 @@ js_library(
|
||||
srcs = glob([
|
||||
"*.graphql",
|
||||
]),
|
||||
visibility = [
|
||||
"//client/backstage-backend/node_modules/@sourcegraph/shared/dev:__pkg__",
|
||||
"//client/shared/dev:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
linter_bin.graphql_schema_linter_test(
|
||||
|
||||
2
cmd/gitserver/server/BUILD.bazel
generated
2
cmd/gitserver/server/BUILD.bazel
generated
@ -34,7 +34,7 @@ go_library(
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//cmd/gitserver/server/internal/accesslog",
|
||||
"//cmd/gitserver/server/accesslog",
|
||||
"//cmd/gitserver/server/internal/cacert",
|
||||
"//internal/actor",
|
||||
"//internal/api",
|
||||
|
||||
@ -3,12 +3,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
go_library(
|
||||
name = "accesslog",
|
||||
srcs = ["accesslog.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog",
|
||||
visibility = ["//cmd/gitserver/server:__subpackages__"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog",
|
||||
visibility = ["//cmd/gitserver:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/audit",
|
||||
"//internal/conf/conftypes",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_uber_go_atomic//:atomic",
|
||||
],
|
||||
)
|
||||
@ -26,5 +27,6 @@ go_test(
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
"@com_github_stretchr_testify//require",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
],
|
||||
)
|
||||
210
cmd/gitserver/server/accesslog/accesslog.go
Normal file
210
cmd/gitserver/server/accesslog/accesslog.go
Normal file
@ -0,0 +1,210 @@
|
||||
// accesslog provides instrumentation to record logs of access made by a given actor to a repo at
|
||||
// the http handler level.
|
||||
// access logs may optionally (as per site configuration) be included in the audit log.
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/audit"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
type paramsContext struct {
|
||||
mu sync.Mutex
|
||||
|
||||
repo string
|
||||
metadata []log.Field
|
||||
}
|
||||
|
||||
func (pc *paramsContext) Set(repo string, metadata ...log.Field) {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
|
||||
pc.repo = repo
|
||||
pc.metadata = metadata
|
||||
}
|
||||
|
||||
func (pc *paramsContext) Get() (repo string, metadata []log.Field) {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
|
||||
return pc.repo, pc.metadata
|
||||
}
|
||||
|
||||
// Record updates a mutable unexported field stored in the context,
|
||||
// making it available for Middleware to log at the end of the middleware
|
||||
// chain.
|
||||
func Record(ctx context.Context, repo string, meta ...log.Field) {
|
||||
pc := fromContext(ctx)
|
||||
if pc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pc.Set(repo, meta...)
|
||||
}
|
||||
|
||||
func withContext(ctx context.Context, pc *paramsContext) context.Context {
|
||||
return context.WithValue(ctx, contextKey{}, pc)
|
||||
}
|
||||
|
||||
func fromContext(ctx context.Context) *paramsContext {
|
||||
pc, ok := ctx.Value(contextKey{}).(*paramsContext)
|
||||
if !ok || pc == nil {
|
||||
return nil
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// accessLogger watches the site configuration and logs accesses (if enabled).
|
||||
type accessLogger struct {
|
||||
logger log.Logger
|
||||
|
||||
logEnabled *atomic.Bool
|
||||
watcher conftypes.WatchableSiteConfig
|
||||
watchEnabledOnce sync.Once
|
||||
}
|
||||
|
||||
func newAccessLogger(logger log.Logger, watcher conftypes.WatchableSiteConfig) *accessLogger {
|
||||
return &accessLogger{
|
||||
logger: logger,
|
||||
|
||||
logEnabled: atomic.NewBool(false),
|
||||
watcher: watcher,
|
||||
}
|
||||
}
|
||||
|
||||
// messages are defined here to make assertions in testing.
|
||||
const (
|
||||
accessEventMessage = "access"
|
||||
accessLoggingEnabledMessage = "access logging enabled"
|
||||
)
|
||||
|
||||
func (a *accessLogger) maybeLog(ctx context.Context) {
|
||||
// If access logging is not enabled, we are done
|
||||
if !a.isEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, log this access
|
||||
|
||||
// Now we've gone through the handler, we can get the params that the handler
|
||||
// got from the request body.
|
||||
paramsCtx := fromContext(ctx)
|
||||
if paramsCtx == nil {
|
||||
return
|
||||
}
|
||||
repository, metadata := paramsCtx.Get()
|
||||
|
||||
if repository == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var fields []log.Field
|
||||
|
||||
if paramsCtx != nil {
|
||||
params := append([]log.Field{log.String("repo", repository)}, metadata...)
|
||||
fields = append(fields, log.Object("params", params...))
|
||||
} else {
|
||||
fields = append(fields, log.String("params", "nil"))
|
||||
}
|
||||
|
||||
audit.Log(ctx, a.logger, audit.Record{
|
||||
Entity: "gitserver",
|
||||
Action: "access",
|
||||
Fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *accessLogger) isEnabled() bool {
|
||||
a.watchEnabledOnce.Do(func() {
|
||||
// Initialize the logEnabled field with the current value
|
||||
logEnabled := audit.IsEnabled(a.watcher.SiteConfig(), audit.GitserverAccess)
|
||||
if logEnabled {
|
||||
a.logger.Info(accessLoggingEnabledMessage)
|
||||
}
|
||||
|
||||
a.logEnabled.Store(logEnabled)
|
||||
|
||||
// Watch for changes to the site config
|
||||
a.watcher.Watch(func() {
|
||||
newShouldLog := audit.IsEnabled(a.watcher.SiteConfig(), audit.GitserverAccess)
|
||||
changed := a.logEnabled.Swap(newShouldLog) != newShouldLog
|
||||
if changed {
|
||||
if newShouldLog {
|
||||
a.logger.Info(accessLoggingEnabledMessage)
|
||||
} else {
|
||||
a.logger.Info("access logging disabled")
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return a.logEnabled.Load()
|
||||
}
|
||||
|
||||
// HTTPMiddleware will extract actor information and params collected by Record that has
|
||||
// been stored in the context, in order to log a trace of the access.
|
||||
func HTTPMiddleware(logger log.Logger, watcher conftypes.WatchableSiteConfig, next http.HandlerFunc) http.HandlerFunc {
|
||||
a := newAccessLogger(logger, watcher)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Prepare the context to hold the params which the handler is going to set.
|
||||
ctx := withContext(r.Context(), ¶msContext{})
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Call the next handler in the chain.
|
||||
next(w, r)
|
||||
|
||||
// Log the access
|
||||
a.maybeLog(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryServerInterceptor returns a grpc.UnaryServerInterceptor that will extract actor information and params collected by Record that has
|
||||
// been stored in the context in order to log a trace of the access.
|
||||
func UnaryServerInterceptor(logger log.Logger, watcher conftypes.WatchableSiteConfig) grpc.UnaryServerInterceptor {
|
||||
a := newAccessLogger(logger, watcher)
|
||||
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
ctx = withContext(ctx, ¶msContext{})
|
||||
resp, err = handler(ctx, req)
|
||||
|
||||
a.maybeLog(ctx)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
// StreamServerInterceptor returns a grpc.StreamServerInterceptor that will extract actor information and params collected by Record that has
|
||||
// been stored in the context in order to log a trace of the access.
|
||||
func StreamServerInterceptor(logger log.Logger, watcher conftypes.WatchableSiteConfig) grpc.StreamServerInterceptor {
|
||||
a := newAccessLogger(logger, watcher)
|
||||
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
ctx := withContext(ss.Context(), ¶msContext{})
|
||||
|
||||
ss = &wrappedServerStream{ServerStream: ss, ctx: ctx}
|
||||
err := handler(srv, ss)
|
||||
|
||||
a.maybeLog(ctx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// wrappedServerStream wraps grpc.ServerStream to override the Context method.
|
||||
type wrappedServerStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *wrappedServerStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
489
cmd/gitserver/server/accesslog/accesslog_test.go
Normal file
489
cmd/gitserver/server/accesslog/accesslog_test.go
Normal file
@ -0,0 +1,489 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/requestclient"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = withContext(ctx, ¶msContext{})
|
||||
|
||||
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
|
||||
|
||||
Record(ctx, "github.com/foo/bar", meta...)
|
||||
|
||||
pc := fromContext(ctx)
|
||||
require.NotNil(t, pc)
|
||||
assert.Equal(t, "github.com/foo/bar", pc.repo)
|
||||
assert.Equal(t, meta, pc.metadata)
|
||||
})
|
||||
|
||||
t.Run("OK not initialized context", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
|
||||
|
||||
Record(ctx, "github.com/foo/bar", meta...)
|
||||
pc := fromContext(ctx)
|
||||
assert.Nil(t, pc)
|
||||
})
|
||||
}
|
||||
|
||||
type accessLogConf struct {
|
||||
disabled bool
|
||||
callback func()
|
||||
}
|
||||
|
||||
var _ conftypes.WatchableSiteConfig = &accessLogConf{}
|
||||
|
||||
func (a *accessLogConf) Watch(cb func()) { a.callback = cb }
|
||||
func (a *accessLogConf) SiteConfig() schema.SiteConfiguration {
|
||||
return schema.SiteConfiguration{
|
||||
Log: &schema.Log{
|
||||
AuditLog: &schema.AuditLog{
|
||||
GitserverAccess: !a.disabled,
|
||||
GraphQL: false,
|
||||
InternalTraffic: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddleware(t *testing.T) {
|
||||
t.Run("OK for access log setting", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
|
||||
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := req.Context()
|
||||
ctx = requestclient.WithClient(ctx, &requestclient.Client{IP: "192.168.1.1"})
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(rec, req)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
assert.Equal(t, "github.com/foo/bar", logs[1].Fields["params"].(map[string]any)["repo"])
|
||||
|
||||
auditFields := logs[1].Fields["audit"].(map[string]interface{})
|
||||
assert.Equal(t, "gitserver", auditFields["entity"])
|
||||
assert.NotEmpty(t, auditFields["auditId"])
|
||||
|
||||
actorFields := auditFields["actor"].(map[string]interface{})
|
||||
assert.Equal(t, "unknown", actorFields["actorUID"])
|
||||
assert.Equal(t, "192.168.1.1", actorFields["ip"])
|
||||
assert.Equal(t, "", actorFields["X-Forwarded-For"])
|
||||
})
|
||||
|
||||
t.Run("handle, no recording", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
var handled bool
|
||||
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
|
||||
handled = true
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Should have handled but not logged
|
||||
assert.True(t, handled)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 1)
|
||||
assert.NotEqual(t, accessEventMessage, logs[0].Message)
|
||||
})
|
||||
|
||||
t.Run("disabled, then enabled", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
cfg := &accessLogConf{disabled: true}
|
||||
var handled bool
|
||||
h := HTTPMiddleware(logger, cfg, func(w http.ResponseWriter, r *http.Request) {
|
||||
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
|
||||
handled = true
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Request with access logging disabled
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Disabled, should have been handled but without a log message
|
||||
assert.True(t, handled)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 0)
|
||||
|
||||
// Now we re-enable
|
||||
handled = false
|
||||
cfg.disabled = false
|
||||
cfg.callback()
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Enabled, should have handled AND generated a log message
|
||||
assert.True(t, handled)
|
||||
logs = exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccessLogGRPC(t *testing.T) {
|
||||
var (
|
||||
fakeIP = "192.168.1.1"
|
||||
fakeRepositoryName = "github.com/foo/bar"
|
||||
)
|
||||
|
||||
t.Run("basic recording and audit fields", func(t *testing.T) {
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
interceptor := chainUnaryInterceptors(
|
||||
mockClientUnaryInterceptor(client),
|
||||
UnaryServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(ctx context.Context, req any) (any, error) {
|
||||
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
|
||||
handlerCalled = true
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
req := struct{}{}
|
||||
info := &grpc.UnaryServerInfo{}
|
||||
_, err := interceptor(context.Background(), req, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
logs := exportLogs()
|
||||
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
assert.Equal(t, fakeRepositoryName, logs[1].Fields["params"].(map[string]any)["repo"])
|
||||
|
||||
auditFields := logs[1].Fields["audit"].(map[string]interface{})
|
||||
assert.Equal(t, "gitserver", auditFields["entity"])
|
||||
assert.NotEmpty(t, auditFields["auditId"])
|
||||
|
||||
actorFields := auditFields["actor"].(map[string]interface{})
|
||||
assert.Equal(t, "unknown", actorFields["actorUID"])
|
||||
assert.Equal(t, fakeIP, actorFields["ip"])
|
||||
assert.Equal(t, "", actorFields["X-Forwarded-For"])
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
streamInterceptor := chainStreamInterceptors(
|
||||
mockClientStreamInterceptor(client),
|
||||
StreamServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||
ctx := stream.Context()
|
||||
|
||||
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
|
||||
handlerCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
srv := struct{}{}
|
||||
ss := &testServerStream{ctx: context.Background()}
|
||||
info := &grpc.StreamServerInfo{}
|
||||
|
||||
err := streamInterceptor(srv, ss, info, handler)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
logs := exportLogs()
|
||||
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
assert.Equal(t, fakeRepositoryName, logs[1].Fields["params"].(map[string]any)["repo"])
|
||||
|
||||
auditFields := logs[1].Fields["audit"].(map[string]interface{})
|
||||
assert.Equal(t, "gitserver", auditFields["entity"])
|
||||
assert.NotEmpty(t, auditFields["auditId"])
|
||||
|
||||
actorFields := auditFields["actor"].(map[string]interface{})
|
||||
assert.Equal(t, "unknown", actorFields["actorUID"])
|
||||
assert.Equal(t, fakeIP, actorFields["ip"])
|
||||
assert.Equal(t, "", actorFields["X-Forwarded-For"])
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("handler, no recording", func(t *testing.T) {
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
interceptor := chainUnaryInterceptors(
|
||||
mockClientUnaryInterceptor(client),
|
||||
UnaryServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(ctx context.Context, req any) (any, error) {
|
||||
handlerCalled = true
|
||||
return req, nil
|
||||
}
|
||||
|
||||
req := struct{}{}
|
||||
info := &grpc.UnaryServerInfo{}
|
||||
_, err := interceptor(context.Background(), req, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
logs := exportLogs()
|
||||
|
||||
// Should have handled but not logged
|
||||
require.Len(t, logs, 1)
|
||||
assert.NotEqual(t, accessEventMessage, logs[0].Message)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
streamInterceptor := chainStreamInterceptors(
|
||||
mockClientStreamInterceptor(client),
|
||||
StreamServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||
handlerCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
srv := struct{}{}
|
||||
ss := &testServerStream{ctx: context.Background()}
|
||||
info := &grpc.StreamServerInfo{}
|
||||
|
||||
err := streamInterceptor(srv, ss, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
logs := exportLogs()
|
||||
|
||||
// Should have handled but not logged
|
||||
require.Len(t, logs, 1)
|
||||
assert.NotEqual(t, accessEventMessage, logs[0].Message)
|
||||
})
|
||||
|
||||
t.Run("disabled, then enabled", func(t *testing.T) {
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{disabled: true}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
interceptor := chainUnaryInterceptors(
|
||||
mockClientUnaryInterceptor(client),
|
||||
UnaryServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(ctx context.Context, req any) (any, error) {
|
||||
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
|
||||
handlerCalled = true
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
req := struct{}{}
|
||||
info := &grpc.UnaryServerInfo{}
|
||||
_, err := interceptor(context.Background(), req, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
// Disabled, should have been handled but without a log message
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 0)
|
||||
|
||||
// Now we re-enable
|
||||
handlerCalled = false
|
||||
configuration.disabled = false
|
||||
configuration.callback()
|
||||
_, err = interceptor(context.Background(), req, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
// Enabled, should have handled AND generated a log message
|
||||
logs = exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
|
||||
configuration := &accessLogConf{disabled: true}
|
||||
client := &requestclient.Client{IP: fakeIP}
|
||||
|
||||
interceptor := chainStreamInterceptors(
|
||||
mockClientStreamInterceptor(client),
|
||||
StreamServerInterceptor(logger, configuration),
|
||||
)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||
ctx := stream.Context()
|
||||
|
||||
Record(ctx, fakeRepositoryName, log.String("foo", "bar"))
|
||||
handlerCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
srv := struct{}{}
|
||||
ss := &testServerStream{ctx: context.Background()}
|
||||
info := &grpc.StreamServerInfo{}
|
||||
|
||||
err := interceptor(srv, ss, info, handler)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
// Disabled, should have been handled but without a log message
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 0)
|
||||
|
||||
// Now we re-enable
|
||||
handlerCalled = false
|
||||
configuration.disabled = false
|
||||
configuration.callback()
|
||||
err = interceptor(srv, ss, info, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to call interceptor: %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
// Enabled, should have handled AND generated a log message
|
||||
logs = exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func mockClientUnaryInterceptor(client *requestclient.Client) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
||||
ctx = requestclient.WithClient(ctx, client)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func mockClientStreamInterceptor(client *requestclient.Client) grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
ctx := requestclient.WithClient(ss.Context(), client)
|
||||
return handler(srv, &wrappedServerStream{ss, ctx})
|
||||
}
|
||||
}
|
||||
|
||||
func chainUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
if len(interceptors) == 0 {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, req, info, func(ctx context.Context, req any) (any, error) {
|
||||
return chainUnaryInterceptors(interceptors[1:]...)(ctx, req, info, handler)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func chainStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if len(interceptors) == 0 {
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
return interceptors[0](srv, ss, info, func(srv any, ss grpc.ServerStream) error {
|
||||
return chainStreamInterceptors(interceptors[1:]...)(srv, ss, info, handler)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testServerStream is a mock implementation of grpc.ServerStream for testing.
|
||||
type testServerStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *testServerStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
var _ grpc.ServerStream = &testServerStream{}
|
||||
@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver/gitdomain"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver/protocol"
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/env"
|
||||
"github.com/sourcegraph/sourcegraph/lib/gitservice"
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
// accesslog provides instrumentation to record logs of access made by a given actor to a repo at
|
||||
// the http handler level.
|
||||
// access logs may optionally (as per site configuration) be included in the audit log.
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/audit"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
type paramsContext struct {
|
||||
repo string
|
||||
metadata []log.Field
|
||||
}
|
||||
|
||||
// Record updates a mutable unexported field stored in the context,
|
||||
// making it available for Middleware to log at the end of the middleware
|
||||
// chain.
|
||||
func Record(ctx context.Context, repo string, meta ...log.Field) {
|
||||
pc := fromContext(ctx)
|
||||
if pc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pc.repo = repo
|
||||
pc.metadata = meta
|
||||
}
|
||||
|
||||
func withContext(ctx context.Context, pc *paramsContext) context.Context {
|
||||
return context.WithValue(ctx, contextKey{}, pc)
|
||||
}
|
||||
|
||||
func fromContext(ctx context.Context) *paramsContext {
|
||||
pc, ok := ctx.Value(contextKey{}).(*paramsContext)
|
||||
if !ok || pc == nil {
|
||||
return nil
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// accessLogger handles HTTP requests and, if logEnabled, logs accesses.
|
||||
type accessLogger struct {
|
||||
logger log.Logger
|
||||
next http.HandlerFunc
|
||||
logEnabled *atomic.Bool
|
||||
}
|
||||
|
||||
var _ http.Handler = &accessLogger{}
|
||||
|
||||
// messages are defined here to make assertions in testing.
|
||||
const (
|
||||
accessEventMessage = "access"
|
||||
accessLoggingEnabledMessage = "access logging enabled"
|
||||
)
|
||||
|
||||
func (a *accessLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Prepare the context to hold the params which the handler is going to set.
|
||||
ctx := r.Context()
|
||||
r = r.WithContext(withContext(ctx, ¶msContext{}))
|
||||
a.next(w, r)
|
||||
|
||||
// If access logging is not enabled, we are done
|
||||
if !a.logEnabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, log this access
|
||||
var fields []log.Field
|
||||
|
||||
// Now we've gone through the handler, we can get the params that the handler
|
||||
// got from the request body.
|
||||
paramsCtx := fromContext(r.Context())
|
||||
if paramsCtx == nil {
|
||||
return
|
||||
}
|
||||
if paramsCtx.repo == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if paramsCtx != nil {
|
||||
params := append([]log.Field{log.String("repo", paramsCtx.repo)}, paramsCtx.metadata...)
|
||||
fields = append(fields, log.Object("params", params...))
|
||||
} else {
|
||||
fields = append(fields, log.String("params", "nil"))
|
||||
}
|
||||
|
||||
audit.Log(ctx, a.logger, audit.Record{
|
||||
Entity: "gitserver",
|
||||
Action: "access",
|
||||
Fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
// HTTPMiddleware will extract actor information and params collected by Record that has
|
||||
// been stored in the context, in order to log a trace of the access.
|
||||
func HTTPMiddleware(logger log.Logger, watcher conftypes.WatchableSiteConfig, next http.HandlerFunc) http.HandlerFunc {
|
||||
handler := &accessLogger{
|
||||
logger: logger,
|
||||
next: next,
|
||||
logEnabled: atomic.NewBool(audit.IsEnabled(watcher.SiteConfig(), audit.GitserverAccess)),
|
||||
}
|
||||
if handler.logEnabled.Load() {
|
||||
logger.Info(accessLoggingEnabledMessage)
|
||||
}
|
||||
|
||||
// Allow live toggling of access logging
|
||||
watcher.Watch(func() {
|
||||
newShouldLog := audit.IsEnabled(watcher.SiteConfig(), audit.GitserverAccess)
|
||||
changed := handler.logEnabled.Swap(newShouldLog) != newShouldLog
|
||||
if changed {
|
||||
if newShouldLog {
|
||||
logger.Info(accessLoggingEnabledMessage)
|
||||
} else {
|
||||
logger.Info("access logging disabled")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return handler.ServeHTTP
|
||||
}
|
||||
@ -1,145 +0,0 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/requestclient"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = withContext(ctx, ¶msContext{})
|
||||
|
||||
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
|
||||
|
||||
Record(ctx, "github.com/foo/bar", meta...)
|
||||
|
||||
pc := fromContext(ctx)
|
||||
require.NotNil(t, pc)
|
||||
assert.Equal(t, "github.com/foo/bar", pc.repo)
|
||||
assert.Equal(t, meta, pc.metadata)
|
||||
})
|
||||
|
||||
t.Run("OK not initialized context", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
meta := []log.Field{log.String("cmd", "git"), log.String("args", "grep foo")}
|
||||
|
||||
Record(ctx, "github.com/foo/bar", meta...)
|
||||
pc := fromContext(ctx)
|
||||
assert.Nil(t, pc)
|
||||
})
|
||||
}
|
||||
|
||||
type accessLogConf struct {
|
||||
disabled bool
|
||||
callback func()
|
||||
}
|
||||
|
||||
var _ conftypes.WatchableSiteConfig = &accessLogConf{}
|
||||
|
||||
func (a *accessLogConf) Watch(cb func()) { a.callback = cb }
|
||||
func (a *accessLogConf) SiteConfig() schema.SiteConfiguration {
|
||||
return schema.SiteConfiguration{
|
||||
Log: &schema.Log{
|
||||
AuditLog: &schema.AuditLog{
|
||||
GitserverAccess: !a.disabled,
|
||||
GraphQL: false,
|
||||
InternalTraffic: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMiddleware(t *testing.T) {
|
||||
t.Run("OK for access log setting", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
|
||||
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := req.Context()
|
||||
ctx = requestclient.WithClient(ctx, &requestclient.Client{IP: "192.168.1.1"})
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(rec, req)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
assert.Equal(t, "github.com/foo/bar", logs[1].Fields["params"].(map[string]any)["repo"])
|
||||
|
||||
auditFields := logs[1].Fields["audit"].(map[string]interface{})
|
||||
assert.Equal(t, "gitserver", auditFields["entity"])
|
||||
assert.NotEmpty(t, auditFields["auditId"])
|
||||
|
||||
actorFields := auditFields["actor"].(map[string]interface{})
|
||||
assert.Equal(t, "unknown", actorFields["actorUID"])
|
||||
assert.Equal(t, "192.168.1.1", actorFields["ip"])
|
||||
assert.Equal(t, "", actorFields["X-Forwarded-For"])
|
||||
})
|
||||
|
||||
t.Run("handle, no recording", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
var handled bool
|
||||
h := HTTPMiddleware(logger, &accessLogConf{}, func(w http.ResponseWriter, r *http.Request) {
|
||||
handled = true
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Should have handled but not logged
|
||||
assert.True(t, handled)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 1)
|
||||
assert.NotEqual(t, accessEventMessage, logs[0].Message)
|
||||
})
|
||||
|
||||
t.Run("disabled, then enabled", func(t *testing.T) {
|
||||
logger, exportLogs := logtest.Captured(t)
|
||||
cfg := &accessLogConf{disabled: true}
|
||||
var handled bool
|
||||
h := HTTPMiddleware(logger, cfg, func(w http.ResponseWriter, r *http.Request) {
|
||||
Record(r.Context(), "github.com/foo/bar", log.String("cmd", "git"), log.String("args", "grep foo"))
|
||||
handled = true
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Request with access logging disabled
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Disabled, should have been handled but without a log message
|
||||
assert.True(t, handled)
|
||||
logs := exportLogs()
|
||||
require.Len(t, logs, 0)
|
||||
|
||||
// Now we re-enable
|
||||
handled = false
|
||||
cfg.disabled = false
|
||||
cfg.callback()
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
// Enabled, should have handled AND generated a log message
|
||||
assert.True(t, handled)
|
||||
logs = exportLogs()
|
||||
require.Len(t, logs, 2)
|
||||
assert.Equal(t, accessLoggingEnabledMessage, logs[0].Message)
|
||||
assert.Contains(t, logs[1].Message, accessEventMessage)
|
||||
})
|
||||
}
|
||||
@ -37,7 +37,7 @@ import (
|
||||
"github.com/sourcegraph/conc"
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/internal/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
@ -1643,7 +1643,7 @@ func (s *Server) handleExec(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log which which actor is accessing the repo.
|
||||
// Log which actor is accessing the repo.
|
||||
args := req.Args
|
||||
cmd := ""
|
||||
if len(req.Args) > 0 {
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver/gitdomain"
|
||||
@ -38,18 +39,30 @@ func (gs *GRPCServer) Exec(req *proto.ExecRequest, ss proto.GitserverService_Exe
|
||||
})
|
||||
})
|
||||
|
||||
// Log which actor is accessing the repo.
|
||||
args := req.GetArgs()
|
||||
cmd := ""
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
accesslog.Record(ss.Context(), req.GetRepo(),
|
||||
log.String("cmd", cmd),
|
||||
log.Strings("args", args),
|
||||
)
|
||||
|
||||
// TODO(mucles): set user agent from all grpc clients
|
||||
return gs.doExec(ss.Context(), gs.Server.Logger, &internalReq, "unknown-grpc-client", w)
|
||||
}
|
||||
|
||||
func (gs *GRPCServer) Archive(req *proto.ArchiveRequest, ss proto.GitserverService_ArchiveServer) error {
|
||||
//TODO(mucles): re-enable access logging (see server.go handleArchive)
|
||||
// Log which which actor is accessing the repo.
|
||||
// accesslog.Record(ctx, req.Repo,
|
||||
// log.String("treeish", req.Treeish),
|
||||
// log.String("format", req.Format),
|
||||
// log.Strings("path", req.Pathspecs),
|
||||
// )
|
||||
accesslog.Record(ss.Context(), req.Repo,
|
||||
log.String("treeish", req.Treeish),
|
||||
log.String("format", req.Format),
|
||||
log.Strings("path", req.Pathspecs),
|
||||
)
|
||||
|
||||
if err := checkSpecArgSafety(req.GetTreeish()); err != nil {
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
|
||||
4
cmd/gitserver/shared/BUILD.bazel
generated
4
cmd/gitserver/shared/BUILD.bazel
generated
@ -11,6 +11,7 @@ go_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//cmd/gitserver/server",
|
||||
"//cmd/gitserver/server/accesslog",
|
||||
"//internal/actor",
|
||||
"//internal/api",
|
||||
"//internal/authz",
|
||||
@ -46,6 +47,7 @@ go_library(
|
||||
"//schema",
|
||||
"@com_github_json_iterator_go//:go",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_x_sync//semaphore",
|
||||
"@org_golang_x_time//rate",
|
||||
],
|
||||
@ -66,7 +68,9 @@ go_test(
|
||||
"//internal/database",
|
||||
"//internal/extsvc",
|
||||
"//internal/types",
|
||||
"@com_github_google_go_cmp//cmp",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@ -17,8 +17,10 @@ import (
|
||||
"github.com/sourcegraph/log"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"golang.org/x/time/rate"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server/accesslog"
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/authz"
|
||||
@ -148,7 +150,25 @@ func Main(ctx context.Context, observationCtx *observation.Context, ready servic
|
||||
GlobalBatchLogSemaphore: semaphore.NewWeighted(int64(batchLogGlobalConcurrencyLimit)),
|
||||
}
|
||||
|
||||
grpcServer := defaults.NewServer(logger)
|
||||
configurationWatcher := conf.DefaultClient()
|
||||
|
||||
var additionalServerOptions []grpc.ServerOption
|
||||
|
||||
for method, scopedLogger := range map[string]log.Logger{
|
||||
proto.GitserverService_Exec_FullMethodName: logger.Scoped("exec.accesslog", "exec endpoint access log"),
|
||||
proto.GitserverService_Archive_FullMethodName: logger.Scoped("archive.accesslog", "archive endpoint access log"),
|
||||
} {
|
||||
streamInterceptor := accesslog.StreamServerInterceptor(scopedLogger, configurationWatcher)
|
||||
unaryInterceptor := accesslog.UnaryServerInterceptor(scopedLogger, configurationWatcher)
|
||||
|
||||
additionalServerOptions = append(additionalServerOptions,
|
||||
grpc.ChainStreamInterceptor(methodSpecificStreamInterceptor(method, streamInterceptor)),
|
||||
grpc.ChainUnaryInterceptor(methodSpecificUnaryInterceptor(method, unaryInterceptor)),
|
||||
)
|
||||
}
|
||||
|
||||
grpcServer := defaults.NewServer(logger, additionalServerOptions...)
|
||||
|
||||
proto.RegisterGitserverServiceServer(grpcServer, &server.GRPCServer{
|
||||
Server: &gitserver,
|
||||
})
|
||||
@ -538,3 +558,29 @@ func getAddr() string {
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// methodSpecificStreamInterceptor returns a gRPC stream server interceptor that only calls the next interceptor if the method matches.
|
||||
//
|
||||
// The returned interceptor will call next if the invoked gRPC method matches the method parameter. Otherwise, it will call handler directly.
|
||||
func methodSpecificStreamInterceptor(method string, next grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if method != info.FullMethod {
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
return next(srv, ss, info, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// methodSpecificUnaryInterceptor returns a gRPC unary server interceptor that only calls the next interceptor if the method matches.
|
||||
//
|
||||
// The returned interceptor will call next if the invoked gRPC method matches the method parameter. Otherwise, it will call handler directly.
|
||||
func methodSpecificUnaryInterceptor(method string, next grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
if method != info.FullMethod {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
return next(ctx, req, info, handler)
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,8 +8,10 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/gitserver/server"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
@ -112,3 +114,118 @@ func TestGetVCSSyncer(t *testing.T) {
|
||||
t.Fatalf("Want *server.PerforceDepotSyncer, got %T", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodSpecificStreamInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
matchedMethod string
|
||||
testMethod string
|
||||
|
||||
expectedInterceptorCalled bool
|
||||
}{
|
||||
{
|
||||
name: "allowed method",
|
||||
|
||||
matchedMethod: "allowedMethod",
|
||||
testMethod: "allowedMethod",
|
||||
|
||||
expectedInterceptorCalled: true,
|
||||
},
|
||||
{
|
||||
name: "not allowed method",
|
||||
|
||||
matchedMethod: "allowedMethod",
|
||||
testMethod: "otherMethod",
|
||||
|
||||
expectedInterceptorCalled: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
interceptorCalled := false
|
||||
interceptor := methodSpecificStreamInterceptor(test.matchedMethod, func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
interceptorCalled = true
|
||||
return handler(srv, ss)
|
||||
})
|
||||
|
||||
handlerCalled := false
|
||||
noopHandler := func(srv any, ss grpc.ServerStream) error {
|
||||
handlerCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
err := interceptor(nil, nil, &grpc.StreamServerInfo{FullMethod: test.testMethod}, noopHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called")
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(test.expectedInterceptorCalled, interceptorCalled); diff != "" {
|
||||
t.Fatalf("unexpected interceptor called value (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodSpecificUnaryInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
matchedMethod string
|
||||
testMethod string
|
||||
|
||||
expectedInterceptorCalled bool
|
||||
}{
|
||||
{
|
||||
name: "allowed method",
|
||||
|
||||
matchedMethod: "allowedMethod",
|
||||
testMethod: "allowedMethod",
|
||||
|
||||
expectedInterceptorCalled: true,
|
||||
},
|
||||
{
|
||||
name: "not allowed method",
|
||||
|
||||
matchedMethod: "allowedMethod",
|
||||
testMethod: "otherMethod",
|
||||
|
||||
expectedInterceptorCalled: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
interceptorCalled := false
|
||||
interceptor := methodSpecificUnaryInterceptor(test.matchedMethod, func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
interceptorCalled = true
|
||||
return handler(ctx, req)
|
||||
})
|
||||
|
||||
handlerCalled := false
|
||||
noopHandler := func(ctx context.Context, req any) (any, error) {
|
||||
handlerCalled = true
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: test.testMethod}, noopHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called")
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(test.expectedInterceptorCalled, interceptorCalled); diff != "" {
|
||||
t.Fatalf("unexpected interceptor called value (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
8
deps.bzl
8
deps.bzl
@ -8978,14 +8978,6 @@ def go_dependencies():
|
||||
sum = "h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc=",
|
||||
version = "v1.53.0",
|
||||
)
|
||||
go_repository(
|
||||
name = "org_golang_google_grpc_cmd_protoc_gen_go_grpc",
|
||||
build_file_proto_mode = "disable_global",
|
||||
importpath = "google.golang.org/grpc/cmd/protoc-gen-go-grpc",
|
||||
sum = "h1:TLkBREm4nIsEcexnCjgQd5GQWaHcqMzwQV0TX9pq8S0=",
|
||||
version = "v1.2.0",
|
||||
) # keep
|
||||
|
||||
go_repository(
|
||||
name = "org_golang_google_grpc_examples",
|
||||
build_file_proto_mode = "disable_global",
|
||||
|
||||
5
internal/gitserver/v1/BUILD.bazel
generated
5
internal/gitserver/v1/BUILD.bazel
generated
@ -14,7 +14,10 @@ proto_library(
|
||||
|
||||
go_proto_library(
|
||||
name = "v1_go_proto",
|
||||
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
|
||||
compilers = [
|
||||
"//:gen-go-grpc",
|
||||
"@io_bazel_rules_go//proto:go_proto",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/gitserver/v1",
|
||||
proto = ":v1_proto",
|
||||
visibility = ["//visibility:private"],
|
||||
|
||||
2
internal/grpc/BUILD.bazel
generated
2
internal/grpc/BUILD.bazel
generated
@ -5,7 +5,6 @@ go_library(
|
||||
srcs = [
|
||||
"grpc.go",
|
||||
"panics.go",
|
||||
"propagator.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc",
|
||||
visibility = ["//:__subpackages__"],
|
||||
@ -14,7 +13,6 @@ go_library(
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//codes",
|
||||
"@org_golang_google_grpc//metadata",
|
||||
"@org_golang_google_grpc//status",
|
||||
"@org_golang_x_net//http2",
|
||||
"@org_golang_x_net//http2/h2c",
|
||||
|
||||
2
internal/grpc/defaults/BUILD.bazel
generated
2
internal/grpc/defaults/BUILD.bazel
generated
@ -12,6 +12,8 @@ go_library(
|
||||
"//internal/actor",
|
||||
"//internal/env",
|
||||
"//internal/grpc",
|
||||
"//internal/grpc/propagator",
|
||||
"//internal/requestclient",
|
||||
"//internal/trace/policy",
|
||||
"//internal/ttlcache",
|
||||
"//lib/errors",
|
||||
|
||||
@ -20,6 +20,8 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/env"
|
||||
internalgrpc "github.com/sourcegraph/sourcegraph/internal/grpc"
|
||||
"github.com/sourcegraph/sourcegraph/internal/grpc/propagator"
|
||||
"github.com/sourcegraph/sourcegraph/internal/requestclient"
|
||||
"github.com/sourcegraph/sourcegraph/internal/trace/policy"
|
||||
)
|
||||
|
||||
@ -47,14 +49,16 @@ func DialOptions() []grpc.DialOption {
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithChainStreamInterceptor(
|
||||
grpc_prometheus.StreamClientInterceptor(metrics),
|
||||
internalgrpc.StreamClientPropagator(actor.ActorPropagator{}),
|
||||
internalgrpc.StreamClientPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.StreamClientPropagator(actor.ActorPropagator{}),
|
||||
propagator.StreamClientPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.StreamClientPropagator(requestclient.Propagator{}),
|
||||
otelStreamInterceptor,
|
||||
),
|
||||
grpc.WithChainUnaryInterceptor(
|
||||
grpc_prometheus.UnaryClientInterceptor(metrics),
|
||||
internalgrpc.UnaryClientPropagator(actor.ActorPropagator{}),
|
||||
internalgrpc.UnaryClientPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.UnaryClientPropagator(actor.ActorPropagator{}),
|
||||
propagator.UnaryClientPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.UnaryClientPropagator(requestclient.Propagator{}),
|
||||
otelUnaryInterceptor,
|
||||
),
|
||||
}
|
||||
@ -87,15 +91,17 @@ func ServerOptions(logger log.Logger) []grpc.ServerOption {
|
||||
grpc.ChainStreamInterceptor(
|
||||
internalgrpc.NewStreamPanicCatcher(logger),
|
||||
grpc_prometheus.StreamServerInterceptor(metrics),
|
||||
internalgrpc.StreamServerPropagator(actor.ActorPropagator{}),
|
||||
internalgrpc.StreamServerPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.StreamServerPropagator(requestclient.Propagator{}),
|
||||
propagator.StreamServerPropagator(actor.ActorPropagator{}),
|
||||
propagator.StreamServerPropagator(policy.ShouldTracePropagator{}),
|
||||
otelgrpc.StreamServerInterceptor(),
|
||||
),
|
||||
grpc.ChainUnaryInterceptor(
|
||||
internalgrpc.NewUnaryPanicCatcher(logger),
|
||||
grpc_prometheus.UnaryServerInterceptor(metrics),
|
||||
internalgrpc.UnaryServerPropagator(actor.ActorPropagator{}),
|
||||
internalgrpc.UnaryServerPropagator(policy.ShouldTracePropagator{}),
|
||||
propagator.UnaryServerPropagator(requestclient.Propagator{}),
|
||||
propagator.UnaryServerPropagator(actor.ActorPropagator{}),
|
||||
propagator.UnaryServerPropagator(policy.ShouldTracePropagator{}),
|
||||
otelgrpc.UnaryServerInterceptor(),
|
||||
),
|
||||
}
|
||||
|
||||
12
internal/grpc/propagator/BUILD.bazel
generated
Normal file
12
internal/grpc/propagator/BUILD.bazel
generated
Normal file
@ -0,0 +1,12 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "propagator",
|
||||
srcs = ["propagator.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/propagator",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//metadata",
|
||||
],
|
||||
)
|
||||
@ -1,4 +1,4 @@
|
||||
package grpc
|
||||
package propagator
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -86,7 +86,7 @@ func StreamServerPropagator(prop Propagator) grpc.StreamServerInterceptor {
|
||||
}
|
||||
}
|
||||
|
||||
// StreamServerPropagator returns an interceptor that will use the given propagator
|
||||
// UnaryServerPropagator returns an interceptor that will use the given propagator
|
||||
// to translate some metadata back into the context for the RPC handler. The client
|
||||
// should be configured with an interceptor that uses the same propagator.
|
||||
func UnaryServerPropagator(prop Propagator) grpc.UnaryServerInterceptor {
|
||||
5
internal/repoupdater/v1/BUILD.bazel
generated
5
internal/repoupdater/v1/BUILD.bazel
generated
@ -14,7 +14,10 @@ proto_library(
|
||||
|
||||
go_proto_library(
|
||||
name = "v1_go_proto",
|
||||
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
|
||||
compilers = [
|
||||
"//:gen-go-grpc",
|
||||
"@io_bazel_rules_go//proto:go_proto",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/repoupdater/v1",
|
||||
proto = ":v1_proto",
|
||||
visibility = ["//visibility:private"],
|
||||
|
||||
18
internal/requestclient/BUILD.bazel
generated
18
internal/requestclient/BUILD.bazel
generated
@ -1,11 +1,27 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "requestclient",
|
||||
srcs = [
|
||||
"client.go",
|
||||
"grpc.go",
|
||||
"http.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/requestclient",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/grpc/propagator",
|
||||
"@org_golang_google_grpc//metadata",
|
||||
"@org_golang_google_grpc//peer",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "requestclient_test",
|
||||
srcs = ["grpc_test.go"],
|
||||
embed = [":requestclient"],
|
||||
deps = [
|
||||
"@com_github_google_go_cmp//cmp",
|
||||
"@org_golang_google_grpc//peer",
|
||||
],
|
||||
)
|
||||
|
||||
70
internal/requestclient/grpc.go
Normal file
70
internal/requestclient/grpc.go
Normal file
@ -0,0 +1,70 @@
|
||||
package requestclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
internalgrpc "github.com/sourcegraph/sourcegraph/internal/grpc/propagator"
|
||||
)
|
||||
|
||||
// Propagator is a github.com/sourcegraph/sourcegraph/internal/grpc/propagator.Propagator that Propagates
|
||||
// the Client in the context across the gRPC client / server request boundary.
|
||||
//
|
||||
// If the context does not contain a Client, the server will backfill the Client's IP with the IP of the address
|
||||
// that the request came from. (see https://pkg.go.dev/google.golang.org/grpc/peer for more information)
|
||||
type Propagator struct{}
|
||||
|
||||
func (Propagator) FromContext(ctx context.Context) metadata.MD {
|
||||
client := FromContext(ctx)
|
||||
if client == nil {
|
||||
return metadata.New(nil)
|
||||
}
|
||||
|
||||
return metadata.Pairs(
|
||||
headerKeyClientIP, client.IP,
|
||||
headerKeyForwardedFor, client.ForwardedFor,
|
||||
)
|
||||
}
|
||||
|
||||
func (Propagator) InjectContext(ctx context.Context, md metadata.MD) context.Context {
|
||||
var ip string
|
||||
var forwardedFor string
|
||||
|
||||
if vals := md.Get(headerKeyClientIP); len(vals) > 0 {
|
||||
ip = vals[0]
|
||||
}
|
||||
|
||||
if vals := md.Get(headerKeyForwardedFor); len(vals) > 0 {
|
||||
forwardedFor = vals[0]
|
||||
}
|
||||
|
||||
if ip == "" {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok && p != nil {
|
||||
ip = baseIP(p.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
c := Client{
|
||||
IP: ip,
|
||||
ForwardedFor: forwardedFor,
|
||||
}
|
||||
return WithClient(ctx, &c)
|
||||
}
|
||||
|
||||
var _ internalgrpc.Propagator = Propagator{}
|
||||
|
||||
// baseIP returns the base IP address of the given net.Addr
|
||||
func baseIP(addr net.Addr) string {
|
||||
switch a := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return a.IP.String()
|
||||
case *net.UDPAddr:
|
||||
return a.IP.String()
|
||||
default:
|
||||
return addr.String()
|
||||
}
|
||||
}
|
||||
157
internal/requestclient/grpc_test.go
Normal file
157
internal/requestclient/grpc_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package requestclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
func TestPropagator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
requestClient *Client
|
||||
requestPeer *peer.Peer
|
||||
|
||||
wantClient *Client
|
||||
}{
|
||||
{
|
||||
name: "no client or peer",
|
||||
|
||||
wantClient: &Client{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "client with no peer",
|
||||
requestClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
|
||||
wantClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "peer only (nil client)",
|
||||
requestPeer: &peer.Peer{
|
||||
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.1")},
|
||||
},
|
||||
|
||||
wantClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer only (non-nil empty client)",
|
||||
|
||||
requestClient: &Client{},
|
||||
requestPeer: &peer.Peer{
|
||||
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.1")},
|
||||
},
|
||||
|
||||
wantClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "client should override peer",
|
||||
|
||||
requestClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
requestPeer: &peer.Peer{
|
||||
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.3")},
|
||||
},
|
||||
|
||||
wantClient: &Client{
|
||||
IP: "192.168.1.1",
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "client for ForwardedFor, peer for IP",
|
||||
|
||||
requestClient: &Client{
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
requestPeer: &peer.Peer{
|
||||
Addr: &net.IPAddr{IP: net.ParseIP("192.168.1.3")},
|
||||
},
|
||||
|
||||
wantClient: &Client{
|
||||
IP: "192.168.1.3",
|
||||
ForwardedFor: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
requestCtx := context.Background()
|
||||
if test.requestClient != nil {
|
||||
requestCtx = WithClient(requestCtx, test.requestClient)
|
||||
}
|
||||
|
||||
if test.requestPeer != nil {
|
||||
requestCtx = peer.NewContext(requestCtx, test.requestPeer)
|
||||
}
|
||||
|
||||
propagator := &Propagator{}
|
||||
md := propagator.FromContext(requestCtx)
|
||||
|
||||
resultCtx := propagator.InjectContext(requestCtx, md)
|
||||
if diff := cmp.Diff(test.wantClient, FromContext(resultCtx)); diff != "" {
|
||||
t.Errorf("Client mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "TCP address",
|
||||
addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.127.2"),
|
||||
Port: 448,
|
||||
},
|
||||
want: "127.0.127.2",
|
||||
},
|
||||
{
|
||||
name: "UDP address",
|
||||
addr: &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 448,
|
||||
},
|
||||
want: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "Other address",
|
||||
addr: &net.UnixAddr{
|
||||
Name: "foobar",
|
||||
},
|
||||
want: "foobar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := baseIP(tt.addr); got != tt.want {
|
||||
t.Errorf("baseIP() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
5
internal/searcher/v1/BUILD.bazel
generated
5
internal/searcher/v1/BUILD.bazel
generated
@ -16,7 +16,10 @@ proto_library(
|
||||
|
||||
go_proto_library(
|
||||
name = "v1_go_proto",
|
||||
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
|
||||
compilers = [
|
||||
"//:gen-go-grpc",
|
||||
"@io_bazel_rules_go//proto:go_proto",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/searcher/v1",
|
||||
proto = ":v1_proto",
|
||||
visibility = ["//visibility:private"],
|
||||
|
||||
5
internal/symbols/v1/BUILD.bazel
generated
5
internal/symbols/v1/BUILD.bazel
generated
@ -16,7 +16,10 @@ proto_library(
|
||||
|
||||
go_proto_library(
|
||||
name = "v1_go_proto",
|
||||
compilers = ["@io_bazel_rules_go//proto:go_grpc"],
|
||||
compilers = [
|
||||
"//:gen-go-grpc",
|
||||
"@io_bazel_rules_go//proto:go_proto",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/symbols/v1",
|
||||
proto = ":v1_proto",
|
||||
visibility = ["//visibility:private"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user