mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:11:49 +00:00
Claude 3 support and /messages API for Enterprise (#60953)
Closes #61166 This PR adds support for Claude 3 and the /messages API to the existing anthropic provider in the Sourcegraph instance. To ensure a smooth experience, there are a couple of edge cases that we need to handle. Because the URL is configurable, a customer could hard-set it to /complete. We need to error properly in this case. Clients might not now what is set so they can send requests in the "old" or "new" format. We handle conversion as best as possible however for better instruction the clients will eventually only send prompts in the /messages format. We introduce cody API versioning for this case. Support /complete style prompt (with trailing assistant, "holes" of no response, system prompt in messages) when a legacy client connects --------- Co-authored-by: Chris Warwick <christopher.warwick@sourcegraph.com>
This commit is contained in:
parent
65ea643174
commit
0ce98cd386
@ -40,7 +40,9 @@ function modelBadgeVariant(model: string, mode: 'completions' | 'embeddings'): '
|
||||
case 'anthropic/claude-instant-v1.1-100k':
|
||||
case 'anthropic/claude-instant-v1.2':
|
||||
case 'anthropic/claude-instant-1.2':
|
||||
case 'anthropic/claude-instant-1.2-cyan':
|
||||
case 'anthropic/claude-3-sonnet-20240229':
|
||||
case 'anthropic/claude-3-opus-20240229':
|
||||
case 'anthropic/claude-3-haiku-20240307':
|
||||
// See here: https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// for currently available Anthropic models. Note that we also need to
|
||||
// allow list the models on the Cody Gateway side.
|
||||
|
||||
@ -25,7 +25,6 @@ go_library(
|
||||
"//cmd/cody-gateway/shared/config",
|
||||
"//internal/codygateway",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/anthropicmessages",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/openai",
|
||||
"//internal/conf/conftypes",
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
)
|
||||
@ -255,7 +255,7 @@ func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logge
|
||||
}
|
||||
|
||||
// Otherwise, we have to parse the event stream from anthropic.
|
||||
dec := anthropicmessages.NewDecoder(r)
|
||||
dec := anthropic.NewDecoder(r)
|
||||
for dec.Scan() {
|
||||
data := dec.Data()
|
||||
|
||||
|
||||
@ -74,10 +74,13 @@ func (c *completionsResolver) Completions(ctx context.Context, args graphqlbacke
|
||||
return "", err
|
||||
}
|
||||
|
||||
// GraphQL API is considered a legacy API
|
||||
version := types.CompletionsVersionLegacy
|
||||
|
||||
params := convertParams(args)
|
||||
// No way to configure the model through the request, we hard code to chat.
|
||||
params.Model = chatModel
|
||||
resp, err := client.Complete(ctx, types.CompletionsFeatureChat, params)
|
||||
resp, err := client.Complete(ctx, types.CompletionsFeatureChat, version, params)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "client.Complete")
|
||||
}
|
||||
|
||||
@ -11,11 +11,12 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func NewClient(cli httpcli.Doer, apiURL, accessToken string) types.CompletionsClient {
|
||||
func NewClient(cli httpcli.Doer, apiURL, accessToken string, viaGateway bool) types.CompletionsClient {
|
||||
return &anthropicClient{
|
||||
cli: cli,
|
||||
accessToken: accessToken,
|
||||
apiURL: apiURL,
|
||||
viaGateway: viaGateway,
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,42 +28,54 @@ type anthropicClient struct {
|
||||
cli httpcli.Doer
|
||||
accessToken string
|
||||
apiURL string
|
||||
viaGateway bool
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
version types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
resp, err := a.makeRequest(ctx, requestParams, false)
|
||||
resp, err := a.makeRequest(ctx, requestParams, version, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response anthropicCompletionResponse
|
||||
var response anthropicNonStreamingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
completion := ""
|
||||
for _, content := range response.Content {
|
||||
completion += content.Text
|
||||
}
|
||||
|
||||
return &types.CompletionResponse{
|
||||
Completion: response.Completion,
|
||||
Completion: completion,
|
||||
StopReason: response.StopReason,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
version types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
resp, err := a.makeRequest(ctx, requestParams, true)
|
||||
resp, err := a.makeRequest(ctx, requestParams, version, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
dec := NewDecoder(resp.Body)
|
||||
|
||||
completion := ""
|
||||
for dec.Scan() {
|
||||
if ctx.Err() != nil && ctx.Err() == context.Canceled {
|
||||
return nil
|
||||
@ -75,48 +88,70 @@ func (a *anthropicClient) Stream(
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropicCompletionResponse
|
||||
stopReason := ""
|
||||
var event anthropicStreamingResponse
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
return errors.Errorf("failed to decode event payload: %w - body: %s", err, string(data))
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil {
|
||||
completion += event.Delta.Text
|
||||
}
|
||||
case "message_delta":
|
||||
if event.Delta != nil {
|
||||
stopReason = event.Delta.StopReason
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
err = sendEvent(types.CompletionResponse{
|
||||
Completion: event.Completion,
|
||||
StopReason: event.StopReason,
|
||||
Completion: completion,
|
||||
StopReason: stopReason,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return dec.Err()
|
||||
}
|
||||
|
||||
func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
|
||||
prompt, err := GetPrompt(requestParams.Messages)
|
||||
func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, stream bool) (*http.Response, error) {
|
||||
convertedMessages := requestParams.Messages
|
||||
stopSequences := removeWhitespaceOnlySequences(requestParams.StopSequences)
|
||||
if version == types.CompletionsVersionLegacy {
|
||||
convertedMessages = convertFromLegacyMessages(convertedMessages)
|
||||
}
|
||||
var payload any
|
||||
messages, err := toAnthropicMessages(convertedMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Backcompat: Remove this code once enough clients are upgraded and we drop the
|
||||
// Prompt field on requestParams.
|
||||
if prompt == "" {
|
||||
prompt = requestParams.Prompt
|
||||
messagesPayload := anthropicRequestParameters{
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
StopSequences: stopSequences,
|
||||
Model: pinModel(requestParams.Model),
|
||||
Temperature: requestParams.Temperature,
|
||||
MaxTokens: requestParams.MaxTokensToSample,
|
||||
TopP: requestParams.TopP,
|
||||
TopK: requestParams.TopK,
|
||||
}
|
||||
|
||||
if len(requestParams.StopSequences) == 0 {
|
||||
requestParams.StopSequences = []string{HUMAN_PROMPT}
|
||||
if !a.viaGateway {
|
||||
// Convert the eventual first message from `system` to a top-level system prompt
|
||||
messagesPayload.System = "" // prevent the upstream API from setting this
|
||||
if len(messagesPayload.Messages) > 0 && messagesPayload.Messages[0].Role == types.SYSTEM_MESSAGE_SPEAKER {
|
||||
messagesPayload.System = messagesPayload.Messages[0].Content[0].Text
|
||||
messagesPayload.Messages = messagesPayload.Messages[1:]
|
||||
}
|
||||
}
|
||||
|
||||
payload := anthropicCompletionsRequestParameters{
|
||||
Stream: stream,
|
||||
StopSequences: requestParams.StopSequences,
|
||||
Model: requestParams.Model,
|
||||
Temperature: requestParams.Temperature,
|
||||
MaxTokensToSample: requestParams.MaxTokensToSample,
|
||||
TopP: requestParams.TopP,
|
||||
TopK: requestParams.TopK,
|
||||
Prompt: prompt,
|
||||
}
|
||||
payload = messagesPayload
|
||||
|
||||
reqBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@ -135,12 +170,7 @@ func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.C
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Client", clientID)
|
||||
req.Header.Set("X-API-Key", a.accessToken)
|
||||
// Set the API version so responses are in the expected format.
|
||||
// NOTE: When changing this here, Cody Gateway currently overwrites this header
|
||||
// with 2023-01-01, so it will not be respected in Gateway usage and we will
|
||||
// have to fall back to the old parser, or implement a mechanism on the Gateway
|
||||
// side that understands the version header we send here and switch out the parser.
|
||||
req.Header.Set("anthropic-version", "2023-01-01")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
resp, err := a.cli.Do(req)
|
||||
if err != nil {
|
||||
@ -154,18 +184,57 @@ func (a *anthropicClient) makeRequest(ctx context.Context, requestParams types.C
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type anthropicCompletionsRequestParameters struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
type anthropicRequestParameters struct {
|
||||
Messages []anthropicMessage `json:"messages,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
|
||||
// These are not accepted from the client an instead are only used to talk to the upstream LLM
|
||||
// APIs directly (these do NOT need to be set when talking to Cody Gateway)
|
||||
System string `json:"system,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicCompletionResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"` // "user", "assistant", or "system" (only allowed for the first message)
|
||||
Content []anthropicMessageContent `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicMessageContent struct {
|
||||
Type string `json:"type"` // "text" or "image" (not yet supported)
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type anthropicNonStreamingResponse struct {
|
||||
Content []anthropicMessageContent `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
// AnthropicMessagesStreamingResponse captures all relevant-to-us fields from each relevant SSE event from https://docs.anthropic.com/claude/reference/messages_post.
|
||||
type anthropicStreamingResponse struct {
|
||||
Type string `json:"type"`
|
||||
Delta *anthropicStreamingResponseTextBucket `json:"delta"`
|
||||
ContentBlock *anthropicStreamingResponseTextBucket `json:"content_block"`
|
||||
}
|
||||
|
||||
type anthropicStreamingResponseTextBucket struct {
|
||||
Text string `json:"text"` // for event `content_block_delta`
|
||||
StopReason string `json:"stop_reason"` // for event `message_delta`
|
||||
}
|
||||
|
||||
// The /stream API does not support unpinned models
|
||||
func pinModel(model string) string {
|
||||
switch model {
|
||||
case "claude-instant-1",
|
||||
"claude-instant-v1":
|
||||
return "claude-instant-1.2"
|
||||
case "claude-2":
|
||||
return "claude-2.0"
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,6 @@ package anthropic
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
@ -23,11 +22,11 @@ func (c *mockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||
return c.do(r)
|
||||
}
|
||||
|
||||
func linesToResponse(lines []string) []byte {
|
||||
func linesToResponse(lines []string, separator string) []byte {
|
||||
responseBytes := []byte{}
|
||||
for _, line := range lines {
|
||||
responseBytes = append(responseBytes, []byte(fmt.Sprintf("data: %s", line))...)
|
||||
responseBytes = append(responseBytes, []byte("\r\n\r\n")...)
|
||||
responseBytes = append(responseBytes, []byte(line)...)
|
||||
responseBytes = append(responseBytes, []byte(separator)...)
|
||||
}
|
||||
return responseBytes
|
||||
}
|
||||
@ -37,21 +36,37 @@ func getMockClient(responseBody []byte) types.CompletionsClient {
|
||||
func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseBody))}, nil
|
||||
},
|
||||
}, "", "")
|
||||
}, "", "", false)
|
||||
}
|
||||
|
||||
func TestValidAnthropicStream(t *testing.T) {
|
||||
var mockAnthropicResponseLines = []string{
|
||||
`{"completion": "Sure!"}`,
|
||||
`{"completion": "Sure! The Fibonacci sequence is defined as:\n\nF0 = 0\nF1 = 1\nFn = Fn-1 + Fn-2\n\nSo in Python, you can write it like this:\ndef fibonacci(n):\n if n < 2:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n\nOr iteratively:\ndef fibonacci(n):\n a, b = 0, 1\n for i in range(n):\n a, b = b, a + b\n return a\n\nSo for example:\nprint(fibonacci(8)) # 21"}`,
|
||||
`2023.28.2 8:54`, // To test skipping over non-JSON data.
|
||||
`{"completion": "Sure! The Fibonacci sequence is defined as:\n\nF0 = 0\nF1 = 1\nFn = Fn-1 + Fn-2\n\nSo in Python, you can write it like this:\ndef fibonacci(n):\n if n < 2:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n\nOr iteratively:\ndef fibonacci(n):\n a, b = 0, 1\n for i in range(n):\n a, b = b, a + b\n return a\n\nSo for example:\nprint(fibonacci(8)) # 21\n\nThe iterative"}`,
|
||||
"[DONE]",
|
||||
func TestValidAnthropicMessagesStream(t *testing.T) {
|
||||
var mockAnthropicMessagesResponseLines = []string{
|
||||
`event: message_start
|
||||
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}`,
|
||||
`event: content_block_start
|
||||
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}`,
|
||||
`event: ping
|
||||
data: {"type": "ping"}`,
|
||||
`event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "He"}}`,
|
||||
`event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "llo"}}`,
|
||||
`event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}`,
|
||||
`event: content_block_stop
|
||||
data: {"type": "content_block_stop", "index": 0}`,
|
||||
`event: message_delta
|
||||
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}`,
|
||||
`event: message_stop
|
||||
data: {"type": "message_stop"}`,
|
||||
}
|
||||
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicResponseLines))
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicMessagesResponseLines, "\n\n"))
|
||||
events := []types.CompletionResponse{}
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error {
|
||||
stream := true
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{
|
||||
Stream: &stream,
|
||||
}, func(event types.CompletionResponse) error {
|
||||
events = append(events, event)
|
||||
return nil
|
||||
})
|
||||
@ -61,11 +76,11 @@ func TestValidAnthropicStream(t *testing.T) {
|
||||
autogold.ExpectFile(t, events)
|
||||
}
|
||||
|
||||
func TestInvalidAnthropicStream(t *testing.T) {
|
||||
var mockAnthropicInvalidResponseLines = []string{`{]`}
|
||||
func TestInvalidAnthropicMessagesStream(t *testing.T) {
|
||||
var mockAnthropicInvalidResponseLines = []string{`data:{]`}
|
||||
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines))
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
mockClient := getMockClient(linesToResponse(mockAnthropicInvalidResponseLines, "\r\n\r\n"))
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@ -80,10 +95,10 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
|
||||
}, nil
|
||||
},
|
||||
}, "", "")
|
||||
}, "", "", false)
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{})
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
@ -93,7 +108,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("Anthropic: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
|
||||
@ -101,3 +116,60 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCompleteApiToMessages(t *testing.T) {
|
||||
var response *http.Request
|
||||
mockClient := NewClient(&mockDoer{
|
||||
func(r *http.Request) (*http.Response, error) {
|
||||
response = r
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
|
||||
}, nil
|
||||
},
|
||||
}, "", "", false)
|
||||
messages := []types.Message{
|
||||
{Speaker: "human", Text: "¡Hola!"},
|
||||
// /complete prompts can have human messages without an assistant response. These should
|
||||
// be ignored.
|
||||
{Speaker: "assistant", Text: ""},
|
||||
{Speaker: "human", Text: "Servus!"},
|
||||
// /complete prompts might end with an empty assistant message
|
||||
{Speaker: "assistant"},
|
||||
}
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: messages})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
autogold.Expect(body).Equal(t, []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Servus!"}]}],"model":""}`))
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
stream := true
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: messages, Stream: &stream}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
autogold.Expect(body).Equal(t, []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Servus!"}]}],"model":"","stream":true}`))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPinModel(t *testing.T) {
|
||||
t.Run("Claude Instant", func(t *testing.T) {
|
||||
assert.Equal(t, pinModel("claude-instant-1"), "claude-instant-1.2")
|
||||
assert.Equal(t, pinModel("claude-instant-v1"), "claude-instant-1.2")
|
||||
})
|
||||
|
||||
t.Run("Claude 2", func(t *testing.T) {
|
||||
assert.Equal(t, pinModel("claude-2"), "claude-2.0")
|
||||
})
|
||||
}
|
||||
|
||||
@ -35,6 +35,9 @@ func NewDecoder(r io.Reader) *decoder {
|
||||
if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
|
||||
return i + 4, data[:i], nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||
return i + 2, data[:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated event. This should
|
||||
// be empty.
|
||||
if atEOF {
|
||||
@ -56,26 +59,30 @@ func (d *decoder) Scan() bool {
|
||||
return false
|
||||
}
|
||||
for d.scanner.Scan() {
|
||||
// event: $_name
|
||||
// data: json($data)|[DONE]
|
||||
line := d.scanner.Bytes()
|
||||
typ, data := splitColon(line)
|
||||
switch {
|
||||
case bytes.Equal(typ, []byte("data")):
|
||||
d.data = data
|
||||
// Check for special sentinel value used by the Anthropic API to
|
||||
// indicate that the stream is done.
|
||||
if bytes.Equal(data, doneBytes) {
|
||||
d.done = true
|
||||
|
||||
lines := bytes.Split(d.scanner.Bytes(), []byte("\n"))
|
||||
for _, line := range lines {
|
||||
typ, data := splitColon(line)
|
||||
|
||||
switch {
|
||||
case bytes.Equal(typ, []byte("data")):
|
||||
d.data = data
|
||||
// Check for special sentinel value used by the Anthropic API to
|
||||
// indicate that the stream is done.
|
||||
if bytes.Equal(data, doneBytes) {
|
||||
d.done = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case bytes.Equal(typ, []byte("event")):
|
||||
// Anthropic sends the event name in the data payload as well so we ignore it for snow
|
||||
continue
|
||||
default:
|
||||
d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case bytes.Equal(typ, []byte("event")):
|
||||
// Anthropic occasionally sends ping events.
|
||||
// Just ignore these and continue scanning.
|
||||
continue
|
||||
default:
|
||||
d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,7 +90,6 @@ func (d *decoder) Scan() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Event returns the event data of the last decoded event
|
||||
func (d *decoder) Data() []byte {
|
||||
return d.data
|
||||
}
|
||||
|
||||
@ -45,12 +45,12 @@ func TestDecoder(t *testing.T) {
|
||||
t.Run("InterleavedPing", func(t *testing.T) {
|
||||
events, err := decodeAll("data:a\r\n\r\nevent: ping\r\ndata: 2023-04-28 21:18:31.866238\r\n\r\ndata:b\r\n\r\ndata: [DONE]\r\n\r\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "b"}})
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "2023-04-28 21:18:31.866238"}, {data: "b"}})
|
||||
})
|
||||
|
||||
t.Run("Ends after done", func(t *testing.T) {
|
||||
events, err := decodeAll("data:a\r\n\r\nevent: ping\r\ndata: 2023-04-28 21:18:31.866238\r\n\r\ndata:b\r\n\r\ndata: [DONE]\r\n\r\ndata:c\r\n\r\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "b"}})
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "2023-04-28 21:18:31.866238"}, {data: "b"}})
|
||||
})
|
||||
}
|
||||
|
||||
@ -7,21 +7,96 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const HUMAN_PROMPT = "\n\nHuman:"
|
||||
const ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
|
||||
func GetPrompt(messages []types.Message) (string, error) {
|
||||
prompt := make([]string, 0, len(messages))
|
||||
for idx, message := range messages {
|
||||
if idx > 0 && messages[idx-1].Speaker == message.Speaker {
|
||||
return "", errors.Newf("found consecutive messages with the same speaker '%s'", message.Speaker)
|
||||
func removeWhitespaceOnlySequences(sequences []string) []string {
|
||||
var result []string
|
||||
for _, sequence := range sequences {
|
||||
if len(strings.TrimSpace(sequence)) > 0 {
|
||||
result = append(result, sequence)
|
||||
}
|
||||
|
||||
messagePrompt, err := message.GetPrompt(HUMAN_PROMPT, ASSISTANT_PROMPT)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
prompt = append(prompt, messagePrompt)
|
||||
}
|
||||
return strings.Join(prompt, ""), nil
|
||||
return result
|
||||
}
|
||||
|
||||
func toAnthropicMessages(messages []types.Message) ([]anthropicMessage, error) {
|
||||
anthropicMessages := make([]anthropicMessage, 0, len(messages))
|
||||
|
||||
for i, message := range messages {
|
||||
speaker := message.Speaker
|
||||
text := message.Text
|
||||
|
||||
anthropicRole := message.Speaker
|
||||
|
||||
switch speaker {
|
||||
case types.SYSTEM_MESSAGE_SPEAKER:
|
||||
if i != 0 {
|
||||
return nil, errors.New("system role can only be used in the first message")
|
||||
}
|
||||
case types.ASSISTANT_MESSAGE_SPEAKER:
|
||||
case types.HUMAN_MESSAGE_SPEAKER:
|
||||
anthropicRole = "user"
|
||||
default:
|
||||
return nil, errors.Errorf("unexpected role: %s", text)
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return nil, errors.New("message content cannot be empty")
|
||||
}
|
||||
|
||||
anthropicMessages = append(anthropicMessages, anthropicMessage{
|
||||
Role: anthropicRole,
|
||||
Content: []anthropicMessageContent{{Text: text, Type: "text"}},
|
||||
})
|
||||
}
|
||||
|
||||
return anthropicMessages, nil
|
||||
}
|
||||
|
||||
func convertFromLegacyMessages(messages []types.Message) []types.Message {
|
||||
filteredMessages := make([]types.Message, 0)
|
||||
skipNext := false
|
||||
for i, message := range messages {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
// 1. If the first message is "system prompt like" convert it to an actual system prompt
|
||||
//
|
||||
// Note: The prefix we scan for here is used in the current chat prompts for VS Code and the
|
||||
// old Web UI prompt.
|
||||
if i == 0 && strings.HasPrefix(message.Text, "You are Cody, an AI") {
|
||||
message.Speaker = types.SYSTEM_MESSAGE_SPEAKER
|
||||
skipNext = true
|
||||
}
|
||||
|
||||
if i == len(messages)-1 && message.Speaker == types.ASSISTANT_MESSAGE_SPEAKER {
|
||||
// 2. If the last message is from an `assistant` with no or empty `text`, omit it
|
||||
if message.Text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. Final assistant content cannot end with trailing whitespace
|
||||
message.Text = strings.TrimRight(message.Text, " \t\n\r")
|
||||
|
||||
}
|
||||
|
||||
// 4. If there is any assistant message in the middle of the messages without a `text`, omit
|
||||
// both the empty assistant message as well as the unanswered question from the `user`
|
||||
|
||||
// Don't apply this to the human message before the last message (it should always be included)
|
||||
if i >= len(messages)-2 {
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
continue
|
||||
}
|
||||
// If the next message is an assistant message with no or empty `content`, omit the current and
|
||||
// the next one
|
||||
nextMessage := messages[i+1]
|
||||
if (nextMessage.Speaker == types.ASSISTANT_MESSAGE_SPEAKER && nextMessage.Text == "") ||
|
||||
(message.Speaker == types.ASSISTANT_MESSAGE_SPEAKER && message.Text == "") {
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
return filteredMessages
|
||||
}
|
||||
|
||||
@ -3,58 +3,64 @@ package anthropic
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hexops/autogold/v2"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
)
|
||||
|
||||
func TestGetPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []types.Message
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "assistant", Text: "Hi there!"},
|
||||
},
|
||||
want: "\n\nHuman: Hello\n\nAssistant: Hi there!",
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "assistant", Text: ""},
|
||||
},
|
||||
want: "\n\nHuman: Hello\n\nAssistant:",
|
||||
},
|
||||
{
|
||||
name: "consecutive same speaker error",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "human", Text: "Hi"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid speaker",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human1", Text: "Hello"},
|
||||
{Speaker: "human2", Text: "Hi"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := GetPrompt(tt.messages)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("getPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("getPrompt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
func TestLegacyMessageConversion(t *testing.T) {
|
||||
messages := []types.Message{
|
||||
// Convert legacy system-like messages to actual system messages
|
||||
{Speaker: "human", Text: "You are Cody, an AI-powered coding assistant created by Sourcegraph. You also have an Austrian dialect."},
|
||||
{Speaker: "assistant", Text: "I understand"},
|
||||
|
||||
// Removes any messages that did not get an answer?
|
||||
{Speaker: "human", Text: "Write a poem"},
|
||||
{Speaker: "assistant"}, // <- can happen when the connection is interrupted
|
||||
|
||||
{Speaker: "human", Text: "Write a poem"},
|
||||
{Speaker: "Roses are red, violets are blue, here is a poem just for you!"},
|
||||
|
||||
{Speaker: "human", Text: "Write another poem"},
|
||||
// Removes the last empty assistant message
|
||||
{Speaker: "assistant"},
|
||||
}
|
||||
|
||||
convertedMessages := convertFromLegacyMessages(messages)
|
||||
|
||||
autogold.Expect([]types.Message{
|
||||
{
|
||||
Speaker: "system",
|
||||
Text: "You are Cody, an AI-powered coding assistant created by Sourcegraph. You also have an Austrian dialect.",
|
||||
},
|
||||
{
|
||||
Speaker: "human",
|
||||
Text: "Write a poem",
|
||||
},
|
||||
{Speaker: "Roses are red, violets are blue, here is a poem just for you!"},
|
||||
{
|
||||
Speaker: "human",
|
||||
Text: "Write another poem",
|
||||
},
|
||||
}).Equal(t, convertedMessages)
|
||||
}
|
||||
|
||||
func TestLegacyMessageConversionWithTrailingAssistantResponse(t *testing.T) {
|
||||
messages := []types.Message{
|
||||
{Speaker: "human", Text: "Write another poem"},
|
||||
// Removes the last empty assistant message
|
||||
{Speaker: "assistant", Text: "Roses are red, "},
|
||||
}
|
||||
|
||||
convertedMessages := convertFromLegacyMessages(messages)
|
||||
|
||||
autogold.Expect([]types.Message{{
|
||||
Speaker: "human",
|
||||
Text: "Write another poem",
|
||||
},
|
||||
{
|
||||
Speaker: "assistant",
|
||||
Text: "Roses are red,",
|
||||
},
|
||||
}).Equal(t, convertedMessages)
|
||||
}
|
||||
|
||||
11
internal/completions/client/anthropic/testdata/TestValidAnthropicMessagesStream.golden
vendored
Normal file
11
internal/completions/client/anthropic/testdata/TestValidAnthropicMessagesStream.golden
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
[]types.CompletionResponse{
|
||||
{
|
||||
Completion: "He",
|
||||
},
|
||||
{Completion: "Hello"},
|
||||
{Completion: "Hello!"},
|
||||
{
|
||||
Completion: "Hello!",
|
||||
StopReason: "end_turn",
|
||||
},
|
||||
}
|
||||
@ -1,19 +0,0 @@
|
||||
load("//dev:go_defs.bzl", "go_test")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "anthropicmessages",
|
||||
srcs = ["decoder.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = ["//lib/errors"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "anthropicmessages_test",
|
||||
timeout = "short",
|
||||
srcs = ["decoder_test.go"],
|
||||
data = glob(["testdata/**"]),
|
||||
embed = [":anthropicmessages"],
|
||||
deps = ["@com_github_stretchr_testify//require"],
|
||||
)
|
||||
@ -1,106 +0,0 @@
|
||||
package anthropicmessages
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
|
||||
|
||||
var doneBytes = []byte("[DONE]")
|
||||
|
||||
// decoder decodes streaming events from a Server Sent Event stream. It only supports
|
||||
// streams generated by the Anthropic completions API. IE this is not a fully
|
||||
// compliant Server Sent Events decoder.
|
||||
//
|
||||
// Adapted from internal/search/streaming/http/decoder.go.
|
||||
type decoder struct {
|
||||
scanner *bufio.Scanner
|
||||
done bool
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewDecoder(r io.Reader) *decoder {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
|
||||
// bufio.ScanLines, except we look for \n\n which separate events.
|
||||
split := func(data []byte, atEOF bool) (int, []byte, error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||
return i + 2, data[:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated event. This should
|
||||
// be empty.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
scanner.Split(split)
|
||||
return &decoder{
|
||||
scanner: scanner,
|
||||
}
|
||||
}
|
||||
|
||||
// Scan advances the decoder to the next event in the stream. It returns
|
||||
// false when it either hits the end of the stream or an error.
|
||||
func (d *decoder) Scan() bool {
|
||||
if d.done {
|
||||
return false
|
||||
}
|
||||
for d.scanner.Scan() {
|
||||
// event: $_name
|
||||
// data: json($data)|[DONE]
|
||||
|
||||
lines := bytes.Split(d.scanner.Bytes(), []byte("\n"))
|
||||
for _, line := range lines {
|
||||
typ, data := splitColon(line)
|
||||
|
||||
switch {
|
||||
case bytes.Equal(typ, []byte("data")):
|
||||
d.data = data
|
||||
// Check for special sentinel value used by the Anthropic API to
|
||||
// indicate that the stream is done.
|
||||
if bytes.Equal(data, doneBytes) {
|
||||
d.done = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case bytes.Equal(typ, []byte("event")):
|
||||
// Anthropic sends the event name in the data payload as well so we ignore it for snow
|
||||
continue
|
||||
default:
|
||||
d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.err = d.scanner.Err()
|
||||
return false
|
||||
}
|
||||
|
||||
// Event returns the event data of the last decoded event
|
||||
func (d *decoder) Data() []byte {
|
||||
return d.data
|
||||
}
|
||||
|
||||
// Err returns the last encountered error
|
||||
func (d *decoder) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
func splitColon(data []byte) ([]byte, []byte) {
|
||||
i := bytes.Index(data, []byte(":"))
|
||||
if i < 0 {
|
||||
return bytes.TrimSpace(data), nil
|
||||
}
|
||||
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
|
||||
}
|
||||
@ -1,50 +0,0 @@
|
||||
package anthropicmessages
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type event struct {
|
||||
data string
|
||||
}
|
||||
|
||||
decodeAll := func(input string) ([]event, error) {
|
||||
dec := NewDecoder(strings.NewReader(input))
|
||||
var events []event
|
||||
for dec.Scan() {
|
||||
events = append(events, event{
|
||||
data: string(dec.Data()),
|
||||
})
|
||||
}
|
||||
return events, dec.Err()
|
||||
}
|
||||
|
||||
t.Run("Single", func(t *testing.T) {
|
||||
events, err := decodeAll("event:foo\ndata:b\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}})
|
||||
})
|
||||
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
events, err := decodeAll("event:foo\ndata:b\n\nevent:foo\ndata:c\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
|
||||
})
|
||||
|
||||
t.Run("ErrExpectedData", func(t *testing.T) {
|
||||
_, err := decodeAll("datas:b\n\n")
|
||||
require.Contains(t, err.Error(), "malformed data, expected data")
|
||||
})
|
||||
|
||||
t.Run("InterleavedPing", func(t *testing.T) {
|
||||
events, err := decodeAll("data:a\n\nevent: ping\ndata: pong\n\ndata:b\n\ndata: [DONE]\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "pong"}, {data: "b"}})
|
||||
})
|
||||
}
|
||||
@ -3,11 +3,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "awsbedrock",
|
||||
srcs = ["bedrock.go"],
|
||||
srcs = [
|
||||
"bedrock.go",
|
||||
"prompt.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/awsbedrock",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/types",
|
||||
"//internal/httpcli",
|
||||
"//lib/errors",
|
||||
@ -21,9 +23,13 @@ go_library(
|
||||
|
||||
go_test(
|
||||
name = "awsbedrock_test",
|
||||
srcs = ["bedrock_test.go"],
|
||||
srcs = [
|
||||
"bedrock_test.go",
|
||||
"prompt_test.go",
|
||||
],
|
||||
embed = [":awsbedrock"],
|
||||
deps = [
|
||||
"//internal/completions/types",
|
||||
"@com_github_aws_aws_sdk_go_v2_config//:config",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
|
||||
@ -19,7 +19,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -46,6 +45,7 @@ type awsBedrockAnthropicCompletionStreamClient struct {
|
||||
func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
resp, err := c.makeRequest(ctx, requestParams, false)
|
||||
@ -68,6 +68,7 @@ func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
|
||||
func (a *awsBedrockAnthropicCompletionStreamClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
@ -158,7 +159,7 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
|
||||
requestParams.TopP = 0
|
||||
}
|
||||
|
||||
prompt, err := anthropic.GetPrompt(requestParams.Messages)
|
||||
prompt, err := GetPrompt(requestParams.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -169,7 +170,7 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
|
||||
}
|
||||
|
||||
if len(requestParams.StopSequences) == 0 {
|
||||
requestParams.StopSequences = []string{anthropic.HUMAN_PROMPT}
|
||||
requestParams.StopSequences = []string{HUMAN_PROMPT}
|
||||
}
|
||||
|
||||
if requestParams.MaxTokensToSample == 0 {
|
||||
|
||||
27
internal/completions/client/awsbedrock/prompt.go
Normal file
27
internal/completions/client/awsbedrock/prompt.go
Normal file
@ -0,0 +1,27 @@
|
||||
package awsbedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const HUMAN_PROMPT = "\n\nHuman:"
|
||||
const ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
|
||||
func GetPrompt(messages []types.Message) (string, error) {
|
||||
prompt := make([]string, 0, len(messages))
|
||||
for idx, message := range messages {
|
||||
if idx > 0 && messages[idx-1].Speaker == message.Speaker {
|
||||
return "", errors.Newf("found consecutive messages with the same speaker '%s'", message.Speaker)
|
||||
}
|
||||
|
||||
messagePrompt, err := message.GetPrompt(HUMAN_PROMPT, ASSISTANT_PROMPT)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
prompt = append(prompt, messagePrompt)
|
||||
}
|
||||
return strings.Join(prompt, ""), nil
|
||||
}
|
||||
60
internal/completions/client/awsbedrock/prompt_test.go
Normal file
60
internal/completions/client/awsbedrock/prompt_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package awsbedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
)
|
||||
|
||||
func TestGetPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []types.Message
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "assistant", Text: "Hi there!"},
|
||||
},
|
||||
want: "\n\nHuman: Hello\n\nAssistant: Hi there!",
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "assistant", Text: ""},
|
||||
},
|
||||
want: "\n\nHuman: Hello\n\nAssistant:",
|
||||
},
|
||||
{
|
||||
name: "consecutive same speaker error",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human", Text: "Hello"},
|
||||
{Speaker: "human", Text: "Hi"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid speaker",
|
||||
messages: []types.Message{
|
||||
{Speaker: "human1", Text: "Hello"},
|
||||
{Speaker: "human2", Text: "Hi"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := GetPrompt(tt.messages)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("getPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("getPrompt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -123,6 +123,7 @@ type azureCompletionClient struct {
|
||||
func (c *azureCompletionClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
|
||||
@ -181,6 +182,7 @@ func completeChat(
|
||||
func (c *azureCompletionClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
@ -312,7 +314,7 @@ func getChatMessages(messages []types.Message) []azopenai.ChatRequestMessageClas
|
||||
switch m.Speaker {
|
||||
case types.HUMAN_MESSAGE_SPEAKER:
|
||||
azureMessages[i] = &azopenai.ChatRequestUserMessage{Content: azopenai.NewChatRequestUserMessageContent(message)}
|
||||
case types.ASISSTANT_MESSAGE_SPEAKER:
|
||||
case types.ASSISTANT_MESSAGE_SPEAKER:
|
||||
azureMessages[i] = &azopenai.ChatRequestAssistantMessage{Content: &message}
|
||||
}
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
|
||||
mockClient, _ := NewClient(getAzureAPIClient, "", "")
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{})
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
@ -87,7 +87,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("AzureOpenAI: unexpected status code 429: too many requests").Equal(t, err.Error())
|
||||
@ -114,7 +114,7 @@ func TestGenericErr(t *testing.T) {
|
||||
|
||||
mockClient, _ := NewClient(getAzureAPIClient, "", "")
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{})
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
@ -124,7 +124,7 @@ func TestGenericErr(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("error").Equal(t, err.Error())
|
||||
|
||||
@ -33,7 +33,7 @@ func Get(
|
||||
func getBasic(endpoint string, provider conftypes.CompletionsProviderName, accessToken string) (types.CompletionsClient, error) {
|
||||
switch provider {
|
||||
case conftypes.CompletionsProviderNameAnthropic:
|
||||
return anthropic.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
|
||||
return anthropic.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false), nil
|
||||
case conftypes.CompletionsProviderNameOpenAI:
|
||||
return openai.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken), nil
|
||||
case conftypes.CompletionsProviderNameAzureOpenAI:
|
||||
|
||||
@ -41,20 +41,20 @@ type codyGatewayClient struct {
|
||||
accessToken string
|
||||
}
|
||||
|
||||
func (c *codyGatewayClient) Stream(ctx context.Context, feature types.CompletionsFeature, requestParams types.CompletionRequestParameters, sendEvent types.SendCompletionEvent) error {
|
||||
func (c *codyGatewayClient) Stream(ctx context.Context, feature types.CompletionsFeature, version types.CompletionsVersion, requestParams types.CompletionRequestParameters, sendEvent types.SendCompletionEvent) error {
|
||||
cc, err := c.clientForParams(feature, &requestParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return overwriteErrSource(cc.Stream(ctx, feature, requestParams, sendEvent))
|
||||
return overwriteErrSource(cc.Stream(ctx, feature, version, requestParams, sendEvent))
|
||||
}
|
||||
|
||||
func (c *codyGatewayClient) Complete(ctx context.Context, feature types.CompletionsFeature, requestParams types.CompletionRequestParameters) (*types.CompletionResponse, error) {
|
||||
func (c *codyGatewayClient) Complete(ctx context.Context, feature types.CompletionsFeature, version types.CompletionsVersion, requestParams types.CompletionRequestParameters) (*types.CompletionResponse, error) {
|
||||
cc, err := c.clientForParams(feature, &requestParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := cc.Complete(ctx, feature, requestParams)
|
||||
resp, err := cc.Complete(ctx, feature, version, requestParams)
|
||||
return resp, overwriteErrSource(err)
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
|
||||
// gatewayDoer that authenticates against the Gateway's API.
|
||||
switch provider {
|
||||
case string(conftypes.CompletionsProviderNameAnthropic):
|
||||
return anthropic.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/anthropic"), "", ""), nil
|
||||
return anthropic.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/anthropic-messages"), "", "", true), nil
|
||||
case string(conftypes.CompletionsProviderNameOpenAI):
|
||||
return openai.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/openai"), "", ""), nil
|
||||
case string(conftypes.CompletionsProviderNameFireworks):
|
||||
|
||||
@ -46,6 +46,7 @@ type fireworksClient struct {
|
||||
func (c *fireworksClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
resp, err := c.makeRequest(ctx, feature, requestParams, false)
|
||||
@ -83,6 +84,7 @@ func (c *fireworksClient) Complete(
|
||||
func (c *fireworksClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
@ -190,7 +192,7 @@ func (c *fireworksClient) makeRequest(ctx context.Context, feature types.Complet
|
||||
switch m.Speaker {
|
||||
case types.HUMAN_MESSAGE_SPEAKER:
|
||||
role = "user"
|
||||
case types.ASISSTANT_MESSAGE_SPEAKER:
|
||||
case types.ASSISTANT_MESSAGE_SPEAKER:
|
||||
role = "assistant"
|
||||
default:
|
||||
role = strings.ToLower(role)
|
||||
|
||||
@ -33,7 +33,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
}, "", "")
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureCode, types.CompletionRequestParameters{Messages: []types.Message{{Text: "Hey"}}})
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureCode, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: []types.Message{{Text: "Hey"}}})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
@ -43,7 +43,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureCode, types.CompletionRequestParameters{Messages: []types.Message{{Text: "Hey"}}}, func(event types.CompletionResponse) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureCode, types.CompletionsVersionLegacy, types.CompletionRequestParameters{Messages: []types.Message{{Text: "Hey"}}}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("Fireworks: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
|
||||
|
||||
@ -31,9 +31,9 @@ type observedClient struct {
|
||||
|
||||
var _ types.CompletionsClient = (*observedClient)(nil)
|
||||
|
||||
func (o *observedClient) Stream(ctx context.Context, feature types.CompletionsFeature, params types.CompletionRequestParameters, send types.SendCompletionEvent) (err error) {
|
||||
func (o *observedClient) Stream(ctx context.Context, feature types.CompletionsFeature, version types.CompletionsVersion, params types.CompletionRequestParameters, send types.SendCompletionEvent) (err error) {
|
||||
ctx, tr, endObservation := o.ops.stream.With(ctx, &err, observation.Args{
|
||||
Attrs: append(params.Attrs(feature), attribute.String("feature", string(feature))),
|
||||
Attrs: append(params.Attrs(feature), attribute.String("feature", string(feature)), attribute.Int("version", int(version))),
|
||||
MetricLabelValues: []string{params.Model},
|
||||
})
|
||||
defer endObservation(1, observation.Args{})
|
||||
@ -48,12 +48,12 @@ func (o *observedClient) Stream(ctx context.Context, feature types.CompletionsFe
|
||||
return send(event)
|
||||
}
|
||||
|
||||
return o.inner.Stream(ctx, feature, params, tracedSend)
|
||||
return o.inner.Stream(ctx, feature, version, params, tracedSend)
|
||||
}
|
||||
|
||||
func (o *observedClient) Complete(ctx context.Context, feature types.CompletionsFeature, params types.CompletionRequestParameters) (resp *types.CompletionResponse, err error) {
|
||||
func (o *observedClient) Complete(ctx context.Context, feature types.CompletionsFeature, version types.CompletionsVersion, params types.CompletionRequestParameters) (resp *types.CompletionResponse, err error) {
|
||||
ctx, _, endObservation := o.ops.complete.With(ctx, &err, observation.Args{
|
||||
Attrs: append(params.Attrs(feature), attribute.String("feature", string(feature))),
|
||||
Attrs: append(params.Attrs(feature), attribute.String("feature", string(feature)), attribute.Int("version", int(version))),
|
||||
MetricLabelValues: []string{params.Model},
|
||||
})
|
||||
defer endObservation(1, observation.Args{})
|
||||
@ -64,7 +64,7 @@ func (o *observedClient) Complete(ctx context.Context, feature types.Completions
|
||||
},
|
||||
})
|
||||
|
||||
return o.inner.Complete(ctx, feature, params)
|
||||
return o.inner.Complete(ctx, feature, version, params)
|
||||
}
|
||||
|
||||
type operations struct {
|
||||
|
||||
@ -30,6 +30,7 @@ type openAIChatCompletionStreamClient struct {
|
||||
func (c *openAIChatCompletionStreamClient) Complete(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
) (*types.CompletionResponse, error) {
|
||||
var resp *http.Response
|
||||
@ -69,6 +70,7 @@ func (c *openAIChatCompletionStreamClient) Complete(
|
||||
func (c *openAIChatCompletionStreamClient) Stream(
|
||||
ctx context.Context,
|
||||
feature types.CompletionsFeature,
|
||||
_ types.CompletionsVersion,
|
||||
requestParams types.CompletionRequestParameters,
|
||||
sendEvent types.SendCompletionEvent,
|
||||
) error {
|
||||
@ -155,7 +157,7 @@ func (c *openAIChatCompletionStreamClient) makeRequest(ctx context.Context, requ
|
||||
switch m.Speaker {
|
||||
case types.HUMAN_MESSAGE_SPEAKER:
|
||||
role = "user"
|
||||
case types.ASISSTANT_MESSAGE_SPEAKER:
|
||||
case types.ASSISTANT_MESSAGE_SPEAKER:
|
||||
role = "assistant"
|
||||
//
|
||||
default:
|
||||
|
||||
@ -33,7 +33,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
}, "", "")
|
||||
|
||||
t.Run("Complete", func(t *testing.T) {
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{})
|
||||
resp, err := mockClient.Complete(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
@ -43,7 +43,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
err := mockClient.Stream(context.Background(), types.CompletionsFeatureChat, types.CompletionsVersionLegacy, types.CompletionRequestParameters{}, func(event types.CompletionResponse) error { return nil })
|
||||
require.Error(t, err)
|
||||
|
||||
autogold.Expect("OpenAI: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error())
|
||||
|
||||
@ -58,6 +58,7 @@ func allowedCustomModel(model string) string {
|
||||
"fireworks/" + fireworks.Llama234bCodeInstruct,
|
||||
"fireworks/" + fireworks.Mistral7bInstruct,
|
||||
"anthropic/claude-instant-1.2",
|
||||
"anthropic/claude-3-haiku-20240307",
|
||||
// Deprecated model identifiers
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1",
|
||||
|
||||
@ -80,6 +80,17 @@ func newCompletionsHandler(
|
||||
return
|
||||
}
|
||||
|
||||
var version types.CompletionsVersion
|
||||
versionParam := r.URL.Query().Get("api-version")
|
||||
if versionParam == "" {
|
||||
version = types.CompletionsVersionLegacy
|
||||
} else if versionParam == "1" {
|
||||
version = types.CompletionsV1
|
||||
} else {
|
||||
http.Error(w, "Unsupported API Version (Please update your client)", http.StatusNotAcceptable)
|
||||
return
|
||||
}
|
||||
|
||||
var requestParams types.CodyCompletionRequestParameters
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestParams); err != nil {
|
||||
http.Error(w, "could not decode request body", http.StatusBadRequest)
|
||||
@ -179,7 +190,7 @@ func newCompletionsHandler(
|
||||
}
|
||||
}
|
||||
|
||||
responseHandler(ctx, requestParams.CompletionRequestParameters, completionClient, w, userStore, test)
|
||||
responseHandler(ctx, requestParams.CompletionRequestParameters, version, completionClient, w, userStore, test)
|
||||
})
|
||||
}
|
||||
|
||||
@ -200,23 +211,23 @@ func respondRateLimited(w http.ResponseWriter, err RateLimitExceededError, isDot
|
||||
|
||||
// newSwitchingResponseHandler handles requests to an LLM provider, and wraps the correct
|
||||
// handler based on the requestParams.Stream flag.
|
||||
func newSwitchingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
func newSwitchingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
nonStreamer := newNonStreamingResponseHandler(logger, db, feature)
|
||||
streamer := newStreamingResponseHandler(logger, db, feature)
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
if requestParams.IsStream(feature) {
|
||||
streamer(ctx, requestParams, cc, w, userStore, test)
|
||||
streamer(ctx, requestParams, version, cc, w, userStore, test)
|
||||
} else {
|
||||
// TODO(#59832): Add attribution to non-streaming endpoint.
|
||||
nonStreamer(ctx, requestParams, cc, w, userStore)
|
||||
nonStreamer(ctx, requestParams, version, cc, w, userStore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newStreamingResponseHandler handles streaming requests to an LLM provider,
|
||||
// It writes events to an SSE stream as they come in.
|
||||
func newStreamingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
func newStreamingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore, test guardrails.AttributionTest) {
|
||||
var eventWriter = sync.OnceValue[*streamhttp.Writer](func() *streamhttp.Writer {
|
||||
eventWriter, err := streamhttp.NewWriter(w)
|
||||
if err != nil {
|
||||
@ -271,7 +282,7 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type
|
||||
f = ff
|
||||
}
|
||||
}
|
||||
err := cc.Stream(ctx, feature, requestParams,
|
||||
err := cc.Stream(ctx, feature, version, requestParams,
|
||||
func(event types.CompletionResponse) error {
|
||||
if !firstEventObserved {
|
||||
firstEventObserved = true
|
||||
@ -347,9 +358,9 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type
|
||||
// newNonStreamingResponseHandler handles non-streaming requests to an LLM provider,
|
||||
// awaiting the complete response before writing it back in a structured JSON response
|
||||
// to the client.
|
||||
func newNonStreamingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
|
||||
completion, err := cc.Complete(ctx, feature, requestParams)
|
||||
func newNonStreamingResponseHandler(logger log.Logger, db database.DB, feature types.CompletionsFeature) func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
|
||||
return func(ctx context.Context, requestParams types.CompletionRequestParameters, version types.CompletionsVersion, cc types.CompletionsClient, w http.ResponseWriter, userStore database.UserStore) {
|
||||
completion, err := cc.Complete(ctx, feature, version, requestParams)
|
||||
if err != nil {
|
||||
logFields := []log.Field{log.Error(err)}
|
||||
|
||||
|
||||
@ -10,7 +10,8 @@ import (
|
||||
)
|
||||
|
||||
const HUMAN_MESSAGE_SPEAKER = "human"
|
||||
const ASISSTANT_MESSAGE_SPEAKER = "assistant"
|
||||
const ASSISTANT_MESSAGE_SPEAKER = "assistant"
|
||||
const SYSTEM_MESSAGE_SPEAKER = "system"
|
||||
|
||||
type Message struct {
|
||||
Speaker string `json:"speaker"`
|
||||
@ -18,7 +19,7 @@ type Message struct {
|
||||
}
|
||||
|
||||
func (m Message) IsValidSpeaker() bool {
|
||||
return m.Speaker == HUMAN_MESSAGE_SPEAKER || m.Speaker == ASISSTANT_MESSAGE_SPEAKER
|
||||
return m.Speaker == HUMAN_MESSAGE_SPEAKER || m.Speaker == ASSISTANT_MESSAGE_SPEAKER
|
||||
}
|
||||
|
||||
func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (string, error) {
|
||||
@ -26,7 +27,7 @@ func (m Message) GetPrompt(humanPromptPrefix, assistantPromptPrefix string) (str
|
||||
switch m.Speaker {
|
||||
case HUMAN_MESSAGE_SPEAKER:
|
||||
prefix = humanPromptPrefix
|
||||
case ASISSTANT_MESSAGE_SPEAKER:
|
||||
case ASSISTANT_MESSAGE_SPEAKER:
|
||||
prefix = assistantPromptPrefix
|
||||
default:
|
||||
return "", errors.Newf("expected message speaker to be 'human' or 'assistant', got %s", m.Speaker)
|
||||
@ -167,11 +168,18 @@ func (b CompletionsFeature) ID() int {
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionsVersion int
|
||||
|
||||
const (
|
||||
CompletionsVersionLegacy CompletionsVersion = 0
|
||||
CompletionsV1 CompletionsVersion = 1
|
||||
)
|
||||
|
||||
type CompletionsClient interface {
|
||||
// Stream executions a completions request, streaming results to the callback.
|
||||
// Callers should check for ErrStatusNotOK and handle the error appropriately.
|
||||
Stream(context.Context, CompletionsFeature, CompletionRequestParameters, SendCompletionEvent) error
|
||||
Stream(context.Context, CompletionsFeature, CompletionsVersion, CompletionRequestParameters, SendCompletionEvent) error
|
||||
// Complete executions a completions request until done. Callers should check
|
||||
// for ErrStatusNotOK and handle the error appropriately.
|
||||
Complete(context.Context, CompletionsFeature, CompletionRequestParameters) (*CompletionResponse, error)
|
||||
Complete(context.Context, CompletionsFeature, CompletionsVersion, CompletionRequestParameters) (*CompletionResponse, error)
|
||||
}
|
||||
|
||||
@ -720,7 +720,7 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
|
||||
} else if completionsConfig.Provider == string(conftypes.CompletionsProviderNameAnthropic) {
|
||||
// If no endpoint is configured, use a default value.
|
||||
if completionsConfig.Endpoint == "" {
|
||||
completionsConfig.Endpoint = "https://api.anthropic.com/v1/complete"
|
||||
completionsConfig.Endpoint = "https://api.anthropic.com/v1/messages"
|
||||
}
|
||||
|
||||
// If not access token is set, we cannot talk to Anthropic. Bail.
|
||||
@ -1213,9 +1213,9 @@ func anthropicDefaultMaxPromptTokens(model string) int {
|
||||
return 100_000
|
||||
|
||||
}
|
||||
if model == "claude-2" || model == "claude-2.0" || model == "claude-2.1" || model == "claude-v2" {
|
||||
// TODO: Technically, v2 also uses a 100k window, but we should validate
|
||||
// that returning 100k here is the right thing to do.
|
||||
if model == "claude-2" || model == "claude-2.0" || model == "claude-2.1" || model == "claude-v2" || model == "claude-3-sonnet-20240229" || model == "claude-3-opus-20240229" || model == "claude-3-haiku-20240307" {
|
||||
// TODO: Technically, v2 and v3 also uses a 100k/200k window respectively, but we should
|
||||
// validate that returning 100k here is the right thing to do.
|
||||
return 12_000
|
||||
}
|
||||
// For now, all other claude models have a 9k token window.
|
||||
|
||||
@ -416,7 +416,7 @@ func TestGetCompletionsConfig(t *testing.T) {
|
||||
CompletionModelMaxTokens: 9000,
|
||||
AccessToken: "asdf",
|
||||
Provider: "anthropic",
|
||||
Endpoint: "https://api.anthropic.com/v1/complete",
|
||||
Endpoint: "https://api.anthropic.com/v1/messages",
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -441,7 +441,7 @@ func TestGetCompletionsConfig(t *testing.T) {
|
||||
CompletionModelMaxTokens: 9000,
|
||||
AccessToken: "asdf",
|
||||
Provider: "anthropic",
|
||||
Endpoint: "https://api.anthropic.com/v1/complete",
|
||||
Endpoint: "https://api.anthropic.com/v1/messages",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@ -640,7 +640,7 @@ type Completions struct {
|
||||
CompletionModelMaxTokens int `json:"completionModelMaxTokens,omitempty"`
|
||||
// Enabled description: DEPRECATED. Use cody.enabled instead to turn Cody on/off.
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
// Endpoint description: The endpoint under which to reach the provider. Currently only used for provider types "sourcegraph", "openai" and "anthropic". The default values are "https://cody-gateway.sourcegraph.com", "https://api.openai.com/v1/chat/completions", and "https://api.anthropic.com/v1/complete" for Sourcegraph, OpenAI, and Anthropic, respectively.
|
||||
// Endpoint description: The endpoint under which to reach the provider. Currently only used for provider types "sourcegraph", "openai" and "anthropic". The default values are "https://cody-gateway.sourcegraph.com", "https://api.openai.com/v1/chat/completions", and "https://api.anthropic.com/v1/messages" for Sourcegraph, OpenAI, and Anthropic, respectively.
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
// FastChatModel description: The model used for fast chat completions.
|
||||
FastChatModel string `json:"fastChatModel,omitempty"`
|
||||
|
||||
@ -2912,7 +2912,7 @@
|
||||
},
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "The endpoint under which to reach the provider. Currently only used for provider types \"sourcegraph\", \"openai\" and \"anthropic\". The default values are \"https://cody-gateway.sourcegraph.com\", \"https://api.openai.com/v1/chat/completions\", and \"https://api.anthropic.com/v1/complete\" for Sourcegraph, OpenAI, and Anthropic, respectively."
|
||||
"description": "The endpoint under which to reach the provider. Currently only used for provider types \"sourcegraph\", \"openai\" and \"anthropic\". The default values are \"https://cody-gateway.sourcegraph.com\", \"https://api.openai.com/v1/chat/completions\", and \"https://api.anthropic.com/v1/messages\" for Sourcegraph, OpenAI, and Anthropic, respectively."
|
||||
},
|
||||
"perUserDailyLimit": {
|
||||
"description": "If > 0, limits the number of completions requests allowed for a user in a day. On instances that allow anonymous requests, we enforce the rate limit by IP.",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user