mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:31:43 +00:00
Integrate Cohere re-ranking API (#63877)
Integrates Cohere re-ranking [API](https://cohere.com/rerank) for server-side Cody Context ([RFC 969](https://linear.app/sourcegraph/project/v1-of-two-stage-intent-detection-context-retrieval-system-c4f7093e9eab/overview)). Before this PR, we only supported `identity` ranker (which returned all items in the input order), which is still the default choice (when Cohere API key is not provided). Closes https://linear.app/sourcegraph/issue/AI-134/add-non-poc-ranking ## Test plan - tested locally, use ``` "cody.serverSideContext": { "reranker": { "type": "cohere", "apiKey": "TOKEN" } } ``` to test locally
This commit is contained in:
parent
658d12ea35
commit
25929d1be9
@ -1,13 +1,23 @@
|
||||
import * as React from 'react'
|
||||
|
||||
import { mdiArrowLeftBoldBoxOutline } from '@mdi/js'
|
||||
|
||||
import { useLocation } from 'react-router-dom'
|
||||
|
||||
import { asError, type ErrorLike, isErrorLike, logger } from '@sourcegraph/common'
|
||||
import type { TelemetryV2Props } from '@sourcegraph/shared/src/telemetry'
|
||||
import { EVENT_LOGGER } from '@sourcegraph/shared/src/telemetry/web/eventLogger'
|
||||
import { Button, Link, LoadingSpinner, Alert, Text, Input, ErrorAlert, Form, Container, Icon } from '@sourcegraph/wildcard'
|
||||
import {
|
||||
Button,
|
||||
Link,
|
||||
LoadingSpinner,
|
||||
Alert,
|
||||
Text,
|
||||
Input,
|
||||
ErrorAlert,
|
||||
Form,
|
||||
Container,
|
||||
Icon,
|
||||
} from '@sourcegraph/wildcard'
|
||||
|
||||
import type { AuthenticatedUser } from '../auth'
|
||||
import { LoaderButton } from '../components/LoaderButton'
|
||||
@ -159,7 +169,8 @@ class ResetPasswordCodeForm extends React.PureComponent<ResetPasswordCodeFormPro
|
||||
if (this.state.submitOrError === null) {
|
||||
return (
|
||||
<Alert variant="success">
|
||||
Your password was reset. <Link to={`/sign-in?email=${email}`}>Sign in with your new password</Link> to continue.
|
||||
Your password was reset. <Link to={`/sign-in?email=${email}`}>Sign in with your new password</Link>{' '}
|
||||
to continue.
|
||||
</Alert>
|
||||
)
|
||||
}
|
||||
@ -168,7 +179,10 @@ class ResetPasswordCodeForm extends React.PureComponent<ResetPasswordCodeFormPro
|
||||
<>
|
||||
{isErrorLike(this.state.submitOrError) && <ErrorAlert error={this.state.submitOrError} />}
|
||||
<Container className="w-100">
|
||||
<Link to='/password-reset'><Icon className="mr-1" aria-hidden={true} svgPath={mdiArrowLeftBoldBoxOutline} />Raise request for a different account</Link>
|
||||
<Link to="/password-reset">
|
||||
<Icon className="mr-1" aria-hidden={true} svgPath={mdiArrowLeftBoldBoxOutline} />
|
||||
Raise request for a different account
|
||||
</Link>
|
||||
<Text className="mt-1 text-center text-muted font-weight-bold mb-3">{email}</Text>
|
||||
<Form data-testid="reset-password-page-form" onSubmit={this.handleSubmitResetPassword}>
|
||||
<PasswordInput
|
||||
|
||||
@ -21,6 +21,8 @@ go_library(
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"//schema",
|
||||
"@com_github_cohere_ai_cohere_go_v2//:cohere-go",
|
||||
"@com_github_cohere_ai_cohere_go_v2//client",
|
||||
"@com_github_sourcegraph_conc//iter",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
],
|
||||
|
||||
@ -24,6 +24,9 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
|
||||
cohere "github.com/cohere-ai/cohere-go/v2"
|
||||
"github.com/cohere-ai/cohere-go/v2/client"
|
||||
)
|
||||
|
||||
func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient *codycontext.CodyContextClient, logger log.Logger) graphqlbackend.CodyContextResolver {
|
||||
@ -34,6 +37,8 @@ func NewResolver(db database.DB, gitserverClient gitserver.Client, contextClient
|
||||
logger: logger,
|
||||
intentApiHttpClient: httpcli.UncachedExternalDoer,
|
||||
intentBackendConfig: conf.CodyIntentConfig(),
|
||||
reranker: conf.CodyReranker(),
|
||||
cohereConfig: conf.CodyRerankerCohereConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,6 +49,8 @@ type Resolver struct {
|
||||
logger log.Logger
|
||||
intentApiHttpClient httpcli.Doer
|
||||
intentBackendConfig *schema.IntentDetectionAPI
|
||||
reranker conf.CodyRerankerBackend
|
||||
cohereConfig *schema.CodyRerankerCohere
|
||||
}
|
||||
|
||||
func (r *Resolver) RecordContext(ctx context.Context, args graphqlbackend.RecordContextArgs) (*graphqlbackend.EmptyResponse, error) {
|
||||
@ -74,14 +81,15 @@ func (r *Resolver) RankContext(ctx context.Context, args graphqlbackend.RankCont
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranker, used, err := r.rerank(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := rankContextResponse{
|
||||
ranker: "identity",
|
||||
ranker: string(ranker),
|
||||
used: used,
|
||||
}
|
||||
r.logger.Info("ranking context", log.String("interactionID", args.InteractionID), log.String("ranker", res.ranker), log.Int("contextItemCount", len(args.ContextItems)))
|
||||
|
||||
for i := range args.ContextItems {
|
||||
res.used = append(res.used, int32(i))
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
@ -260,6 +268,35 @@ func (r *Resolver) fileChunkToResolver(ctx context.Context, chunk *codycontext.F
|
||||
return graphqlbackend.NewFileChunkContextResolver(gitTreeEntryResolver, chunk.StartLine, endLine), nil
|
||||
}
|
||||
|
||||
func (r *Resolver) rerank(ctx context.Context, args graphqlbackend.RankContextArgs) (conf.CodyRerankerBackend, []int32, error) {
|
||||
if r.reranker == conf.CodyRerankerIdentity {
|
||||
var used []int32
|
||||
for i := range args.ContextItems {
|
||||
used = append(used, int32(i))
|
||||
}
|
||||
return conf.CodyRerankerIdentity, used, nil
|
||||
}
|
||||
co := client.NewClient(client.WithToken(r.cohereConfig.ApiKey))
|
||||
|
||||
req := &cohere.RerankRequest{
|
||||
Query: args.Query,
|
||||
Model: cohere.String(r.cohereConfig.Model),
|
||||
}
|
||||
for _, ci := range args.ContextItems {
|
||||
req.Documents = append(req.Documents, &cohere.RerankRequestDocumentsItem{String: ci.Content})
|
||||
}
|
||||
resp, err := co.Rerank(ctx, req)
|
||||
if err != nil {
|
||||
r.logger.Error("cohere reranking error", log.String("interactionId", args.InteractionID), log.String("query", args.Query), log.Error(err))
|
||||
return conf.CodyRerankerCohere, nil, err
|
||||
}
|
||||
var used []int32
|
||||
for _, r := range resp.Results {
|
||||
used = append(used, int32(r.Index))
|
||||
}
|
||||
return conf.CodyRerankerCohere, used, nil
|
||||
}
|
||||
|
||||
// countLines finds the number of lines corresponding to the number of runes. We 'round down'
|
||||
// to ensure that we don't return more characters than our budget.
|
||||
func countLines(content string, numRunes int) int {
|
||||
|
||||
7
deps.bzl
7
deps.bzl
@ -1181,6 +1181,13 @@ def go_dependencies():
|
||||
sum = "h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q=",
|
||||
version = "v0.0.0-20150114235600-33e0aa1cb7c0",
|
||||
)
|
||||
go_repository(
|
||||
name = "com_github_cohere_ai_cohere_go_v2",
|
||||
build_file_proto_mode = "disable_global",
|
||||
importpath = "github.com/cohere-ai/cohere-go/v2",
|
||||
sum = "h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=",
|
||||
version = "v2.8.2",
|
||||
)
|
||||
go_repository(
|
||||
name = "com_github_common_nighthawk_go_figure",
|
||||
build_file_proto_mode = "disable_global",
|
||||
|
||||
1
go.mod
1
go.mod
@ -269,6 +269,7 @@ require (
|
||||
github.com/bevzzz/nb v0.3.0
|
||||
github.com/bevzzz/nb-synth v0.0.0-20240128164931-35fdda0583a0
|
||||
github.com/bevzzz/nb/extension/extra/goldmark-jupyter v0.0.0-20240131001330-e69229bd9da4
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2
|
||||
github.com/derision-test/go-mockgen/v2 v2.0.1
|
||||
github.com/dghubble/gologin/v2 v2.4.0
|
||||
github.com/edsrzf/mmap-go v1.1.0
|
||||
|
||||
2
go.sum
2
go.sum
@ -1000,6 +1000,8 @@ github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZe
|
||||
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs=
|
||||
github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30=
|
||||
github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg=
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2 h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
|
||||
github.com/containerd/containerd v1.7.12 h1:+KQsnv4VnzyxWcfO9mlxxELaoztsDEjOuCMPAuPqgU0=
|
||||
github.com/containerd/containerd v1.7.12/go.mod h1:/5OMpE1p0ylxtEUGY8kuCYkDRzJm9NO1TFMWjUpdevk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
|
||||
@ -293,6 +293,30 @@ func CodyIntentConfig() *schema.IntentDetectionAPI {
|
||||
return Get().ExperimentalFeatures.CodyServerSideContext.IntentDetectionAPI
|
||||
}
|
||||
|
||||
type CodyRerankerBackend string
|
||||
|
||||
const (
|
||||
CodyRerankerIdentity CodyRerankerBackend = "identity"
|
||||
CodyRerankerCohere CodyRerankerBackend = "cohere"
|
||||
)
|
||||
|
||||
func CodyReranker() CodyRerankerBackend {
|
||||
if Get().ExperimentalFeatures == nil || Get().ExperimentalFeatures.CodyServerSideContext == nil || Get().ExperimentalFeatures.CodyServerSideContext.Reranker == nil {
|
||||
return CodyRerankerIdentity
|
||||
}
|
||||
if Get().ExperimentalFeatures.CodyServerSideContext.Reranker.Identity != nil {
|
||||
return CodyRerankerIdentity
|
||||
}
|
||||
return CodyRerankerCohere
|
||||
}
|
||||
|
||||
func CodyRerankerCohereConfig() *schema.CodyRerankerCohere {
|
||||
if CodyReranker() != CodyRerankerCohere {
|
||||
return nil
|
||||
}
|
||||
return Get().ExperimentalFeatures.CodyServerSideContext.Reranker.Cohere
|
||||
}
|
||||
|
||||
func ExecutorsEnabled() bool {
|
||||
return Get().ExecutorsAccessToken != ""
|
||||
}
|
||||
|
||||
@ -676,10 +676,24 @@ type CodyProConfig struct {
|
||||
UseEmbeddedUI bool `json:"useEmbeddedUI,omitempty"`
|
||||
}
|
||||
|
||||
// CodyRerankerCohere description: Re-ranker using Cohere API
|
||||
type CodyRerankerCohere struct {
|
||||
ApiKey string `json:"apiKey"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// CodyRerankerIdentity description: Identity re-ranker
|
||||
type CodyRerankerIdentity struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// CodyServerSideContext description: Configuration for Server-side context API
|
||||
type CodyServerSideContext struct {
|
||||
// IntentDetectionAPI description: Configuration for intent detection API
|
||||
IntentDetectionAPI *IntentDetectionAPI `json:"intentDetectionAPI,omitempty"`
|
||||
// Reranker description: Reranker to use for rankContext requests
|
||||
Reranker *Reranker `json:"reranker,omitempty"`
|
||||
}
|
||||
|
||||
// CommitGraphUpdates description: Customize strategy used for commit graph updates
|
||||
@ -2303,6 +2317,38 @@ type RequestMessage struct {
|
||||
Params any `json:"params,omitempty"`
|
||||
Settings map[string]any `json:"settings,omitempty"`
|
||||
}
|
||||
|
||||
// Reranker description: Reranker to use for rankContext requests
|
||||
type Reranker struct {
|
||||
Identity *CodyRerankerIdentity
|
||||
Cohere *CodyRerankerCohere
|
||||
}
|
||||
|
||||
func (v Reranker) MarshalJSON() ([]byte, error) {
|
||||
if v.Identity != nil {
|
||||
return json.Marshal(v.Identity)
|
||||
}
|
||||
if v.Cohere != nil {
|
||||
return json.Marshal(v.Cohere)
|
||||
}
|
||||
return nil, errors.New("tagged union type must have exactly 1 non-nil field value")
|
||||
}
|
||||
func (v *Reranker) UnmarshalJSON(data []byte) error {
|
||||
var d struct {
|
||||
DiscriminantProperty string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
switch d.DiscriminantProperty {
|
||||
case "cohere":
|
||||
return json.Unmarshal(data, &v.Cohere)
|
||||
case "identity":
|
||||
return json.Unmarshal(data, &v.Identity)
|
||||
}
|
||||
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"identity", "cohere"})
|
||||
}
|
||||
|
||||
type Responders struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
@ -656,6 +656,27 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"reranker": {
|
||||
"description": "Reranker to use for rankContext requests",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["identity", "cohere"]
|
||||
}
|
||||
},
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/definitions/CodyRerankerIdentity"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/CodyRerankerCohere"
|
||||
}
|
||||
],
|
||||
"!go": {
|
||||
"taggedUnionType": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4570,6 +4591,35 @@
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CodyRerankerIdentity": {
|
||||
"description": "Identity re-ranker",
|
||||
"type": "object",
|
||||
"required": ["type"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "identity"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CodyRerankerCohere": {
|
||||
"description": "Re-ranker using Cohere API",
|
||||
"type": "object",
|
||||
"required": ["type", "apiKey"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "cohere"
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"default": "rerank-english-v3.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user