mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 19:51:50 +00:00
parent
6ac54905a9
commit
3e15c63b3d
@ -32,6 +32,7 @@ go_library(
|
||||
"//lib/errors",
|
||||
"@com_github_grafana_regexp//:regexp",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
"@com_github_twin_go_away//:go-away",
|
||||
"@io_opentelemetry_go_otel//attribute",
|
||||
"@io_opentelemetry_go_otel//codes",
|
||||
"@io_opentelemetry_go_otel_trace//:trace",
|
||||
|
||||
@ -135,6 +135,7 @@ func NewAnthropicHandler(
|
||||
// user.
|
||||
2, // seconds
|
||||
autoFlushStreamingResponses,
|
||||
config.DetectedPromptPatterns,
|
||||
), nil
|
||||
}
|
||||
|
||||
@ -162,6 +163,10 @@ func (ar anthropicRequest) GetModel() string {
|
||||
return ar.Model
|
||||
}
|
||||
|
||||
func (ar anthropicRequest) BuildPrompt() string {
|
||||
return ar.Prompt
|
||||
}
|
||||
|
||||
type anthropicTokenCount struct {
|
||||
count int
|
||||
err error
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
@ -58,6 +59,7 @@ func NewFireworksHandler(
|
||||
// do any retries
|
||||
30, // seconds
|
||||
autoFlushStreamingResponses,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
@ -84,6 +86,17 @@ func (fr fireworksRequest) GetModel() string {
|
||||
return fr.Model
|
||||
}
|
||||
|
||||
func (fr fireworksRequest) BuildPrompt() string {
|
||||
if fr.Prompt != "" {
|
||||
return fr.Prompt
|
||||
}
|
||||
var sb strings.Builder
|
||||
for _, m := range fr.Messages {
|
||||
sb.WriteString(m.Content + "\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
|
||||
@ -6,10 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/shared/config"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
|
||||
@ -49,6 +49,7 @@ func NewOpenAIHandler(
|
||||
// help in a minute-long rate limit window.
|
||||
30, // seconds
|
||||
autoFlushStreamingResponses,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
@ -81,6 +82,14 @@ func (r openaiRequest) GetModel() string {
|
||||
return r.Model
|
||||
}
|
||||
|
||||
func (r openaiRequest) BuildPrompt() string {
|
||||
var sb strings.Builder
|
||||
for _, m := range r.Messages {
|
||||
sb.WriteString(m.Content + "\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type openaiUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
goaway "github.com/TwiN/go-away"
|
||||
"github.com/sourcegraph/log"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
@ -92,6 +93,8 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface {
|
||||
type UpstreamRequest interface {
|
||||
GetModel() string
|
||||
ShouldStream() bool
|
||||
// BuildPrompt returns the aggregated prompt (either full prompt as generated by Client, or all messages concatenated)
|
||||
BuildPrompt() string
|
||||
}
|
||||
|
||||
func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
@ -115,6 +118,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// response.
|
||||
defaultRetryAfterSeconds int,
|
||||
autoFlushStreamingResponses bool,
|
||||
patternsToDetect []string,
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped(upstreamName).
|
||||
// This URL is used only for logging reason so we default to the chat endpoint
|
||||
@ -249,6 +253,15 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
|
||||
// Retrieve metadata from the initial request.
|
||||
model, requestMetadata := methods.getRequestMetadata(body)
|
||||
prompt := body.BuildPrompt()
|
||||
if goaway.IsProfane(prompt) {
|
||||
requestMetadata["is_profane"] = true
|
||||
}
|
||||
for _, p := range patternsToDetect {
|
||||
if strings.Contains(prompt, p) {
|
||||
requestMetadata["detected_phrases"] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Match the model against the allowlist of models, which are configured
|
||||
// with the Cody Gateway model format "$PROVIDER/$MODEL_NAME". Models
|
||||
|
||||
@ -258,3 +258,9 @@ func gaugeHandler(counter metric.Int64UpDownCounter, attrs attribute.Set, handle
|
||||
counter.Add(context.Background(), -1, metric.WithAttributeSet(attrs))
|
||||
})
|
||||
}
|
||||
|
||||
type CompletionsConfig struct {
|
||||
logger log.Logger
|
||||
eventLogger events.Logger
|
||||
rs limiter.RedisStore
|
||||
}
|
||||
|
||||
@ -72,6 +72,7 @@ type AnthropicConfig struct {
|
||||
AccessToken string
|
||||
MaxTokensToSample int
|
||||
AllowedPromptPatterns []string
|
||||
DetectedPromptPatterns []string
|
||||
RequestBlockingEnabled bool
|
||||
}
|
||||
|
||||
@ -135,6 +136,7 @@ func (c *Config) Load() {
|
||||
}
|
||||
c.Anthropic.MaxTokensToSample = c.GetInt("CODY_GATEWAY_ANTHROPIC_MAX_TOKENS_TO_SAMPLE", "10000", "Maximum permitted value of maxTokensToSample")
|
||||
c.Anthropic.AllowedPromptPatterns = splitMaybe(c.GetOptional("CODY_GATEWAY_ANTHROPIC_ALLOWED_PROMPT_PATTERNS", "Prompt patterns to allow."))
|
||||
c.Anthropic.DetectedPromptPatterns = splitMaybe(c.GetOptional("CODY_GATEWAY_ANTHROPIC_DETECTED_PROMPT_PATTERNS", "Patterns to detect in prompt."))
|
||||
c.Anthropic.RequestBlockingEnabled = c.GetBool("CODY_GATEWAY_ANTHROPIC_REQUEST_BLOCKING_ENABLED", "false", "Whether we should block requests that match our blocking criteria.")
|
||||
|
||||
c.OpenAI.AccessToken = c.GetOptional("CODY_GATEWAY_OPENAI_ACCESS_TOKEN", "The OpenAI access token to be used.")
|
||||
|
||||
7
deps.bzl
7
deps.bzl
@ -5460,6 +5460,13 @@ def go_dependencies():
|
||||
sum = "h1:Y/M5lygoNPKwVNLMPXgVfsRT40CSFKXCxuU8LoHySjs=",
|
||||
version = "v0.0.0-20230623042737-f9a4f7ef6531",
|
||||
)
|
||||
go_repository(
|
||||
name = "com_github_twin_go_away",
|
||||
build_file_proto_mode = "disable_global",
|
||||
importpath = "github.com/TwiN/go-away",
|
||||
sum = "h1:80AjDyeTjfQaSFYbALzRcDKMAmxKW0a5PoxwXKZlW2A=",
|
||||
version = "v1.6.12",
|
||||
)
|
||||
go_repository(
|
||||
name = "com_github_uber_gonduit",
|
||||
build_file_proto_mode = "disable_global",
|
||||
|
||||
1
go.mod
1
go.mod
@ -250,6 +250,7 @@ require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.4.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.4.0
|
||||
github.com/TwiN/go-away v1.6.12
|
||||
github.com/aws/constructs-go/constructs/v10 v10.2.69
|
||||
github.com/aws/jsii-runtime-go v1.84.0
|
||||
github.com/dghubble/gologin/v2 v2.4.0
|
||||
|
||||
2
go.sum
2
go.sum
@ -152,6 +152,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
|
||||
github.com/RoaringBitmap/roaring v1.3.0 h1:aQmu9zQxDU0uhwR8SXOH/OrqEf+X8A0LQmwW3JX8Lcg=
|
||||
github.com/RoaringBitmap/roaring v1.3.0/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE=
|
||||
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
|
||||
github.com/TwiN/go-away v1.6.12 h1:80AjDyeTjfQaSFYbALzRcDKMAmxKW0a5PoxwXKZlW2A=
|
||||
github.com/TwiN/go-away v1.6.12/go.mod h1:MpvIC9Li3minq+CGgbgUDvQ9tDaeW35k5IXZrF9MVas=
|
||||
github.com/XSAM/otelsql v0.27.0 h1:i9xtxtdcqXV768a5C6SoT/RkG+ue3JTOgkYInzlTOqs=
|
||||
github.com/XSAM/otelsql v0.27.0/go.mod h1:0mFB3TvLa7NCuhm/2nU7/b2wEtsczkj8Rey8ygO7V+A=
|
||||
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
|
||||
|
||||
Loading…
Reference in New Issue
Block a user