Add support for Claude 3 to Cody Gateway (#60830)

Co-authored-by: Chris Warwick <christopher.warwick@sourcegraph.com>
This commit is contained in:
Philipp Spiess 2024-03-06 18:13:03 +01:00 committed by GitHub
parent 9e6e4e1bc9
commit 5da08a3bc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 821 additions and 86 deletions

View File

@ -5,7 +5,9 @@ go_library(
name = "completions",
srcs = [
"anthropic.go",
"anthropicmessages.go",
"fireworks.go",
"flagging.go",
"openai.go",
"upstream.go",
],
@ -23,6 +25,7 @@ go_library(
"//cmd/cody-gateway/shared/config",
"//internal/codygateway",
"//internal/completions/client/anthropic",
"//internal/completions/client/anthropicmessages",
"//internal/completions/client/fireworks",
"//internal/completions/client/openai",
"//internal/conf/conftypes",
@ -42,6 +45,7 @@ go_test(
name = "completions_test",
srcs = [
"anthropic_test.go",
"anthropicmessages_test.go",
"fireworks_test.go",
"openai_test.go",
],

View File

@ -6,9 +6,9 @@ import (
"encoding/json"
"io"
"net/http"
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
@ -248,61 +248,19 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, cfg
if ar.Model != "claude-2" && ar.Model != "claude-2.0" && ar.Model != "claude-2.1" && ar.Model != "claude-v1" {
return nil, nil
}
var reasons []string
prompt := strings.ToLower(ar.Prompt)
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
reasons = append(reasons, "unknown_prompt")
}
// If this request has a very high token count for responses, then flag it.
if ar.MaxTokensToSample > int32(cfg.MaxTokensToSampleFlaggingLimit) {
reasons = append(reasons, "high_max_tokens_to_sample")
}
// If this prompt consists of a very large number of tokens, then flag it.
tokenCount, err := ar.GetPromptTokenCount(tk)
if err != nil {
return &flaggingResult{}, errors.Wrap(err, "tokenize prompt")
}
if tokenCount > cfg.PromptTokenFlaggingLimit {
reasons = append(reasons, "high_prompt_token_count")
}
if len(reasons) > 0 { // request is flagged
blocked := false
hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns)
if tokenCount > cfg.PromptTokenBlockingLimit || ar.MaxTokensToSample > int32(cfg.ResponseTokenBlockingLimit) || hasBlockedPhrase {
blocked = true
}
promptPrefix := ar.Prompt
if len(promptPrefix) > logPromptPrefixLength {
promptPrefix = promptPrefix[0:logPromptPrefixLength]
}
res := &flaggingResult{
reasons: reasons,
maxTokensToSample: int(ar.MaxTokensToSample),
promptPrefix: promptPrefix,
promptTokenCount: tokenCount,
shouldBlock: blocked,
}
if hasBlockedPhrase {
res.blockedPhrase = &phrase
}
return res, nil
}
return nil, nil
}
func containsAny(prompt string, patterns []string) (bool, string) {
prompt = strings.ToLower(prompt)
for _, pattern := range patterns {
if strings.Contains(prompt, pattern) {
return true, pattern
}
}
return false, ""
return isFlaggedRequest(tk,
flaggingRequest{
FlattenedPrompt: ar.Prompt,
MaxTokens: int(ar.MaxTokensToSample),
},
flaggingConfig{
AllowedPromptPatterns: cfg.AllowedPromptPatterns,
BlockedPromptPatterns: cfg.BlockedPromptPatterns,
PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit,
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
},
)
}

View File

@ -6,10 +6,11 @@ import (
"testing"
"github.com/hexops/autogold/v2"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
)
@ -58,7 +59,7 @@ func TestIsFlaggedAnthropicRequest(t *testing.T) {
require.Equal(t, int32(result.maxTokensToSample), ar.MaxTokensToSample)
})
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
t.Run("high prompt token count (above block limit)", func(t *testing.T) {
tokenLengths, err := tk.Tokenize(validPreamble)
require.NoError(t, err)

View File

@ -0,0 +1,298 @@
package completions
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
)
const anthropicMessagesAPIURL = "https://api.anthropic.com/v1/messages"
// This implements the newer `/messages` API by Anthropic
// https://docs.anthropic.com/claude/reference/messages_post
func NewAnthropicMessagesHandler(
baseLogger log.Logger,
eventLogger events.Logger,
rs limiter.RedisStore,
rateLimitNotifier notify.RateLimitNotifier,
httpClient httpcli.Doer,
config config.AnthropicConfig,
promptRecorder PromptRecorder,
autoFlushStreamingResponses bool,
) (http.Handler, error) {
// Tokenizer only needs to be initialized once, and can be shared globally.
tokenizer, err := tokenizer.NewAnthropicClaudeTokenizer()
if err != nil {
return nil, err
}
return makeUpstreamHandler[anthropicMessagesRequest](
baseLogger,
eventLogger,
rs,
rateLimitNotifier,
httpClient,
string(conftypes.CompletionsProviderNameAnthropic),
func(_ codygateway.Feature) string { return anthropicMessagesAPIURL },
config.AllowedModels,
&AnthropicMessagesHandlerMethods{config: config, tokenizer: tokenizer, promptRecorder: promptRecorder},
// Anthropic primarily uses concurrent requests to rate-limit spikes
// in requests, so set a default retry-after that is likely to be
// acceptable for Sourcegraph clients to retry (the default
// SRC_HTTP_CLI_EXTERNAL_RETRY_AFTER_MAX_DURATION) since we might be
// able to circumvent concurrents limits without raising an error to the
// user.
2, // seconds
autoFlushStreamingResponses,
nil,
), nil
}
// AnthropicMessagesRequest captures all known fields from https://console.anthropic.com/docs/api/reference.
type anthropicMessagesRequest struct {
Messages []anthropicMessage `json:"messages,omitempty"`
Model string `json:"model"`
MaxTokens int32 `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
TopK int32 `json:"top_k,omitempty"`
Stream bool `json:"stream,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
// These are not accepted from the client an instead are only used to talk
// to the upstream LLM APIs.
Metadata *anthropicMessagesRequestMetadata `json:"metadata,omitempty"`
System string `json:"system,omitempty"`
}
type anthropicMessage struct {
Role string `json:"role"` // "user", "assistant", or "system" (only allowed for the first message)
Content []anthropicMessageContent `json:"content"`
}
type anthropicMessageContent struct {
Type string `json:"type"` // "text" or "image" (not yet supported)
Text string `json:"text"`
}
type anthropicMessagesRequestMetadata struct {
UserID string `json:"user_id,omitempty"`
}
func (ar anthropicMessagesRequest) ShouldStream() bool {
return ar.Stream
}
func (ar anthropicMessagesRequest) GetModel() string {
return ar.Model
}
// Note: This is not the actual prompt send to Anthropic but it's a good
// approximation to measure tokens.
func (r anthropicMessagesRequest) BuildPrompt() string {
var sb strings.Builder
for _, m := range r.Messages {
switch m.Role {
case "user":
sb.WriteString("Human: ")
case "assistant":
sb.WriteString("Assistant: ")
case "system":
sb.WriteString("System: ")
default:
return ""
}
for _, c := range m.Content {
if c.Type == "text" {
sb.WriteString(c.Text)
}
}
sb.WriteString("\n\n")
}
return sb.String()
}
// AnthropicMessagesNonStreamingResponse captures all relevant-to-us fields from https://docs.anthropic.com/claude/reference/messages_post.
type anthropicMessagesNonStreamingResponse struct {
Content []anthropicMessageContent `json:"content"`
Usage anthropicMessagesResponseUsage `json:"usage"`
StopReason string `json:"stop_reason"`
}
// AnthropicMessagesStreamingResponse captures all relevant-to-us fields from each relevant SSE event from https://docs.anthropic.com/claude/reference/messages_post.
type anthropicMessagesStreamingResponse struct {
Type string `json:"type"`
Delta *anthropicMessagesStreamingResponseTextBucket `json:"delta"`
ContentBlock *anthropicMessagesStreamingResponseTextBucket `json:"content_block"`
Usage *anthropicMessagesResponseUsage `json:"usage"`
Message *anthropicStreamingResponseMessage `json:"message"`
}
type anthropicStreamingResponseMessage struct {
Usage *anthropicMessagesResponseUsage `json:"usage"`
}
type anthropicMessagesStreamingResponseTextBucket struct {
Text string `json:"text"`
}
type anthropicMessagesResponseUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type AnthropicMessagesHandlerMethods struct {
tokenizer *tokenizer.Tokenizer
promptRecorder PromptRecorder
config config.AnthropicConfig
}
func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, logger log.Logger, _ codygateway.Feature, ar anthropicMessagesRequest) (int, *flaggingResult, error) {
if ar.MaxTokens > int32(a.config.MaxTokensToSample) {
return http.StatusBadRequest, nil, errors.Errorf("max_tokens exceeds maximum allowed value of %d: %d", a.config.MaxTokensToSample, ar.MaxTokens)
}
if result, err := isFlaggedAnthropicMessagesRequest(a.tokenizer, ar, a.config); err != nil {
logger.Error("error checking AnthropicMessages request - treating as non-flagged",
log.Error(err))
} else if result.IsFlagged() {
// Record flagged prompts in hotpath - they usually take a long time on the backend side, so this isn't going to make things meaningfully worse
if err := a.promptRecorder.Record(ctx, ar.BuildPrompt()); err != nil {
logger.Warn("failed to record flagged prompt", log.Error(err))
}
if a.config.RequestBlockingEnabled && result.shouldBlock {
return http.StatusBadRequest, result, errors.Errorf("request blocked - if you think this is a mistake, please contact support@sourcegraph.com")
}
return 0, result, nil
}
return 0, nil, nil
}
func (a *AnthropicMessagesHandlerMethods) transformBody(body *anthropicMessagesRequest, identifier string) {
// Overwrite the metadata field, we don't want to allow users to specify it:
body.Metadata = &anthropicMessagesRequestMetadata{
// We forward the actor ID to support tracking.
UserID: identifier,
}
// Remove the `anthropic/` prefix from the model string
body.Model = strings.TrimPrefix(body.Model, "anthropic/")
// Convert the eventual first message from `system` to a top-level system prompt
body.System = "" // prevent the upstream API from setting this
if len(body.Messages) > 0 && body.Messages[0].Role == "system" {
body.System = body.Messages[0].Content[0].Text
body.Messages = body.Messages[1:]
}
}
func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessagesRequest) (model string, additionalMetadata map[string]any) {
return body.Model, map[string]any{
"stream": body.Stream,
"max_tokens": body.MaxTokens,
}
}
func (a *AnthropicMessagesHandlerMethods) transformRequest(r *http.Request) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("X-API-Key", a.config.AccessToken)
r.Header.Set("anthropic-version", "2023-06-01")
}
func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
// First, extract prompt usage details from the request.
for _, m := range body.Messages {
promptUsage.characters += len(m.Content)
}
promptUsage.characters = len(body.System)
// Try to parse the request we saw, if it was non-streaming, we can simply parse
// it as JSON.
if !body.ShouldStream() {
var res anthropicMessagesNonStreamingResponse
if err := json.NewDecoder(r).Decode(&res); err != nil {
logger.Error("failed to parse Anthropic response as JSON", log.Error(err))
return promptUsage, completionUsage
}
// Extract character data from response by summing up all text
for _, c := range res.Content {
completionUsage.characters += len(c.Text)
}
// Extract prompt usage data from the response
completionUsage.tokens = res.Usage.OutputTokens
promptUsage.tokens = res.Usage.InputTokens
return promptUsage, completionUsage
}
// Otherwise, we have to parse the event stream from anthropic.
dec := anthropicmessages.NewDecoder(r)
for dec.Scan() {
data := dec.Data()
// Gracefully skip over any data that isn't JSON-like. Anthropic's API sometimes sends
// non-documented data over the stream, like timestamps.
if !bytes.HasPrefix(data, []byte("{")) {
continue
}
var event anthropicMessagesStreamingResponse
if err := json.Unmarshal(data, &event); err != nil {
logger.Error("failed to decode event payload", log.Error(err), log.String("body", string(data)))
continue
}
switch event.Type {
case "message_start":
if event.Message != nil && event.Message.Usage != nil {
promptUsage.tokens = event.Message.Usage.InputTokens
}
case "content_block_delta":
if event.Delta != nil {
completionUsage.characters += len(event.Delta.Text)
}
case "message_delta":
if event.Usage != nil {
completionUsage.tokens = event.Usage.OutputTokens
}
}
}
if err := dec.Err(); err != nil {
logger.Error("failed to decode Anthropic streaming response", log.Error(err))
}
return promptUsage, completionUsage
}
func isFlaggedAnthropicMessagesRequest(tk *tokenizer.Tokenizer, r anthropicMessagesRequest, cfg config.AnthropicConfig) (*flaggingResult, error) {
return isFlaggedRequest(tk,
flaggingRequest{
FlattenedPrompt: r.BuildPrompt(),
MaxTokens: int(r.MaxTokens),
},
flaggingConfig{
AllowedPromptPatterns: cfg.AllowedPromptPatterns,
BlockedPromptPatterns: cfg.BlockedPromptPatterns,
PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit,
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
},
)
}

View File

@ -0,0 +1,166 @@
package completions
import (
"encoding/json"
"strings"
"testing"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
)
func TestIsFlaggedAnthropicMessagesRequest(t *testing.T) {
validPreamble := "You are cody-gateway."
cfg := config.AnthropicConfig{
PromptTokenFlaggingLimit: 18000,
PromptTokenBlockingLimit: 20000,
MaxTokensToSampleFlaggingLimit: 1000,
ResponseTokenBlockingLimit: 1000,
}
cfgWithPreamble := config.AnthropicConfig{
PromptTokenFlaggingLimit: 18000,
PromptTokenBlockingLimit: 20000,
MaxTokensToSampleFlaggingLimit: 1000,
ResponseTokenBlockingLimit: 1000,
AllowedPromptPatterns: []string{strings.ToLower(validPreamble)},
}
tk, err := tokenizer.NewAnthropicClaudeTokenizer()
require.NoError(t, err)
t.Run("works for known preamble", func(t *testing.T) {
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
require.NoError(t, err)
require.Nil(t, result)
})
t.Run("missing known preamble", func(t *testing.T) {
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.False(t, result.shouldBlock)
require.Contains(t, result.reasons, "unknown_prompt")
})
t.Run("preamble not configured ", func(t *testing.T) {
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg)
require.NoError(t, err)
require.False(t, result.IsFlagged())
})
t.Run("high max tokens to sample", func(t *testing.T) {
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", MaxTokens: 10000, Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.True(t, result.shouldBlock)
require.Contains(t, result.reasons, "high_max_tokens_to_sample")
require.Equal(t, int32(result.maxTokensToSample), r.MaxTokens)
})
t.Run("high prompt token count and bad phrase", func(t *testing.T) {
cfgWithBadPhrase := &cfgWithPreamble
cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"}
longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1)
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.True(t, result.shouldBlock)
})
t.Run("low prompt token count and bad phrase", func(t *testing.T) {
cfgWithBadPhrase := &cfgWithPreamble
cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"}
longPrompt := strings.Repeat("word ", 5)
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase)
require.NoError(t, err)
// for now, we should not flag requests purely because of bad phrases
require.False(t, result.IsFlagged())
})
t.Run("high prompt token count (above block limit)", func(t *testing.T) {
tokenLengths, err := tk.Tokenize(validPreamble)
require.NoError(t, err)
validPreambleTokens := len(tokenLengths)
longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1)
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.False(t, result.shouldBlock)
require.Contains(t, result.reasons, "high_prompt_token_count")
require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenFlaggingLimit+4, cfg)
})
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
tokenLengths, err := tk.Tokenize(validPreamble)
require.NoError(t, err)
validPreambleTokens := len(tokenLengths)
longPrompt := strings.Repeat("word ", cfg.PromptTokenBlockingLimit+1)
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}},
}}
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.True(t, result.shouldBlock)
require.Contains(t, result.reasons, "high_prompt_token_count")
require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenBlockingLimit+4)
})
}
func TestAnthropicMessagesRequestJSON(t *testing.T) {
_, err := tokenizer.NewAnthropicClaudeTokenizer()
require.NoError(t, err)
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: "Hello world"}}},
}}
b, err := json.MarshalIndent(r, "", "\t")
require.NoError(t, err)
autogold.Expect(`{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello world"
}
]
}
],
"model": "anthropic/claude-3-sonnet-20240229"
}`).Equal(t, string(b))
}

View File

@ -0,0 +1,101 @@
package completions
import (
"strings"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type flaggingConfig struct {
// Phrases we look for in the prompt to consider it valid.
// Each phrase is lower case.
AllowedPromptPatterns []string
// Phrases we look for in a flagged request to consider blocking the response.
// Each phrase is lower case. Can be empty (to disable blocking).
BlockedPromptPatterns []string
// Phrases we look for in a request to collect data.
// Each phrase is lower case. Can be empty (to disable data collection).
PromptTokenFlaggingLimit int
PromptTokenBlockingLimit int
MaxTokensToSampleFlaggingLimit int
ResponseTokenBlockingLimit int
}
type flaggingRequest struct {
FlattenedPrompt string
MaxTokens int
}
type flaggingResult struct {
shouldBlock bool
blockedPhrase *string
reasons []string
promptPrefix string
maxTokensToSample int
promptTokenCount int
}
func isFlaggedRequest(tk *tokenizer.Tokenizer, r flaggingRequest, cfg flaggingConfig) (*flaggingResult, error) {
var reasons []string
prompt := strings.ToLower(r.FlattenedPrompt)
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
reasons = append(reasons, "unknown_prompt")
}
// If this request has a very high token count for responses, then flag it.
if r.MaxTokens > cfg.MaxTokensToSampleFlaggingLimit {
reasons = append(reasons, "high_max_tokens_to_sample")
}
// If this prompt consists of a very large number of tokens, then flag it.
tokens, err := tk.Tokenize(r.FlattenedPrompt)
if err != nil {
return &flaggingResult{}, errors.Wrap(err, "tokenize prompt")
}
tokenCount := len(tokens)
if tokenCount > cfg.PromptTokenFlaggingLimit {
reasons = append(reasons, "high_prompt_token_count")
}
if len(reasons) > 0 { // request is flagged
blocked := false
hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns)
if tokenCount > cfg.PromptTokenBlockingLimit || r.MaxTokens > cfg.ResponseTokenBlockingLimit || hasBlockedPhrase {
blocked = true
}
promptPrefix := r.FlattenedPrompt
if len(promptPrefix) > logPromptPrefixLength {
promptPrefix = promptPrefix[0:logPromptPrefixLength]
}
res := &flaggingResult{
reasons: reasons,
maxTokensToSample: r.MaxTokens,
promptPrefix: promptPrefix,
promptTokenCount: tokenCount,
shouldBlock: blocked,
}
if hasBlockedPhrase {
res.blockedPhrase = &phrase
}
return res, nil
}
return nil, nil
}
func (f *flaggingResult) IsFlagged() bool {
return f != nil
}
func containsAny(prompt string, patterns []string) (bool, string) {
prompt = strings.ToLower(prompt)
for _, pattern := range patterns {
if strings.Contains(prompt, pattern) {
return true, pattern
}
}
return false, ""
}

View File

@ -130,8 +130,10 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// Convert allowedModels to the Cody Gateway configuration format with the
// provider as a prefix. This aligns with the models returned when we query
// for rate limits from actor sources.
for i := range allowedModels {
allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i])
clonedAllowedModels := make([]string, len(allowedModels))
copy(clonedAllowedModels, allowedModels)
for i := range clonedAllowedModels {
clonedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, clonedAllowedModels[i])
}
// turn off sanitization for profanity detection
@ -284,7 +286,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// the prefix yet when extracted - we need to add it back here. This
// full gatewayModel is also used in events tracking.
gatewayModel := fmt.Sprintf("%s/%s", upstreamName, model)
if allowed := intersection(allowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
if allowed := intersection(clonedAllowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
response.JSONError(logger, w, http.StatusBadRequest,
errors.Newf("model %q is not allowed, allowed: [%s]",
gatewayModel, strings.Join(allowed, ", ")))
@ -508,16 +510,3 @@ func intersection(a, b []string) (c []string) {
}
return c
}
type flaggingResult struct {
shouldBlock bool
blockedPhrase *string
reasons []string
promptPrefix string
maxTokensToSample int
promptTokenCount int
}
func (f *flaggingResult) IsFlagged() bool {
return f != nil
}

View File

@ -7,12 +7,13 @@ import (
"github.com/Khan/genqlient/graphql"
"github.com/gorilla/mux"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
@ -83,7 +84,6 @@ func NewHandler(
httpClient,
config.Anthropic,
promptRecorder,
config.AutoFlushStreamingResponses,
)
if err != nil {
@ -106,6 +106,37 @@ func NewHandler(
otelhttp.WithPublicEndpoint(),
),
))
anthropicMessagesHandler, err := completions.NewAnthropicMessagesHandler(
logger,
eventLogger,
rs,
config.RateLimitNotifier,
httpClient,
config.Anthropic,
promptRecorder,
config.AutoFlushStreamingResponses,
)
if err != nil {
return nil, errors.Wrap(err, "init anthropicMessages handler")
}
v1router.Path("/completions/anthropic-messages").Methods(http.MethodPost).Handler(
overhead.HTTPMiddleware(latencyHistogram,
instrumentation.HTTPMiddleware("v1.completions.anthropicmessages",
gaugeHandler(
counter,
attributesAnthropicCompletions,
authr.Middleware(
requestlogger.Middleware(
logger,
anthropicMessagesHandler,
),
),
),
otelhttp.WithPublicEndpoint(),
),
))
} else {
logger.Error("Anthropic access token not set")
}

View File

@ -141,6 +141,8 @@ func (c *Config) Load() {
"claude-instant-v1.2",
"claude-instant-1.2",
"claude-instant-1.2-cyan",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
}, ","),
"Anthropic models that can be used."))
if c.Anthropic.AccessToken != "" && len(c.Anthropic.AllowedModels) == 0 {

View File

@ -348,6 +348,7 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
if !isProUser {
return []string{
// Remove after the Claude 3 rollout is complete
"anthropic/claude-2.0",
"anthropic/claude-instant-v1",
"anthropic/claude-instant-1.2",
@ -356,6 +357,13 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
}
return []string{
"anthropic/claude-3-sonnet-20240229",
"anthropic/claude-3-opus-20240229",
"fireworks/" + fireworks.Mixtral8x7bInstruct,
"openai/gpt-3.5-turbo",
"openai/gpt-4-1106-preview",
// Remove after the Claude 3 rollout is complete
"anthropic/claude-2",
"anthropic/claude-2.0",
"anthropic/claude-2.1",
@ -363,9 +371,6 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
"anthropic/claude-instant-1.2",
"anthropic/claude-instant-v1",
"anthropic/claude-instant-1",
"openai/gpt-3.5-turbo",
"openai/gpt-4-1106-preview",
"fireworks/" + fireworks.Mixtral8x7bInstruct,
}
case types.CompletionsFeatureCode:
return []string{

View File

@ -0,0 +1,19 @@
load("//dev:go_defs.bzl", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "anthropicmessages",
srcs = ["decoder.go"],
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages",
visibility = ["//:__subpackages__"],
deps = ["//lib/errors"],
)
go_test(
name = "anthropicmessages_test",
timeout = "short",
srcs = ["decoder_test.go"],
data = glob(["testdata/**"]),
embed = [":anthropicmessages"],
deps = ["@com_github_stretchr_testify//require"],
)

View File

@ -0,0 +1,106 @@
package anthropicmessages
import (
"bufio"
"bytes"
"io"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
var doneBytes = []byte("[DONE]")
// decoder decodes streaming events from a Server Sent Event stream. It only supports
// streams generated by the Anthropic completions API. IE this is not a fully
// compliant Server Sent Events decoder.
//
// Adapted from internal/search/streaming/http/decoder.go.
type decoder struct {
scanner *bufio.Scanner
done bool
data []byte
err error
}
func NewDecoder(r io.Reader) *decoder {
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
// bufio.ScanLines, except we look for \n\n which separate events.
split := func(data []byte, atEOF bool) (int, []byte, error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
return i + 2, data[:i], nil
}
// If we're at EOF, we have a final, non-terminated event. This should
// be empty.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(split)
return &decoder{
scanner: scanner,
}
}
// Scan advances the decoder to the next event in the stream. It returns
// false when it either hits the end of the stream or an error.
func (d *decoder) Scan() bool {
if d.done {
return false
}
for d.scanner.Scan() {
// event: $_name
// data: json($data)|[DONE]
lines := bytes.Split(d.scanner.Bytes(), []byte("\n"))
for _, line := range lines {
typ, data := splitColon(line)
switch {
case bytes.Equal(typ, []byte("data")):
d.data = data
// Check for special sentinel value used by the Anthropic API to
// indicate that the stream is done.
if bytes.Equal(data, doneBytes) {
d.done = true
return false
}
return true
case bytes.Equal(typ, []byte("event")):
// Anthropic sends the event name in the data payload as well so we ignore it for snow
continue
default:
d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line)
return false
}
}
}
d.err = d.scanner.Err()
return false
}
// Event returns the event data of the last decoded event
func (d *decoder) Data() []byte {
return d.data
}
// Err returns the last encountered error
func (d *decoder) Err() error {
return d.err
}
func splitColon(data []byte) ([]byte, []byte) {
i := bytes.Index(data, []byte(":"))
if i < 0 {
return bytes.TrimSpace(data), nil
}
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
}

View File

@ -0,0 +1,50 @@
package anthropicmessages
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDecoder(t *testing.T) {
t.Parallel()
type event struct {
data string
}
decodeAll := func(input string) ([]event, error) {
dec := NewDecoder(strings.NewReader(input))
var events []event
for dec.Scan() {
events = append(events, event{
data: string(dec.Data()),
})
}
return events, dec.Err()
}
t.Run("Single", func(t *testing.T) {
events, err := decodeAll("event:foo\ndata:b\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "b"}})
})
t.Run("Multiple", func(t *testing.T) {
events, err := decodeAll("event:foo\ndata:b\n\nevent:foo\ndata:c\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
})
t.Run("ErrExpectedData", func(t *testing.T) {
_, err := decodeAll("datas:b\n\n")
require.Contains(t, err.Error(), "malformed data, expected data")
})
t.Run("InterleavedPing", func(t *testing.T) {
events, err := decodeAll("data:a\n\nevent: ping\ndata: pong\n\ndata:b\n\ndata: [DONE]\n\n")
require.NoError(t, err)
require.Equal(t, events, []event{{data: "a"}, {data: "pong"}, {data: "b"}})
})
}

View File

@ -74,21 +74,26 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool {
// When updating these two lists, make sure you also update `allowedModels` in codygateway_dotcom_user.go.
if isProUser {
switch model {
case "anthropic/claude-2",
case "anthropic/claude-3-sonnet-20240229",
"anthropic/claude-3-opus-20240229",
"fireworks/" + fireworks.Mixtral8x7bInstruct,
"openai/gpt-3.5-turbo",
"openai/gpt-4-1106-preview",
// Remove after the Claude 3 rollout is complete
"anthropic/claude-2",
"anthropic/claude-2.0",
"anthropic/claude-2.1",
"anthropic/claude-instant-1.2-cyan",
"anthropic/claude-instant-1.2",
"anthropic/claude-instant-v1",
"anthropic/claude-instant-1",
"openai/gpt-3.5-turbo",
"openai/gpt-4-1106-preview",
"fireworks/" + fireworks.Mixtral8x7bInstruct:
"anthropic/claude-instant-1":
return true
}
} else {
switch model {
case "anthropic/claude-2",
case // Remove after the Claude 3 rollout is complete
"anthropic/claude-2",
"anthropic/claude-2.0",
"anthropic/claude-instant-v1",
"anthropic/claude-instant-1":