Enterprise instance uses gateway for attribution (#59513)

* Remove local snippet attribution

* Draft

* Attribution service updated with gateway client

* Thursday progress

* Attribution success test

* Unify types for gateway request/response

* Attribution client test

* BAZEL

* Update gateway request/response references

* Remove Attribution request TODO
This commit is contained in:
Cezary Bartoszuk 2024-01-16 10:16:32 +01:00 committed by GitHub
parent ea521b5757
commit f5bcbfcbb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 185 additions and 395 deletions

View File

@ -22,7 +22,6 @@ go_test(
name = "attribution_test",
srcs = ["handler_test.go"],
deps = [
":attribution",
"//cmd/cody-gateway/internal/actor",
"//cmd/cody-gateway/internal/auth",
"//cmd/cody-gateway/internal/dotcom",

View File

@ -20,30 +20,6 @@ import (
// If a higher value is given, then this default is set.
const LimitUpperBound = 4
// Request for attribution search. Expected in JSON form as the body of POST request.
type Request struct {
// Snippet is the text to search attribution of.
Snippet string
// Limit is the upper bound of number of responses we want to get.
Limit int
}
// Response of attribution search. Contains some repositories to which the snippet can be attributed to.
type Response struct {
// Repositories which contain code matching search snippet.
Repositories []Repository
// TotalCount denotes how many total matches there were (including listed repositories).
TotalCount int
// LimitHit is true if the number of search hits goes beyond limit specified in request.
LimitHit bool
}
// Repository represents matching of search content against a repository.
type Repository struct {
// Name of the repo on dotcom. Like github.com/sourcegraph/sourcegraph.
Name string
}
// NewHandler creates a REST handler for attribution search.
// graphql.Client can be nil which disables the search.
func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
@ -59,7 +35,7 @@ func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
response.JSONError(logger, w, http.StatusServiceUnavailable, errors.New("attribution search not enabled"))
return
}
var request Request
var request codygateway.AttributionRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
response.JSONError(logger, w, http.StatusBadRequest, err)
return
@ -73,13 +49,13 @@ func NewHandler(client graphql.Client, baseLogger log.Logger) http.Handler {
response.JSONError(logger, w, http.StatusServiceUnavailable, err)
return
}
var rs []Repository
var rs []codygateway.AttributionRepository
for _, n := range searchResponse.SnippetAttribution.Nodes {
rs = append(rs, Repository{Name: n.RepositoryName})
rs = append(rs, codygateway.AttributionRepository{Name: n.RepositoryName})
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(&Response{
if err := json.NewEncoder(w).Encode(&codygateway.AttributionResponse{
Repositories: rs,
TotalCount: searchResponse.SnippetAttribution.TotalCount,
LimitHit: searchResponse.SnippetAttribution.LimitHit,

View File

@ -32,7 +32,6 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/dotcom"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/attribution"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
)
@ -73,7 +72,7 @@ func runFakeGraphQL(t *testing.T) *fakeGraphQL {
// request creates an attribution search request to the gateway.
func request(t *testing.T) *http.Request {
requestBody, err := json.Marshal(&attribution.Request{
requestBody, err := json.Marshal(&codygateway.AttributionRequest{
Snippet: strings.Join([]string{
"for n != 1 {",
" if n % 2 == 0 {",
@ -127,10 +126,10 @@ func TestSuccess(t *testing.T) {
t.Error(w.Body.String())
t.Fatalf("expected OK, got %d", got)
}
var gotResponseBody attribution.Response
var gotResponseBody codygateway.AttributionResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&gotResponseBody))
wantResponseBody := &attribution.Response{
Repositories: []attribution.Repository{
wantResponseBody := &codygateway.AttributionResponse{
Repositories: []codygateway.AttributionRepository{
{Name: "github.com/sourcegraph/sourcegraph"},
{Name: "github.com/sourcegraph/cody"},
},

View File

@ -9,15 +9,12 @@ go_library(
"//cmd/frontend/enterprise",
"//cmd/frontend/envvar",
"//cmd/frontend/internal/guardrails/attribution",
"//cmd/frontend/internal/guardrails/dotcom",
"//cmd/frontend/internal/guardrails/resolvers",
"//internal/codeintel",
"//internal/codygateway",
"//internal/conf/conftypes",
"//internal/database",
"//internal/gitserver",
"//internal/httpcli",
"//internal/observation",
"//internal/search/client",
"//lib/errors",
],
)

View File

@ -10,16 +10,9 @@ go_library(
importpath = "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/attribution",
visibility = ["//cmd/frontend:__subpackages__"],
deps = [
"//cmd/frontend/internal/guardrails/dotcom",
"//internal/api",
"//internal/codygateway",
"//internal/metrics",
"//internal/observation",
"//internal/search",
"//internal/search/client",
"//internal/search/streaming",
"//lib/errors",
"//lib/pointers",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
"@io_opentelemetry_go_otel//attribute",
],
@ -30,16 +23,8 @@ go_test(
srcs = ["attribution_test.go"],
embed = [":attribution"],
deps = [
"//cmd/frontend/internal/guardrails/dotcom",
"//internal/database/dbmocks",
"//internal/codygateway",
"//internal/observation",
"//internal/search/backend",
"//internal/search/client",
"//internal/search/job",
"//internal/types",
"@com_github_google_go_cmp//cmp",
"@com_github_khan_genqlient//graphql",
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_zoekt//:zoekt",
"@com_github_stretchr_testify//require",
],
)

View File

@ -2,43 +2,19 @@ package attribution
import (
"context"
"fmt"
"sync"
"github.com/sourcegraph/conc/pool"
"go.opentelemetry.io/otel/attribute"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search"
"github.com/sourcegraph/sourcegraph/internal/search/client"
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
// ServiceOpts configures Service.
type ServiceOpts struct {
// SearchClient is used to find attribution on the local instance.
SearchClient client.SearchClient
// SourcegraphDotComClient is a graphql client that is queried if
// federating out to sourcegraph.com is enabled.
SourcegraphDotComClient dotcom.Client
// SourcegraphDotComFederate is true if this instance should also federate
// to sourcegraph.com.
SourcegraphDotComFederate bool
}
// Service is for the attribution service which searches for matches on
// snippets of code.
//
// Use NewService to construct this value.
type Service struct {
ServiceOpts
client codygateway.Client
operations *operations
}
@ -46,10 +22,10 @@ type Service struct {
//
// Note: this registers metrics so should only be called once with the same
// observationCtx.
func NewService(observationCtx *observation.Context, opts ServiceOpts) *Service {
func NewService(observationCtx *observation.Context, client codygateway.Client) *Service {
return &Service{
operations: newOperations(observationCtx),
ServiceOpts: opts,
operations: newOperations(observationCtx),
client: client,
}
}
@ -89,173 +65,13 @@ func (c *Service) SnippetAttribution(ctx context.Context, snippet string, limit
},
})
defer endObservationWithResult(traceLogger, endObservation, &result)()
limitHitErr := errors.New("limit hit error")
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
// we massage results in this function and possibly cancel if we can stop
// looking.
truncateAtLimit := func(result *SnippetAttributions) {
if result == nil {
return
}
if limit <= len(result.RepositoryNames) {
result.LimitHit = true
result.RepositoryNames = result.RepositoryNames[:limit]
}
if result.LimitHit {
cancel(limitHitErr)
}
}
// TODO(keegancsmith) how should we handle partial errors?
p := pool.New().WithContext(ctx).WithCancelOnError().WithFirstError()
// We don't use NewWithResults since we want local results to come before dotcom
var local, dotcom *SnippetAttributions
p.Go(func(ctx context.Context) error {
var err error
local, err = c.snippetAttributionLocal(ctx, snippet, limit)
truncateAtLimit(local)
return err
})
if c.SourcegraphDotComFederate {
p.Go(func(ctx context.Context) error {
var err error
dotcom, err = c.snippetAttributionDotCom(ctx, snippet, limit)
truncateAtLimit(dotcom)
return err
})
}
if err := p.Wait(); err != nil && context.Cause(ctx) != limitHitErr {
return nil, err
}
var agg SnippetAttributions
seen := map[string]struct{}{}
for _, result := range []*SnippetAttributions{local, dotcom} {
if result == nil {
continue
}
// Limitation: We just add to TotalCount even though that may mean we
// overcount (both dotcom and local instance have the repo)
agg.TotalCount += result.TotalCount
agg.LimitHit = agg.LimitHit || result.LimitHit
for _, name := range result.RepositoryNames {
if _, ok := seen[name]; ok {
// We have already counted this repo in the above TotalCount
// increment, so undo that.
agg.TotalCount--
continue
}
seen[name] = struct{}{}
agg.RepositoryNames = append(agg.RepositoryNames, name)
}
}
// we call truncateAtLimit on the aggregated result to ensure we only
// return upto limit. Note this function will call cancel but that is fine
// since we just return after this.
truncateAtLimit(&agg)
return &agg, nil
}
func (c *Service) snippetAttributionLocal(ctx context.Context, snippet string, limit int) (result *SnippetAttributions, err error) {
ctx, traceLogger, endObservation := c.operations.snippetAttributionLocal.With(ctx, &err, observation.Args{})
defer endObservationWithResult(traceLogger, endObservation, &result)()
const (
version = "V3"
searchMode = search.Precise
protocol = search.Streaming
)
patternType := "literal"
searchQuery := fmt.Sprintf("type:file select:repo index:only case:yes count:%d content:%q", limit, snippet)
inputs, err := c.SearchClient.Plan(
ctx,
version,
&patternType,
searchQuery,
searchMode,
protocol,
pointers.Ptr(int32(0)),
)
if err != nil {
return nil, errors.Wrap(err, "failed to create search plan")
}
// TODO(keegancsmith) Reading the SearchClient code it seems to miss out
// on some of the observability that we instead add in at a later stage.
// For example the search dataset in honeycomb will be missing. Will have
// to follow-up with observability and maybe solve it for all users.
//
// Note: In our current API we could just store repo names in seen. But it
// is safer to rely on searches ranking for result stability than doing
// something like sorting by name from the map.
var (
mu sync.Mutex
seen = map[api.RepoID]struct{}{}
repoNames []string
limitHit bool
)
_, err = c.SearchClient.Execute(ctx, streaming.StreamFunc(func(ev streaming.SearchEvent) {
mu.Lock()
defer mu.Unlock()
limitHit = limitHit || ev.Stats.IsLimitHit
for _, m := range ev.Results {
repo := m.RepoName()
if _, ok := seen[repo.ID]; ok {
continue
}
seen[repo.ID] = struct{}{}
repoNames = append(repoNames, string(repo.Name))
}
}), inputs)
if err != nil {
return nil, errors.Wrap(err, "failed to execute search")
}
// Note: Our search API is missing total count internally, but Zoekt does
// expose this. For now we just count what we found.
totalCount := len(repoNames)
if len(repoNames) > limit {
repoNames = repoNames[:limit]
}
return &SnippetAttributions{
RepositoryNames: repoNames,
TotalCount: totalCount,
LimitHit: limitHit,
}, nil
}
func (c *Service) snippetAttributionDotCom(ctx context.Context, snippet string, limit int) (result *SnippetAttributions, err error) {
ctx, traceLogger, endObservation := c.operations.snippetAttributionDotCom.With(ctx, &err, observation.Args{})
defer endObservationWithResult(traceLogger, endObservation, &result)()
resp, err := dotcom.SnippetAttribution(ctx, c.SourcegraphDotComClient, snippet, limit)
attribution, err := c.client.Attribution(ctx, snippet, limit)
if err != nil {
return nil, err
}
var repoNames []string
for _, node := range resp.SnippetAttribution.Nodes {
repoNames = append(repoNames, node.RepositoryName)
}
return &SnippetAttributions{
RepositoryNames: repoNames,
TotalCount: resp.SnippetAttribution.TotalCount,
LimitHit: resp.SnippetAttribution.LimitHit,
RepositoryNames: attribution.Repositories,
TotalCount: len(attribution.Repositories), // TODO: Remove total count.
LimitHit: attribution.LimitHit,
}, nil
}

View File

@ -2,138 +2,33 @@ package attribution
import (
"context"
"fmt"
"testing"
"github.com/Khan/genqlient/graphql"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/zoekt"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
"github.com/sourcegraph/sourcegraph/internal/database/dbmocks"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/observation"
searchbackend "github.com/sourcegraph/sourcegraph/internal/search/backend"
"github.com/sourcegraph/sourcegraph/internal/search/client"
"github.com/sourcegraph/sourcegraph/internal/search/job"
"github.com/sourcegraph/sourcegraph/internal/types"
)
func TestAttribution(t *testing.T) {
ctx := context.Background()
type fakeGateway struct {
codygateway.Client
}
// inputs
localCount, dotcomCount := 5, 5
limit := localCount + dotcomCount + 1
localNames := genRepoNames("localrepo-", localCount)
dotcomNames := genRepoNames("dotcomrepo-", dotcomCount)
func (f fakeGateway) Attribution(ctx context.Context, snippet string, limit int) (codygateway.Attribution, error) {
return codygateway.Attribution{
Repositories: []string{"repo1", "repo2"},
LimitHit: false,
}, nil
}
// we want the localNames back followed by dotcomNames
wantCount := localCount + dotcomCount
wantNames := append(genRepoNames("localrepo-", localCount), genRepoNames("dotcomrepo-", dotcomCount)...)
svc := NewService(observation.TestContextTB(t), ServiceOpts{
SearchClient: mockSearchClient(t, localNames),
SourcegraphDotComClient: mockDotComClient(t, dotcomNames),
SourcegraphDotComFederate: true,
})
result, err := svc.SnippetAttribution(ctx, "test", limit)
if err != nil {
t.Fatal(err)
}
want := &SnippetAttributions{
TotalCount: wantCount,
func TestSuccess(t *testing.T) {
gateway := fakeGateway{}
service := NewService(observation.TestContextTB(t), gateway)
attribution, err := service.SnippetAttribution(context.Background(), "snippet", 3)
require.NoError(t, err)
require.Equal(t, &SnippetAttributions{
RepositoryNames: []string{"repo1", "repo2"},
TotalCount: 2,
LimitHit: false,
RepositoryNames: wantNames,
}
if d := cmp.Diff(want, result); d != "" {
t.Fatalf("unexpected (-want, +got):\n%s", d)
}
// With a limit of one we expect one of local or dotcom, depending on
// which one returns first.
result, err = svc.SnippetAttribution(ctx, "test", 1)
if err != nil {
t.Fatal(err)
}
if !result.LimitHit {
t.Fatal("we expected the limit to be hit")
}
if len(result.RepositoryNames) != 1 {
t.Fatalf("we wanted one result, got %v", result.RepositoryNames)
}
if name := result.RepositoryNames[0]; name != "localrepo-1" && name != "dotcomrepo-1" {
t.Fatalf("we wanted the first result, got %v", result.RepositoryNames)
}
}
func genRepoNames(prefix string, count int) []string {
var names []string
for i := 1; i <= count; i++ {
names = append(names, fmt.Sprintf("%s%d", prefix, i))
}
return names
}
// mockSearchClient returns a client which will return matches. This exercises
// more of the search code path to give a bit more confidence we are correctly
// calling Plan and Execute vs a dumb SearchClient mock.
func mockSearchClient(t testing.TB, repoNames []string) client.SearchClient {
repos := dbmocks.NewMockRepoStore()
repos.ListMinimalReposFunc.SetDefaultReturn([]types.MinimalRepo{}, nil)
repos.CountFunc.SetDefaultReturn(0, nil)
db := dbmocks.NewMockDB()
db.ReposFunc.SetDefaultReturn(repos)
var matches []zoekt.FileMatch
for i, name := range repoNames {
matches = append(matches, zoekt.FileMatch{
RepositoryID: uint32(i),
Repository: name,
})
}
mockZoekt := &searchbackend.FakeStreamer{
Repos: []*zoekt.RepoListEntry{},
Results: []*zoekt.SearchResult{{
Files: matches,
}},
}
return client.Mocked(job.RuntimeClients{
Logger: logtest.Scoped(t),
DB: db,
Zoekt: mockZoekt,
})
}
func mockDotComClient(t testing.TB, repoNames []string) dotcom.Client {
return makeRequester(func(ctx context.Context, req *graphql.Request, resp *graphql.Response) error {
// :O :O generated type names :O :O
var nodes []dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnectionNodesSnippetAttribution
for _, name := range repoNames {
nodes = append(nodes, dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnectionNodesSnippetAttribution{
RepositoryName: name,
})
}
data := resp.Data.(*dotcom.SnippetAttributionResponse)
*data = dotcom.SnippetAttributionResponse{
// :O
SnippetAttribution: dotcom.SnippetAttributionSnippetAttributionSnippetAttributionConnection{
TotalCount: len(repoNames),
Nodes: nodes,
},
}
return context.Cause(ctx)
})
}
type makeRequester func(ctx context.Context, req *graphql.Request, resp *graphql.Response) error
func (f makeRequester) MakeRequest(ctx context.Context, req *graphql.Request, resp *graphql.Response) error {
return f(ctx, req, resp)
}, attribution)
}

View File

@ -6,16 +6,13 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/frontend/enterprise"
"github.com/sourcegraph/sourcegraph/cmd/frontend/envvar"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/attribution"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/dotcom"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/guardrails/resolvers"
"github.com/sourcegraph/sourcegraph/internal/codeintel"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/client"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func Init(
@ -26,26 +23,17 @@ func Init(
_ conftypes.UnifiedWatchable,
enterpriseServices *enterprise.Services,
) error {
opts := attribution.ServiceOpts{
SearchClient: client.New(observationCtx.Logger, db, gitserver.NewClient("http.guardrails.search")),
// Guardrails is only available in enterprise instances.
if envvar.SourcegraphDotComMode() {
return nil
}
// TODO(keegancsmith) configuration for access token and enabling.
if !envvar.SourcegraphDotComMode() {
httpClient, err := httpcli.UncachedExternalClientFactory.Doer()
if err != nil {
return errors.Wrap(err, "failed to initialize external http client for guardrails")
}
endpoint := "https://sourcegraph.com/.api/graphql"
accessToken := ""
opts.SourcegraphDotComFederate = true
opts.SourcegraphDotComClient = dotcom.NewClient(httpClient, endpoint, accessToken)
client, ok := codygateway.NewClientFromSiteConfig(httpcli.ExternalDoer)
if !ok {
// TODO handle error
return nil
}
enterpriseServices.GuardrailsResolver = &resolvers.GuardrailsResolver{
AttributionService: attribution.NewService(observationCtx, opts),
AttributionService: attribution.NewService(observationCtx, client),
}
return nil
}

View File

@ -1,3 +1,4 @@
load("//dev:go_defs.bzl", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
@ -18,3 +19,12 @@ go_library(
"//lib/errors",
],
)
go_test(
name = "codygateway_test",
srcs = ["client_test.go"],
deps = [
":codygateway",
"@com_github_stretchr_testify//require",
],
)

View File

@ -1,6 +1,7 @@
package codygateway
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -33,8 +34,14 @@ func (rl LimitStatus) PercentUsed() int {
return int(math.Ceil(float64(rl.IntervalUsage) / float64(rl.IntervalLimit) * 100))
}
type Attribution struct {
Repositories []string
LimitHit bool
}
type Client interface {
GetLimits(ctx context.Context) ([]LimitStatus, error)
Attribution(ctx context.Context, snippet string, limit int) (Attribution, error)
}
func NewClientFromSiteConfig(cli httpcli.Doer) (_ Client, ok bool) {
@ -47,6 +54,7 @@ func NewClientFromSiteConfig(cli httpcli.Doer) (_ Client, ok bool) {
return nil, false
}
// TODO: What if customer is BYOK? How do we talk to gateway then?
// If neither completions nor embeddings use Cody Gateway, return empty.
ccUsingGateway := cc != nil && cc.Provider == conftypes.CompletionsProviderNameSourcegraph
ecUsingGateway := ec != nil && ec.Provider == conftypes.EmbeddingsProviderNameSourcegraph
@ -129,3 +137,43 @@ func (c *client) GetLimits(ctx context.Context) ([]LimitStatus, error) {
return rateLimits, nil
}
func (c *client) Attribution(ctx context.Context, snippet string, limit int) (Attribution, error) {
u, err := url.Parse(c.endpoint)
if err != nil {
return Attribution{}, err
}
u.Path = "v1/attribution"
body := new(bytes.Buffer)
if err := json.NewEncoder(body).Encode(AttributionRequest{
Snippet: snippet,
Limit: limit,
}); err != nil {
return Attribution{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body.Bytes()))
if err != nil {
return Attribution{}, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken))
resp, err := c.cli.Do(req)
if err != nil {
return Attribution{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return Attribution{}, errors.Newf("request failed with status: %d", errors.Safe(resp.StatusCode))
}
var gatewayResponse AttributionResponse
if err := json.NewDecoder(resp.Body).Decode(&gatewayResponse); err != nil {
return Attribution{}, errors.Wrap(err, "cannot interpret gateway response")
}
a := Attribution{
Repositories: make([]string, len(gatewayResponse.Repositories)),
LimitHit: gatewayResponse.LimitHit,
}
for i, r := range gatewayResponse.Repositories {
a.Repositories[i] = r.Name
}
return a, nil
}

View File

@ -0,0 +1,51 @@
package codygateway_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
)
type attributionHandler struct {
t *testing.T
requests []codygateway.AttributionRequest
}
func (h *attributionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var request codygateway.AttributionRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
h.requests = append(h.requests, request)
w.WriteHeader(http.StatusOK)
require.NoError(h.t, json.NewEncoder(w).Encode(codygateway.AttributionResponse{
Repositories: []codygateway.AttributionRepository{{"repo1"}, {"repo2"}},
LimitHit: false,
}))
}
func TestAttribution(t *testing.T) {
h := &attributionHandler{
t: t,
}
srv := httptest.NewServer(h)
t.Cleanup(srv.Close)
client := codygateway.NewClient(http.DefaultClient, srv.URL, "token")
attribution, err := client.Attribution(context.Background(), "snippet", 3)
require.NoError(t, err)
require.Equal(t, []codygateway.AttributionRequest{{
Snippet: "snippet",
Limit: 3,
}}, h.requests)
require.Equal(t, codygateway.Attribution{
Repositories: []string{"repo1", "repo2"},
LimitHit: false,
}, attribution)
}

View File

@ -80,3 +80,29 @@ type ActorRateLimitNotifyConfig struct {
// SlackWebhookURL is the URL of the Slack webhook to send the alerts to.
SlackWebhookURL string
}
// AttributionRequest is request for attribution search.
// Expected in JSON form as the body of POST request.
type AttributionRequest struct {
// Snippet is the text to search attribution of.
Snippet string `json:"snippet"`
// Limit is the upper bound of number of responses we want to get.
Limit int `json:"limit"`
}
// AttributionResponse is response of attribution search.
// Contains some repositories to which the snippet can be attributed to.
type AttributionResponse struct {
// Repositories which contain code matching search snippet.
Repositories []AttributionRepository
// TotalCount denotes how many total matches there were (including listed repositories).
TotalCount int `json:"totalCount,omitempty"`
// LimitHit is true if the number of search hits goes beyond limit specified in request.
LimitHit bool `json:"limitHit,omitempty"`
}
// AttributionRepository represents matching of search content against a repository.
type AttributionRepository struct {
// Name of the repo on dotcom. Like github.com/sourcegraph/sourcegraph.
Name string `json:"name"`
}