diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go index 70d73f83ce8..8b4eb170835 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go @@ -128,6 +128,7 @@ func (a *AnthropicHandlerMethods) validateRequest(ctx context.Context, logger lo func (a *AnthropicHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicRequest) (*flaggingResult, error) { result, err := isFlaggedRequest(a.anthropicTokenizer, flaggingRequest{ + ModelName: ar.Model, FlattenedPrompt: ar.Prompt, MaxTokens: int(ar.MaxTokensToSample), }, diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go index c254ed353a1..256b05994f4 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -172,6 +173,7 @@ func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, l func (a *AnthropicMessagesHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicMessagesRequest) (*flaggingResult, error) { result, err := isFlaggedRequest(a.tokenizer, flaggingRequest{ + ModelName: ar.Model, FlattenedPrompt: ar.BuildPrompt(), MaxTokens: int(ar.MaxTokens), }, diff --git a/cmd/cody-gateway/internal/httpapi/completions/flagging.go b/cmd/cody-gateway/internal/httpapi/completions/flagging.go index e28b62aef02..26b3de231fe 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/flagging.go +++ b/cmd/cody-gateway/internal/httpapi/completions/flagging.go @@ -2,6 +2,7 @@ package completions import ( "context" + "slices" "strings" "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" @@ -24,6 +25,10 @@ type flaggingConfig struct { MaxTokensToSampleFlaggingLimit int ResponseTokenBlockingLimit int + // FlaggedModelNames is a slice of LLM model names, e.g. "gpt-3.5-turbo", + // that will lead to the request getting flagged. + FlaggedModelNames []string + // If false, flaggingResult.shouldBlock will always be false when returned by isFlaggedRequest. RequestBlockingEnabled bool } @@ -39,11 +44,15 @@ func makeFlaggingConfig(cfg config.FlaggingConfig) flaggingConfig { PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit, MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit, ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit, + FlaggedModelNames: cfg.FlaggedModelNames, RequestBlockingEnabled: cfg.RequestBlockingEnabled, } } type flaggingRequest struct { + // ModelName is the slug for the specific LLM model. + // e.g. "llama-v2-13b-code" + ModelName string FlattenedPrompt string MaxTokens int } @@ -64,6 +73,10 @@ func isFlaggedRequest(tk tokenizer.Tokenizer, r flaggingRequest, cfg flaggingCon var reasons []string prompt := strings.ToLower(r.FlattenedPrompt) + if r.ModelName != "" && slices.Contains(cfg.FlaggedModelNames, r.ModelName) { + reasons = append(reasons, "model_used") + } + if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern { reasons = append(reasons, "unknown_prompt") } diff --git a/cmd/cody-gateway/internal/httpapi/completions/flagging_test.go b/cmd/cody-gateway/internal/httpapi/completions/flagging_test.go index 7592bf7cf3e..04c1282609e 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/flagging_test.go +++ b/cmd/cody-gateway/internal/httpapi/completions/flagging_test.go @@ -40,12 +40,14 @@ func TestMakeFlaggingConfig(t *testing.T) { func TestIsFlaggedRequest(t *testing.T) { validPreamble := "You are cody-gateway." + flaggedModelNames := []string{"dangerous-llm-model"} basicCfg := flaggingConfig{ PromptTokenFlaggingLimit: 18000, PromptTokenBlockingLimit: 20000, MaxTokensToSampleFlaggingLimit: 4000, ResponseTokenBlockingLimit: 4000, + FlaggedModelNames: flaggedModelNames, RequestBlockingEnabled: true, } cfgWithPreamble := flaggingConfig{ @@ -54,6 +56,7 @@ func TestIsFlaggedRequest(t *testing.T) { MaxTokensToSampleFlaggingLimit: 4000, ResponseTokenBlockingLimit: 4000, RequestBlockingEnabled: true, + FlaggedModelNames: flaggedModelNames, AllowedPromptPatterns: []string{strings.ToLower(validPreamble)}, } @@ -67,6 +70,7 @@ func TestIsFlaggedRequest(t *testing.T) { return isFlaggedRequest( tokenizer, flaggingRequest{ + ModelName: "random-model-name", FlattenedPrompt: prompt, MaxTokens: 200, }, @@ -202,4 +206,54 @@ func TestIsFlaggedRequest(t *testing.T) { assert.Greater(t, result.promptTokenCount, tokenCountConfig.PromptTokenBlockingLimit) }) }) + + t.Run("ModelNameFlagging", func(t *testing.T) { + t.Run("NotSet", func(t *testing.T) { + result, err := isFlaggedRequest( + tokenizer, + flaggingRequest{ + ModelName: "", // Test that this is OK. + FlattenedPrompt: validPreamble + "legit request", + MaxTokens: 200, + }, + cfgWithPreamble) + + require.NoError(t, err) + require.False(t, result.IsFlagged()) + }) + + t.Run("NotConfigured", func(t *testing.T) { + // If no models were listed specified in the config, should work. + cfgWithoutModelsList := cfgWithPreamble // copy + cfgWithoutModelsList.FlaggedModelNames = nil + + result, err := isFlaggedRequest( + tokenizer, + flaggingRequest{ + ModelName: "arbitrary-model-name", + FlattenedPrompt: validPreamble + "legit request", + MaxTokens: 200, + }, + cfgWithoutModelsList) + + require.NoError(t, err) + require.False(t, result.IsFlagged()) + }) + + t.Run("FlaggedModel", func(t *testing.T) { + result, err := isFlaggedRequest( + tokenizer, + flaggingRequest{ + ModelName: flaggedModelNames[0], + FlattenedPrompt: validPreamble + "legit request", + MaxTokens: 200, + }, + cfgWithPreamble) + + require.NoError(t, err) + require.True(t, result.IsFlagged()) + require.Equal(t, 1, len(result.reasons)) + assert.Equal(t, "model_used", result.reasons[0]) + }) + }) } diff --git a/cmd/cody-gateway/internal/httpapi/completions/google.go b/cmd/cody-gateway/internal/httpapi/completions/google.go index 8936aad8661..55c9e13a8e9 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google.go @@ -108,8 +108,9 @@ func (*GoogleHandlerMethods) validateRequest(_ context.Context, _ log.Logger, fe func (g *GoogleHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger, req googleRequest) (*flaggingResult, error) { result, err := isFlaggedRequest( - nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */ + nil, // tokenizer, meaning token counts aren't considered when for flagging consideration. flaggingRequest{ + ModelName: req.Model, FlattenedPrompt: req.BuildPrompt(), MaxTokens: int(req.GenerationConfig.MaxOutputTokens), }, diff --git a/cmd/cody-gateway/internal/httpapi/completions/openai.go b/cmd/cody-gateway/internal/httpapi/completions/openai.go index fadcf8d6568..c407cdcb627 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/openai.go +++ b/cmd/cody-gateway/internal/httpapi/completions/openai.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config" "github.com/sourcegraph/sourcegraph/internal/completions/client/openai" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -122,6 +123,7 @@ func (o *OpenAIHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger result, err := isFlaggedRequest( nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */ flaggingRequest{ + ModelName: req.Model, FlattenedPrompt: req.BuildPrompt(), MaxTokens: int(req.MaxTokens), }, diff --git a/cmd/cody-gateway/shared/config/config.go b/cmd/cody-gateway/shared/config/config.go index 2296c63408d..89b9ac3b391 100644 --- a/cmd/cody-gateway/shared/config/config.go +++ b/cmd/cody-gateway/shared/config/config.go @@ -150,6 +150,10 @@ type FlaggingConfig struct { // So MaxTokensToSampleFlaggingLimit should be <= MaxTokensToSample. MaxTokensToSampleFlaggingLimit int + // FlaggedModelNames is a list of provider-specific model names (e.g. "gtp-3.5") + // that if used will lead to the request being flagged. + FlaggedModelNames []string + // ResponseTokenBlockingLimit is the maximum number of tokens we allow before outright blocking // a response. e.g. the client sends a request desiring a response with 100 max tokens, we will // block it IFF the ResponseTokenBlockingLimit is less than 100. @@ -363,6 +367,8 @@ func (c *Config) loadFlaggingConfig(cfg *FlaggingConfig, envVarPrefix string) { cfg.PromptTokenBlockingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_BLOCKING_LIMIT", "20000", "Maximum number of prompt tokens to allow without blocking.") cfg.PromptTokenFlaggingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_FLAGGING_LIMIT", "18000", "Maximum number of prompt tokens to allow without flagging.") cfg.ResponseTokenBlockingLimit = c.GetInt(envVarPrefix+"RESPONSE_TOKEN_BLOCKING_LIMIT", "4000", "Maximum number of completion tokens to allow without blocking.") + + cfg.FlaggedModelNames = maybeLoadLowercaseSlice("FLAGGED_MODEL_NAMES", "LLM models that will always lead to the request getting flagged.") } // splitMaybe splits the provided string on commas, but returns nil if given the empty string.