Detect patterns in prompt (#60547)

WIP on detection
This commit is contained in:
Rafał Gajdulewicz 2024-02-15 11:31:58 +01:00 committed by GitHub
parent 6ac54905a9
commit 3e15c63b3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 60 additions and 1 deletions

View File

@ -32,6 +32,7 @@ go_library(
"//lib/errors",
"@com_github_grafana_regexp//:regexp",
"@com_github_sourcegraph_log//:log",
"@com_github_twin_go_away//:go-away",
"@io_opentelemetry_go_otel//attribute",
"@io_opentelemetry_go_otel//codes",
"@io_opentelemetry_go_otel_trace//:trace",

View File

@ -135,6 +135,7 @@ func NewAnthropicHandler(
// user.
2, // seconds
autoFlushStreamingResponses,
config.DetectedPromptPatterns,
), nil
}
@ -162,6 +163,10 @@ func (ar anthropicRequest) GetModel() string {
return ar.Model
}
func (ar anthropicRequest) BuildPrompt() string {
return ar.Prompt
}
type anthropicTokenCount struct {
count int
err error

View File

@ -7,6 +7,7 @@ import (
"io"
"math/rand"
"net/http"
"strings"
"github.com/sourcegraph/log"
@ -58,6 +59,7 @@ func NewFireworksHandler(
// do any retries
30, // seconds
autoFlushStreamingResponses,
nil,
)
}
@ -84,6 +86,17 @@ func (fr fireworksRequest) GetModel() string {
return fr.Model
}
func (fr fireworksRequest) BuildPrompt() string {
if fr.Prompt != "" {
return fr.Prompt
}
var sb strings.Builder
for _, m := range fr.Messages {
sb.WriteString(m.Content + "\n")
}
return sb.String()
}
type message struct {
Role string `json:"role"`
Content string `json:"content"`

View File

@ -6,10 +6,10 @@ import (
"encoding/json"
"io"
"net/http"
"strings"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -49,6 +49,7 @@ func NewOpenAIHandler(
// help in a minute-long rate limit window.
30, // seconds
autoFlushStreamingResponses,
nil,
)
}
@ -81,6 +82,14 @@ func (r openaiRequest) GetModel() string {
return r.Model
}
func (r openaiRequest) BuildPrompt() string {
var sb strings.Builder
for _, m := range r.Messages {
sb.WriteString(m.Content + "\n")
}
return sb.String()
}
type openaiUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`

View File

@ -12,6 +12,7 @@ import (
"strings"
"time"
goaway "github.com/TwiN/go-away"
"github.com/sourcegraph/log"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
@ -92,6 +93,8 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface {
type UpstreamRequest interface {
GetModel() string
ShouldStream() bool
// BuildPrompt returns the aggregated prompt (either full prompt as generated by Client, or all messages concatenated)
BuildPrompt() string
}
func makeUpstreamHandler[ReqT UpstreamRequest](
@ -115,6 +118,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// response.
defaultRetryAfterSeconds int,
autoFlushStreamingResponses bool,
patternsToDetect []string,
) http.Handler {
baseLogger = baseLogger.Scoped(upstreamName).
// This URL is used only for logging reason so we default to the chat endpoint
@ -249,6 +253,15 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// Retrieve metadata from the initial request.
model, requestMetadata := methods.getRequestMetadata(body)
prompt := body.BuildPrompt()
if goaway.IsProfane(prompt) {
requestMetadata["is_profane"] = true
}
for _, p := range patternsToDetect {
if strings.Contains(prompt, p) {
requestMetadata["detected_phrases"] = true
}
}
// Match the model against the allowlist of models, which are configured
// with the Cody Gateway model format "$PROVIDER/$MODEL_NAME". Models

View File

@ -258,3 +258,9 @@ func gaugeHandler(counter metric.Int64UpDownCounter, attrs attribute.Set, handle
counter.Add(context.Background(), -1, metric.WithAttributeSet(attrs))
})
}
type CompletionsConfig struct {
logger log.Logger
eventLogger events.Logger
rs limiter.RedisStore
}

View File

@ -72,6 +72,7 @@ type AnthropicConfig struct {
AccessToken string
MaxTokensToSample int
AllowedPromptPatterns []string
DetectedPromptPatterns []string
RequestBlockingEnabled bool
}
@ -135,6 +136,7 @@ func (c *Config) Load() {
}
c.Anthropic.MaxTokensToSample = c.GetInt("CODY_GATEWAY_ANTHROPIC_MAX_TOKENS_TO_SAMPLE", "10000", "Maximum permitted value of maxTokensToSample")
c.Anthropic.AllowedPromptPatterns = splitMaybe(c.GetOptional("CODY_GATEWAY_ANTHROPIC_ALLOWED_PROMPT_PATTERNS", "Prompt patterns to allow."))
c.Anthropic.DetectedPromptPatterns = splitMaybe(c.GetOptional("CODY_GATEWAY_ANTHROPIC_DETECTED_PROMPT_PATTERNS", "Patterns to detect in prompt."))
c.Anthropic.RequestBlockingEnabled = c.GetBool("CODY_GATEWAY_ANTHROPIC_REQUEST_BLOCKING_ENABLED", "false", "Whether we should block requests that match our blocking criteria.")
c.OpenAI.AccessToken = c.GetOptional("CODY_GATEWAY_OPENAI_ACCESS_TOKEN", "The OpenAI access token to be used.")

View File

@ -5460,6 +5460,13 @@ def go_dependencies():
sum = "h1:Y/M5lygoNPKwVNLMPXgVfsRT40CSFKXCxuU8LoHySjs=",
version = "v0.0.0-20230623042737-f9a4f7ef6531",
)
go_repository(
name = "com_github_twin_go_away",
build_file_proto_mode = "disable_global",
importpath = "github.com/TwiN/go-away",
sum = "h1:80AjDyeTjfQaSFYbALzRcDKMAmxKW0a5PoxwXKZlW2A=",
version = "v1.6.12",
)
go_repository(
name = "com_github_uber_gonduit",
build_file_proto_mode = "disable_global",

1
go.mod
View File

@ -250,6 +250,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.4.1
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.4.0
github.com/TwiN/go-away v1.6.12
github.com/aws/constructs-go/constructs/v10 v10.2.69
github.com/aws/jsii-runtime-go v1.84.0
github.com/dghubble/gologin/v2 v2.4.0

2
go.sum
View File

@ -152,6 +152,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
github.com/RoaringBitmap/roaring v1.3.0 h1:aQmu9zQxDU0uhwR8SXOH/OrqEf+X8A0LQmwW3JX8Lcg=
github.com/RoaringBitmap/roaring v1.3.0/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE=
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
github.com/TwiN/go-away v1.6.12 h1:80AjDyeTjfQaSFYbALzRcDKMAmxKW0a5PoxwXKZlW2A=
github.com/TwiN/go-away v1.6.12/go.mod h1:MpvIC9Li3minq+CGgbgUDvQ9tDaeW35k5IXZrF9MVas=
github.com/XSAM/otelsql v0.27.0 h1:i9xtxtdcqXV768a5C6SoT/RkG+ue3JTOgkYInzlTOqs=
github.com/XSAM/otelsql v0.27.0/go.mod h1:0mFB3TvLa7NCuhm/2nU7/b2wEtsczkj8Rey8ygO7V+A=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=