mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 15:51:43 +00:00
LLM-enhanced keyword context (#52815)
* Adds `fast` parameter to completions endpoint and `fastChatModel` param to site config. This is intended for faster chat models that are useful for simple generations. * Use the fast chat model to generate a local keyword search. This replaces the old keyword search mechanism, which stemmed/lemmatized every word in the user query. * Use the fast chat model to generate a small set of file fragments to search for. This is mainly useful for surfacing READMEs for questions like "What does this project do?" * Update the set of "files read" presented in the UI to include only those files actually read into the context window. Previously, we were showing all files returned by the context fetcher, but in reality, only a subset of these would fit into the context window.
This commit is contained in:
parent
e541d08400
commit
dd528f30fe
@ -111,7 +111,8 @@ async function startCLI() {
|
||||
}
|
||||
}
|
||||
|
||||
const finalPrompt = await transcript.toPrompt(getPreamble(codebase))
|
||||
const { prompt: finalPrompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(codebase))
|
||||
transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
|
||||
let text = ''
|
||||
streamCompletions(completionsClient, finalPrompt, {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
|
||||
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
|
||||
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context'
|
||||
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context'
|
||||
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
|
||||
import { isError } from '@sourcegraph/cody-shared/src/utils'
|
||||
|
||||
@ -26,7 +26,8 @@ export async function createCodebaseContext(
|
||||
{ useContext: contextType, serverEndpoint },
|
||||
codebase,
|
||||
embeddingsSearch,
|
||||
new LocalKeywordContextFetcherMock()
|
||||
new LocalKeywordContextFetcherMock(),
|
||||
null
|
||||
)
|
||||
|
||||
return codebaseContext
|
||||
|
||||
@ -45,7 +45,8 @@ export async function interactionFromMessage(
|
||||
new Interaction(
|
||||
{ speaker: 'human', text, displayText: text },
|
||||
{ speaker: 'assistant', text: '', displayText: '' },
|
||||
contextMessages
|
||||
contextMessages,
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
5
client/cody-shared/BUILD.bazel
generated
5
client/cody-shared/BUILD.bazel
generated
@ -58,6 +58,7 @@ ts_project(
|
||||
"src/chat/viewHelpers.ts",
|
||||
"src/codebase-context/index.ts",
|
||||
"src/codebase-context/messages.ts",
|
||||
"src/codebase-context/rerank.ts",
|
||||
"src/configuration.ts",
|
||||
"src/editor/index.ts",
|
||||
"src/editor/withPreselectedOptions.ts",
|
||||
@ -68,7 +69,7 @@ ts_project(
|
||||
"src/hallucinations-detector/index.ts",
|
||||
"src/intent-detector/client.ts",
|
||||
"src/intent-detector/index.ts",
|
||||
"src/keyword-context/index.ts",
|
||||
"src/local-context/index.ts",
|
||||
"src/prompt/constants.ts",
|
||||
"src/prompt/prompt-mixin.ts",
|
||||
"src/prompt/templates.ts",
|
||||
@ -93,6 +94,8 @@ ts_project(
|
||||
deps = [
|
||||
":node_modules/@sourcegraph/common",
|
||||
":node_modules/@sourcegraph/http-client",
|
||||
":node_modules/@types/xml2js",
|
||||
":node_modules/xml2js",
|
||||
"//:node_modules/@microsoft/fetch-event-source",
|
||||
"//:node_modules/@types/isomorphic-fetch",
|
||||
"//:node_modules/@types/marked",
|
||||
|
||||
@ -19,6 +19,10 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@sourcegraph/common": "workspace:*",
|
||||
"@sourcegraph/http-client": "workspace:*"
|
||||
"@sourcegraph/http-client": "workspace:*",
|
||||
"xml2js": "^0.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/xml2js": "^0.4.11"
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,9 @@ import { Message } from '../sourcegraph-api'
|
||||
import type { SourcegraphCompletionsClient } from '../sourcegraph-api/completions/client'
|
||||
import type { CompletionParameters, CompletionCallbacks } from '../sourcegraph-api/completions/types'
|
||||
|
||||
const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<CompletionParameters, 'messages'> = {
|
||||
type ChatParameters = Omit<CompletionParameters, 'messages'>
|
||||
|
||||
const DEFAULT_CHAT_COMPLETION_PARAMETERS: ChatParameters = {
|
||||
temperature: 0.2,
|
||||
maxTokensToSample: SOLUTION_TOKEN_LENGTH,
|
||||
topK: -1,
|
||||
@ -13,10 +15,17 @@ const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<CompletionParameters, 'messages'>
|
||||
export class ChatClient {
|
||||
constructor(private completions: SourcegraphCompletionsClient) {}
|
||||
|
||||
public chat(messages: Message[], cb: CompletionCallbacks): () => void {
|
||||
public chat(messages: Message[], cb: CompletionCallbacks, params?: Partial<ChatParameters>): () => void {
|
||||
const isLastMessageFromHuman = messages.length > 0 && messages[messages.length - 1].speaker === 'human'
|
||||
const augmentedMessages = isLastMessageFromHuman ? messages.concat([{ speaker: 'assistant' }]) : messages
|
||||
|
||||
return this.completions.stream({ messages: augmentedMessages, ...DEFAULT_CHAT_COMPLETION_PARAMETERS }, cb)
|
||||
return this.completions.stream(
|
||||
{
|
||||
...DEFAULT_CHAT_COMPLETION_PARAMETERS,
|
||||
...params,
|
||||
messages: augmentedMessages,
|
||||
},
|
||||
cb
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ export async function createClient({
|
||||
|
||||
const embeddingsSearch = repoId ? new SourcegraphEmbeddingsSearchClient(graphqlClient, repoId, true) : null
|
||||
|
||||
const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null)
|
||||
const codebaseContext = new CodebaseContext(config, config.codebase, embeddingsSearch, null, null)
|
||||
|
||||
const intentDetector = new SourcegraphIntentDetectorClient(graphqlClient)
|
||||
|
||||
@ -116,10 +116,11 @@ export async function createClient({
|
||||
|
||||
sendTranscript()
|
||||
|
||||
const prompt = await transcript.toPrompt(getPreamble(config.codebase))
|
||||
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(getPreamble(config.codebase))
|
||||
transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
|
||||
const responsePrefix = interaction.getAssistantMessage().prefix ?? ''
|
||||
let rawText = ''
|
||||
|
||||
chatClient.chat(prompt, {
|
||||
onChange(_rawText) {
|
||||
rawText = _rawText
|
||||
|
||||
@ -31,7 +31,8 @@ export class ChatQuestion implements Recipe {
|
||||
context.intentDetector,
|
||||
context.codebaseContext,
|
||||
context.editor.getActiveTextEditorSelection() || null
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -46,7 +46,8 @@ export class ContextSearch implements Recipe {
|
||||
text: '',
|
||||
displayText: await this.displaySearchResults(truncatedText, context.codebaseContext, wsRootPath),
|
||||
},
|
||||
new Promise(resolve => resolve([]))
|
||||
new Promise(resolve => resolve([])),
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -32,7 +32,8 @@ export class ExplainCodeDetailed implements Recipe {
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,7 +32,8 @@ export class ExplainCodeHighLevel implements Recipe {
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +103,8 @@ export class FileTouch implements Recipe {
|
||||
speaker: 'assistant',
|
||||
prefix: 'Working on it! I will notify you when the file is ready.\n',
|
||||
},
|
||||
this.getContextMessages(selection, currentDir)
|
||||
this.getContextMessages(selection, currentDir),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -36,7 +36,8 @@ If you have no ideas because the code looks fine, feel free to say that it alrea
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
new Promise(resolve => resolve([]))
|
||||
new Promise(resolve => resolve([])),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +67,8 @@ export class Fixup implements Recipe {
|
||||
speaker: 'assistant',
|
||||
prefix: 'Check your document for updates from Cody.\n',
|
||||
},
|
||||
this.getContextMessages(selection.selectedText, context.codebaseContext)
|
||||
this.getContextMessages(selection.selectedText, context.codebaseContext),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -60,7 +60,8 @@ export class GenerateDocstring implements Recipe {
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +52,8 @@ export class PrDescription implements Recipe {
|
||||
prefix: emptyGitCommitMessage,
|
||||
text: emptyGitCommitMessage,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
@ -71,7 +72,8 @@ export class PrDescription implements Recipe {
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,7 +71,8 @@ export class ReleaseNotes implements Recipe {
|
||||
prefix: emptyGitLogMessage,
|
||||
text: emptyGitLogMessage,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
@ -91,7 +92,8 @@ export class ReleaseNotes implements Recipe {
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,7 +44,8 @@ export class GenerateTest implements Recipe {
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +65,8 @@ export class GitHistory implements Recipe {
|
||||
prefix: emptyGitLogMessage,
|
||||
text: emptyGitLogMessage,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
@ -84,7 +85,8 @@ export class GitHistory implements Recipe {
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,7 +44,8 @@ export class ImproveVariableNames implements Recipe {
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +56,8 @@ export class InlineAssist implements Recipe {
|
||||
displayText,
|
||||
},
|
||||
{ speaker: 'assistant' },
|
||||
this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor)
|
||||
this.getContextMessages(truncatedText, context.codebaseContext, selection, context.editor),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -31,7 +31,8 @@ export class NextQuestions implements Recipe {
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext)
|
||||
this.getContextMessages(promptMessage, context.editor, context.intentDetector, context.codebaseContext),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -76,7 +76,8 @@ export class NonStop implements Recipe {
|
||||
speaker: 'assistant',
|
||||
prefix: 'Check your document for updates from Cody.',
|
||||
},
|
||||
this.getContextMessages(selection.selectedText, context.codebaseContext)
|
||||
this.getContextMessages(selection.selectedText, context.codebaseContext),
|
||||
[]
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@ -51,7 +51,8 @@ However if no optimization is possible; just say the code is already optimized.
|
||||
truncatedFollowingText,
|
||||
selection,
|
||||
context.codebaseContext
|
||||
)
|
||||
),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,7 +39,8 @@ export class TranslateToLanguage implements Recipe {
|
||||
prefix: assistantResponsePrefix,
|
||||
text: assistantResponsePrefix,
|
||||
},
|
||||
Promise.resolve([])
|
||||
Promise.resolve([]),
|
||||
[]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { OldContextMessage } from '../../codebase-context/messages'
|
||||
import { ContextFile, ContextMessage, OldContextMessage } from '../../codebase-context/messages'
|
||||
import { CHARS_PER_TOKEN, MAX_AVAILABLE_PROMPT_LENGTH } from '../../prompt/constants'
|
||||
import { PromptMixin } from '../../prompt/prompt-mixin'
|
||||
import { Message } from '../../sourcegraph-api'
|
||||
|
||||
import { Interaction, InteractionJSON } from './interaction'
|
||||
@ -20,18 +21,19 @@ export interface TranscriptJSON {
|
||||
}
|
||||
|
||||
/**
|
||||
* A transcript of a conversation between a human and an assistant.
|
||||
* The "model" class that tracks the call and response of the Cody chat box.
|
||||
* Any "controller" logic belongs outside of this class.
|
||||
*/
|
||||
export class Transcript {
|
||||
public static fromJSON(json: TranscriptJSON): Transcript {
|
||||
return new Transcript(
|
||||
json.interactions.map(
|
||||
({ humanMessage, assistantMessage, context, timestamp }) =>
|
||||
({ humanMessage, assistantMessage, fullContext, usedContextFiles, timestamp }) =>
|
||||
new Interaction(
|
||||
humanMessage,
|
||||
assistantMessage,
|
||||
Promise.resolve(
|
||||
context.map(message => {
|
||||
fullContext.map(message => {
|
||||
if (message.file) {
|
||||
return message
|
||||
}
|
||||
@ -44,6 +46,7 @@ export class Transcript {
|
||||
return message
|
||||
})
|
||||
),
|
||||
usedContextFiles,
|
||||
timestamp || new Date().toISOString()
|
||||
)
|
||||
),
|
||||
@ -144,20 +147,53 @@ export class Transcript {
|
||||
return -1
|
||||
}
|
||||
|
||||
public async toPrompt(preamble: Message[] = []): Promise<Message[]> {
|
||||
public async getPromptForLastInteraction(
|
||||
preamble: Message[] = []
|
||||
): Promise<{ prompt: Message[]; contextFiles: ContextFile[] }> {
|
||||
if (this.interactions.length === 0) {
|
||||
return { prompt: [], contextFiles: [] }
|
||||
}
|
||||
|
||||
const lastInteractionWithContextIndex = await this.getLastInteractionWithContextIndex()
|
||||
const messages: Message[] = []
|
||||
for (let index = 0; index < this.interactions.length; index++) {
|
||||
// Include context messages for the last interaction that has a non-empty context.
|
||||
const interactionMessages = await this.interactions[index].toPrompt(
|
||||
index === lastInteractionWithContextIndex
|
||||
)
|
||||
messages.push(...interactionMessages)
|
||||
const interaction = this.interactions[index]
|
||||
const humanMessage = PromptMixin.mixInto(interaction.getHumanMessage())
|
||||
const assistantMessage = interaction.getAssistantMessage()
|
||||
const contextMessages = await interaction.getFullContext()
|
||||
if (index === lastInteractionWithContextIndex) {
|
||||
messages.push(...contextMessages, humanMessage, assistantMessage)
|
||||
} else {
|
||||
messages.push(humanMessage, assistantMessage)
|
||||
}
|
||||
}
|
||||
|
||||
const preambleTokensUsage = preamble.reduce((acc, message) => acc + estimateTokensUsage(message), 0)
|
||||
const truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage)
|
||||
return [...preamble, ...truncatedMessages]
|
||||
let truncatedMessages = truncatePrompt(messages, MAX_AVAILABLE_PROMPT_LENGTH - preambleTokensUsage)
|
||||
|
||||
// Return what context fits in the window
|
||||
const contextFiles: ContextFile[] = []
|
||||
for (const msg of truncatedMessages) {
|
||||
const contextFile = (msg as ContextMessage).file
|
||||
if (contextFile) {
|
||||
contextFiles.push(contextFile)
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out extraneous fields from ContextMessage instances
|
||||
truncatedMessages = truncatedMessages.map(({ speaker, text }) => ({ speaker, text }))
|
||||
|
||||
return {
|
||||
prompt: [...preamble, ...truncatedMessages],
|
||||
contextFiles,
|
||||
}
|
||||
}
|
||||
|
||||
public setUsedContextFilesForLastInteraction(contextFiles: ContextFile[]): void {
|
||||
if (this.interactions.length === 0) {
|
||||
throw new Error('Cannot set context files for empty transcript')
|
||||
}
|
||||
this.interactions[this.interactions.length - 1].setUsedContext(contextFiles)
|
||||
}
|
||||
|
||||
public toChat(): ChatMessage[] {
|
||||
|
||||
@ -1,89 +1,59 @@
|
||||
import { ContextMessage, ContextFile } from '../../codebase-context/messages'
|
||||
import { PromptMixin } from '../../prompt/prompt-mixin'
|
||||
import { Message } from '../../sourcegraph-api'
|
||||
|
||||
import { ChatMessage, InteractionMessage } from './messages'
|
||||
|
||||
export interface InteractionJSON {
|
||||
humanMessage: InteractionMessage
|
||||
assistantMessage: InteractionMessage
|
||||
context: ContextMessage[]
|
||||
fullContext: ContextMessage[]
|
||||
usedContextFiles: ContextFile[]
|
||||
timestamp: string
|
||||
}
|
||||
|
||||
export class Interaction {
|
||||
private readonly humanMessage: InteractionMessage
|
||||
private assistantMessage: InteractionMessage
|
||||
private cachedContextFiles: ContextFile[] = []
|
||||
public readonly timestamp: string
|
||||
private readonly context: Promise<ContextMessage[]>
|
||||
|
||||
constructor(
|
||||
humanMessage: InteractionMessage,
|
||||
assistantMessage: InteractionMessage,
|
||||
context: Promise<ContextMessage[]>,
|
||||
timestamp: string = new Date().toISOString()
|
||||
) {
|
||||
this.humanMessage = humanMessage
|
||||
this.assistantMessage = assistantMessage
|
||||
this.timestamp = timestamp
|
||||
|
||||
// This is some hacky behavior: returns a promise that resolves to the same array that was passed,
|
||||
// but also caches the context file names in memory as a side effect.
|
||||
this.context = context.then(messages => {
|
||||
const contextFilesMap = messages.reduce((map, { file }) => {
|
||||
if (!file?.fileName) {
|
||||
return map
|
||||
}
|
||||
map[`${file.repoName || 'repo'}@${file?.revision || 'HEAD'}/${file.fileName}`] = file
|
||||
return map
|
||||
}, {} as { [key: string]: ContextFile })
|
||||
|
||||
// Cache the context files so we don't have to block the UI when calling `toChat` by waiting for the context to resolve.
|
||||
this.cachedContextFiles = [
|
||||
...Object.keys(contextFilesMap)
|
||||
.sort((a, b) => a.localeCompare(b))
|
||||
.map((key: string) => contextFilesMap[key]),
|
||||
]
|
||||
|
||||
return messages
|
||||
})
|
||||
}
|
||||
private readonly humanMessage: InteractionMessage,
|
||||
private assistantMessage: InteractionMessage,
|
||||
private fullContext: Promise<ContextMessage[]>,
|
||||
private usedContextFiles: ContextFile[],
|
||||
public readonly timestamp: string = new Date().toISOString()
|
||||
) {}
|
||||
|
||||
public getAssistantMessage(): InteractionMessage {
|
||||
return this.assistantMessage
|
||||
return { ...this.assistantMessage }
|
||||
}
|
||||
|
||||
public setAssistantMessage(assistantMessage: InteractionMessage): void {
|
||||
this.assistantMessage = assistantMessage
|
||||
}
|
||||
|
||||
public getHumanMessage(): InteractionMessage {
|
||||
return { ...this.humanMessage }
|
||||
}
|
||||
|
||||
public async getFullContext(): Promise<ContextMessage[]> {
|
||||
const msgs = await this.fullContext
|
||||
return msgs.map(msg => ({ ...msg }))
|
||||
}
|
||||
|
||||
public async hasContext(): Promise<boolean> {
|
||||
const contextMessages = await this.context
|
||||
const contextMessages = await this.fullContext
|
||||
return contextMessages.length > 0
|
||||
}
|
||||
|
||||
public async toPrompt(includeContext: boolean): Promise<Message[]> {
|
||||
const messages: (ContextMessage | InteractionMessage)[] = [
|
||||
PromptMixin.mixInto(this.humanMessage),
|
||||
this.assistantMessage,
|
||||
]
|
||||
if (includeContext) {
|
||||
messages.unshift(...(await this.context))
|
||||
}
|
||||
|
||||
return messages.map(message => ({ speaker: message.speaker, text: message.text }))
|
||||
public setUsedContext(usedContextFiles: ContextFile[]): void {
|
||||
this.usedContextFiles = usedContextFiles
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the interaction to chat message pair: one message from a human, one from an assistant.
|
||||
*/
|
||||
public toChat(): ChatMessage[] {
|
||||
return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.cachedContextFiles }]
|
||||
return [this.humanMessage, { ...this.assistantMessage, contextFiles: this.usedContextFiles }]
|
||||
}
|
||||
|
||||
public async toChatPromise(): Promise<ChatMessage[]> {
|
||||
await this.context
|
||||
await this.fullContext
|
||||
return this.toChat()
|
||||
}
|
||||
|
||||
@ -91,7 +61,8 @@ export class Interaction {
|
||||
return {
|
||||
humanMessage: this.humanMessage,
|
||||
assistantMessage: this.assistantMessage,
|
||||
context: await this.context,
|
||||
fullContext: await this.fullContext,
|
||||
usedContextFiles: this.usedContextFiles,
|
||||
timestamp: this.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ async function generateLongTranscript(): Promise<{ transcript: Transcript; token
|
||||
describe('Transcript', () => {
|
||||
it('generates an empty prompt with no interactions', async () => {
|
||||
const transcript = new Transcript()
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
assert.deepStrictEqual(prompt, [])
|
||||
})
|
||||
|
||||
@ -51,7 +51,7 @@ describe('Transcript', () => {
|
||||
const transcript = new Transcript()
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
|
||||
{ speaker: 'assistant', text: undefined },
|
||||
@ -78,7 +78,8 @@ describe('Transcript', () => {
|
||||
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
embeddings,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
),
|
||||
})
|
||||
)
|
||||
@ -86,7 +87,7 @@ describe('Transcript', () => {
|
||||
const transcript = new Transcript()
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
|
||||
{ speaker: 'assistant', text: 'Ok.' },
|
||||
@ -114,7 +115,8 @@ describe('Transcript', () => {
|
||||
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
embeddings,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
),
|
||||
firstInteraction: true,
|
||||
})
|
||||
@ -123,7 +125,7 @@ describe('Transcript', () => {
|
||||
const transcript = new Transcript()
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
|
||||
{ speaker: 'assistant', text: 'Ok.' },
|
||||
@ -148,7 +150,8 @@ describe('Transcript', () => {
|
||||
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
embeddings,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
)
|
||||
|
||||
const chatQuestionRecipe = new ChatQuestion(() => {})
|
||||
@ -175,7 +178,7 @@ describe('Transcript', () => {
|
||||
)
|
||||
transcript.addInteraction(secondInteraction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
|
||||
{ speaker: 'assistant', text: assistantResponse },
|
||||
@ -195,7 +198,7 @@ describe('Transcript', () => {
|
||||
const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction)
|
||||
const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages.
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
assert.deepStrictEqual(prompt.length, numExpectedMessages)
|
||||
})
|
||||
|
||||
@ -212,7 +215,7 @@ describe('Transcript', () => {
|
||||
const numExpectedInteractions = Math.floor(MAX_AVAILABLE_PROMPT_LENGTH / tokensPerInteraction)
|
||||
const numExpectedMessages = numExpectedInteractions * 2 // Each interaction has two messages.
|
||||
|
||||
const prompt = await transcript.toPrompt(preamble)
|
||||
const { prompt } = await transcript.getPromptForLastInteraction(preamble)
|
||||
assert.deepStrictEqual(prompt.length, numExpectedMessages)
|
||||
assert.deepStrictEqual(preamble, prompt.slice(0, 4))
|
||||
})
|
||||
@ -233,7 +236,8 @@ describe('Transcript', () => {
|
||||
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
embeddings,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
)
|
||||
|
||||
const chatQuestionRecipe = new ChatQuestion(() => {})
|
||||
@ -249,7 +253,7 @@ describe('Transcript', () => {
|
||||
)
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'Use the following text from file `docs/README.md`:\n# Main' },
|
||||
{ speaker: 'assistant', text: 'Ok.' },
|
||||
@ -288,7 +292,7 @@ describe('Transcript', () => {
|
||||
)
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'how do access tokens work in sourcegraph' },
|
||||
{ speaker: 'assistant', text: undefined },
|
||||
@ -309,7 +313,8 @@ describe('Transcript', () => {
|
||||
{ useContext: 'embeddings', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
embeddings,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
)
|
||||
|
||||
const chatQuestionRecipe = new ChatQuestion(() => {})
|
||||
@ -344,7 +349,7 @@ describe('Transcript', () => {
|
||||
)
|
||||
transcript.addInteraction(thirdInteraction)
|
||||
|
||||
const prompt = await transcript.toPrompt()
|
||||
const { prompt } = await transcript.getPromptForLastInteraction()
|
||||
const expectedPrompt = [
|
||||
{ speaker: 'human', text: 'how do batch changes work in sourcegraph' },
|
||||
{ speaker: 'assistant', text: 'Smartly.' },
|
||||
|
||||
@ -222,7 +222,14 @@ export const useClient = ({
|
||||
}
|
||||
|
||||
const unifiedContextFetcherClient = new UnifiedContextFetcherClient(graphqlClient, repoIds)
|
||||
const codebaseContext = new CodebaseContext(config, undefined, null, null, unifiedContextFetcherClient)
|
||||
const codebaseContext = new CodebaseContext(
|
||||
config,
|
||||
undefined,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
unifiedContextFetcherClient
|
||||
)
|
||||
|
||||
const { humanChatInput = '', prefilledOptions } = options ?? {}
|
||||
// TODO(naman): save scope with each interaction
|
||||
@ -242,10 +249,13 @@ export const useClient = ({
|
||||
setIsMessageInProgressState(true)
|
||||
onEvent?.('submit')
|
||||
|
||||
const prompt = await transcript.toPrompt(getMultiRepoPreamble(repoNames))
|
||||
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(
|
||||
getMultiRepoPreamble(repoNames)
|
||||
)
|
||||
transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
|
||||
const responsePrefix = interaction.getAssistantMessage().prefix ?? ''
|
||||
let rawText = ''
|
||||
|
||||
return new Promise(resolve => {
|
||||
chatClient.chat(prompt, {
|
||||
onChange(_rawText) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { Configuration } from '../configuration'
|
||||
import { EmbeddingsSearch } from '../embeddings'
|
||||
import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context'
|
||||
import { FilenameContextFetcher, KeywordContextFetcher, ContextResult } from '../local-context'
|
||||
import { isMarkdownFile, populateCodeContextTemplate, populateMarkdownContextTemplate } from '../prompt/templates'
|
||||
import { Message } from '../sourcegraph-api'
|
||||
import { EmbeddingsSearchResult } from '../sourcegraph-api/graphql/client'
|
||||
@ -21,7 +21,9 @@ export class CodebaseContext {
|
||||
private codebase: string | undefined,
|
||||
private embeddings: EmbeddingsSearch | null,
|
||||
private keywords: KeywordContextFetcher | null,
|
||||
private unifiedContextFetcher?: UnifiedContextFetcher | null
|
||||
private filenames: FilenameContextFetcher | null,
|
||||
private unifiedContextFetcher?: UnifiedContextFetcher | null,
|
||||
private rerank?: (query: string, results: ContextResult[]) => Promise<ContextResult[]>
|
||||
) {}
|
||||
|
||||
public getCodebase(): string | undefined {
|
||||
@ -32,18 +34,28 @@ export class CodebaseContext {
|
||||
this.config = newConfig
|
||||
}
|
||||
|
||||
private mergeContextResults(keywordResults: ContextResult[], filenameResults: ContextResult[]): ContextResult[] {
|
||||
// Just take the single most relevant filename suggestion for now. Otherwise, because our reranking relies solely
|
||||
// on filename, the filename results would dominate the keyword results.
|
||||
return filenameResults.slice(-1).concat(keywordResults)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of context messages for a given query, sorted in *reverse* order of importance (that is,
|
||||
* the most important context message appears *last*)
|
||||
*/
|
||||
public async getContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
|
||||
switch (this.config.useContext) {
|
||||
case 'unified':
|
||||
return this.getUnifiedContextMessages(query, options)
|
||||
case 'keyword':
|
||||
return this.getKeywordContextMessages(query, options)
|
||||
return this.getLocalContextMessages(query, options)
|
||||
case 'none':
|
||||
return []
|
||||
default:
|
||||
return this.embeddings
|
||||
? this.getEmbeddingsContextMessages(query, options)
|
||||
: this.getKeywordContextMessages(query, options)
|
||||
: this.getLocalContextMessages(query, options)
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +70,7 @@ export class CodebaseContext {
|
||||
public async getSearchResults(
|
||||
query: string,
|
||||
options: ContextSearchOptions
|
||||
): Promise<{ results: KeywordContextFetcherResult[] | EmbeddingsSearchResult[]; endpoint: string }> {
|
||||
): Promise<{ results: ContextResult[] | EmbeddingsSearchResult[]; endpoint: string }> {
|
||||
if (this.embeddings && this.config.useContext !== 'keyword') {
|
||||
return {
|
||||
results: await this.getEmbeddingSearchResults(query, options),
|
||||
@ -141,22 +153,29 @@ export class CodebaseContext {
|
||||
})
|
||||
}
|
||||
|
||||
private async getKeywordContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
|
||||
const results = await this.getKeywordSearchResults(query, options)
|
||||
return results.flatMap(({ content, fileName, repoName, revision }) => {
|
||||
const messageText = populateCodeContextTemplate(content, fileName)
|
||||
return getContextMessageWithResponse(messageText, { fileName, repoName, revision })
|
||||
})
|
||||
private async getLocalContextMessages(query: string, options: ContextSearchOptions): Promise<ContextMessage[]> {
|
||||
const keywordResults = this.getKeywordSearchResults(query, options)
|
||||
const filenameResults = this.getFilenameSearchResults(query, options)
|
||||
const combinedResults = this.mergeContextResults(await keywordResults, await filenameResults)
|
||||
const rerankedResults = await (this.rerank ? this.rerank(query, combinedResults) : combinedResults)
|
||||
const messages = resultsToMessages(rerankedResults)
|
||||
return messages
|
||||
}
|
||||
|
||||
private async getKeywordSearchResults(
|
||||
query: string,
|
||||
options: ContextSearchOptions
|
||||
): Promise<KeywordContextFetcherResult[]> {
|
||||
private async getKeywordSearchResults(query: string, options: ContextSearchOptions): Promise<ContextResult[]> {
|
||||
if (!this.keywords) {
|
||||
return []
|
||||
}
|
||||
return this.keywords.getContext(query, options.numCodeResults + options.numTextResults)
|
||||
const results = await this.keywords.getContext(query, options.numCodeResults + options.numTextResults)
|
||||
return results
|
||||
}
|
||||
|
||||
private async getFilenameSearchResults(query: string, options: ContextSearchOptions): Promise<ContextResult[]> {
|
||||
if (!this.filenames) {
|
||||
return []
|
||||
}
|
||||
const results = await this.filenames.getContext(query, options.numCodeResults + options.numTextResults)
|
||||
return results
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,3 +220,10 @@ function mergeConsecutiveResults(results: EmbeddingsSearchResult[]): string[] {
|
||||
|
||||
return mergedResults
|
||||
}
|
||||
|
||||
function resultsToMessages(results: ContextResult[]): ContextMessage[] {
|
||||
return results.flatMap(({ content, fileName, repoName, revision }) => {
|
||||
const messageText = populateCodeContextTemplate(content, fileName)
|
||||
return getContextMessageWithResponse(messageText, { fileName, repoName, revision })
|
||||
})
|
||||
}
|
||||
|
||||
93
client/cody-shared/src/codebase-context/rerank.ts
Normal file
93
client/cody-shared/src/codebase-context/rerank.ts
Normal file
@ -0,0 +1,93 @@
|
||||
import { parseStringPromise } from 'xml2js'
|
||||
|
||||
import { ChatClient } from '../chat/chat'
|
||||
import { ContextResult } from '../local-context'
|
||||
|
||||
export interface Reranker {
|
||||
rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]>
|
||||
}
|
||||
|
||||
export class MockReranker implements Reranker {
|
||||
constructor(private rerank_: (userQuery: string, results: ContextResult[]) => Promise<ContextResult[]>) {}
|
||||
public rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]> {
|
||||
return this.rerank_(userQuery, results)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A reranker class that uses a LLM to boost high-relevance results.
|
||||
*/
|
||||
export class LLMReranker implements Reranker {
|
||||
constructor(private chatClient: ChatClient) {}
|
||||
public async rerank(userQuery: string, results: ContextResult[]): Promise<ContextResult[]> {
|
||||
// Reverse the results so the most important appears first
|
||||
results = [...results].reverse()
|
||||
|
||||
let out = await new Promise<string>((resolve, reject) => {
|
||||
let responseText = ''
|
||||
this.chatClient.chat(
|
||||
[
|
||||
{
|
||||
speaker: 'human',
|
||||
text: `I am a professional computer programmer and need help deciding which of these files to read first to answer my question. My question is <userQuestion>${userQuery}</userQuestion>. Select the files from the following list that I should read to answer my question, ranked by most relevant first. Format the result as XML, like this: <list><item><filename>filename 1</filename><explanation>this is why I chose this item</explanation></item><item><filename>filename 2</filename><explanation>why I chose this item</explanation></item></list>\n${results
|
||||
.map(r => r.fileName)
|
||||
.join('\n')}`,
|
||||
},
|
||||
],
|
||||
{
|
||||
onChange: (text: string) => {
|
||||
responseText = text
|
||||
},
|
||||
onComplete: () => {
|
||||
resolve(responseText)
|
||||
},
|
||||
onError: (message: string, statusCode?: number) => {
|
||||
reject(new Error(`Status code ${statusCode}: ${message}`))
|
||||
},
|
||||
},
|
||||
{
|
||||
temperature: 0,
|
||||
fast: true,
|
||||
}
|
||||
)
|
||||
})
|
||||
if (out.indexOf('<list>') > 0) {
|
||||
out = out.slice(out.indexOf('<list>'))
|
||||
}
|
||||
if (out.indexOf('</list>') !== out.length - '</list>'.length) {
|
||||
out = out.slice(0, out.indexOf('</list>') + '</list>'.length)
|
||||
}
|
||||
const boostedFilenames = await parseXml(out)
|
||||
|
||||
const resultsMap = Object.fromEntries(results.map(r => [r.fileName, r]))
|
||||
const boostedNames = new Set<string>()
|
||||
const rerankedResults = []
|
||||
for (const boostedFilename of boostedFilenames) {
|
||||
const boostedResult = resultsMap[boostedFilename]
|
||||
if (!boostedResult) {
|
||||
continue
|
||||
}
|
||||
rerankedResults.push(boostedResult)
|
||||
boostedNames.add(boostedFilename)
|
||||
}
|
||||
for (const result of results) {
|
||||
if (!boostedNames.has(result.fileName)) {
|
||||
rerankedResults.push(result)
|
||||
}
|
||||
}
|
||||
|
||||
rerankedResults.reverse()
|
||||
return rerankedResults
|
||||
}
|
||||
}
|
||||
|
||||
async function parseXml(xml: string): Promise<string[]> {
|
||||
const result = await parseStringPromise(xml)
|
||||
const items = result.list.item
|
||||
const files: { filename: string; explanation: string }[] = items.map((item: any) => ({
|
||||
filename: item.filename[0],
|
||||
explanation: item.explanation[0],
|
||||
}))
|
||||
|
||||
return files.map(f => f.filename)
|
||||
}
|
||||
@ -1,11 +0,0 @@
|
||||
export interface KeywordContextFetcherResult {
|
||||
repoName?: string
|
||||
revision?: string
|
||||
fileName: string
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface KeywordContextFetcher {
|
||||
getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]>
|
||||
getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]>
|
||||
}
|
||||
15
client/cody-shared/src/local-context/index.ts
Normal file
15
client/cody-shared/src/local-context/index.ts
Normal file
@ -0,0 +1,15 @@
|
||||
export interface ContextResult {
|
||||
repoName?: string
|
||||
revision?: string
|
||||
fileName: string
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface KeywordContextFetcher {
|
||||
getContext(query: string, numResults: number): Promise<ContextResult[]>
|
||||
getSearchContext(query: string, numResults: number): Promise<ContextResult[]>
|
||||
}
|
||||
|
||||
export interface FilenameContextFetcher {
|
||||
getContext(query: string, numResults: number): Promise<ContextResult[]>
|
||||
}
|
||||
@ -24,6 +24,7 @@ export interface CompletionResponse {
|
||||
}
|
||||
|
||||
export interface CompletionParameters {
|
||||
fast?: boolean
|
||||
messages: Message[]
|
||||
maxTokensToSample: number
|
||||
temperature?: number
|
||||
|
||||
@ -4,7 +4,7 @@ import { CodebaseContext } from '../codebase-context'
|
||||
import { ActiveTextEditor, ActiveTextEditorSelection, ActiveTextEditorVisibleContent, Editor } from '../editor'
|
||||
import { EmbeddingsSearch } from '../embeddings'
|
||||
import { IntentDetector } from '../intent-detector'
|
||||
import { KeywordContextFetcher, KeywordContextFetcherResult } from '../keyword-context'
|
||||
import { KeywordContextFetcher, ContextResult } from '../local-context'
|
||||
import { EmbeddingsSearchResults } from '../sourcegraph-api/graphql'
|
||||
|
||||
export class MockEmbeddingsClient implements EmbeddingsSearch {
|
||||
@ -37,11 +37,11 @@ export class MockIntentDetector implements IntentDetector {
|
||||
export class MockKeywordContextFetcher implements KeywordContextFetcher {
|
||||
constructor(private mocks: Partial<KeywordContextFetcher> = {}) {}
|
||||
|
||||
public getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
|
||||
public getContext(query: string, numResults: number): Promise<ContextResult[]> {
|
||||
return this.mocks.getContext?.(query, numResults) ?? Promise.resolve([])
|
||||
}
|
||||
|
||||
public getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
|
||||
public getSearchContext(query: string, numResults: number): Promise<ContextResult[]> {
|
||||
return this.mocks.getSearchContext?.(query, numResults) ?? Promise.resolve([])
|
||||
}
|
||||
}
|
||||
@ -111,7 +111,8 @@ export function newRecipeContext(args?: Partial<RecipeContext>): RecipeContext {
|
||||
{ useContext: 'none', serverEndpoint: 'https://example.com' },
|
||||
'dummy-codebase',
|
||||
defaultEmbeddingsClient,
|
||||
defaultKeywordContextFetcher
|
||||
defaultKeywordContextFetcher,
|
||||
null
|
||||
),
|
||||
responseMultiplexer: args.responseMultiplexer || new BotResponseMultiplexer(),
|
||||
firstInteraction: args.firstInteraction ?? false,
|
||||
|
||||
@ -44,7 +44,8 @@ export async function handleHumanMessage(event: AppMentionEvent, appContext: App
|
||||
const response = await slackHelpers.postMessage(IN_PROGRESS_MESSAGE, channel, thread_ts)
|
||||
|
||||
// Generate a prompt and start completion streaming
|
||||
const prompt = await transcript.toPrompt(SLACK_PREAMBLE)
|
||||
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(SLACK_PREAMBLE)
|
||||
transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
console.log('PROMPT', prompt)
|
||||
startCompletionStreaming(prompt, channel, transcript, response?.ts)
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { memoize } from 'lodash'
|
||||
|
||||
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
|
||||
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
|
||||
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/keyword-context'
|
||||
import { KeywordContextFetcher } from '@sourcegraph/cody-shared/src/local-context'
|
||||
import { isError } from '@sourcegraph/cody-shared/src/utils'
|
||||
|
||||
import { sourcegraphClient } from './sourcegraph-client'
|
||||
@ -36,7 +36,8 @@ export async function createCodebaseContext(
|
||||
{ useContext: contextType, serverEndpoint },
|
||||
codebase,
|
||||
embeddingsSearch,
|
||||
new LocalKeywordContextFetcherMock()
|
||||
new LocalKeywordContextFetcherMock(),
|
||||
null
|
||||
)
|
||||
|
||||
return codebaseContext
|
||||
|
||||
@ -56,7 +56,7 @@ class SlackInteraction {
|
||||
}
|
||||
|
||||
public getTranscriptInteraction() {
|
||||
return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages))
|
||||
return new Interaction(this.humanMessage, this.assistantMessage, Promise.resolve(this.contextMessages), [])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
7
client/cody/BUILD.bazel
generated
7
client/cody/BUILD.bazel
generated
@ -52,9 +52,11 @@ ts_project(
|
||||
"src/extension.ts",
|
||||
"src/extension-api.ts",
|
||||
"src/external-services.ts",
|
||||
"src/keyword-context/local-keyword-context-fetcher.ts",
|
||||
"src/local-app-detector.ts",
|
||||
"src/local-context/filename-context-fetcher.ts",
|
||||
"src/local-context/local-keyword-context-fetcher.ts",
|
||||
"src/log.ts",
|
||||
"src/logged-rerank.ts",
|
||||
"src/main.ts",
|
||||
"src/non-stop/FixupCodeLenses.ts",
|
||||
"src/non-stop/FixupContentStore.ts",
|
||||
@ -111,6 +113,7 @@ ts_project(
|
||||
"//:node_modules/@storybook/react", #keep
|
||||
"//:node_modules/@types/classnames",
|
||||
"//:node_modules/@types/jest", #keep
|
||||
"//:node_modules/@types/lodash",
|
||||
"//:node_modules/@types/lru-cache",
|
||||
"//:node_modules/@types/node",
|
||||
"//:node_modules/@types/react",
|
||||
@ -121,6 +124,7 @@ ts_project(
|
||||
"//:node_modules/@vscode",
|
||||
"//:node_modules/@vscode/webview-ui-toolkit",
|
||||
"//:node_modules/classnames",
|
||||
"//:node_modules/lodash",
|
||||
"//:node_modules/react",
|
||||
"//:node_modules/react-dom",
|
||||
"//:node_modules/stream-json",
|
||||
@ -140,7 +144,6 @@ ts_project(
|
||||
"src/completions/context.test.ts",
|
||||
"src/completions/provider.test.ts",
|
||||
"src/configuration.test.ts",
|
||||
"src/keyword-context/local-keyword-context-fetcher.test.ts",
|
||||
"src/non-stop/diff.test.ts",
|
||||
"src/non-stop/tracked-range.test.ts",
|
||||
"src/non-stop/utils.test.ts",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { spawnSync } from 'child_process'
|
||||
import path from 'path'
|
||||
|
||||
import * as vscode from 'vscode'
|
||||
@ -11,17 +12,24 @@ import { ChatMessage, ChatHistory } from '@sourcegraph/cody-shared/src/chat/tran
|
||||
import { reformatBotMessage } from '@sourcegraph/cody-shared/src/chat/viewHelpers'
|
||||
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
|
||||
import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration'
|
||||
import { Editor } from '@sourcegraph/cody-shared/src/editor'
|
||||
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
|
||||
import { Guardrails, annotateAttribution } from '@sourcegraph/cody-shared/src/guardrails'
|
||||
import { highlightTokens } from '@sourcegraph/cody-shared/src/hallucinations-detector'
|
||||
import { IntentDetector } from '@sourcegraph/cody-shared/src/intent-detector'
|
||||
import { Message } from '@sourcegraph/cody-shared/src/sourcegraph-api'
|
||||
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
|
||||
import { isError } from '@sourcegraph/cody-shared/src/utils'
|
||||
|
||||
import { View } from '../../webviews/NavBar'
|
||||
import { getFullConfig, updateConfiguration } from '../configuration'
|
||||
import { VSCodeEditor } from '../editor/vscode-editor'
|
||||
import { logEvent } from '../event-logger'
|
||||
import { LocalAppDetector } from '../local-app-detector'
|
||||
import { FilenameContextFetcher } from '../local-context/filename-context-fetcher'
|
||||
import { LocalKeywordContextFetcher } from '../local-context/local-keyword-context-fetcher'
|
||||
import { debug } from '../log'
|
||||
import { getRerankWithLog } from '../logged-rerank'
|
||||
import { FixupTask } from '../non-stop/FixupTask'
|
||||
import { LocalStorage } from '../services/LocalStorageProvider'
|
||||
import { CODY_ACCESS_TOKEN_SECRET, SecretStorage } from '../services/SecretStorageProvider'
|
||||
@ -38,7 +46,7 @@ import {
|
||||
isLoggedIn,
|
||||
} from './protocol'
|
||||
import { getRecipe } from './recipes'
|
||||
import { getAuthStatus, getCodebaseContext } from './utils'
|
||||
import { convertGitCloneURLToCodebaseName, getAuthStatus } from './utils'
|
||||
|
||||
export type Config = Pick<
|
||||
ConfigurationWithAccessToken,
|
||||
@ -107,7 +115,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
|
||||
}),
|
||||
vscode.workspace.onDidChangeConfiguration(async () => {
|
||||
this.config = await getFullConfig(this.secretStorage)
|
||||
const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor)
|
||||
const newCodebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, chat)
|
||||
if (newCodebaseContext) {
|
||||
this.codebaseContext = newCodebaseContext
|
||||
await this.setAnonymousUserID()
|
||||
@ -376,7 +384,7 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
|
||||
}
|
||||
this.currentWorkspaceRoot = workspaceRoot
|
||||
|
||||
const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor)
|
||||
const codebaseContext = await getCodebaseContext(this.config, this.rgPath, this.editor, this.chat)
|
||||
if (!codebaseContext) {
|
||||
return
|
||||
}
|
||||
@ -430,12 +438,14 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
|
||||
default: {
|
||||
this.sendTranscript()
|
||||
|
||||
const prompt = await this.transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase()))
|
||||
const { prompt, contextFiles } = await this.transcript.getPromptForLastInteraction(
|
||||
getPreamble(this.codebaseContext.getCodebase())
|
||||
)
|
||||
this.transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
this.sendPrompt(prompt, interaction.getAssistantMessage().prefix ?? '')
|
||||
await this.saveTranscriptToChatHistory()
|
||||
}
|
||||
}
|
||||
|
||||
logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`)
|
||||
}
|
||||
|
||||
@ -460,7 +470,10 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
|
||||
}
|
||||
transcript.addInteraction(interaction)
|
||||
|
||||
const prompt = await transcript.toPrompt(getPreamble(this.codebaseContext.getCodebase()))
|
||||
const { prompt, contextFiles } = await transcript.getPromptForLastInteraction(
|
||||
getPreamble(this.codebaseContext.getCodebase())
|
||||
)
|
||||
transcript.setUsedContextFilesForLastInteraction(contextFiles)
|
||||
|
||||
logEvent(`CodyVSCodeExtension:recipe:${recipe.id}:executed`)
|
||||
|
||||
@ -818,3 +831,49 @@ export class ChatViewProvider implements vscode.WebviewViewProvider, vscode.Disp
|
||||
this.disposables = []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets codebase context for the current workspace.
|
||||
*
|
||||
* @param config Cody configuration
|
||||
* @param rgPath Path to rg (ripgrep) executable
|
||||
* @param editor Editor instance
|
||||
* @returns CodebaseContext if a codebase can be determined, else null
|
||||
*/
|
||||
export async function getCodebaseContext(
|
||||
config: Config,
|
||||
rgPath: string,
|
||||
editor: Editor,
|
||||
chatClient: ChatClient
|
||||
): Promise<CodebaseContext | null> {
|
||||
const client = new SourcegraphGraphQLAPIClient(config)
|
||||
const workspaceRoot = editor.getWorkspaceRootPath()
|
||||
if (!workspaceRoot) {
|
||||
return null
|
||||
}
|
||||
const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot })
|
||||
const gitOutput = gitCommand.stdout.toString().trim()
|
||||
// Get codebase from config or fallback to getting repository name from git clone URL
|
||||
const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput)
|
||||
if (!codebase) {
|
||||
return null
|
||||
}
|
||||
// Check if repo is embedded in endpoint
|
||||
const repoId = await client.getRepoIdIfEmbeddingExists(codebase)
|
||||
if (isError(repoId)) {
|
||||
const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n`
|
||||
console.info(infoMessage)
|
||||
return null
|
||||
}
|
||||
|
||||
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
|
||||
return new CodebaseContext(
|
||||
config,
|
||||
codebase,
|
||||
embeddingsSearch,
|
||||
new LocalKeywordContextFetcher(rgPath, editor, chatClient),
|
||||
new FilenameContextFetcher(rgPath, editor, chatClient),
|
||||
undefined,
|
||||
getRerankWithLog(chatClient)
|
||||
)
|
||||
}
|
||||
|
||||
@ -12,10 +12,11 @@ import * as path from 'path'
|
||||
export async function fastFilesExist(
|
||||
rgPath: string,
|
||||
rootPath: string,
|
||||
filePaths: string[]
|
||||
filePaths: string[],
|
||||
maxDepth?: number
|
||||
): Promise<{ [filePath: string]: boolean }> {
|
||||
const searchPattern = constructSearchPattern(filePaths)
|
||||
const rgOutput = await executeRg(rgPath, rootPath, searchPattern)
|
||||
const rgOutput = await executeRg(rgPath, rootPath, searchPattern, maxDepth)
|
||||
return processRgOutput(rgOutput, filePaths)
|
||||
}
|
||||
|
||||
@ -28,6 +29,7 @@ export function makeTrimRegex(sep: string): RegExp {
|
||||
|
||||
// Regex to match '**', '*' or path.sep at the start (^) or end ($) of the string.
|
||||
const trimRegex = makeTrimRegex(path.sep)
|
||||
|
||||
/**
|
||||
* Constructs a search pattern for the 'rg' tool.
|
||||
*
|
||||
@ -42,6 +44,7 @@ function constructSearchPattern(filePaths: string[]): string {
|
||||
})
|
||||
return `{${searchPatternParts.join(',')}}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the 'rg' tool and returns the output.
|
||||
*
|
||||
@ -50,11 +53,15 @@ function constructSearchPattern(filePaths: string[]): string {
|
||||
* @param searchPattern - The search pattern to use.
|
||||
* @returns The output from the 'rg' tool.
|
||||
*/
|
||||
async function executeRg(rgPath: string, rootPath: string, searchPattern: string): Promise<string> {
|
||||
async function executeRg(rgPath: string, rootPath: string, searchPattern: string, maxDepth?: number): Promise<string> {
|
||||
const args = ['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global']
|
||||
if (maxDepth !== undefined) {
|
||||
args.push('--max-depth', `${maxDepth}`)
|
||||
}
|
||||
return new Promise((resolve, reject) => {
|
||||
execFile(
|
||||
rgPath,
|
||||
['--files', '-g', searchPattern, '--crlf', '--fixed-strings', '--no-config', '--no-ignore-global'],
|
||||
args,
|
||||
{
|
||||
cwd: rootPath,
|
||||
maxBuffer: 1024 * 1024 * 1024,
|
||||
|
||||
@ -1,15 +1,7 @@
|
||||
import { spawnSync } from 'child_process'
|
||||
|
||||
import { CodebaseContext } from '@sourcegraph/cody-shared/src/codebase-context'
|
||||
import { ConfigurationWithAccessToken } from '@sourcegraph/cody-shared/src/configuration'
|
||||
import { Editor } from '@sourcegraph/cody-shared/src/editor'
|
||||
import { SourcegraphEmbeddingsSearchClient } from '@sourcegraph/cody-shared/src/embeddings/client'
|
||||
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
|
||||
import { isError } from '@sourcegraph/cody-shared/src/utils'
|
||||
|
||||
import { LocalKeywordContextFetcher } from '../keyword-context/local-keyword-context-fetcher'
|
||||
|
||||
import { Config } from './ChatViewProvider'
|
||||
import { AuthStatus, defaultAuthStatus, isLocalApp, unauthenticatedStatus } from './protocol'
|
||||
|
||||
// Converts a git clone URL to the codebase name that includes the slash-separated code host, owner, and repository name
|
||||
@ -50,40 +42,11 @@ export function convertGitCloneURLToCodebaseName(cloneURL: string): string | nul
|
||||
}
|
||||
return null
|
||||
} catch (error) {
|
||||
console.log(`Cody could not extract repo name from clone URL ${cloneURL}:`, error)
|
||||
console.error(`Cody could not extract repo name from clone URL ${cloneURL}:`, error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCodebaseContext(
|
||||
config: Config,
|
||||
rgPath: string,
|
||||
editor: Editor
|
||||
): Promise<CodebaseContext | null> {
|
||||
const client = new SourcegraphGraphQLAPIClient(config)
|
||||
const workspaceRoot = editor.getWorkspaceRootPath()
|
||||
if (!workspaceRoot) {
|
||||
return null
|
||||
}
|
||||
const gitCommand = spawnSync('git', ['remote', 'get-url', 'origin'], { cwd: workspaceRoot })
|
||||
const gitOutput = gitCommand.stdout.toString().trim()
|
||||
// Get codebase from config or fallback to getting repository name from git clone URL
|
||||
const codebase = config.codebase || convertGitCloneURLToCodebaseName(gitOutput)
|
||||
if (!codebase) {
|
||||
return null
|
||||
}
|
||||
// Check if repo is embedded in endpoint
|
||||
const repoId = await client.getRepoIdIfEmbeddingExists(codebase)
|
||||
if (isError(repoId)) {
|
||||
const infoMessage = `Cody could not find embeddings for '${codebase}' on your Sourcegraph instance.\n`
|
||||
console.info(infoMessage)
|
||||
return null
|
||||
}
|
||||
|
||||
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
|
||||
return new CodebaseContext(config, codebase, embeddingsSearch, new LocalKeywordContextFetcher(rgPath, editor))
|
||||
}
|
||||
|
||||
let client: SourcegraphGraphQLAPIClient
|
||||
let configWithToken: Pick<ConfigurationWithAccessToken, 'serverEndpoint' | 'accessToken' | 'customHeaders'>
|
||||
|
||||
|
||||
@ -12,8 +12,10 @@ import { SourcegraphNodeCompletionsClient } from '@sourcegraph/cody-shared/src/s
|
||||
import { SourcegraphGraphQLAPIClient } from '@sourcegraph/cody-shared/src/sourcegraph-api/graphql'
|
||||
import { isError } from '@sourcegraph/cody-shared/src/utils'
|
||||
|
||||
import { LocalKeywordContextFetcher } from './keyword-context/local-keyword-context-fetcher'
|
||||
import { FilenameContextFetcher } from './local-context/filename-context-fetcher'
|
||||
import { LocalKeywordContextFetcher } from './local-context/local-keyword-context-fetcher'
|
||||
import { logger } from './log'
|
||||
import { getRerankWithLog } from './logged-rerank'
|
||||
|
||||
interface ExternalServices {
|
||||
intentDetector: IntentDetector
|
||||
@ -48,11 +50,15 @@ export async function configureExternalServices(
|
||||
}
|
||||
const embeddingsSearch = repoId && !isError(repoId) ? new SourcegraphEmbeddingsSearchClient(client, repoId) : null
|
||||
|
||||
const chatClient = new ChatClient(completions)
|
||||
const codebaseContext = new CodebaseContext(
|
||||
initialConfig,
|
||||
initialConfig.codebase,
|
||||
embeddingsSearch,
|
||||
new LocalKeywordContextFetcher(rgPath, editor)
|
||||
new LocalKeywordContextFetcher(rgPath, editor, chatClient),
|
||||
new FilenameContextFetcher(rgPath, editor, chatClient),
|
||||
undefined,
|
||||
getRerankWithLog(chatClient)
|
||||
)
|
||||
|
||||
const guardrails = new SourcegraphGuardrailsClient(client)
|
||||
@ -60,7 +66,7 @@ export async function configureExternalServices(
|
||||
return {
|
||||
intentDetector: new SourcegraphIntentDetectorClient(client),
|
||||
codebaseContext,
|
||||
chatClient: new ChatClient(completions),
|
||||
chatClient,
|
||||
completionsClient: completions,
|
||||
guardrails,
|
||||
onConfigurationChange: newConfig => {
|
||||
|
||||
@ -1,153 +0,0 @@
|
||||
import * as assert from 'assert'
|
||||
|
||||
import { Term, regexForTerms, userQueryToKeywordQuery } from './local-keyword-context-fetcher'
|
||||
|
||||
describe('keyword context', () => {
|
||||
it('userQueryToKeywordQuery', () => {
|
||||
const cases: { query: string; expected: Term[] }[] = [
|
||||
{
|
||||
query: 'Where is auth in Sourcegraph?',
|
||||
expected: [
|
||||
{
|
||||
count: 1,
|
||||
originals: ['Where', 'Where'],
|
||||
prefix: 'where',
|
||||
stem: 'where',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['auth', 'auth'],
|
||||
prefix: 'auth',
|
||||
stem: 'auth',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['Sourcegraph', 'Sourcegraph'],
|
||||
prefix: 'sourcegraph',
|
||||
stem: 'sourcegraph',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
query: `Explain the following code at a high level:
|
||||
uint32_t PackUInt32(const Color& color) {
|
||||
uint32_t result = 0;
|
||||
result |= static_cast<uint32_t>(color.r * 255 + 0.5f) << 24;
|
||||
result |= static_cast<uint32_t>(color.g * 255 + 0.5f) << 16;
|
||||
result |= static_cast<uint32_t>(color.b * 255 + 0.5f) << 8;
|
||||
result |= static_cast<uint32_t>(color.a * 255 + 0.5f);
|
||||
return result;
|
||||
}
|
||||
`,
|
||||
expected: [
|
||||
{
|
||||
count: 1,
|
||||
originals: ['Explain', 'Explain'],
|
||||
prefix: 'explain',
|
||||
stem: 'explain',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['following', 'following'],
|
||||
prefix: 'follow',
|
||||
stem: 'follow',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['code', 'code'],
|
||||
prefix: 'code',
|
||||
stem: 'code',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['high', 'high'],
|
||||
prefix: 'high',
|
||||
stem: 'high',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['level', 'level'],
|
||||
prefix: 'level',
|
||||
stem: 'level',
|
||||
},
|
||||
{
|
||||
count: 6,
|
||||
originals: ['uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t', 'uint32_t'],
|
||||
prefix: 'uint',
|
||||
stem: 'uinty2_t',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['PackUInt32', 'PackUInt32'],
|
||||
prefix: 'packuint',
|
||||
stem: 'packuinty2',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['const', 'const'],
|
||||
prefix: 'const',
|
||||
stem: 'const',
|
||||
},
|
||||
{
|
||||
count: 6,
|
||||
originals: ['Color', 'Color', 'color', 'color', 'color', 'color', 'color'],
|
||||
prefix: 'color',
|
||||
stem: 'color',
|
||||
},
|
||||
{
|
||||
count: 6,
|
||||
originals: ['result', 'result', 'result', 'result', 'result', 'result', 'result'],
|
||||
prefix: 'result',
|
||||
stem: 'result',
|
||||
},
|
||||
{
|
||||
count: 4,
|
||||
originals: ['static_cast', 'static_cast', 'static_cast', 'static_cast', 'static_cast'],
|
||||
prefix: 'static_cast',
|
||||
stem: 'static_cast',
|
||||
},
|
||||
{
|
||||
count: 4,
|
||||
originals: ['255', '255', '255', '255', '255'],
|
||||
prefix: '255',
|
||||
stem: '255',
|
||||
},
|
||||
{
|
||||
count: 1,
|
||||
originals: ['return', 'return'],
|
||||
prefix: 'return',
|
||||
stem: 'return',
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
for (const testcase of cases) {
|
||||
const actual = userQueryToKeywordQuery(testcase.query)
|
||||
assert.deepStrictEqual(actual, testcase.expected)
|
||||
}
|
||||
})
|
||||
it('query to regex', () => {
|
||||
const trials: {
|
||||
userQuery: string
|
||||
expRegex: string
|
||||
}[] = [
|
||||
{
|
||||
userQuery: 'Where is auth in Sourcegraph?',
|
||||
expRegex: '(?:where|auth|sourcegraph)',
|
||||
},
|
||||
{
|
||||
userQuery: 'saml auth handler',
|
||||
expRegex: '(?:saml|auth|handler)',
|
||||
},
|
||||
{
|
||||
userQuery: 'Where is the HTTP middleware defined in this codebase?',
|
||||
expRegex: '(?:where|http|middlewar|defin|codebas)',
|
||||
},
|
||||
]
|
||||
for (const trial of trials) {
|
||||
const terms = userQueryToKeywordQuery(trial.userQuery)
|
||||
const regex = regexForTerms(...terms)
|
||||
expect(regex).toEqual(trial.expRegex)
|
||||
}
|
||||
})
|
||||
})
|
||||
154
client/cody/src/local-context/filename-context-fetcher.ts
Normal file
154
client/cody/src/local-context/filename-context-fetcher.ts
Normal file
@ -0,0 +1,154 @@
|
||||
import { execFile } from 'child_process'
|
||||
import * as path from 'path'
|
||||
|
||||
import { uniq } from 'lodash'
|
||||
import * as vscode from 'vscode'
|
||||
|
||||
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
|
||||
import { Editor } from '@sourcegraph/cody-shared/src/editor'
|
||||
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
|
||||
|
||||
import { debug } from '../log'
|
||||
|
||||
/**
|
||||
* A local context fetcher that uses a LLM to generate filename fragments, which are then used to
|
||||
* find files that are relevant based on their path or name.
|
||||
*/
|
||||
export class FilenameContextFetcher {
|
||||
constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {}
|
||||
|
||||
/**
|
||||
* Returns pieces of context relevant for the given query. Uses a filename search approach
|
||||
* @param query user query
|
||||
* @param numResults the number of context results to return
|
||||
* @returns a list of context results, sorted in *reverse* order (that is,
|
||||
* the most important result appears at the bottom)
|
||||
*/
|
||||
public async getContext(query: string, numResults: number): Promise<ContextResult[]> {
|
||||
const time0 = performance.now()
|
||||
|
||||
const rootPath = this.editor.getWorkspaceRootPath()
|
||||
if (!rootPath) {
|
||||
return []
|
||||
}
|
||||
const time1 = performance.now()
|
||||
const filenameFragments = await this.queryToFileFragments(query)
|
||||
const time2 = performance.now()
|
||||
const unsortedMatchingFiles = await this.getFilenames(rootPath, filenameFragments, 3)
|
||||
const time3 = performance.now()
|
||||
|
||||
const specialFragments = ['readme']
|
||||
const allBoostedFiles = []
|
||||
let remainingFiles = unsortedMatchingFiles
|
||||
let nextRemainingFiles = []
|
||||
for (const specialFragment of specialFragments) {
|
||||
const boostedFiles = []
|
||||
for (const fileName of remainingFiles) {
|
||||
const fileNameLower = fileName.toLocaleLowerCase()
|
||||
if (fileNameLower.includes(specialFragment)) {
|
||||
boostedFiles.push(fileName)
|
||||
} else {
|
||||
nextRemainingFiles.push(fileName)
|
||||
}
|
||||
}
|
||||
remainingFiles = nextRemainingFiles
|
||||
nextRemainingFiles = []
|
||||
allBoostedFiles.push(...boostedFiles.sort((a, b) => a.length - b.length))
|
||||
}
|
||||
|
||||
const sortedMatchingFiles = allBoostedFiles.concat(remainingFiles).slice(0, numResults)
|
||||
|
||||
const results = await Promise.all(
|
||||
sortedMatchingFiles
|
||||
.map(async fileName => {
|
||||
const uri = vscode.Uri.file(path.join(rootPath, fileName))
|
||||
const content = (await vscode.workspace.openTextDocument(uri)).getText()
|
||||
return {
|
||||
fileName,
|
||||
content,
|
||||
}
|
||||
})
|
||||
.reverse()
|
||||
)
|
||||
|
||||
const time4 = performance.now()
|
||||
debug(
|
||||
'FilenameContextFetcher:getContext',
|
||||
JSON.stringify({
|
||||
duration: time4 - time0,
|
||||
queryToFileFragments: { duration: time2 - time1, fragments: filenameFragments },
|
||||
getFilenames: { duration: time3 - time2 },
|
||||
}),
|
||||
{ verbose: { matchingFiles: unsortedMatchingFiles, results: results.map(r => r.fileName) } }
|
||||
)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
private async queryToFileFragments(query: string): Promise<string[]> {
|
||||
const filenameFragments = await new Promise<string[]>((resolve, reject) => {
|
||||
let responseText = ''
|
||||
this.chatClient.chat(
|
||||
[
|
||||
{
|
||||
speaker: 'human',
|
||||
text: `Write 3 filename fragments that would be contained by files in a git repository that are relevant to answering the following user query: <query>${query}</query> Your response should be only a space-delimited list of filename fragments and nothing else.`,
|
||||
},
|
||||
],
|
||||
{
|
||||
onChange: (text: string) => {
|
||||
responseText = text
|
||||
},
|
||||
onComplete: () => {
|
||||
resolve(responseText.split(/\s+/).filter(e => e.length > 0))
|
||||
},
|
||||
onError: (message: string, statusCode?: number) => {
|
||||
reject(new Error(message))
|
||||
},
|
||||
},
|
||||
{
|
||||
temperature: 0,
|
||||
fast: true,
|
||||
}
|
||||
)
|
||||
})
|
||||
const uniqueFragments = uniq(filenameFragments.map(e => e.toLocaleLowerCase()))
|
||||
return uniqueFragments
|
||||
}
|
||||
|
||||
private async getFilenames(rootPath: string, filenameFragments: string[], maxDepth: number): Promise<string[]> {
|
||||
const searchPattern = '{' + filenameFragments.map(fragment => `**${fragment}**`).join(',') + '}'
|
||||
const rgArgs = [
|
||||
'--files',
|
||||
'--iglob',
|
||||
searchPattern,
|
||||
'--crlf',
|
||||
'--fixed-strings',
|
||||
'--no-config',
|
||||
'--no-ignore-global',
|
||||
`--max-depth=${maxDepth}`,
|
||||
]
|
||||
const results = await new Promise<string>((resolve, reject) => {
|
||||
execFile(
|
||||
this.rgPath,
|
||||
rgArgs,
|
||||
{
|
||||
cwd: rootPath,
|
||||
maxBuffer: 1024 * 1024 * 1024,
|
||||
},
|
||||
(error, stdout, stderr) => {
|
||||
if (error?.code === 2) {
|
||||
reject(new Error(`${error.message}: ${stderr}`))
|
||||
} else {
|
||||
resolve(stdout)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
return results
|
||||
.split('\n')
|
||||
.map(r => r.trim())
|
||||
.filter(r => r.length > 0)
|
||||
.sort((a, b) => a.length - b.length)
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,17 @@
|
||||
import { execFile, spawn } from 'child_process'
|
||||
import * as path from 'path'
|
||||
|
||||
import Assembler from 'stream-json/Assembler'
|
||||
import StreamValues from 'stream-json/streamers/StreamValues'
|
||||
import * as vscode from 'vscode'
|
||||
import winkUtils from 'wink-nlp-utils'
|
||||
|
||||
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
|
||||
import { Editor } from '@sourcegraph/cody-shared/src/editor'
|
||||
import { KeywordContextFetcher, KeywordContextFetcherResult } from '@sourcegraph/cody-shared/src/keyword-context'
|
||||
import { KeywordContextFetcher, ContextResult } from '@sourcegraph/cody-shared/src/local-context'
|
||||
|
||||
import { logEvent } from '../event-logger'
|
||||
import { debug } from '../log'
|
||||
|
||||
/**
|
||||
* Exclude files without extensions and hidden files (starts with '.')
|
||||
@ -17,6 +20,7 @@ import { logEvent } from '../event-logger'
|
||||
* Note: Ripgrep excludes binary files and respects .gitignore by default
|
||||
*/
|
||||
const fileExtRipgrepParams = [
|
||||
'--ignore-case',
|
||||
'-g',
|
||||
'*.*',
|
||||
'-g',
|
||||
@ -25,10 +29,10 @@ const fileExtRipgrepParams = [
|
||||
'!*.lock',
|
||||
'-g',
|
||||
'!*.snap',
|
||||
'--threads',
|
||||
'1',
|
||||
'--max-filesize',
|
||||
'1M',
|
||||
'10K',
|
||||
'--max-depth',
|
||||
'10',
|
||||
]
|
||||
|
||||
interface RipgrepStreamData {
|
||||
@ -68,62 +72,34 @@ export function regexForTerms(...terms: Term[]): string {
|
||||
return `(?:${inner.join('|')})`
|
||||
}
|
||||
|
||||
export function userQueryToKeywordQuery(query: string): Term[] {
|
||||
const longestCommonPrefix = (s: string, t: string): string => {
|
||||
let endIdx = 0
|
||||
for (let i = 0; i < s.length && i < t.length; i++) {
|
||||
if (s[i] !== t[i]) {
|
||||
break
|
||||
}
|
||||
endIdx = i + 1
|
||||
function longestCommonPrefix(s: string, t: string): string {
|
||||
let endIdx = 0
|
||||
for (let i = 0; i < s.length && i < t.length; i++) {
|
||||
if (s[i] !== t[i]) {
|
||||
break
|
||||
}
|
||||
return s.slice(0, endIdx)
|
||||
endIdx = i + 1
|
||||
}
|
||||
|
||||
const origWords: string[] = []
|
||||
for (const chunk of query.split(/\W+/)) {
|
||||
if (chunk.trim().length === 0) {
|
||||
continue
|
||||
}
|
||||
origWords.push(...winkUtils.string.tokenize0(chunk))
|
||||
}
|
||||
const filteredWords = winkUtils.tokens.removeWords(origWords)
|
||||
const terms = new Map<string, Term>()
|
||||
for (const word of filteredWords) {
|
||||
// Ignore ASCII-only strings of length 2 or less
|
||||
if (word.length <= 2) {
|
||||
let skip = true
|
||||
for (let i = 0; i < word.length; i++) {
|
||||
if (word.charCodeAt(i) >= 128) {
|
||||
// non-ASCII
|
||||
skip = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if (skip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
const stem = winkUtils.string.stem(word)
|
||||
const term = terms.get(stem) || {
|
||||
stem,
|
||||
originals: [word],
|
||||
prefix: longestCommonPrefix(word.toLowerCase(), stem),
|
||||
count: 0,
|
||||
}
|
||||
term.originals.push(word)
|
||||
term.count++
|
||||
terms.set(stem, term)
|
||||
}
|
||||
return [...terms.values()]
|
||||
return s.slice(0, endIdx)
|
||||
}
|
||||
|
||||
/**
|
||||
* A local context fetcher that uses a LLM to generate a keyword query, which is then
|
||||
* converted to a regex fed to ripgrep to search for files that are relevant to the
|
||||
* user query.
|
||||
*/
|
||||
export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
constructor(private rgPath: string, private editor: Editor) {}
|
||||
constructor(private rgPath: string, private editor: Editor, private chatClient: ChatClient) {}
|
||||
|
||||
public async getContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
|
||||
console.log('fetching keyword matches')
|
||||
/**
|
||||
* Returns pieces of context relevant for the given query. Uses a keyword-search-based
|
||||
* approach.
|
||||
* @param query user query
|
||||
* @param numResults the number of context results to return
|
||||
* @returns a list of context results, sorted in *reverse* order (that is,
|
||||
* the most important result appears at the bottom)
|
||||
*/
|
||||
public async getContext(query: string, numResults: number): Promise<ContextResult[]> {
|
||||
const startTime = performance.now()
|
||||
const rootPath = this.editor.getWorkspaceRootPath()
|
||||
if (!rootPath) {
|
||||
@ -141,18 +117,89 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
)
|
||||
const searchDuration = performance.now() - startTime
|
||||
logEvent('CodyVSCodeExtension:keywordContext:searchDuration', searchDuration, searchDuration)
|
||||
debug('LocalKeywordContextFetcher:getContext', JSON.stringify({ searchDuration }))
|
||||
return messagePairs.reverse().flat()
|
||||
}
|
||||
|
||||
private async userQueryToExpandedKeywords(query: string): Promise<Map<string, Term>> {
|
||||
const start = performance.now()
|
||||
const keywords = await new Promise<string[]>((resolve, reject) => {
|
||||
let responseText = ''
|
||||
this.chatClient.chat(
|
||||
[
|
||||
{
|
||||
speaker: 'human',
|
||||
text: `Write 3-5 keywords that you would use to search for code snippets that are relevant to answering the following user query: <query>${query}</query> Your response should be only a list of space-delimited keywords and nothing else.`,
|
||||
},
|
||||
],
|
||||
{
|
||||
onChange: (text: string) => {
|
||||
responseText = text
|
||||
},
|
||||
onComplete: () => {
|
||||
resolve(responseText.split(/\s+/).filter(e => e.length > 0))
|
||||
},
|
||||
onError: (message: string, statusCode?: number) => {
|
||||
reject(new Error(message))
|
||||
},
|
||||
},
|
||||
{
|
||||
temperature: 0,
|
||||
fast: true,
|
||||
}
|
||||
)
|
||||
})
|
||||
const terms = new Map<string, Term>()
|
||||
for (const kw of keywords) {
|
||||
const stem = winkUtils.string.stem(kw)
|
||||
if (terms.has(stem)) {
|
||||
continue
|
||||
}
|
||||
terms.set(stem, {
|
||||
count: 1,
|
||||
originals: [kw],
|
||||
prefix: longestCommonPrefix(kw.toLowerCase(), stem),
|
||||
stem,
|
||||
})
|
||||
}
|
||||
debug(
|
||||
'LocalKeywordContextFetcher:userQueryToExpandedKeywords',
|
||||
JSON.stringify({ duration: performance.now() - start })
|
||||
)
|
||||
return terms
|
||||
}
|
||||
|
||||
private async userQueryToKeywordQuery(query: string): Promise<Term[]> {
|
||||
const terms = new Map<string, Term>()
|
||||
const keywordExpansionStartTime = Date.now()
|
||||
const expandedTerms = await this.userQueryToExpandedKeywords(query)
|
||||
const keywordExpansionDuration = Date.now() - keywordExpansionStartTime
|
||||
for (const [stem, term] of expandedTerms) {
|
||||
if (terms.has(stem)) {
|
||||
continue
|
||||
}
|
||||
terms.set(stem, term)
|
||||
}
|
||||
debug(
|
||||
'LocalKeywordContextFetcher:userQueryToKeywordQuery',
|
||||
'keyword expansion',
|
||||
JSON.stringify({
|
||||
duration: keywordExpansionDuration,
|
||||
expandedTerms: [...expandedTerms.values()].map(v => v.prefix),
|
||||
})
|
||||
)
|
||||
const ret = [...terms.values()]
|
||||
return ret
|
||||
}
|
||||
|
||||
// Return context results for the Codebase Context Search recipe
|
||||
public async getSearchContext(query: string, numResults: number): Promise<KeywordContextFetcherResult[]> {
|
||||
console.log('fetching keyword search context...')
|
||||
public async getSearchContext(query: string, numResults: number): Promise<ContextResult[]> {
|
||||
const rootPath = this.editor.getWorkspaceRootPath()
|
||||
if (!rootPath) {
|
||||
return []
|
||||
}
|
||||
|
||||
const stems = userQueryToKeywordQuery(query)
|
||||
const stems = (await this.userQueryToKeywordQuery(query))
|
||||
.map(t => (t.prefix.length < 4 ? t.originals[0] : t.prefix))
|
||||
.join('|')
|
||||
const filesnamesWithScores = await this.fetchKeywordFiles(rootPath, query)
|
||||
@ -180,8 +227,10 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
terms: Term[],
|
||||
rootPath: string
|
||||
): Promise<{ [filename: string]: { bytesSearched: number } }> {
|
||||
const start = performance.now()
|
||||
const regexQuery = `\\b${regexForTerms(...terms)}`
|
||||
const proc = spawn(this.rgPath, ['-i', ...fileExtRipgrepParams, '--json', regexQuery, './'], {
|
||||
const rgArgs = [...fileExtRipgrepParams, '--json', regexQuery, '.']
|
||||
const proc = spawn(this.rgPath, rgArgs, {
|
||||
cwd: rootPath,
|
||||
stdio: ['ignore', 'pipe', process.stderr],
|
||||
windowsHide: true,
|
||||
@ -191,21 +240,40 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
bytesSearched: number
|
||||
}
|
||||
} = {}
|
||||
|
||||
// Process the ripgrep JSON output to get the file sizes. We use an object filter to
|
||||
// fast-filter out irrelevant lines of output
|
||||
const objectFilter = (assembler: Assembler): boolean | undefined => {
|
||||
// Each ripgrep JSON line begins with the following format:
|
||||
//
|
||||
// {"type":"begin|match|end","data":"...
|
||||
//
|
||||
// We only care about the "type":"end" lines, which contain the file size in bytes.
|
||||
if (assembler.key === null && assembler.stack.length === 0 && assembler.current.type) {
|
||||
return assembler.current.type === 'end'
|
||||
}
|
||||
// return undefined to indicate our uncertainty at this moment
|
||||
return undefined
|
||||
}
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
try {
|
||||
proc.stdout
|
||||
.pipe(StreamValues.withParser())
|
||||
.pipe(StreamValues.withParser({ objectFilter }))
|
||||
.on('data', data => {
|
||||
try {
|
||||
const typedData = data as RipgrepStreamData
|
||||
switch (typedData.value.type) {
|
||||
case 'end':
|
||||
if (!fileTermCounts[typedData.value.data.path.text]) {
|
||||
fileTermCounts[typedData.value.data.path.text] = { bytesSearched: 0 }
|
||||
case 'end': {
|
||||
let filename = typedData.value.data.path.text
|
||||
if (filename.startsWith(`.${path.sep}`)) {
|
||||
filename = filename.slice(2)
|
||||
}
|
||||
fileTermCounts[typedData.value.data.path.text].bytesSearched =
|
||||
typedData.value.data.stats.bytes_searched
|
||||
if (!fileTermCounts[filename]) {
|
||||
fileTermCounts[filename] = { bytesSearched: 0 }
|
||||
}
|
||||
fileTermCounts[filename].bytesSearched = typedData.value.data.stats.bytes_searched
|
||||
break
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
reject(error)
|
||||
@ -216,6 +284,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
reject(error)
|
||||
}
|
||||
})
|
||||
debug('fetchFileStats', JSON.stringify({ duration: performance.now() - start }))
|
||||
return fileTermCounts
|
||||
}
|
||||
|
||||
@ -227,20 +296,21 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
fileTermCounts: { [filename: string]: { [stem: string]: number } }
|
||||
termTotalFiles: { [stem: string]: number }
|
||||
}> {
|
||||
const start = performance.now()
|
||||
const termFileCountsArr: { fileCounts: { [filename: string]: number }; filesSearched: number }[] =
|
||||
await Promise.all(
|
||||
queryTerms.map(async term => {
|
||||
const rgArgs = [
|
||||
...fileExtRipgrepParams,
|
||||
'--count-matches',
|
||||
'--stats',
|
||||
`\\b${regexForTerms(term)}`,
|
||||
'.',
|
||||
]
|
||||
const out = await new Promise<string>((resolve, reject) => {
|
||||
execFile(
|
||||
this.rgPath,
|
||||
[
|
||||
'-i',
|
||||
...fileExtRipgrepParams,
|
||||
'--count-matches',
|
||||
'--stats',
|
||||
`\\b${regexForTerms(term)}`,
|
||||
'./',
|
||||
],
|
||||
rgArgs,
|
||||
{
|
||||
cwd: rootPath,
|
||||
maxBuffer: 1024 * 1024 * 1024,
|
||||
@ -272,8 +342,12 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
continue
|
||||
}
|
||||
try {
|
||||
let filename = terms[0]
|
||||
if (filename.startsWith(`.${path.sep}`)) {
|
||||
filename = filename.slice(2)
|
||||
}
|
||||
const count = parseInt(terms[1], 10)
|
||||
fileCounts[terms[0]] = count
|
||||
fileCounts[filename] = count
|
||||
} catch {
|
||||
console.error(`could not parse count from ${terms[1]}`)
|
||||
}
|
||||
@ -282,6 +356,7 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
})
|
||||
)
|
||||
|
||||
debug('LocalKeywordContextFetcher.fetchFileMatches', JSON.stringify({ duration: performance.now() - start }))
|
||||
let totalFilesSearched = -1
|
||||
for (const { filesSearched } of termFileCountsArr) {
|
||||
if (totalFilesSearched >= 0 && totalFilesSearched !== filesSearched) {
|
||||
@ -316,13 +391,15 @@ export class LocalKeywordContextFetcher implements KeywordContextFetcher {
|
||||
rootPath: string,
|
||||
rawQuery: string
|
||||
): Promise<{ filename: string; score: number }[]> {
|
||||
const query = userQueryToKeywordQuery(rawQuery)
|
||||
const query = await this.userQueryToKeywordQuery(rawQuery)
|
||||
|
||||
const fetchFilesStart = performance.now()
|
||||
const fileMatchesPromise = this.fetchFileMatches(query, rootPath)
|
||||
const fileStatsPromise = this.fetchFileStats(query, rootPath)
|
||||
|
||||
const fileMatches = await fileMatchesPromise
|
||||
const fileStats = await fileStatsPromise
|
||||
const fetchFilesDuration = performance.now() - fetchFilesStart
|
||||
debug('LocalKeywordContextFetcher:fetchKeywordFiles', JSON.stringify({ fetchFilesDuration }))
|
||||
|
||||
const { fileTermCounts, termTotalFiles, totalFiles } = fileMatches
|
||||
const idfDict = idf(termTotalFiles, totalFiles)
|
||||
24
client/cody/src/logged-rerank.ts
Normal file
24
client/cody/src/logged-rerank.ts
Normal file
@ -0,0 +1,24 @@
|
||||
import { ChatClient } from '@sourcegraph/cody-shared/src/chat/chat'
|
||||
import { LLMReranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank'
|
||||
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
|
||||
|
||||
import { debug } from './log'
|
||||
import { TestSupport } from './test-support'
|
||||
|
||||
export function getRerankWithLog(
|
||||
chatClient: ChatClient
|
||||
): (query: string, results: ContextResult[]) => Promise<ContextResult[]> {
|
||||
if (TestSupport.instance) {
|
||||
const reranker = TestSupport.instance.getReranker()
|
||||
return (query: string, results: ContextResult[]): Promise<ContextResult[]> => reranker.rerank(query, results)
|
||||
}
|
||||
|
||||
const reranker = new LLMReranker(chatClient)
|
||||
return async (userQuery: string, results: ContextResult[]): Promise<ContextResult[]> => {
|
||||
const start = performance.now()
|
||||
const rerankedResults = await reranker.rerank(userQuery, results)
|
||||
const duration = performance.now() - start
|
||||
debug('Reranker:rerank', JSON.stringify({ duration }))
|
||||
return rerankedResults
|
||||
}
|
||||
}
|
||||
@ -43,12 +43,13 @@ export class LocalStorage {
|
||||
([humanMessage, assistantMessageAndContextFiles]) => ({
|
||||
humanMessage,
|
||||
assistantMessage: assistantMessageAndContextFiles,
|
||||
context: assistantMessageAndContextFiles.contextFiles
|
||||
fullContext: assistantMessageAndContextFiles.contextFiles
|
||||
? assistantMessageAndContextFiles.contextFiles.map(fileName => ({
|
||||
speaker: 'assistant',
|
||||
fileName,
|
||||
}))
|
||||
: [],
|
||||
usedContextFiles: [],
|
||||
// Timestamp not recoverable so we use the group timestamp
|
||||
timestamp: id,
|
||||
})
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import { ChatMessage } from '@sourcegraph/cody-shared/src/chat/transcript/messages'
|
||||
import { MockReranker, Reranker } from '@sourcegraph/cody-shared/src/codebase-context/rerank'
|
||||
import { ContextResult } from '@sourcegraph/cody-shared/src/local-context'
|
||||
|
||||
import { ChatViewProvider } from './chat/ChatViewProvider'
|
||||
import { FixupTask } from './non-stop/FixupTask'
|
||||
@ -39,6 +41,17 @@ export class TestSupport {
|
||||
|
||||
public chatViewProvider = new Rendezvous<ChatViewProvider>()
|
||||
|
||||
public reranker: Reranker | undefined
|
||||
|
||||
public getReranker(): Reranker {
|
||||
if (!this.reranker) {
|
||||
return new MockReranker(
|
||||
(_: string, results: ContextResult[]): Promise<ContextResult[]> => Promise.resolve(results)
|
||||
)
|
||||
}
|
||||
return this.reranker
|
||||
}
|
||||
|
||||
public async chatTranscript(): Promise<ChatMessage[]> {
|
||||
return (await this.chatViewProvider.get()).transcriptForTesting(this)
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ type CompletionsResolver interface {
|
||||
|
||||
type CompletionsArgs struct {
|
||||
Input CompletionsInput
|
||||
Fast bool
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
||||
@ -45,7 +45,14 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
|
||||
return "", errors.New("completions are not configured or disabled")
|
||||
}
|
||||
|
||||
ctx, done := httpapi.Trace(ctx, "resolver", completionsConfig.ChatModel).
|
||||
var chatModel string
|
||||
if args.Fast {
|
||||
chatModel = completionsConfig.FastChatModel
|
||||
} else {
|
||||
chatModel = completionsConfig.ChatModel
|
||||
}
|
||||
|
||||
ctx, done := httpapi.Trace(ctx, "resolver", chatModel).
|
||||
WithErrorP(&err).
|
||||
Build()
|
||||
defer done()
|
||||
@ -66,7 +73,7 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
|
||||
|
||||
params := convertParams(args)
|
||||
// No way to configure the model through the request, we hard code to chat.
|
||||
params.Model = completionsConfig.ChatModel
|
||||
params.Model = chatModel
|
||||
resp, err := client.Complete(ctx, types.CompletionsFeatureChat, params)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "client.Complete")
|
||||
|
||||
@ -79,6 +79,10 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio
|
||||
completionsConfig.ChatModel = completionsConfig.Model
|
||||
}
|
||||
|
||||
if completionsConfig.FastChatModel == "" {
|
||||
completionsConfig.FastChatModel = completionsConfig.ChatModel
|
||||
}
|
||||
|
||||
// TODO: Temporary workaround to fix instances where no completion model is set.
|
||||
if completionsConfig.CompletionModel == "" {
|
||||
completionsConfig.CompletionModel = "claude-instant-v1"
|
||||
@ -102,6 +106,7 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) *schema.Completio
|
||||
// TODO: These are not required right now as upstream overwrites this,
|
||||
// but should we switch to Cody Gateway they will be.
|
||||
ChatModel: "claude-v1",
|
||||
FastChatModel: "claude-instant-v1",
|
||||
CompletionModel: "claude-instant-v1",
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,13 +28,15 @@ func TestGetCompletionsConfig(t *testing.T) {
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
ChatModel: "claude-v1",
|
||||
FastChatModel: "claude-instant-v1",
|
||||
CompletionModel: "claude-instant-v1",
|
||||
},
|
||||
},
|
||||
want: autogold.Expect(&schema.Completions{
|
||||
ChatModel: "claude-v1", CompletionModel: "claude-instant-v1",
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
FastChatModel: "claude-instant-v1",
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
}),
|
||||
},
|
||||
{
|
||||
|
||||
@ -19,7 +19,7 @@ func NewCodeCompletionsHandler(logger log.Logger, db database.DB) http.Handler {
|
||||
logger = logger.Scoped("code", "code completions handler")
|
||||
|
||||
rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureCode)
|
||||
return newCompletionsHandler(rl, "code", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
|
||||
return newCompletionsHandler(rl, "code", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string {
|
||||
// No user defined models for now.
|
||||
// TODO(eseliger): Look into reviving this, but it was unused so far.
|
||||
return c.CompletionModel
|
||||
|
||||
@ -19,7 +19,12 @@ import (
|
||||
// being cancelled.
|
||||
const maxRequestDuration = time.Minute
|
||||
|
||||
func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(types.CompletionRequestParameters, *schema.Completions) string, handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter)) http.Handler {
|
||||
func newCompletionsHandler(
|
||||
rl RateLimiter,
|
||||
traceFamily string,
|
||||
getModel func(types.CodyCompletionRequestParameters, *schema.Completions) string,
|
||||
handle func(context.Context, types.CompletionRequestParameters, types.CompletionsClient, http.ResponseWriter),
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, fmt.Sprintf("unsupported method %s", r.Method), http.StatusMethodNotAllowed)
|
||||
@ -40,13 +45,17 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ
|
||||
return
|
||||
}
|
||||
|
||||
var requestParams types.CompletionRequestParameters
|
||||
var requestParams types.CodyCompletionRequestParameters
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestParams); err != nil {
|
||||
http.Error(w, "could not decode request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Model is not configurable but technically allowed in the request body right now.
|
||||
if requestParams.Model != "" {
|
||||
http.Error(w, "user-specified models are not allowed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
requestParams.Model = getModel(requestParams, completionsConfig)
|
||||
|
||||
var err error
|
||||
@ -77,7 +86,7 @@ func newCompletionsHandler(rl RateLimiter, traceFamily string, getModel func(typ
|
||||
return
|
||||
}
|
||||
|
||||
handle(ctx, requestParams, completionClient, w)
|
||||
handle(ctx, requestParams.CompletionRequestParameters, completionClient, w)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -19,8 +19,11 @@ func NewChatCompletionsStreamHandler(logger log.Logger, db database.DB) http.Han
|
||||
logger = logger.Scoped("chat", "chat completions handler")
|
||||
rl := NewRateLimiter(db, redispool.Store, types.CompletionsFeatureChat)
|
||||
|
||||
return newCompletionsHandler(rl, "chat", func(requestParams types.CompletionRequestParameters, c *schema.Completions) string {
|
||||
return newCompletionsHandler(rl, "chat", func(requestParams types.CodyCompletionRequestParameters, c *schema.Completions) string {
|
||||
// No user defined models for now.
|
||||
if requestParams.Fast {
|
||||
return c.FastChatModel
|
||||
}
|
||||
return c.ChatModel
|
||||
}, func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter) {
|
||||
eventWriter, err := streamhttp.NewWriter(w)
|
||||
|
||||
@ -37,6 +37,14 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str
|
||||
return fmt.Sprintf("%s %s", prefix, m.Text), nil
|
||||
}
|
||||
|
||||
type CodyCompletionRequestParameters struct {
|
||||
CompletionRequestParameters
|
||||
|
||||
// When Fast is true, then it is used as a hint to prefer a model
|
||||
// that is faster (but probably "dumber").
|
||||
Fast bool
|
||||
}
|
||||
|
||||
type CompletionRequestParameters struct {
|
||||
// Prompt exists only for backwards compatibility. Do not use it in new
|
||||
// implementations. It will be removed once we are reasonably sure 99%
|
||||
|
||||
@ -1445,6 +1445,13 @@ importers:
|
||||
'@sourcegraph/http-client':
|
||||
specifier: workspace:*
|
||||
version: link:../http-client
|
||||
xml2js:
|
||||
specifier: ^0.6.0
|
||||
version: 0.6.0
|
||||
devDependencies:
|
||||
'@types/xml2js':
|
||||
specifier: ^0.4.11
|
||||
version: 0.4.11
|
||||
|
||||
client/cody-slack:
|
||||
dependencies:
|
||||
@ -10016,6 +10023,12 @@ packages:
|
||||
dependencies:
|
||||
'@types/node': 13.13.5
|
||||
|
||||
/@types/xml2js@0.4.11:
|
||||
resolution: {integrity: sha512-JdigeAKmCyoJUiQljjr7tQG3if9NkqGUgwEUqBvV0N7LM4HyQk7UXCnusRa1lnvXAEYJ8mw8GtZWioagNztOwA==}
|
||||
dependencies:
|
||||
'@types/node': 13.13.5
|
||||
dev: true
|
||||
|
||||
/@types/yargs-parser@21.0.0:
|
||||
resolution: {integrity: sha512-iO9ZQHkZxHn4mSakYV0vFHAVDyEOIJQrV2uZ06HxEPcx+mt8swXoZHIbaaJ2crJYFfErySgktuTZ3BeLz+XmFA==}
|
||||
|
||||
@ -26143,7 +26156,6 @@ packages:
|
||||
|
||||
/sax@1.2.4:
|
||||
resolution: {integrity: sha512-NqVDv9TpANUjFm0N8uM5GxL36UgKi9/atZw+x7YFnQ8ckwFGKrl4xX4yWtrey3UJm5nP1kUbnYgLopqWNSRhWw==}
|
||||
dev: true
|
||||
|
||||
/saxes@5.0.1:
|
||||
resolution: {integrity: sha512-5LBh1Tls8c9xgGjw3QrMwETmTMVk0oFgvrFSvWx62llR2hcEInrKNZ2GZCCuuy2lvWrdl5jhbpeqc5hRYKFOcw==}
|
||||
@ -30109,6 +30121,14 @@ packages:
|
||||
xmlbuilder: 11.0.1
|
||||
dev: true
|
||||
|
||||
/xml2js@0.6.0:
|
||||
resolution: {integrity: sha512-eLTh0kA8uHceqesPqSE+VvO1CDDJWMwlQfB6LuN6T8w6MaDJ8Txm8P7s5cHD0miF0V+GGTZrDQfxPZQVsur33w==}
|
||||
engines: {node: '>=4.0.0'}
|
||||
dependencies:
|
||||
sax: 1.2.4
|
||||
xmlbuilder: 11.0.1
|
||||
dev: false
|
||||
|
||||
/xml@1.0.1:
|
||||
resolution: {integrity: sha512-huCv9IH9Tcf95zuYCsQraZtWnJvBtLVE0QHMOs8bWyZAFZNDcYjsPq1nEx8jKA9y+Beo9v+7OBPRisQTjinQMw==}
|
||||
dev: true
|
||||
@ -30116,7 +30136,6 @@ packages:
|
||||
/xmlbuilder@11.0.1:
|
||||
resolution: {integrity: sha512-fDlsI/kFEx7gLvbecc0/ohLG50fugQp8ryHzMTuW9vSa1GJ0XYWKnhsUx7oie3G98+r56aTQIUB4kht42R3JvA==}
|
||||
engines: {node: '>=4.0'}
|
||||
dev: true
|
||||
|
||||
/xmlchars@2.2.0:
|
||||
resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==}
|
||||
|
||||
@ -564,6 +564,8 @@ type Completions struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
// Endpoint description: The endpoint under which to reach the provider. Currently only used for provider types "sourcegraph", "openai" and "anthropic". The default values are "https://cody-gateway.sourcegraph.com", "https://api.openai.com/v1/chat/completions", and "https://api.anthropic.com/v1/complete" for Sourcegraph, OpenAI, and Anthropic, respectively.
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
// FastChatModel description: The model used for fast chat completions.
|
||||
FastChatModel string `json:"fastChatModel,omitempty"`
|
||||
// Model description: DEPRECATED. Use chatModel instead.
|
||||
Model string `json:"model,omitempty"`
|
||||
// PerUserCodeCompletionsDailyLimit description: If > 0, enables the maximum number of code completions requests allowed to be made by a single user account in a day. On instances that allow anonymous requests, the rate limit is enforced by IP.
|
||||
|
||||
@ -2422,6 +2422,10 @@
|
||||
"description": "DEPRECATED. Use chatModel instead.",
|
||||
"type": "string"
|
||||
},
|
||||
"fastChatModel": {
|
||||
"description": "The model used for fast chat completions.",
|
||||
"type": "string"
|
||||
},
|
||||
"chatModel": {
|
||||
"description": "The model used for chat completions. If using the default provider 'sourcegraph', a reasonable default model will be set.",
|
||||
"type": "string"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user