diff --git a/client/cody-shared/src/sourcegraph-api/completions/browserClient.ts b/client/cody-shared/src/sourcegraph-api/completions/browserClient.ts index deef1a0ce18..93dbfd1e819 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/browserClient.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/browserClient.ts @@ -1,10 +1,10 @@ import { fetchEventSource } from '@microsoft/fetch-event-source' import { SourcegraphCompletionsClient } from './client' -import type { Event, CompletionParameters, CompletionCallbacks, CodeCompletionResponse } from './types' +import type { Event, CompletionParameters, CompletionCallbacks, CompletionResponse } from './types' export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsClient { - public complete(): Promise { + public complete(): Promise { throw new Error('SourcegraphBrowserCompletionsClient.complete not implemented') } diff --git a/client/cody-shared/src/sourcegraph-api/completions/client.ts b/client/cody-shared/src/sourcegraph-api/completions/client.ts index f6ff925e7c9..ee9a0dc199a 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/client.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/client.ts @@ -1,19 +1,13 @@ import { ConfigurationWithAccessToken } from '../../configuration' -import { - Event, - CompletionParameters, - CompletionCallbacks, - CodeCompletionParameters, - CodeCompletionResponse, -} from './types' +import { Event, CompletionCallbacks, CompletionParameters, CompletionResponse } from './types' export interface CompletionLogger { - startCompletion(params: CodeCompletionParameters | CompletionParameters): + startCompletion(params: CompletionParameters): | undefined | { onError: (error: string) => void - onComplete: (response: string | CodeCompletionResponse) => void + onComplete: (response: string | CompletionResponse) => void onEvents: (events: Event[]) => void } } @@ -52,8 +46,5 @@ export abstract class SourcegraphCompletionsClient { } public abstract stream(params: CompletionParameters, cb: CompletionCallbacks): () => void - public abstract complete( - params: CodeCompletionParameters, - abortSignal: AbortSignal - ): Promise + public abstract complete(params: CompletionParameters, abortSignal: AbortSignal): Promise } diff --git a/client/cody-shared/src/sourcegraph-api/completions/nodeClient.ts b/client/cody-shared/src/sourcegraph-api/completions/nodeClient.ts index 9df0433adf2..45a022ab61f 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/nodeClient.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/nodeClient.ts @@ -6,10 +6,10 @@ import { toPartialUtf8String } from '../utils' import { SourcegraphCompletionsClient } from './client' import { parseEvents } from './parse' -import { CompletionParameters, CompletionCallbacks, CodeCompletionParameters, CodeCompletionResponse } from './types' +import { CompletionParameters, CompletionCallbacks, CompletionResponse } from './types' export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClient { - public async complete(params: CodeCompletionParameters, abortSignal: AbortSignal): Promise { + public async complete(params: CompletionParameters, abortSignal: AbortSignal): Promise { const log = this.logger?.startCompletion(params) const requestFn = this.codeCompletionsEndpoint.startsWith('https://') ? https.request : http.request @@ -18,7 +18,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie if (this.config.accessToken) { headersInstance.set('Authorization', `token ${this.config.accessToken}`) } - const completion = await new Promise((resolve, reject) => { + const completion = await new Promise((resolve, reject) => { const req = requestFn( this.codeCompletionsEndpoint, { @@ -44,7 +44,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie reject(new Error(buffer)) } - const resp = JSON.parse(buffer) as CodeCompletionResponse + const resp = JSON.parse(buffer) as CompletionResponse if (typeof resp.completion !== 'string' || typeof resp.stopReason !== 'string') { const message = `response does not satisfy CodeCompletionResponse: ${buffer}` log?.onError(message) diff --git a/client/cody-shared/src/sourcegraph-api/completions/parse.ts b/client/cody-shared/src/sourcegraph-api/completions/parse.ts index bc8d4c3ec06..f2db31ab330 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/parse.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/parse.ts @@ -37,14 +37,14 @@ function parseEventData(eventType: Event['type'], dataLine: string): Event | Err const jsonData = dataLine.slice(DATA_LINE_PREFIX.length) switch (eventType) { case 'completion': { - const data = parseJSON<{ completion: string }>(jsonData) + const data = parseJSON<{ completion: string; stopReason: string }>(jsonData) if (isError(data)) { return data } if (typeof data.completion === undefined) { return new Error('invalid completion event') } - return { type: eventType, completion: data.completion } + return { type: eventType, completion: data.completion, stopReason: data.stopReason } } case 'error': { const data = parseJSON<{ error: string }>(jsonData) diff --git a/client/cody-shared/src/sourcegraph-api/completions/types.ts b/client/cody-shared/src/sourcegraph-api/completions/types.ts index f0ff46da649..c992087d64e 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/types.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/types.ts @@ -2,9 +2,8 @@ export interface DoneEvent { type: 'done' } -export interface CompletionEvent { +export interface CompletionEvent extends CompletionResponse { type: 'completion' - completion: string } export interface ErrorEvent { @@ -19,31 +18,19 @@ export interface Message { text?: string } -export interface CodeCompletionResponse { +export interface CompletionResponse { completion: string - stop: string | null stopReason: string - truncated: boolean - exception: string | null - logID: string -} - -export interface CodeCompletionParameters { - prompt: string - temperature: number - maxTokensToSample: number - stopSequences: string[] - topK: number - topP: number - model?: string } export interface CompletionParameters { messages: Message[] - temperature: number maxTokensToSample: number - topK: number - topP: number + temperature?: number + stopSequences?: string[] + topK?: number + topP?: number + model?: string } export interface CompletionCallbacks { diff --git a/client/cody/src/completions/cache.test.ts b/client/cody/src/completions/cache.test.ts index d9c8a5d38c1..aaa79e3d0b2 100644 --- a/client/cody/src/completions/cache.test.ts +++ b/client/cody/src/completions/cache.test.ts @@ -3,14 +3,14 @@ import { CompletionsCache } from './cache' describe('CompletionsCache', () => { it('returns the cached completion items', () => { const cache = new CompletionsCache() - cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }]) + cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }]) expect(cache.get('foo\n')).toEqual([{ prefix: 'foo\n', content: 'bar', prompt: '' }]) }) it('returns the cached items when the prefix includes characters from the completion', () => { const cache = new CompletionsCache() - cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }]) + cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }]) expect(cache.get('foo\nb')).toEqual([{ prefix: 'foo\nb', content: 'ar', prompt: '' }]) expect(cache.get('foo\nba')).toEqual([{ prefix: 'foo\nba', content: 'r', prompt: '' }]) @@ -18,7 +18,7 @@ describe('CompletionsCache', () => { it('returns the cached items when the prefix has less whitespace', () => { const cache = new CompletionsCache() - cache.add([{ prefix: 'foo \n ', content: 'bar', prompt: '' }]) + cache.add([{ prefix: 'foo \n ', content: 'bar', messages: [] }]) expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }]) expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }]) diff --git a/client/cody/src/completions/index.ts b/client/cody/src/completions/index.ts index ecc8b5a09ca..f2d59e69c24 100644 --- a/client/cody/src/completions/index.ts +++ b/client/cody/src/completions/index.ts @@ -1,5 +1,6 @@ import * as vscode from 'vscode' +import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api' import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient' import { logEvent } from '../event-logger' @@ -372,7 +373,7 @@ function getCurrentDocContext( export interface Completion { prefix: string - prompt: string + messages: Message[] content: string stopReason?: string } diff --git a/client/cody/src/completions/prompts.ts b/client/cody/src/completions/prompts.ts index d9400c087a9..ea44168cabe 100644 --- a/client/cody/src/completions/prompts.ts +++ b/client/cody/src/completions/prompts.ts @@ -1,18 +1,15 @@ import * as anthropic from '@anthropic-ai/sdk' -import { ReferenceSnippet } from './context' +import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api' -export interface Message { - role: 'human' | 'ai' - text: string | null -} +import { ReferenceSnippet } from './context' export function messagesToText(messages: Message[]): string { return messages .map( message => - `${message.role === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${ - message.text === null ? '' : ' ' + message.text + `${message.speaker === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${ + message.text === undefined ? '' : ' ' + message.text }` ) .join('') @@ -39,7 +36,7 @@ export class SingleLinePromptTemplate implements PromptTemplate { const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5) prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Complete the following file:\n' + '```' + @@ -47,7 +44,7 @@ export class SingleLinePromptTemplate implements PromptTemplate { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Here is the completion of the file:\n' + '```' + @@ -57,11 +54,11 @@ export class SingleLinePromptTemplate implements PromptTemplate { } else { prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Write some code', }, { - role: 'ai', + speaker: 'assistant', text: `Here is some code:\n\`\`\`\n${prefix}`, }, ] @@ -102,7 +99,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate { const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5) prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Complete the following file:\n' + '```' + @@ -110,7 +107,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Here is the completion of the file:\n' + '```' + @@ -120,11 +117,11 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate { } else { prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Write some code', }, { - role: 'ai', + speaker: 'assistant', text: `Here is some code:\n\`\`\`\n${prefix}`, }, ] @@ -135,7 +132,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate { for (const snippet of snippets) { const snippetMessages: Message[] = [ { - role: 'human', + speaker: 'human', text: `Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` + '```' + @@ -143,7 +140,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Okay, I have added it to my knowledge base.', }, ] diff --git a/client/cody/src/completions/provider.ts b/client/cody/src/completions/provider.ts index d6a64fcc63d..4800c7435a2 100644 --- a/client/cody/src/completions/provider.ts +++ b/client/cody/src/completions/provider.ts @@ -2,13 +2,14 @@ import * as anthropic from '@anthropic-ai/sdk' import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient' import { - CodeCompletionParameters, - CodeCompletionResponse, + CompletionParameters, + CompletionResponse, + Message, } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types' import { Completion } from '.' import { ReferenceSnippet } from './context' -import { Message, messagesToText } from './prompts' +import { messagesToText } from './prompts' export abstract class CompletionProvider { constructor( @@ -32,7 +33,7 @@ export abstract class CompletionProvider { // Creates the resulting prompt and adds as many snippets from the reference // list as possible. - protected createPrompt(): string { + protected createPrompt(): Message[] { const prefixMessages = this.createPromptPrefix() const referenceSnippetMessages: Message[] = [] @@ -50,7 +51,7 @@ export abstract class CompletionProvider { if (suffix.length > 0) { const suffixContext: Message[] = [ { - role: 'human', + speaker: 'human', text: 'Add the following code snippet to your knowledge base:\n' + '```' + @@ -58,7 +59,7 @@ export abstract class CompletionProvider { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Okay, I have added it to my knowledge base.', }, ] @@ -74,7 +75,7 @@ export abstract class CompletionProvider { for (const snippet of this.snippets) { const snippetMessages: Message[] = [ { - role: 'human', + speaker: 'human', text: `Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` + '```' + @@ -82,7 +83,7 @@ export abstract class CompletionProvider { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Okay, I have added it to my knowledge base.', }, ] @@ -94,7 +95,7 @@ export abstract class CompletionProvider { remainingChars -= numSnippetChars } - return messagesToText([...referenceSnippetMessages, ...prefixMessages]) + return [...referenceSnippetMessages, ...prefixMessages] } public abstract generateCompletions(abortSignal: AbortSignal, n?: number): Promise @@ -115,7 +116,7 @@ export class MultilineCompletionProvider extends CompletionProvider { const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5) prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Complete the following file:\n' + '```' + @@ -123,18 +124,18 @@ export class MultilineCompletionProvider extends CompletionProvider { '```', }, { - role: 'ai', + speaker: 'assistant', text: `Here is the completion of the file:\n\`\`\`\n${prefixLines.slice(endLine).join('\n')}`, }, ] } else { prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Write some code', }, { - role: 'ai', + speaker: 'assistant', text: `Here is some code:\n\`\`\`\n${prefix}`, }, ] @@ -164,7 +165,8 @@ export class MultilineCompletionProvider extends CompletionProvider { // Create prompt const prompt = this.createPrompt() - if (prompt.length > this.promptChars) { + const textPrompt = messagesToText(prompt) + if (textPrompt.length > this.promptChars) { throw new Error('prompt length exceeded maximum alloted chars') } @@ -172,13 +174,8 @@ export class MultilineCompletionProvider extends CompletionProvider { const responses = await batchCompletions( this.completionsClient, { - prompt, - stopSequences: [anthropic.HUMAN_PROMPT], + messages: prompt, maxTokensToSample: this.responseTokens, - model: 'claude-instant-v1.0', - temperature: 1, // default value (source: https://console.anthropic.com/docs/api/reference) - topK: -1, // default value - topP: -1, // default value }, n || this.defaultN, abortSignal @@ -186,7 +183,7 @@ export class MultilineCompletionProvider extends CompletionProvider { // Post-process return responses.map(resp => ({ prefix, - prompt, + messages: prompt, content: this.postProcess(resp.completion), stopReason: resp.stopReason, })) @@ -206,7 +203,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider { const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5) prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Complete the following file:\n' + '```' + @@ -214,7 +211,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider { '```', }, { - role: 'ai', + speaker: 'assistant', text: 'Here is the completion of the file:\n' + '```' + @@ -224,11 +221,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider { } else { prefixMessages = [ { - role: 'human', + speaker: 'human', text: 'Write some code', }, { - role: 'ai', + speaker: 'assistant', text: `Here is some code:\n\`\`\`\n${this.prefix}${this.injectPrefix}`, }, ] @@ -272,10 +269,9 @@ export class EndOfLineCompletionProvider extends CompletionProvider { const responses = await batchCompletions( this.completionsClient, { - prompt, + messages: prompt, stopSequences: [anthropic.HUMAN_PROMPT, '\n'], maxTokensToSample: this.responseTokens, - model: 'claude-instant-v1.0', temperature: 1, topK: -1, topP: -1, @@ -286,7 +282,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider { // Post-process return responses.map(resp => ({ prefix, - prompt, + messages: prompt, content: this.postProcess(resp.completion), stopReason: resp.stopReason, })) @@ -295,11 +291,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider { async function batchCompletions( client: SourcegraphNodeCompletionsClient, - params: CodeCompletionParameters, + params: CompletionParameters, n: number, abortSignal: AbortSignal -): Promise { - const responses: Promise[] = [] +): Promise { + const responses: Promise[] = [] for (let i = 0; i < n; i++) { responses.push(client.complete(params, abortSignal)) } diff --git a/client/cody/src/log.ts b/client/cody/src/log.ts index cbf2efc2a28..f793fc60c9d 100644 --- a/client/cody/src/log.ts +++ b/client/cody/src/log.ts @@ -2,9 +2,8 @@ import vscode from 'vscode' import { CompletionLogger } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/client' import { - CodeCompletionParameters, - CodeCompletionResponse, CompletionParameters, + CompletionResponse, Event, } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types' @@ -19,7 +18,7 @@ if (config.debug) { } export const logger: CompletionLogger = { - startCompletion(params: CodeCompletionParameters | CompletionParameters) { + startCompletion(params: CompletionParameters) { if (!outputChannel) { return undefined } @@ -47,7 +46,7 @@ export const logger: CompletionLogger = { outputChannel!.appendLine('') } - function onComplete(result: string | CodeCompletionResponse): void { + function onComplete(result: string | CompletionResponse): void { if (hasFinished) { return } diff --git a/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go b/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go index a14824676b2..19ddf8d31b4 100644 --- a/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go +++ b/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go @@ -65,7 +65,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke } var last string - if err := client.Stream(ctx, convertParams(args), func(event types.ChatCompletionEvent) error { + if err := client.Stream(ctx, convertParams(args), func(event types.CompletionResponse) error { // each completion is just a partial of the final result, since we're in a sync request anyway // we will just wait for the final completion event last = event.Completion @@ -76,8 +76,8 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke return last, nil } -func convertParams(args graphqlbackend.CompletionsArgs) types.ChatCompletionRequestParameters { - return types.ChatCompletionRequestParameters{ +func convertParams(args graphqlbackend.CompletionsArgs) types.CompletionRequestParameters { + return types.CompletionRequestParameters{ Messages: convertMessages(args.Input.Messages), Temperature: float32(args.Input.Temperature), MaxTokensToSample: int(args.Input.MaxTokensToSample), diff --git a/enterprise/cmd/llm-proxy/internal/httpapi/anthropic.go b/enterprise/cmd/llm-proxy/internal/httpapi/anthropic.go index 943d566de88..7bcd05f5446 100644 --- a/enterprise/cmd/llm-proxy/internal/httpapi/anthropic.go +++ b/enterprise/cmd/llm-proxy/internal/httpapi/anthropic.go @@ -15,7 +15,7 @@ import ( const anthropicAPIURL = "https://api.anthropic.com/v1/complete" func newAnthropicHandler(logger log.Logger, eventLogger events.Logger, accessToken string) http.Handler { - return makeUpstreamHandler[anthropicRequest]( + return makeUpstreamHandler( logger, eventLogger, anthropicAPIURL, diff --git a/enterprise/cmd/llm-proxy/internal/httpapi/openai.go b/enterprise/cmd/llm-proxy/internal/httpapi/openai.go index c1aa70d477e..cd6d15c25ba 100644 --- a/enterprise/cmd/llm-proxy/internal/httpapi/openai.go +++ b/enterprise/cmd/llm-proxy/internal/httpapi/openai.go @@ -15,7 +15,7 @@ import ( const openAIURL = "https://api.openai.com/v1/chat/completions" func newOpenAIHandler(logger log.Logger, eventLogger events.Logger, accessToken string, orgID string) http.Handler { - return makeUpstreamHandler[openaiRequest]( + return makeUpstreamHandler( logger, eventLogger, openAIURL, diff --git a/enterprise/internal/completions/client/anthropic/anthropic.go b/enterprise/internal/completions/client/anthropic/anthropic.go index bce55c7525a..bbbc7da29fd 100644 --- a/enterprise/internal/completions/client/anthropic/anthropic.go +++ b/enterprise/internal/completions/client/anthropic/anthropic.go @@ -12,17 +12,6 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) -type AnthropicCompletionsRequestParameters struct { - Prompt string `json:"prompt"` - Temperature float32 `json:"temperature"` - MaxTokensToSample int `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - Model string `json:"model"` - Stream bool `json:"stream"` -} - const ProviderName = "anthropic" func NewClient(cli httpcli.Doer, accessToken string, model string) types.CompletionsClient { @@ -33,77 +22,41 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet } } +const apiURL = "https://api.anthropic.com/v1/complete" +const clientID = "sourcegraph/1.0" + type anthropicClient struct { cli httpcli.Doer accessToken string model string } -const apiURL = "https://api.anthropic.com/v1/complete" -const clientID = "sourcegraph/1.0" - -var stopSequences = []string{HUMAN_PROMPT} - -var allowedClientSpecifiedModels = map[string]struct{}{ - "claude-instant-v1.0": {}, -} - func (a *anthropicClient) Complete( ctx context.Context, - requestParams types.CodeCompletionRequestParameters, -) (*types.CodeCompletionResponse, error) { - var model string - if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed { - model = requestParams.Model - } else { - model = a.model - } - payload := AnthropicCompletionsRequestParameters{ - Stream: false, - StopSequences: requestParams.StopSequences, - Model: model, - Temperature: float32(requestParams.Temperature), - MaxTokensToSample: requestParams.MaxTokensToSample, - TopP: float32(requestParams.TopP), - TopK: requestParams.TopK, - Prompt: requestParams.Prompt, - } - - resp, err := a.makeRequest(ctx, payload) + requestParams types.CompletionRequestParameters, +) (*types.CompletionResponse, error) { + resp, err := a.makeRequest(ctx, requestParams, false) if err != nil { return nil, err } defer resp.Body.Close() - var response types.CodeCompletionResponse + var response anthropicCompletionResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return nil, err } - return &response, nil + return &types.CompletionResponse{ + Completion: response.Completion, + StopReason: response.StopReason, + }, nil } func (a *anthropicClient) Stream( ctx context.Context, - requestParams types.ChatCompletionRequestParameters, + requestParams types.CompletionRequestParameters, sendEvent types.SendCompletionEvent, ) error { - prompt, err := getPrompt(requestParams.Messages) - if err != nil { - return err - } - - payload := AnthropicCompletionsRequestParameters{ - Stream: true, - StopSequences: stopSequences, - Model: a.model, - Temperature: requestParams.Temperature, - MaxTokensToSample: requestParams.MaxTokensToSample, - TopP: requestParams.TopP, - TopK: requestParams.TopK, - Prompt: prompt, - } - - resp, err := a.makeRequest(ctx, payload) + resp, err := a.makeRequest(ctx, requestParams, true) if err != nil { return err } @@ -122,12 +75,15 @@ func (a *anthropicClient) Stream( continue } - var event types.ChatCompletionEvent + var event anthropicCompletionResponse if err := json.Unmarshal(data, &event); err != nil { return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data)) } - err = sendEvent(event) + err = sendEvent(types.CompletionResponse{ + Completion: event.Completion, + StopReason: event.StopReason, + }) if err != nil { return err } @@ -136,7 +92,32 @@ func (a *anthropicClient) Stream( return dec.Err() } -func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicCompletionsRequestParameters) (*http.Response, error) { +func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) { + prompt, err := getPrompt(requestParams.Messages) + if err != nil { + return nil, err + } + // Backcompat: Remove this code once enough clients are upgraded and we drop the + // Prompt field on requestParams. + if prompt == "" { + prompt = requestParams.Prompt + } + + if len(requestParams.StopSequences) == 0 { + requestParams.StopSequences = []string{HUMAN_PROMPT} + } + + payload := anthropicCompletionsRequestParameters{ + Stream: stream, + StopSequences: requestParams.StopSequences, + Model: a.model, + Temperature: requestParams.Temperature, + MaxTokensToSample: requestParams.MaxTokensToSample, + TopP: requestParams.TopP, + TopK: requestParams.TopK, + Prompt: prompt, + } + reqBody, err := json.Marshal(payload) if err != nil { return nil, err @@ -168,3 +149,19 @@ func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicComp return resp, nil } + +type anthropicCompletionsRequestParameters struct { + Prompt string `json:"prompt"` + Temperature float32 `json:"temperature"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + Model string `json:"model"` + Stream bool `json:"stream"` +} + +type anthropicCompletionResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` +} diff --git a/enterprise/internal/completions/client/anthropic/anthropic_test.go b/enterprise/internal/completions/client/anthropic/anthropic_test.go index 50a22d9fc49..cca6777065a 100644 --- a/enterprise/internal/completions/client/anthropic/anthropic_test.go +++ b/enterprise/internal/completions/client/anthropic/anthropic_test.go @@ -49,8 +49,8 @@ func TestValidAnthropicStream(t *testing.T) { } mockClient := getMockClient(linesToResponse(mockAnthropicResponseLines)) - events := []types.ChatCompletionEvent{} - err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error { + events := []types.CompletionResponse{} + err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { events = append(events, event) return nil }) @@ -64,7 +64,7 @@ func TestInvalidAnthropicStream(t *testing.T) { var mockAnthropicInvalidResponseLines = []string{`{]`} mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines)) - err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error { return nil }) + err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil }) if err == nil { t.Fatal("expected error, got nil") } diff --git a/enterprise/internal/completions/client/anthropic/testdata/TestValidAnthropicStream.golden b/enterprise/internal/completions/client/anthropic/testdata/TestValidAnthropicStream.golden index 94e59ac815c..cc5a6cf692c 100644 --- a/enterprise/internal/completions/client/anthropic/testdata/TestValidAnthropicStream.golden +++ b/enterprise/internal/completions/client/anthropic/testdata/TestValidAnthropicStream.golden @@ -1,4 +1,4 @@ -[]types.ChatCompletionEvent{ +[]types.CompletionResponse{ { Completion: "Sure!", }, diff --git a/enterprise/internal/completions/client/dotcom/dotcom.go b/enterprise/internal/completions/client/dotcom/dotcom.go index 7d3986c63c7..3c588b31985 100644 --- a/enterprise/internal/completions/client/dotcom/dotcom.go +++ b/enterprise/internal/completions/client/dotcom/dotcom.go @@ -36,14 +36,14 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet func (a *dotcomClient) Complete( ctx context.Context, - requestParams types.CodeCompletionRequestParameters, -) (*types.CodeCompletionResponse, error) { + requestParams types.CompletionRequestParameters, +) (*types.CompletionResponse, error) { return nil, errors.New("not implemented") } func (a *dotcomClient) Stream( ctx context.Context, - requestParams types.ChatCompletionRequestParameters, + requestParams types.CompletionRequestParameters, sendEvent types.SendCompletionEvent, ) error { reqBody, err := json.Marshal(requestParams) @@ -75,7 +75,7 @@ func (a *dotcomClient) Stream( return nil } - var event types.ChatCompletionEvent + var event types.CompletionResponse if err := json.Unmarshal(dec.Data(), &event); err != nil { return errors.Errorf("failed to decode event payload: %w", err) } diff --git a/enterprise/internal/completions/client/openai/openai.go b/enterprise/internal/completions/client/openai/openai.go index ce8666520fa..c37092c7516 100644 --- a/enterprise/internal/completions/client/openai/openai.go +++ b/enterprise/internal/completions/client/openai/openai.go @@ -32,13 +32,13 @@ type openAIChatCompletionStreamClient struct { model string } -func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CodeCompletionRequestParameters) (*types.CodeCompletionResponse, error) { +func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CompletionRequestParameters) (*types.CompletionResponse, error) { return nil, errors.New("openAIChatCompletionStreamClient.Complete: unimplemented") } func (a *openAIChatCompletionStreamClient) Stream( ctx context.Context, - requestParams types.ChatCompletionRequestParameters, + requestParams types.CompletionRequestParameters, sendEvent types.SendCompletionEvent, ) error { if requestParams.TopK < 0 { @@ -128,9 +128,13 @@ func (a *openAIChatCompletionStreamClient) Stream( return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data)) } - if len(event.Choices) > 0 && event.Choices[0].FinishReason == nil { + if len(event.Choices) > 0 { content = append(content, event.Choices[0].Delta.Content) - err = sendEvent(types.ChatCompletionEvent{Completion: strings.Join(content, "")}) + ev := types.CompletionResponse{Completion: strings.Join(content, "")} + if event.Choices[0].FinishReason != nil { + ev.StopReason = *event.Choices[0].FinishReason + } + err = sendEvent(ev) if err != nil { return err } diff --git a/enterprise/internal/completions/httpapi/codecompletion.go b/enterprise/internal/completions/httpapi/codecompletion.go index ee72ec60f0d..e6e0c3615b4 100644 --- a/enterprise/internal/completions/httpapi/codecompletion.go +++ b/enterprise/internal/completions/httpapi/codecompletion.go @@ -13,12 +13,24 @@ import ( "github.com/sourcegraph/sourcegraph/schema" ) +var allowedClientSpecifiedModels = map[string]struct{}{ + // TODO(eseliger): This list should probably be configurable. + "claude-instant-v1.0": {}, +} + // NewCodeCompletionsHandler is an http handler which sends back code completion results func NewCodeCompletionsHandler(_ log.Logger, db database.DB) http.Handler { rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCodeCompletion) - return newCompletionsHandler(rl, "codeCompletions", func(c *schema.Completions) string { - return c.CompletionModel - }, func(ctx context.Context, requestParams types.CodeCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) { + return newCompletionsHandler(rl, "codeCompletions", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string { + var model string + if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed { + model = requestParams.Model + } else { + model = c.CompletionModel + } + + return model + }, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) { completion, err := cc.Complete(ctx, requestParams) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/enterprise/internal/completions/httpapi/handler.go b/enterprise/internal/completions/httpapi/handler.go index 71e39b936cc..aaeff66784e 100644 --- a/enterprise/internal/completions/httpapi/handler.go +++ b/enterprise/internal/completions/httpapi/handler.go @@ -18,7 +18,7 @@ import ( // being cancelled. const maxRequestDuration = time.Minute -func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(*schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler { +func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(T, *schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed) @@ -45,7 +45,7 @@ func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel f return } - model := getModel(completionsConfig) + model := getModel(requestParams, completionsConfig) var err error ctx, done := Trace(ctx, traceFamily, model). diff --git a/enterprise/internal/completions/httpapi/stream.go b/enterprise/internal/completions/httpapi/stream.go index 7b04f58df13..444ca4dfa6d 100644 --- a/enterprise/internal/completions/httpapi/stream.go +++ b/enterprise/internal/completions/httpapi/stream.go @@ -17,9 +17,16 @@ import ( // NewCompletionsStreamHandler is an http handler which streams back completions results. func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler { rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCompletion) - return newCompletionsHandler(rl, "stream", func(c *schema.Completions) string { - return c.ChatModel - }, func(ctx context.Context, requestParams types.ChatCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) { + return newCompletionsHandler(rl, "stream", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string { + var model string + if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed { + model = requestParams.Model + } else { + model = c.ChatModel + } + + return model + }, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) { eventWriter, err := streamhttp.NewWriter(w) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -31,7 +38,7 @@ func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler _ = eventWriter.Event("done", map[string]any{}) }() - err = cc.Stream(ctx, requestParams, func(event types.ChatCompletionEvent) error { return eventWriter.Event("completion", event) }) + err = cc.Stream(ctx, requestParams, func(event types.CompletionResponse) error { return eventWriter.Event("completion", event) }) if err != nil { trace.Logger(ctx, logger).Error("error while streaming completions", log.Error(err)) _ = eventWriter.Event("error", map[string]string{"error": err.Error()}) diff --git a/enterprise/internal/completions/types/types.go b/enterprise/internal/completions/types/types.go index 8df293afe87..ec33629f478 100644 --- a/enterprise/internal/completions/types/types.go +++ b/enterprise/internal/completions/types/types.go @@ -15,36 +15,8 @@ type Message struct { Text string `json:"text"` } -type CodeCompletionRequestParameters struct { - Prompt string `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokensToSample int `json:"maxTokensToSample"` - StopSequences []string `json:"stopSequences"` - TopK int `json:"topK,omitempty"` - TopP float64 `json:"topP,omitempty"` - Model string `json:"model"` - Tags map[string]string `json:"tags,omitempty"` -} - -type CodeCompletionResponse struct { - Completion string `json:"completion"` - Stop *string `json:"stop"` - StopReason string `json:"stopReason"` - Truncated bool `json:"truncated"` - Exception *string `json:"exception"` - LogID string `json:"logID"` -} - -type ChatCompletionRequestParameters struct { - Messages []Message `json:"messages"` - Temperature float32 `json:"temperature"` - MaxTokensToSample int `json:"maxTokensToSample"` - TopK int `json:"topK"` - TopP float32 `json:"topP"` -} - -type ChatCompletionEvent struct { - Completion string `json:"completion"` +func (m Message) IsValidSpeaker() bool { + return m.Speaker == HUMAN_MESSAGE_SPEAKER || m.Speaker == ASISSTANT_MESSAGE_SPEAKER } func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (string, error) { @@ -65,9 +37,28 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str return fmt.Sprintf("%s %s", prefix, m.Text), nil } -type SendCompletionEvent func(event ChatCompletionEvent) error +type CompletionRequestParameters struct { + // Prompt exists only for backwards compatibility. Do not use it in new + // implementations. It will be removed once we are reasonably sure 99% + // of VSCode extension installations are upgraded to a new Cody version. + Prompt string `json:"prompt"` + Messages []Message `json:"messages"` + MaxTokensToSample int `json:"maxTokensToSample,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + TopK int `json:"topK,omitempty"` + TopP float32 `json:"topP,omitempty"` + Model string `json:"model,omitempty"` +} + +type CompletionResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stopReason"` +} + +type SendCompletionEvent func(event CompletionResponse) error type CompletionsClient interface { - Stream(ctx context.Context, requestParams ChatCompletionRequestParameters, sendEvent SendCompletionEvent) error - Complete(ctx context.Context, requestParams CodeCompletionRequestParameters) (*CodeCompletionResponse, error) + Stream(ctx context.Context, requestParams CompletionRequestParameters, sendEvent SendCompletionEvent) error + Complete(ctx context.Context, requestParams CompletionRequestParameters) (*CompletionResponse, error) }