mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
Enterprise instance uses gateway for attribution (#59513)
* Remove local snippet attribution * Draft * Attribution service updated with gateway client * Thursday progress * Attribution success test * Unify types for gateway request/response * Attribution client test * BAZEL * Update gateway request/response references * Remove Attribution request TODO
This commit is contained in:
parent
ea521b5757
commit
f5bcbfcbb2
@ -22,7 +22,6 @@ go_test(
|
||||
name = "attribution_test",
|
||||
srcs = ["handler_test.go"],
|
||||
deps = [
|
||||
":attribution",
|
||||
"//cmd/cody-gateway/internal/actor",
|
||||
"//cmd/cody-gateway/internal/auth",
|
||||
"//cmd/cody-gateway/internal/dotcom",
|
||||
|
||||
@ -20,30 +20,6 @@ import (
|
||||
// If a higher value is given, then this default is set.
|
||||
const LimitUpperBound = 4
|
||||
|
||||
// Request for attribution search. Expected in JSON form as the body of POST request.
|
||||
type Request struct {
|
||||
// Snippet is the text to search attribution of.
|
||||
Snippet string
|
||||
// Limit is the upper bound of number of responses we want to get.
|
||||
Limit int
|
||||
}
|
||||
|
||||
// Response of attribution search. Contains some repositories to which the snippet can be attributed to.
|
||||
type Response struct {
|
||||
// Repositories which contain code matching search snippet.
|
||||
Repositories []Repository
|
||||
// TotalCount denotes how many total matches there were (including listed repositories).
|
||||
TotalCount int
|
||||
// LimitHit is true if the number of search hits goes beyond limit specified in request.
|
||||
LimitHit bool
|
||||
}
|
||||
|
||||
// Repository represents matching of search content against a repository.
|
||||
type Repository struct {
|
||||
// Name of the repo on dotcom. Like github.com/sourcegraph/sourcegraph.
|
||||
Name string
|
||||
}
|
||||
|
||||
// NewHandler creates a REST handler for attribution search.
|
||||
// graphql.Client can be nil which disables the search.
|
||||
func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
|
||||
@ -59,7 +35,7 @@ func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
|
||||
response.JSONError(logger, w, http.StatusServiceUnavailable, errors.New("attribution search not enabled"))
|
||||
return
|
||||
}
|
||||
var request Request
|
||||
var request codygateway.AttributionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
response.JSONError(logger, w, http.StatusBadRequest, err)
|
||||
return
|
||||
@ -73,13 +49,13 @@ func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
|
||||
response.JSONError(logger, w, http.StatusServiceUnavailable, err)
|
||||
return
|
||||
}
|
||||
var rs []Repository
|
||||
var rs []codygateway.AttributionRepository
|
||||
for _, n := range searchResponse.SnippetAttribution.Nodes {
|
||||
rs = append(rs, Repository{Name: n.RepositoryName})
|
||||
rs = append(rs, codygateway.AttributionRepository{Name: n.RepositoryName})
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(&Response{
|
||||
if err := json.NewEncoder(w).Encode(&codygateway.AttributionResponse{
|
||||
Repositories: rs,
|
||||
TotalCount: searchResponse.SnippetAttribution.TotalCount,
|
||||
LimitHit: searchResponse.SnippetAttribution.LimitHit,
|
||||
|
||||
@ -32,7 +32,6 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/dotcom"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
)
|
||||
|
||||
@ -73,7 +72,7 @@ func runFakeGraphQL(t *testing.T) *fakeGraphQL {
|
||||
|
||||
// request creates an attribution search request to the gateway.
|
||||
func request(t *testing.T) *http.Request {
|
||||
requestBody, err := json.Marshal(&attribution.Request{
|
||||
requestBody, err := json.Marshal(&codygateway.AttributionRequest{
|
||||
Snippet: strings.Join([]string{
|
||||
"for n != 1 {",
|
||||
" if n % 2 == 0 {",
|
||||
@ -127,10 +126,10 @@ func TestSuccess(t *testing.T) {
|
||||
t.Error(w.Body.String())
|
||||
t.Fatalf("expected OK, got %d", got)
|
||||
}
|
||||
var gotResponseBody attribution.Response
|
||||
var gotResponseBody codygateway.AttributionResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&gotResponseBody))
|
||||
wantResponseBody := &attribution.Response{
|
||||
Repositories: []attribution.Repository{
|
||||
wantResponseBody := &codygateway.AttributionResponse{
|
||||
Repositories: []codygateway.AttributionRepository{
|
||||
{Name: "github.com/sourcegraph/sourcegraph"},
|
||||
{Name: "github.com/sourcegraph/cody"},
|
||||
},
|
||||
|
||||
@ -9,15 +9,12 @@ go_library(
|
||||
"//cmd/frontend/enterprise",
|
||||
"//cmd/frontend/envvar",
|
||||
"//cmd/frontend/internal/guardrails/attribution",
|
||||
"//cmd/frontend/internal/guardrails/dotcom",
|
||||
"//cmd/frontend/internal/guardrails/resolvers",
|
||||
"//internal/codeintel",
|
||||
"//internal/codygateway",
|
||||
"//internal/conf/conftypes",
|
||||
"//internal/database",
|
||||
"//internal/gitserver",
|
||||
"//internal/httpcli",
|
||||
"//internal/observation",
|
||||
"//internal/search/client",
|
||||
"//lib/errors",
|
||||
],
|
||||
)
|
||||
|
||||
@ -10,16 +10,9 @@ go_library(
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/attribution",
|
||||
visibility = ["//cmd/frontend:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/frontend/internal/guardrails/dotcom",
|
||||
"//internal/api",
|
||||
"//internal/codygateway",
|
||||
"//internal/metrics",
|
||||
"//internal/observation",
|
||||
"//internal/search",
|
||||
"//internal/search/client",
|
||||
"//internal/search/streaming",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@io_opentelemetry_go_otel//attribute",
|
||||
],
|
||||
@ -30,16 +23,8 @@ go_test(
|
||||
srcs = ["attribution_test.go"],
|
||||
embed = [":attribution"],
|
||||
deps = [
|
||||
"//cmd/frontend/internal/guardrails/dotcom",
|
||||
"//internal/database/dbmocks",
|
||||
"//internal/codygateway",
|
||||
"//internal/observation",
|
||||
"//internal/search/backend",
|
||||
"//internal/search/client",
|
||||
"//internal/search/job",
|
||||
"//internal/types",
|
||||
"@com_github_google_go_cmp//cmp",
|
||||
"@com_github_khan_genqlient//graphql",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_sourcegraph_zoekt//:zoekt",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,43 +2,19 @@ package attribution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
|
||||
"github.com/sourcegraph/sourcegraph/internal/api"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
// ServiceOpts configures Service.
|
||||
type ServiceOpts struct {
|
||||
// SearchClient is used to find attribution on the local instance.
|
||||
SearchClient client.SearchClient
|
||||
|
||||
// SourcegraphDotComClient is a graphql client that is queried if
|
||||
// federating out to sourcegraph.com is enabled.
|
||||
SourcegraphDotComClient dotcom.Client
|
||||
|
||||
// SourcegraphDotComFederate is true if this instance should also federate
|
||||
// to sourcegraph.com.
|
||||
SourcegraphDotComFederate bool
|
||||
}
|
||||
|
||||
// Service is for the attribution service which searches for matches on
|
||||
// snippets of code.
|
||||
//
|
||||
// Use NewService to construct this value.
|
||||
type Service struct {
|
||||
ServiceOpts
|
||||
|
||||
client codygateway.Client
|
||||
operations *operations
|
||||
}
|
||||
|
||||
@ -46,10 +22,10 @@ type Service struct {
|
||||
//
|
||||
// Note: this registers metrics so should only be called once with the same
|
||||
// observationCtx.
|
||||
func NewService(observationCtx *observation.Context, opts ServiceOpts) *Service {
|
||||
func NewService(observationCtx *observation.Context, client codygateway.Client) *Service {
|
||||
return &Service{
|
||||
operations: newOperations(observationCtx),
|
||||
ServiceOpts: opts,
|
||||
operations: newOperations(observationCtx),
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,173 +65,13 @@ func (c *Service) SnippetAttribution(ctx context.Context, snippet string, limit
|
||||
},
|
||||
})
|
||||
defer endObservationWithResult(traceLogger, endObservation, &result)()
|
||||
|
||||
limitHitErr := errors.New("limit hit error")
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
// we massage results in this function and possibly cancel if we can stop
|
||||
// looking.
|
||||
truncateAtLimit := func(result *SnippetAttributions) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if limit <= len(result.RepositoryNames) {
|
||||
result.LimitHit = true
|
||||
result.RepositoryNames = result.RepositoryNames[:limit]
|
||||
}
|
||||
if result.LimitHit {
|
||||
cancel(limitHitErr)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(keegancsmith) how should we handle partial errors?
|
||||
p := pool.New().WithContext(ctx).WithCancelOnError().WithFirstError()
|
||||
|
||||
// We don't use NewWithResults since we want local results to come before dotcom
|
||||
var local, dotcom *SnippetAttributions
|
||||
|
||||
p.Go(func(ctx context.Context) error {
|
||||
var err error
|
||||
local, err = c.snippetAttributionLocal(ctx, snippet, limit)
|
||||
truncateAtLimit(local)
|
||||
return err
|
||||
})
|
||||
|
||||
if c.SourcegraphDotComFederate {
|
||||
p.Go(func(ctx context.Context) error {
|
||||
var err error
|
||||
dotcom, err = c.snippetAttributionDotCom(ctx, snippet, limit)
|
||||
truncateAtLimit(dotcom)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
if err := p.Wait(); err != nil && context.Cause(ctx) != limitHitErr {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var agg SnippetAttributions
|
||||
seen := map[string]struct{}{}
|
||||
for _, result := range []*SnippetAttributions{local, dotcom} {
|
||||
if result == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Limitation: We just add to TotalCount even though that may mean we
|
||||
// overcount (both dotcom and local instance have the repo)
|
||||
agg.TotalCount += result.TotalCount
|
||||
agg.LimitHit = agg.LimitHit || result.LimitHit
|
||||
for _, name := range result.RepositoryNames {
|
||||
if _, ok := seen[name]; ok {
|
||||
// We have already counted this repo in the above TotalCount
|
||||
// increment, so undo that.
|
||||
agg.TotalCount--
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
agg.RepositoryNames = append(agg.RepositoryNames, name)
|
||||
}
|
||||
}
|
||||
|
||||
// we call truncateAtLimit on the aggregated result to ensure we only
|
||||
// return upto limit. Note this function will call cancel but that is fine
|
||||
// since we just return after this.
|
||||
truncateAtLimit(&agg)
|
||||
|
||||
return &agg, nil
|
||||
}
|
||||
|
||||
func (c *Service) snippetAttributionLocal(ctx context.Context, snippet string, limit int) (result *SnippetAttributions, err error) {
|
||||
ctx, traceLogger, endObservation := c.operations.snippetAttributionLocal.With(ctx, &err, observation.Args{})
|
||||
defer endObservationWithResult(traceLogger, endObservation, &result)()
|
||||
|
||||
const (
|
||||
version = "V3"
|
||||
searchMode = search.Precise
|
||||
protocol = search.Streaming
|
||||
)
|
||||
|
||||
patternType := "literal"
|
||||
searchQuery := fmt.Sprintf("type:file select:repo index:only case:yes count:%d content:%q", limit, snippet)
|
||||
|
||||
inputs, err := c.SearchClient.Plan(
|
||||
ctx,
|
||||
version,
|
||||
&patternType,
|
||||
searchQuery,
|
||||
searchMode,
|
||||
protocol,
|
||||
pointers.Ptr(int32(0)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create search plan")
|
||||
}
|
||||
|
||||
// TODO(keegancsmith) Reading the SearchClient code it seems to miss out
|
||||
// on some of the observability that we instead add in at a later stage.
|
||||
// For example the search dataset in honeycomb will be missing. Will have
|
||||
// to follow-up with observability and maybe solve it for all users.
|
||||
//
|
||||
// Note: In our current API we could just store repo names in seen. But it
|
||||
// is safer to rely on searches ranking for result stability than doing
|
||||
// something like sorting by name from the map.
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seen = map[api.RepoID]struct{}{}
|
||||
repoNames []string
|
||||
limitHit bool
|
||||
)
|
||||
_, err = c.SearchClient.Execute(ctx, streaming.StreamFunc(func(ev streaming.SearchEvent) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
limitHit = limitHit || ev.Stats.IsLimitHit
|
||||
|
||||
for _, m := range ev.Results {
|
||||
repo := m.RepoName()
|
||||
if _, ok := seen[repo.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[repo.ID] = struct{}{}
|
||||
repoNames = append(repoNames, string(repo.Name))
|
||||
}
|
||||
}), inputs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute search")
|
||||
}
|
||||
|
||||
// Note: Our search API is missing total count internally, but Zoekt does
|
||||
// expose this. For now we just count what we found.
|
||||
totalCount := len(repoNames)
|
||||
if len(repoNames) > limit {
|
||||
repoNames = repoNames[:limit]
|
||||
}
|
||||
|
||||
return &SnippetAttributions{
|
||||
RepositoryNames: repoNames,
|
||||
TotalCount: totalCount,
|
||||
LimitHit: limitHit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Service) snippetAttributionDotCom(ctx context.Context, snippet string, limit int) (result *SnippetAttributions, err error) {
|
||||
ctx, traceLogger, endObservation := c.operations.snippetAttributionDotCom.With(ctx, &err, observation.Args{})
|
||||
defer endObservationWithResult(traceLogger, endObservation, &result)()
|
||||
|
||||
resp, err := dotcom.SnippetAttribution(ctx, c.SourcegraphDotComClient, snippet, limit)
|
||||
attribution, err := c.client.Attribution(ctx, snippet, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var repoNames []string
|
||||
for _, node := range resp.SnippetAttribution.Nodes {
|
||||
repoNames = append(repoNames, node.RepositoryName)
|
||||
}
|
||||
|
||||
return &SnippetAttributions{
|
||||
RepositoryNames: repoNames,
|
||||
TotalCount: resp.SnippetAttribution.TotalCount,
|
||||
LimitHit: resp.SnippetAttribution.LimitHit,
|
||||
RepositoryNames: attribution.Repositories,
|
||||
TotalCount: len(attribution.Repositories), // TODO: Remove total count.
|
||||
LimitHit: attribution.LimitHit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -2,138 +2,33 @@ package attribution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Khan/genqlient/graphql"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
"github.com/sourcegraph/zoekt"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/dbmocks"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
searchbackend "github.com/sourcegraph/sourcegraph/internal/search/backend"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/job"
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
)
|
||||
|
||||
func TestAttribution(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type fakeGateway struct {
|
||||
codygateway.Client
|
||||
}
|
||||
|
||||
// inputs
|
||||
localCount, dotcomCount := 5, 5
|
||||
limit := localCount + dotcomCount + 1
|
||||
localNames := genRepoNames("localrepo-", localCount)
|
||||
dotcomNames := genRepoNames("dotcomrepo-", dotcomCount)
|
||||
func (f fakeGateway) Attribution(ctx context.Context, snippet string, limit int) (codygateway.Attribution, error) {
|
||||
return codygateway.Attribution{
|
||||
Repositories: []string{"repo1", "repo2"},
|
||||
LimitHit: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// we want the localNames back followed by dotcomNames
|
||||
wantCount := localCount + dotcomCount
|
||||
wantNames := append(genRepoNames("localrepo-", localCount), genRepoNames("dotcomrepo-", dotcomCount)...)
|
||||
|
||||
svc := NewService(observation.TestContextTB(t), ServiceOpts{
|
||||
SearchClient: mockSearchClient(t, localNames),
|
||||
SourcegraphDotComClient: mockDotComClient(t, dotcomNames),
|
||||
SourcegraphDotComFederate: true,
|
||||
})
|
||||
|
||||
result, err := svc.SnippetAttribution(ctx, "test", limit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := &SnippetAttributions{
|
||||
TotalCount: wantCount,
|
||||
func TestSuccess(t *testing.T) {
|
||||
gateway := fakeGateway{}
|
||||
service := NewService(observation.TestContextTB(t), gateway)
|
||||
attribution, err := service.SnippetAttribution(context.Background(), "snippet", 3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &SnippetAttributions{
|
||||
RepositoryNames: []string{"repo1", "repo2"},
|
||||
TotalCount: 2,
|
||||
LimitHit: false,
|
||||
RepositoryNames: wantNames,
|
||||
}
|
||||
if d := cmp.Diff(want, result); d != "" {
|
||||
t.Fatalf("unexpected (-want, +got):\n%s", d)
|
||||
}
|
||||
|
||||
// With a limit of one we expect one of local or dotcom, depending on
|
||||
// which one returns first.
|
||||
result, err = svc.SnippetAttribution(ctx, "test", 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.LimitHit {
|
||||
t.Fatal("we expected the limit to be hit")
|
||||
}
|
||||
if len(result.RepositoryNames) != 1 {
|
||||
t.Fatalf("we wanted one result, got %v", result.RepositoryNames)
|
||||
}
|
||||
if name := result.RepositoryNames[0]; name != "localrepo-1" && name != "dotcomrepo-1" {
|
||||
t.Fatalf("we wanted the first result, got %v", result.RepositoryNames)
|
||||
}
|
||||
}
|
||||
|
||||
func genRepoNames(prefix string, count int) []string {
|
||||
var names []string
|
||||
for i := 1; i <= count; i++ {
|
||||
names = append(names, fmt.Sprintf("%s%d", prefix, i))
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// mockSearchClient returns a client which will return matches. This exercises
|
||||
// more of the search code path to give a bit more confidence we are correctly
|
||||
// calling Plan and Execute vs a dumb SearchClient mock.
|
||||
func mockSearchClient(t testing.TB, repoNames []string) client.SearchClient {
|
||||
repos := dbmocks.NewMockRepoStore()
|
||||
repos.ListMinimalReposFunc.SetDefaultReturn([]types.MinimalRepo{}, nil)
|
||||
repos.CountFunc.SetDefaultReturn(0, nil)
|
||||
|
||||
db := dbmocks.NewMockDB()
|
||||
db.ReposFunc.SetDefaultReturn(repos)
|
||||
|
||||
var matches []zoekt.FileMatch
|
||||
for i, name := range repoNames {
|
||||
matches = append(matches, zoekt.FileMatch{
|
||||
RepositoryID: uint32(i),
|
||||
Repository: name,
|
||||
})
|
||||
}
|
||||
mockZoekt := &searchbackend.FakeStreamer{
|
||||
Repos: []*zoekt.RepoListEntry{},
|
||||
Results: []*zoekt.SearchResult{{
|
||||
Files: matches,
|
||||
}},
|
||||
}
|
||||
|
||||
return client.Mocked(job.RuntimeClients{
|
||||
Logger: logtest.Scoped(t),
|
||||
DB: db,
|
||||
Zoekt: mockZoekt,
|
||||
})
|
||||
}
|
||||
|
||||
func mockDotComClient(t testing.TB, repoNames []string) dotcom.Client {
|
||||
return makeRequester(func(ctx context.Context, req *graphql.Request, resp *graphql.Response) error {
|
||||
// :O :O generated type names :O :O
|
||||
var nodes []dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnectionNodesSnippetAttribution
|
||||
for _, name := range repoNames {
|
||||
nodes = append(nodes, dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnectionNodesSnippetAttribution{
|
||||
RepositoryName: name,
|
||||
})
|
||||
}
|
||||
|
||||
data := resp.Data.(*dotcom.SnippetAttributionResponse)
|
||||
*data = dotcom.SnippetAttributionResponse{
|
||||
// :O
|
||||
SnippetAttribution: dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnection{
|
||||
TotalCount: len(repoNames),
|
||||
Nodes: nodes,
|
||||
},
|
||||
}
|
||||
|
||||
return context.Cause(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
type makeRequester func(ctx context.Context, req *graphql.Request, resp *graphql.Response) error
|
||||
|
||||
func (f makeRequester) MakeRequest(ctx context.Context, req *graphql.Request, resp *graphql.Response) error {
|
||||
return f(ctx, req, resp)
|
||||
}, attribution)
|
||||
}
|
||||
|
||||
@ -6,16 +6,13 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/enterprise"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/envvar"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/attribution"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/resolvers"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codeintel"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/gitserver"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func Init(
|
||||
@ -26,26 +23,17 @@ func Init(
|
||||
_ conftypes.UnifiedWatchable,
|
||||
enterpriseServices *enterprise.Services,
|
||||
) error {
|
||||
opts := attribution.ServiceOpts{
|
||||
SearchClient: client.New(observationCtx.Logger, db, gitserver.NewClient("http.guardrails.search")),
|
||||
// Guardrails is only available in enterprise instances.
|
||||
if envvar.SourcegraphDotComMode() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(keegancsmith) configuration for access token and enabling.
|
||||
if !envvar.SourcegraphDotComMode() {
|
||||
httpClient, err := httpcli.UncachedExternalClientFactory.Doer()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to initialize external http client for guardrails")
|
||||
}
|
||||
endpoint := "https://sourcegraph.com/.api/graphql"
|
||||
accessToken := ""
|
||||
|
||||
opts.SourcegraphDotComFederate = true
|
||||
opts.SourcegraphDotComClient = dotcom.NewClient(httpClient, endpoint, accessToken)
|
||||
client, ok := codygateway.NewClientFromSiteConfig(httpcli.ExternalDoer)
|
||||
if !ok {
|
||||
// TODO handle error
|
||||
return nil
|
||||
}
|
||||
|
||||
enterpriseServices.GuardrailsResolver = &resolvers.GuardrailsResolver{
|
||||
AttributionService: attribution.NewService(observationCtx, opts),
|
||||
AttributionService: attribution.NewService(observationCtx, client),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
load("//dev:go_defs.bzl", "go_test")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
@ -18,3 +19,12 @@ go_library(
|
||||
"//lib/errors",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "codygateway_test",
|
||||
srcs = ["client_test.go"],
|
||||
deps = [
|
||||
":codygateway",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package codygateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -33,8 +34,14 @@ func (rl LimitStatus) PercentUsed() int {
|
||||
return int(math.Ceil(float64(rl.IntervalUsage) / float64(rl.IntervalLimit) * 100))
|
||||
}
|
||||
|
||||
type Attribution struct {
|
||||
Repositories []string
|
||||
LimitHit bool
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
GetLimits(ctx context.Context) ([]LimitStatus, error)
|
||||
Attribution(ctx context.Context, snippet string, limit int) (Attribution, error)
|
||||
}
|
||||
|
||||
func NewClientFromSiteConfig(cli httpcli.Doer) (_ Client, ok bool) {
|
||||
@ -47,6 +54,7 @@ func NewClientFromSiteConfig(cli httpcli.Doer) (_ Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// TODO: What if customer is BYOK? How do we talk to gateway then?
|
||||
// If neither completions nor embeddings use Cody Gateway, return empty.
|
||||
ccUsingGateway := cc != nil && cc.Provider == conftypes.CompletionsProviderNameSourcegraph
|
||||
ecUsingGateway := ec != nil && ec.Provider == conftypes.EmbeddingsProviderNameSourcegraph
|
||||
@ -129,3 +137,43 @@ func (c *client) GetLimits(ctx context.Context) ([]LimitStatus, error) {
|
||||
|
||||
return rateLimits, nil
|
||||
}
|
||||
|
||||
func (c *client) Attribution(ctx context.Context, snippet string, limit int) (Attribution, error) {
|
||||
u, err := url.Parse(c.endpoint)
|
||||
if err != nil {
|
||||
return Attribution{}, err
|
||||
}
|
||||
u.Path = "v1/attribution"
|
||||
body := new(bytes.Buffer)
|
||||
if err := json.NewEncoder(body).Encode(AttributionRequest{
|
||||
Snippet: snippet,
|
||||
Limit: limit,
|
||||
}); err != nil {
|
||||
return Attribution{}, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body.Bytes()))
|
||||
if err != nil {
|
||||
return Attribution{}, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken))
|
||||
resp, err := c.cli.Do(req)
|
||||
if err != nil {
|
||||
return Attribution{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Attribution{}, errors.Newf("request failed with status: %d", errors.Safe(resp.StatusCode))
|
||||
}
|
||||
var gatewayResponse AttributionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&gatewayResponse); err != nil {
|
||||
return Attribution{}, errors.Wrap(err, "cannot interpret gateway response")
|
||||
}
|
||||
a := Attribution{
|
||||
Repositories: make([]string, len(gatewayResponse.Repositories)),
|
||||
LimitHit: gatewayResponse.LimitHit,
|
||||
}
|
||||
for i, r := range gatewayResponse.Repositories {
|
||||
a.Repositories[i] = r.Name
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
51
internal/codygateway/client_test.go
Normal file
51
internal/codygateway/client_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package codygateway_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
)
|
||||
|
||||
type attributionHandler struct {
|
||||
t *testing.T
|
||||
requests []codygateway.AttributionRequest
|
||||
}
|
||||
|
||||
func (h *attributionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var request codygateway.AttributionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h.requests = append(h.requests, request)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
require.NoError(h.t, json.NewEncoder(w).Encode(codygateway.AttributionResponse{
|
||||
Repositories: []codygateway.AttributionRepository{{"repo1"}, {"repo2"}},
|
||||
LimitHit: false,
|
||||
}))
|
||||
}
|
||||
|
||||
func TestAttribution(t *testing.T) {
|
||||
h := &attributionHandler{
|
||||
t: t,
|
||||
}
|
||||
srv := httptest.NewServer(h)
|
||||
t.Cleanup(srv.Close)
|
||||
client := codygateway.NewClient(http.DefaultClient, srv.URL, "token")
|
||||
attribution, err := client.Attribution(context.Background(), "snippet", 3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []codygateway.AttributionRequest{{
|
||||
Snippet: "snippet",
|
||||
Limit: 3,
|
||||
}}, h.requests)
|
||||
require.Equal(t, codygateway.Attribution{
|
||||
Repositories: []string{"repo1", "repo2"},
|
||||
LimitHit: false,
|
||||
}, attribution)
|
||||
}
|
||||
@ -80,3 +80,29 @@ type ActorRateLimitNotifyConfig struct {
|
||||
// SlackWebhookURL is the URL of the Slack webhook to send the alerts to.
|
||||
SlackWebhookURL string
|
||||
}
|
||||
|
||||
// AttributionRequest is request for attribution search.
|
||||
// Expected in JSON form as the body of POST request.
|
||||
type AttributionRequest struct {
|
||||
// Snippet is the text to search attribution of.
|
||||
Snippet string `json:"snippet"`
|
||||
// Limit is the upper bound of number of responses we want to get.
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// AttributionResponse is response of attribution search.
|
||||
// Contains some repositories to which the snippet can be attributed to.
|
||||
type AttributionResponse struct {
|
||||
// Repositories which contain code matching search snippet.
|
||||
Repositories []AttributionRepository
|
||||
// TotalCount denotes how many total matches there were (including listed repositories).
|
||||
TotalCount int `json:"totalCount,omitempty"`
|
||||
// LimitHit is true if the number of search hits goes beyond limit specified in request.
|
||||
LimitHit bool `json:"limitHit,omitempty"`
|
||||
}
|
||||
|
||||
// AttributionRepository represents matching of search content against a repository.
|
||||
type AttributionRepository struct {
|
||||
// Name of the repo on dotcom. Like github.com/sourcegraph/sourcegraph.
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user