diff --git a/client/cody-cli/src/app.ts b/client/cody-cli/src/app.ts index 984a9887bd6..585f16d406d 100644 --- a/client/cody-cli/src/app.ts +++ b/client/cody-cli/src/app.ts @@ -111,7 +111,8 @@ async function startCLI() { } } - const finalPrompt = await transcript.toPrompt(getPreamble(codebase)) + const { prompt: finalPrompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(codebase)) + transcript.setUsedContextFilesForLastInteraction(contextFiles) let text = '' streamCompletions(completionsClient, finalPrompt, { diff --git a/client/cody-cli/src/context.ts b/client/cody-cli/src/context.ts index aa7eac68e01..dd7ddb0aa05 100644 --- a/client/cody-cli/src/context.ts +++ b/client/cody-cli/src/context.ts @@ -1,6 +1,6 @@ import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context' import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client' -import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context' +import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context' import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql' import { isError } from '@sourcegraph/cody-shared/src/utils' @@ -26,7 +26,8 @@ export async function createCodebaseContext( { useContext: contextType, serverEndpoint }, codebase, embeddingsSearch, - new LocalKeywordContextFetcherMock() + new LocalKeywordContextFetcherMock(), + null ) return codebaseContext diff --git a/client/cody-cli/src/interactions.ts b/client/cody-cli/src/interactions.ts index e78ff8ceee4..922b5009b5c 100644 --- a/client/cody-cli/src/interactions.ts +++ b/client/cody-cli/src/interactions.ts @@ -45,7 +45,8 @@ export async function interactionFromMessage( new Interaction( { speaker: 'human', text, displayText: text }, { speaker: 'assistant', text: '', displayText: '' }, - contextMessages + contextMessages, + [] ) ) } diff --git a/client/cody-shared/BUILD.bazel b/client/cody-shared/BUILD.bazel index 79571dc6d07..47deda52863 100644 --- a/client/cody-shared/BUILD.bazel +++ b/client/cody-shared/BUILD.bazel @@ -58,6 +58,7 @@ ts_project( "src/chat/viewHelpers.ts", "src/codebase-context/index.ts", "src/codebase-context/messages.ts", + "src/codebase-context/rerank.ts", "src/configuration.ts", "src/editor/index.ts", "src/editor/withPreselectedOptions.ts", @@ -68,7 +69,7 @@ ts_project( "src/hallucinations-detector/index.ts", "src/intent-detector/client.ts", "src/intent-detector/index.ts", - "src/keyword-context/index.ts", + "src/local-context/index.ts", "src/prompt/constants.ts", "src/prompt/prompt-mixin.ts", "src/prompt/templates.ts", @@ -93,6 +94,8 @@ ts_project( deps = [ ":node_modules/@sourcegraph/common", ":node_modules/@sourcegraph/http-client", + ":node_modules/@types/xml2js", + ":node_modules/xml2js", "//:node_modules/@microsoft/fetch-event-source", "//:node_modules/@types/isomorphic-fetch", "//:node_modules/@types/marked", diff --git a/client/cody-shared/package.json b/client/cody-shared/package.json index dd3b0c3638c..75d194370bd 100644 --- a/client/cody-shared/package.json +++ b/client/cody-shared/package.json @@ -19,6 +19,10 @@ }, "dependencies": { "@sourcegraph/common": "workspace:*", - "@sourcegraph/http-client": "workspace:*" + "@sourcegraph/http-client": "workspace:*", + "xml2js": "^0.6.0" + }, + "devDependencies": { + "@types/xml2js": "^0.4.11" } } diff --git a/client/cody-shared/src/chat/chat.ts b/client/cody-shared/src/chat/chat.ts index 90ef9898301..2f5a2934d4a 100644 --- a/client/cody-shared/src/chat/chat.ts +++ b/client/cody-shared/src/chat/chat.ts @@ -3,7 +3,9 @@ import { Message } from '../sourcegraph-api' import type { SourcegraphCompletionsClient } from '../sourcegraph-api/completions/client' import type { CompletionParameters, CompletionCallbacks } from '../sourcegraph-api/completions/types' -const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit = { +type ChatParameters = Omit + +const DEFAULT_CHAT_COMPLETION_PARAMETERS: ChatParameters = { temperature: 0.2, maxTokensToSample: SOLUTION_TOKEN_LENGTH, topK: -1, @@ -13,10 +15,17 @@ const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit export class ChatClient { constructor(private completions: SourcegraphCompletionsClient) {} - public chat(messages: Message[], cb: CompletionCallbacks): () => void { + public chat(messages: Message[], cb: CompletionCallbacks, params?: Partial): () => void { const isLastMessageFromHuman = messages.length > 0 && messages[messages.length - 1].speaker === 'human' const augmentedMessages = isLastMessageFromHuman ? messages.concat([{ speaker: 'assistant' }]) : messages - return this.completions.stream({ messages: augmentedMessages, ...DEFAULT_CHAT_COMPLETION_PARAMETERS }, cb) + return this.completions.stream( + { + ...DEFAULT_CHAT_COMPLETION_PARAMETERS, + ...params, + messages: augmentedMessages, + }, + cb + ) } } diff --git a/client/cody-shared/src/chat/client.ts b/client/cody-shared/src/chat/client.ts index efb4b5a0589..df8051127eb 100644 --- a/client/cody-shared/src/chat/client.ts +++ b/client/cody-shared/src/chat/client.ts @@ -69,7 +69,7 @@ export async function createClient({ const embeddingsSearch = repoId ? new SourcegraphEmbeddingsSearchClient(graphqlClient, repoId, true) : null - const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null) + const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null, null) const intentDetector = new SourcegraphIntentDetectorClient(graphqlClient) @@ -116,10 +116,11 @@ export async function createClient({ sendTranscript() - const prompt = await transcript.toPrompt(getPreamble(config.codebase)) + const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(config.codebase)) + transcript.setUsedContextFilesForLastInteraction(contextFiles) + const responsePrefix = interaction.getAssistantMessage().prefix ?? '' let rawText = '' - chatClient.chat(prompt, { onChange(_rawText) { rawText = _rawText diff --git a/client/cody-shared/src/chat/recipes/chat-question.ts b/client/cody-shared/src/chat/recipes/chat-question.ts index a9a68b57049..a96abf297d4 100644 --- a/client/cody-shared/src/chat/recipes/chat-question.ts +++ b/client/cody-shared/src/chat/recipes/chat-question.ts @@ -31,7 +31,8 @@ export class ChatQuestion implements Recipe { context.intentDetector, context.codebaseContext, context.editor.getActiveTextEditorSelection() || null - ) + ), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/context-search.ts b/client/cody-shared/src/chat/recipes/context-search.ts index 60ce0e91a29..7db67ddcb3e 100644 --- a/client/cody-shared/src/chat/recipes/context-search.ts +++ b/client/cody-shared/src/chat/recipes/context-search.ts @@ -46,7 +46,8 @@ export class ContextSearch implements Recipe { text: '', displayText: await this.displaySearchResults(truncatedText, context.codebaseContext, wsRootPath), }, - new Promise(resolve => resolve([])) + new Promise(resolve => resolve([])), + [] ) } diff --git a/client/cody-shared/src/chat/recipes/explain-code-detailed.ts b/client/cody-shared/src/chat/recipes/explain-code-detailed.ts index 70ca04ad30b..a40afa94b4c 100644 --- a/client/cody-shared/src/chat/recipes/explain-code-detailed.ts +++ b/client/cody-shared/src/chat/recipes/explain-code-detailed.ts @@ -32,7 +32,8 @@ export class ExplainCodeDetailed implements Recipe { truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/explain-code-high-level.ts b/client/cody-shared/src/chat/recipes/explain-code-high-level.ts index 61f23352261..7be00c5488d 100644 --- a/client/cody-shared/src/chat/recipes/explain-code-high-level.ts +++ b/client/cody-shared/src/chat/recipes/explain-code-high-level.ts @@ -32,7 +32,8 @@ export class ExplainCodeHighLevel implements Recipe { truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/file-touch.ts b/client/cody-shared/src/chat/recipes/file-touch.ts index 96fc197e53a..c0eadc7b4a5 100644 --- a/client/cody-shared/src/chat/recipes/file-touch.ts +++ b/client/cody-shared/src/chat/recipes/file-touch.ts @@ -103,7 +103,8 @@ export class FileTouch implements Recipe { speaker: 'assistant', prefix: 'Working on it! I will notify you when the file is ready.\n', }, - this.getContextMessages(selection, currentDir) + this.getContextMessages(selection, currentDir), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/find-code-smells.ts b/client/cody-shared/src/chat/recipes/find-code-smells.ts index 6b61b3ec619..1d2db10b67e 100644 --- a/client/cody-shared/src/chat/recipes/find-code-smells.ts +++ b/client/cody-shared/src/chat/recipes/find-code-smells.ts @@ -36,7 +36,8 @@ If you have no ideas because the code looks fine, feel free to say that it alrea prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - new Promise(resolve => resolve([])) + new Promise(resolve => resolve([])), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/fixup.ts b/client/cody-shared/src/chat/recipes/fixup.ts index 93bed995064..4768026fb9f 100644 --- a/client/cody-shared/src/chat/recipes/fixup.ts +++ b/client/cody-shared/src/chat/recipes/fixup.ts @@ -67,7 +67,8 @@ export class Fixup implements Recipe { speaker: 'assistant', prefix: 'Check your document for updates from Cody.\n', }, - this.getContextMessages(selection.selectedText, context.codebaseContext) + this.getContextMessages(selection.selectedText, context.codebaseContext), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/generate-docstring.ts b/client/cody-shared/src/chat/recipes/generate-docstring.ts index 4bede5cb717..b60efba4ef6 100644 --- a/client/cody-shared/src/chat/recipes/generate-docstring.ts +++ b/client/cody-shared/src/chat/recipes/generate-docstring.ts @@ -60,7 +60,8 @@ export class GenerateDocstring implements Recipe { truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/generate-pr-description.ts b/client/cody-shared/src/chat/recipes/generate-pr-description.ts index d9211225ac1..535e9c1fbc9 100644 --- a/client/cody-shared/src/chat/recipes/generate-pr-description.ts +++ b/client/cody-shared/src/chat/recipes/generate-pr-description.ts @@ -52,7 +52,8 @@ export class PrDescription implements Recipe { prefix: emptyGitCommitMessage, text: emptyGitCommitMessage, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } @@ -71,7 +72,8 @@ export class PrDescription implements Recipe { prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/generate-release-notes.ts b/client/cody-shared/src/chat/recipes/generate-release-notes.ts index 20b731b4e62..b76dfc92ab8 100644 --- a/client/cody-shared/src/chat/recipes/generate-release-notes.ts +++ b/client/cody-shared/src/chat/recipes/generate-release-notes.ts @@ -71,7 +71,8 @@ export class ReleaseNotes implements Recipe { prefix: emptyGitLogMessage, text: emptyGitLogMessage, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } @@ -91,7 +92,8 @@ export class ReleaseNotes implements Recipe { prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/generate-test.ts b/client/cody-shared/src/chat/recipes/generate-test.ts index af1aa0dc332..20cc6308e04 100644 --- a/client/cody-shared/src/chat/recipes/generate-test.ts +++ b/client/cody-shared/src/chat/recipes/generate-test.ts @@ -44,7 +44,8 @@ export class GenerateTest implements Recipe { truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/git-log.ts b/client/cody-shared/src/chat/recipes/git-log.ts index c79b2108e01..fe03311c3b8 100644 --- a/client/cody-shared/src/chat/recipes/git-log.ts +++ b/client/cody-shared/src/chat/recipes/git-log.ts @@ -65,7 +65,8 @@ export class GitHistory implements Recipe { prefix: emptyGitLogMessage, text: emptyGitLogMessage, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } @@ -84,7 +85,8 @@ export class GitHistory implements Recipe { prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/improve-variable-names.ts b/client/cody-shared/src/chat/recipes/improve-variable-names.ts index 91c80a0a8e5..2cdd9d4a142 100644 --- a/client/cody-shared/src/chat/recipes/improve-variable-names.ts +++ b/client/cody-shared/src/chat/recipes/improve-variable-names.ts @@ -44,7 +44,8 @@ export class ImproveVariableNames implements Recipe { truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/inline-chat.ts b/client/cody-shared/src/chat/recipes/inline-chat.ts index cdaeba1a57e..6bac1bab589 100644 --- a/client/cody-shared/src/chat/recipes/inline-chat.ts +++ b/client/cody-shared/src/chat/recipes/inline-chat.ts @@ -56,7 +56,8 @@ export class InlineAssist implements Recipe { displayText, }, { speaker: 'assistant' }, - this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor) + this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/next-questions.ts b/client/cody-shared/src/chat/recipes/next-questions.ts index 5199017754c..b4d3f9deaa1 100644 --- a/client/cody-shared/src/chat/recipes/next-questions.ts +++ b/client/cody-shared/src/chat/recipes/next-questions.ts @@ -31,7 +31,8 @@ export class NextQuestions implements Recipe { prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext) + this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/non-stop.ts b/client/cody-shared/src/chat/recipes/non-stop.ts index fe8dbb6d15e..86dc42a16e9 100644 --- a/client/cody-shared/src/chat/recipes/non-stop.ts +++ b/client/cody-shared/src/chat/recipes/non-stop.ts @@ -76,7 +76,8 @@ export class NonStop implements Recipe { speaker: 'assistant', prefix: 'Check your document for updates from Cody.', }, - this.getContextMessages(selection.selectedText, context.codebaseContext) + this.getContextMessages(selection.selectedText, context.codebaseContext), + [] ) ) } diff --git a/client/cody-shared/src/chat/recipes/optimize-code.ts b/client/cody-shared/src/chat/recipes/optimize-code.ts index 474400ada0b..90309f91a49 100644 --- a/client/cody-shared/src/chat/recipes/optimize-code.ts +++ b/client/cody-shared/src/chat/recipes/optimize-code.ts @@ -51,7 +51,8 @@ However if no optimization is possible; just say the code is already optimized. truncatedFollowingText, selection, context.codebaseContext - ) + ), + [] ) } } diff --git a/client/cody-shared/src/chat/recipes/translate.ts b/client/cody-shared/src/chat/recipes/translate.ts index c3950dd3578..77713b2ff87 100644 --- a/client/cody-shared/src/chat/recipes/translate.ts +++ b/client/cody-shared/src/chat/recipes/translate.ts @@ -39,7 +39,8 @@ export class TranslateToLanguage implements Recipe { prefix: assistantResponsePrefix, text: assistantResponsePrefix, }, - Promise.resolve([]) + Promise.resolve([]), + [] ) } } diff --git a/client/cody-shared/src/chat/transcript/index.ts b/client/cody-shared/src/chat/transcript/index.ts index 4e965329f86..afc11430a61 100644 --- a/client/cody-shared/src/chat/transcript/index.ts +++ b/client/cody-shared/src/chat/transcript/index.ts @@ -1,5 +1,6 @@ -import { OldContextMessage } from '../../codebase-context/messages' +import { ContextFile, ContextMessage, OldContextMessage } from '../../codebase-context/messages' import { CHARS_PER_TOKEN, MAX_AVAILABLE_PROMPT_LENGTH } from '../../prompt/constants' +import { PromptMixin } from '../../prompt/prompt-mixin' import { Message } from '../../sourcegraph-api' import { Interaction, InteractionJSON } from './interaction' @@ -20,18 +21,19 @@ export interface TranscriptJSON { } /** - * A transcript of a conversation between a human and an assistant. + * The "model" class that tracks the call and response of the Cody chat box. + * Any "controller" logic belongs outside of this class. */ export class Transcript { public static fromJSON(json: TranscriptJSON): Transcript { return new Transcript( json.interactions.map( - ({ humanMessage, assistantMessage, context, timestamp }) => + ({ humanMessage, assistantMessage, fullContext, usedContextFiles, timestamp }) => new Interaction( humanMessage, assistantMessage, Promise.resolve( - context.map(message => { + fullContext.map(message => { if (message.file) { return message } @@ -44,6 +46,7 @@ export class Transcript { return message }) ), + usedContextFiles, timestamp || new Date().toISOString() ) ), @@ -144,20 +147,53 @@ export class Transcript { return -1 } - public async toPrompt(preamble: Message[] = []): Promise { + public async getPromptForLastInteraction( + preamble: Message[] = [] + ): Promise<{ prompt: Message[]; contextFiles: ContextFile[] }> { + if (this.interactions.length === 0) { + return { prompt: [], contextFiles: [] } + } + const lastInteractionWithContextIndex = await this.getLastInteractionWithContextIndex() const messages: Message[] = [] for (let index = 0; index < this.interactions.length; index++) { - // Include context messages for the last interaction that has a non-empty context. - const interactionMessages = await this.interactions[index].toPrompt( - index === lastInteractionWithContextIndex - ) - messages.push(...interactionMessages) + const interaction = this.interactions[index] + const humanMessage = PromptMixin.mixInto(interaction.getHumanMessage()) + const assistantMessage = interaction.getAssistantMessage() + const contextMessages = await interaction.getFullContext() + if (index === lastInteractionWithContextIndex) { + messages.push(...contextMessages, humanMessage, assistantMessage) + } else { + messages.push(humanMessage, assistantMessage) + } } const preambleTokensUsage = preamble.reduce((acc, message) => acc + estimateTokensUsage(message), 0) - const truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage) - return [...preamble, ...truncatedMessages] + let truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage) + + // Return what context fits in the window + const contextFiles: ContextFile[] = [] + for (const msg of truncatedMessages) { + const contextFile = (msg as ContextMessage).file + if (contextFile) { + contextFiles.push(contextFile) + } + } + + // Filter out extraneous fields from ContextMessage instances + truncatedMessages = truncatedMessages.map(({ speaker, text }) => ({ speaker, text })) + + return { + prompt: [...preamble, ...truncatedMessages], + contextFiles, + } + } + + public setUsedContextFilesForLastInteraction(contextFiles: ContextFile[]): void { + if (this.interactions.length === 0) { + throw new Error('Cannot set context files for empty transcript') + } + this.interactions[this.interactions.length - 1].setUsedContext(contextFiles) } public toChat(): ChatMessage[] { diff --git a/client/cody-shared/src/chat/transcript/interaction.ts b/client/cody-shared/src/chat/transcript/interaction.ts index 75b0edfd1de..7376fe9b470 100644 --- a/client/cody-shared/src/chat/transcript/interaction.ts +++ b/client/cody-shared/src/chat/transcript/interaction.ts @@ -1,89 +1,59 @@ import { ContextMessage, ContextFile } from '../../codebase-context/messages' -import { PromptMixin } from '../../prompt/prompt-mixin' -import { Message } from '../../sourcegraph-api' import { ChatMessage, InteractionMessage } from './messages' export interface InteractionJSON { humanMessage: InteractionMessage assistantMessage: InteractionMessage - context: ContextMessage[] + fullContext: ContextMessage[] + usedContextFiles: ContextFile[] timestamp: string } export class Interaction { - private readonly humanMessage: InteractionMessage - private assistantMessage: InteractionMessage - private cachedContextFiles: ContextFile[] = [] - public readonly timestamp: string - private readonly context: Promise - constructor( - humanMessage: InteractionMessage, - assistantMessage: InteractionMessage, - context: Promise, - timestamp: string = new Date().toISOString() - ) { - this.humanMessage = humanMessage - this.assistantMessage = assistantMessage - this.timestamp = timestamp - - // This is some hacky behavior: returns a promise that resolves to the same array that was passed, - // but also caches the context file names in memory as a side effect. - this.context = context.then(messages => { - const contextFilesMap = messages.reduce((map, { file }) => { - if (!file?.fileName) { - return map - } - map[`${file.repoName || 'repo'}@${file?.revision || 'HEAD'}/${file.fileName}`] = file - return map - }, {} as { [key: string]: ContextFile }) - - // Cache the context files so we don't have to block the UI when calling `toChat` by waiting for the context to resolve. - this.cachedContextFiles = [ - ...Object.keys(contextFilesMap) - .sort((a, b) => a.localeCompare(b)) - .map((key: string) => contextFilesMap[key]), - ] - - return messages - }) - } + private readonly humanMessage: InteractionMessage, + private assistantMessage: InteractionMessage, + private fullContext: Promise, + private usedContextFiles: ContextFile[], + public readonly timestamp: string = new Date().toISOString() + ) {} public getAssistantMessage(): InteractionMessage { - return this.assistantMessage + return { ...this.assistantMessage } } public setAssistantMessage(assistantMessage: InteractionMessage): void { this.assistantMessage = assistantMessage } + public getHumanMessage(): InteractionMessage { + return { ...this.humanMessage } + } + + public async getFullContext(): Promise { + const msgs = await this.fullContext + return msgs.map(msg => ({ ...msg })) + } + public async hasContext(): Promise { - const contextMessages = await this.context + const contextMessages = await this.fullContext return contextMessages.length > 0 } - public async toPrompt(includeContext: boolean): Promise { - const messages: (ContextMessage | InteractionMessage)[] = [ - PromptMixin.mixInto(this.humanMessage), - this.assistantMessage, - ] - if (includeContext) { - messages.unshift(...(await this.context)) - } - - return messages.map(message => ({ speaker: message.speaker, text: message.text })) + public setUsedContext(usedContextFiles: ContextFile[]): void { + this.usedContextFiles = usedContextFiles } /** * Converts the interaction to chat message pair: one message from a human, one from an assistant. */ public toChat(): ChatMessage[] { - return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.cachedContextFiles }] + return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.usedContextFiles }] } public async toChatPromise(): Promise { - await this.context + await this.fullContext return this.toChat() } @@ -91,7 +61,8 @@ export class Interaction { return { humanMessage: this.humanMessage, assistantMessage: this.assistantMessage, - context: await this.context, + fullContext: await this.fullContext, + usedContextFiles: this.usedContextFiles, timestamp: this.timestamp, } } diff --git a/client/cody-shared/src/chat/transcript/transcript.test.ts b/client/cody-shared/src/chat/transcript/transcript.test.ts index d36166a28e3..5653cfd6401 100644 --- a/client/cody-shared/src/chat/transcript/transcript.test.ts +++ b/client/cody-shared/src/chat/transcript/transcript.test.ts @@ -38,7 +38,7 @@ async function generateLongTranscript(): Promise<{ transcript: Transcript; token describe('Transcript', () => { it('generates an empty prompt with no interactions', async () => { const transcript = new Transcript() - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() assert.deepStrictEqual(prompt, []) }) @@ -51,7 +51,7 @@ describe('Transcript', () => { const transcript = new Transcript() transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'how do access tokens work in sourcegraph' }, { speaker: 'assistant', text: undefined }, @@ -78,7 +78,8 @@ describe('Transcript', () => { { useContext: 'embeddings', serverEndpoint: 'https://example.com' }, 'dummy-codebase', embeddings, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ), }) ) @@ -86,7 +87,7 @@ describe('Transcript', () => { const transcript = new Transcript() transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' }, { speaker: 'assistant', text: 'Ok.' }, @@ -114,7 +115,8 @@ describe('Transcript', () => { { useContext: 'embeddings', serverEndpoint: 'https://example.com' }, 'dummy-codebase', embeddings, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ), firstInteraction: true, }) @@ -123,7 +125,7 @@ describe('Transcript', () => { const transcript = new Transcript() transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' }, { speaker: 'assistant', text: 'Ok.' }, @@ -148,7 +150,8 @@ describe('Transcript', () => { { useContext: 'embeddings', serverEndpoint: 'https://example.com' }, 'dummy-codebase', embeddings, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ) const chatQuestionRecipe = new ChatQuestion(() => {}) @@ -175,7 +178,7 @@ describe('Transcript', () => { ) transcript.addInteraction(secondInteraction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'how do access tokens work in sourcegraph' }, { speaker: 'assistant', text: assistantResponse }, @@ -195,7 +198,7 @@ describe('Transcript', () => { const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction) const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages. - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() assert.deepStrictEqual(prompt.length, numExpectedMessages) }) @@ -212,7 +215,7 @@ describe('Transcript', () => { const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction) const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages. - const prompt = await transcript.toPrompt(preamble) + const { prompt } = await transcript.getPromptForLastInteraction(preamble) assert.deepStrictEqual(prompt.length, numExpectedMessages) assert.deepStrictEqual(preamble, prompt.slice(0, 4)) }) @@ -233,7 +236,8 @@ describe('Transcript', () => { { useContext: 'embeddings', serverEndpoint: 'https://example.com' }, 'dummy-codebase', embeddings, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ) const chatQuestionRecipe = new ChatQuestion(() => {}) @@ -249,7 +253,7 @@ describe('Transcript', () => { ) transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' }, { speaker: 'assistant', text: 'Ok.' }, @@ -288,7 +292,7 @@ describe('Transcript', () => { ) transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'how do access tokens work in sourcegraph' }, { speaker: 'assistant', text: undefined }, @@ -309,7 +313,8 @@ describe('Transcript', () => { { useContext: 'embeddings', serverEndpoint: 'https://example.com' }, 'dummy-codebase', embeddings, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ) const chatQuestionRecipe = new ChatQuestion(() => {}) @@ -344,7 +349,7 @@ describe('Transcript', () => { ) transcript.addInteraction(thirdInteraction) - const prompt = await transcript.toPrompt() + const { prompt } = await transcript.getPromptForLastInteraction() const expectedPrompt = [ { speaker: 'human', text: 'how do batch changes work in sourcegraph' }, { speaker: 'assistant', text: 'Smartly.' }, diff --git a/client/cody-shared/src/chat/useClient.ts b/client/cody-shared/src/chat/useClient.ts index ac4a138f6dc..cfb4333d6ef 100644 --- a/client/cody-shared/src/chat/useClient.ts +++ b/client/cody-shared/src/chat/useClient.ts @@ -222,7 +222,14 @@ export const useClient = ({ } const unifiedContextFetcherClient = new UnifiedContextFetcherClient(graphqlClient, repoIds) - const codebaseContext = new CodebaseContext(config, undefined, null, null, unifiedContextFetcherClient) + const codebaseContext = new CodebaseContext( + config, + undefined, + null, + null, + null, + unifiedContextFetcherClient + ) const { humanChatInput = '', prefilledOptions } = options ?? {} // TODO(naman): save scope with each interaction @@ -242,10 +249,13 @@ export const useClient = ({ setIsMessageInProgressState(true) onEvent?.('submit') - const prompt = await transcript.toPrompt(getMultiRepoPreamble(repoNames)) + const { prompt, contextFiles } = await transcript.getPromptForLastInteraction( + getMultiRepoPreamble(repoNames) + ) + transcript.setUsedContextFilesForLastInteraction(contextFiles) + const responsePrefix = interaction.getAssistantMessage().prefix ?? '' let rawText = '' - return new Promise(resolve => { chatClient.chat(prompt, { onChange(_rawText) { diff --git a/client/cody-shared/src/codebase-context/index.ts b/client/cody-shared/src/codebase-context/index.ts index 195b5ae6b0a..f571a410f63 100644 --- a/client/cody-shared/src/codebase-context/index.ts +++ b/client/cody-shared/src/codebase-context/index.ts @@ -1,6 +1,6 @@ import { Configuration } from '../configuration' import { EmbeddingsSearch } from '../embeddings' -import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context' +import { FilenameContextFetcher, KeywordContextFetcher, ContextResult } from '../local-context' import { isMarkdownFile, populateCodeContextTemplate, populateMarkdownContextTemplate } from '../prompt/templates' import { Message } from '../sourcegraph-api' import { EmbeddingsSearchResult } from '../sourcegraph-api/graphql/client' @@ -21,7 +21,9 @@ export class CodebaseContext { private codebase: string | undefined, private embeddings: EmbeddingsSearch | null, private keywords: KeywordContextFetcher | null, - private unifiedContextFetcher?: UnifiedContextFetcher | null + private filenames: FilenameContextFetcher | null, + private unifiedContextFetcher?: UnifiedContextFetcher | null, + private rerank?: (query: string, results: ContextResult[]) => Promise ) {} public getCodebase(): string | undefined { @@ -32,18 +34,28 @@ export class CodebaseContext { this.config = newConfig } + private mergeContextResults(keywordResults: ContextResult[], filenameResults: ContextResult[]): ContextResult[] { + // Just take the single most relevant filename suggestion for now. Otherwise, because our reranking relies solely + // on filename, the filename results would dominate the keyword results. + return filenameResults.slice(-1).concat(keywordResults) + } + + /** + * Returns list of context messages for a given query, sorted in *reverse* order of importance (that is, + * the most important context message appears *last*) + */ public async getContextMessages(query: string, options: ContextSearchOptions): Promise { switch (this.config.useContext) { case 'unified': return this.getUnifiedContextMessages(query, options) case 'keyword': - return this.getKeywordContextMessages(query, options) + return this.getLocalContextMessages(query, options) case 'none': return [] default: return this.embeddings ? this.getEmbeddingsContextMessages(query, options) - : this.getKeywordContextMessages(query, options) + : this.getLocalContextMessages(query, options) } } @@ -58,7 +70,7 @@ export class CodebaseContext { public async getSearchResults( query: string, options: ContextSearchOptions - ): Promise<{ results: KeywordContextFetcherResult[] | EmbeddingsSearchResult[]; endpoint: string }> { + ): Promise<{ results: ContextResult[] | EmbeddingsSearchResult[]; endpoint: string }> { if (this.embeddings && this.config.useContext !== 'keyword') { return { results: await this.getEmbeddingSearchResults(query, options), @@ -141,22 +153,29 @@ export class CodebaseContext { }) } - private async getKeywordContextMessages(query: string, options: ContextSearchOptions): Promise { - const results = await this.getKeywordSearchResults(query, options) - return results.flatMap(({ content, fileName, repoName, revision }) => { - const messageText = populateCodeContextTemplate(content, fileName) - return getContextMessageWithResponse(messageText, { fileName, repoName, revision }) - }) + private async getLocalContextMessages(query: string, options: ContextSearchOptions): Promise { + const keywordResults = this.getKeywordSearchResults(query, options) + const filenameResults = this.getFilenameSearchResults(query, options) + const combinedResults = this.mergeContextResults(await keywordResults, await filenameResults) + const rerankedResults = await (this.rerank ? this.rerank(query, combinedResults) : combinedResults) + const messages = resultsToMessages(rerankedResults) + return messages } - private async getKeywordSearchResults( - query: string, - options: ContextSearchOptions - ): Promise { + private async getKeywordSearchResults(query: string, options: ContextSearchOptions): Promise { if (!this.keywords) { return [] } - return this.keywords.getContext(query, options.numCodeResults + options.numTextResults) + const results = await this.keywords.getContext(query, options.numCodeResults + options.numTextResults) + return results + } + + private async getFilenameSearchResults(query: string, options: ContextSearchOptions): Promise { + if (!this.filenames) { + return [] + } + const results = await this.filenames.getContext(query, options.numCodeResults + options.numTextResults) + return results } } @@ -201,3 +220,10 @@ function mergeConsecutiveResults(results: EmbeddingsSearchResult[]): string[] { return mergedResults } + +function resultsToMessages(results: ContextResult[]): ContextMessage[] { + return results.flatMap(({ content, fileName, repoName, revision }) => { + const messageText = populateCodeContextTemplate(content, fileName) + return getContextMessageWithResponse(messageText, { fileName, repoName, revision }) + }) +} diff --git a/client/cody-shared/src/codebase-context/rerank.ts b/client/cody-shared/src/codebase-context/rerank.ts new file mode 100644 index 00000000000..2de8f33805c --- /dev/null +++ b/client/cody-shared/src/codebase-context/rerank.ts @@ -0,0 +1,93 @@ +import { parseStringPromise } from 'xml2js' + +import { ChatClient } from '../chat/chat' +import { ContextResult } from '../local-context' + +export interface Reranker { + rerank(userQuery: string, results: ContextResult[]): Promise +} + +export class MockReranker implements Reranker { + constructor(private rerank_: (userQuery: string, results: ContextResult[]) => Promise) {} + public rerank(userQuery: string, results: ContextResult[]): Promise { + return this.rerank_(userQuery, results) + } +} + +/** + * A reranker class that uses a LLM to boost high-relevance results. + */ +export class LLMReranker implements Reranker { + constructor(private chatClient: ChatClient) {} + public async rerank(userQuery: string, results: ContextResult[]): Promise { + // Reverse the results so the most important appears first + results = [...results].reverse() + + let out = await new Promise((resolve, reject) => { + let responseText = '' + this.chatClient.chat( + [ + { + speaker: 'human', + text: `I am a professional computer programmer and need help deciding which of these files to read first to answer my question. My question is ${userQuery}. Select the files from the following list that I should read to answer my question, ranked by most relevant first. Format the result as XML, like this: filename 1this is why I chose this itemfilename 2why I chose this item\n${results + .map(r => r.fileName) + .join('\n')}`, + }, + ], + { + onChange: (text: string) => { + responseText = text + }, + onComplete: () => { + resolve(responseText) + }, + onError: (message: string, statusCode?: number) => { + reject(new Error(`Status code ${statusCode}: ${message}`)) + }, + }, + { + temperature: 0, + fast: true, + } + ) + }) + if (out.indexOf('') > 0) { + out = out.slice(out.indexOf('')) + } + if (out.indexOf('') !== out.length - ''.length) { + out = out.slice(0, out.indexOf('') + ''.length) + } + const boostedFilenames = await parseXml(out) + + const resultsMap = Object.fromEntries(results.map(r => [r.fileName, r])) + const boostedNames = new Set() + const rerankedResults = [] + for (const boostedFilename of boostedFilenames) { + const boostedResult = resultsMap[boostedFilename] + if (!boostedResult) { + continue + } + rerankedResults.push(boostedResult) + boostedNames.add(boostedFilename) + } + for (const result of results) { + if (!boostedNames.has(result.fileName)) { + rerankedResults.push(result) + } + } + + rerankedResults.reverse() + return rerankedResults + } +} + +async function parseXml(xml: string): Promise { + const result = await parseStringPromise(xml) + const items = result.list.item + const files: { filename: string; explanation: string }[] = items.map((item: any) => ({ + filename: item.filename[0], + explanation: item.explanation[0], + })) + + return files.map(f => f.filename) +} diff --git a/client/cody-shared/src/keyword-context/index.ts b/client/cody-shared/src/keyword-context/index.ts deleted file mode 100644 index 35f8b3a3399..00000000000 --- a/client/cody-shared/src/keyword-context/index.ts +++ /dev/null @@ -1,11 +0,0 @@ -export interface KeywordContextFetcherResult { - repoName?: string - revision?: string - fileName: string - content: string -} - -export interface KeywordContextFetcher { - getContext(query: string, numResults: number): Promise - getSearchContext(query: string, numResults: number): Promise -} diff --git a/client/cody-shared/src/local-context/index.ts b/client/cody-shared/src/local-context/index.ts new file mode 100644 index 00000000000..c76b0c6157a --- /dev/null +++ b/client/cody-shared/src/local-context/index.ts @@ -0,0 +1,15 @@ +export interface ContextResult { + repoName?: string + revision?: string + fileName: string + content: string +} + +export interface KeywordContextFetcher { + getContext(query: string, numResults: number): Promise + getSearchContext(query: string, numResults: number): Promise +} + +export interface FilenameContextFetcher { + getContext(query: string, numResults: number): Promise +} diff --git a/client/cody-shared/src/sourcegraph-api/completions/types.ts b/client/cody-shared/src/sourcegraph-api/completions/types.ts index c992087d64e..c083e75138f 100644 --- a/client/cody-shared/src/sourcegraph-api/completions/types.ts +++ b/client/cody-shared/src/sourcegraph-api/completions/types.ts @@ -24,6 +24,7 @@ export interface CompletionResponse { } export interface CompletionParameters { + fast?: boolean messages: Message[] maxTokensToSample: number temperature?: number diff --git a/client/cody-shared/src/test/mocks.ts b/client/cody-shared/src/test/mocks.ts index 054c82b0204..474ba4d881d 100644 --- a/client/cody-shared/src/test/mocks.ts +++ b/client/cody-shared/src/test/mocks.ts @@ -4,7 +4,7 @@ import { CodebaseContext } from '../codebase-context' import { ActiveTextEditor, ActiveTextEditorSelection, ActiveTextEditorVisibleContent, Editor } from '../editor' import { EmbeddingsSearch } from '../embeddings' import { IntentDetector } from '../intent-detector' -import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context' +import { KeywordContextFetcher, ContextResult } from '../local-context' import { EmbeddingsSearchResults } from '../sourcegraph-api/graphql' export class MockEmbeddingsClient implements EmbeddingsSearch { @@ -37,11 +37,11 @@ export class MockIntentDetector implements IntentDetector { export class MockKeywordContextFetcher implements KeywordContextFetcher { constructor(private mocks: Partial = {}) {} - public getContext(query: string, numResults: number): Promise { + public getContext(query: string, numResults: number): Promise { return this.mocks.getContext?.(query, numResults) ?? Promise.resolve([]) } - public getSearchContext(query: string, numResults: number): Promise { + public getSearchContext(query: string, numResults: number): Promise { return this.mocks.getSearchContext?.(query, numResults) ?? Promise.resolve([]) } } @@ -111,7 +111,8 @@ export function newRecipeContext(args?: Partial): RecipeContext { { useContext: 'none', serverEndpoint: 'https://example.com' }, 'dummy-codebase', defaultEmbeddingsClient, - defaultKeywordContextFetcher + defaultKeywordContextFetcher, + null ), responseMultiplexer: args.responseMultiplexer || new BotResponseMultiplexer(), firstInteraction: args.firstInteraction ?? false, diff --git a/client/cody-slack/src/mention-handler.ts b/client/cody-slack/src/mention-handler.ts index cb185ac4008..a4618953d40 100644 --- a/client/cody-slack/src/mention-handler.ts +++ b/client/cody-slack/src/mention-handler.ts @@ -44,7 +44,8 @@ export async function handleHumanMessage(event: AppMentionEvent, appContext: App const response = await slackHelpers.postMessage(IN_PROGRESS_MESSAGE, channel, thread_ts) // Generate a prompt and start completion streaming - const prompt = await transcript.toPrompt(SLACK_PREAMBLE) + const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(SLACK_PREAMBLE) + transcript.setUsedContextFilesForLastInteraction(contextFiles) console.log('PROMPT', prompt) startCompletionStreaming(prompt, channel, transcript, response?.ts) } diff --git a/client/cody-slack/src/services/codebase-context.ts b/client/cody-slack/src/services/codebase-context.ts index d2838353346..12b24942ef8 100644 --- a/client/cody-slack/src/services/codebase-context.ts +++ b/client/cody-slack/src/services/codebase-context.ts @@ -2,7 +2,7 @@ import { memoize } from 'lodash' import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context' import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client' -import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context' +import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context' import { isError } from '@sourcegraph/cody-shared/src/utils' import { sourcegraphClient } from './sourcegraph-client' @@ -36,7 +36,8 @@ export async function createCodebaseContext( { useContext: contextType, serverEndpoint }, codebase, embeddingsSearch, - new LocalKeywordContextFetcherMock() + new LocalKeywordContextFetcherMock(), + null ) return codebaseContext diff --git a/client/cody-slack/src/slack/message-interaction.ts b/client/cody-slack/src/slack/message-interaction.ts index 90f6f22f4be..deed6d8ab1d 100644 --- a/client/cody-slack/src/slack/message-interaction.ts +++ b/client/cody-slack/src/slack/message-interaction.ts @@ -56,7 +56,7 @@ class SlackInteraction { } public getTranscriptInteraction() { - return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages)) + return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages), []) } } diff --git a/client/cody/BUILD.bazel b/client/cody/BUILD.bazel index 21791c484a3..4c90ac5dd18 100644 --- a/client/cody/BUILD.bazel +++ b/client/cody/BUILD.bazel @@ -52,9 +52,11 @@ ts_project( "src/extension.ts", "src/extension-api.ts", "src/external-services.ts", - "src/keyword-context/local-keyword-context-fetcher.ts", "src/local-app-detector.ts", + "src/local-context/filename-context-fetcher.ts", + "src/local-context/local-keyword-context-fetcher.ts", "src/log.ts", + "src/logged-rerank.ts", "src/main.ts", "src/non-stop/FixupCodeLenses.ts", "src/non-stop/FixupContentStore.ts", @@ -111,6 +113,7 @@ ts_project( "//:node_modules/@storybook/react", #keep "//:node_modules/@types/classnames", "//:node_modules/@types/jest", #keep + "//:node_modules/@types/lodash", "//:node_modules/@types/lru-cache", "//:node_modules/@types/node", "//:node_modules/@types/react", @@ -121,6 +124,7 @@ ts_project( "//:node_modules/@vscode", "//:node_modules/@vscode/webview-ui-toolkit", "//:node_modules/classnames", + "//:node_modules/lodash", "//:node_modules/react", "//:node_modules/react-dom", "//:node_modules/stream-json", @@ -140,7 +144,6 @@ ts_project( "src/completions/context.test.ts", "src/completions/provider.test.ts", "src/configuration.test.ts", - "src/keyword-context/local-keyword-context-fetcher.test.ts", "src/non-stop/diff.test.ts", "src/non-stop/tracked-range.test.ts", "src/non-stop/utils.test.ts", diff --git a/client/cody/src/chat/ChatViewProvider.ts b/client/cody/src/chat/ChatViewProvider.ts index 64f5aac30b5..5a71c4827db 100644 --- a/client/cody/src/chat/ChatViewProvider.ts +++ b/client/cody/src/chat/ChatViewProvider.ts @@ -1,3 +1,4 @@ +import { spawnSync } from 'child_process' import path from 'path' import * as vscode from 'vscode' @@ -11,17 +12,24 @@ import { ChatMessage, ChatHistory } from '@sourcegraph/cody-shared/src/chat/tran import { reformatBotMessage } from '@sourcegraph/cody-shared/src/chat/viewHelpers' import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context' import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration' +import { Editor } from '@sourcegraph/cody-shared/src/editor' +import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client' import { Guardrails, annotateAttribution } from '@sourcegraph/cody-shared/src/guardrails' import { highlightTokens } from '@sourcegraph/cody-shared/src/hallucinations-detector' import { IntentDetector } from '@sourcegraph/cody-shared/src/intent-detector' import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api' +import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql' +import { isError } from '@sourcegraph/cody-shared/src/utils' import { View } from '../../webviews/NavBar' import { getFullConfig, updateConfiguration } from '../configuration' import { VSCodeEditor } from '../editor/vscode-editor' import { logEvent } from '../event-logger' import { LocalAppDetector } from '../local-app-detector' +import { FilenameContextFetcher } from '../local-context/filename-context-fetcher' +import { LocalKeywordContextFetcher } from '../local-context/local-keyword-context-fetcher' import { debug } from '../log' +import { getRerankWithLog } from '../logged-rerank' import { FixupTask } from '../non-stop/FixupTask' import { LocalStorage } from '../services/LocalStorageProvider' import { CODY_ACCESS_TOKEN_SECRET, SecretStorage } from '../services/SecretStorageProvider' @@ -38,7 +46,7 @@ import { isLoggedIn, } from './protocol' import { getRecipe } from './recipes' -import { getAuthStatus, getCodebaseContext } from './utils' +import { convertGitCloneURLToCodebaseName, getAuthStatus } from './utils' export type Config = Pick< ConfigurationWithAccessToken, @@ -107,7 +115,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp }), vscode.workspace.onDidChangeConfiguration(async () => { this.config = await getFullConfig(this.secretStorage) - const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor) + const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, chat) if (newCodebaseContext) { this.codebaseContext = newCodebaseContext await this.setAnonymousUserID() @@ -376,7 +384,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp } this.currentWorkspaceRoot = workspaceRoot - const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor) + const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, this.chat) if (!codebaseContext) { return } @@ -430,12 +438,14 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp default: { this.sendTranscript() - const prompt = await this.transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase())) + const { prompt, contextFiles } = await this.transcript.getPromptForLastInteraction( + getPreamble(this.codebaseContext.getCodebase()) + ) + this.transcript.setUsedContextFilesForLastInteraction(contextFiles) this.sendPrompt(prompt, interaction.getAssistantMessage().prefix ?? '') await this.saveTranscriptToChatHistory() } } - logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`) } @@ -460,7 +470,10 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp } transcript.addInteraction(interaction) - const prompt = await transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase())) + const { prompt, contextFiles } = await transcript.getPromptForLastInteraction( + getPreamble(this.codebaseContext.getCodebase()) + ) + transcript.setUsedContextFilesForLastInteraction(contextFiles) logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`) @@ -818,3 +831,49 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp this.disposables = [] } } + +/** + * Gets codebase context for the current workspace. + * + * @param config Cody configuration + * @param rgPath Path to rg (ripgrep) executable + * @param editor Editor instance + * @returns CodebaseContext if a codebase can be determined, else null + */ +export async function getCodebaseContext( + config: Config, + rgPath: string, + editor: Editor, + chatClient: ChatClient +): Promise { + const client = new SourcegraphGraphQLAPIClient(config) + const workspaceRoot = editor.getWorkspaceRootPath() + if (!workspaceRoot) { + return null + } + const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot }) + const gitOutput = gitCommand.stdout.toString().trim() + // Get codebase from config or fallback to getting repository name from git clone URL + const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput) + if (!codebase) { + return null + } + // Check if repo is embedded in endpoint + const repoId = await client.getRepoIdIfEmbeddingExists(codebase) + if (isError(repoId)) { + const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n` + console.info(infoMessage) + return null + } + + const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null + return new CodebaseContext( + config, + codebase, + embeddingsSearch, + new LocalKeywordContextFetcher(rgPath, editor, chatClient), + new FilenameContextFetcher(rgPath, editor, chatClient), + undefined, + getRerankWithLog(chatClient) + ) +} diff --git a/client/cody/src/chat/fastFileFinder.ts b/client/cody/src/chat/fastFileFinder.ts index 33753e98877..ccccf0af491 100644 --- a/client/cody/src/chat/fastFileFinder.ts +++ b/client/cody/src/chat/fastFileFinder.ts @@ -12,10 +12,11 @@ import * as path from 'path' export async function fastFilesExist( rgPath: string, rootPath: string, - filePaths: string[] + filePaths: string[], + maxDepth?: number ): Promise<{ [filePath: string]: boolean }> { const searchPattern = constructSearchPattern(filePaths) - const rgOutput = await executeRg(rgPath, rootPath, searchPattern) + const rgOutput = await executeRg(rgPath, rootPath, searchPattern, maxDepth) return processRgOutput(rgOutput, filePaths) } @@ -28,6 +29,7 @@ export function makeTrimRegex(sep: string): RegExp { // Regex to match '**', '*' or path.sep at the start (^) or end ($) of the string. const trimRegex = makeTrimRegex(path.sep) + /** * Constructs a search pattern for the 'rg' tool. * @@ -42,6 +44,7 @@ function constructSearchPattern(filePaths: string[]): string { }) return `{${searchPatternParts.join(',')}}` } + /** * Executes the 'rg' tool and returns the output. * @@ -50,11 +53,15 @@ function constructSearchPattern(filePaths: string[]): string { * @param searchPattern - The search pattern to use. * @returns The output from the 'rg' tool. */ -async function executeRg(rgPath: string, rootPath: string, searchPattern: string): Promise { +async function executeRg(rgPath: string, rootPath: string, searchPattern: string, maxDepth?: number): Promise { + const args = ['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global'] + if (maxDepth !== undefined) { + args.push('--max-depth', `${maxDepth}`) + } return new Promise((resolve, reject) => { execFile( rgPath, - ['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global'], + args, { cwd: rootPath, maxBuffer: 1024 * 1024 * 1024, diff --git a/client/cody/src/chat/utils.ts b/client/cody/src/chat/utils.ts index 9cead6bdba6..d3f42602cc6 100644 --- a/client/cody/src/chat/utils.ts +++ b/client/cody/src/chat/utils.ts @@ -1,15 +1,7 @@ -import { spawnSync } from 'child_process' - -import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context' import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration' -import { Editor } from '@sourcegraph/cody-shared/src/editor' -import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client' import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql' import { isError } from '@sourcegraph/cody-shared/src/utils' -import { LocalKeywordContextFetcher } from '../keyword-context/local-keyword-context-fetcher' - -import { Config } from './ChatViewProvider' import { AuthStatus, defaultAuthStatus, isLocalApp, unauthenticatedStatus } from './protocol' // Converts a git clone URL to the codebase name that includes the slash-separated code host, owner, and repository name @@ -50,40 +42,11 @@ export function convertGitCloneURLToCodebaseName(cloneURL: string): string | nul } return null } catch (error) { - console.log(`Cody could not extract repo name from clone URL ${cloneURL}:`, error) + console.error(`Cody could not extract repo name from clone URL ${cloneURL}:`, error) return null } } -export async function getCodebaseContext( - config: Config, - rgPath: string, - editor: Editor -): Promise { - const client = new SourcegraphGraphQLAPIClient(config) - const workspaceRoot = editor.getWorkspaceRootPath() - if (!workspaceRoot) { - return null - } - const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot }) - const gitOutput = gitCommand.stdout.toString().trim() - // Get codebase from config or fallback to getting repository name from git clone URL - const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput) - if (!codebase) { - return null - } - // Check if repo is embedded in endpoint - const repoId = await client.getRepoIdIfEmbeddingExists(codebase) - if (isError(repoId)) { - const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n` - console.info(infoMessage) - return null - } - - const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null - return new CodebaseContext(config, codebase, embeddingsSearch, new LocalKeywordContextFetcher(rgPath, editor)) -} - let client: SourcegraphGraphQLAPIClient let configWithToken: Pick diff --git a/client/cody/src/external-services.ts b/client/cody/src/external-services.ts index 943cca3d48a..00fafd883f0 100644 --- a/client/cody/src/external-services.ts +++ b/client/cody/src/external-services.ts @@ -12,8 +12,10 @@ import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/s import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql' import { isError } from '@sourcegraph/cody-shared/src/utils' -import { LocalKeywordContextFetcher } from './keyword-context/local-keyword-context-fetcher' +import { FilenameContextFetcher } from './local-context/filename-context-fetcher' +import { LocalKeywordContextFetcher } from './local-context/local-keyword-context-fetcher' import { logger } from './log' +import { getRerankWithLog } from './logged-rerank' interface ExternalServices { intentDetector: IntentDetector @@ -48,11 +50,15 @@ export async function configureExternalServices( } const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null + const chatClient = new ChatClient(completions) const codebaseContext = new CodebaseContext( initialConfig, initialConfig.codebase, embeddingsSearch, - new LocalKeywordContextFetcher(rgPath, editor) + new LocalKeywordContextFetcher(rgPath, editor, chatClient), + new FilenameContextFetcher(rgPath, editor, chatClient), + undefined, + getRerankWithLog(chatClient) ) const guardrails = new SourcegraphGuardrailsClient(client) @@ -60,7 +66,7 @@ export async function configureExternalServices( return { intentDetector: new SourcegraphIntentDetectorClient(client), codebaseContext, - chatClient: new ChatClient(completions), + chatClient, completionsClient: completions, guardrails, onConfigurationChange: newConfig => { diff --git a/client/cody/src/keyword-context/local-keyword-context-fetcher.test.ts b/client/cody/src/keyword-context/local-keyword-context-fetcher.test.ts deleted file mode 100644 index a578777719d..00000000000 --- a/client/cody/src/keyword-context/local-keyword-context-fetcher.test.ts +++ /dev/null @@ -1,153 +0,0 @@ -import * as assert from 'assert' - -import { Term, regexForTerms, userQueryToKeywordQuery } from './local-keyword-context-fetcher' - -describe('keyword context', () => { - it('userQueryToKeywordQuery', () => { - const cases: { query: string; expected: Term[] }[] = [ - { - query: 'Where is auth in Sourcegraph?', - expected: [ - { - count: 1, - originals: ['Where', 'Where'], - prefix: 'where', - stem: 'where', - }, - { - count: 1, - originals: ['auth', 'auth'], - prefix: 'auth', - stem: 'auth', - }, - { - count: 1, - originals: ['Sourcegraph', 'Sourcegraph'], - prefix: 'sourcegraph', - stem: 'sourcegraph', - }, - ], - }, - { - query: `Explain the following code at a high level: -uint32_t PackUInt32(const Color& color) { - uint32_t result = 0; - result |= static_cast(color.r * 255 + 0.5f) << 24; - result |= static_cast(color.g * 255 + 0.5f) << 16; - result |= static_cast(color.b * 255 + 0.5f) << 8; - result |= static_cast(color.a * 255 + 0.5f); - return result; -} -`, - expected: [ - { - count: 1, - originals: ['Explain', 'Explain'], - prefix: 'explain', - stem: 'explain', - }, - { - count: 1, - originals: ['following', 'following'], - prefix: 'follow', - stem: 'follow', - }, - { - count: 1, - originals: ['code', 'code'], - prefix: 'code', - stem: 'code', - }, - { - count: 1, - originals: ['high', 'high'], - prefix: 'high', - stem: 'high', - }, - { - count: 1, - originals: ['level', 'level'], - prefix: 'level', - stem: 'level', - }, - { - count: 6, - originals: ['uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t'], - prefix: 'uint', - stem: 'uinty2_t', - }, - { - count: 1, - originals: ['PackUInt32', 'PackUInt32'], - prefix: 'packuint', - stem: 'packuinty2', - }, - { - count: 1, - originals: ['const', 'const'], - prefix: 'const', - stem: 'const', - }, - { - count: 6, - originals: ['Color', 'Color', 'color', 'color', 'color', 'color', 'color'], - prefix: 'color', - stem: 'color', - }, - { - count: 6, - originals: ['result', 'result', 'result', 'result', 'result', 'result', 'result'], - prefix: 'result', - stem: 'result', - }, - { - count: 4, - originals: ['static_cast', 'static_cast', 'static_cast', 'static_cast', 'static_cast'], - prefix: 'static_cast', - stem: 'static_cast', - }, - { - count: 4, - originals: ['255', '255', '255', '255', '255'], - prefix: '255', - stem: '255', - }, - { - count: 1, - originals: ['return', 'return'], - prefix: 'return', - stem: 'return', - }, - ], - }, - ] - for (const testcase of cases) { - const actual = userQueryToKeywordQuery(testcase.query) - assert.deepStrictEqual(actual, testcase.expected) - } - }) - it('query to regex', () => { - const trials: { - userQuery: string - expRegex: string - }[] = [ - { - userQuery: 'Where is auth in Sourcegraph?', - expRegex: '(?:where|auth|sourcegraph)', - }, - { - userQuery: 'saml auth handler', - expRegex: '(?:saml|auth|handler)', - }, - { - userQuery: 'Where is the HTTP middleware defined in this codebase?', - expRegex: '(?:where|http|middlewar|defin|codebas)', - }, - ] - for (const trial of trials) { - const terms = userQueryToKeywordQuery(trial.userQuery) - const regex = regexForTerms(...terms) - expect(regex).toEqual(trial.expRegex) - } - }) -}) diff --git a/client/cody/src/local-context/filename-context-fetcher.ts b/client/cody/src/local-context/filename-context-fetcher.ts new file mode 100644 index 00000000000..b6550618685 --- /dev/null +++ b/client/cody/src/local-context/filename-context-fetcher.ts @@ -0,0 +1,154 @@ +import { execFile } from 'child_process' +import * as path from 'path' + +import { uniq } from 'lodash' +import * as vscode from 'vscode' + +import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat' +import { Editor } from '@sourcegraph/cody-shared/src/editor' +import { ContextResult } from '@sourcegraph/cody-shared/src/local-context' + +import { debug } from '../log' + +/** + * A local context fetcher that uses a LLM to generate filename fragments, which are then used to + * find files that are relevant based on their path or name. + */ +export class FilenameContextFetcher { + constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {} + + /** + * Returns pieces of context relevant for the given query. Uses a filename search approach + * @param query user query + * @param numResults the number of context results to return + * @returns a list of context results, sorted in *reverse* order (that is, + * the most important result appears at the bottom) + */ + public async getContext(query: string, numResults: number): Promise { + const time0 = performance.now() + + const rootPath = this.editor.getWorkspaceRootPath() + if (!rootPath) { + return [] + } + const time1 = performance.now() + const filenameFragments = await this.queryToFileFragments(query) + const time2 = performance.now() + const unsortedMatchingFiles = await this.getFilenames(rootPath, filenameFragments, 3) + const time3 = performance.now() + + const specialFragments = ['readme'] + const allBoostedFiles = [] + let remainingFiles = unsortedMatchingFiles + let nextRemainingFiles = [] + for (const specialFragment of specialFragments) { + const boostedFiles = [] + for (const fileName of remainingFiles) { + const fileNameLower = fileName.toLocaleLowerCase() + if (fileNameLower.includes(specialFragment)) { + boostedFiles.push(fileName) + } else { + nextRemainingFiles.push(fileName) + } + } + remainingFiles = nextRemainingFiles + nextRemainingFiles = [] + allBoostedFiles.push(...boostedFiles.sort((a, b) => a.length - b.length)) + } + + const sortedMatchingFiles = allBoostedFiles.concat(remainingFiles).slice(0, numResults) + + const results = await Promise.all( + sortedMatchingFiles + .map(async fileName => { + const uri = vscode.Uri.file(path.join(rootPath, fileName)) + const content = (await vscode.workspace.openTextDocument(uri)).getText() + return { + fileName, + content, + } + }) + .reverse() + ) + + const time4 = performance.now() + debug( + 'FilenameContextFetcher:getContext', + JSON.stringify({ + duration: time4 - time0, + queryToFileFragments: { duration: time2 - time1, fragments: filenameFragments }, + getFilenames: { duration: time3 - time2 }, + }), + { verbose: { matchingFiles: unsortedMatchingFiles, results: results.map(r => r.fileName) } } + ) + + return results + } + + private async queryToFileFragments(query: string): Promise { + const filenameFragments = await new Promise((resolve, reject) => { + let responseText = '' + this.chatClient.chat( + [ + { + speaker: 'human', + text: `Write 3 filename fragments that would be contained by files in a git repository that are relevant to answering the following user query: ${query} Your response should be only a space-delimited list of filename fragments and nothing else.`, + }, + ], + { + onChange: (text: string) => { + responseText = text + }, + onComplete: () => { + resolve(responseText.split(/\s+/).filter(e => e.length > 0)) + }, + onError: (message: string, statusCode?: number) => { + reject(new Error(message)) + }, + }, + { + temperature: 0, + fast: true, + } + ) + }) + const uniqueFragments = uniq(filenameFragments.map(e => e.toLocaleLowerCase())) + return uniqueFragments + } + + private async getFilenames(rootPath: string, filenameFragments: string[], maxDepth: number): Promise { + const searchPattern = '{' + filenameFragments.map(fragment => `**${fragment}**`).join(',') + '}' + const rgArgs = [ + '--files', + '--iglob', + searchPattern, + '--crlf', + '--fixed-strings', + '--no-config', + '--no-ignore-global', + `--max-depth=${maxDepth}`, + ] + const results = await new Promise((resolve, reject) => { + execFile( + this.rgPath, + rgArgs, + { + cwd: rootPath, + maxBuffer: 1024 * 1024 * 1024, + }, + (error, stdout, stderr) => { + if (error?.code === 2) { + reject(new Error(`${error.message}: ${stderr}`)) + } else { + resolve(stdout) + } + } + ) + }) + return results + .split('\n') + .map(r => r.trim()) + .filter(r => r.length > 0) + .sort((a, b) => a.length - b.length) + } +} diff --git a/client/cody/src/keyword-context/local-keyword-context-fetcher.ts b/client/cody/src/local-context/local-keyword-context-fetcher.ts similarity index 66% rename from client/cody/src/keyword-context/local-keyword-context-fetcher.ts rename to client/cody/src/local-context/local-keyword-context-fetcher.ts index 4bb0efe90ec..84d5eb7147d 100644 --- a/client/cody/src/keyword-context/local-keyword-context-fetcher.ts +++ b/client/cody/src/local-context/local-keyword-context-fetcher.ts @@ -1,14 +1,17 @@ import { execFile, spawn } from 'child_process' import * as path from 'path' +import Assembler from 'stream-json/Assembler' import StreamValues from 'stream-json/streamers/StreamValues' import * as vscode from 'vscode' import winkUtils from 'wink-nlp-utils' +import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat' import { Editor } from '@sourcegraph/cody-shared/src/editor' -import { KeywordContextFetcher, KeywordContextFetcherResult } from '@sourcegraph/cody-shared/src/keyword-context' +import { KeywordContextFetcher, ContextResult } from '@sourcegraph/cody-shared/src/local-context' import { logEvent } from '../event-logger' +import { debug } from '../log' /** * Exclude files without extensions and hidden files (starts with '.') @@ -17,6 +20,7 @@ import { logEvent } from '../event-logger' * Note: Ripgrep excludes binary files and respects .gitignore by default */ const fileExtRipgrepParams = [ + '--ignore-case', '-g', '*.*', '-g', @@ -25,10 +29,10 @@ const fileExtRipgrepParams = [ '!*.lock', '-g', '!*.snap', - '--threads', - '1', '--max-filesize', - '1M', + '10K', + '--max-depth', + '10', ] interface RipgrepStreamData { @@ -68,62 +72,34 @@ export function regexForTerms(...terms: Term[]): string { return `(?:${inner.join('|')})` } -export function userQueryToKeywordQuery(query: string): Term[] { - const longestCommonPrefix = (s: string, t: string): string => { - let endIdx = 0 - for (let i = 0; i < s.length && i < t.length; i++) { - if (s[i] !== t[i]) { - break - } - endIdx = i + 1 +function longestCommonPrefix(s: string, t: string): string { + let endIdx = 0 + for (let i = 0; i < s.length && i < t.length; i++) { + if (s[i] !== t[i]) { + break } - return s.slice(0, endIdx) + endIdx = i + 1 } - - const origWords: string[] = [] - for (const chunk of query.split(/\W+/)) { - if (chunk.trim().length === 0) { - continue - } - origWords.push(...winkUtils.string.tokenize0(chunk)) - } - const filteredWords = winkUtils.tokens.removeWords(origWords) - const terms = new Map() - for (const word of filteredWords) { - // Ignore ASCII-only strings of length 2 or less - if (word.length <= 2) { - let skip = true - for (let i = 0; i < word.length; i++) { - if (word.charCodeAt(i) >= 128) { - // non-ASCII - skip = false - break - } - } - if (skip) { - continue - } - } - - const stem = winkUtils.string.stem(word) - const term = terms.get(stem) || { - stem, - originals: [word], - prefix: longestCommonPrefix(word.toLowerCase(), stem), - count: 0, - } - term.originals.push(word) - term.count++ - terms.set(stem, term) - } - return [...terms.values()] + return s.slice(0, endIdx) } +/** + * A local context fetcher that uses a LLM to generate a keyword query, which is then + * converted to a regex fed to ripgrep to search for files that are relevant to the + * user query. + */ export class LocalKeywordContextFetcher implements KeywordContextFetcher { - constructor(private rgPath: string, private editor: Editor) {} + constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {} - public async getContext(query: string, numResults: number): Promise { - console.log('fetching keyword matches') + /** + * Returns pieces of context relevant for the given query. Uses a keyword-search-based + * approach. + * @param query user query + * @param numResults the number of context results to return + * @returns a list of context results, sorted in *reverse* order (that is, + * the most important result appears at the bottom) + */ + public async getContext(query: string, numResults: number): Promise { const startTime = performance.now() const rootPath = this.editor.getWorkspaceRootPath() if (!rootPath) { @@ -141,18 +117,89 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { ) const searchDuration = performance.now() - startTime logEvent('CodyVSCodeExtension:keywordContext:searchDuration', searchDuration, searchDuration) + debug('LocalKeywordContextFetcher:getContext', JSON.stringify({ searchDuration })) return messagePairs.reverse().flat() } + private async userQueryToExpandedKeywords(query: string): Promise> { + const start = performance.now() + const keywords = await new Promise((resolve, reject) => { + let responseText = '' + this.chatClient.chat( + [ + { + speaker: 'human', + text: `Write 3-5 keywords that you would use to search for code snippets that are relevant to answering the following user query: ${query} Your response should be only a list of space-delimited keywords and nothing else.`, + }, + ], + { + onChange: (text: string) => { + responseText = text + }, + onComplete: () => { + resolve(responseText.split(/\s+/).filter(e => e.length > 0)) + }, + onError: (message: string, statusCode?: number) => { + reject(new Error(message)) + }, + }, + { + temperature: 0, + fast: true, + } + ) + }) + const terms = new Map() + for (const kw of keywords) { + const stem = winkUtils.string.stem(kw) + if (terms.has(stem)) { + continue + } + terms.set(stem, { + count: 1, + originals: [kw], + prefix: longestCommonPrefix(kw.toLowerCase(), stem), + stem, + }) + } + debug( + 'LocalKeywordContextFetcher:userQueryToExpandedKeywords', + JSON.stringify({ duration: performance.now() - start }) + ) + return terms + } + + private async userQueryToKeywordQuery(query: string): Promise { + const terms = new Map() + const keywordExpansionStartTime = Date.now() + const expandedTerms = await this.userQueryToExpandedKeywords(query) + const keywordExpansionDuration = Date.now() - keywordExpansionStartTime + for (const [stem, term] of expandedTerms) { + if (terms.has(stem)) { + continue + } + terms.set(stem, term) + } + debug( + 'LocalKeywordContextFetcher:userQueryToKeywordQuery', + 'keyword expansion', + JSON.stringify({ + duration: keywordExpansionDuration, + expandedTerms: [...expandedTerms.values()].map(v => v.prefix), + }) + ) + const ret = [...terms.values()] + return ret + } + // Return context results for the Codebase Context Search recipe - public async getSearchContext(query: string, numResults: number): Promise { - console.log('fetching keyword search context...') + public async getSearchContext(query: string, numResults: number): Promise { const rootPath = this.editor.getWorkspaceRootPath() if (!rootPath) { return [] } - const stems = userQueryToKeywordQuery(query) + const stems = (await this.userQueryToKeywordQuery(query)) .map(t => (t.prefix.length < 4 ? t.originals[0] : t.prefix)) .join('|') const filesnamesWithScores = await this.fetchKeywordFiles(rootPath, query) @@ -180,8 +227,10 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { terms: Term[], rootPath: string ): Promise<{ [filename: string]: { bytesSearched: number } }> { + const start = performance.now() const regexQuery = `\\b${regexForTerms(...terms)}` - const proc = spawn(this.rgPath, ['-i', ...fileExtRipgrepParams, '--json', regexQuery, './'], { + const rgArgs = [...fileExtRipgrepParams, '--json', regexQuery, '.'] + const proc = spawn(this.rgPath, rgArgs, { cwd: rootPath, stdio: ['ignore', 'pipe', process.stderr], windowsHide: true, @@ -191,21 +240,40 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { bytesSearched: number } } = {} + + // Process the ripgrep JSON output to get the file sizes. We use an object filter to + // fast-filter out irrelevant lines of output + const objectFilter = (assembler: Assembler): boolean | undefined => { + // Each ripgrep JSON line begins with the following format: + // + // {"type":"begin|match|end","data":"... + // + // We only care about the "type":"end" lines, which contain the file size in bytes. + if (assembler.key === null && assembler.stack.length === 0 && assembler.current.type) { + return assembler.current.type === 'end' + } + // return undefined to indicate our uncertainty at this moment + return undefined + } await new Promise((resolve, reject) => { try { proc.stdout - .pipe(StreamValues.withParser()) + .pipe(StreamValues.withParser({ objectFilter })) .on('data', data => { try { const typedData = data as RipgrepStreamData switch (typedData.value.type) { - case 'end': - if (!fileTermCounts[typedData.value.data.path.text]) { - fileTermCounts[typedData.value.data.path.text] = { bytesSearched: 0 } + case 'end': { + let filename = typedData.value.data.path.text + if (filename.startsWith(`.${path.sep}`)) { + filename = filename.slice(2) } - fileTermCounts[typedData.value.data.path.text].bytesSearched = - typedData.value.data.stats.bytes_searched + if (!fileTermCounts[filename]) { + fileTermCounts[filename] = { bytesSearched: 0 } + } + fileTermCounts[filename].bytesSearched = typedData.value.data.stats.bytes_searched break + } } } catch (error) { reject(error) @@ -216,6 +284,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { reject(error) } }) + debug('fetchFileStats', JSON.stringify({ duration: performance.now() - start })) return fileTermCounts } @@ -227,20 +296,21 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { fileTermCounts: { [filename: string]: { [stem: string]: number } } termTotalFiles: { [stem: string]: number } }> { + const start = performance.now() const termFileCountsArr: { fileCounts: { [filename: string]: number }; filesSearched: number }[] = await Promise.all( queryTerms.map(async term => { + const rgArgs = [ + ...fileExtRipgrepParams, + '--count-matches', + '--stats', + `\\b${regexForTerms(term)}`, + '.', + ] const out = await new Promise((resolve, reject) => { execFile( this.rgPath, - [ - '-i', - ...fileExtRipgrepParams, - '--count-matches', - '--stats', - `\\b${regexForTerms(term)}`, - './', - ], + rgArgs, { cwd: rootPath, maxBuffer: 1024 * 1024 * 1024, @@ -272,8 +342,12 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { continue } try { + let filename = terms[0] + if (filename.startsWith(`.${path.sep}`)) { + filename = filename.slice(2) + } const count = parseInt(terms[1], 10) - fileCounts[terms[0]] = count + fileCounts[filename] = count } catch { console.error(`could not parse count from ${terms[1]}`) } @@ -282,6 +356,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { }) ) + debug('LocalKeywordContextFetcher.fetchFileMatches', JSON.stringify({ duration: performance.now() - start })) let totalFilesSearched = -1 for (const { filesSearched } of termFileCountsArr) { if (totalFilesSearched >= 0 && totalFilesSearched !== filesSearched) { @@ -316,13 +391,15 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher { rootPath: string, rawQuery: string ): Promise<{ filename: string; score: number }[]> { - const query = userQueryToKeywordQuery(rawQuery) + const query = await this.userQueryToKeywordQuery(rawQuery) + const fetchFilesStart = performance.now() const fileMatchesPromise = this.fetchFileMatches(query, rootPath) const fileStatsPromise = this.fetchFileStats(query, rootPath) - const fileMatches = await fileMatchesPromise const fileStats = await fileStatsPromise + const fetchFilesDuration = performance.now() - fetchFilesStart + debug('LocalKeywordContextFetcher:fetchKeywordFiles', JSON.stringify({ fetchFilesDuration })) const { fileTermCounts, termTotalFiles, totalFiles } = fileMatches const idfDict = idf(termTotalFiles, totalFiles) diff --git a/client/cody/src/logged-rerank.ts b/client/cody/src/logged-rerank.ts new file mode 100644 index 00000000000..ba3d21003c7 --- /dev/null +++ b/client/cody/src/logged-rerank.ts @@ -0,0 +1,24 @@ +import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat' +import { LLMReranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank' +import { ContextResult } from '@sourcegraph/cody-shared/src/local-context' + +import { debug } from './log' +import { TestSupport } from './test-support' + +export function getRerankWithLog( + chatClient: ChatClient +): (query: string, results: ContextResult[]) => Promise { + if (TestSupport.instance) { + const reranker = TestSupport.instance.getReranker() + return (query: string, results: ContextResult[]): Promise => reranker.rerank(query, results) + } + + const reranker = new LLMReranker(chatClient) + return async (userQuery: string, results: ContextResult[]): Promise => { + const start = performance.now() + const rerankedResults = await reranker.rerank(userQuery, results) + const duration = performance.now() - start + debug('Reranker:rerank', JSON.stringify({ duration })) + return rerankedResults + } +} diff --git a/client/cody/src/services/LocalStorageProvider.ts b/client/cody/src/services/LocalStorageProvider.ts index af1dea25e9f..630e539ddf2 100644 --- a/client/cody/src/services/LocalStorageProvider.ts +++ b/client/cody/src/services/LocalStorageProvider.ts @@ -43,12 +43,13 @@ export class LocalStorage { ([humanMessage, assistantMessageAndContextFiles]) => ({ humanMessage, assistantMessage: assistantMessageAndContextFiles, - context: assistantMessageAndContextFiles.contextFiles + fullContext: assistantMessageAndContextFiles.contextFiles ? assistantMessageAndContextFiles.contextFiles.map(fileName => ({ speaker: 'assistant', fileName, })) : [], + usedContextFiles: [], // Timestamp not recoverable so we use the group timestamp timestamp: id, }) diff --git a/client/cody/src/test-support.ts b/client/cody/src/test-support.ts index beccc67e090..da797ffcc4b 100644 --- a/client/cody/src/test-support.ts +++ b/client/cody/src/test-support.ts @@ -1,4 +1,6 @@ import { ChatMessage } from '@sourcegraph/cody-shared/src/chat/transcript/messages' +import { MockReranker, Reranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank' +import { ContextResult } from '@sourcegraph/cody-shared/src/local-context' import { ChatViewProvider } from './chat/ChatViewProvider' import { FixupTask } from './non-stop/FixupTask' @@ -39,6 +41,17 @@ export class TestSupport { public chatViewProvider = new Rendezvous() + public reranker: Reranker | undefined + + public getReranker(): Reranker { + if (!this.reranker) { + return new MockReranker( + (_: string, results: ContextResult[]): Promise => Promise.resolve(results) + ) + } + return this.reranker + } + public async chatTranscript(): Promise { return (await this.chatViewProvider.get()).transcriptForTesting(this) } diff --git a/cmd/frontend/graphqlbackend/completions.go b/cmd/frontend/graphqlbackend/completions.go index 439e68dbca9..6ab027e3f82 100644 --- a/cmd/frontend/graphqlbackend/completions.go +++ b/cmd/frontend/graphqlbackend/completions.go @@ -8,6 +8,7 @@ type CompletionsResolver interface { type CompletionsArgs struct { Input CompletionsInput + Fast bool } type Message struct { diff --git a/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go b/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go index d9b0c75e8e0..8a5d1062f4e 100644 --- a/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go +++ b/enterprise/cmd/frontend/internal/completions/resolvers/resolver.go @@ -45,7 +45,14 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke return "", errors.New("completions are not configured or disabled") } - ctx, done := httpapi.Trace(ctx, "resolver", completionsConfig.ChatModel). + var chatModel string + if args.Fast { + chatModel = completionsConfig.FastChatModel + } else { + chatModel = completionsConfig.ChatModel + } + + ctx, done := httpapi.Trace(ctx, "resolver", chatModel). WithErrorP(&err). Build() defer done() @@ -66,7 +73,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke params := convertParams(args) // No way to configure the model through the request, we hard code to chat. - params.Model = completionsConfig.ChatModel + params.Model = chatModel resp, err := client.Complete(ctx, types.CompletionsFeatureChat, params) if err != nil { return "", errors.Wrap(err, "client.Complete") diff --git a/enterprise/internal/completions/client/client.go b/enterprise/internal/completions/client/client.go index 61e96b472c0..bc9e4b5fe3e 100644 --- a/enterprise/internal/completions/client/client.go +++ b/enterprise/internal/completions/client/client.go @@ -79,6 +79,10 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio completionsConfig.ChatModel = completionsConfig.Model } + if completionsConfig.FastChatModel == "" { + completionsConfig.FastChatModel = completionsConfig.ChatModel + } + // TODO: Temporary workaround to fix instances where no completion model is set. if completionsConfig.CompletionModel == "" { completionsConfig.CompletionModel = "claude-instant-v1" @@ -102,6 +106,7 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio // TODO: These are not required right now as upstream overwrites this, // but should we switch to Cody Gateway they will be. ChatModel: "claude-v1", + FastChatModel: "claude-instant-v1", CompletionModel: "claude-instant-v1", } } diff --git a/enterprise/internal/completions/client/client_test.go b/enterprise/internal/completions/client/client_test.go index e84224e2816..121a2d6e4f2 100644 --- a/enterprise/internal/completions/client/client_test.go +++ b/enterprise/internal/completions/client/client_test.go @@ -28,13 +28,15 @@ func TestGetCompletionsConfig(t *testing.T) { Enabled: true, Provider: "anthropic", ChatModel: "claude-v1", + FastChatModel: "claude-instant-v1", CompletionModel: "claude-instant-v1", }, }, want: autogold.Expect(&schema.Completions{ ChatModel: "claude-v1", CompletionModel: "claude-instant-v1", - Enabled: true, - Provider: "anthropic", + FastChatModel: "claude-instant-v1", + Enabled: true, + Provider: "anthropic", }), }, { diff --git a/enterprise/internal/completions/httpapi/codecompletion.go b/enterprise/internal/completions/httpapi/codecompletion.go index f5e2557138a..c4bfec91061 100644 --- a/enterprise/internal/completions/httpapi/codecompletion.go +++ b/enterprise/internal/completions/httpapi/codecompletion.go @@ -19,7 +19,7 @@ func NewCodeCompletionsHandler(logger log.Logger, db database.DB) http.Handler { logger = logger.Scoped("code", "code completions handler") rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureCode) - return newCompletionsHandler(rl, "code", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string { + return newCompletionsHandler(rl, "code", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string { // No user defined models for now. // TODO(eseliger): Look into reviving this, but it was unused so far. return c.CompletionModel diff --git a/enterprise/internal/completions/httpapi/handler.go b/enterprise/internal/completions/httpapi/handler.go index d77767ec3e7..1b3c1fadbe4 100644 --- a/enterprise/internal/completions/httpapi/handler.go +++ b/enterprise/internal/completions/httpapi/handler.go @@ -19,7 +19,12 @@ import ( // being cancelled. const maxRequestDuration = time.Minute -func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(types.CompletionRequestParameters, *schema.Completions) string, handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter)) http.Handler { +func newCompletionsHandler( + rl RateLimiter, + traceFamily string, + getModel func(types.CodyCompletionRequestParameters, *schema.Completions) string, + handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter), +) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed) @@ -40,13 +45,17 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ return } - var requestParams types.CompletionRequestParameters + var requestParams types.CodyCompletionRequestParameters if err := json.NewDecoder(r.Body).Decode(&requestParams); err != nil { http.Error(w, "could not decode request body", http.StatusBadRequest) return } // TODO: Model is not configurable but technically allowed in the request body right now. + if requestParams.Model != "" { + http.Error(w, "user-specified models are not allowed", http.StatusBadRequest) + return + } requestParams.Model = getModel(requestParams, completionsConfig) var err error @@ -77,7 +86,7 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ return } - handle(ctx, requestParams, completionClient, w) + handle(ctx, requestParams.CompletionRequestParameters, completionClient, w) }) } diff --git a/enterprise/internal/completions/httpapi/stream.go b/enterprise/internal/completions/httpapi/stream.go index a6b67568ebb..ab525fb156e 100644 --- a/enterprise/internal/completions/httpapi/stream.go +++ b/enterprise/internal/completions/httpapi/stream.go @@ -19,8 +19,11 @@ func NewChatCompletionsStreamHandler(logger log.Logger, db database.DB) http.Han logger = logger.Scoped("chat", "chat completions handler") rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureChat) - return newCompletionsHandler(rl, "chat", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string { + return newCompletionsHandler(rl, "chat", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string { // No user defined models for now. + if requestParams.Fast { + return c.FastChatModel + } return c.ChatModel }, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) { eventWriter, err := streamhttp.NewWriter(w) diff --git a/enterprise/internal/completions/types/types.go b/enterprise/internal/completions/types/types.go index ec3f4002d52..3615b5f679a 100644 --- a/enterprise/internal/completions/types/types.go +++ b/enterprise/internal/completions/types/types.go @@ -37,6 +37,14 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str return fmt.Sprintf("%s %s", prefix, m.Text), nil } +type CodyCompletionRequestParameters struct { + CompletionRequestParameters + + // When Fast is true, then it is used as a hint to prefer a model + // that is faster (but probably "dumber"). + Fast bool +} + type CompletionRequestParameters struct { // Prompt exists only for backwards compatibility. Do not use it in new // implementations. It will be removed once we are reasonably sure 99% diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8a6fd0d8e02..0433c1359ae 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1445,6 +1445,13 @@ importers: '@sourcegraph/http-client': specifier: workspace:* version: link:../http-client + xml2js: + specifier: ^0.6.0 + version: 0.6.0 + devDependencies: + '@types/xml2js': + specifier: ^0.4.11 + version: 0.4.11 client/cody-slack: dependencies: @@ -10016,6 +10023,12 @@ packages: dependencies: '@types/node': 13.13.5 + /@types/xml2js@0.4.11: + resolution: {integrity: sha512-JdigeAKmCyoJUiQljjr7tQG3if9NkqGUgwEUqBvV0N7LM4HyQk7UXCnusRa1lnvXAEYJ8mw8GtZWioagNztOwA==} + dependencies: + '@types/node': 13.13.5 + dev: true + /@types/yargs-parser@21.0.0: resolution: {integrity: sha512-iO9ZQHkZxHn4mSakYV0vFHAVDyEOIJQrV2uZ06HxEPcx+mt8swXoZHIbaaJ2crJYFfErySgktuTZ3BeLz+XmFA==} @@ -26143,7 +26156,6 @@ packages: /sax@1.2.4: resolution: {integrity: sha512-NqVDv9TpANUjFm0N8uM5GxL36UgKi9/atZw+x7YFnQ8ckwFGKrl4xX4yWtrey3UJm5nP1kUbnYgLopqWNSRhWw==} - dev: true /saxes@5.0.1: resolution: {integrity: sha512-5LBh1Tls8c9xgGjw3QrMwETmTMVk0oFgvrFSvWx62llR2hcEInrKNZ2GZCCuuy2lvWrdl5jhbpeqc5hRYKFOcw==} @@ -30109,6 +30121,14 @@ packages: xmlbuilder: 11.0.1 dev: true + /xml2js@0.6.0: + resolution: {integrity: sha512-eLTh0kA8uHceqesPqSE+VvO1CDDJWMwlQfB6LuN6T8w6MaDJ8Txm8P7s5cHD0miF0V+GGTZrDQfxPZQVsur33w==} + engines: {node: '>=4.0.0'} + dependencies: + sax: 1.2.4 + xmlbuilder: 11.0.1 + dev: false + /xml@1.0.1: resolution: {integrity: sha512-huCv9IH9Tcf95zuYCsQraZtWnJvBtLVE0QHMOs8bWyZAFZNDcYjsPq1nEx8jKA9y+Beo9v+7OBPRisQTjinQMw==} dev: true @@ -30116,7 +30136,6 @@ packages: /xmlbuilder@11.0.1: resolution: {integrity: sha512-fDlsI/kFEx7gLvbecc0/ohLG50fugQp8ryHzMTuW9vSa1GJ0XYWKnhsUx7oie3G98+r56aTQIUB4kht42R3JvA==} engines: {node: '>=4.0'} - dev: true /xmlchars@2.2.0: resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==} diff --git a/schema/schema.go b/schema/schema.go index 3ca713a8f0a..2e074fe6a9f 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -564,6 +564,8 @@ type Completions struct { Enabled bool `json:"enabled"` // Endpoint description: The endpoint under which to reach the provider. Currently only used for provider types "sourcegraph", "openai" and "anthropic". The default values are "https://cody-gateway.sourcegraph.com", "https://api.openai.com/v1/chat/completions", and "https://api.anthropic.com/v1/complete" for Sourcegraph, OpenAI, and Anthropic, respectively. Endpoint string `json:"endpoint,omitempty"` + // FastChatModel description: The model used for fast chat completions. + FastChatModel string `json:"fastChatModel,omitempty"` // Model description: DEPRECATED. Use chatModel instead. Model string `json:"model,omitempty"` // PerUserCodeCompletionsDailyLimit description: If > 0, enables the maximum number of code completions requests allowed to be made by a single user account in a day. On instances that allow anonymous requests, the rate limit is enforced by IP. diff --git a/schema/site.schema.json b/schema/site.schema.json index b7e91954536..1bbcb97ad7b 100644 --- a/schema/site.schema.json +++ b/schema/site.schema.json @@ -2422,6 +2422,10 @@ "description": "DEPRECATED. Use chatModel instead.", "type": "string" }, + "fastChatModel": { + "description": "The model used for fast chat completions.", + "type": "string" + }, "chatModel": { "description": "The model used for chat completions. If using the default provider 'sourcegraph', a reasonable default model will be set.", "type": "string"