mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:31:43 +00:00
feat(cody-gateway): Add FLAGGED_MODEL_NAMES check (#63013)
* Cody Gateway: Add FLAGGED_MODEL_NAMES check * Update cmd/cody-gateway/internal/httpapi/completions/flagging.go Co-authored-by: Quinn Slack <quinn@slack.org> --------- Co-authored-by: Quinn Slack <quinn@slack.org>
This commit is contained in:
parent
5833a98185
commit
c4b5c73260
@ -128,6 +128,7 @@ func (a *AnthropicHandlerMethods) validateRequest(ctx context.Context, logger lo
|
||||
func (a *AnthropicHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicRequest) (*flaggingResult, error) {
|
||||
result, err := isFlaggedRequest(a.anthropicTokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: ar.Model,
|
||||
FlattenedPrompt: ar.Prompt,
|
||||
MaxTokens: int(ar.MaxTokensToSample),
|
||||
},
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
|
||||
@ -172,6 +173,7 @@ func (a *AnthropicMessagesHandlerMethods) validateRequest(ctx context.Context, l
|
||||
func (a *AnthropicMessagesHandlerMethods) shouldFlagRequest(ctx context.Context, logger log.Logger, ar anthropicMessagesRequest) (*flaggingResult, error) {
|
||||
result, err := isFlaggedRequest(a.tokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: ar.Model,
|
||||
FlattenedPrompt: ar.BuildPrompt(),
|
||||
MaxTokens: int(ar.MaxTokens),
|
||||
},
|
||||
|
||||
@ -2,6 +2,7 @@ package completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
@ -24,6 +25,10 @@ type flaggingConfig struct {
|
||||
MaxTokensToSampleFlaggingLimit int
|
||||
ResponseTokenBlockingLimit int
|
||||
|
||||
// FlaggedModelNames is a slice of LLM model names, e.g. "gpt-3.5-turbo",
|
||||
// that will lead to the request getting flagged.
|
||||
FlaggedModelNames []string
|
||||
|
||||
// If false, flaggingResult.shouldBlock will always be false when returned by isFlaggedRequest.
|
||||
RequestBlockingEnabled bool
|
||||
}
|
||||
@ -39,11 +44,15 @@ func makeFlaggingConfig(cfg config.FlaggingConfig) flaggingConfig {
|
||||
PromptTokenBlockingLimit: cfg.PromptTokenBlockingLimit,
|
||||
MaxTokensToSampleFlaggingLimit: cfg.MaxTokensToSampleFlaggingLimit,
|
||||
ResponseTokenBlockingLimit: cfg.ResponseTokenBlockingLimit,
|
||||
FlaggedModelNames: cfg.FlaggedModelNames,
|
||||
RequestBlockingEnabled: cfg.RequestBlockingEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
type flaggingRequest struct {
|
||||
// ModelName is the slug for the specific LLM model.
|
||||
// e.g. "llama-v2-13b-code"
|
||||
ModelName string
|
||||
FlattenedPrompt string
|
||||
MaxTokens int
|
||||
}
|
||||
@ -64,6 +73,10 @@ func isFlaggedRequest(tk tokenizer.Tokenizer, r flaggingRequest, cfg flaggingCon
|
||||
var reasons []string
|
||||
prompt := strings.ToLower(r.FlattenedPrompt)
|
||||
|
||||
if r.ModelName != "" && slices.Contains(cfg.FlaggedModelNames, r.ModelName) {
|
||||
reasons = append(reasons, "model_used")
|
||||
}
|
||||
|
||||
if hasValidPattern, _ := containsAny(prompt, cfg.AllowedPromptPatterns); len(cfg.AllowedPromptPatterns) > 0 && !hasValidPattern {
|
||||
reasons = append(reasons, "unknown_prompt")
|
||||
}
|
||||
|
||||
@ -40,12 +40,14 @@ func TestMakeFlaggingConfig(t *testing.T) {
|
||||
|
||||
func TestIsFlaggedRequest(t *testing.T) {
|
||||
validPreamble := "You are cody-gateway."
|
||||
flaggedModelNames := []string{"dangerous-llm-model"}
|
||||
|
||||
basicCfg := flaggingConfig{
|
||||
PromptTokenFlaggingLimit: 18000,
|
||||
PromptTokenBlockingLimit: 20000,
|
||||
MaxTokensToSampleFlaggingLimit: 4000,
|
||||
ResponseTokenBlockingLimit: 4000,
|
||||
FlaggedModelNames: flaggedModelNames,
|
||||
RequestBlockingEnabled: true,
|
||||
}
|
||||
cfgWithPreamble := flaggingConfig{
|
||||
@ -54,6 +56,7 @@ func TestIsFlaggedRequest(t *testing.T) {
|
||||
MaxTokensToSampleFlaggingLimit: 4000,
|
||||
ResponseTokenBlockingLimit: 4000,
|
||||
RequestBlockingEnabled: true,
|
||||
FlaggedModelNames: flaggedModelNames,
|
||||
AllowedPromptPatterns: []string{strings.ToLower(validPreamble)},
|
||||
}
|
||||
|
||||
@ -67,6 +70,7 @@ func TestIsFlaggedRequest(t *testing.T) {
|
||||
return isFlaggedRequest(
|
||||
tokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: "random-model-name",
|
||||
FlattenedPrompt: prompt,
|
||||
MaxTokens: 200,
|
||||
},
|
||||
@ -202,4 +206,54 @@ func TestIsFlaggedRequest(t *testing.T) {
|
||||
assert.Greater(t, result.promptTokenCount, tokenCountConfig.PromptTokenBlockingLimit)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ModelNameFlagging", func(t *testing.T) {
|
||||
t.Run("NotSet", func(t *testing.T) {
|
||||
result, err := isFlaggedRequest(
|
||||
tokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: "", // Test that this is OK.
|
||||
FlattenedPrompt: validPreamble + "legit request",
|
||||
MaxTokens: 200,
|
||||
},
|
||||
cfgWithPreamble)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsFlagged())
|
||||
})
|
||||
|
||||
t.Run("NotConfigured", func(t *testing.T) {
|
||||
// If no models were listed specified in the config, should work.
|
||||
cfgWithoutModelsList := cfgWithPreamble // copy
|
||||
cfgWithoutModelsList.FlaggedModelNames = nil
|
||||
|
||||
result, err := isFlaggedRequest(
|
||||
tokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: "arbitrary-model-name",
|
||||
FlattenedPrompt: validPreamble + "legit request",
|
||||
MaxTokens: 200,
|
||||
},
|
||||
cfgWithoutModelsList)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsFlagged())
|
||||
})
|
||||
|
||||
t.Run("FlaggedModel", func(t *testing.T) {
|
||||
result, err := isFlaggedRequest(
|
||||
tokenizer,
|
||||
flaggingRequest{
|
||||
ModelName: flaggedModelNames[0],
|
||||
FlattenedPrompt: validPreamble + "legit request",
|
||||
MaxTokens: 200,
|
||||
},
|
||||
cfgWithPreamble)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsFlagged())
|
||||
require.Equal(t, 1, len(result.reasons))
|
||||
assert.Equal(t, "model_used", result.reasons[0])
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -108,8 +108,9 @@ func (*GoogleHandlerMethods) validateRequest(_ context.Context, _ log.Logger, fe
|
||||
|
||||
func (g *GoogleHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger, req googleRequest) (*flaggingResult, error) {
|
||||
result, err := isFlaggedRequest(
|
||||
nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */
|
||||
nil, // tokenizer, meaning token counts aren't considered when for flagging consideration.
|
||||
flaggingRequest{
|
||||
ModelName: req.Model,
|
||||
FlattenedPrompt: req.BuildPrompt(),
|
||||
MaxTokens: int(req.GenerationConfig.MaxOutputTokens),
|
||||
},
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -122,6 +123,7 @@ func (o *OpenAIHandlerMethods) shouldFlagRequest(_ context.Context, _ log.Logger
|
||||
result, err := isFlaggedRequest(
|
||||
nil, /* tokenizer, meaning token counts aren't considered when for flagging consideration. */
|
||||
flaggingRequest{
|
||||
ModelName: req.Model,
|
||||
FlattenedPrompt: req.BuildPrompt(),
|
||||
MaxTokens: int(req.MaxTokens),
|
||||
},
|
||||
|
||||
@ -150,6 +150,10 @@ type FlaggingConfig struct {
|
||||
// So MaxTokensToSampleFlaggingLimit should be <= MaxTokensToSample.
|
||||
MaxTokensToSampleFlaggingLimit int
|
||||
|
||||
// FlaggedModelNames is a list of provider-specific model names (e.g. "gtp-3.5")
|
||||
// that if used will lead to the request being flagged.
|
||||
FlaggedModelNames []string
|
||||
|
||||
// ResponseTokenBlockingLimit is the maximum number of tokens we allow before outright blocking
|
||||
// a response. e.g. the client sends a request desiring a response with 100 max tokens, we will
|
||||
// block it IFF the ResponseTokenBlockingLimit is less than 100.
|
||||
@ -363,6 +367,8 @@ func (c *Config) loadFlaggingConfig(cfg *FlaggingConfig, envVarPrefix string) {
|
||||
cfg.PromptTokenBlockingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_BLOCKING_LIMIT", "20000", "Maximum number of prompt tokens to allow without blocking.")
|
||||
cfg.PromptTokenFlaggingLimit = c.GetInt(envVarPrefix+"PROMPT_TOKEN_FLAGGING_LIMIT", "18000", "Maximum number of prompt tokens to allow without flagging.")
|
||||
cfg.ResponseTokenBlockingLimit = c.GetInt(envVarPrefix+"RESPONSE_TOKEN_BLOCKING_LIMIT", "4000", "Maximum number of completion tokens to allow without blocking.")
|
||||
|
||||
cfg.FlaggedModelNames = maybeLoadLowercaseSlice("FLAGGED_MODEL_NAMES", "LLM models that will always lead to the request getting flagged.")
|
||||
}
|
||||
|
||||
// splitMaybe splits the provided string on commas, but returns nil if given the empty string.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user