Unify API for code completions and chat (#51962)

This PR makes the input / output of these two endpoints the same, making
them effectively the same thing, with the only difference that another
model is chosen for code, and that code completions are non-streaming.
This will be useful long-term as there's less divergence, and help make
the implementation of additional upstream providers easier as only one
method needs to be implemented in most cases.
This commit is contained in:
Erik Seliger 2023-05-17 00:11:14 +02:00 committed by GitHub
parent c2045704d8
commit fe8e70d94d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 201 additions and 219 deletions

View File

@ -1,10 +1,10 @@
import { fetchEventSource } from '@microsoft/fetch-event-source'
import { SourcegraphCompletionsClient } from './client'
import type { Event, CompletionParameters, CompletionCallbacks, CodeCompletionResponse } from './types'
import type { Event, CompletionParameters, CompletionCallbacks, CompletionResponse } from './types'
export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsClient {
public complete(): Promise<CodeCompletionResponse> {
public complete(): Promise<CompletionResponse> {
throw new Error('SourcegraphBrowserCompletionsClient.complete not implemented')
}

View File

@ -1,19 +1,13 @@
import { ConfigurationWithAccessToken } from '../../configuration'
import {
Event,
CompletionParameters,
CompletionCallbacks,
CodeCompletionParameters,
CodeCompletionResponse,
} from './types'
import { Event, CompletionCallbacks, CompletionParameters, CompletionResponse } from './types'
export interface CompletionLogger {
startCompletion(params: CodeCompletionParameters | CompletionParameters):
startCompletion(params: CompletionParameters):
| undefined
| {
onError: (error: string) => void
onComplete: (response: string | CodeCompletionResponse) => void
onComplete: (response: string | CompletionResponse) => void
onEvents: (events: Event[]) => void
}
}
@ -52,8 +46,5 @@ export abstract class SourcegraphCompletionsClient {
}
public abstract stream(params: CompletionParameters, cb: CompletionCallbacks): () => void
public abstract complete(
params: CodeCompletionParameters,
abortSignal: AbortSignal
): Promise<CodeCompletionResponse>
public abstract complete(params: CompletionParameters, abortSignal: AbortSignal): Promise<CompletionResponse>
}

View File

@ -6,10 +6,10 @@ import { toPartialUtf8String } from '../utils'
import { SourcegraphCompletionsClient } from './client'
import { parseEvents } from './parse'
import { CompletionParameters, CompletionCallbacks, CodeCompletionParameters, CodeCompletionResponse } from './types'
import { CompletionParameters, CompletionCallbacks, CompletionResponse } from './types'
export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClient {
public async complete(params: CodeCompletionParameters, abortSignal: AbortSignal): Promise<CodeCompletionResponse> {
public async complete(params: CompletionParameters, abortSignal: AbortSignal): Promise<CompletionResponse> {
const log = this.logger?.startCompletion(params)
const requestFn = this.codeCompletionsEndpoint.startsWith('https://') ? https.request : http.request
@ -18,7 +18,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
if (this.config.accessToken) {
headersInstance.set('Authorization', `token ${this.config.accessToken}`)
}
const completion = await new Promise<CodeCompletionResponse>((resolve, reject) => {
const completion = await new Promise<CompletionResponse>((resolve, reject) => {
const req = requestFn(
this.codeCompletionsEndpoint,
{
@ -44,7 +44,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
reject(new Error(buffer))
}
const resp = JSON.parse(buffer) as CodeCompletionResponse
const resp = JSON.parse(buffer) as CompletionResponse
if (typeof resp.completion !== 'string' || typeof resp.stopReason !== 'string') {
const message = `response does not satisfy CodeCompletionResponse: ${buffer}`
log?.onError(message)

View File

@ -37,14 +37,14 @@ function parseEventData(eventType: Event['type'], dataLine: string): Event | Err
const jsonData = dataLine.slice(DATA_LINE_PREFIX.length)
switch (eventType) {
case 'completion': {
const data = parseJSON<{ completion: string }>(jsonData)
const data = parseJSON<{ completion: string; stopReason: string }>(jsonData)
if (isError(data)) {
return data
}
if (typeof data.completion === undefined) {
return new Error('invalid completion event')
}
return { type: eventType, completion: data.completion }
return { type: eventType, completion: data.completion, stopReason: data.stopReason }
}
case 'error': {
const data = parseJSON<{ error: string }>(jsonData)

View File

@ -2,9 +2,8 @@ export interface DoneEvent {
type: 'done'
}
export interface CompletionEvent {
export interface CompletionEvent extends CompletionResponse {
type: 'completion'
completion: string
}
export interface ErrorEvent {
@ -19,31 +18,19 @@ export interface Message {
text?: string
}
export interface CodeCompletionResponse {
export interface CompletionResponse {
completion: string
stop: string | null
stopReason: string
truncated: boolean
exception: string | null
logID: string
}
export interface CodeCompletionParameters {
prompt: string
temperature: number
maxTokensToSample: number
stopSequences: string[]
topK: number
topP: number
model?: string
}
export interface CompletionParameters {
messages: Message[]
temperature: number
maxTokensToSample: number
topK: number
topP: number
temperature?: number
stopSequences?: string[]
topK?: number
topP?: number
model?: string
}
export interface CompletionCallbacks {

View File

@ -3,14 +3,14 @@ import { CompletionsCache } from './cache'
describe('CompletionsCache', () => {
it('returns the cached completion items', () => {
const cache = new CompletionsCache()
cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }])
expect(cache.get('foo\n')).toEqual([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
})
it('returns the cached items when the prefix includes characters from the completion', () => {
const cache = new CompletionsCache()
cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }])
expect(cache.get('foo\nb')).toEqual([{ prefix: 'foo\nb', content: 'ar', prompt: '' }])
expect(cache.get('foo\nba')).toEqual([{ prefix: 'foo\nba', content: 'r', prompt: '' }])
@ -18,7 +18,7 @@ describe('CompletionsCache', () => {
it('returns the cached items when the prefix has less whitespace', () => {
const cache = new CompletionsCache()
cache.add([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])
cache.add([{ prefix: 'foo \n ', content: 'bar', messages: [] }])
expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])
expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])

View File

@ -1,5 +1,6 @@
import * as vscode from 'vscode'
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient'
import { logEvent } from '../event-logger'
@ -372,7 +373,7 @@ function getCurrentDocContext(
export interface Completion {
prefix: string
prompt: string
messages: Message[]
content: string
stopReason?: string
}

View File

@ -1,18 +1,15 @@
import * as anthropic from '@anthropic-ai/sdk'
import { ReferenceSnippet } from './context'
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
export interface Message {
role: 'human' | 'ai'
text: string | null
}
import { ReferenceSnippet } from './context'
export function messagesToText(messages: Message[]): string {
return messages
.map(
message =>
`${message.role === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${
message.text === null ? '' : ' ' + message.text
`${message.speaker === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${
message.text === undefined ? '' : ' ' + message.text
}`
)
.join('')
@ -39,7 +36,7 @@ export class SingleLinePromptTemplate implements PromptTemplate {
const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
prefixMessages = [
{
role: 'human',
speaker: 'human',
text:
'Complete the following file:\n' +
'```' +
@ -47,7 +44,7 @@ export class SingleLinePromptTemplate implements PromptTemplate {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text:
'Here is the completion of the file:\n' +
'```' +
@ -57,11 +54,11 @@ export class SingleLinePromptTemplate implements PromptTemplate {
} else {
prefixMessages = [
{
role: 'human',
speaker: 'human',
text: 'Write some code',
},
{
role: 'ai',
speaker: 'assistant',
text: `Here is some code:\n\`\`\`\n${prefix}`,
},
]
@ -102,7 +99,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
prefixMessages = [
{
role: 'human',
speaker: 'human',
text:
'Complete the following file:\n' +
'```' +
@ -110,7 +107,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text:
'Here is the completion of the file:\n' +
'```' +
@ -120,11 +117,11 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
} else {
prefixMessages = [
{
role: 'human',
speaker: 'human',
text: 'Write some code',
},
{
role: 'ai',
speaker: 'assistant',
text: `Here is some code:\n\`\`\`\n${prefix}`,
},
]
@ -135,7 +132,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
for (const snippet of snippets) {
const snippetMessages: Message[] = [
{
role: 'human',
speaker: 'human',
text:
`Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` +
'```' +
@ -143,7 +140,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text: 'Okay, I have added it to my knowledge base.',
},
]

View File

@ -2,13 +2,14 @@ import * as anthropic from '@anthropic-ai/sdk'
import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient'
import {
CodeCompletionParameters,
CodeCompletionResponse,
CompletionParameters,
CompletionResponse,
Message,
} from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types'
import { Completion } from '.'
import { ReferenceSnippet } from './context'
import { Message, messagesToText } from './prompts'
import { messagesToText } from './prompts'
export abstract class CompletionProvider {
constructor(
@ -32,7 +33,7 @@ export abstract class CompletionProvider {
// Creates the resulting prompt and adds as many snippets from the reference
// list as possible.
protected createPrompt(): string {
protected createPrompt(): Message[] {
const prefixMessages = this.createPromptPrefix()
const referenceSnippetMessages: Message[] = []
@ -50,7 +51,7 @@ export abstract class CompletionProvider {
if (suffix.length > 0) {
const suffixContext: Message[] = [
{
role: 'human',
speaker: 'human',
text:
'Add the following code snippet to your knowledge base:\n' +
'```' +
@ -58,7 +59,7 @@ export abstract class CompletionProvider {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text: 'Okay, I have added it to my knowledge base.',
},
]
@ -74,7 +75,7 @@ export abstract class CompletionProvider {
for (const snippet of this.snippets) {
const snippetMessages: Message[] = [
{
role: 'human',
speaker: 'human',
text:
`Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` +
'```' +
@ -82,7 +83,7 @@ export abstract class CompletionProvider {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text: 'Okay, I have added it to my knowledge base.',
},
]
@ -94,7 +95,7 @@ export abstract class CompletionProvider {
remainingChars -= numSnippetChars
}
return messagesToText([...referenceSnippetMessages, ...prefixMessages])
return [...referenceSnippetMessages, ...prefixMessages]
}
public abstract generateCompletions(abortSignal: AbortSignal, n?: number): Promise<Completion[]>
@ -115,7 +116,7 @@ export class MultilineCompletionProvider extends CompletionProvider {
const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
prefixMessages = [
{
role: 'human',
speaker: 'human',
text:
'Complete the following file:\n' +
'```' +
@ -123,18 +124,18 @@ export class MultilineCompletionProvider extends CompletionProvider {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text: `Here is the completion of the file:\n\`\`\`\n${prefixLines.slice(endLine).join('\n')}`,
},
]
} else {
prefixMessages = [
{
role: 'human',
speaker: 'human',
text: 'Write some code',
},
{
role: 'ai',
speaker: 'assistant',
text: `Here is some code:\n\`\`\`\n${prefix}`,
},
]
@ -164,7 +165,8 @@ export class MultilineCompletionProvider extends CompletionProvider {
// Create prompt
const prompt = this.createPrompt()
if (prompt.length > this.promptChars) {
const textPrompt = messagesToText(prompt)
if (textPrompt.length > this.promptChars) {
throw new Error('prompt length exceeded maximum alloted chars')
}
@ -172,13 +174,8 @@ export class MultilineCompletionProvider extends CompletionProvider {
const responses = await batchCompletions(
this.completionsClient,
{
prompt,
stopSequences: [anthropic.HUMAN_PROMPT],
messages: prompt,
maxTokensToSample: this.responseTokens,
model: 'claude-instant-v1.0',
temperature: 1, // default value (source: https://console.anthropic.com/docs/api/reference)
topK: -1, // default value
topP: -1, // default value
},
n || this.defaultN,
abortSignal
@ -186,7 +183,7 @@ export class MultilineCompletionProvider extends CompletionProvider {
// Post-process
return responses.map(resp => ({
prefix,
prompt,
messages: prompt,
content: this.postProcess(resp.completion),
stopReason: resp.stopReason,
}))
@ -206,7 +203,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
prefixMessages = [
{
role: 'human',
speaker: 'human',
text:
'Complete the following file:\n' +
'```' +
@ -214,7 +211,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
'```',
},
{
role: 'ai',
speaker: 'assistant',
text:
'Here is the completion of the file:\n' +
'```' +
@ -224,11 +221,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
} else {
prefixMessages = [
{
role: 'human',
speaker: 'human',
text: 'Write some code',
},
{
role: 'ai',
speaker: 'assistant',
text: `Here is some code:\n\`\`\`\n${this.prefix}${this.injectPrefix}`,
},
]
@ -272,10 +269,9 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
const responses = await batchCompletions(
this.completionsClient,
{
prompt,
messages: prompt,
stopSequences: [anthropic.HUMAN_PROMPT, '\n'],
maxTokensToSample: this.responseTokens,
model: 'claude-instant-v1.0',
temperature: 1,
topK: -1,
topP: -1,
@ -286,7 +282,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
// Post-process
return responses.map(resp => ({
prefix,
prompt,
messages: prompt,
content: this.postProcess(resp.completion),
stopReason: resp.stopReason,
}))
@ -295,11 +291,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
async function batchCompletions(
client: SourcegraphNodeCompletionsClient,
params: CodeCompletionParameters,
params: CompletionParameters,
n: number,
abortSignal: AbortSignal
): Promise<CodeCompletionResponse[]> {
const responses: Promise<CodeCompletionResponse>[] = []
): Promise<CompletionResponse[]> {
const responses: Promise<CompletionResponse>[] = []
for (let i = 0; i < n; i++) {
responses.push(client.complete(params, abortSignal))
}

View File

@ -2,9 +2,8 @@ import vscode from 'vscode'
import { CompletionLogger } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/client'
import {
CodeCompletionParameters,
CodeCompletionResponse,
CompletionParameters,
CompletionResponse,
Event,
} from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types'
@ -19,7 +18,7 @@ if (config.debug) {
}
export const logger: CompletionLogger = {
startCompletion(params: CodeCompletionParameters | CompletionParameters) {
startCompletion(params: CompletionParameters) {
if (!outputChannel) {
return undefined
}
@ -47,7 +46,7 @@ export const logger: CompletionLogger = {
outputChannel!.appendLine('')
}
function onComplete(result: string | CodeCompletionResponse): void {
function onComplete(result: string | CompletionResponse): void {
if (hasFinished) {
return
}

View File

@ -65,7 +65,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
}
var last string
if err := client.Stream(ctx, convertParams(args), func(event types.ChatCompletionEvent) error {
if err := client.Stream(ctx, convertParams(args), func(event types.CompletionResponse) error {
// each completion is just a partial of the final result, since we're in a sync request anyway
// we will just wait for the final completion event
last = event.Completion
@ -76,8 +76,8 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
return last, nil
}
func convertParams(args graphqlbackend.CompletionsArgs) types.ChatCompletionRequestParameters {
return types.ChatCompletionRequestParameters{
func convertParams(args graphqlbackend.CompletionsArgs) types.CompletionRequestParameters {
return types.CompletionRequestParameters{
Messages: convertMessages(args.Input.Messages),
Temperature: float32(args.Input.Temperature),
MaxTokensToSample: int(args.Input.MaxTokensToSample),

View File

@ -15,7 +15,7 @@ import (
const anthropicAPIURL = "https://api.anthropic.com/v1/complete"
func newAnthropicHandler(logger log.Logger, eventLogger events.Logger, accessToken string) http.Handler {
return makeUpstreamHandler[anthropicRequest](
return makeUpstreamHandler(
logger,
eventLogger,
anthropicAPIURL,

View File

@ -15,7 +15,7 @@ import (
const openAIURL = "https://api.openai.com/v1/chat/completions"
func newOpenAIHandler(logger log.Logger, eventLogger events.Logger, accessToken string, orgID string) http.Handler {
return makeUpstreamHandler[openaiRequest](
return makeUpstreamHandler(
logger,
eventLogger,
openAIURL,

View File

@ -12,17 +12,6 @@ import (
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type AnthropicCompletionsRequestParameters struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
Model string `json:"model"`
Stream bool `json:"stream"`
}
const ProviderName = "anthropic"
func NewClient(cli httpcli.Doer, accessToken string, model string) types.CompletionsClient {
@ -33,77 +22,41 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet
}
}
const apiURL = "https://api.anthropic.com/v1/complete"
const clientID = "sourcegraph/1.0"
type anthropicClient struct {
cli httpcli.Doer
accessToken string
model string
}
const apiURL = "https://api.anthropic.com/v1/complete"
const clientID = "sourcegraph/1.0"
var stopSequences = []string{HUMAN_PROMPT}
var allowedClientSpecifiedModels = map[string]struct{}{
"claude-instant-v1.0": {},
}
func (a *anthropicClient) Complete(
ctx context.Context,
requestParams types.CodeCompletionRequestParameters,
) (*types.CodeCompletionResponse, error) {
var model string
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
model = requestParams.Model
} else {
model = a.model
}
payload := AnthropicCompletionsRequestParameters{
Stream: false,
StopSequences: requestParams.StopSequences,
Model: model,
Temperature: float32(requestParams.Temperature),
MaxTokensToSample: requestParams.MaxTokensToSample,
TopP: float32(requestParams.TopP),
TopK: requestParams.TopK,
Prompt: requestParams.Prompt,
}
resp, err := a.makeRequest(ctx, payload)
requestParams types.CompletionRequestParameters,
) (*types.CompletionResponse, error) {
resp, err := a.makeRequest(ctx, requestParams, false)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var response types.CodeCompletionResponse
var response anthropicCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, err
}
return &response, nil
return &types.CompletionResponse{
Completion: response.Completion,
StopReason: response.StopReason,
}, nil
}
func (a *anthropicClient) Stream(
ctx context.Context,
requestParams types.ChatCompletionRequestParameters,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
) error {
prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return err
}
payload := AnthropicCompletionsRequestParameters{
Stream: true,
StopSequences: stopSequences,
Model: a.model,
Temperature: requestParams.Temperature,
MaxTokensToSample: requestParams.MaxTokensToSample,
TopP: requestParams.TopP,
TopK: requestParams.TopK,
Prompt: prompt,
}
resp, err := a.makeRequest(ctx, payload)
resp, err := a.makeRequest(ctx, requestParams, true)
if err != nil {
return err
}
@ -122,12 +75,15 @@ func (a *anthropicClient) Stream(
continue
}
var event types.ChatCompletionEvent
var event anthropicCompletionResponse
if err := json.Unmarshal(data, &event); err != nil {
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
}
err = sendEvent(event)
err = sendEvent(types.CompletionResponse{
Completion: event.Completion,
StopReason: event.StopReason,
})
if err != nil {
return err
}
@ -136,7 +92,32 @@ func (a *anthropicClient) Stream(
return dec.Err()
}
func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicCompletionsRequestParameters) (*http.Response, error) {
func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return nil, err
}
// Backcompat: Remove this code once enough clients are upgraded and we drop the
// Prompt field on requestParams.
if prompt == "" {
prompt = requestParams.Prompt
}
if len(requestParams.StopSequences) == 0 {
requestParams.StopSequences = []string{HUMAN_PROMPT}
}
payload := anthropicCompletionsRequestParameters{
Stream: stream,
StopSequences: requestParams.StopSequences,
Model: a.model,
Temperature: requestParams.Temperature,
MaxTokensToSample: requestParams.MaxTokensToSample,
TopP: requestParams.TopP,
TopK: requestParams.TopK,
Prompt: prompt,
}
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err
@ -168,3 +149,19 @@ func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicComp
return resp, nil
}
type anthropicCompletionsRequestParameters struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
Model string `json:"model"`
Stream bool `json:"stream"`
}
type anthropicCompletionResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
}

View File

@ -49,8 +49,8 @@ func TestValidAnthropicStream(t *testing.T) {
}
mockClient := getMockClient(linesToResponse(mockAnthropicResponseLines))
events := []types.ChatCompletionEvent{}
err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error {
events := []types.CompletionResponse{}
err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error {
events = append(events, event)
return nil
})
@ -64,7 +64,7 @@ func TestInvalidAnthropicStream(t *testing.T) {
var mockAnthropicInvalidResponseLines = []string{`{]`}
mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines))
err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error { return nil })
err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
if err == nil {
t.Fatal("expected error, got nil")
}

View File

@ -1,4 +1,4 @@
[]types.ChatCompletionEvent{
[]types.CompletionResponse{
{
Completion: "Sure!",
},

View File

@ -36,14 +36,14 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet
func (a *dotcomClient) Complete(
ctx context.Context,
requestParams types.CodeCompletionRequestParameters,
) (*types.CodeCompletionResponse, error) {
requestParams types.CompletionRequestParameters,
) (*types.CompletionResponse, error) {
return nil, errors.New("not implemented")
}
func (a *dotcomClient) Stream(
ctx context.Context,
requestParams types.ChatCompletionRequestParameters,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
) error {
reqBody, err := json.Marshal(requestParams)
@ -75,7 +75,7 @@ func (a *dotcomClient) Stream(
return nil
}
var event types.ChatCompletionEvent
var event types.CompletionResponse
if err := json.Unmarshal(dec.Data(), &event); err != nil {
return errors.Errorf("failed to decode event payload: %w", err)
}

View File

@ -32,13 +32,13 @@ type openAIChatCompletionStreamClient struct {
model string
}
func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CodeCompletionRequestParameters) (*types.CodeCompletionResponse, error) {
func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CompletionRequestParameters) (*types.CompletionResponse, error) {
return nil, errors.New("openAIChatCompletionStreamClient.Complete: unimplemented")
}
func (a *openAIChatCompletionStreamClient) Stream(
ctx context.Context,
requestParams types.ChatCompletionRequestParameters,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
) error {
if requestParams.TopK < 0 {
@ -128,9 +128,13 @@ func (a *openAIChatCompletionStreamClient) Stream(
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
}
if len(event.Choices) > 0 && event.Choices[0].FinishReason == nil {
if len(event.Choices) > 0 {
content = append(content, event.Choices[0].Delta.Content)
err = sendEvent(types.ChatCompletionEvent{Completion: strings.Join(content, "")})
ev := types.CompletionResponse{Completion: strings.Join(content, "")}
if event.Choices[0].FinishReason != nil {
ev.StopReason = *event.Choices[0].FinishReason
}
err = sendEvent(ev)
if err != nil {
return err
}

View File

@ -13,12 +13,24 @@ import (
"github.com/sourcegraph/sourcegraph/schema"
)
var allowedClientSpecifiedModels = map[string]struct{}{
// TODO(eseliger): This list should probably be configurable.
"claude-instant-v1.0": {},
}
// NewCodeCompletionsHandler is an http handler which sends back code completion results
func NewCodeCompletionsHandler(_ log.Logger, db database.DB) http.Handler {
rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCodeCompletion)
return newCompletionsHandler(rl, "codeCompletions", func(c *schema.Completions) string {
return c.CompletionModel
}, func(ctx context.Context, requestParams types.CodeCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
return newCompletionsHandler(rl, "codeCompletions", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
var model string
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
model = requestParams.Model
} else {
model = c.CompletionModel
}
return model
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
completion, err := cc.Complete(ctx, requestParams)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -18,7 +18,7 @@ import (
// being cancelled.
const maxRequestDuration = time.Minute
func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(*schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler {
func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(T, *schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed)
@ -45,7 +45,7 @@ func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel f
return
}
model := getModel(completionsConfig)
model := getModel(requestParams, completionsConfig)
var err error
ctx, done := Trace(ctx, traceFamily, model).

View File

@ -17,9 +17,16 @@ import (
// NewCompletionsStreamHandler is an http handler which streams back completions results.
func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler {
rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCompletion)
return newCompletionsHandler(rl, "stream", func(c *schema.Completions) string {
return c.ChatModel
}, func(ctx context.Context, requestParams types.ChatCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
return newCompletionsHandler(rl, "stream", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
var model string
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
model = requestParams.Model
} else {
model = c.ChatModel
}
return model
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
eventWriter, err := streamhttp.NewWriter(w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -31,7 +38,7 @@ func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler
_ = eventWriter.Event("done", map[string]any{})
}()
err = cc.Stream(ctx, requestParams, func(event types.ChatCompletionEvent) error { return eventWriter.Event("completion", event) })
err = cc.Stream(ctx, requestParams, func(event types.CompletionResponse) error { return eventWriter.Event("completion", event) })
if err != nil {
trace.Logger(ctx, logger).Error("error while streaming completions", log.Error(err))
_ = eventWriter.Event("error", map[string]string{"error": err.Error()})

View File

@ -15,36 +15,8 @@ type Message struct {
Text string `json:"text"`
}
type CodeCompletionRequestParameters struct {
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokensToSample int `json:"maxTokensToSample"`
StopSequences []string `json:"stopSequences"`
TopK int `json:"topK,omitempty"`
TopP float64 `json:"topP,omitempty"`
Model string `json:"model"`
Tags map[string]string `json:"tags,omitempty"`
}
type CodeCompletionResponse struct {
Completion string `json:"completion"`
Stop *string `json:"stop"`
StopReason string `json:"stopReason"`
Truncated bool `json:"truncated"`
Exception *string `json:"exception"`
LogID string `json:"logID"`
}
type ChatCompletionRequestParameters struct {
Messages []Message `json:"messages"`
Temperature float32 `json:"temperature"`
MaxTokensToSample int `json:"maxTokensToSample"`
TopK int `json:"topK"`
TopP float32 `json:"topP"`
}
type ChatCompletionEvent struct {
Completion string `json:"completion"`
func (m Message) IsValidSpeaker() bool {
return m.Speaker == HUMAN_MESSAGE_SPEAKER || m.Speaker == ASISSTANT_MESSAGE_SPEAKER
}
func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (string, error) {
@ -65,9 +37,28 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str
return fmt.Sprintf("%s %s", prefix, m.Text), nil
}
type SendCompletionEvent func(event ChatCompletionEvent) error
type CompletionRequestParameters struct {
// Prompt exists only for backwards compatibility. Do not use it in new
// implementations. It will be removed once we are reasonably sure 99%
// of VSCode extension installations are upgraded to a new Cody version.
Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
MaxTokensToSample int `json:"maxTokensToSample,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float32 `json:"topP,omitempty"`
Model string `json:"model,omitempty"`
}
type CompletionResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stopReason"`
}
type SendCompletionEvent func(event CompletionResponse) error
type CompletionsClient interface {
Stream(ctx context.Context, requestParams ChatCompletionRequestParameters, sendEvent SendCompletionEvent) error
Complete(ctx context.Context, requestParams CodeCompletionRequestParameters) (*CodeCompletionResponse, error)
Stream(ctx context.Context, requestParams CompletionRequestParameters, sendEvent SendCompletionEvent) error
Complete(ctx context.Context, requestParams CompletionRequestParameters) (*CompletionResponse, error)
}