mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
Unify API for code completions and chat (#51962)
This PR makes the input / output of these two endpoints the same, making them effectively the same thing, with the only difference that another model is chosen for code, and that code completions are non-streaming. This will be useful long-term as there's less divergence, and help make the implementation of additional upstream providers easier as only one method needs to be implemented in most cases.
This commit is contained in:
parent
c2045704d8
commit
fe8e70d94d
@ -1,10 +1,10 @@
|
||||
import { fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
|
||||
import { SourcegraphCompletionsClient } from './client'
|
||||
import type { Event, CompletionParameters, CompletionCallbacks, CodeCompletionResponse } from './types'
|
||||
import type { Event, CompletionParameters, CompletionCallbacks, CompletionResponse } from './types'
|
||||
|
||||
export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsClient {
|
||||
public complete(): Promise<CodeCompletionResponse> {
|
||||
public complete(): Promise<CompletionResponse> {
|
||||
throw new Error('SourcegraphBrowserCompletionsClient.complete not implemented')
|
||||
}
|
||||
|
||||
|
||||
@ -1,19 +1,13 @@
|
||||
import { ConfigurationWithAccessToken } from '../../configuration'
|
||||
|
||||
import {
|
||||
Event,
|
||||
CompletionParameters,
|
||||
CompletionCallbacks,
|
||||
CodeCompletionParameters,
|
||||
CodeCompletionResponse,
|
||||
} from './types'
|
||||
import { Event, CompletionCallbacks, CompletionParameters, CompletionResponse } from './types'
|
||||
|
||||
export interface CompletionLogger {
|
||||
startCompletion(params: CodeCompletionParameters | CompletionParameters):
|
||||
startCompletion(params: CompletionParameters):
|
||||
| undefined
|
||||
| {
|
||||
onError: (error: string) => void
|
||||
onComplete: (response: string | CodeCompletionResponse) => void
|
||||
onComplete: (response: string | CompletionResponse) => void
|
||||
onEvents: (events: Event[]) => void
|
||||
}
|
||||
}
|
||||
@ -52,8 +46,5 @@ export abstract class SourcegraphCompletionsClient {
|
||||
}
|
||||
|
||||
public abstract stream(params: CompletionParameters, cb: CompletionCallbacks): () => void
|
||||
public abstract complete(
|
||||
params: CodeCompletionParameters,
|
||||
abortSignal: AbortSignal
|
||||
): Promise<CodeCompletionResponse>
|
||||
public abstract complete(params: CompletionParameters, abortSignal: AbortSignal): Promise<CompletionResponse>
|
||||
}
|
||||
|
||||
@ -6,10 +6,10 @@ import { toPartialUtf8String } from '../utils'
|
||||
|
||||
import { SourcegraphCompletionsClient } from './client'
|
||||
import { parseEvents } from './parse'
|
||||
import { CompletionParameters, CompletionCallbacks, CodeCompletionParameters, CodeCompletionResponse } from './types'
|
||||
import { CompletionParameters, CompletionCallbacks, CompletionResponse } from './types'
|
||||
|
||||
export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClient {
|
||||
public async complete(params: CodeCompletionParameters, abortSignal: AbortSignal): Promise<CodeCompletionResponse> {
|
||||
public async complete(params: CompletionParameters, abortSignal: AbortSignal): Promise<CompletionResponse> {
|
||||
const log = this.logger?.startCompletion(params)
|
||||
|
||||
const requestFn = this.codeCompletionsEndpoint.startsWith('https://') ? https.request : http.request
|
||||
@ -18,7 +18,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
|
||||
if (this.config.accessToken) {
|
||||
headersInstance.set('Authorization', `token ${this.config.accessToken}`)
|
||||
}
|
||||
const completion = await new Promise<CodeCompletionResponse>((resolve, reject) => {
|
||||
const completion = await new Promise<CompletionResponse>((resolve, reject) => {
|
||||
const req = requestFn(
|
||||
this.codeCompletionsEndpoint,
|
||||
{
|
||||
@ -44,7 +44,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
|
||||
reject(new Error(buffer))
|
||||
}
|
||||
|
||||
const resp = JSON.parse(buffer) as CodeCompletionResponse
|
||||
const resp = JSON.parse(buffer) as CompletionResponse
|
||||
if (typeof resp.completion !== 'string' || typeof resp.stopReason !== 'string') {
|
||||
const message = `response does not satisfy CodeCompletionResponse: ${buffer}`
|
||||
log?.onError(message)
|
||||
|
||||
@ -37,14 +37,14 @@ function parseEventData(eventType: Event['type'], dataLine: string): Event | Err
|
||||
const jsonData = dataLine.slice(DATA_LINE_PREFIX.length)
|
||||
switch (eventType) {
|
||||
case 'completion': {
|
||||
const data = parseJSON<{ completion: string }>(jsonData)
|
||||
const data = parseJSON<{ completion: string; stopReason: string }>(jsonData)
|
||||
if (isError(data)) {
|
||||
return data
|
||||
}
|
||||
if (typeof data.completion === undefined) {
|
||||
return new Error('invalid completion event')
|
||||
}
|
||||
return { type: eventType, completion: data.completion }
|
||||
return { type: eventType, completion: data.completion, stopReason: data.stopReason }
|
||||
}
|
||||
case 'error': {
|
||||
const data = parseJSON<{ error: string }>(jsonData)
|
||||
|
||||
@ -2,9 +2,8 @@ export interface DoneEvent {
|
||||
type: 'done'
|
||||
}
|
||||
|
||||
export interface CompletionEvent {
|
||||
export interface CompletionEvent extends CompletionResponse {
|
||||
type: 'completion'
|
||||
completion: string
|
||||
}
|
||||
|
||||
export interface ErrorEvent {
|
||||
@ -19,31 +18,19 @@ export interface Message {
|
||||
text?: string
|
||||
}
|
||||
|
||||
export interface CodeCompletionResponse {
|
||||
export interface CompletionResponse {
|
||||
completion: string
|
||||
stop: string | null
|
||||
stopReason: string
|
||||
truncated: boolean
|
||||
exception: string | null
|
||||
logID: string
|
||||
}
|
||||
|
||||
export interface CodeCompletionParameters {
|
||||
prompt: string
|
||||
temperature: number
|
||||
maxTokensToSample: number
|
||||
stopSequences: string[]
|
||||
topK: number
|
||||
topP: number
|
||||
model?: string
|
||||
}
|
||||
|
||||
export interface CompletionParameters {
|
||||
messages: Message[]
|
||||
temperature: number
|
||||
maxTokensToSample: number
|
||||
topK: number
|
||||
topP: number
|
||||
temperature?: number
|
||||
stopSequences?: string[]
|
||||
topK?: number
|
||||
topP?: number
|
||||
model?: string
|
||||
}
|
||||
|
||||
export interface CompletionCallbacks {
|
||||
|
||||
@ -3,14 +3,14 @@ import { CompletionsCache } from './cache'
|
||||
describe('CompletionsCache', () => {
|
||||
it('returns the cached completion items', () => {
|
||||
const cache = new CompletionsCache()
|
||||
cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
|
||||
cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }])
|
||||
|
||||
expect(cache.get('foo\n')).toEqual([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
|
||||
})
|
||||
|
||||
it('returns the cached items when the prefix includes characters from the completion', () => {
|
||||
const cache = new CompletionsCache()
|
||||
cache.add([{ prefix: 'foo\n', content: 'bar', prompt: '' }])
|
||||
cache.add([{ prefix: 'foo\n', content: 'bar', messages: [] }])
|
||||
|
||||
expect(cache.get('foo\nb')).toEqual([{ prefix: 'foo\nb', content: 'ar', prompt: '' }])
|
||||
expect(cache.get('foo\nba')).toEqual([{ prefix: 'foo\nba', content: 'r', prompt: '' }])
|
||||
@ -18,7 +18,7 @@ describe('CompletionsCache', () => {
|
||||
|
||||
it('returns the cached items when the prefix has less whitespace', () => {
|
||||
const cache = new CompletionsCache()
|
||||
cache.add([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])
|
||||
cache.add([{ prefix: 'foo \n ', content: 'bar', messages: [] }])
|
||||
|
||||
expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])
|
||||
expect(cache.get('foo \n ')).toEqual([{ prefix: 'foo \n ', content: 'bar', prompt: '' }])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import * as vscode from 'vscode'
|
||||
|
||||
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
|
||||
import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient'
|
||||
|
||||
import { logEvent } from '../event-logger'
|
||||
@ -372,7 +373,7 @@ function getCurrentDocContext(
|
||||
|
||||
export interface Completion {
|
||||
prefix: string
|
||||
prompt: string
|
||||
messages: Message[]
|
||||
content: string
|
||||
stopReason?: string
|
||||
}
|
||||
|
||||
@ -1,18 +1,15 @@
|
||||
import * as anthropic from '@anthropic-ai/sdk'
|
||||
|
||||
import { ReferenceSnippet } from './context'
|
||||
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
|
||||
|
||||
export interface Message {
|
||||
role: 'human' | 'ai'
|
||||
text: string | null
|
||||
}
|
||||
import { ReferenceSnippet } from './context'
|
||||
|
||||
export function messagesToText(messages: Message[]): string {
|
||||
return messages
|
||||
.map(
|
||||
message =>
|
||||
`${message.role === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${
|
||||
message.text === null ? '' : ' ' + message.text
|
||||
`${message.speaker === 'human' ? anthropic.HUMAN_PROMPT : anthropic.AI_PROMPT}${
|
||||
message.text === undefined ? '' : ' ' + message.text
|
||||
}`
|
||||
)
|
||||
.join('')
|
||||
@ -39,7 +36,7 @@ export class SingleLinePromptTemplate implements PromptTemplate {
|
||||
const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
'Complete the following file:\n' +
|
||||
'```' +
|
||||
@ -47,7 +44,7 @@ export class SingleLinePromptTemplate implements PromptTemplate {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text:
|
||||
'Here is the completion of the file:\n' +
|
||||
'```' +
|
||||
@ -57,11 +54,11 @@ export class SingleLinePromptTemplate implements PromptTemplate {
|
||||
} else {
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text: 'Write some code',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: `Here is some code:\n\`\`\`\n${prefix}`,
|
||||
},
|
||||
]
|
||||
@ -102,7 +99,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
|
||||
const lastHumanLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
'Complete the following file:\n' +
|
||||
'```' +
|
||||
@ -110,7 +107,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text:
|
||||
'Here is the completion of the file:\n' +
|
||||
'```' +
|
||||
@ -120,11 +117,11 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
|
||||
} else {
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text: 'Write some code',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: `Here is some code:\n\`\`\`\n${prefix}`,
|
||||
},
|
||||
]
|
||||
@ -135,7 +132,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
|
||||
for (const snippet of snippets) {
|
||||
const snippetMessages: Message[] = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
`Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` +
|
||||
'```' +
|
||||
@ -143,7 +140,7 @@ export class KnowledgeBasePromptTemplate implements PromptTemplate {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: 'Okay, I have added it to my knowledge base.',
|
||||
},
|
||||
]
|
||||
|
||||
@ -2,13 +2,14 @@ import * as anthropic from '@anthropic-ai/sdk'
|
||||
|
||||
import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/nodeClient'
|
||||
import {
|
||||
CodeCompletionParameters,
|
||||
CodeCompletionResponse,
|
||||
CompletionParameters,
|
||||
CompletionResponse,
|
||||
Message,
|
||||
} from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types'
|
||||
|
||||
import { Completion } from '.'
|
||||
import { ReferenceSnippet } from './context'
|
||||
import { Message, messagesToText } from './prompts'
|
||||
import { messagesToText } from './prompts'
|
||||
|
||||
export abstract class CompletionProvider {
|
||||
constructor(
|
||||
@ -32,7 +33,7 @@ export abstract class CompletionProvider {
|
||||
|
||||
// Creates the resulting prompt and adds as many snippets from the reference
|
||||
// list as possible.
|
||||
protected createPrompt(): string {
|
||||
protected createPrompt(): Message[] {
|
||||
const prefixMessages = this.createPromptPrefix()
|
||||
const referenceSnippetMessages: Message[] = []
|
||||
|
||||
@ -50,7 +51,7 @@ export abstract class CompletionProvider {
|
||||
if (suffix.length > 0) {
|
||||
const suffixContext: Message[] = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
'Add the following code snippet to your knowledge base:\n' +
|
||||
'```' +
|
||||
@ -58,7 +59,7 @@ export abstract class CompletionProvider {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: 'Okay, I have added it to my knowledge base.',
|
||||
},
|
||||
]
|
||||
@ -74,7 +75,7 @@ export abstract class CompletionProvider {
|
||||
for (const snippet of this.snippets) {
|
||||
const snippetMessages: Message[] = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
`Add the following code snippet (from file ${snippet.filename}) to your knowledge base:\n` +
|
||||
'```' +
|
||||
@ -82,7 +83,7 @@ export abstract class CompletionProvider {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: 'Okay, I have added it to my knowledge base.',
|
||||
},
|
||||
]
|
||||
@ -94,7 +95,7 @@ export abstract class CompletionProvider {
|
||||
remainingChars -= numSnippetChars
|
||||
}
|
||||
|
||||
return messagesToText([...referenceSnippetMessages, ...prefixMessages])
|
||||
return [...referenceSnippetMessages, ...prefixMessages]
|
||||
}
|
||||
|
||||
public abstract generateCompletions(abortSignal: AbortSignal, n?: number): Promise<Completion[]>
|
||||
@ -115,7 +116,7 @@ export class MultilineCompletionProvider extends CompletionProvider {
|
||||
const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
'Complete the following file:\n' +
|
||||
'```' +
|
||||
@ -123,18 +124,18 @@ export class MultilineCompletionProvider extends CompletionProvider {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: `Here is the completion of the file:\n\`\`\`\n${prefixLines.slice(endLine).join('\n')}`,
|
||||
},
|
||||
]
|
||||
} else {
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text: 'Write some code',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: `Here is some code:\n\`\`\`\n${prefix}`,
|
||||
},
|
||||
]
|
||||
@ -164,7 +165,8 @@ export class MultilineCompletionProvider extends CompletionProvider {
|
||||
|
||||
// Create prompt
|
||||
const prompt = this.createPrompt()
|
||||
if (prompt.length > this.promptChars) {
|
||||
const textPrompt = messagesToText(prompt)
|
||||
if (textPrompt.length > this.promptChars) {
|
||||
throw new Error('prompt length exceeded maximum alloted chars')
|
||||
}
|
||||
|
||||
@ -172,13 +174,8 @@ export class MultilineCompletionProvider extends CompletionProvider {
|
||||
const responses = await batchCompletions(
|
||||
this.completionsClient,
|
||||
{
|
||||
prompt,
|
||||
stopSequences: [anthropic.HUMAN_PROMPT],
|
||||
messages: prompt,
|
||||
maxTokensToSample: this.responseTokens,
|
||||
model: 'claude-instant-v1.0',
|
||||
temperature: 1, // default value (source: https://console.anthropic.com/docs/api/reference)
|
||||
topK: -1, // default value
|
||||
topP: -1, // default value
|
||||
},
|
||||
n || this.defaultN,
|
||||
abortSignal
|
||||
@ -186,7 +183,7 @@ export class MultilineCompletionProvider extends CompletionProvider {
|
||||
// Post-process
|
||||
return responses.map(resp => ({
|
||||
prefix,
|
||||
prompt,
|
||||
messages: prompt,
|
||||
content: this.postProcess(resp.completion),
|
||||
stopReason: resp.stopReason,
|
||||
}))
|
||||
@ -206,7 +203,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
const endLine = Math.max(Math.floor(prefixLines.length / 2), prefixLines.length - 5)
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text:
|
||||
'Complete the following file:\n' +
|
||||
'```' +
|
||||
@ -214,7 +211,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
'```',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text:
|
||||
'Here is the completion of the file:\n' +
|
||||
'```' +
|
||||
@ -224,11 +221,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
} else {
|
||||
prefixMessages = [
|
||||
{
|
||||
role: 'human',
|
||||
speaker: 'human',
|
||||
text: 'Write some code',
|
||||
},
|
||||
{
|
||||
role: 'ai',
|
||||
speaker: 'assistant',
|
||||
text: `Here is some code:\n\`\`\`\n${this.prefix}${this.injectPrefix}`,
|
||||
},
|
||||
]
|
||||
@ -272,10 +269,9 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
const responses = await batchCompletions(
|
||||
this.completionsClient,
|
||||
{
|
||||
prompt,
|
||||
messages: prompt,
|
||||
stopSequences: [anthropic.HUMAN_PROMPT, '\n'],
|
||||
maxTokensToSample: this.responseTokens,
|
||||
model: 'claude-instant-v1.0',
|
||||
temperature: 1,
|
||||
topK: -1,
|
||||
topP: -1,
|
||||
@ -286,7 +282,7 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
// Post-process
|
||||
return responses.map(resp => ({
|
||||
prefix,
|
||||
prompt,
|
||||
messages: prompt,
|
||||
content: this.postProcess(resp.completion),
|
||||
stopReason: resp.stopReason,
|
||||
}))
|
||||
@ -295,11 +291,11 @@ export class EndOfLineCompletionProvider extends CompletionProvider {
|
||||
|
||||
async function batchCompletions(
|
||||
client: SourcegraphNodeCompletionsClient,
|
||||
params: CodeCompletionParameters,
|
||||
params: CompletionParameters,
|
||||
n: number,
|
||||
abortSignal: AbortSignal
|
||||
): Promise<CodeCompletionResponse[]> {
|
||||
const responses: Promise<CodeCompletionResponse>[] = []
|
||||
): Promise<CompletionResponse[]> {
|
||||
const responses: Promise<CompletionResponse>[] = []
|
||||
for (let i = 0; i < n; i++) {
|
||||
responses.push(client.complete(params, abortSignal))
|
||||
}
|
||||
|
||||
@ -2,9 +2,8 @@ import vscode from 'vscode'
|
||||
|
||||
import { CompletionLogger } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/client'
|
||||
import {
|
||||
CodeCompletionParameters,
|
||||
CodeCompletionResponse,
|
||||
CompletionParameters,
|
||||
CompletionResponse,
|
||||
Event,
|
||||
} from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/types'
|
||||
|
||||
@ -19,7 +18,7 @@ if (config.debug) {
|
||||
}
|
||||
|
||||
export const logger: CompletionLogger = {
|
||||
startCompletion(params: CodeCompletionParameters | CompletionParameters) {
|
||||
startCompletion(params: CompletionParameters) {
|
||||
if (!outputChannel) {
|
||||
return undefined
|
||||
}
|
||||
@ -47,7 +46,7 @@ export const logger: CompletionLogger = {
|
||||
outputChannel!.appendLine('')
|
||||
}
|
||||
|
||||
function onComplete(result: string | CodeCompletionResponse): void {
|
||||
function onComplete(result: string | CompletionResponse): void {
|
||||
if (hasFinished) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
|
||||
}
|
||||
|
||||
var last string
|
||||
if err := client.Stream(ctx, convertParams(args), func(event types.ChatCompletionEvent) error {
|
||||
if err := client.Stream(ctx, convertParams(args), func(event types.CompletionResponse) error {
|
||||
// each completion is just a partial of the final result, since we're in a sync request anyway
|
||||
// we will just wait for the final completion event
|
||||
last = event.Completion
|
||||
@ -76,8 +76,8 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
|
||||
return last, nil
|
||||
}
|
||||
|
||||
func convertParams(args graphqlbackend.CompletionsArgs) types.ChatCompletionRequestParameters {
|
||||
return types.ChatCompletionRequestParameters{
|
||||
func convertParams(args graphqlbackend.CompletionsArgs) types.CompletionRequestParameters {
|
||||
return types.CompletionRequestParameters{
|
||||
Messages: convertMessages(args.Input.Messages),
|
||||
Temperature: float32(args.Input.Temperature),
|
||||
MaxTokensToSample: int(args.Input.MaxTokensToSample),
|
||||
|
||||
@ -15,7 +15,7 @@ import (
|
||||
const anthropicAPIURL = "https://api.anthropic.com/v1/complete"
|
||||
|
||||
func newAnthropicHandler(logger log.Logger, eventLogger events.Logger, accessToken string) http.Handler {
|
||||
return makeUpstreamHandler[anthropicRequest](
|
||||
return makeUpstreamHandler(
|
||||
logger,
|
||||
eventLogger,
|
||||
anthropicAPIURL,
|
||||
|
||||
@ -15,7 +15,7 @@ import (
|
||||
const openAIURL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
func newOpenAIHandler(logger log.Logger, eventLogger events.Logger, accessToken string, orgID string) http.Handler {
|
||||
return makeUpstreamHandler[openaiRequest](
|
||||
return makeUpstreamHandler(
|
||||
logger,
|
||||
eventLogger,
|
||||
openAIURL,
|
||||
|
||||
@ -12,17 +12,6 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
type AnthropicCompletionsRequestParameters struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
const ProviderName = "anthropic"
|
||||
|
||||
func NewClient(cli httpcli.Doer, accessToken string, model string) types.CompletionsClient {
|
||||
@ -33,77 +22,41 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet
|
||||
}
|
||||
}
|
||||
|
||||
const apiURL = "https://api.anthropic.com/v1/complete"
|
||||
const clientID = "sourcegraph/1.0"
|
||||
|
||||
type anthropicClient struct {
|
||||
cli httpcli.Doer
|
||||
accessToken string
|
||||
model string
|
||||
}
|
||||
|
||||
const apiURL = "https://api.anthropic.com/v1/complete"
|
||||
const clientID = "sourcegraph/1.0"
|
||||
|
||||
var stopSequences = []string{HUMAN_PROMPT}
|
||||
|
||||
var allowedClientSpecifiedModels = map[string]struct{}{
|
||||
"claude-instant-v1.0": {},
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Complete(
|
||||
ctx context.Context,
|
||||
requestParams types.CodeCompletionRequestParameters,
|
||||
) (*types.CodeCompletionResponse, error) {
|
||||
var model string
|
||||
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
|
||||
model = requestParams.Model
|
||||
} else {
|
||||
model = a.model
|
||||
}
|
||||
payload := AnthropicCompletionsRequestParameters{
|
||||
Stream: false,
|
||||
StopSequences: requestParams.StopSequences,
|
||||
Model: model,
|
||||
Temperature: float32(requestParams.Temperature),
|
||||
MaxTokensToSample: requestParams.MaxTokensToSample,
|
||||
TopP: float32(requestParams.TopP),
|
||||
TopK: requestParams.TopK,
|
||||
Prompt: requestParams.Prompt,
|
||||
}
|
||||
|
||||
resp, err := a.makeRequest(ctx, payload)
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
resp, err := a.makeRequest(ctx, requestParams, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response types.CodeCompletionResponse
|
||||
var response anthropicCompletionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
return &types.CompletionResponse{
|
||||
Completion: response.Completion,
|
||||
StopReason: response.StopReason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Stream(
|
||||
ctx context.Context,
|
||||
requestParams types.ChatCompletionRequestParameters,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
prompt, err := getPrompt(requestParams.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := AnthropicCompletionsRequestParameters{
|
||||
Stream: true,
|
||||
StopSequences: stopSequences,
|
||||
Model: a.model,
|
||||
Temperature: requestParams.Temperature,
|
||||
MaxTokensToSample: requestParams.MaxTokensToSample,
|
||||
TopP: requestParams.TopP,
|
||||
TopK: requestParams.TopK,
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
resp, err := a.makeRequest(ctx, payload)
|
||||
resp, err := a.makeRequest(ctx, requestParams, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -122,12 +75,15 @@ func (a *anthropicClient) Stream(
|
||||
continue
|
||||
}
|
||||
|
||||
var event types.ChatCompletionEvent
|
||||
var event anthropicCompletionResponse
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
|
||||
}
|
||||
|
||||
err = sendEvent(event)
|
||||
err = sendEvent(types.CompletionResponse{
|
||||
Completion: event.Completion,
|
||||
StopReason: event.StopReason,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -136,7 +92,32 @@ func (a *anthropicClient) Stream(
|
||||
return dec.Err()
|
||||
}
|
||||
|
||||
func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicCompletionsRequestParameters) (*http.Response, error) {
|
||||
func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
|
||||
prompt, err := getPrompt(requestParams.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Backcompat: Remove this code once enough clients are upgraded and we drop the
|
||||
// Prompt field on requestParams.
|
||||
if prompt == "" {
|
||||
prompt = requestParams.Prompt
|
||||
}
|
||||
|
||||
if len(requestParams.StopSequences) == 0 {
|
||||
requestParams.StopSequences = []string{HUMAN_PROMPT}
|
||||
}
|
||||
|
||||
payload := anthropicCompletionsRequestParameters{
|
||||
Stream: stream,
|
||||
StopSequences: requestParams.StopSequences,
|
||||
Model: a.model,
|
||||
Temperature: requestParams.Temperature,
|
||||
MaxTokensToSample: requestParams.MaxTokensToSample,
|
||||
TopP: requestParams.TopP,
|
||||
TopK: requestParams.TopK,
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -168,3 +149,19 @@ func (a *anthropicClient) makeRequest(ctx context.Context, payload AnthropicComp
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type anthropicCompletionsRequestParameters struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type anthropicCompletionResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
@ -49,8 +49,8 @@ func TestValidAnthropicStream(t *testing.T) {
|
||||
}
|
||||
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicResponseLines))
|
||||
events := []types.ChatCompletionEvent{}
|
||||
err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error {
|
||||
events := []types.CompletionResponse{}
|
||||
err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error {
|
||||
events = append(events, event)
|
||||
return nil
|
||||
})
|
||||
@ -64,7 +64,7 @@ func TestInvalidAnthropicStream(t *testing.T) {
|
||||
var mockAnthropicInvalidResponseLines = []string{`{]`}
|
||||
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines))
|
||||
err := mockClient.Stream(context.Background(), types.ChatCompletionRequestParameters{}, func(event types.ChatCompletionEvent) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
[]types.ChatCompletionEvent{
|
||||
[]types.CompletionResponse{
|
||||
{
|
||||
Completion: "Sure!",
|
||||
},
|
||||
|
||||
@ -36,14 +36,14 @@ func NewClient(cli httpcli.Doer, accessToken string, model string) types.Complet
|
||||
|
||||
func (a *dotcomClient) Complete(
|
||||
ctx context.Context,
|
||||
requestParams types.CodeCompletionRequestParameters,
|
||||
) (*types.CodeCompletionResponse, error) {
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *dotcomClient) Stream(
|
||||
ctx context.Context,
|
||||
requestParams types.ChatCompletionRequestParameters,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
reqBody, err := json.Marshal(requestParams)
|
||||
@ -75,7 +75,7 @@ func (a *dotcomClient) Stream(
|
||||
return nil
|
||||
}
|
||||
|
||||
var event types.ChatCompletionEvent
|
||||
var event types.CompletionResponse
|
||||
if err := json.Unmarshal(dec.Data(), &event); err != nil {
|
||||
return errors.Errorf("failed to decode event payload: %w", err)
|
||||
}
|
||||
|
||||
@ -32,13 +32,13 @@ type openAIChatCompletionStreamClient struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CodeCompletionRequestParameters) (*types.CodeCompletionResponse, error) {
|
||||
func (a *openAIChatCompletionStreamClient) Complete(ctx context.Context, requestParams types.CompletionRequestParameters) (*types.CompletionResponse, error) {
|
||||
return nil, errors.New("openAIChatCompletionStreamClient.Complete: unimplemented")
|
||||
}
|
||||
|
||||
func (a *openAIChatCompletionStreamClient) Stream(
|
||||
ctx context.Context,
|
||||
requestParams types.ChatCompletionRequestParameters,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
if requestParams.TopK < 0 {
|
||||
@ -128,9 +128,13 @@ func (a *openAIChatCompletionStreamClient) Stream(
|
||||
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
|
||||
}
|
||||
|
||||
if len(event.Choices) > 0 && event.Choices[0].FinishReason == nil {
|
||||
if len(event.Choices) > 0 {
|
||||
content = append(content, event.Choices[0].Delta.Content)
|
||||
err = sendEvent(types.ChatCompletionEvent{Completion: strings.Join(content, "")})
|
||||
ev := types.CompletionResponse{Completion: strings.Join(content, "")}
|
||||
if event.Choices[0].FinishReason != nil {
|
||||
ev.StopReason = *event.Choices[0].FinishReason
|
||||
}
|
||||
err = sendEvent(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -13,12 +13,24 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
var allowedClientSpecifiedModels = map[string]struct{}{
|
||||
// TODO(eseliger): This list should probably be configurable.
|
||||
"claude-instant-v1.0": {},
|
||||
}
|
||||
|
||||
// NewCodeCompletionsHandler is an http handler which sends back code completion results
|
||||
func NewCodeCompletionsHandler(_ log.Logger, db database.DB) http.Handler {
|
||||
rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCodeCompletion)
|
||||
return newCompletionsHandler(rl, "codeCompletions", func(c *schema.Completions) string {
|
||||
return c.CompletionModel
|
||||
}, func(ctx context.Context, requestParams types.CodeCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
|
||||
return newCompletionsHandler(rl, "codeCompletions", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
|
||||
var model string
|
||||
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
|
||||
model = requestParams.Model
|
||||
} else {
|
||||
model = c.CompletionModel
|
||||
}
|
||||
|
||||
return model
|
||||
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
|
||||
completion, err := cc.Complete(ctx, requestParams)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
// being cancelled.
|
||||
const maxRequestDuration = time.Minute
|
||||
|
||||
func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(*schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler {
|
||||
func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel func(T, *schema.Completions) string, handle func(context.Context, T, types.CompletionsClient, http.ResponseWriter)) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed)
|
||||
@ -45,7 +45,7 @@ func newCompletionsHandler[T any](rl RateLimiter, traceFamily string, getModel f
|
||||
return
|
||||
}
|
||||
|
||||
model := getModel(completionsConfig)
|
||||
model := getModel(requestParams, completionsConfig)
|
||||
|
||||
var err error
|
||||
ctx, done := Trace(ctx, traceFamily, model).
|
||||
|
||||
@ -17,9 +17,16 @@ import (
|
||||
// NewCompletionsStreamHandler is an http handler which streams back completions results.
|
||||
func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler {
|
||||
rl := NewRateLimiter(db, redispool.Store, RateLimitScopeCompletion)
|
||||
return newCompletionsHandler(rl, "stream", func(c *schema.Completions) string {
|
||||
return c.ChatModel
|
||||
}, func(ctx context.Context, requestParams types.ChatCompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
|
||||
return newCompletionsHandler(rl, "stream", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
|
||||
var model string
|
||||
if _, isAllowed := allowedClientSpecifiedModels[requestParams.Model]; isAllowed {
|
||||
model = requestParams.Model
|
||||
} else {
|
||||
model = c.ChatModel
|
||||
}
|
||||
|
||||
return model
|
||||
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
|
||||
eventWriter, err := streamhttp.NewWriter(w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@ -31,7 +38,7 @@ func NewCompletionsStreamHandler(logger log.Logger, db database.DB) http.Handler
|
||||
_ = eventWriter.Event("done", map[string]any{})
|
||||
}()
|
||||
|
||||
err = cc.Stream(ctx, requestParams, func(event types.ChatCompletionEvent) error { return eventWriter.Event("completion", event) })
|
||||
err = cc.Stream(ctx, requestParams, func(event types.CompletionResponse) error { return eventWriter.Event("completion", event) })
|
||||
if err != nil {
|
||||
trace.Logger(ctx, logger).Error("error while streaming completions", log.Error(err))
|
||||
_ = eventWriter.Event("error", map[string]string{"error": err.Error()})
|
||||
|
||||
@ -15,36 +15,8 @@ type Message struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type CodeCompletionRequestParameters struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokensToSample int `json:"maxTokensToSample"`
|
||||
StopSequences []string `json:"stopSequences"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Tags map[string]string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
type CodeCompletionResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
Stop *string `json:"stop"`
|
||||
StopReason string `json:"stopReason"`
|
||||
Truncated bool `json:"truncated"`
|
||||
Exception *string `json:"exception"`
|
||||
LogID string `json:"logID"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequestParameters struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokensToSample int `json:"maxTokensToSample"`
|
||||
TopK int `json:"topK"`
|
||||
TopP float32 `json:"topP"`
|
||||
}
|
||||
|
||||
type ChatCompletionEvent struct {
|
||||
Completion string `json:"completion"`
|
||||
func (m Message) IsValidSpeaker() bool {
|
||||
return m.Speaker == HUMAN_MESSAGE_SPEAKER || m.Speaker == ASISSTANT_MESSAGE_SPEAKER
|
||||
}
|
||||
|
||||
func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (string, error) {
|
||||
@ -65,9 +37,28 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str
|
||||
return fmt.Sprintf("%s %s", prefix, m.Text), nil
|
||||
}
|
||||
|
||||
type SendCompletionEvent func(event ChatCompletionEvent) error
|
||||
type CompletionRequestParameters struct {
|
||||
// Prompt exists only for backwards compatibility. Do not use it in new
|
||||
// implementations. It will be removed once we are reasonably sure 99%
|
||||
// of VSCode extension installations are upgraded to a new Cody version.
|
||||
Prompt string `json:"prompt"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokensToSample int `json:"maxTokensToSample,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
TopP float32 `json:"topP,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stopReason"`
|
||||
}
|
||||
|
||||
type SendCompletionEvent func(event CompletionResponse) error
|
||||
|
||||
type CompletionsClient interface {
|
||||
Stream(ctx context.Context, requestParams ChatCompletionRequestParameters, sendEvent SendCompletionEvent) error
|
||||
Complete(ctx context.Context, requestParams CodeCompletionRequestParameters) (*CodeCompletionResponse, error)
|
||||
Stream(ctx context.Context, requestParams CompletionRequestParameters, sendEvent SendCompletionEvent) error
|
||||
Complete(ctx context.Context, requestParams CompletionRequestParameters) (*CompletionResponse, error)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user