Integrate Cohere re-ranking API (#63877)

Integrates Cohere re-ranking [API](https://cohere.com/rerank) for
server-side Cody Context ([RFC
969](https://linear.app/sourcegraph/project/v1-of-two-stage-intent-detection-context-retrieval-system-c4f7093e9eab/overview)).
Before this PR, we only supported `identity` ranker (which returned all
items in the input order), which is still the default choice (when
Cohere API key is not provided).

Closes https://linear.app/sourcegraph/issue/AI-134/add-non-poc-ranking

## Test plan

- tested locally, use 
```
"cody.serverSideContext": {
      "reranker": {
        "type": "cohere",
        "apiKey": "TOKEN"
      }
    }
```
to test locally
This commit is contained in:
Rafał Gajdulewicz 2024-07-17 21:20:13 +02:00 committed by GitHub
parent 658d12ea35
commit 25929d1be9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 192 additions and 9 deletions

View File

@ -1,13 +1,23 @@
import * as React from 'react'
import { mdiArrowLeftBoldBoxOutline } from '@mdi/js'
import { useLocation } from 'react-router-dom'
import { asError, type ErrorLike, isErrorLike, logger } from '@sourcegraph/common'
import type { TelemetryV2Props } from '@sourcegraph/shared/src/telemetry'
import { EVENT_LOGGER } from '@sourcegraph/shared/src/telemetry/web/eventLogger'
import { Button, Link, LoadingSpinner, Alert, Text, Input, ErrorAlert, Form, Container, Icon } from '@sourcegraph/wildcard'
import {
Button,
Link,
LoadingSpinner,
Alert,
Text,
Input,
ErrorAlert,
Form,
Container,
Icon,
} from '@sourcegraph/wildcard'
import type { AuthenticatedUser } from '../auth'
import { LoaderButton } from '../components/LoaderButton'
@ -159,7 +169,8 @@ class ResetPasswordCodeForm extends React.PureComponent<ResetPasswordCodeFormPro
if (this.state.submitOrError === null) {
return (
<Alert variant="success">
Your password was reset. <Link to={`/sign-in?email=${email}`}>Sign in with your new password</Link> to continue.
Your password was reset. <Link to={`/sign-in?email=${email}`}>Sign in with your new password</Link>{' '}
to continue.
</Alert>
)
}
@ -168,7 +179,10 @@ class ResetPasswordCodeForm extends React.PureComponent<ResetPasswordCodeFormPro
<>
{isErrorLike(this.state.submitOrError) && <ErrorAlert error={this.state.submitOrError} />}
<Container className="w-100">
<Link to='/password-reset'><Icon className="mr-1" aria-hidden={true} svgPath={mdiArrowLeftBoldBoxOutline} />Raise request for a different account</Link>
<Link to="/password-reset">
<Icon className="mr-1" aria-hidden={true} svgPath={mdiArrowLeftBoldBoxOutline} />
Raise request for a different account
</Link>
<Text className="mt-1 text-center text-muted font-weight-bold mb-3">{email}</Text>
<Form data-testid="reset-password-page-form" onSubmit={this.handleSubmitResetPassword}>
<PasswordInput

View File

@ -21,6 +21,8 @@ go_library(
"//lib/errors",
"//lib/pointers",
"//schema",
"@com_github_cohere_ai_cohere_go_v2//:cohere-go",
"@com_github_cohere_ai_cohere_go_v2//client",
"@com_github_sourcegraph_conc//iter",
"@com_github_sourcegraph_log//:log",
],

View File

@ -24,6 +24,9 @@ import (
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
cohere "github.com/cohere-ai/cohere-go/v2"
"github.com/cohere-ai/cohere-go/v2/client"
)
func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient *codycontext.CodyContextClient, logger log.Logger) graphqlbackend.CodyContextResolver {
@ -34,6 +37,8 @@ func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient
logger: logger,
intentApiHttpClient: httpcli.UncachedExternalDoer,
intentBackendConfig: conf.CodyIntentConfig(),
reranker: conf.CodyReranker(),
cohereConfig: conf.CodyRerankerCohereConfig(),
}
}
@ -44,6 +49,8 @@ type Resolver struct {
logger log.Logger
intentApiHttpClient httpcli.Doer
intentBackendConfig *schema.IntentDetectionAPI
reranker conf.CodyRerankerBackend
cohereConfig *schema.CodyRerankerCohere
}
func (r *Resolver) RecordContext(ctx context.Context, args graphqlbackend.RecordContextArgs) (*graphqlbackend.EmptyResponse, error) {
@ -74,14 +81,15 @@ func (r *Resolver) RankContext(ctx context.Context, args graphqlbackend.RankCont
if err != nil {
return nil, err
}
ranker, used, err := r.rerank(ctx, args)
if err != nil {
return nil, err
}
res := rankContextResponse{
ranker: "identity",
ranker: string(ranker),
used: used,
}
r.logger.Info("ranking context", log.String("interactionID", args.InteractionID), log.String("ranker", res.ranker), log.Int("contextItemCount", len(args.ContextItems)))
for i := range args.ContextItems {
res.used = append(res.used, int32(i))
}
return res, nil
}
@ -260,6 +268,35 @@ func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.F
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, endLine), nil
}
func (r *Resolver) rerank(ctx context.Context, args graphqlbackend.RankContextArgs) (conf.CodyRerankerBackend, []int32, error) {
if r.reranker == conf.CodyRerankerIdentity {
var used []int32
for i := range args.ContextItems {
used = append(used, int32(i))
}
return conf.CodyRerankerIdentity, used, nil
}
co := client.NewClient(client.WithToken(r.cohereConfig.ApiKey))
req := &cohere.RerankRequest{
Query: args.Query,
Model: cohere.String(r.cohereConfig.Model),
}
for _, ci := range args.ContextItems {
req.Documents = append(req.Documents, &cohere.RerankRequestDocumentsItem{String: ci.Content})
}
resp, err := co.Rerank(ctx, req)
if err != nil {
r.logger.Error("cohere reranking error", log.String("interactionId", args.InteractionID), log.String("query", args.Query), log.Error(err))
return conf.CodyRerankerCohere, nil, err
}
var used []int32
for _, r := range resp.Results {
used = append(used, int32(r.Index))
}
return conf.CodyRerankerCohere, used, nil
}
// countLines finds the number of lines corresponding to the number of runes. We 'round down'
// to ensure that we don't return more characters than our budget.
func countLines(content string, numRunes int) int {

View File

@ -1181,6 +1181,13 @@ def go_dependencies():
sum = "h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=",
version = "v0.0.0-20150114235600-33e0aa1cb7c0",
)
go_repository(
name = "com_github_cohere_ai_cohere_go_v2",
build_file_proto_mode = "disable_global",
importpath = "github.com/cohere-ai/cohere-go/v2",
sum = "h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=",
version = "v2.8.2",
)
go_repository(
name = "com_github_common_nighthawk_go_figure",
build_file_proto_mode = "disable_global",

1
go.mod
View File

@ -269,6 +269,7 @@ require (
github.com/bevzzz/nb v0.3.0
github.com/bevzzz/nb-synth v0.0.0-20240128164931-35fdda0583a0
github.com/bevzzz/nb/extension/extra/goldmark-jupyter v0.0.0-20240131001330-e69229bd9da4
github.com/cohere-ai/cohere-go/v2 v2.8.2
github.com/derision-test/go-mockgen/v2 v2.0.1
github.com/dghubble/gologin/v2 v2.4.0
github.com/edsrzf/mmap-go v1.1.0

2
go.sum
View File

@ -1000,6 +1000,8 @@ github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZe
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs=
github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30=
github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg=
github.com/cohere-ai/cohere-go/v2 v2.8.2 h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=
github.com/cohere-ai/cohere-go/v2 v2.8.2/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
github.com/containerd/containerd v1.7.12 h1:+KQsnv4VnzyxWcfO9mlxxELaoztsDEjOuCMPAuPqgU0=
github.com/containerd/containerd v1.7.12/go.mod h1:/5OMpE1p0ylxtEUGY8kuCYkDRzJm9NO1TFMWjUpdevk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=

View File

@ -293,6 +293,30 @@ func CodyIntentConfig() *schema.IntentDetectionAPI {
return Get().ExperimentalFeatures.CodyServerSideContext.IntentDetectionAPI
}
type CodyRerankerBackend string
const (
CodyRerankerIdentity CodyRerankerBackend = "identity"
CodyRerankerCohere CodyRerankerBackend = "cohere"
)
func CodyReranker() CodyRerankerBackend {
if Get().ExperimentalFeatures == nil || Get().ExperimentalFeatures.CodyServerSideContext == nil || Get().ExperimentalFeatures.CodyServerSideContext.Reranker == nil {
return CodyRerankerIdentity
}
if Get().ExperimentalFeatures.CodyServerSideContext.Reranker.Identity != nil {
return CodyRerankerIdentity
}
return CodyRerankerCohere
}
func CodyRerankerCohereConfig() *schema.CodyRerankerCohere {
if CodyReranker() != CodyRerankerCohere {
return nil
}
return Get().ExperimentalFeatures.CodyServerSideContext.Reranker.Cohere
}
func ExecutorsEnabled() bool {
return Get().ExecutorsAccessToken != ""
}

View File

@ -676,10 +676,24 @@ type CodyProConfig struct {
UseEmbeddedUI bool `json:"useEmbeddedUI,omitempty"`
}
// CodyRerankerCohere description: Re-ranker using Cohere API
type CodyRerankerCohere struct {
ApiKey string `json:"apiKey"`
Model string `json:"model,omitempty"`
Type string `json:"type"`
}
// CodyRerankerIdentity description: Identity re-ranker
type CodyRerankerIdentity struct {
Type string `json:"type"`
}
// CodyServerSideContext description: Configuration for Server-side context API
type CodyServerSideContext struct {
// IntentDetectionAPI description: Configuration for intent detection API
IntentDetectionAPI *IntentDetectionAPI `json:"intentDetectionAPI,omitempty"`
// Reranker description: Reranker to use for rankContext requests
Reranker *Reranker `json:"reranker,omitempty"`
}
// CommitGraphUpdates description: Customize strategy used for commit graph updates
@ -2303,6 +2317,38 @@ type RequestMessage struct {
Params any `json:"params,omitempty"`
Settings map[string]any `json:"settings,omitempty"`
}
// Reranker description: Reranker to use for rankContext requests
type Reranker struct {
Identity *CodyRerankerIdentity
Cohere *CodyRerankerCohere
}
func (v Reranker) MarshalJSON() ([]byte, error) {
if v.Identity != nil {
return json.Marshal(v.Identity)
}
if v.Cohere != nil {
return json.Marshal(v.Cohere)
}
return nil, errors.New("tagged union type must have exactly 1 non-nil field value")
}
func (v *Reranker) UnmarshalJSON(data []byte) error {
var d struct {
DiscriminantProperty string `json:"type"`
}
if err := json.Unmarshal(data, &d); err != nil {
return err
}
switch d.DiscriminantProperty {
case "cohere":
return json.Unmarshal(data, &v.Cohere)
case "identity":
return json.Unmarshal(data, &v.Identity)
}
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"identity", "cohere"})
}
type Responders struct {
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`

View File

@ -656,6 +656,27 @@
}
}
}
},
"reranker": {
"description": "Reranker to use for rankContext requests",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["identity", "cohere"]
}
},
"oneOf": [
{
"$ref": "#/definitions/CodyRerankerIdentity"
},
{
"$ref": "#/definitions/CodyRerankerCohere"
}
],
"!go": {
"taggedUnionType": true
}
}
}
}
@ -4570,6 +4591,35 @@
"type": "string"
}
}
},
"CodyRerankerIdentity": {
"description": "Identity re-ranker",
"type": "object",
"required": ["type"],
"properties": {
"type": {
"type": "string",
"const": "identity"
}
}
},
"CodyRerankerCohere": {
"description": "Re-ranker using Cohere API",
"type": "object",
"required": ["type", "apiKey"],
"properties": {
"type": {
"type": "string",
"const": "cohere"
},
"apiKey": {
"type": "string"
},
"model": {
"type": "string",
"default": "rerank-english-v3.0"
}
}
}
}
}