From 5da08a3bc8aa77046bb025fe87497d24c469aff3 Mon Sep 17 00:00:00 2001 From: Philipp Spiess Date: Wed, 6 Mar 2024 18:13:03 +0100 Subject: [PATCH] Add support for Claude 3 to Cody Gateway (#60830) Co-authored-by: Chris Warwick --- .../internal/httpapi/completions/BUILD.bazel | 4 + .../internal/httpapi/completions/anthropic.go | 72 +---- .../httpapi/completions/anthropic_test.go | 5 +- .../httpapi/completions/anthropicmessages.go | 298 ++++++++++++++++++ .../completions/anthropicmessages_test.go | 166 ++++++++++ .../internal/httpapi/completions/flagging.go | 101 ++++++ .../internal/httpapi/completions/upstream.go | 21 +- cmd/cody-gateway/internal/httpapi/handler.go | 35 +- cmd/cody-gateway/shared/config/config.go | 2 + .../codygateway_dotcom_user.go | 11 +- .../client/anthropicmessages/BUILD.bazel | 19 ++ .../client/anthropicmessages/decoder.go | 106 +++++++ .../client/anthropicmessages/decoder_test.go | 50 +++ internal/completions/httpapi/chat.go | 17 +- 14 files changed, 821 insertions(+), 86 deletions(-) create mode 100644 cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go create mode 100644 cmd/cody-gateway/internal/httpapi/completions/anthropicmessages_test.go create mode 100644 cmd/cody-gateway/internal/httpapi/completions/flagging.go create mode 100644 internal/completions/client/anthropicmessages/BUILD.bazel create mode 100644 internal/completions/client/anthropicmessages/decoder.go create mode 100644 internal/completions/client/anthropicmessages/decoder_test.go diff --git a/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel b/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel index 6bbdcb518ae..69351143a6c 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel +++ b/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel @@ -5,7 +5,9 @@ go_library( name = "completions", srcs = [ "anthropic.go", + "anthropicmessages.go", "fireworks.go", + "flagging.go", "openai.go", "upstream.go", ], @@ -23,6 +25,7 @@ go_library( "//cmd/cody-gateway/shared/config", "//internal/codygateway", "//internal/completions/client/anthropic", + "//internal/completions/client/anthropicmessages", "//internal/completions/client/fireworks", "//internal/completions/client/openai", "//internal/conf/conftypes", @@ -42,6 +45,7 @@ go_test( name = "completions_test", srcs = [ "anthropic_test.go", + "anthropicmessages_test.go", "fireworks_test.go", "openai_test.go", ], diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go index 73ca55c8905..13d6af4667b 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go @@ -6,9 +6,9 @@ import ( "encoding/json" "io" "net/http" - "strings" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic" @@ -248,61 +248,19 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, cfg if ar.Model != "claude-2" && ar.Model != "claude-2.0" && ar.Model != "claude-2.1" && ar.Model != "claude-v1" { return nil, nil } - var reasons []string - prompt := strings.ToLower(ar.Prompt) - - if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern { - reasons = append(reasons, "unknown_prompt") - } - - // If this request has a very high token count for responses, then flag it. - if ar.MaxTokensToSample > int32(cfg.MaxTokensToSampleFlaggingLimit) { - reasons = append(reasons, "high_max_tokens_to_sample") - } - - // If this prompt consists of a very large number of tokens, then flag it. - tokenCount, err := ar.GetPromptTokenCount(tk) - if err != nil { - return &flaggingResult{}, errors.Wrap(err, "tokenize prompt") - } - if tokenCount > cfg.PromptTokenFlaggingLimit { - reasons = append(reasons, "high_prompt_token_count") - } - - if len(reasons) > 0 { // request is flagged - blocked := false - hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns) - if tokenCount > cfg.PromptTokenBlockingLimit || ar.MaxTokensToSample > int32(cfg.ResponseTokenBlockingLimit) || hasBlockedPhrase { - blocked = true - } - - promptPrefix := ar.Prompt - if len(promptPrefix) > logPromptPrefixLength { - promptPrefix = promptPrefix[0:logPromptPrefixLength] - } - res := &flaggingResult{ - reasons: reasons, - maxTokensToSample: int(ar.MaxTokensToSample), - promptPrefix: promptPrefix, - promptTokenCount: tokenCount, - shouldBlock: blocked, - } - if hasBlockedPhrase { - res.blockedPhrase = &phrase - } - return res, nil - } - - return nil, nil -} - -func containsAny(prompt string, patterns []string) (bool, string) { - prompt = strings.ToLower(prompt) - for _, pattern := range patterns { - if strings.Contains(prompt, pattern) { - return true, pattern - } - } - return false, "" + return isFlaggedRequest(tk, + flaggingRequest{ + FlattenedPrompt: ar.Prompt, + MaxTokens: int(ar.MaxTokensToSample), + }, + flaggingConfig{ + AllowedPromptPatterns: cfg.AllowedPromptPatterns, + BlockedPromptPatterns: cfg.BlockedPromptPatterns, + PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit, + PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit, + MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit, + ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit, + }, + ) } diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropic_test.go b/cmd/cody-gateway/internal/httpapi/completions/anthropic_test.go index db5a52bd517..3916993d86e 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropic_test.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropic_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/hexops/autogold/v2" - "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer" ) @@ -58,7 +59,7 @@ func TestIsFlaggedAnthropicRequest(t *testing.T) { require.Equal(t, int32(result.maxTokensToSample), ar.MaxTokensToSample) }) - t.Run("high prompt token count (below block limit)", func(t *testing.T) { + t.Run("high prompt token count (above block limit)", func(t *testing.T) { tokenLengths, err := tk.Tokenize(validPreamble) require.NoError(t, err) diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go new file mode 100644 index 00000000000..cff8aa478a1 --- /dev/null +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go @@ -0,0 +1,298 @@ +package completions + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/sourcegraph/log" + + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" + "github.com/sourcegraph/sourcegraph/lib/errors" + + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer" + "github.com/sourcegraph/sourcegraph/internal/codygateway" + "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages" + "github.com/sourcegraph/sourcegraph/internal/conf/conftypes" + "github.com/sourcegraph/sourcegraph/internal/httpcli" +) + +const anthropicMessagesAPIURL = "https://api.anthropic.com/v1/messages" + +// This implements the newer `/messages` API by Anthropic +// https://docs.anthropic.com/claude/reference/messages_post +func NewAnthropicMessagesHandler( + baseLogger log.Logger, + eventLogger events.Logger, + rs limiter.RedisStore, + rateLimitNotifier notify.RateLimitNotifier, + httpClient httpcli.Doer, + config config.AnthropicConfig, + promptRecorder PromptRecorder, + autoFlushStreamingResponses bool, +) (http.Handler, error) { + // Tokenizer only needs to be initialized once, and can be shared globally. + tokenizer, err := tokenizer.NewAnthropicClaudeTokenizer() + if err != nil { + return nil, err + } + return makeUpstreamHandler[anthropicMessagesRequest]( + baseLogger, + eventLogger, + rs, + rateLimitNotifier, + httpClient, + string(conftypes.CompletionsProviderNameAnthropic), + func(_ codygateway.Feature) string { return anthropicMessagesAPIURL }, + config.AllowedModels, + &AnthropicMessagesHandlerMethods{config: config, tokenizer: tokenizer, promptRecorder: promptRecorder}, + + // Anthropic primarily uses concurrent requests to rate-limit spikes + // in requests, so set a default retry-after that is likely to be + // acceptable for Sourcegraph clients to retry (the default + // SRC_HTTP_CLI_EXTERNAL_RETRY_AFTER_MAX_DURATION) since we might be + // able to circumvent concurrents limits without raising an error to the + // user. + 2, // seconds + autoFlushStreamingResponses, + nil, + ), nil +} + +// AnthropicMessagesRequest captures all known fields from https://console.anthropic.com/docs/api/reference. +type anthropicMessagesRequest struct { + Messages []anthropicMessage `json:"messages,omitempty"` + Model string `json:"model"` + MaxTokens int32 `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + TopK int32 `json:"top_k,omitempty"` + Stream bool `json:"stream,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + + // These are not accepted from the client an instead are only used to talk + // to the upstream LLM APIs. + Metadata *anthropicMessagesRequestMetadata `json:"metadata,omitempty"` + System string `json:"system,omitempty"` +} + +type anthropicMessage struct { + Role string `json:"role"` // "user", "assistant", or "system" (only allowed for the first message) + Content []anthropicMessageContent `json:"content"` +} + +type anthropicMessageContent struct { + Type string `json:"type"` // "text" or "image" (not yet supported) + Text string `json:"text"` +} + +type anthropicMessagesRequestMetadata struct { + UserID string `json:"user_id,omitempty"` +} + +func (ar anthropicMessagesRequest) ShouldStream() bool { + return ar.Stream +} + +func (ar anthropicMessagesRequest) GetModel() string { + return ar.Model +} + +// Note: This is not the actual prompt send to Anthropic but it's a good +// approximation to measure tokens. +func (r anthropicMessagesRequest) BuildPrompt() string { + var sb strings.Builder + for _, m := range r.Messages { + switch m.Role { + case "user": + sb.WriteString("Human: ") + case "assistant": + sb.WriteString("Assistant: ") + case "system": + sb.WriteString("System: ") + default: + return "" + } + + for _, c := range m.Content { + if c.Type == "text" { + sb.WriteString(c.Text) + } + } + sb.WriteString("\n\n") + } + return sb.String() +} + +// AnthropicMessagesNonStreamingResponse captures all relevant-to-us fields from https://docs.anthropic.com/claude/reference/messages_post. +type anthropicMessagesNonStreamingResponse struct { + Content []anthropicMessageContent `json:"content"` + Usage anthropicMessagesResponseUsage `json:"usage"` + StopReason string `json:"stop_reason"` +} + +// AnthropicMessagesStreamingResponse captures all relevant-to-us fields from each relevant SSE event from https://docs.anthropic.com/claude/reference/messages_post. +type anthropicMessagesStreamingResponse struct { + Type string `json:"type"` + Delta *anthropicMessagesStreamingResponseTextBucket `json:"delta"` + ContentBlock *anthropicMessagesStreamingResponseTextBucket `json:"content_block"` + Usage *anthropicMessagesResponseUsage `json:"usage"` + Message *anthropicStreamingResponseMessage `json:"message"` +} + +type anthropicStreamingResponseMessage struct { + Usage *anthropicMessagesResponseUsage `json:"usage"` +} + +type anthropicMessagesStreamingResponseTextBucket struct { + Text string `json:"text"` +} + +type anthropicMessagesResponseUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type AnthropicMessagesHandlerMethods struct { + tokenizer *tokenizer.Tokenizer + promptRecorder PromptRecorder + config config.AnthropicConfig +} + +func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, logger log.Logger, _ codygateway.Feature, ar anthropicMessagesRequest) (int, *flaggingResult, error) { + if ar.MaxTokens > int32(a.config.MaxTokensToSample) { + return http.StatusBadRequest, nil, errors.Errorf("max_tokens exceeds maximum allowed value of %d: %d", a.config.MaxTokensToSample, ar.MaxTokens) + } + + if result, err := isFlaggedAnthropicMessagesRequest(a.tokenizer, ar, a.config); err != nil { + logger.Error("error checking AnthropicMessages request - treating as non-flagged", + log.Error(err)) + } else if result.IsFlagged() { + // Record flagged prompts in hotpath - they usually take a long time on the backend side, so this isn't going to make things meaningfully worse + if err := a.promptRecorder.Record(ctx, ar.BuildPrompt()); err != nil { + logger.Warn("failed to record flagged prompt", log.Error(err)) + } + if a.config.RequestBlockingEnabled && result.shouldBlock { + return http.StatusBadRequest, result, errors.Errorf("request blocked - if you think this is a mistake, please contact support@sourcegraph.com") + } + return 0, result, nil + } + + return 0, nil, nil +} +func (a *AnthropicMessagesHandlerMethods) transformBody(body *anthropicMessagesRequest, identifier string) { + // Overwrite the metadata field, we don't want to allow users to specify it: + body.Metadata = &anthropicMessagesRequestMetadata{ + // We forward the actor ID to support tracking. + UserID: identifier, + } + + // Remove the `anthropic/` prefix from the model string + body.Model = strings.TrimPrefix(body.Model, "anthropic/") + + // Convert the eventual first message from `system` to a top-level system prompt + body.System = "" // prevent the upstream API from setting this + if len(body.Messages) > 0 && body.Messages[0].Role == "system" { + body.System = body.Messages[0].Content[0].Text + body.Messages = body.Messages[1:] + } +} +func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessagesRequest) (model string, additionalMetadata map[string]any) { + return body.Model, map[string]any{ + "stream": body.Stream, + "max_tokens": body.MaxTokens, + } +} +func (a *AnthropicMessagesHandlerMethods) transformRequest(r *http.Request) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("X-API-Key", a.config.AccessToken) + r.Header.Set("anthropic-version", "2023-06-01") +} +func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader) (promptUsage, completionUsage usageStats) { + // First, extract prompt usage details from the request. + for _, m := range body.Messages { + promptUsage.characters += len(m.Content) + } + promptUsage.characters = len(body.System) + + // Try to parse the request we saw, if it was non-streaming, we can simply parse + // it as JSON. + if !body.ShouldStream() { + var res anthropicMessagesNonStreamingResponse + if err := json.NewDecoder(r).Decode(&res); err != nil { + logger.Error("failed to parse Anthropic response as JSON", log.Error(err)) + return promptUsage, completionUsage + } + + // Extract character data from response by summing up all text + for _, c := range res.Content { + completionUsage.characters += len(c.Text) + } + // Extract prompt usage data from the response + completionUsage.tokens = res.Usage.OutputTokens + promptUsage.tokens = res.Usage.InputTokens + + return promptUsage, completionUsage + } + + // Otherwise, we have to parse the event stream from anthropic. + dec := anthropicmessages.NewDecoder(r) + for dec.Scan() { + data := dec.Data() + + // Gracefully skip over any data that isn't JSON-like. Anthropic's API sometimes sends + // non-documented data over the stream, like timestamps. + if !bytes.HasPrefix(data, []byte("{")) { + continue + } + + var event anthropicMessagesStreamingResponse + if err := json.Unmarshal(data, &event); err != nil { + logger.Error("failed to decode event payload", log.Error(err), log.String("body", string(data))) + continue + } + + switch event.Type { + case "message_start": + if event.Message != nil && event.Message.Usage != nil { + promptUsage.tokens = event.Message.Usage.InputTokens + } + case "content_block_delta": + if event.Delta != nil { + completionUsage.characters += len(event.Delta.Text) + } + case "message_delta": + if event.Usage != nil { + completionUsage.tokens = event.Usage.OutputTokens + } + } + } + if err := dec.Err(); err != nil { + logger.Error("failed to decode Anthropic streaming response", log.Error(err)) + } + + return promptUsage, completionUsage +} + +func isFlaggedAnthropicMessagesRequest(tk *tokenizer.Tokenizer, r anthropicMessagesRequest, cfg config.AnthropicConfig) (*flaggingResult, error) { + return isFlaggedRequest(tk, + flaggingRequest{ + FlattenedPrompt: r.BuildPrompt(), + MaxTokens: int(r.MaxTokens), + }, + flaggingConfig{ + AllowedPromptPatterns: cfg.AllowedPromptPatterns, + BlockedPromptPatterns: cfg.BlockedPromptPatterns, + PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit, + PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit, + MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit, + ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit, + }, + ) +} diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages_test.go b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages_test.go new file mode 100644 index 00000000000..1551410de0b --- /dev/null +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages_test.go @@ -0,0 +1,166 @@ +package completions + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/hexops/autogold/v2" + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" + + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer" +) + +func TestIsFlaggedAnthropicMessagesRequest(t *testing.T) { + validPreamble := "You are cody-gateway." + + cfg := config.AnthropicConfig{ + PromptTokenFlaggingLimit: 18000, + PromptTokenBlockingLimit: 20000, + MaxTokensToSampleFlaggingLimit: 1000, + ResponseTokenBlockingLimit: 1000, + } + cfgWithPreamble := config.AnthropicConfig{ + PromptTokenFlaggingLimit: 18000, + PromptTokenBlockingLimit: 20000, + MaxTokensToSampleFlaggingLimit: 1000, + ResponseTokenBlockingLimit: 1000, + AllowedPromptPatterns: []string{strings.ToLower(validPreamble)}, + } + tk, err := tokenizer.NewAnthropicClaudeTokenizer() + require.NoError(t, err) + + t.Run("works for known preamble", func(t *testing.T) { + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("missing known preamble", func(t *testing.T) { + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble) + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.False(t, result.shouldBlock) + require.Contains(t, result.reasons, "unknown_prompt") + }) + + t.Run("preamble not configured ", func(t *testing.T) { + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg) + require.NoError(t, err) + require.False(t, result.IsFlagged()) + }) + + t.Run("high max tokens to sample", func(t *testing.T) { + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", MaxTokens: 10000, Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg) + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.True(t, result.shouldBlock) + require.Contains(t, result.reasons, "high_max_tokens_to_sample") + require.Equal(t, int32(result.maxTokensToSample), r.MaxTokens) + }) + + t.Run("high prompt token count and bad phrase", func(t *testing.T) { + cfgWithBadPhrase := &cfgWithPreamble + cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"} + longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1) + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase) + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.True(t, result.shouldBlock) + }) + + t.Run("low prompt token count and bad phrase", func(t *testing.T) { + cfgWithBadPhrase := &cfgWithPreamble + cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"} + longPrompt := strings.Repeat("word ", 5) + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}}, + }} + result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase) + require.NoError(t, err) + // for now, we should not flag requests purely because of bad phrases + require.False(t, result.IsFlagged()) + }) + + t.Run("high prompt token count (above block limit)", func(t *testing.T) { + tokenLengths, err := tk.Tokenize(validPreamble) + require.NoError(t, err) + + validPreambleTokens := len(tokenLengths) + longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1) + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}}, + {Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}}, + }} + + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble) + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.False(t, result.shouldBlock) + require.Contains(t, result.reasons, "high_prompt_token_count") + require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenFlaggingLimit+4, cfg) + }) + + t.Run("high prompt token count (below block limit)", func(t *testing.T) { + tokenLengths, err := tk.Tokenize(validPreamble) + require.NoError(t, err) + + validPreambleTokens := len(tokenLengths) + longPrompt := strings.Repeat("word ", cfg.PromptTokenBlockingLimit+1) + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}}, + {Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}}, + }} + + result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble) + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.True(t, result.shouldBlock) + require.Contains(t, result.reasons, "high_prompt_token_count") + require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenBlockingLimit+4) + }) +} + +func TestAnthropicMessagesRequestJSON(t *testing.T) { + _, err := tokenizer.NewAnthropicClaudeTokenizer() + require.NoError(t, err) + + r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{ + {Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: "Hello world"}}}, + }} + + b, err := json.MarshalIndent(r, "", "\t") + require.NoError(t, err) + + autogold.Expect(`{ +"messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello world" + } + ] + } +], +"model": "anthropic/claude-3-sonnet-20240229" +}`).Equal(t, string(b)) +} diff --git a/cmd/cody-gateway/internal/httpapi/completions/flagging.go b/cmd/cody-gateway/internal/httpapi/completions/flagging.go new file mode 100644 index 00000000000..f548a01bedd --- /dev/null +++ b/cmd/cody-gateway/internal/httpapi/completions/flagging.go @@ -0,0 +1,101 @@ +package completions + +import ( + "strings" + + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +type flaggingConfig struct { + // Phrases we look for in the prompt to consider it valid. + // Each phrase is lower case. + AllowedPromptPatterns []string + // Phrases we look for in a flagged request to consider blocking the response. + // Each phrase is lower case. Can be empty (to disable blocking). + BlockedPromptPatterns []string + // Phrases we look for in a request to collect data. + // Each phrase is lower case. Can be empty (to disable data collection). + PromptTokenFlaggingLimit int + PromptTokenBlockingLimit int + MaxTokensToSampleFlaggingLimit int + ResponseTokenBlockingLimit int +} +type flaggingRequest struct { + FlattenedPrompt string + MaxTokens int +} +type flaggingResult struct { + shouldBlock bool + blockedPhrase *string + reasons []string + promptPrefix string + maxTokensToSample int + promptTokenCount int +} + +func isFlaggedRequest(tk *tokenizer.Tokenizer, r flaggingRequest, cfg flaggingConfig) (*flaggingResult, error) { + var reasons []string + + prompt := strings.ToLower(r.FlattenedPrompt) + + if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern { + reasons = append(reasons, "unknown_prompt") + } + + // If this request has a very high token count for responses, then flag it. + if r.MaxTokens > cfg.MaxTokensToSampleFlaggingLimit { + reasons = append(reasons, "high_max_tokens_to_sample") + } + + // If this prompt consists of a very large number of tokens, then flag it. + tokens, err := tk.Tokenize(r.FlattenedPrompt) + if err != nil { + return &flaggingResult{}, errors.Wrap(err, "tokenize prompt") + } + tokenCount := len(tokens) + + if tokenCount > cfg.PromptTokenFlaggingLimit { + reasons = append(reasons, "high_prompt_token_count") + } + + if len(reasons) > 0 { // request is flagged + blocked := false + hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns) + if tokenCount > cfg.PromptTokenBlockingLimit || r.MaxTokens > cfg.ResponseTokenBlockingLimit || hasBlockedPhrase { + blocked = true + } + + promptPrefix := r.FlattenedPrompt + if len(promptPrefix) > logPromptPrefixLength { + promptPrefix = promptPrefix[0:logPromptPrefixLength] + } + res := &flaggingResult{ + reasons: reasons, + maxTokensToSample: r.MaxTokens, + promptPrefix: promptPrefix, + promptTokenCount: tokenCount, + shouldBlock: blocked, + } + if hasBlockedPhrase { + res.blockedPhrase = &phrase + } + return res, nil + } + + return nil, nil +} + +func (f *flaggingResult) IsFlagged() bool { + return f != nil +} + +func containsAny(prompt string, patterns []string) (bool, string) { + prompt = strings.ToLower(prompt) + for _, pattern := range patterns { + if strings.Contains(prompt, pattern) { + return true, pattern + } + } + return false, "" +} diff --git a/cmd/cody-gateway/internal/httpapi/completions/upstream.go b/cmd/cody-gateway/internal/httpapi/completions/upstream.go index d358e134bfa..4420b45528a 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/upstream.go +++ b/cmd/cody-gateway/internal/httpapi/completions/upstream.go @@ -130,8 +130,10 @@ func makeUpstreamHandler[ReqT UpstreamRequest]( // Convert allowedModels to the Cody Gateway configuration format with the // provider as a prefix. This aligns with the models returned when we query // for rate limits from actor sources. - for i := range allowedModels { - allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i]) + clonedAllowedModels := make([]string, len(allowedModels)) + copy(clonedAllowedModels, allowedModels) + for i := range clonedAllowedModels { + clonedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, clonedAllowedModels[i]) } // turn off sanitization for profanity detection @@ -284,7 +286,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest]( // the prefix yet when extracted - we need to add it back here. This // full gatewayModel is also used in events tracking. gatewayModel := fmt.Sprintf("%s/%s", upstreamName, model) - if allowed := intersection(allowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) { + if allowed := intersection(clonedAllowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) { response.JSONError(logger, w, http.StatusBadRequest, errors.Newf("model %q is not allowed, allowed: [%s]", gatewayModel, strings.Join(allowed, ", "))) @@ -508,16 +510,3 @@ func intersection(a, b []string) (c []string) { } return c } - -type flaggingResult struct { - shouldBlock bool - blockedPhrase *string - reasons []string - promptPrefix string - maxTokensToSample int - promptTokenCount int -} - -func (f *flaggingResult) IsFlagged() bool { - return f != nil -} diff --git a/cmd/cody-gateway/internal/httpapi/handler.go b/cmd/cody-gateway/internal/httpapi/handler.go index b3c3631c08c..e0176847e55 100644 --- a/cmd/cody-gateway/internal/httpapi/handler.go +++ b/cmd/cody-gateway/internal/httpapi/handler.go @@ -7,12 +7,13 @@ import ( "github.com/Khan/genqlient/graphql" "github.com/gorilla/mux" "github.com/sourcegraph/log" - "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth" @@ -83,7 +84,6 @@ func NewHandler( httpClient, config.Anthropic, promptRecorder, - config.AutoFlushStreamingResponses, ) if err != nil { @@ -106,6 +106,37 @@ func NewHandler( otelhttp.WithPublicEndpoint(), ), )) + + anthropicMessagesHandler, err := completions.NewAnthropicMessagesHandler( + logger, + eventLogger, + rs, + config.RateLimitNotifier, + httpClient, + config.Anthropic, + promptRecorder, + config.AutoFlushStreamingResponses, + ) + if err != nil { + return nil, errors.Wrap(err, "init anthropicMessages handler") + } + + v1router.Path("/completions/anthropic-messages").Methods(http.MethodPost).Handler( + overhead.HTTPMiddleware(latencyHistogram, + instrumentation.HTTPMiddleware("v1.completions.anthropicmessages", + gaugeHandler( + counter, + attributesAnthropicCompletions, + authr.Middleware( + requestlogger.Middleware( + logger, + anthropicMessagesHandler, + ), + ), + ), + otelhttp.WithPublicEndpoint(), + ), + )) } else { logger.Error("Anthropic access token not set") } diff --git a/cmd/cody-gateway/shared/config/config.go b/cmd/cody-gateway/shared/config/config.go index 3a3939d0ba7..cf92f4bac7a 100644 --- a/cmd/cody-gateway/shared/config/config.go +++ b/cmd/cody-gateway/shared/config/config.go @@ -141,6 +141,8 @@ func (c *Config) Load() { "claude-instant-v1.2", "claude-instant-1.2", "claude-instant-1.2-cyan", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", }, ","), "Anthropic models that can be used.")) if c.Anthropic.AccessToken != "" && len(c.Anthropic.AllowedModels) == 0 { diff --git a/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go b/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go index 099a3a687ac..b61fbaf72d9 100644 --- a/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go +++ b/cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go @@ -348,6 +348,7 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string { if !isProUser { return []string{ + // Remove after the Claude 3 rollout is complete "anthropic/claude-2.0", "anthropic/claude-instant-v1", "anthropic/claude-instant-1.2", @@ -356,6 +357,13 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string { } return []string{ + "anthropic/claude-3-sonnet-20240229", + "anthropic/claude-3-opus-20240229", + "fireworks/" + fireworks.Mixtral8x7bInstruct, + "openai/gpt-3.5-turbo", + "openai/gpt-4-1106-preview", + + // Remove after the Claude 3 rollout is complete "anthropic/claude-2", "anthropic/claude-2.0", "anthropic/claude-2.1", @@ -363,9 +371,6 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string { "anthropic/claude-instant-1.2", "anthropic/claude-instant-v1", "anthropic/claude-instant-1", - "openai/gpt-3.5-turbo", - "openai/gpt-4-1106-preview", - "fireworks/" + fireworks.Mixtral8x7bInstruct, } case types.CompletionsFeatureCode: return []string{ diff --git a/internal/completions/client/anthropicmessages/BUILD.bazel b/internal/completions/client/anthropicmessages/BUILD.bazel new file mode 100644 index 00000000000..5dc470a63af --- /dev/null +++ b/internal/completions/client/anthropicmessages/BUILD.bazel @@ -0,0 +1,19 @@ +load("//dev:go_defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "anthropicmessages", + srcs = ["decoder.go"], + importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages", + visibility = ["//:__subpackages__"], + deps = ["//lib/errors"], +) + +go_test( + name = "anthropicmessages_test", + timeout = "short", + srcs = ["decoder_test.go"], + data = glob(["testdata/**"]), + embed = [":anthropicmessages"], + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/internal/completions/client/anthropicmessages/decoder.go b/internal/completions/client/anthropicmessages/decoder.go new file mode 100644 index 00000000000..5c5b40c832e --- /dev/null +++ b/internal/completions/client/anthropicmessages/decoder.go @@ -0,0 +1,106 @@ +package anthropicmessages + +import ( + "bufio" + "bytes" + "io" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const maxPayloadSize = 10 * 1024 * 1024 // 10mb + +var doneBytes = []byte("[DONE]") + +// decoder decodes streaming events from a Server Sent Event stream. It only supports +// streams generated by the Anthropic completions API. IE this is not a fully +// compliant Server Sent Events decoder. +// +// Adapted from internal/search/streaming/http/decoder.go. +type decoder struct { + scanner *bufio.Scanner + done bool + data []byte + err error +} + +func NewDecoder(r io.Reader) *decoder { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize) + // bufio.ScanLines, except we look for \n\n which separate events. + split := func(data []byte, atEOF bool) (int, []byte, error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.Index(data, []byte("\n\n")); i >= 0 { + return i + 2, data[:i], nil + } + // If we're at EOF, we have a final, non-terminated event. This should + // be empty. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + scanner.Split(split) + return &decoder{ + scanner: scanner, + } +} + +// Scan advances the decoder to the next event in the stream. It returns +// false when it either hits the end of the stream or an error. +func (d *decoder) Scan() bool { + if d.done { + return false + } + for d.scanner.Scan() { + // event: $_name + // data: json($data)|[DONE] + + lines := bytes.Split(d.scanner.Bytes(), []byte("\n")) + for _, line := range lines { + typ, data := splitColon(line) + + switch { + case bytes.Equal(typ, []byte("data")): + d.data = data + // Check for special sentinel value used by the Anthropic API to + // indicate that the stream is done. + if bytes.Equal(data, doneBytes) { + d.done = true + return false + } + return true + case bytes.Equal(typ, []byte("event")): + // Anthropic sends the event name in the data payload as well so we ignore it for snow + continue + default: + d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line) + return false + } + } + } + + d.err = d.scanner.Err() + return false +} + +// Event returns the event data of the last decoded event +func (d *decoder) Data() []byte { + return d.data +} + +// Err returns the last encountered error +func (d *decoder) Err() error { + return d.err +} + +func splitColon(data []byte) ([]byte, []byte) { + i := bytes.Index(data, []byte(":")) + if i < 0 { + return bytes.TrimSpace(data), nil + } + return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:]) +} diff --git a/internal/completions/client/anthropicmessages/decoder_test.go b/internal/completions/client/anthropicmessages/decoder_test.go new file mode 100644 index 00000000000..1e6be96f1f3 --- /dev/null +++ b/internal/completions/client/anthropicmessages/decoder_test.go @@ -0,0 +1,50 @@ +package anthropicmessages + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecoder(t *testing.T) { + t.Parallel() + + type event struct { + data string + } + + decodeAll := func(input string) ([]event, error) { + dec := NewDecoder(strings.NewReader(input)) + var events []event + for dec.Scan() { + events = append(events, event{ + data: string(dec.Data()), + }) + } + return events, dec.Err() + } + + t.Run("Single", func(t *testing.T) { + events, err := decodeAll("event:foo\ndata:b\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "b"}}) + }) + + t.Run("Multiple", func(t *testing.T) { + events, err := decodeAll("event:foo\ndata:b\n\nevent:foo\ndata:c\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "b"}, {data: "c"}}) + }) + + t.Run("ErrExpectedData", func(t *testing.T) { + _, err := decodeAll("datas:b\n\n") + require.Contains(t, err.Error(), "malformed data, expected data") + }) + + t.Run("InterleavedPing", func(t *testing.T) { + events, err := decodeAll("data:a\n\nevent: ping\ndata: pong\n\ndata:b\n\ndata: [DONE]\n\n") + require.NoError(t, err) + require.Equal(t, events, []event{{data: "a"}, {data: "pong"}, {data: "b"}}) + }) +} diff --git a/internal/completions/httpapi/chat.go b/internal/completions/httpapi/chat.go index 5b42c16034e..a95b1c17ddc 100644 --- a/internal/completions/httpapi/chat.go +++ b/internal/completions/httpapi/chat.go @@ -74,21 +74,26 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool { // When updating these two lists, make sure you also update `allowedModels` in codygateway_dotcom_user.go. if isProUser { switch model { - case "anthropic/claude-2", + case "anthropic/claude-3-sonnet-20240229", + "anthropic/claude-3-opus-20240229", + "fireworks/" + fireworks.Mixtral8x7bInstruct, + "openai/gpt-3.5-turbo", + "openai/gpt-4-1106-preview", + + // Remove after the Claude 3 rollout is complete + "anthropic/claude-2", "anthropic/claude-2.0", "anthropic/claude-2.1", "anthropic/claude-instant-1.2-cyan", "anthropic/claude-instant-1.2", "anthropic/claude-instant-v1", - "anthropic/claude-instant-1", - "openai/gpt-3.5-turbo", - "openai/gpt-4-1106-preview", - "fireworks/" + fireworks.Mixtral8x7bInstruct: + "anthropic/claude-instant-1": return true } } else { switch model { - case "anthropic/claude-2", + case // Remove after the Claude 3 rollout is complete + "anthropic/claude-2", "anthropic/claude-2.0", "anthropic/claude-instant-v1", "anthropic/claude-instant-1":