LLM-enhanced keyword context (#52815)

* Adds `fast` parameter to completions endpoint and `fastChatModel`
param to site config. This is intended for faster chat models that are
useful for simple generations.
* Use the fast chat model to generate a local keyword search. This
replaces the old keyword search mechanism, which stemmed/lemmatized
every word in the user query.
* Use the fast chat model to generate a small set of file fragments to
search for. This is mainly useful for surfacing READMEs for questions
like "What does this project do?"
* Update the set of "files read" presented in the UI to include only
those files actually read into the context window. Previously, we were
showing all files returned by the context fetcher, but in reality, only
a subset of these would fit into the context window.
This commit is contained in:
Beyang Liu 2023-06-12 11:56:08 -07:00 committed by GitHub
parent e541d08400
commit dd528f30fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 849 additions and 445 deletions

View File

@ -111,7 +111,8 @@ async function startCLI() {
}
}
const finalPrompt = await transcript.toPrompt(getPreamble(codebase))
const { prompt: finalPrompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(codebase))
transcript.setUsedContextFilesForLastInteraction(contextFiles)
let text = ''
streamCompletions(completionsClient, finalPrompt, {

View File

@ -1,6 +1,6 @@
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context'
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context'
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
import { isError } from '@sourcegraph/cody-shared/src/utils'
@ -26,7 +26,8 @@ export async function createCodebaseContext(
{ useContext: contextType, serverEndpoint },
codebase,
embeddingsSearch,
new LocalKeywordContextFetcherMock()
new LocalKeywordContextFetcherMock(),
null
)
return codebaseContext

View File

@ -45,7 +45,8 @@ export async function interactionFromMessage(
new Interaction(
{ speaker: 'human', text, displayText: text },
{ speaker: 'assistant', text: '', displayText: '' },
contextMessages
contextMessages,
[]
)
)
}

View File

@ -58,6 +58,7 @@ ts_project(
"src/chat/viewHelpers.ts",
"src/codebase-context/index.ts",
"src/codebase-context/messages.ts",
"src/codebase-context/rerank.ts",
"src/configuration.ts",
"src/editor/index.ts",
"src/editor/withPreselectedOptions.ts",
@ -68,7 +69,7 @@ ts_project(
"src/hallucinations-detector/index.ts",
"src/intent-detector/client.ts",
"src/intent-detector/index.ts",
"src/keyword-context/index.ts",
"src/local-context/index.ts",
"src/prompt/constants.ts",
"src/prompt/prompt-mixin.ts",
"src/prompt/templates.ts",
@ -93,6 +94,8 @@ ts_project(
deps = [
":node_modules/@sourcegraph/common",
":node_modules/@sourcegraph/http-client",
":node_modules/@types/xml2js",
":node_modules/xml2js",
"//:node_modules/@microsoft/fetch-event-source",
"//:node_modules/@types/isomorphic-fetch",
"//:node_modules/@types/marked",

View File

@ -19,6 +19,10 @@
},
"dependencies": {
"@sourcegraph/common": "workspace:*",
"@sourcegraph/http-client": "workspace:*"
"@sourcegraph/http-client": "workspace:*",
"xml2js": "^0.6.0"
},
"devDependencies": {
"@types/xml2js": "^0.4.11"
}
}

View File

@ -3,7 +3,9 @@ import { Message } from '../sourcegraph-api'
import type { SourcegraphCompletionsClient } from '../sourcegraph-api/completions/client'
import type { CompletionParameters, CompletionCallbacks } from '../sourcegraph-api/completions/types'
const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<CompletionParameters, 'messages'> = {
type ChatParameters = Omit<CompletionParameters, 'messages'>
const DEFAULT_CHAT_COMPLETION_PARAMETERS: ChatParameters = {
temperature: 0.2,
maxTokensToSample: SOLUTION_TOKEN_LENGTH,
topK: -1,
@ -13,10 +15,17 @@ const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<CompletionParameters, 'messages'>
export class ChatClient {
constructor(private completions: SourcegraphCompletionsClient) {}
public chat(messages: Message[], cb: CompletionCallbacks): () => void {
public chat(messages: Message[], cb: CompletionCallbacks, params?: Partial<ChatParameters>): () => void {
const isLastMessageFromHuman = messages.length > 0 && messages[messages.length - 1].speaker === 'human'
const augmentedMessages = isLastMessageFromHuman ? messages.concat([{ speaker: 'assistant' }]) : messages
return this.completions.stream({ messages: augmentedMessages, ...DEFAULT_CHAT_COMPLETION_PARAMETERS }, cb)
return this.completions.stream(
{
...DEFAULT_CHAT_COMPLETION_PARAMETERS,
...params,
messages: augmentedMessages,
},
cb
)
}
}

View File

@ -69,7 +69,7 @@ export async function createClient({
const embeddingsSearch = repoId ? new SourcegraphEmbeddingsSearchClient(graphqlClient, repoId, true) : null
const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null)
const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null, null)
const intentDetector = new SourcegraphIntentDetectorClient(graphqlClient)
@ -116,10 +116,11 @@ export async function createClient({
sendTranscript()
const prompt = await transcript.toPrompt(getPreamble(config.codebase))
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(config.codebase))
transcript.setUsedContextFilesForLastInteraction(contextFiles)
const responsePrefix = interaction.getAssistantMessage().prefix ?? ''
let rawText = ''
chatClient.chat(prompt, {
onChange(_rawText) {
rawText = _rawText

View File

@ -31,7 +31,8 @@ export class ChatQuestion implements Recipe {
context.intentDetector,
context.codebaseContext,
context.editor.getActiveTextEditorSelection() || null
)
),
[]
)
)
}

View File

@ -46,7 +46,8 @@ export class ContextSearch implements Recipe {
text: '',
displayText: await this.displaySearchResults(truncatedText, context.codebaseContext, wsRootPath),
},
new Promise(resolve => resolve([]))
new Promise(resolve => resolve([])),
[]
)
}

View File

@ -32,7 +32,8 @@ export class ExplainCodeDetailed implements Recipe {
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -32,7 +32,8 @@ export class ExplainCodeHighLevel implements Recipe {
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -103,7 +103,8 @@ export class FileTouch implements Recipe {
speaker: 'assistant',
prefix: 'Working on it! I will notify you when the file is ready.\n',
},
this.getContextMessages(selection, currentDir)
this.getContextMessages(selection, currentDir),
[]
)
)
}

View File

@ -36,7 +36,8 @@ If you have no ideas because the code looks fine, feel free to say that it alrea
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
new Promise(resolve => resolve([]))
new Promise(resolve => resolve([])),
[]
)
}
}

View File

@ -67,7 +67,8 @@ export class Fixup implements Recipe {
speaker: 'assistant',
prefix: 'Check your document for updates from Cody.\n',
},
this.getContextMessages(selection.selectedText, context.codebaseContext)
this.getContextMessages(selection.selectedText, context.codebaseContext),
[]
)
)
}

View File

@ -60,7 +60,8 @@ export class GenerateDocstring implements Recipe {
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -52,7 +52,8 @@ export class PrDescription implements Recipe {
prefix: emptyGitCommitMessage,
text: emptyGitCommitMessage,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
@ -71,7 +72,8 @@ export class PrDescription implements Recipe {
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
}

View File

@ -71,7 +71,8 @@ export class ReleaseNotes implements Recipe {
prefix: emptyGitLogMessage,
text: emptyGitLogMessage,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
@ -91,7 +92,8 @@ export class ReleaseNotes implements Recipe {
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
}

View File

@ -44,7 +44,8 @@ export class GenerateTest implements Recipe {
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -65,7 +65,8 @@ export class GitHistory implements Recipe {
prefix: emptyGitLogMessage,
text: emptyGitLogMessage,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
@ -84,7 +85,8 @@ export class GitHistory implements Recipe {
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
}

View File

@ -44,7 +44,8 @@ export class ImproveVariableNames implements Recipe {
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -56,7 +56,8 @@ export class InlineAssist implements Recipe {
displayText,
},
{ speaker: 'assistant' },
this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor)
this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor),
[]
)
)
}

View File

@ -31,7 +31,8 @@ export class NextQuestions implements Recipe {
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext)
this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext),
[]
)
)
}

View File

@ -76,7 +76,8 @@ export class NonStop implements Recipe {
speaker: 'assistant',
prefix: 'Check your document for updates from Cody.',
},
this.getContextMessages(selection.selectedText, context.codebaseContext)
this.getContextMessages(selection.selectedText, context.codebaseContext),
[]
)
)
}

View File

@ -51,7 +51,8 @@ However if no optimization is possible; just say the code is already optimized.
truncatedFollowingText,
selection,
context.codebaseContext
)
),
[]
)
}
}

View File

@ -39,7 +39,8 @@ export class TranslateToLanguage implements Recipe {
prefix: assistantResponsePrefix,
text: assistantResponsePrefix,
},
Promise.resolve([])
Promise.resolve([]),
[]
)
}
}

View File

@ -1,5 +1,6 @@
import { OldContextMessage } from '../../codebase-context/messages'
import { ContextFile, ContextMessage, OldContextMessage } from '../../codebase-context/messages'
import { CHARS_PER_TOKEN, MAX_AVAILABLE_PROMPT_LENGTH } from '../../prompt/constants'
import { PromptMixin } from '../../prompt/prompt-mixin'
import { Message } from '../../sourcegraph-api'
import { Interaction, InteractionJSON } from './interaction'
@ -20,18 +21,19 @@ export interface TranscriptJSON {
}
/**
* A transcript of a conversation between a human and an assistant.
* The "model" class that tracks the call and response of the Cody chat box.
* Any "controller" logic belongs outside of this class.
*/
export class Transcript {
public static fromJSON(json: TranscriptJSON): Transcript {
return new Transcript(
json.interactions.map(
({ humanMessage, assistantMessage, context, timestamp }) =>
({ humanMessage, assistantMessage, fullContext, usedContextFiles, timestamp }) =>
new Interaction(
humanMessage,
assistantMessage,
Promise.resolve(
context.map(message => {
fullContext.map(message => {
if (message.file) {
return message
}
@ -44,6 +46,7 @@ export class Transcript {
return message
})
),
usedContextFiles,
timestamp || new Date().toISOString()
)
),
@ -144,20 +147,53 @@ export class Transcript {
return -1
}
public async toPrompt(preamble: Message[] = []): Promise<Message[]> {
public async getPromptForLastInteraction(
preamble: Message[] = []
): Promise<{ prompt: Message[]; contextFiles: ContextFile[] }> {
if (this.interactions.length === 0) {
return { prompt: [], contextFiles: [] }
}
const lastInteractionWithContextIndex = await this.getLastInteractionWithContextIndex()
const messages: Message[] = []
for (let index = 0; index < this.interactions.length; index++) {
// Include context messages for the last interaction that has a non-empty context.
const interactionMessages = await this.interactions[index].toPrompt(
index === lastInteractionWithContextIndex
)
messages.push(...interactionMessages)
const interaction = this.interactions[index]
const humanMessage = PromptMixin.mixInto(interaction.getHumanMessage())
const assistantMessage = interaction.getAssistantMessage()
const contextMessages = await interaction.getFullContext()
if (index === lastInteractionWithContextIndex) {
messages.push(...contextMessages, humanMessage, assistantMessage)
} else {
messages.push(humanMessage, assistantMessage)
}
}
const preambleTokensUsage = preamble.reduce((acc, message) => acc + estimateTokensUsage(message), 0)
const truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage)
return [...preamble, ...truncatedMessages]
let truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage)
// Return what context fits in the window
const contextFiles: ContextFile[] = []
for (const msg of truncatedMessages) {
const contextFile = (msg as ContextMessage).file
if (contextFile) {
contextFiles.push(contextFile)
}
}
// Filter out extraneous fields from ContextMessage instances
truncatedMessages = truncatedMessages.map(({ speaker, text }) => ({ speaker, text }))
return {
prompt: [...preamble, ...truncatedMessages],
contextFiles,
}
}
public setUsedContextFilesForLastInteraction(contextFiles: ContextFile[]): void {
if (this.interactions.length === 0) {
throw new Error('Cannot set context files for empty transcript')
}
this.interactions[this.interactions.length - 1].setUsedContext(contextFiles)
}
public toChat(): ChatMessage[] {

View File

@ -1,89 +1,59 @@
import { ContextMessage, ContextFile } from '../../codebase-context/messages'
import { PromptMixin } from '../../prompt/prompt-mixin'
import { Message } from '../../sourcegraph-api'
import { ChatMessage, InteractionMessage } from './messages'
export interface InteractionJSON {
humanMessage: InteractionMessage
assistantMessage: InteractionMessage
context: ContextMessage[]
fullContext: ContextMessage[]
usedContextFiles: ContextFile[]
timestamp: string
}
export class Interaction {
private readonly humanMessage: InteractionMessage
private assistantMessage: InteractionMessage
private cachedContextFiles: ContextFile[] = []
public readonly timestamp: string
private readonly context: Promise<ContextMessage[]>
constructor(
humanMessage: InteractionMessage,
assistantMessage: InteractionMessage,
context: Promise<ContextMessage[]>,
timestamp: string = new Date().toISOString()
) {
this.humanMessage = humanMessage
this.assistantMessage = assistantMessage
this.timestamp = timestamp
// This is some hacky behavior: returns a promise that resolves to the same array that was passed,
// but also caches the context file names in memory as a side effect.
this.context = context.then(messages => {
const contextFilesMap = messages.reduce((map, { file }) => {
if (!file?.fileName) {
return map
}
map[`${file.repoName || 'repo'}@${file?.revision || 'HEAD'}/${file.fileName}`] = file
return map
}, {} as { [key: string]: ContextFile })
// Cache the context files so we don't have to block the UI when calling `toChat` by waiting for the context to resolve.
this.cachedContextFiles = [
...Object.keys(contextFilesMap)
.sort((a, b) => a.localeCompare(b))
.map((key: string) => contextFilesMap[key]),
]
return messages
})
}
private readonly humanMessage: InteractionMessage,
private assistantMessage: InteractionMessage,
private fullContext: Promise<ContextMessage[]>,
private usedContextFiles: ContextFile[],
public readonly timestamp: string = new Date().toISOString()
) {}
public getAssistantMessage(): InteractionMessage {
return this.assistantMessage
return { ...this.assistantMessage }
}
public setAssistantMessage(assistantMessage: InteractionMessage): void {
this.assistantMessage = assistantMessage
}
public getHumanMessage(): InteractionMessage {
return { ...this.humanMessage }
}
public async getFullContext(): Promise<ContextMessage[]> {
const msgs = await this.fullContext
return msgs.map(msg => ({ ...msg }))
}
public async hasContext(): Promise<boolean> {
const contextMessages = await this.context
const contextMessages = await this.fullContext
return contextMessages.length > 0
}
public async toPrompt(includeContext: boolean): Promise<Message[]> {
const messages: (ContextMessage | InteractionMessage)[] = [
PromptMixin.mixInto(this.humanMessage),
this.assistantMessage,
]
if (includeContext) {
messages.unshift(...(await this.context))
}
return messages.map(message => ({ speaker: message.speaker, text: message.text }))
public setUsedContext(usedContextFiles: ContextFile[]): void {
this.usedContextFiles = usedContextFiles
}
/**
* Converts the interaction to chat message pair: one message from a human, one from an assistant.
*/
public toChat(): ChatMessage[] {
return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.cachedContextFiles }]
return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.usedContextFiles }]
}
public async toChatPromise(): Promise<ChatMessage[]> {
await this.context
await this.fullContext
return this.toChat()
}
@ -91,7 +61,8 @@ export class Interaction {
return {
humanMessage: this.humanMessage,
assistantMessage: this.assistantMessage,
context: await this.context,
fullContext: await this.fullContext,
usedContextFiles: this.usedContextFiles,
timestamp: this.timestamp,
}
}

View File

@ -38,7 +38,7 @@ async function generateLongTranscript(): Promise<{ transcript: Transcript; token
describe('Transcript', () => {
it('generates an empty prompt with no interactions', async () => {
const transcript = new Transcript()
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
assert.deepStrictEqual(prompt, [])
})
@ -51,7 +51,7 @@ describe('Transcript', () => {
const transcript = new Transcript()
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
{ speaker: 'assistant', text: undefined },
@ -78,7 +78,8 @@ describe('Transcript', () => {
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
'dummy-codebase',
embeddings,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
),
})
)
@ -86,7 +87,7 @@ describe('Transcript', () => {
const transcript = new Transcript()
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
{ speaker: 'assistant', text: 'Ok.' },
@ -114,7 +115,8 @@ describe('Transcript', () => {
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
'dummy-codebase',
embeddings,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
),
firstInteraction: true,
})
@ -123,7 +125,7 @@ describe('Transcript', () => {
const transcript = new Transcript()
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
{ speaker: 'assistant', text: 'Ok.' },
@ -148,7 +150,8 @@ describe('Transcript', () => {
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
'dummy-codebase',
embeddings,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
)
const chatQuestionRecipe = new ChatQuestion(() => {})
@ -175,7 +178,7 @@ describe('Transcript', () => {
)
transcript.addInteraction(secondInteraction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
{ speaker: 'assistant', text: assistantResponse },
@ -195,7 +198,7 @@ describe('Transcript', () => {
const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction)
const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages.
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
assert.deepStrictEqual(prompt.length, numExpectedMessages)
})
@ -212,7 +215,7 @@ describe('Transcript', () => {
const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction)
const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages.
const prompt = await transcript.toPrompt(preamble)
const { prompt } = await transcript.getPromptForLastInteraction(preamble)
assert.deepStrictEqual(prompt.length, numExpectedMessages)
assert.deepStrictEqual(preamble, prompt.slice(0, 4))
})
@ -233,7 +236,8 @@ describe('Transcript', () => {
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
'dummy-codebase',
embeddings,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
)
const chatQuestionRecipe = new ChatQuestion(() => {})
@ -249,7 +253,7 @@ describe('Transcript', () => {
)
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
{ speaker: 'assistant', text: 'Ok.' },
@ -288,7 +292,7 @@ describe('Transcript', () => {
)
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
{ speaker: 'assistant', text: undefined },
@ -309,7 +313,8 @@ describe('Transcript', () => {
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
'dummy-codebase',
embeddings,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
)
const chatQuestionRecipe = new ChatQuestion(() => {})
@ -344,7 +349,7 @@ describe('Transcript', () => {
)
transcript.addInteraction(thirdInteraction)
const prompt = await transcript.toPrompt()
const { prompt } = await transcript.getPromptForLastInteraction()
const expectedPrompt = [
{ speaker: 'human', text: 'how do batch changes work in sourcegraph' },
{ speaker: 'assistant', text: 'Smartly.' },

View File

@ -222,7 +222,14 @@ export const useClient = ({
}
const unifiedContextFetcherClient = new UnifiedContextFetcherClient(graphqlClient, repoIds)
const codebaseContext = new CodebaseContext(config, undefined, null, null, unifiedContextFetcherClient)
const codebaseContext = new CodebaseContext(
config,
undefined,
null,
null,
null,
unifiedContextFetcherClient
)
const { humanChatInput = '', prefilledOptions } = options ?? {}
// TODO(naman): save scope with each interaction
@ -242,10 +249,13 @@ export const useClient = ({
setIsMessageInProgressState(true)
onEvent?.('submit')
const prompt = await transcript.toPrompt(getMultiRepoPreamble(repoNames))
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(
getMultiRepoPreamble(repoNames)
)
transcript.setUsedContextFilesForLastInteraction(contextFiles)
const responsePrefix = interaction.getAssistantMessage().prefix ?? ''
let rawText = ''
return new Promise(resolve => {
chatClient.chat(prompt, {
onChange(_rawText) {

View File

@ -1,6 +1,6 @@
import { Configuration } from '../configuration'
import { EmbeddingsSearch } from '../embeddings'
import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context'
import { FilenameContextFetcher, KeywordContextFetcher, ContextResult } from '../local-context'
import { isMarkdownFile, populateCodeContextTemplate, populateMarkdownContextTemplate } from '../prompt/templates'
import { Message } from '../sourcegraph-api'
import { EmbeddingsSearchResult } from '../sourcegraph-api/graphql/client'
@ -21,7 +21,9 @@ export class CodebaseContext {
private codebase: string | undefined,
private embeddings: EmbeddingsSearch | null,
private keywords: KeywordContextFetcher | null,
private unifiedContextFetcher?: UnifiedContextFetcher | null
private filenames: FilenameContextFetcher | null,
private unifiedContextFetcher?: UnifiedContextFetcher | null,
private rerank?: (query: string, results: ContextResult[]) => Promise<ContextResult[]>
) {}
public getCodebase(): string | undefined {
@ -32,18 +34,28 @@ export class CodebaseContext {
this.config = newConfig
}
private mergeContextResults(keywordResults: ContextResult[], filenameResults: ContextResult[]): ContextResult[] {
// Just take the single most relevant filename suggestion for now. Otherwise, because our reranking relies solely
// on filename, the filename results would dominate the keyword results.
return filenameResults.slice(-1).concat(keywordResults)
}
/**
* Returns list of context messages for a given query, sorted in *reverse* order of importance (that is,
* the most important context message appears *last*)
*/
public async getContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
switch (this.config.useContext) {
case 'unified':
return this.getUnifiedContextMessages(query, options)
case 'keyword':
return this.getKeywordContextMessages(query, options)
return this.getLocalContextMessages(query, options)
case 'none':
return []
default:
return this.embeddings
? this.getEmbeddingsContextMessages(query, options)
: this.getKeywordContextMessages(query, options)
: this.getLocalContextMessages(query, options)
}
}
@ -58,7 +70,7 @@ export class CodebaseContext {
public async getSearchResults(
query: string,
options: ContextSearchOptions
): Promise<{ results: KeywordContextFetcherResult[] | EmbeddingsSearchResult[]; endpoint: string }> {
): Promise<{ results: ContextResult[] | EmbeddingsSearchResult[]; endpoint: string }> {
if (this.embeddings && this.config.useContext !== 'keyword') {
return {
results: await this.getEmbeddingSearchResults(query, options),
@ -141,22 +153,29 @@ export class CodebaseContext {
})
}
private async getKeywordContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
const results = await this.getKeywordSearchResults(query, options)
return results.flatMap(({ content, fileName, repoName, revision }) => {
const messageText = populateCodeContextTemplate(content, fileName)
return getContextMessageWithResponse(messageText, { fileName, repoName, revision })
})
private async getLocalContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
const keywordResults = this.getKeywordSearchResults(query, options)
const filenameResults = this.getFilenameSearchResults(query, options)
const combinedResults = this.mergeContextResults(await keywordResults, await filenameResults)
const rerankedResults = await (this.rerank ? this.rerank(query, combinedResults) : combinedResults)
const messages = resultsToMessages(rerankedResults)
return messages
}
private async getKeywordSearchResults(
query: string,
options: ContextSearchOptions
): Promise<KeywordContextFetcherResult[]> {
private async getKeywordSearchResults(query: string, options: ContextSearchOptions): Promise<ContextResult[]> {
if (!this.keywords) {
return []
}
return this.keywords.getContext(query, options.numCodeResults + options.numTextResults)
const results = await this.keywords.getContext(query, options.numCodeResults + options.numTextResults)
return results
}
private async getFilenameSearchResults(query: string, options: ContextSearchOptions): Promise<ContextResult[]> {
if (!this.filenames) {
return []
}
const results = await this.filenames.getContext(query, options.numCodeResults + options.numTextResults)
return results
}
}
@ -201,3 +220,10 @@ function mergeConsecutiveResults(results: EmbeddingsSearchResult[]): string[] {
return mergedResults
}
function resultsToMessages(results: ContextResult[]): ContextMessage[] {
return results.flatMap(({ content, fileName, repoName, revision }) => {
const messageText = populateCodeContextTemplate(content, fileName)
return getContextMessageWithResponse(messageText, { fileName, repoName, revision })
})
}

View File

@ -0,0 +1,93 @@
import { parseStringPromise } from 'xml2js'
import { ChatClient } from '../chat/chat'
import { ContextResult } from '../local-context'
export interface Reranker {
rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]>
}
export class MockReranker implements Reranker {
constructor(private rerank_: (userQuery: string, results: ContextResult[]) => Promise<ContextResult[]>) {}
public rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]> {
return this.rerank_(userQuery, results)
}
}
/**
* A reranker class that uses a LLM to boost high-relevance results.
*/
export class LLMReranker implements Reranker {
constructor(private chatClient: ChatClient) {}
public async rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]> {
// Reverse the results so the most important appears first
results = [...results].reverse()
let out = await new Promise<string>((resolve, reject) => {
let responseText = ''
this.chatClient.chat(
[
{
speaker: 'human',
text: `I am a professional computer programmer and need help deciding which of these files to read first to answer my question. My question is <userQuestion>${userQuery}</userQuestion>. Select the files from the following list that I should read to answer my question, ranked by most relevant first. Format the result as XML, like this: <list><item><filename>filename 1</filename><explanation>this is why I chose this item</explanation></item><item><filename>filename 2</filename><explanation>why I chose this item</explanation></item></list>\n${results
.map(r => r.fileName)
.join('\n')}`,
},
],
{
onChange: (text: string) => {
responseText = text
},
onComplete: () => {
resolve(responseText)
},
onError: (message: string, statusCode?: number) => {
reject(new Error(`Status code ${statusCode}: ${message}`))
},
},
{
temperature: 0,
fast: true,
}
)
})
if (out.indexOf('<list>') > 0) {
out = out.slice(out.indexOf('<list>'))
}
if (out.indexOf('</list>') !== out.length - '</list>'.length) {
out = out.slice(0, out.indexOf('</list>') + '</list>'.length)
}
const boostedFilenames = await parseXml(out)
const resultsMap = Object.fromEntries(results.map(r => [r.fileName, r]))
const boostedNames = new Set<string>()
const rerankedResults = []
for (const boostedFilename of boostedFilenames) {
const boostedResult = resultsMap[boostedFilename]
if (!boostedResult) {
continue
}
rerankedResults.push(boostedResult)
boostedNames.add(boostedFilename)
}
for (const result of results) {
if (!boostedNames.has(result.fileName)) {
rerankedResults.push(result)
}
}
rerankedResults.reverse()
return rerankedResults
}
}
async function parseXml(xml: string): Promise<string[]> {
const result = await parseStringPromise(xml)
const items = result.list.item
const files: { filename: string; explanation: string }[] = items.map((item: any) => ({
filename: item.filename[0],
explanation: item.explanation[0],
}))
return files.map(f => f.filename)
}

View File

@ -1,11 +0,0 @@
export interface KeywordContextFetcherResult {
repoName?: string
revision?: string
fileName: string
content: string
}
export interface KeywordContextFetcher {
getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]>
getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]>
}

View File

@ -0,0 +1,15 @@
export interface ContextResult {
repoName?: string
revision?: string
fileName: string
content: string
}
export interface KeywordContextFetcher {
getContext(query: string, numResults: number): Promise<ContextResult[]>
getSearchContext(query: string, numResults: number): Promise<ContextResult[]>
}
export interface FilenameContextFetcher {
getContext(query: string, numResults: number): Promise<ContextResult[]>
}

View File

@ -24,6 +24,7 @@ export interface CompletionResponse {
}
export interface CompletionParameters {
fast?: boolean
messages: Message[]
maxTokensToSample: number
temperature?: number

View File

@ -4,7 +4,7 @@ import { CodebaseContext } from '../codebase-context'
import { ActiveTextEditor, ActiveTextEditorSelection, ActiveTextEditorVisibleContent, Editor } from '../editor'
import { EmbeddingsSearch } from '../embeddings'
import { IntentDetector } from '../intent-detector'
import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context'
import { KeywordContextFetcher, ContextResult } from '../local-context'
import { EmbeddingsSearchResults } from '../sourcegraph-api/graphql'
export class MockEmbeddingsClient implements EmbeddingsSearch {
@ -37,11 +37,11 @@ export class MockIntentDetector implements IntentDetector {
export class MockKeywordContextFetcher implements KeywordContextFetcher {
constructor(private mocks: Partial<KeywordContextFetcher> = {}) {}
public getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
public getContext(query: string, numResults: number): Promise<ContextResult[]> {
return this.mocks.getContext?.(query, numResults) ?? Promise.resolve([])
}
public getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
public getSearchContext(query: string, numResults: number): Promise<ContextResult[]> {
return this.mocks.getSearchContext?.(query, numResults) ?? Promise.resolve([])
}
}
@ -111,7 +111,8 @@ export function newRecipeContext(args?: Partial<RecipeContext>): RecipeContext {
{ useContext: 'none', serverEndpoint: 'https://example.com' },
'dummy-codebase',
defaultEmbeddingsClient,
defaultKeywordContextFetcher
defaultKeywordContextFetcher,
null
),
responseMultiplexer: args.responseMultiplexer || new BotResponseMultiplexer(),
firstInteraction: args.firstInteraction ?? false,

View File

@ -44,7 +44,8 @@ export async function handleHumanMessage(event: AppMentionEvent, appContext: App
const response = await slackHelpers.postMessage(IN_PROGRESS_MESSAGE, channel, thread_ts)
// Generate a prompt and start completion streaming
const prompt = await transcript.toPrompt(SLACK_PREAMBLE)
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(SLACK_PREAMBLE)
transcript.setUsedContextFilesForLastInteraction(contextFiles)
console.log('PROMPT', prompt)
startCompletionStreaming(prompt, channel, transcript, response?.ts)
}

View File

@ -2,7 +2,7 @@ import { memoize } from 'lodash'
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context'
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context'
import { isError } from '@sourcegraph/cody-shared/src/utils'
import { sourcegraphClient } from './sourcegraph-client'
@ -36,7 +36,8 @@ export async function createCodebaseContext(
{ useContext: contextType, serverEndpoint },
codebase,
embeddingsSearch,
new LocalKeywordContextFetcherMock()
new LocalKeywordContextFetcherMock(),
null
)
return codebaseContext

View File

@ -56,7 +56,7 @@ class SlackInteraction {
}
public getTranscriptInteraction() {
return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages))
return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages), [])
}
}

View File

@ -52,9 +52,11 @@ ts_project(
"src/extension.ts",
"src/extension-api.ts",
"src/external-services.ts",
"src/keyword-context/local-keyword-context-fetcher.ts",
"src/local-app-detector.ts",
"src/local-context/filename-context-fetcher.ts",
"src/local-context/local-keyword-context-fetcher.ts",
"src/log.ts",
"src/logged-rerank.ts",
"src/main.ts",
"src/non-stop/FixupCodeLenses.ts",
"src/non-stop/FixupContentStore.ts",
@ -111,6 +113,7 @@ ts_project(
"//:node_modules/@storybook/react", #keep
"//:node_modules/@types/classnames",
"//:node_modules/@types/jest", #keep
"//:node_modules/@types/lodash",
"//:node_modules/@types/lru-cache",
"//:node_modules/@types/node",
"//:node_modules/@types/react",
@ -121,6 +124,7 @@ ts_project(
"//:node_modules/@vscode",
"//:node_modules/@vscode/webview-ui-toolkit",
"//:node_modules/classnames",
"//:node_modules/lodash",
"//:node_modules/react",
"//:node_modules/react-dom",
"//:node_modules/stream-json",
@ -140,7 +144,6 @@ ts_project(
"src/completions/context.test.ts",
"src/completions/provider.test.ts",
"src/configuration.test.ts",
"src/keyword-context/local-keyword-context-fetcher.test.ts",
"src/non-stop/diff.test.ts",
"src/non-stop/tracked-range.test.ts",
"src/non-stop/utils.test.ts",

View File

@ -1,3 +1,4 @@
import { spawnSync } from 'child_process'
import path from 'path'
import * as vscode from 'vscode'
@ -11,17 +12,24 @@ import { ChatMessage, ChatHistory } from '@sourcegraph/cody-shared/src/chat/tran
import { reformatBotMessage } from '@sourcegraph/cody-shared/src/chat/viewHelpers'
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration'
import { Editor } from '@sourcegraph/cody-shared/src/editor'
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
import { Guardrails, annotateAttribution } from '@sourcegraph/cody-shared/src/guardrails'
import { highlightTokens } from '@sourcegraph/cody-shared/src/hallucinations-detector'
import { IntentDetector } from '@sourcegraph/cody-shared/src/intent-detector'
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
import { isError } from '@sourcegraph/cody-shared/src/utils'
import { View } from '../../webviews/NavBar'
import { getFullConfig, updateConfiguration } from '../configuration'
import { VSCodeEditor } from '../editor/vscode-editor'
import { logEvent } from '../event-logger'
import { LocalAppDetector } from '../local-app-detector'
import { FilenameContextFetcher } from '../local-context/filename-context-fetcher'
import { LocalKeywordContextFetcher } from '../local-context/local-keyword-context-fetcher'
import { debug } from '../log'
import { getRerankWithLog } from '../logged-rerank'
import { FixupTask } from '../non-stop/FixupTask'
import { LocalStorage } from '../services/LocalStorageProvider'
import { CODY_ACCESS_TOKEN_SECRET, SecretStorage } from '../services/SecretStorageProvider'
@ -38,7 +46,7 @@ import {
isLoggedIn,
} from './protocol'
import { getRecipe } from './recipes'
import { getAuthStatus, getCodebaseContext } from './utils'
import { convertGitCloneURLToCodebaseName, getAuthStatus } from './utils'
export type Config = Pick<
ConfigurationWithAccessToken,
@ -107,7 +115,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
}),
vscode.workspace.onDidChangeConfiguration(async () => {
this.config = await getFullConfig(this.secretStorage)
const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor)
const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, chat)
if (newCodebaseContext) {
this.codebaseContext = newCodebaseContext
await this.setAnonymousUserID()
@ -376,7 +384,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
}
this.currentWorkspaceRoot = workspaceRoot
const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor)
const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, this.chat)
if (!codebaseContext) {
return
}
@ -430,12 +438,14 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
default: {
this.sendTranscript()
const prompt = await this.transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase()))
const { prompt, contextFiles } = await this.transcript.getPromptForLastInteraction(
getPreamble(this.codebaseContext.getCodebase())
)
this.transcript.setUsedContextFilesForLastInteraction(contextFiles)
this.sendPrompt(prompt, interaction.getAssistantMessage().prefix ?? '')
await this.saveTranscriptToChatHistory()
}
}
logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`)
}
@ -460,7 +470,10 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
}
transcript.addInteraction(interaction)
const prompt = await transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase()))
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(
getPreamble(this.codebaseContext.getCodebase())
)
transcript.setUsedContextFilesForLastInteraction(contextFiles)
logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`)
@ -818,3 +831,49 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
this.disposables = []
}
}
/**
* Gets codebase context for the current workspace.
*
* @param config Cody configuration
* @param rgPath Path to rg (ripgrep) executable
* @param editor Editor instance
* @returns CodebaseContext if a codebase can be determined, else null
*/
export async function getCodebaseContext(
config: Config,
rgPath: string,
editor: Editor,
chatClient: ChatClient
): Promise<CodebaseContext | null> {
const client = new SourcegraphGraphQLAPIClient(config)
const workspaceRoot = editor.getWorkspaceRootPath()
if (!workspaceRoot) {
return null
}
const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot })
const gitOutput = gitCommand.stdout.toString().trim()
// Get codebase from config or fallback to getting repository name from git clone URL
const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput)
if (!codebase) {
return null
}
// Check if repo is embedded in endpoint
const repoId = await client.getRepoIdIfEmbeddingExists(codebase)
if (isError(repoId)) {
const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n`
console.info(infoMessage)
return null
}
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
return new CodebaseContext(
config,
codebase,
embeddingsSearch,
new LocalKeywordContextFetcher(rgPath, editor, chatClient),
new FilenameContextFetcher(rgPath, editor, chatClient),
undefined,
getRerankWithLog(chatClient)
)
}

View File

@ -12,10 +12,11 @@ import * as path from 'path'
export async function fastFilesExist(
rgPath: string,
rootPath: string,
filePaths: string[]
filePaths: string[],
maxDepth?: number
): Promise<{ [filePath: string]: boolean }> {
const searchPattern = constructSearchPattern(filePaths)
const rgOutput = await executeRg(rgPath, rootPath, searchPattern)
const rgOutput = await executeRg(rgPath, rootPath, searchPattern, maxDepth)
return processRgOutput(rgOutput, filePaths)
}
@ -28,6 +29,7 @@ export function makeTrimRegex(sep: string): RegExp {
// Regex to match '**', '*' or path.sep at the start (^) or end ($) of the string.
const trimRegex = makeTrimRegex(path.sep)
/**
* Constructs a search pattern for the 'rg' tool.
*
@ -42,6 +44,7 @@ function constructSearchPattern(filePaths: string[]): string {
})
return `{${searchPatternParts.join(',')}}`
}
/**
* Executes the 'rg' tool and returns the output.
*
@ -50,11 +53,15 @@ function constructSearchPattern(filePaths: string[]): string {
* @param searchPattern - The search pattern to use.
* @returns The output from the 'rg' tool.
*/
async function executeRg(rgPath: string, rootPath: string, searchPattern: string): Promise<string> {
async function executeRg(rgPath: string, rootPath: string, searchPattern: string, maxDepth?: number): Promise<string> {
const args = ['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global']
if (maxDepth !== undefined) {
args.push('--max-depth', `${maxDepth}`)
}
return new Promise((resolve, reject) => {
execFile(
rgPath,
['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global'],
args,
{
cwd: rootPath,
maxBuffer: 1024 * 1024 * 1024,

View File

@ -1,15 +1,7 @@
import { spawnSync } from 'child_process'
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration'
import { Editor } from '@sourcegraph/cody-shared/src/editor'
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
import { isError } from '@sourcegraph/cody-shared/src/utils'
import { LocalKeywordContextFetcher } from '../keyword-context/local-keyword-context-fetcher'
import { Config } from './ChatViewProvider'
import { AuthStatus, defaultAuthStatus, isLocalApp, unauthenticatedStatus } from './protocol'
// Converts a git clone URL to the codebase name that includes the slash-separated code host, owner, and repository name
@ -50,40 +42,11 @@ export function convertGitCloneURLToCodebaseName(cloneURL: string): string | nul
}
return null
} catch (error) {
console.log(`Cody could not extract repo name from clone URL ${cloneURL}:`, error)
console.error(`Cody could not extract repo name from clone URL ${cloneURL}:`, error)
return null
}
}
export async function getCodebaseContext(
config: Config,
rgPath: string,
editor: Editor
): Promise<CodebaseContext | null> {
const client = new SourcegraphGraphQLAPIClient(config)
const workspaceRoot = editor.getWorkspaceRootPath()
if (!workspaceRoot) {
return null
}
const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot })
const gitOutput = gitCommand.stdout.toString().trim()
// Get codebase from config or fallback to getting repository name from git clone URL
const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput)
if (!codebase) {
return null
}
// Check if repo is embedded in endpoint
const repoId = await client.getRepoIdIfEmbeddingExists(codebase)
if (isError(repoId)) {
const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n`
console.info(infoMessage)
return null
}
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
return new CodebaseContext(config, codebase, embeddingsSearch, new LocalKeywordContextFetcher(rgPath, editor))
}
let client: SourcegraphGraphQLAPIClient
let configWithToken: Pick<ConfigurationWithAccessToken, 'serverEndpoint' | 'accessToken' | 'customHeaders'>

View File

@ -12,8 +12,10 @@ import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/s
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
import { isError } from '@sourcegraph/cody-shared/src/utils'
import { LocalKeywordContextFetcher } from './keyword-context/local-keyword-context-fetcher'
import { FilenameContextFetcher } from './local-context/filename-context-fetcher'
import { LocalKeywordContextFetcher } from './local-context/local-keyword-context-fetcher'
import { logger } from './log'
import { getRerankWithLog } from './logged-rerank'
interface ExternalServices {
intentDetector: IntentDetector
@ -48,11 +50,15 @@ export async function configureExternalServices(
}
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
const chatClient = new ChatClient(completions)
const codebaseContext = new CodebaseContext(
initialConfig,
initialConfig.codebase,
embeddingsSearch,
new LocalKeywordContextFetcher(rgPath, editor)
new LocalKeywordContextFetcher(rgPath, editor, chatClient),
new FilenameContextFetcher(rgPath, editor, chatClient),
undefined,
getRerankWithLog(chatClient)
)
const guardrails = new SourcegraphGuardrailsClient(client)
@ -60,7 +66,7 @@ export async function configureExternalServices(
return {
intentDetector: new SourcegraphIntentDetectorClient(client),
codebaseContext,
chatClient: new ChatClient(completions),
chatClient,
completionsClient: completions,
guardrails,
onConfigurationChange: newConfig => {

View File

@ -1,153 +0,0 @@
import * as assert from 'assert'
import { Term, regexForTerms, userQueryToKeywordQuery } from './local-keyword-context-fetcher'
describe('keyword context', () => {
it('userQueryToKeywordQuery', () => {
const cases: { query: string; expected: Term[] }[] = [
{
query: 'Where is auth in Sourcegraph?',
expected: [
{
count: 1,
originals: ['Where', 'Where'],
prefix: 'where',
stem: 'where',
},
{
count: 1,
originals: ['auth', 'auth'],
prefix: 'auth',
stem: 'auth',
},
{
count: 1,
originals: ['Sourcegraph', 'Sourcegraph'],
prefix: 'sourcegraph',
stem: 'sourcegraph',
},
],
},
{
query: `Explain the following code at a high level:
uint32_t PackUInt32(const Color& color) {
uint32_t result = 0;
result |= static_cast<uint32_t>(color.r * 255 + 0.5f) << 24;
result |= static_cast<uint32_t>(color.g * 255 + 0.5f) << 16;
result |= static_cast<uint32_t>(color.b * 255 + 0.5f) << 8;
result |= static_cast<uint32_t>(color.a * 255 + 0.5f);
return result;
}
`,
expected: [
{
count: 1,
originals: ['Explain', 'Explain'],
prefix: 'explain',
stem: 'explain',
},
{
count: 1,
originals: ['following', 'following'],
prefix: 'follow',
stem: 'follow',
},
{
count: 1,
originals: ['code', 'code'],
prefix: 'code',
stem: 'code',
},
{
count: 1,
originals: ['high', 'high'],
prefix: 'high',
stem: 'high',
},
{
count: 1,
originals: ['level', 'level'],
prefix: 'level',
stem: 'level',
},
{
count: 6,
originals: ['uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t'],
prefix: 'uint',
stem: 'uinty2_t',
},
{
count: 1,
originals: ['PackUInt32', 'PackUInt32'],
prefix: 'packuint',
stem: 'packuinty2',
},
{
count: 1,
originals: ['const', 'const'],
prefix: 'const',
stem: 'const',
},
{
count: 6,
originals: ['Color', 'Color', 'color', 'color', 'color', 'color', 'color'],
prefix: 'color',
stem: 'color',
},
{
count: 6,
originals: ['result', 'result', 'result', 'result', 'result', 'result', 'result'],
prefix: 'result',
stem: 'result',
},
{
count: 4,
originals: ['static_cast', 'static_cast', 'static_cast', 'static_cast', 'static_cast'],
prefix: 'static_cast',
stem: 'static_cast',
},
{
count: 4,
originals: ['255', '255', '255', '255', '255'],
prefix: '255',
stem: '255',
},
{
count: 1,
originals: ['return', 'return'],
prefix: 'return',
stem: 'return',
},
],
},
]
for (const testcase of cases) {
const actual = userQueryToKeywordQuery(testcase.query)
assert.deepStrictEqual(actual, testcase.expected)
}
})
it('query to regex', () => {
const trials: {
userQuery: string
expRegex: string
}[] = [
{
userQuery: 'Where is auth in Sourcegraph?',
expRegex: '(?:where|auth|sourcegraph)',
},
{
userQuery: 'saml auth handler',
expRegex: '(?:saml|auth|handler)',
},
{
userQuery: 'Where is the HTTP middleware defined in this codebase?',
expRegex: '(?:where|http|middlewar|defin|codebas)',
},
]
for (const trial of trials) {
const terms = userQueryToKeywordQuery(trial.userQuery)
const regex = regexForTerms(...terms)
expect(regex).toEqual(trial.expRegex)
}
})
})

View File

@ -0,0 +1,154 @@
import { execFile } from 'child_process'
import * as path from 'path'
import { uniq } from 'lodash'
import * as vscode from 'vscode'
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
import { Editor } from '@sourcegraph/cody-shared/src/editor'
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
import { debug } from '../log'
/**
* A local context fetcher that uses a LLM to generate filename fragments, which are then used to
* find files that are relevant based on their path or name.
*/
export class FilenameContextFetcher {
constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {}
/**
* Returns pieces of context relevant for the given query. Uses a filename search approach
* @param query user query
* @param numResults the number of context results to return
* @returns a list of context results, sorted in *reverse* order (that is,
* the most important result appears at the bottom)
*/
public async getContext(query: string, numResults: number): Promise<ContextResult[]> {
const time0 = performance.now()
const rootPath = this.editor.getWorkspaceRootPath()
if (!rootPath) {
return []
}
const time1 = performance.now()
const filenameFragments = await this.queryToFileFragments(query)
const time2 = performance.now()
const unsortedMatchingFiles = await this.getFilenames(rootPath, filenameFragments, 3)
const time3 = performance.now()
const specialFragments = ['readme']
const allBoostedFiles = []
let remainingFiles = unsortedMatchingFiles
let nextRemainingFiles = []
for (const specialFragment of specialFragments) {
const boostedFiles = []
for (const fileName of remainingFiles) {
const fileNameLower = fileName.toLocaleLowerCase()
if (fileNameLower.includes(specialFragment)) {
boostedFiles.push(fileName)
} else {
nextRemainingFiles.push(fileName)
}
}
remainingFiles = nextRemainingFiles
nextRemainingFiles = []
allBoostedFiles.push(...boostedFiles.sort((a, b) => a.length - b.length))
}
const sortedMatchingFiles = allBoostedFiles.concat(remainingFiles).slice(0, numResults)
const results = await Promise.all(
sortedMatchingFiles
.map(async fileName => {
const uri = vscode.Uri.file(path.join(rootPath, fileName))
const content = (await vscode.workspace.openTextDocument(uri)).getText()
return {
fileName,
content,
}
})
.reverse()
)
const time4 = performance.now()
debug(
'FilenameContextFetcher:getContext',
JSON.stringify({
duration: time4 - time0,
queryToFileFragments: { duration: time2 - time1, fragments: filenameFragments },
getFilenames: { duration: time3 - time2 },
}),
{ verbose: { matchingFiles: unsortedMatchingFiles, results: results.map(r => r.fileName) } }
)
return results
}
private async queryToFileFragments(query: string): Promise<string[]> {
const filenameFragments = await new Promise<string[]>((resolve, reject) => {
let responseText = ''
this.chatClient.chat(
[
{
speaker: 'human',
text: `Write 3 filename fragments that would be contained by files in a git repository that are relevant to answering the following user query: <query>${query}</query> Your response should be only a space-delimited list of filename fragments and nothing else.`,
},
],
{
onChange: (text: string) => {
responseText = text
},
onComplete: () => {
resolve(responseText.split(/\s+/).filter(e => e.length > 0))
},
onError: (message: string, statusCode?: number) => {
reject(new Error(message))
},
},
{
temperature: 0,
fast: true,
}
)
})
const uniqueFragments = uniq(filenameFragments.map(e => e.toLocaleLowerCase()))
return uniqueFragments
}
private async getFilenames(rootPath: string, filenameFragments: string[], maxDepth: number): Promise<string[]> {
const searchPattern = '{' + filenameFragments.map(fragment => `**${fragment}**`).join(',') + '}'
const rgArgs = [
'--files',
'--iglob',
searchPattern,
'--crlf',
'--fixed-strings',
'--no-config',
'--no-ignore-global',
`--max-depth=${maxDepth}`,
]
const results = await new Promise<string>((resolve, reject) => {
execFile(
this.rgPath,
rgArgs,
{
cwd: rootPath,
maxBuffer: 1024 * 1024 * 1024,
},
(error, stdout, stderr) => {
if (error?.code === 2) {
reject(new Error(`${error.message}: ${stderr}`))
} else {
resolve(stdout)
}
}
)
})
return results
.split('\n')
.map(r => r.trim())
.filter(r => r.length > 0)
.sort((a, b) => a.length - b.length)
}
}

View File

@ -1,14 +1,17 @@
import { execFile, spawn } from 'child_process'
import * as path from 'path'
import Assembler from 'stream-json/Assembler'
import StreamValues from 'stream-json/streamers/StreamValues'
import * as vscode from 'vscode'
import winkUtils from 'wink-nlp-utils'
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
import { Editor } from '@sourcegraph/cody-shared/src/editor'
import { KeywordContextFetcher, KeywordContextFetcherResult } from '@sourcegraph/cody-shared/src/keyword-context'
import { KeywordContextFetcher, ContextResult } from '@sourcegraph/cody-shared/src/local-context'
import { logEvent } from '../event-logger'
import { debug } from '../log'
/**
* Exclude files without extensions and hidden files (starts with '.')
@ -17,6 +20,7 @@ import { logEvent } from '../event-logger'
* Note: Ripgrep excludes binary files and respects .gitignore by default
*/
const fileExtRipgrepParams = [
'--ignore-case',
'-g',
'*.*',
'-g',
@ -25,10 +29,10 @@ const fileExtRipgrepParams = [
'!*.lock',
'-g',
'!*.snap',
'--threads',
'1',
'--max-filesize',
'1M',
'10K',
'--max-depth',
'10',
]
interface RipgrepStreamData {
@ -68,62 +72,34 @@ export function regexForTerms(...terms: Term[]): string {
return `(?:${inner.join('|')})`
}
export function userQueryToKeywordQuery(query: string): Term[] {
const longestCommonPrefix = (s: string, t: string): string => {
let endIdx = 0
for (let i = 0; i < s.length && i < t.length; i++) {
if (s[i] !== t[i]) {
break
}
endIdx = i + 1
function longestCommonPrefix(s: string, t: string): string {
let endIdx = 0
for (let i = 0; i < s.length && i < t.length; i++) {
if (s[i] !== t[i]) {
break
}
return s.slice(0, endIdx)
endIdx = i + 1
}
const origWords: string[] = []
for (const chunk of query.split(/\W+/)) {
if (chunk.trim().length === 0) {
continue
}
origWords.push(...winkUtils.string.tokenize0(chunk))
}
const filteredWords = winkUtils.tokens.removeWords(origWords)
const terms = new Map<string, Term>()
for (const word of filteredWords) {
// Ignore ASCII-only strings of length 2 or less
if (word.length <= 2) {
let skip = true
for (let i = 0; i < word.length; i++) {
if (word.charCodeAt(i) >= 128) {
// non-ASCII
skip = false
break
}
}
if (skip) {
continue
}
}
const stem = winkUtils.string.stem(word)
const term = terms.get(stem) || {
stem,
originals: [word],
prefix: longestCommonPrefix(word.toLowerCase(), stem),
count: 0,
}
term.originals.push(word)
term.count++
terms.set(stem, term)
}
return [...terms.values()]
return s.slice(0, endIdx)
}
/**
* A local context fetcher that uses a LLM to generate a keyword query, which is then
* converted to a regex fed to ripgrep to search for files that are relevant to the
* user query.
*/
export class LocalKeywordContextFetcher implements KeywordContextFetcher {
constructor(private rgPath: string, private editor: Editor) {}
constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {}
public async getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
console.log('fetching keyword matches')
/**
* Returns pieces of context relevant for the given query. Uses a keyword-search-based
* approach.
* @param query user query
* @param numResults the number of context results to return
* @returns a list of context results, sorted in *reverse* order (that is,
* the most important result appears at the bottom)
*/
public async getContext(query: string, numResults: number): Promise<ContextResult[]> {
const startTime = performance.now()
const rootPath = this.editor.getWorkspaceRootPath()
if (!rootPath) {
@ -141,18 +117,89 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
)
const searchDuration = performance.now() - startTime
logEvent('CodyVSCodeExtension:keywordContext:searchDuration', searchDuration, searchDuration)
debug('LocalKeywordContextFetcher:getContext', JSON.stringify({ searchDuration }))
return messagePairs.reverse().flat()
}
private async userQueryToExpandedKeywords(query: string): Promise<Map<string, Term>> {
const start = performance.now()
const keywords = await new Promise<string[]>((resolve, reject) => {
let responseText = ''
this.chatClient.chat(
[
{
speaker: 'human',
text: `Write 3-5 keywords that you would use to search for code snippets that are relevant to answering the following user query: <query>${query}</query> Your response should be only a list of space-delimited keywords and nothing else.`,
},
],
{
onChange: (text: string) => {
responseText = text
},
onComplete: () => {
resolve(responseText.split(/\s+/).filter(e => e.length > 0))
},
onError: (message: string, statusCode?: number) => {
reject(new Error(message))
},
},
{
temperature: 0,
fast: true,
}
)
})
const terms = new Map<string, Term>()
for (const kw of keywords) {
const stem = winkUtils.string.stem(kw)
if (terms.has(stem)) {
continue
}
terms.set(stem, {
count: 1,
originals: [kw],
prefix: longestCommonPrefix(kw.toLowerCase(), stem),
stem,
})
}
debug(
'LocalKeywordContextFetcher:userQueryToExpandedKeywords',
JSON.stringify({ duration: performance.now() - start })
)
return terms
}
private async userQueryToKeywordQuery(query: string): Promise<Term[]> {
const terms = new Map<string, Term>()
const keywordExpansionStartTime = Date.now()
const expandedTerms = await this.userQueryToExpandedKeywords(query)
const keywordExpansionDuration = Date.now() - keywordExpansionStartTime
for (const [stem, term] of expandedTerms) {
if (terms.has(stem)) {
continue
}
terms.set(stem, term)
}
debug(
'LocalKeywordContextFetcher:userQueryToKeywordQuery',
'keyword expansion',
JSON.stringify({
duration: keywordExpansionDuration,
expandedTerms: [...expandedTerms.values()].map(v => v.prefix),
})
)
const ret = [...terms.values()]
return ret
}
// Return context results for the Codebase Context Search recipe
public async getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
console.log('fetching keyword search context...')
public async getSearchContext(query: string, numResults: number): Promise<ContextResult[]> {
const rootPath = this.editor.getWorkspaceRootPath()
if (!rootPath) {
return []
}
const stems = userQueryToKeywordQuery(query)
const stems = (await this.userQueryToKeywordQuery(query))
.map(t => (t.prefix.length < 4 ? t.originals[0] : t.prefix))
.join('|')
const filesnamesWithScores = await this.fetchKeywordFiles(rootPath, query)
@ -180,8 +227,10 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
terms: Term[],
rootPath: string
): Promise<{ [filename: string]: { bytesSearched: number } }> {
const start = performance.now()
const regexQuery = `\\b${regexForTerms(...terms)}`
const proc = spawn(this.rgPath, ['-i', ...fileExtRipgrepParams, '--json', regexQuery, './'], {
const rgArgs = [...fileExtRipgrepParams, '--json', regexQuery, '.']
const proc = spawn(this.rgPath, rgArgs, {
cwd: rootPath,
stdio: ['ignore', 'pipe', process.stderr],
windowsHide: true,
@ -191,21 +240,40 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
bytesSearched: number
}
} = {}
// Process the ripgrep JSON output to get the file sizes. We use an object filter to
// fast-filter out irrelevant lines of output
const objectFilter = (assembler: Assembler): boolean | undefined => {
// Each ripgrep JSON line begins with the following format:
//
// {"type":"begin|match|end","data":"...
//
// We only care about the "type":"end" lines, which contain the file size in bytes.
if (assembler.key === null && assembler.stack.length === 0 && assembler.current.type) {
return assembler.current.type === 'end'
}
// return undefined to indicate our uncertainty at this moment
return undefined
}
await new Promise<void>((resolve, reject) => {
try {
proc.stdout
.pipe(StreamValues.withParser())
.pipe(StreamValues.withParser({ objectFilter }))
.on('data', data => {
try {
const typedData = data as RipgrepStreamData
switch (typedData.value.type) {
case 'end':
if (!fileTermCounts[typedData.value.data.path.text]) {
fileTermCounts[typedData.value.data.path.text] = { bytesSearched: 0 }
case 'end': {
let filename = typedData.value.data.path.text
if (filename.startsWith(`.${path.sep}`)) {
filename = filename.slice(2)
}
fileTermCounts[typedData.value.data.path.text].bytesSearched =
typedData.value.data.stats.bytes_searched
if (!fileTermCounts[filename]) {
fileTermCounts[filename] = { bytesSearched: 0 }
}
fileTermCounts[filename].bytesSearched = typedData.value.data.stats.bytes_searched
break
}
}
} catch (error) {
reject(error)
@ -216,6 +284,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
reject(error)
}
})
debug('fetchFileStats', JSON.stringify({ duration: performance.now() - start }))
return fileTermCounts
}
@ -227,20 +296,21 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
fileTermCounts: { [filename: string]: { [stem: string]: number } }
termTotalFiles: { [stem: string]: number }
}> {
const start = performance.now()
const termFileCountsArr: { fileCounts: { [filename: string]: number }; filesSearched: number }[] =
await Promise.all(
queryTerms.map(async term => {
const rgArgs = [
...fileExtRipgrepParams,
'--count-matches',
'--stats',
`\\b${regexForTerms(term)}`,
'.',
]
const out = await new Promise<string>((resolve, reject) => {
execFile(
this.rgPath,
[
'-i',
...fileExtRipgrepParams,
'--count-matches',
'--stats',
`\\b${regexForTerms(term)}`,
'./',
],
rgArgs,
{
cwd: rootPath,
maxBuffer: 1024 * 1024 * 1024,
@ -272,8 +342,12 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
continue
}
try {
let filename = terms[0]
if (filename.startsWith(`.${path.sep}`)) {
filename = filename.slice(2)
}
const count = parseInt(terms[1], 10)
fileCounts[terms[0]] = count
fileCounts[filename] = count
} catch {
console.error(`could not parse count from ${terms[1]}`)
}
@ -282,6 +356,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
})
)
debug('LocalKeywordContextFetcher.fetchFileMatches', JSON.stringify({ duration: performance.now() - start }))
let totalFilesSearched = -1
for (const { filesSearched } of termFileCountsArr) {
if (totalFilesSearched >= 0 && totalFilesSearched !== filesSearched) {
@ -316,13 +391,15 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
rootPath: string,
rawQuery: string
): Promise<{ filename: string; score: number }[]> {
const query = userQueryToKeywordQuery(rawQuery)
const query = await this.userQueryToKeywordQuery(rawQuery)
const fetchFilesStart = performance.now()
const fileMatchesPromise = this.fetchFileMatches(query, rootPath)
const fileStatsPromise = this.fetchFileStats(query, rootPath)
const fileMatches = await fileMatchesPromise
const fileStats = await fileStatsPromise
const fetchFilesDuration = performance.now() - fetchFilesStart
debug('LocalKeywordContextFetcher:fetchKeywordFiles', JSON.stringify({ fetchFilesDuration }))
const { fileTermCounts, termTotalFiles, totalFiles } = fileMatches
const idfDict = idf(termTotalFiles, totalFiles)

View File

@ -0,0 +1,24 @@
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
import { LLMReranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank'
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
import { debug } from './log'
import { TestSupport } from './test-support'
export function getRerankWithLog(
chatClient: ChatClient
): (query: string, results: ContextResult[]) => Promise<ContextResult[]> {
if (TestSupport.instance) {
const reranker = TestSupport.instance.getReranker()
return (query: string, results: ContextResult[]): Promise<ContextResult[]> => reranker.rerank(query, results)
}
const reranker = new LLMReranker(chatClient)
return async (userQuery: string, results: ContextResult[]): Promise<ContextResult[]> => {
const start = performance.now()
const rerankedResults = await reranker.rerank(userQuery, results)
const duration = performance.now() - start
debug('Reranker:rerank', JSON.stringify({ duration }))
return rerankedResults
}
}

View File

@ -43,12 +43,13 @@ export class LocalStorage {
([humanMessage, assistantMessageAndContextFiles]) => ({
humanMessage,
assistantMessage: assistantMessageAndContextFiles,
context: assistantMessageAndContextFiles.contextFiles
fullContext: assistantMessageAndContextFiles.contextFiles
? assistantMessageAndContextFiles.contextFiles.map(fileName => ({
speaker: 'assistant',
fileName,
}))
: [],
usedContextFiles: [],
// Timestamp not recoverable so we use the group timestamp
timestamp: id,
})

View File

@ -1,4 +1,6 @@
import { ChatMessage } from '@sourcegraph/cody-shared/src/chat/transcript/messages'
import { MockReranker, Reranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank'
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
import { ChatViewProvider } from './chat/ChatViewProvider'
import { FixupTask } from './non-stop/FixupTask'
@ -39,6 +41,17 @@ export class TestSupport {
public chatViewProvider = new Rendezvous<ChatViewProvider>()
public reranker: Reranker | undefined
public getReranker(): Reranker {
if (!this.reranker) {
return new MockReranker(
(_: string, results: ContextResult[]): Promise<ContextResult[]> => Promise.resolve(results)
)
}
return this.reranker
}
public async chatTranscript(): Promise<ChatMessage[]> {
return (await this.chatViewProvider.get()).transcriptForTesting(this)
}

View File

@ -8,6 +8,7 @@ type CompletionsResolver interface {
type CompletionsArgs struct {
Input CompletionsInput
Fast bool
}
type Message struct {

View File

@ -45,7 +45,14 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
return "", errors.New("completions are not configured or disabled")
}
ctx, done := httpapi.Trace(ctx, "resolver", completionsConfig.ChatModel).
var chatModel string
if args.Fast {
chatModel = completionsConfig.FastChatModel
} else {
chatModel = completionsConfig.ChatModel
}
ctx, done := httpapi.Trace(ctx, "resolver", chatModel).
WithErrorP(&err).
Build()
defer done()
@ -66,7 +73,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
params := convertParams(args)
// No way to configure the model through the request, we hard code to chat.
params.Model = completionsConfig.ChatModel
params.Model = chatModel
resp, err := client.Complete(ctx, types.CompletionsFeatureChat, params)
if err != nil {
return "", errors.Wrap(err, "client.Complete")

View File

@ -79,6 +79,10 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio
completionsConfig.ChatModel = completionsConfig.Model
}
if completionsConfig.FastChatModel == "" {
completionsConfig.FastChatModel = completionsConfig.ChatModel
}
// TODO: Temporary workaround to fix instances where no completion model is set.
if completionsConfig.CompletionModel == "" {
completionsConfig.CompletionModel = "claude-instant-v1"
@ -102,6 +106,7 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio
// TODO: These are not required right now as upstream overwrites this,
// but should we switch to Cody Gateway they will be.
ChatModel: "claude-v1",
FastChatModel: "claude-instant-v1",
CompletionModel: "claude-instant-v1",
}
}

View File

@ -28,13 +28,15 @@ func TestGetCompletionsConfig(t *testing.T) {
Enabled: true,
Provider: "anthropic",
ChatModel: "claude-v1",
FastChatModel: "claude-instant-v1",
CompletionModel: "claude-instant-v1",
},
},
want: autogold.Expect(&schema.Completions{
ChatModel: "claude-v1", CompletionModel: "claude-instant-v1",
Enabled: true,
Provider: "anthropic",
FastChatModel: "claude-instant-v1",
Enabled: true,
Provider: "anthropic",
}),
},
{

View File

@ -19,7 +19,7 @@ func NewCodeCompletionsHandler(logger log.Logger, db database.DB) http.Handler {
logger = logger.Scoped("code", "code completions handler")
rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureCode)
return newCompletionsHandler(rl, "code", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
return newCompletionsHandler(rl, "code", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string {
// No user defined models for now.
// TODO(eseliger): Look into reviving this, but it was unused so far.
return c.CompletionModel

View File

@ -19,7 +19,12 @@ import (
// being cancelled.
const maxRequestDuration = time.Minute
func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(types.CompletionRequestParameters, *schema.Completions) string, handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter)) http.Handler {
func newCompletionsHandler(
rl RateLimiter,
traceFamily string,
getModel func(types.CodyCompletionRequestParameters, *schema.Completions) string,
handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter),
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed)
@ -40,13 +45,17 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ
return
}
var requestParams types.CompletionRequestParameters
var requestParams types.CodyCompletionRequestParameters
if err := json.NewDecoder(r.Body).Decode(&requestParams); err != nil {
http.Error(w, "could not decode request body", http.StatusBadRequest)
return
}
// TODO: Model is not configurable but technically allowed in the request body right now.
if requestParams.Model != "" {
http.Error(w, "user-specified models are not allowed", http.StatusBadRequest)
return
}
requestParams.Model = getModel(requestParams, completionsConfig)
var err error
@ -77,7 +86,7 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ
return
}
handle(ctx, requestParams, completionClient, w)
handle(ctx, requestParams.CompletionRequestParameters, completionClient, w)
})
}

View File

@ -19,8 +19,11 @@ func NewChatCompletionsStreamHandler(logger log.Logger, db database.DB) http.Han
logger = logger.Scoped("chat", "chat completions handler")
rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureChat)
return newCompletionsHandler(rl, "chat", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
return newCompletionsHandler(rl, "chat", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string {
// No user defined models for now.
if requestParams.Fast {
return c.FastChatModel
}
return c.ChatModel
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
eventWriter, err := streamhttp.NewWriter(w)

View File

@ -37,6 +37,14 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str
return fmt.Sprintf("%s %s", prefix, m.Text), nil
}
type CodyCompletionRequestParameters struct {
CompletionRequestParameters
// When Fast is true, then it is used as a hint to prefer a model
// that is faster (but probably "dumber").
Fast bool
}
type CompletionRequestParameters struct {
// Prompt exists only for backwards compatibility. Do not use it in new
// implementations. It will be removed once we are reasonably sure 99%

View File

@ -1445,6 +1445,13 @@ importers:
'@sourcegraph/http-client':
specifier: workspace:*
version: link:../http-client
xml2js:
specifier: ^0.6.0
version: 0.6.0
devDependencies:
'@types/xml2js':
specifier: ^0.4.11
version: 0.4.11
client/cody-slack:
dependencies:
@ -10016,6 +10023,12 @@ packages:
dependencies:
'@types/node': 13.13.5
/@types/xml2js@0.4.11:
resolution: {integrity: sha512-JdigeAKmCyoJUiQljjr7tQG3if9NkqGUgwEUqBvV0N7LM4HyQk7UXCnusRa1lnvXAEYJ8mw8GtZWioagNztOwA==}
dependencies:
'@types/node': 13.13.5
dev: true
/@types/yargs-parser@21.0.0:
resolution: {integrity: sha512-iO9ZQHkZxHn4mSakYV0vFHAVDyEOIJQrV2uZ06HxEPcx+mt8swXoZHIbaaJ2crJYFfErySgktuTZ3BeLz+XmFA==}
@ -26143,7 +26156,6 @@ packages:
/sax@1.2.4:
resolution: {integrity: sha512-NqVDv9TpANUjFm0N8uM5GxL36UgKi9/atZw+x7YFnQ8ckwFGKrl4xX4yWtrey3UJm5nP1kUbnYgLopqWNSRhWw==}
dev: true
/saxes@5.0.1:
resolution: {integrity: sha512-5LBh1Tls8c9xgGjw3QrMwETmTMVk0oFgvrFSvWx62llR2hcEInrKNZ2GZCCuuy2lvWrdl5jhbpeqc5hRYKFOcw==}
@ -30109,6 +30121,14 @@ packages:
xmlbuilder: 11.0.1
dev: true
/xml2js@0.6.0:
resolution: {integrity: sha512-eLTh0kA8uHceqesPqSE+VvO1CDDJWMwlQfB6LuN6T8w6MaDJ8Txm8P7s5cHD0miF0V+GGTZrDQfxPZQVsur33w==}
engines: {node: '>=4.0.0'}
dependencies:
sax: 1.2.4
xmlbuilder: 11.0.1
dev: false
/xml@1.0.1:
resolution: {integrity: sha512-huCv9IH9Tcf95zuYCsQraZtWnJvBtLVE0QHMOs8bWyZAFZNDcYjsPq1nEx8jKA9y+Beo9v+7OBPRisQTjinQMw==}
dev: true
@ -30116,7 +30136,6 @@ packages:
/xmlbuilder@11.0.1:
resolution: {integrity: sha512-fDlsI/kFEx7gLvbecc0/ohLG50fugQp8ryHzMTuW9vSa1GJ0XYWKnhsUx7oie3G98+r56aTQIUB4kht42R3JvA==}
engines: {node: '>=4.0'}
dev: true
/xmlchars@2.2.0:
resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==}

View File

@ -564,6 +564,8 @@ type Completions struct {
Enabled bool `json:"enabled"`
// Endpoint description: The endpoint under which to reach the provider. Currently only used for provider types "sourcegraph", "openai" and "anthropic". The default values are "https://cody-gateway.sourcegraph.com", "https://api.openai.com/v1/chat/completions", and "https://api.anthropic.com/v1/complete" for Sourcegraph, OpenAI, and Anthropic, respectively.
Endpoint string `json:"endpoint,omitempty"`
// FastChatModel description: The model used for fast chat completions.
FastChatModel string `json:"fastChatModel,omitempty"`
// Model description: DEPRECATED. Use chatModel instead.
Model string `json:"model,omitempty"`
// PerUserCodeCompletionsDailyLimit description: If > 0, enables the maximum number of code completions requests allowed to be made by a single user account in a day. On instances that allow anonymous requests, the rate limit is enforced by IP.

View File

@ -2422,6 +2422,10 @@
"description": "DEPRECATED. Use chatModel instead.",
"type": "string"
},
"fastChatModel": {
"description": "The model used for fast chat completions.",
"type": "string"
},
"chatModel": {
"description": "The model used for chat completions. If using the default provider 'sourcegraph', a reasonable default model will be set.",
"type": "string"