mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:51:55 +00:00
Add support for Claude 3 to Cody Gateway (#60830)
Co-authored-by: Chris Warwick <christopher.warwick@sourcegraph.com>
This commit is contained in:
parent
9e6e4e1bc9
commit
5da08a3bc8
@ -5,7 +5,9 @@ go_library(
|
||||
name = "completions",
|
||||
srcs = [
|
||||
"anthropic.go",
|
||||
"anthropicmessages.go",
|
||||
"fireworks.go",
|
||||
"flagging.go",
|
||||
"openai.go",
|
||||
"upstream.go",
|
||||
],
|
||||
@ -23,6 +25,7 @@ go_library(
|
||||
"//cmd/cody-gateway/shared/config",
|
||||
"//internal/codygateway",
|
||||
"//internal/completions/client/anthropic",
|
||||
"//internal/completions/client/anthropicmessages",
|
||||
"//internal/completions/client/fireworks",
|
||||
"//internal/completions/client/openai",
|
||||
"//internal/conf/conftypes",
|
||||
@ -42,6 +45,7 @@ go_test(
|
||||
name = "completions_test",
|
||||
srcs = [
|
||||
"anthropic_test.go",
|
||||
"anthropicmessages_test.go",
|
||||
"fireworks_test.go",
|
||||
"openai_test.go",
|
||||
],
|
||||
|
||||
@ -6,9 +6,9 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropic"
|
||||
|
||||
@ -248,61 +248,19 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, cfg
|
||||
if ar.Model != "claude-2" && ar.Model != "claude-2.0" && ar.Model != "claude-2.1" && ar.Model != "claude-v1" {
|
||||
return nil, nil
|
||||
}
|
||||
var reasons []string
|
||||
|
||||
prompt := strings.ToLower(ar.Prompt)
|
||||
|
||||
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
|
||||
reasons = append(reasons, "unknown_prompt")
|
||||
}
|
||||
|
||||
// If this request has a very high token count for responses, then flag it.
|
||||
if ar.MaxTokensToSample > int32(cfg.MaxTokensToSampleFlaggingLimit) {
|
||||
reasons = append(reasons, "high_max_tokens_to_sample")
|
||||
}
|
||||
|
||||
// If this prompt consists of a very large number of tokens, then flag it.
|
||||
tokenCount, err := ar.GetPromptTokenCount(tk)
|
||||
if err != nil {
|
||||
return &flaggingResult{}, errors.Wrap(err, "tokenize prompt")
|
||||
}
|
||||
if tokenCount > cfg.PromptTokenFlaggingLimit {
|
||||
reasons = append(reasons, "high_prompt_token_count")
|
||||
}
|
||||
|
||||
if len(reasons) > 0 { // request is flagged
|
||||
blocked := false
|
||||
hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns)
|
||||
if tokenCount > cfg.PromptTokenBlockingLimit || ar.MaxTokensToSample > int32(cfg.ResponseTokenBlockingLimit) || hasBlockedPhrase {
|
||||
blocked = true
|
||||
}
|
||||
|
||||
promptPrefix := ar.Prompt
|
||||
if len(promptPrefix) > logPromptPrefixLength {
|
||||
promptPrefix = promptPrefix[0:logPromptPrefixLength]
|
||||
}
|
||||
res := &flaggingResult{
|
||||
reasons: reasons,
|
||||
maxTokensToSample: int(ar.MaxTokensToSample),
|
||||
promptPrefix: promptPrefix,
|
||||
promptTokenCount: tokenCount,
|
||||
shouldBlock: blocked,
|
||||
}
|
||||
if hasBlockedPhrase {
|
||||
res.blockedPhrase = &phrase
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func containsAny(prompt string, patterns []string) (bool, string) {
|
||||
prompt = strings.ToLower(prompt)
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(prompt, pattern) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
return isFlaggedRequest(tk,
|
||||
flaggingRequest{
|
||||
FlattenedPrompt: ar.Prompt,
|
||||
MaxTokens: int(ar.MaxTokensToSample),
|
||||
},
|
||||
flaggingConfig{
|
||||
AllowedPromptPatterns: cfg.AllowedPromptPatterns,
|
||||
BlockedPromptPatterns: cfg.BlockedPromptPatterns,
|
||||
PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit,
|
||||
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
|
||||
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
|
||||
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@ -6,10 +6,11 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
|
||||
)
|
||||
|
||||
@ -58,7 +59,7 @@ func TestIsFlaggedAnthropicRequest(t *testing.T) {
|
||||
require.Equal(t, int32(result.maxTokensToSample), ar.MaxTokensToSample)
|
||||
})
|
||||
|
||||
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
|
||||
t.Run("high prompt token count (above block limit)", func(t *testing.T) {
|
||||
tokenLengths, err := tk.Tokenize(validPreamble)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@ -0,0 +1,298 @@
|
||||
package completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages"
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
)
|
||||
|
||||
const anthropicMessagesAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
// This implements the newer `/messages` API by Anthropic
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
func NewAnthropicMessagesHandler(
|
||||
baseLogger log.Logger,
|
||||
eventLogger events.Logger,
|
||||
rs limiter.RedisStore,
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
httpClient httpcli.Doer,
|
||||
config config.AnthropicConfig,
|
||||
promptRecorder PromptRecorder,
|
||||
autoFlushStreamingResponses bool,
|
||||
) (http.Handler, error) {
|
||||
// Tokenizer only needs to be initialized once, and can be shared globally.
|
||||
tokenizer, err := tokenizer.NewAnthropicClaudeTokenizer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeUpstreamHandler[anthropicMessagesRequest](
|
||||
baseLogger,
|
||||
eventLogger,
|
||||
rs,
|
||||
rateLimitNotifier,
|
||||
httpClient,
|
||||
string(conftypes.CompletionsProviderNameAnthropic),
|
||||
func(_ codygateway.Feature) string { return anthropicMessagesAPIURL },
|
||||
config.AllowedModels,
|
||||
&AnthropicMessagesHandlerMethods{config: config, tokenizer: tokenizer, promptRecorder: promptRecorder},
|
||||
|
||||
// Anthropic primarily uses concurrent requests to rate-limit spikes
|
||||
// in requests, so set a default retry-after that is likely to be
|
||||
// acceptable for Sourcegraph clients to retry (the default
|
||||
// SRC_HTTP_CLI_EXTERNAL_RETRY_AFTER_MAX_DURATION) since we might be
|
||||
// able to circumvent concurrents limits without raising an error to the
|
||||
// user.
|
||||
2, // seconds
|
||||
autoFlushStreamingResponses,
|
||||
nil,
|
||||
), nil
|
||||
}
|
||||
|
||||
// AnthropicMessagesRequest captures all known fields from https://console.anthropic.com/docs/api/reference.
|
||||
type anthropicMessagesRequest struct {
|
||||
Messages []anthropicMessage `json:"messages,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TopK int32 `json:"top_k,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
|
||||
// These are not accepted from the client an instead are only used to talk
|
||||
// to the upstream LLM APIs.
|
||||
Metadata *anthropicMessagesRequestMetadata `json:"metadata,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"` // "user", "assistant", or "system" (only allowed for the first message)
|
||||
Content []anthropicMessageContent `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicMessageContent struct {
|
||||
Type string `json:"type"` // "text" or "image" (not yet supported)
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type anthropicMessagesRequestMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func (ar anthropicMessagesRequest) ShouldStream() bool {
|
||||
return ar.Stream
|
||||
}
|
||||
|
||||
func (ar anthropicMessagesRequest) GetModel() string {
|
||||
return ar.Model
|
||||
}
|
||||
|
||||
// Note: This is not the actual prompt send to Anthropic but it's a good
|
||||
// approximation to measure tokens.
|
||||
func (r anthropicMessagesRequest) BuildPrompt() string {
|
||||
var sb strings.Builder
|
||||
for _, m := range r.Messages {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
sb.WriteString("Human: ")
|
||||
case "assistant":
|
||||
sb.WriteString("Assistant: ")
|
||||
case "system":
|
||||
sb.WriteString("System: ")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, c := range m.Content {
|
||||
if c.Type == "text" {
|
||||
sb.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// AnthropicMessagesNonStreamingResponse captures all relevant-to-us fields from https://docs.anthropic.com/claude/reference/messages_post.
|
||||
type anthropicMessagesNonStreamingResponse struct {
|
||||
Content []anthropicMessageContent `json:"content"`
|
||||
Usage anthropicMessagesResponseUsage `json:"usage"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
// AnthropicMessagesStreamingResponse captures all relevant-to-us fields from each relevant SSE event from https://docs.anthropic.com/claude/reference/messages_post.
|
||||
type anthropicMessagesStreamingResponse struct {
|
||||
Type string `json:"type"`
|
||||
Delta *anthropicMessagesStreamingResponseTextBucket `json:"delta"`
|
||||
ContentBlock *anthropicMessagesStreamingResponseTextBucket `json:"content_block"`
|
||||
Usage *anthropicMessagesResponseUsage `json:"usage"`
|
||||
Message *anthropicStreamingResponseMessage `json:"message"`
|
||||
}
|
||||
|
||||
type anthropicStreamingResponseMessage struct {
|
||||
Usage *anthropicMessagesResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type anthropicMessagesStreamingResponseTextBucket struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type anthropicMessagesResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type AnthropicMessagesHandlerMethods struct {
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
promptRecorder PromptRecorder
|
||||
config config.AnthropicConfig
|
||||
}
|
||||
|
||||
func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, logger log.Logger, _ codygateway.Feature, ar anthropicMessagesRequest) (int, *flaggingResult, error) {
|
||||
if ar.MaxTokens > int32(a.config.MaxTokensToSample) {
|
||||
return http.StatusBadRequest, nil, errors.Errorf("max_tokens exceeds maximum allowed value of %d: %d", a.config.MaxTokensToSample, ar.MaxTokens)
|
||||
}
|
||||
|
||||
if result, err := isFlaggedAnthropicMessagesRequest(a.tokenizer, ar, a.config); err != nil {
|
||||
logger.Error("error checking AnthropicMessages request - treating as non-flagged",
|
||||
log.Error(err))
|
||||
} else if result.IsFlagged() {
|
||||
// Record flagged prompts in hotpath - they usually take a long time on the backend side, so this isn't going to make things meaningfully worse
|
||||
if err := a.promptRecorder.Record(ctx, ar.BuildPrompt()); err != nil {
|
||||
logger.Warn("failed to record flagged prompt", log.Error(err))
|
||||
}
|
||||
if a.config.RequestBlockingEnabled && result.shouldBlock {
|
||||
return http.StatusBadRequest, result, errors.Errorf("request blocked - if you think this is a mistake, please contact support@sourcegraph.com")
|
||||
}
|
||||
return 0, result, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
}
|
||||
func (a *AnthropicMessagesHandlerMethods) transformBody(body *anthropicMessagesRequest, identifier string) {
|
||||
// Overwrite the metadata field, we don't want to allow users to specify it:
|
||||
body.Metadata = &anthropicMessagesRequestMetadata{
|
||||
// We forward the actor ID to support tracking.
|
||||
UserID: identifier,
|
||||
}
|
||||
|
||||
// Remove the `anthropic/` prefix from the model string
|
||||
body.Model = strings.TrimPrefix(body.Model, "anthropic/")
|
||||
|
||||
// Convert the eventual first message from `system` to a top-level system prompt
|
||||
body.System = "" // prevent the upstream API from setting this
|
||||
if len(body.Messages) > 0 && body.Messages[0].Role == "system" {
|
||||
body.System = body.Messages[0].Content[0].Text
|
||||
body.Messages = body.Messages[1:]
|
||||
}
|
||||
}
|
||||
func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessagesRequest) (model string, additionalMetadata map[string]any) {
|
||||
return body.Model, map[string]any{
|
||||
"stream": body.Stream,
|
||||
"max_tokens": body.MaxTokens,
|
||||
}
|
||||
}
|
||||
func (a *AnthropicMessagesHandlerMethods) transformRequest(r *http.Request) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("X-API-Key", a.config.AccessToken)
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
|
||||
// First, extract prompt usage details from the request.
|
||||
for _, m := range body.Messages {
|
||||
promptUsage.characters += len(m.Content)
|
||||
}
|
||||
promptUsage.characters = len(body.System)
|
||||
|
||||
// Try to parse the request we saw, if it was non-streaming, we can simply parse
|
||||
// it as JSON.
|
||||
if !body.ShouldStream() {
|
||||
var res anthropicMessagesNonStreamingResponse
|
||||
if err := json.NewDecoder(r).Decode(&res); err != nil {
|
||||
logger.Error("failed to parse Anthropic response as JSON", log.Error(err))
|
||||
return promptUsage, completionUsage
|
||||
}
|
||||
|
||||
// Extract character data from response by summing up all text
|
||||
for _, c := range res.Content {
|
||||
completionUsage.characters += len(c.Text)
|
||||
}
|
||||
// Extract prompt usage data from the response
|
||||
completionUsage.tokens = res.Usage.OutputTokens
|
||||
promptUsage.tokens = res.Usage.InputTokens
|
||||
|
||||
return promptUsage, completionUsage
|
||||
}
|
||||
|
||||
// Otherwise, we have to parse the event stream from anthropic.
|
||||
dec := anthropicmessages.NewDecoder(r)
|
||||
for dec.Scan() {
|
||||
data := dec.Data()
|
||||
|
||||
// Gracefully skip over any data that isn't JSON-like. Anthropic's API sometimes sends
|
||||
// non-documented data over the stream, like timestamps.
|
||||
if !bytes.HasPrefix(data, []byte("{")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropicMessagesStreamingResponse
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
logger.Error("failed to decode event payload", log.Error(err), log.String("body", string(data)))
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
if event.Message != nil && event.Message.Usage != nil {
|
||||
promptUsage.tokens = event.Message.Usage.InputTokens
|
||||
}
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil {
|
||||
completionUsage.characters += len(event.Delta.Text)
|
||||
}
|
||||
case "message_delta":
|
||||
if event.Usage != nil {
|
||||
completionUsage.tokens = event.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := dec.Err(); err != nil {
|
||||
logger.Error("failed to decode Anthropic streaming response", log.Error(err))
|
||||
}
|
||||
|
||||
return promptUsage, completionUsage
|
||||
}
|
||||
|
||||
func isFlaggedAnthropicMessagesRequest(tk *tokenizer.Tokenizer, r anthropicMessagesRequest, cfg config.AnthropicConfig) (*flaggingResult, error) {
|
||||
return isFlaggedRequest(tk,
|
||||
flaggingRequest{
|
||||
FlattenedPrompt: r.BuildPrompt(),
|
||||
MaxTokens: int(r.MaxTokens),
|
||||
},
|
||||
flaggingConfig{
|
||||
AllowedPromptPatterns: cfg.AllowedPromptPatterns,
|
||||
BlockedPromptPatterns: cfg.BlockedPromptPatterns,
|
||||
PromptTokenFlaggingLimit: cfg.PromptTokenFlaggingLimit,
|
||||
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
|
||||
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
|
||||
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
|
||||
},
|
||||
)
|
||||
}
|
||||
@ -0,0 +1,166 @@
|
||||
package completions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
|
||||
)
|
||||
|
||||
func TestIsFlaggedAnthropicMessagesRequest(t *testing.T) {
|
||||
validPreamble := "You are cody-gateway."
|
||||
|
||||
cfg := config.AnthropicConfig{
|
||||
PromptTokenFlaggingLimit: 18000,
|
||||
PromptTokenBlockingLimit: 20000,
|
||||
MaxTokensToSampleFlaggingLimit: 1000,
|
||||
ResponseTokenBlockingLimit: 1000,
|
||||
}
|
||||
cfgWithPreamble := config.AnthropicConfig{
|
||||
PromptTokenFlaggingLimit: 18000,
|
||||
PromptTokenBlockingLimit: 20000,
|
||||
MaxTokensToSampleFlaggingLimit: 1000,
|
||||
ResponseTokenBlockingLimit: 1000,
|
||||
AllowedPromptPatterns: []string{strings.ToLower(validPreamble)},
|
||||
}
|
||||
tk, err := tokenizer.NewAnthropicClaudeTokenizer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("works for known preamble", func(t *testing.T) {
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("missing known preamble", func(t *testing.T) {
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.False(t, result.shouldBlock)
|
||||
require.Contains(t, result.reasons, "unknown_prompt")
|
||||
})
|
||||
|
||||
t.Run("preamble not configured ", func(t *testing.T) {
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: "some prompt without known preamble"}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsFlagged())
|
||||
})
|
||||
|
||||
t.Run("high max tokens to sample", func(t *testing.T) {
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", MaxTokens: 10000, Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfg)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.True(t, result.shouldBlock)
|
||||
require.Contains(t, result.reasons, "high_max_tokens_to_sample")
|
||||
require.Equal(t, int32(result.maxTokensToSample), r.MaxTokens)
|
||||
})
|
||||
|
||||
t.Run("high prompt token count and bad phrase", func(t *testing.T) {
|
||||
cfgWithBadPhrase := &cfgWithPreamble
|
||||
cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"}
|
||||
longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1)
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.True(t, result.shouldBlock)
|
||||
})
|
||||
|
||||
t.Run("low prompt token count and bad phrase", func(t *testing.T) {
|
||||
cfgWithBadPhrase := &cfgWithPreamble
|
||||
cfgWithBadPhrase.BlockedPromptPatterns = []string{"bad phrase"}
|
||||
longPrompt := strings.Repeat("word ", 5)
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble + " " + longPrompt + "bad phrase"}}},
|
||||
}}
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, *cfgWithBadPhrase)
|
||||
require.NoError(t, err)
|
||||
// for now, we should not flag requests purely because of bad phrases
|
||||
require.False(t, result.IsFlagged())
|
||||
})
|
||||
|
||||
t.Run("high prompt token count (above block limit)", func(t *testing.T) {
|
||||
tokenLengths, err := tk.Tokenize(validPreamble)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPreambleTokens := len(tokenLengths)
|
||||
longPrompt := strings.Repeat("word ", cfg.PromptTokenFlaggingLimit+1)
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
|
||||
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}},
|
||||
}}
|
||||
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.False(t, result.shouldBlock)
|
||||
require.Contains(t, result.reasons, "high_prompt_token_count")
|
||||
require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenFlaggingLimit+4, cfg)
|
||||
})
|
||||
|
||||
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
|
||||
tokenLengths, err := tk.Tokenize(validPreamble)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPreambleTokens := len(tokenLengths)
|
||||
longPrompt := strings.Repeat("word ", cfg.PromptTokenBlockingLimit+1)
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "system", Content: []anthropicMessageContent{{Type: "text", Text: validPreamble}}},
|
||||
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: longPrompt}}},
|
||||
}}
|
||||
|
||||
result, err := isFlaggedAnthropicMessagesRequest(tk, r, cfgWithPreamble)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.True(t, result.shouldBlock)
|
||||
require.Contains(t, result.reasons, "high_prompt_token_count")
|
||||
require.Equal(t, result.promptTokenCount, validPreambleTokens+4+cfg.PromptTokenBlockingLimit+4)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesRequestJSON(t *testing.T) {
|
||||
_, err := tokenizer.NewAnthropicClaudeTokenizer()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := anthropicMessagesRequest{Model: "anthropic/claude-3-sonnet-20240229", Messages: []anthropicMessage{
|
||||
{Role: "user", Content: []anthropicMessageContent{{Type: "text", Text: "Hello world"}}},
|
||||
}}
|
||||
|
||||
b, err := json.MarshalIndent(r, "", "\t")
|
||||
require.NoError(t, err)
|
||||
|
||||
autogold.Expect(`{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello world"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "anthropic/claude-3-sonnet-20240229"
|
||||
}`).Equal(t, string(b))
|
||||
}
|
||||
101
cmd/cody-gateway/internal/httpapi/completions/flagging.go
Normal file
101
cmd/cody-gateway/internal/httpapi/completions/flagging.go
Normal file
@ -0,0 +1,101 @@
|
||||
package completions
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/tokenizer"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
type flaggingConfig struct {
|
||||
// Phrases we look for in the prompt to consider it valid.
|
||||
// Each phrase is lower case.
|
||||
AllowedPromptPatterns []string
|
||||
// Phrases we look for in a flagged request to consider blocking the response.
|
||||
// Each phrase is lower case. Can be empty (to disable blocking).
|
||||
BlockedPromptPatterns []string
|
||||
// Phrases we look for in a request to collect data.
|
||||
// Each phrase is lower case. Can be empty (to disable data collection).
|
||||
PromptTokenFlaggingLimit int
|
||||
PromptTokenBlockingLimit int
|
||||
MaxTokensToSampleFlaggingLimit int
|
||||
ResponseTokenBlockingLimit int
|
||||
}
|
||||
type flaggingRequest struct {
|
||||
FlattenedPrompt string
|
||||
MaxTokens int
|
||||
}
|
||||
type flaggingResult struct {
|
||||
shouldBlock bool
|
||||
blockedPhrase *string
|
||||
reasons []string
|
||||
promptPrefix string
|
||||
maxTokensToSample int
|
||||
promptTokenCount int
|
||||
}
|
||||
|
||||
func isFlaggedRequest(tk *tokenizer.Tokenizer, r flaggingRequest, cfg flaggingConfig) (*flaggingResult, error) {
|
||||
var reasons []string
|
||||
|
||||
prompt := strings.ToLower(r.FlattenedPrompt)
|
||||
|
||||
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
|
||||
reasons = append(reasons, "unknown_prompt")
|
||||
}
|
||||
|
||||
// If this request has a very high token count for responses, then flag it.
|
||||
if r.MaxTokens > cfg.MaxTokensToSampleFlaggingLimit {
|
||||
reasons = append(reasons, "high_max_tokens_to_sample")
|
||||
}
|
||||
|
||||
// If this prompt consists of a very large number of tokens, then flag it.
|
||||
tokens, err := tk.Tokenize(r.FlattenedPrompt)
|
||||
if err != nil {
|
||||
return &flaggingResult{}, errors.Wrap(err, "tokenize prompt")
|
||||
}
|
||||
tokenCount := len(tokens)
|
||||
|
||||
if tokenCount > cfg.PromptTokenFlaggingLimit {
|
||||
reasons = append(reasons, "high_prompt_token_count")
|
||||
}
|
||||
|
||||
if len(reasons) > 0 { // request is flagged
|
||||
blocked := false
|
||||
hasBlockedPhrase, phrase := containsAny(prompt, cfg.BlockedPromptPatterns)
|
||||
if tokenCount > cfg.PromptTokenBlockingLimit || r.MaxTokens > cfg.ResponseTokenBlockingLimit || hasBlockedPhrase {
|
||||
blocked = true
|
||||
}
|
||||
|
||||
promptPrefix := r.FlattenedPrompt
|
||||
if len(promptPrefix) > logPromptPrefixLength {
|
||||
promptPrefix = promptPrefix[0:logPromptPrefixLength]
|
||||
}
|
||||
res := &flaggingResult{
|
||||
reasons: reasons,
|
||||
maxTokensToSample: r.MaxTokens,
|
||||
promptPrefix: promptPrefix,
|
||||
promptTokenCount: tokenCount,
|
||||
shouldBlock: blocked,
|
||||
}
|
||||
if hasBlockedPhrase {
|
||||
res.blockedPhrase = &phrase
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *flaggingResult) IsFlagged() bool {
|
||||
return f != nil
|
||||
}
|
||||
|
||||
func containsAny(prompt string, patterns []string) (bool, string) {
|
||||
prompt = strings.ToLower(prompt)
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(prompt, pattern) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
@ -130,8 +130,10 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// Convert allowedModels to the Cody Gateway configuration format with the
|
||||
// provider as a prefix. This aligns with the models returned when we query
|
||||
// for rate limits from actor sources.
|
||||
for i := range allowedModels {
|
||||
allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i])
|
||||
clonedAllowedModels := make([]string, len(allowedModels))
|
||||
copy(clonedAllowedModels, allowedModels)
|
||||
for i := range clonedAllowedModels {
|
||||
clonedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, clonedAllowedModels[i])
|
||||
}
|
||||
|
||||
// turn off sanitization for profanity detection
|
||||
@ -284,7 +286,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// the prefix yet when extracted - we need to add it back here. This
|
||||
// full gatewayModel is also used in events tracking.
|
||||
gatewayModel := fmt.Sprintf("%s/%s", upstreamName, model)
|
||||
if allowed := intersection(allowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
|
||||
if allowed := intersection(clonedAllowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
|
||||
response.JSONError(logger, w, http.StatusBadRequest,
|
||||
errors.Newf("model %q is not allowed, allowed: [%s]",
|
||||
gatewayModel, strings.Join(allowed, ", ")))
|
||||
@ -508,16 +510,3 @@ func intersection(a, b []string) (c []string) {
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type flaggingResult struct {
|
||||
shouldBlock bool
|
||||
blockedPhrase *string
|
||||
reasons []string
|
||||
promptPrefix string
|
||||
maxTokensToSample int
|
||||
promptTokenCount int
|
||||
}
|
||||
|
||||
func (f *flaggingResult) IsFlagged() bool {
|
||||
return f != nil
|
||||
}
|
||||
|
||||
@ -7,12 +7,13 @@ import (
|
||||
"github.com/Khan/genqlient/graphql"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/httpapi/overhead"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/auth"
|
||||
@ -83,7 +84,6 @@ func NewHandler(
|
||||
httpClient,
|
||||
config.Anthropic,
|
||||
promptRecorder,
|
||||
|
||||
config.AutoFlushStreamingResponses,
|
||||
)
|
||||
if err != nil {
|
||||
@ -106,6 +106,37 @@ func NewHandler(
|
||||
otelhttp.WithPublicEndpoint(),
|
||||
),
|
||||
))
|
||||
|
||||
anthropicMessagesHandler, err := completions.NewAnthropicMessagesHandler(
|
||||
logger,
|
||||
eventLogger,
|
||||
rs,
|
||||
config.RateLimitNotifier,
|
||||
httpClient,
|
||||
config.Anthropic,
|
||||
promptRecorder,
|
||||
config.AutoFlushStreamingResponses,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "init anthropicMessages handler")
|
||||
}
|
||||
|
||||
v1router.Path("/completions/anthropic-messages").Methods(http.MethodPost).Handler(
|
||||
overhead.HTTPMiddleware(latencyHistogram,
|
||||
instrumentation.HTTPMiddleware("v1.completions.anthropicmessages",
|
||||
gaugeHandler(
|
||||
counter,
|
||||
attributesAnthropicCompletions,
|
||||
authr.Middleware(
|
||||
requestlogger.Middleware(
|
||||
logger,
|
||||
anthropicMessagesHandler,
|
||||
),
|
||||
),
|
||||
),
|
||||
otelhttp.WithPublicEndpoint(),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
logger.Error("Anthropic access token not set")
|
||||
}
|
||||
|
||||
@ -141,6 +141,8 @@ func (c *Config) Load() {
|
||||
"claude-instant-v1.2",
|
||||
"claude-instant-1.2",
|
||||
"claude-instant-1.2-cyan",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
}, ","),
|
||||
"Anthropic models that can be used."))
|
||||
if c.Anthropic.AccessToken != "" && len(c.Anthropic.AllowedModels) == 0 {
|
||||
|
||||
@ -348,6 +348,7 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
|
||||
|
||||
if !isProUser {
|
||||
return []string{
|
||||
// Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2.0",
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1.2",
|
||||
@ -356,6 +357,13 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
|
||||
}
|
||||
|
||||
return []string{
|
||||
"anthropic/claude-3-sonnet-20240229",
|
||||
"anthropic/claude-3-opus-20240229",
|
||||
"fireworks/" + fireworks.Mixtral8x7bInstruct,
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4-1106-preview",
|
||||
|
||||
// Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2",
|
||||
"anthropic/claude-2.0",
|
||||
"anthropic/claude-2.1",
|
||||
@ -363,9 +371,6 @@ func allowedModels(scope types.CompletionsFeature, isProUser bool) []string {
|
||||
"anthropic/claude-instant-1.2",
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1",
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4-1106-preview",
|
||||
"fireworks/" + fireworks.Mixtral8x7bInstruct,
|
||||
}
|
||||
case types.CompletionsFeatureCode:
|
||||
return []string{
|
||||
|
||||
19
internal/completions/client/anthropicmessages/BUILD.bazel
Normal file
19
internal/completions/client/anthropicmessages/BUILD.bazel
Normal file
@ -0,0 +1,19 @@
|
||||
load("//dev:go_defs.bzl", "go_test")
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "anthropicmessages",
|
||||
srcs = ["decoder.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/anthropicmessages",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = ["//lib/errors"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "anthropicmessages_test",
|
||||
timeout = "short",
|
||||
srcs = ["decoder_test.go"],
|
||||
data = glob(["testdata/**"]),
|
||||
embed = [":anthropicmessages"],
|
||||
deps = ["@com_github_stretchr_testify//require"],
|
||||
)
|
||||
106
internal/completions/client/anthropicmessages/decoder.go
Normal file
106
internal/completions/client/anthropicmessages/decoder.go
Normal file
@ -0,0 +1,106 @@
|
||||
package anthropicmessages
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
|
||||
|
||||
var doneBytes = []byte("[DONE]")
|
||||
|
||||
// decoder decodes streaming events from a Server Sent Event stream. It only supports
|
||||
// streams generated by the Anthropic completions API. IE this is not a fully
|
||||
// compliant Server Sent Events decoder.
|
||||
//
|
||||
// Adapted from internal/search/streaming/http/decoder.go.
|
||||
type decoder struct {
|
||||
scanner *bufio.Scanner
|
||||
done bool
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewDecoder(r io.Reader) *decoder {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
|
||||
// bufio.ScanLines, except we look for \n\n which separate events.
|
||||
split := func(data []byte, atEOF bool) (int, []byte, error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
|
||||
return i + 2, data[:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated event. This should
|
||||
// be empty.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
scanner.Split(split)
|
||||
return &decoder{
|
||||
scanner: scanner,
|
||||
}
|
||||
}
|
||||
|
||||
// Scan advances the decoder to the next event in the stream. It returns
|
||||
// false when it either hits the end of the stream or an error.
|
||||
func (d *decoder) Scan() bool {
|
||||
if d.done {
|
||||
return false
|
||||
}
|
||||
for d.scanner.Scan() {
|
||||
// event: $_name
|
||||
// data: json($data)|[DONE]
|
||||
|
||||
lines := bytes.Split(d.scanner.Bytes(), []byte("\n"))
|
||||
for _, line := range lines {
|
||||
typ, data := splitColon(line)
|
||||
|
||||
switch {
|
||||
case bytes.Equal(typ, []byte("data")):
|
||||
d.data = data
|
||||
// Check for special sentinel value used by the Anthropic API to
|
||||
// indicate that the stream is done.
|
||||
if bytes.Equal(data, doneBytes) {
|
||||
d.done = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case bytes.Equal(typ, []byte("event")):
|
||||
// Anthropic sends the event name in the data payload as well so we ignore it for snow
|
||||
continue
|
||||
default:
|
||||
d.err = errors.Errorf("malformed data, expected data: %s %q", typ, line)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.err = d.scanner.Err()
|
||||
return false
|
||||
}
|
||||
|
||||
// Event returns the event data of the last decoded event
|
||||
func (d *decoder) Data() []byte {
|
||||
return d.data
|
||||
}
|
||||
|
||||
// Err returns the last encountered error
|
||||
func (d *decoder) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
func splitColon(data []byte) ([]byte, []byte) {
|
||||
i := bytes.Index(data, []byte(":"))
|
||||
if i < 0 {
|
||||
return bytes.TrimSpace(data), nil
|
||||
}
|
||||
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
|
||||
}
|
||||
@ -0,0 +1,50 @@
|
||||
package anthropicmessages
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type event struct {
|
||||
data string
|
||||
}
|
||||
|
||||
decodeAll := func(input string) ([]event, error) {
|
||||
dec := NewDecoder(strings.NewReader(input))
|
||||
var events []event
|
||||
for dec.Scan() {
|
||||
events = append(events, event{
|
||||
data: string(dec.Data()),
|
||||
})
|
||||
}
|
||||
return events, dec.Err()
|
||||
}
|
||||
|
||||
t.Run("Single", func(t *testing.T) {
|
||||
events, err := decodeAll("event:foo\ndata:b\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}})
|
||||
})
|
||||
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
events, err := decodeAll("event:foo\ndata:b\n\nevent:foo\ndata:c\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "b"}, {data: "c"}})
|
||||
})
|
||||
|
||||
t.Run("ErrExpectedData", func(t *testing.T) {
|
||||
_, err := decodeAll("datas:b\n\n")
|
||||
require.Contains(t, err.Error(), "malformed data, expected data")
|
||||
})
|
||||
|
||||
t.Run("InterleavedPing", func(t *testing.T) {
|
||||
events, err := decodeAll("data:a\n\nevent: ping\ndata: pong\n\ndata:b\n\ndata: [DONE]\n\n")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, events, []event{{data: "a"}, {data: "pong"}, {data: "b"}})
|
||||
})
|
||||
}
|
||||
@ -74,21 +74,26 @@ func isAllowedCustomChatModel(model string, isProUser bool) bool {
|
||||
// When updating these two lists, make sure you also update `allowedModels` in codygateway_dotcom_user.go.
|
||||
if isProUser {
|
||||
switch model {
|
||||
case "anthropic/claude-2",
|
||||
case "anthropic/claude-3-sonnet-20240229",
|
||||
"anthropic/claude-3-opus-20240229",
|
||||
"fireworks/" + fireworks.Mixtral8x7bInstruct,
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4-1106-preview",
|
||||
|
||||
// Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2",
|
||||
"anthropic/claude-2.0",
|
||||
"anthropic/claude-2.1",
|
||||
"anthropic/claude-instant-1.2-cyan",
|
||||
"anthropic/claude-instant-1.2",
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1",
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4-1106-preview",
|
||||
"fireworks/" + fireworks.Mixtral8x7bInstruct:
|
||||
"anthropic/claude-instant-1":
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
switch model {
|
||||
case "anthropic/claude-2",
|
||||
case // Remove after the Claude 3 rollout is complete
|
||||
"anthropic/claude-2",
|
||||
"anthropic/claude-2.0",
|
||||
"anthropic/claude-instant-v1",
|
||||
"anthropic/claude-instant-1":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user