feat(cody-gateway): Add FLAGGED_MODEL_NAMES check (#63013)

* Cody Gateway: Add FLAGGED_MODEL_NAMES check

* Update cmd/cody-gateway/internal/httpapi/completions/flagging.go

Co-authored-by: Quinn Slack <quinn@slack.org>

---------

Co-authored-by: Quinn Slack <quinn@slack.org>
This commit is contained in:
Chris Smith 2024-05-31 13:12:27 -07:00 committed by GitHub
parent 5833a98185
commit c4b5c73260
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 80 additions and 1 deletions

View File

@ -128,6 +128,7 @@ func (a *AnthropicHandlerMethods) validateRequest(ctx context.Context, logger lo
func (a *AnthropicHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicRequest) (*flaggingResult, error) {
result, err := isFlaggedRequest(a.anthropicTokenizer,
flaggingRequest{
ModelName: ar.Model,
FlattenedPrompt: ar.Prompt,
MaxTokens: int(ar.MaxTokensToSample),
},

View File

@ -9,6 +9,7 @@ import (
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -172,6 +173,7 @@ func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, l
func (a *AnthropicMessagesHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicMessagesRequest) (*flaggingResult, error) {
result, err := isFlaggedRequest(a.tokenizer,
flaggingRequest{
ModelName: ar.Model,
FlattenedPrompt: ar.BuildPrompt(),
MaxTokens: int(ar.MaxTokens),
},

View File

@ -2,6 +2,7 @@ package completions
import (
"context"
"slices"
"strings"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
@ -24,6 +25,10 @@ type flaggingConfig struct {
MaxTokensToSampleFlaggingLimit int
ResponseTokenBlockingLimit int
// FlaggedModelNames is a slice of LLM model names, e.g. "gpt-3.5-turbo",
// that will lead to the request getting flagged.
FlaggedModelNames []string
// If false, flaggingResult.shouldBlock will always be false when returned by isFlaggedRequest.
RequestBlockingEnabled bool
}
@ -39,11 +44,15 @@ func makeFlaggingConfig(cfg config.FlaggingConfig) flaggingConfig {
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
FlaggedModelNames: cfg.FlaggedModelNames,
RequestBlockingEnabled: cfg.RequestBlockingEnabled,
}
}
type flaggingRequest struct {
// ModelName is the slug for the specific LLM model.
// e.g. "llama-v2-13b-code"
ModelName string
FlattenedPrompt string
MaxTokens int
}
@ -64,6 +73,10 @@ func isFlaggedRequest(tk tokenizer.Tokenizer, r flaggingRequest, cfg flaggingCon
var reasons []string
prompt := strings.ToLower(r.FlattenedPrompt)
if r.ModelName != "" && slices.Contains(cfg.FlaggedModelNames, r.ModelName) {
reasons = append(reasons, "model_used")
}
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
reasons = append(reasons, "unknown_prompt")
}

View File

@ -40,12 +40,14 @@ func TestMakeFlaggingConfig(t *testing.T) {
func TestIsFlaggedRequest(t *testing.T) {
validPreamble := "You are cody-gateway."
flaggedModelNames := []string{"dangerous-llm-model"}
basicCfg := flaggingConfig{
PromptTokenFlaggingLimit: 18000,
PromptTokenBlockingLimit: 20000,
MaxTokensToSampleFlaggingLimit: 4000,
ResponseTokenBlockingLimit: 4000,
FlaggedModelNames: flaggedModelNames,
RequestBlockingEnabled: true,
}
cfgWithPreamble := flaggingConfig{
@ -54,6 +56,7 @@ func TestIsFlaggedRequest(t *testing.T) {
MaxTokensToSampleFlaggingLimit: 4000,
ResponseTokenBlockingLimit: 4000,
RequestBlockingEnabled: true,
FlaggedModelNames: flaggedModelNames,
AllowedPromptPatterns: []string{strings.ToLower(validPreamble)},
}
@ -67,6 +70,7 @@ func TestIsFlaggedRequest(t *testing.T) {
return isFlaggedRequest(
tokenizer,
flaggingRequest{
ModelName: "random-model-name",
FlattenedPrompt: prompt,
MaxTokens: 200,
},
@ -202,4 +206,54 @@ func TestIsFlaggedRequest(t *testing.T) {
assert.Greater(t, result.promptTokenCount, tokenCountConfig.PromptTokenBlockingLimit)
})
})
t.Run("ModelNameFlagging", func(t *testing.T) {
t.Run("NotSet", func(t *testing.T) {
result, err := isFlaggedRequest(
tokenizer,
flaggingRequest{
ModelName: "", // Test that this is OK.
FlattenedPrompt: validPreamble + "legit request",
MaxTokens: 200,
},
cfgWithPreamble)
require.NoError(t, err)
require.False(t, result.IsFlagged())
})
t.Run("NotConfigured", func(t *testing.T) {
// If no models were listed specified in the config, should work.
cfgWithoutModelsList := cfgWithPreamble // copy
cfgWithoutModelsList.FlaggedModelNames = nil
result, err := isFlaggedRequest(
tokenizer,
flaggingRequest{
ModelName: "arbitrary-model-name",
FlattenedPrompt: validPreamble + "legit request",
MaxTokens: 200,
},
cfgWithoutModelsList)
require.NoError(t, err)
require.False(t, result.IsFlagged())
})
t.Run("FlaggedModel", func(t *testing.T) {
result, err := isFlaggedRequest(
tokenizer,
flaggingRequest{
ModelName: flaggedModelNames[0],
FlattenedPrompt: validPreamble + "legit request",
MaxTokens: 200,
},
cfgWithPreamble)
require.NoError(t, err)
require.True(t, result.IsFlagged())
require.Equal(t, 1, len(result.reasons))
assert.Equal(t, "model_used", result.reasons[0])
})
})
}

View File

@ -108,8 +108,9 @@ func (*GoogleHandlerMethods) validateRequest(_ context.Context, _ log.Logger, fe
func (g *GoogleHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger, req googleRequest) (*flaggingResult, error) {
result, err := isFlaggedRequest(
nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */
nil, // tokenizer, meaning token counts aren't considered when for flagging consideration.
flaggingRequest{
ModelName: req.Model,
FlattenedPrompt: req.BuildPrompt(),
MaxTokens: int(req.GenerationConfig.MaxOutputTokens),
},

View File

@ -9,6 +9,7 @@ import (
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -122,6 +123,7 @@ func (o *OpenAIHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger
result, err := isFlaggedRequest(
nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */
flaggingRequest{
ModelName: req.Model,
FlattenedPrompt: req.BuildPrompt(),
MaxTokens: int(req.MaxTokens),
},

View File

@ -150,6 +150,10 @@ type FlaggingConfig struct {
// So MaxTokensToSampleFlaggingLimit should be <= MaxTokensToSample.
MaxTokensToSampleFlaggingLimit int
// FlaggedModelNames is a list of provider-specific model names (e.g. "gtp-3.5")
// that if used will lead to the request being flagged.
FlaggedModelNames []string
// ResponseTokenBlockingLimit is the maximum number of tokens we allow before outright blocking
// a response. e.g. the client sends a request desiring a response with 100 max tokens, we will
// block it IFF the ResponseTokenBlockingLimit is less than 100.
@ -363,6 +367,8 @@ func (c *Config) loadFlaggingConfig(cfg *FlaggingConfig, envVarPrefix string) {
cfg.PromptTokenBlockingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_BLOCKING_LIMIT", "20000", "Maximum number of prompt tokens to allow without blocking.")
cfg.PromptTokenFlaggingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_FLAGGING_LIMIT", "18000", "Maximum number of prompt tokens to allow without flagging.")
cfg.ResponseTokenBlockingLimit = c.GetInt(envVarPrefix+"RESPONSE_TOKEN_BLOCKING_LIMIT", "4000", "Maximum number of completion tokens to allow without blocking.")
cfg.FlaggedModelNames = maybeLoadLowercaseSlice("FLAGGED_MODEL_NAMES", "LLM models that will always lead to the request getting flagged.")
}
// splitMaybe splits the provided string on commas, but returns nil if given the empty string.