diff --git a/cmd/telemetry-gateway/internal/events/BUILD.bazel b/cmd/telemetry-gateway/internal/events/BUILD.bazel index b9da5e6ada4..df0af3981ba 100644 --- a/cmd/telemetry-gateway/internal/events/BUILD.bazel +++ b/cmd/telemetry-gateway/internal/events/BUILD.bazel @@ -38,5 +38,6 @@ go_test( "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", "@org_golang_google_protobuf//types/known/structpb", + "@org_golang_google_protobuf//types/known/timestamppb", ], ) diff --git a/cmd/telemetry-gateway/internal/events/attributes.go b/cmd/telemetry-gateway/internal/events/attributes.go index 04498293386..2a1233f8eca 100644 --- a/cmd/telemetry-gateway/internal/events/attributes.go +++ b/cmd/telemetry-gateway/internal/events/attributes.go @@ -10,10 +10,11 @@ import ( // extractPubSubAttributes extracts attributes from the event for use in the // published pub/sub message as attributes. This makes it easiser to build // routing of events in our data pipelines. -func extractPubSubAttributes(event *telemetrygatewayv1.Event) map[string]string { +func extractPubSubAttributes(publisherSource string, event *telemetrygatewayv1.Event) map[string]string { attributes := map[string]string{ - "event.feature": event.Feature, - "event.action": event.Action, + "publisher.source": publisherSource, + "event.feature": event.Feature, + "event.action": event.Action, "event.hasPrivateMetadata": strconv.FormatBool( event.GetParameters().GetPrivateMetadata() != nil), } diff --git a/cmd/telemetry-gateway/internal/events/attributes_test.go b/cmd/telemetry-gateway/internal/events/attributes_test.go index 42df8c4e029..d4ca1e3b285 100644 --- a/cmd/telemetry-gateway/internal/events/attributes_test.go +++ b/cmd/telemetry-gateway/internal/events/attributes_test.go @@ -25,6 +25,7 @@ func TestExtractPubSubAttributes(t *testing.T) { expect: autogold.Expect(map[string]string{ "event.action": "chat", "event.feature": "cody.feature", "event.hasPrivateMetadata": "false", + "publisher.source": "licensed_instance", }), }, { @@ -41,6 +42,7 @@ func TestExtractPubSubAttributes(t *testing.T) { expect: autogold.Expect(map[string]string{ "event.action": "chat", "event.feature": "cody.feature", "event.hasPrivateMetadata": "true", + "publisher.source": "licensed_instance", }), }, { @@ -63,6 +65,7 @@ func TestExtractPubSubAttributes(t *testing.T) { "event.action": "chat", "event.feature": "cody.feature", "event.hasPrivateMetadata": "true", "event.recordsPrivateMetadataTranscript": "true", + "publisher.source": "licensed_instance", }), }, { @@ -80,11 +83,12 @@ func TestExtractPubSubAttributes(t *testing.T) { expect: autogold.Expect(map[string]string{ "event.action": "chat", "event.feature": "cody.feature", "event.hasPrivateMetadata": "false", + "publisher.source": "licensed_instance", }), }, } { t.Run(tc.name, func(t *testing.T) { - tc.expect.Equal(t, extractPubSubAttributes(tc.event)) + tc.expect.Equal(t, extractPubSubAttributes("licensed_instance", tc.event)) }) } } diff --git a/cmd/telemetry-gateway/internal/events/events.go b/cmd/telemetry-gateway/internal/events/events.go index 0fe3d601220..37d681d6276 100644 --- a/cmd/telemetry-gateway/internal/events/events.go +++ b/cmd/telemetry-gateway/internal/events/events.go @@ -21,6 +21,7 @@ import ( type Publisher struct { logger log.Logger + source string topic pubsub.TopicPublisher opts PublishStreamOptions @@ -50,14 +51,39 @@ func NewPublisherForStream( if opts.ConcurrencyLimit <= 0 { opts.ConcurrencyLimit = 250 } + + var source string + switch identifier := metadata.GetIdentifier(); identifier.GetIdentifier().(type) { + case *telemetrygatewayv1.Identifier_LicensedInstance: + source = "licensed_instance" + case *telemetrygatewayv1.Identifier_UnlicensedInstance: + source = "unlicensed_instance" + case *telemetrygatewayv1.Identifier_ManagedService: + // Is a trusted client, so use the service ID directly as the source + source = identifier.GetManagedService().ServiceId + default: + source = "unknown" + } + return &Publisher{ - logger: logger, + logger: logger.With(log.String("source", source)), + source: source, topic: eventsTopic, opts: opts, metadataJSON: metadataJSON, }, nil } +// GetSourceName returns a name inferred from metadata provided to +// NewPublisherForStream, for use as a metric label. It is safe to call on a nil +// publisher. +func (p *Publisher) GetSourceName() string { + if p == nil { + return "invalid" + } + return p.source +} + type PublishEventResult struct { // EventID is the ID of the event that was published. EventID string @@ -74,6 +100,21 @@ func (p *Publisher) Publish(ctx context.Context, events []*telemetrygatewayv1.Ev event := event // capture range variable :( doPublish := func(event *telemetrygatewayv1.Event) error { + // Ensure the most important fields are in place + if event.Id == "" { + return errors.New("event ID is required") + } + if event.Feature == "" { + return errors.New("event feature is required") + } + if event.Action == "" { + return errors.New("event action is required") + } + if event.Timestamp == nil { + return errors.New("event timestamp is required") + } + + // Render JSON format for publishing eventJSON, err := protojson.Marshal(event) if err != nil { return errors.Wrap(err, "marshalling event") @@ -120,7 +161,7 @@ func (p *Publisher) Publish(ctx context.Context, events []*telemetrygatewayv1.Ev // Publish a single message in each callback to manage concurrency // ourselves, and attach attributes for ease of routing the pub/sub // message. - if err := p.topic.PublishMessage(ctx, payload, extractPubSubAttributes(event)); err != nil { + if err := p.topic.PublishMessage(ctx, payload, extractPubSubAttributes(p.source, event)); err != nil { // Try to record the cancel cause as the primary error in case // one is recorded. if cancelCause := context.Cause(ctx); cancelCause != nil { diff --git a/cmd/telemetry-gateway/internal/events/events_test.go b/cmd/telemetry-gateway/internal/events/events_test.go index 777b9a36d25..ce05a7a1917 100644 --- a/cmd/telemetry-gateway/internal/events/events_test.go +++ b/cmd/telemetry-gateway/internal/events/events_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/sourcegraph/log/logtest" @@ -35,18 +36,28 @@ func TestPublish(t *testing.T) { publisher, err := events.NewPublisherForStream( logtest.Scoped(t), memTopic, - &telemetrygatewayv1.RecordEventsRequestMetadata{}, + &telemetrygatewayv1.RecordEventsRequestMetadata{ + Identifier: &telemetrygatewayv1.Identifier{ + Identifier: &telemetrygatewayv1.Identifier_LicensedInstance{ + LicensedInstance: &telemetrygatewayv1.Identifier_LicensedInstanceIdentifier{}, + }, + }, + }, events.PublishStreamOptions{ ConcurrencyLimit: concurrency, }) require.NoError(t, err) + // Check the evaluated source + assert.Equal(t, "licensed_instance", publisher.GetSourceName()) + events := make([]*telemetrygatewayv1.Event, concurrency) for i := range events { events[i] = &telemetrygatewayv1.Event{ - Id: strconv.Itoa(i), - Feature: t.Name(), - Action: strconv.Itoa(i), + Id: strconv.Itoa(i), + Feature: t.Name(), + Action: strconv.Itoa(i), + Timestamp: timestamppb.Now(), } } @@ -77,9 +88,13 @@ func TestPublish(t *testing.T) { var payload map[string]json.RawMessage require.NoError(t, json.Unmarshal(m.Data, &payload)) - var event telemetrygatewayv1.Event + var event struct { + Id string + Feature string + Action string + } require.NoError(t, json.Unmarshal(payload["event"], &event)) - publishedEvents[event.GetId()] = true + publishedEvents[event.Id] = true assert.Equal(t, event.Feature, m.Attributes["event.feature"]) assert.Equal(t, event.Action, m.Attributes["event.action"]) diff --git a/cmd/telemetry-gateway/internal/server/BUILD.bazel b/cmd/telemetry-gateway/internal/server/BUILD.bazel index 96e324dd269..c56688c0135 100644 --- a/cmd/telemetry-gateway/internal/server/BUILD.bazel +++ b/cmd/telemetry-gateway/internal/server/BUILD.bazel @@ -13,8 +13,10 @@ go_library( visibility = ["//cmd/telemetry-gateway:__subpackages__"], deps = [ "//cmd/telemetry-gateway/internal/events", + "//cmd/telemetry-gateway/internal/server/samsm2m", "//internal/licensing", "//internal/pubsub", + "//internal/sams", "//internal/telemetrygateway/v1:telemetrygateway", "//internal/trace", "//lib/errors", diff --git a/cmd/telemetry-gateway/internal/server/metrics.go b/cmd/telemetry-gateway/internal/server/metrics.go index 8c5422acbcc..6b7bb189ca3 100644 --- a/cmd/telemetry-gateway/internal/server/metrics.go +++ b/cmd/telemetry-gateway/internal/server/metrics.go @@ -44,3 +44,19 @@ func newRecordEventsMetrics() (m recordEventsMetrics, err error) { return m, err } + +type recordEventMetrics struct { + // Count of processed events + processedEvents metric.Int64Counter +} + +func newRecordEventMetrics() (m recordEventMetrics, err error) { + m.processedEvents, err = meter.Int64Counter( + "telemetry-gateway.record_event.processed_events", + metric.WithDescription("Number of events processed")) + if err != nil { + return m, err + } + + return m, nil +} diff --git a/cmd/telemetry-gateway/internal/server/publish_events.go b/cmd/telemetry-gateway/internal/server/publish_events.go index 7a8a4b4e823..3edcbf339be 100644 --- a/cmd/telemetry-gateway/internal/server/publish_events.go +++ b/cmd/telemetry-gateway/internal/server/publish_events.go @@ -36,13 +36,14 @@ func handlePublishEvents( // Record the result on the trace and metrics resultAttribute := attribute.String("result", summary.result) - tr.SetAttributes(resultAttribute) + sourceAttribute := attribute.String("source", publisher.GetSourceName()) + tr.SetAttributes(resultAttribute, sourceAttribute) payloadMetrics.length.Record(ctx, int64(len(events)), - metric.WithAttributes(resultAttribute)) + metric.WithAttributes(resultAttribute, sourceAttribute)) payloadMetrics.processedEvents.Add(ctx, int64(len(summary.succeededEvents)), - metric.WithAttributes(attribute.Bool("succeeded", true), resultAttribute)) + metric.WithAttributes(attribute.Bool("succeeded", true), resultAttribute, sourceAttribute)) payloadMetrics.processedEvents.Add(ctx, int64(len(summary.failedEvents)), - metric.WithAttributes(attribute.Bool("succeeded", false), resultAttribute)) + metric.WithAttributes(attribute.Bool("succeeded", false), resultAttribute, sourceAttribute)) // Generate a log message for convenience summaryFields := []log.Field{ diff --git a/cmd/telemetry-gateway/internal/server/samsm2m/BUILD.bazel b/cmd/telemetry-gateway/internal/server/samsm2m/BUILD.bazel new file mode 100644 index 00000000000..673a61dbb65 --- /dev/null +++ b/cmd/telemetry-gateway/internal/server/samsm2m/BUILD.bazel @@ -0,0 +1,36 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//dev:go_defs.bzl", "go_test") + +go_library( + name = "samsm2m", + srcs = ["samsm2m.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/telemetry-gateway/internal/server/samsm2m", + visibility = ["//cmd/telemetry-gateway:__subpackages__"], + deps = [ + "//internal/authbearer", + "//internal/sams", + "//lib/errors", + "@com_github_sourcegraph_log//:log", + "@io_opentelemetry_go_otel//:otel", + "@io_opentelemetry_go_otel//codes", + "@io_opentelemetry_go_otel_trace//:trace", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//metadata", + "@org_golang_google_grpc//status", + ], +) + +go_test( + name = "samsm2m_test", + srcs = ["samsm2m_test.go"], + embed = [":samsm2m"], + deps = [ + "//internal/sams", + "//lib/errors", + "@com_github_hexops_autogold_v2//:autogold", + "@com_github_sourcegraph_log//logtest", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//metadata", + ], +) diff --git a/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m.go b/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m.go new file mode 100644 index 00000000000..18ed13d4c33 --- /dev/null +++ b/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m.go @@ -0,0 +1,86 @@ +package samsm2m + +import ( + "context" + "slices" + "strings" + + "go.opentelemetry.io/otel" + otelcodes "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/sourcegraph/log" + + "github.com/sourcegraph/sourcegraph/internal/authbearer" + "github.com/sourcegraph/sourcegraph/internal/sams" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const requiredSamsScope = "telemetry_gateway::events::write" + +var tracer = otel.GetTracerProvider().Tracer("telemetry-gateway/samsm2m") + +// CheckWriteEventsScope ensures the request context has a valid SAMS MSM token +// with requiredSamsScope. It returns a gRPC status error suitable to be returned +// directly from an RPC implementation. +// +// See: go/sams-m2m +func CheckWriteEventsScope(ctx context.Context, logger log.Logger, samsClient sams.Client) (err error) { + var span trace.Span + ctx, span = tracer.Start(ctx, "CheckWriteEventsScope") + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(otelcodes.Error, "check failed") + } + span.End() + }() + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "no token header") + } + + var token string + if v := md.Get("authorization"); len(v) == 1 && v[0] != "" { + var err error + token, err = authbearer.ExtractBearerContents(v[0]) + if err != nil { + return status.Errorf(codes.Unauthenticated, "invalid token header: %v", err) + } + } else { + return status.Error(codes.Unauthenticated, "no token header value") + } + + // TODO: as part of go/sams-m2m we need to build out a SDK for SAMS M2M + // consumers that has a recommended short-caching mechanism. Avoid doing it + // for now until we have a concerted effort. + result, err := samsClient.IntrospectToken(ctx, token) + if err != nil { + logger.Error("samsClient.IntrospectToken failed", log.Error(err)) + return status.Error(codes.Internal, "unable to validate token") + } + + // Active encapsulates whether the token is active, including expiration. + if !result.Active { + // Record detailed error in span, and return an opaque one + span.RecordError(errors.New("inactive scope")) + return status.Error(codes.PermissionDenied, "permission denied") + } + + // Check for our required scope. + gotScopes := strings.Split(result.Scope, " ") + if !slices.Contains(gotScopes, requiredSamsScope) { + // Record detailed error in span and logs, and return an opaque one + err = errors.Newf("got scopes %q, required: %q", gotScopes, requiredSamsScope) + span.RecordError(err) + logger.Error("attempt to authenticate using SAMS token without required scope", + log.Error(err)) + return status.Error(codes.PermissionDenied, "permission denied") + } + + return nil +} diff --git a/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m_test.go b/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m_test.go new file mode 100644 index 00000000000..cc28d487f25 --- /dev/null +++ b/cmd/telemetry-gateway/internal/server/samsm2m/samsm2m_test.go @@ -0,0 +1,97 @@ +package samsm2m + +import ( + "context" + "testing" + + "github.com/hexops/autogold/v2" + "github.com/sourcegraph/log/logtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "github.com/sourcegraph/sourcegraph/internal/sams" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +type mockSAMSClient struct { + result *sams.TokenIntrospection + error error +} + +func (m mockSAMSClient) IntrospectToken(context.Context, string) (*sams.TokenIntrospection, error) { + return m.result, m.error +} + +func TestCheckWriteEventsScope(t *testing.T) { + for _, tc := range []struct { + name string + metadata map[string]string + samsClient sams.Client + wantErr autogold.Value + }{ + { + name: "no metadata", + metadata: nil, + samsClient: nil, // will not be used + wantErr: autogold.Expect("rpc error: code = Unauthenticated desc = no token header"), + }, + { + name: "no authorization header", + metadata: map[string]string{"somethingelse": "foobar"}, + samsClient: nil, // will not be used + wantErr: autogold.Expect("rpc error: code = Unauthenticated desc = no token header value"), + }, + { + name: "malformed authorization header", + metadata: map[string]string{"authorization": "bearer"}, + samsClient: nil, // will not be used + wantErr: autogold.Expect("rpc error: code = Unauthenticated desc = invalid token header: token type missing in Authorization header"), + }, + { + name: "token ok, introspect failed", + metadata: map[string]string{"authorization": "bearer foobar"}, + samsClient: mockSAMSClient{error: errors.New("introspection failed")}, + wantErr: autogold.Expect("rpc error: code = Internal desc = unable to validate token"), + }, + { + name: "token ok, but inactive", + metadata: map[string]string{"authorization": "bearer foobar"}, + samsClient: mockSAMSClient{result: &sams.TokenIntrospection{Active: false}}, + wantErr: autogold.Expect("rpc error: code = PermissionDenied desc = permission denied"), + }, + { + name: "token ok and active, but invalid scope", + metadata: map[string]string{"authorization": "bearer foobar"}, + samsClient: mockSAMSClient{result: &sams.TokenIntrospection{Active: true, Scope: "foo bar"}}, + wantErr: autogold.Expect("rpc error: code = PermissionDenied desc = permission denied"), + }, + { + name: "token ok and active and valid scope", + metadata: map[string]string{"authorization": "bearer foobar"}, + samsClient: mockSAMSClient{ + result: &sams.TokenIntrospection{ + Active: true, + Scope: "foo bar " + requiredSamsScope, + }, + }, + wantErr: nil, // success + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + if len(tc.metadata) > 0 { + // we mock the ctx of an incoming context + ctx = metadata.NewIncomingContext(ctx, metadata.New(tc.metadata)) + } + + err := CheckWriteEventsScope(ctx, logtest.Scoped(t), tc.samsClient) + if tc.wantErr == nil { + assert.NoError(t, err) + } else { + require.Error(t, err) + tc.wantErr.Equal(t, err.Error()) + } + }) + } +} diff --git a/cmd/telemetry-gateway/internal/server/server.go b/cmd/telemetry-gateway/internal/server/server.go index 45bec9858a0..17f0d7ab83d 100644 --- a/cmd/telemetry-gateway/internal/server/server.go +++ b/cmd/telemetry-gateway/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "io" @@ -11,12 +12,14 @@ import ( "github.com/sourcegraph/log" - "github.com/sourcegraph/sourcegraph/cmd/telemetry-gateway/internal/events" "github.com/sourcegraph/sourcegraph/internal/licensing" "github.com/sourcegraph/sourcegraph/internal/pubsub" + "github.com/sourcegraph/sourcegraph/internal/sams" sgtrace "github.com/sourcegraph/sourcegraph/internal/trace" "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/cmd/telemetry-gateway/internal/events" + "github.com/sourcegraph/sourcegraph/cmd/telemetry-gateway/internal/server/samsm2m" telemetrygatewayv1 "github.com/sourcegraph/sourcegraph/internal/telemetrygateway/v1" ) @@ -25,7 +28,11 @@ type Server struct { eventsTopic pubsub.TopicPublisher publishOpts events.PublishStreamOptions + // samsClient is used for M2M authn/authz: go/sams-m2m + samsClient sams.Client + recordEventsMetrics recordEventsMetrics + recordEventMetrics recordEventMetrics // Fallback unimplemented handler telemetrygatewayv1.UnimplementedTelemeteryGatewayServiceServer @@ -33,8 +40,17 @@ type Server struct { var _ telemetrygatewayv1.TelemeteryGatewayServiceServer = (*Server)(nil) -func New(logger log.Logger, eventsTopic pubsub.TopicPublisher, publishOpts events.PublishStreamOptions) (*Server, error) { - m, err := newRecordEventsMetrics() +func New( + logger log.Logger, + eventsTopic pubsub.TopicPublisher, + samsClient sams.Client, + publishOpts events.PublishStreamOptions, +) (*Server, error) { + recordEventsRPCMetrics, err := newRecordEventsMetrics() + if err != nil { + return nil, err + } + recordEventRPCMetrics, err := newRecordEventMetrics() if err != nil { return nil, err } @@ -44,13 +60,17 @@ func New(logger log.Logger, eventsTopic pubsub.TopicPublisher, publishOpts event eventsTopic: eventsTopic, publishOpts: publishOpts, - recordEventsMetrics: m, + samsClient: samsClient, + + recordEventsMetrics: recordEventsRPCMetrics, + recordEventMetrics: recordEventRPCMetrics, }, nil } func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService_RecordEventsServer) (err error) { var ( - logger = sgtrace.Logger(stream.Context(), s.logger) + logger = sgtrace.Logger(stream.Context(), s.logger). + Scoped("RecordEvent") // publisher is initialized once for RecordEventsRequestMetadata. publisher *events.Publisher // count of all processed events, collected at the end of a request @@ -60,7 +80,10 @@ func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService defer func() { s.recordEventsMetrics.totalLength.Record(stream.Context(), totalProcessedEvents, - metric.WithAttributes(attribute.Bool("error", err != nil))) + metric.WithAttributes( + attribute.Bool("error", err != nil), + attribute.String("source", publisher.GetSourceName()), + )) }() for { @@ -82,14 +105,13 @@ func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService logger = logger.With(log.String("requestID", metadata.GetRequestId())) // Validate self-reported instance identifier - switch metadata.GetIdentifier().Identifier.(type) { + switch metadata.GetIdentifier().GetIdentifier().(type) { case *telemetrygatewayv1.Identifier_LicensedInstance: identifier := metadata.Identifier.GetLicensedInstance() licenseInfo, _, err := licensing.ParseProductLicenseKey(identifier.GetLicenseKey()) if err != nil { return status.Errorf(codes.InvalidArgument, "invalid license_key: %s", err) } - // Attach instance ID to all subsequent log messages logger = logger.With(log.String("instanceID", identifier.InstanceId)) // Record start of stream + additional diagnostics details // like salesforce info and external URL once @@ -103,13 +125,29 @@ func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService if identifier.InstanceId == "" { return status.Error(codes.InvalidArgument, "instance_id is required for unlicensed instance") } - // Attach instance ID to all subsequent log messages logger = logger.With(log.String("instanceID", identifier.InstanceId)) // Record start of stream logger.Info("handling events submission stream for unlicensed instance") + case *telemetrygatewayv1.Identifier_ManagedService: + identifier := metadata.Identifier.GetManagedService() + if identifier.ServiceId == "" { + return status.Error(codes.InvalidArgument, "service_id is required for managed services") + } + logger = logger.With( + log.String("serviceID", identifier.ServiceId), + log.Stringp("serviceEnvironment", identifier.ServiceEnvironment)) + + // 🚨 SECURITY: Only known clients registered in SAMS can submit events + // as a managed service. + if err := samsm2m.CheckWriteEventsScope(stream.Context(), logger, s.samsClient); err != nil { + return err + } + + logger.Info("handling events submission stream for managed service") + default: - logger.Error("unknown identifier type", + logger.Error("identifier not supported for this RPC", log.String("type", fmt.Sprintf("%T", metadata.Identifier.Identifier))) return status.Error(codes.Unimplemented, "unsupported identifier type") } @@ -119,6 +157,7 @@ func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService if err != nil { return status.Errorf(codes.Internal, "failed to create publisher: %v", err) } + logger = logger.With(log.String("source", publisher.GetSourceName())) case *telemetrygatewayv1.RecordEventsRequest_Events: events := msg.GetEvents().GetEvents() @@ -156,3 +195,78 @@ func (s *Server) RecordEvents(stream telemetrygatewayv1.TelemeteryGatewayService logger.Info("request done") return nil } + +func (s *Server) RecordEvent(ctx context.Context, req *telemetrygatewayv1.RecordEventRequest) (_ *telemetrygatewayv1.RecordEventResponse, err error) { + var ( + metadata = req.GetMetadata() + event = req.GetEvent() + ) + if event == nil { + return nil, status.Error(codes.InvalidArgument, "event is required") + } + + logger := sgtrace.Logger(ctx, s.logger). + Scoped("RecordEvent"). + With( + log.String("requestID", metadata.GetRequestId()), + // Include more liberal amounts of diagnostics because this RPC + // currently has a more limited audience + log.String("eventID", event.GetId()), + log.String("eventFeature", event.GetFeature()), + log.String("eventAction", event.GetAction())) + + // We only allow a limited set of identifiers to use this RPC for now, as + // Sourcegraph instances should only use RecordEvents. + switch metadata.GetIdentifier().GetIdentifier().(type) { + case *telemetrygatewayv1.Identifier_ManagedService: + identifier := metadata.Identifier.GetManagedService() + if identifier.ServiceId == "" { + return nil, status.Error(codes.InvalidArgument, "service_id is required for managed services") + } + logger = logger.With( + log.String("serviceID", identifier.ServiceId), + log.Stringp("serviceEnvironment", identifier.ServiceEnvironment)) + + // 🚨 SECURITY: Only known clients registered in SAMS can submit events + // as a managed service. + if err := samsm2m.CheckWriteEventsScope(ctx, logger, s.samsClient); err != nil { + return nil, err + } + + default: + logger.Error("identifier not supported for this RPC", + log.String("type", fmt.Sprintf("%T", metadata.Identifier.Identifier))) + return nil, status.Error(codes.Unimplemented, "unsupported identifier type") + } + + // Set up a publisher with the provided metadata + publisher, err := events.NewPublisherForStream(s.logger, s.eventsTopic, metadata, s.publishOpts) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create publisher: %v", err) + } + logger = logger.With(log.String("source", publisher.GetSourceName())) + + defer func() { + s.recordEventMetrics.processedEvents.Add(ctx, + 1, // RPC only accepts 1 event at a time + metric.WithAttributes( + attribute.Bool("error", err != nil), + attribute.String("source", publisher.GetSourceName()))) + }() + + // Submit the single event + results := publisher.Publish(ctx, []*telemetrygatewayv1.Event{event}) + if len(results) != 1 { + logger.Error("unexpected result when publishing", + log.Error(errors.Newf("expected 1 result, got %d", len(results)))) + return nil, status.Errorf(codes.Internal, "unexpected publishing issue") + } + if err := results[0].PublishError; err != nil { + logger.Error("failed to publish event", log.Error(err)) + return nil, status.Errorf(codes.Internal, "failed to publish event: %v", err) + } + + return &telemetrygatewayv1.RecordEventResponse{ + // no properties + }, nil +} diff --git a/cmd/telemetry-gateway/main.go b/cmd/telemetry-gateway/main.go index ea27d87370b..e64f11844ab 100644 --- a/cmd/telemetry-gateway/main.go +++ b/cmd/telemetry-gateway/main.go @@ -6,5 +6,5 @@ import ( ) func main() { - runtime.Start[service.Config](&service.Service{}) + runtime.Start(&service.Service{}) } diff --git a/cmd/telemetry-gateway/service/BUILD.bazel b/cmd/telemetry-gateway/service/BUILD.bazel index 6cfc20f171e..9550e8cd208 100644 --- a/cmd/telemetry-gateway/service/BUILD.bazel +++ b/cmd/telemetry-gateway/service/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//internal/grpc/defaults", "//internal/httpserver", "//internal/pubsub", + "//internal/sams", "//internal/telemetrygateway/v1:telemetrygateway", "//internal/trace/policy", "//internal/version", @@ -25,5 +26,6 @@ go_library( "@com_github_sourcegraph_log//:log", "@io_opentelemetry_go_otel//:otel", "@io_opentelemetry_go_otel_metric//:metric", + "@org_golang_x_oauth2//clientcredentials", ], ) diff --git a/cmd/telemetry-gateway/service/config.go b/cmd/telemetry-gateway/service/config.go index eb7c65aec24..58ff53a59a7 100644 --- a/cmd/telemetry-gateway/service/config.go +++ b/cmd/telemetry-gateway/service/config.go @@ -14,6 +14,12 @@ type Config struct { StreamPublishConcurrency int } + + SAMS struct { + ServerURL string + ClientID string + ClientSecret string + } } func (c *Config) Load(env *runtime.Env) { @@ -25,4 +31,11 @@ func (c *Config) Load(env *runtime.Env) { "The topic ID for the Pub/Sub.") c.Events.StreamPublishConcurrency = env.GetInt("TELEMETRY_GATEWAY_EVENTS_STREAM_PUBLISH_CONCURRENCY", "250", "Per-stream concurrent publishing limit.") + + c.SAMS.ServerURL = env.Get("TELEMETRY_GATEWAY_SAMS_SERVER_URL", "https://accounts.sourcegraph.com", + "Sourcegraph Accounts Management System URL") + c.SAMS.ClientID = env.Get("TELEMETRY_GATEWAY_SAMS_CLIENT_ID", "", + "Sourcegraph Accounts Management System client ID") + c.SAMS.ClientSecret = env.Get("TELEMETRY_GATEWAY_SAMS_CLIENT_SECRET", "", + "Sourcegraph Accounts Management System client secret") } diff --git a/cmd/telemetry-gateway/service/service.go b/cmd/telemetry-gateway/service/service.go index d47871ccd4b..d149621430f 100644 --- a/cmd/telemetry-gateway/service/service.go +++ b/cmd/telemetry-gateway/service/service.go @@ -10,12 +10,14 @@ import ( "github.com/sourcegraph/log" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" + "golang.org/x/oauth2/clientcredentials" "github.com/sourcegraph/sourcegraph/internal/debugserver" internalgrpc "github.com/sourcegraph/sourcegraph/internal/grpc" "github.com/sourcegraph/sourcegraph/internal/grpc/defaults" "github.com/sourcegraph/sourcegraph/internal/httpserver" "github.com/sourcegraph/sourcegraph/internal/pubsub" + "github.com/sourcegraph/sourcegraph/internal/sams" "github.com/sourcegraph/sourcegraph/internal/trace/policy" "github.com/sourcegraph/sourcegraph/internal/version" @@ -61,11 +63,23 @@ func (Service) Initialize(ctx context.Context, logger log.Logger, contract runti return nil, errors.Wrap(err, "create pubsub.published_message_size metric") } + // Prepare SAMS client, so that we can enforce SAMS-based M2M authz/authn + logger.Debug("using SAMS client", + log.String("samsServer", config.SAMS.ServerURL), + log.String("clientID", config.SAMS.ClientID)) + samsClient := sams.NewClient(config.SAMS.ServerURL, clientcredentials.Config{ + ClientID: config.SAMS.ClientID, + ClientSecret: config.SAMS.ClientSecret, + TokenURL: fmt.Sprintf("%s/oauth/token", config.SAMS.ServerURL), + Scopes: []string{"openid", "profile", "email"}, + }) + // Initialize our gRPC server grpcServer := defaults.NewPublicServer(logger) telemetryGatewayServer, err := server.New( logger, eventsTopic, + samsClient, events.PublishStreamOptions{ ConcurrencyLimit: config.Events.StreamPublishConcurrency, MessageSizeHistogram: publishMessageBytes, @@ -87,6 +101,7 @@ func (Service) Initialize(ctx context.Context, logger log.Logger, contract runti // development! grpcUI := debugserver.NewGRPCWebUIEndpoint("telemetry-gateway", listenAddr) diagnosticsServer.Handle(grpcUI.Path, grpcUI.Handler) + logger.Warn("gRPC web UI enabled", log.String("url", fmt.Sprintf("%s%s", listenAddr, grpcUI.Path))) } return background.LIFOStopRoutine{ diff --git a/doc/dev/background-information/telemetry/protocol.md b/doc/dev/background-information/telemetry/protocol.md index 1f752c73724..528e4033c71 100644 --- a/doc/dev/background-information/telemetry/protocol.md +++ b/doc/dev/background-information/telemetry/protocol.md @@ -305,7 +305,7 @@ Sourcegraph.com instance and managed services. | ----- | ---- | ----- | ----------- | | licensed_instance | [Identifier.LicensedInstanceIdentifier](#telemetrygateway-v1-Identifier-LicensedInstanceIdentifier) | |
A licensed Sourcegraph instance.
| | unlicensed_instance | [Identifier.UnlicensedInstanceIdentifier](#telemetrygateway-v1-Identifier-UnlicensedInstanceIdentifier) | |An unlicensed Sourcegraph instance.
| -| managed_service | [Identifier.ManagedServiceIdentifier](#telemetrygateway-v1-Identifier-ManagedServiceIdentifier) | |A service operated and managed by the Sourcegraph team, for example
a service deployed by https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/
| +| managed_service | [Identifier.ManagedServiceIdentifier](#telemetrygateway-v1-Identifier-ManagedServiceIdentifier) | |A service operated and managed by the Sourcegraph team, for example
a service deployed by MSP: https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/
Valid SAMS client credentials are required to publish events under a
managed service identifier. The required scope is
'telemetry_gateway::events::publish'. See go/sams-client-credentials and
go/sams-token-scopes for more information.
| @@ -463,7 +463,7 @@ Sourcegraph.com instance and managed services. | Method Name | Request Type | Response Type | Description | | ----------- | ------------ | ------------- | ------------| | RecordEvents | [RecordEventsRequest](#telemetrygateway-v1-RecordEventsRequest) stream | [RecordEventsResponse](#telemetrygateway-v1-RecordEventsResponse) stream |RecordEvents streams telemetry events in batches to the Telemetry Gateway
service. Events should only be considered delivered if recording is
acknowledged in RecordEventsResponse.
This is the preferred mechanism for exporting large volumes of events in
bulk.
🚨 SECURITY: Callers exporting for single-tenant Sourcegraph should check
the attributes of the Event type to ensure that only the appropriate fields
are exported, as some fields should only be exported on an allowlist basis.
| -| RecordEvent | [RecordEventRequest](#telemetrygateway-v1-RecordEventRequest) | [RecordEventResponse](#telemetrygateway-v1-RecordEventResponse) |RecordEvent records a single telemetry event to the Telemetry Gateway service.
If the RPC succeeds, then the event was successfully published.
This mechanism is intended for low-volume managed services. Higher-volume
use cases should implement a batching mechanism and use the RecordEvents
RPC instead.
🚨 SECURITY: Callers exporting for single-tenant Sourcegraph should check
the attributes of the Event type to ensure that only the appropriate fields
are exported, as some fields should only be exported on an allowlist basis.
| +| RecordEvent | [RecordEventRequest](#telemetrygateway-v1-RecordEventRequest) | [RecordEventResponse](#telemetrygateway-v1-RecordEventResponse) |RecordEvent records a single telemetry event to the Telemetry Gateway service.
If the RPC succeeds, then the event was successfully published.
This RPC currently ONLY accepts events published by ManagedServiceIdentifier,
as this mechanism is intended for low-volume managed services. Higher-volume
use cases should implement a batching mechanism and use the RecordEvents
RPC instead.
🚨 SECURITY: Callers exporting for single-tenant Sourcegraph should check
the attributes of the Event type to ensure that only the appropriate fields
are exported, as some fields should only be exported on an allowlist basis.
| diff --git a/internal/authbearer/authbearer.go b/internal/authbearer/authbearer.go index f93238c97b6..bb5bd9543d3 100644 --- a/internal/authbearer/authbearer.go +++ b/internal/authbearer/authbearer.go @@ -8,19 +8,22 @@ import ( ) func ExtractBearer(h http.Header) (string, error) { - var token string - if authHeader := h.Get("Authorization"); authHeader != "" { - typ := strings.SplitN(authHeader, " ", 2) - if len(typ) != 2 { - return "", errors.New("token type missing in Authorization header") - } - if strings.ToLower(typ[0]) != "bearer" { - return "", errors.Newf("invalid token type %s", typ[0]) - } - - token = typ[1] + return ExtractBearerContents(authHeader) } - - return token, nil + return "", nil +} + +func ExtractBearerContents(s string) (string, error) { + if s == "" { + return "", errors.New("no token provided in Authorization header") + } + typ := strings.SplitN(s, " ", 2) + if len(typ) != 2 { + return "", errors.New("token type missing in Authorization header") + } + if strings.ToLower(typ[0]) != "bearer" { + return "", errors.Newf("invalid token type %s in Authorization header", typ[0]) + } + return typ[1], nil } diff --git a/internal/telemetrygateway/v1/telemetrygateway.pb.go b/internal/telemetrygateway/v1/telemetrygateway.pb.go index a307b6de0d7..7213fb96c48 100644 --- a/internal/telemetrygateway/v1/telemetrygateway.pb.go +++ b/internal/telemetrygateway/v1/telemetrygateway.pb.go @@ -119,7 +119,12 @@ type Identifier_UnlicensedInstance struct { type Identifier_ManagedService struct { // A service operated and managed by the Sourcegraph team, for example - // a service deployed by https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/ + // a service deployed by MSP: https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/ + // + // Valid SAMS client credentials are required to publish events under a + // managed service identifier. The required scope is + // 'telemetry_gateway::events::publish'. See go/sams-client-credentials and + // go/sams-token-scopes for more information. ManagedService *Identifier_ManagedServiceIdentifier `protobuf:"bytes,3,opt,name=managed_service,json=managedService,proto3,oneof"` } diff --git a/internal/telemetrygateway/v1/telemetrygateway.proto b/internal/telemetrygateway/v1/telemetrygateway.proto index c09b1ebf0da..581f64e9e89 100644 --- a/internal/telemetrygateway/v1/telemetrygateway.proto +++ b/internal/telemetrygateway/v1/telemetrygateway.proto @@ -31,7 +31,8 @@ service TelemeteryGatewayService { // RecordEvent records a single telemetry event to the Telemetry Gateway service. // If the RPC succeeds, then the event was successfully published. // - // This mechanism is intended for low-volume managed services. Higher-volume + // This RPC currently ONLY accepts events published by ManagedServiceIdentifier, + // as this mechanism is intended for low-volume managed services. Higher-volume // use cases should implement a batching mechanism and use the RecordEvents // RPC instead. // @@ -71,7 +72,12 @@ message Identifier { // An unlicensed Sourcegraph instance. UnlicensedInstanceIdentifier unlicensed_instance = 2; // A service operated and managed by the Sourcegraph team, for example - // a service deployed by https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/ + // a service deployed by MSP: https://handbook.sourcegraph.com/departments/engineering/teams/core-services/managed-services/platform/ + // + // Valid SAMS client credentials are required to publish events under a + // managed service identifier. The required scope is + // 'telemetry_gateway::events::publish'. See go/sams-client-credentials and + // go/sams-token-scopes for more information. ManagedServiceIdentifier managed_service = 3; } } diff --git a/internal/telemetrygateway/v1/telemetrygateway_grpc.pb.go b/internal/telemetrygateway/v1/telemetrygateway_grpc.pb.go index 5db87f3b2bb..05155265458 100644 --- a/internal/telemetrygateway/v1/telemetrygateway_grpc.pb.go +++ b/internal/telemetrygateway/v1/telemetrygateway_grpc.pb.go @@ -49,7 +49,8 @@ type TelemeteryGatewayServiceClient interface { // RecordEvent records a single telemetry event to the Telemetry Gateway service. // If the RPC succeeds, then the event was successfully published. // - // This mechanism is intended for low-volume managed services. Higher-volume + // This RPC currently ONLY accepts events published by ManagedServiceIdentifier, + // as this mechanism is intended for low-volume managed services. Higher-volume // use cases should implement a batching mechanism and use the RecordEvents // RPC instead. // @@ -125,7 +126,8 @@ type TelemeteryGatewayServiceServer interface { // RecordEvent records a single telemetry event to the Telemetry Gateway service. // If the RPC succeeds, then the event was successfully published. // - // This mechanism is intended for low-volume managed services. Higher-volume + // This RPC currently ONLY accepts events published by ManagedServiceIdentifier, + // as this mechanism is intended for low-volume managed services. Higher-volume // use cases should implement a batching mechanism and use the RecordEvents // RPC instead. // diff --git a/lib/managedservicesplatform/runtime/env.go b/lib/managedservicesplatform/runtime/env.go index 24957c7008e..f0b99c6b355 100644 --- a/lib/managedservicesplatform/runtime/env.go +++ b/lib/managedservicesplatform/runtime/env.go @@ -75,7 +75,8 @@ func (e *Env) validate() error { func (e *Env) Get(name, defaultValue, description string) string { rawValue := e.get(name, defaultValue, description) if rawValue == "" { - e.AddError(errors.Errorf("invalid value %q for %s: no value supplied", rawValue, name)) + e.AddError(errors.Errorf("invalid value %q for %s: no value supplied, description: %s", + rawValue, name, description)) return "" } diff --git a/sg.config.yaml b/sg.config.yaml index 124c22dd625..9c9b9986749 100644 --- a/sg.config.yaml +++ b/sg.config.yaml @@ -334,6 +334,11 @@ commands: TELEMETRY_GATEWAY_EVENTS_PUBSUB_ENABLED: false SRC_LOG_LEVEL: info GRPC_WEB_UI_ENABLED: true + # Set for convenience - use real values in sg.config.overwrite.yaml if you + # are interacting with RPCs that enforce SAMS M2M auth. See + # https://github.com/sourcegraph/accounts.sourcegraph.com/wiki/Operators-Cheat-Sheet#create-a-new-idp-client + TELEMETRY_GATEWAY_SAMS_CLIENT_ID: "foo" + TELEMETRY_GATEWAY_SAMS_CLIENT_SECRET: "bar" watch: - lib - internal