self hosted models (#63899)

This PR is stacked on top of all the prior work @chrsmith has done for
shuffling configuration data around; it implements the new "Self hosted
models" functionality.

## Configuration

Configuring a Sourcegraph instance to use self-hosted models basically
involves adding some configuration like this to the site config (if you
set `modelConfiguration`, you are opting in to the new system which is
in early access):

```
  // Setting this field means we are opting into the new Cody model configuration system.
  "modelConfiguration": {
    // Disable use of Sourcegraph's servers for model discovery
    "sourcegraph": null,

    // Create two model providers
    "providerOverrides": [
      {
        // Our first model provider "mistral" will be a Huggingface TGI deployment which hosts our
        // mistral model for chat functionality.
        "id": "mistral",
        "displayName": "Mistral",
        "serverSideConfig": {
          "type": "huggingface-tgi",
          "endpoints": [{"url": "https://mistral.example.com/v1"}]
        },
      },
      {
        // Our second model provider "bigcode" will be a Huggingface TGI deployment which hosts our
        // bigcode/starcoder model for code completion functionality.
        "id": "bigcode",
        "displayName": "Bigcode",
        "serverSideConfig": {
          "type": "huggingface-tgi",
          "endpoints": [{"url": "http://starcoder.example.com/v1"}]
        }
      }
    ],

    // Make these two models available to Cody users
    "modelOverridesRecommendedSettings": [
      "mistral::v1::mixtral-8x7b-instruct",
      "bigcode::v1::starcoder2-7b"
    ],

    // Configure which models Cody will use by default
    "defaultModels": {
      "chat": "mistral::v1::mixtral-8x7b-instruct",
      "fastChat": "mistral::v1::mixtral-8x7b-instruct",
      "codeCompletion": "bigcode::v1::starcoder2-7b"
    }
  }
```

More advanced configurations are possible, the above is our blessed
configuration for today.

## Hosting models

Another major component of this work is starting to build up
recommendations around how to self-host models, which ones to use, how
to configure them, etc.

For now, we've been testing with these two on a machine with dual A100s:

* Huggingface TGI (this is a Docker container for model inference, which
provides an OpenAI-compatible API - and is widely popular)
* Two models:
* Starcoder2 for code completion; specifically `bigcode/starcoder2-15b`
with `eetq` 8-bit quantization.
* Mixtral 8x7b instruct for chat; specifically
`casperhansen/mixtral-instruct-awq` which uses `awq` 4-bit quantization.

This is our 'starter' configuration. Other models - specifically other
starcoder 2, and mixtral instruct models - certainly work too, and
higher parameter versions may of course provide better results.

Documentation for how to deploy Huggingface TGI, suggested configuration
and debugging tips - coming soon.

## Advanced configuration

As part of this effort, I have added a quite extensive set of
configuration knobs to to the client side model configuration (see `type
ClientSideModelConfigOpenAICompatible` in this PR)

Some of these configuration options are needed for things to work at a
basic level, while others (e.g. prompt customization) are not needed for
basic functionality, but are very important for customers interested in
self-hosting their own models.

Today, Cody clients have a number of different _autocomplete provider
implementations_ which tie model-specific logic to enable autocomplete,
to a provider. For example, if you use a GPT model through Azure OpenAI,
the autocomplete provider for that is entirely different from what you'd
get if you used a GPT model through OpenAI officially. This can lead to
some subtle issues for us, and so it is worth exploring ways to have a
_generalized autocomplete provider_ - and since with self-hosted models
we _must_ address this problem, these configuration knobs fed to the
client from the server are a pathway to doing that - initially just for
self-hosted models, but in the future possibly generalized to other
providers.

## Debugging facilities

Working with customers in the past to use OpenAI-compatible APIs, we've
learned that debugging can be quite a pain. If you can't see what
requests the Sourcegraph backend is making, and what it is getting
back.. it can be quite painful to debug.

This PR implements quite extensive logging, and a `debugConnections`
flag which can be turned on to enable logging of the actual request
payloads and responses. This is critical when a customer is trying to
add support for a new model, their own custom OpenAI API service, etc.

## Robustness

Working with customers in the past, we also learned that various parts
of our backend `openai` provider were not super robust. For example, [if
more than one message was present it was a fatal
error](https://github.com/sourcegraph/sourcegraph/blob/main/internal/completions/client/openai/openai.go#L305),
or if the SSE stream yielded `{"error"}` payloads, they would go
ignored. Similarly, the SSE event stream parser we use is heavily
tailored towards [the exact response
structure](https://github.com/sourcegraph/sourcegraph/blob/main/internal/completions/client/openai/decoder.go#L15-L19)
which OpenAI's official API returns, and is therefor quite brittle if
connecting to a different SSE stream.

For this work, I have _started by forking_ our
`internal/completions/client/openai` - and made a number of major
improvements to it to make it more robust, handle errors better, etc.

I have also replaced the usage of a custom SSE event stream parser -
which was not spec compliant and brittle - with a proper SSE event
stream parser that recently popped up in the Go community:
https://github.com/tmaxmax/go-sse

My intention is that after more extensive testing, this new
`internal/completions/client/openaicompatible` provider will be more
robust, more correct, and all around better than
`internal/completions/client/openai` (and possibly the azure one) so
that we can just supersede those with this new `openaicompatible` one
entirely.

## Client implementation

Much of the work done in this PR is just "let the site admin configure
things, and broadcast that config to the client through the new model
config system."

Actually getting the clients to respect the new configuration, is a task
I am tackling in future `sourcegraph/cody` PRs.

## Test plan

1. This change currently lacks any unit/regression tests, that is a
major noteworthy point. I will follow-up with those in a future PR.
* However, these changes are **incredibly** isolated, clearly only
affecting customers who opt-in to this new self-hosted models
configuration.
* Most of the heavy lifting (SSE streaming, shuffling data around) is
done in other well-tested codebases.
2. Manual testing has played a big role here, specifically:
* Running a dev instance with the new configuration, actually connected
to Huggingface TGI deployed on a remote server.
* Using the new `debugConnections` mechanism (which customers would use)
to directly confirm requests are going to the right places, with the
right data and payloads.
* Confirming with a new client (changes not yet landed) that
autocomplete and chat functionality work.

Can we use more testing? Hell yeah, and I'm going to add it soon. Does
it work quite well and have small room for error? Also yes.

## Changelog

Cody Enterprise: added a new configuration for self-hosting models.
Reach out to support if you would like to use this feature as it is in
early access.

---------

Signed-off-by: Stephen Gutekanst <stephen@sourcegraph.com>
This commit is contained in:
Stephen Gutekanst 2024-07-18 18:34:02 -07:00 committed by GitHub
parent b4e03f45b0
commit dca1b9694d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1072 additions and 78 deletions

View File

@ -173,13 +173,18 @@ func convertServerSideProviderConfig(cfg *schema.ServerSideProviderConfig) *type
Endpoint: v.Endpoint,
},
}
} else if v := cfg.Openaicompatible; v != nil {
// TODO(slimsag): self-hosted-models: map this to OpenAICompatibleProviderConfig in the future
} else if v := cfg.HuggingfaceTgi; v != nil {
return &types.ServerSideProviderConfig{
GenericProvider: &types.GenericProviderConfig{
ServiceName: types.GenericServiceProviderOpenAI,
AccessToken: v.AccessToken,
Endpoint: v.Endpoint,
OpenAICompatible: &types.OpenAICompatibleProviderConfig{
Endpoints: convertOpenAICompatibleEndpoints(v.Endpoints),
EnableVerboseLogs: v.EnableVerboseLogs,
},
}
} else if v := cfg.Openaicompatible; v != nil {
return &types.ServerSideProviderConfig{
OpenAICompatible: &types.OpenAICompatibleProviderConfig{
Endpoints: convertOpenAICompatibleEndpoints(v.Endpoints),
EnableVerboseLogs: v.EnableVerboseLogs,
},
}
} else if v := cfg.Sourcegraph; v != nil {
@ -194,13 +199,57 @@ func convertServerSideProviderConfig(cfg *schema.ServerSideProviderConfig) *type
}
}
func convertOpenAICompatibleEndpoints(configEndpoints []*schema.OpenAICompatibleEndpoint) []types.OpenAICompatibleEndpoint {
var endpoints []types.OpenAICompatibleEndpoint
for _, e := range configEndpoints {
endpoints = append(endpoints, types.OpenAICompatibleEndpoint{
URL: e.Url,
AccessToken: e.AccessToken,
})
}
return endpoints
}
func convertClientSideModelConfig(v *schema.ClientSideModelConfig) *types.ClientSideModelConfig {
if v == nil {
return nil
}
return &types.ClientSideModelConfig{
// We currently do not have any known client-side model configuration.
cfg := &types.ClientSideModelConfig{}
if o := v.Openaicompatible; o != nil {
cfg.OpenAICompatible = &types.ClientSideModelConfigOpenAICompatible{
StopSequences: o.StopSequences,
EndOfText: o.EndOfText,
ContextSizeHintTotalCharacters: intPtrToUintPtr(o.ContextSizeHintTotalCharacters),
ContextSizeHintPrefixCharacters: intPtrToUintPtr(o.ContextSizeHintPrefixCharacters),
ContextSizeHintSuffixCharacters: intPtrToUintPtr(o.ContextSizeHintSuffixCharacters),
ChatPreInstruction: o.ChatPreInstruction,
EditPostInstruction: o.EditPostInstruction,
AutocompleteSinglelineTimeout: uint(o.AutocompleteSinglelineTimeout),
AutocompleteMultilineTimeout: uint(o.AutocompleteMultilineTimeout),
ChatTopK: float32(o.ChatTopK),
ChatTopP: float32(o.ChatTopP),
ChatTemperature: float32(o.ChatTemperature),
ChatMaxTokens: uint(o.ChatMaxTokens),
AutoCompleteTopK: float32(o.AutoCompleteTopK),
AutoCompleteTopP: float32(o.AutoCompleteTopP),
AutoCompleteTemperature: float32(o.AutoCompleteTemperature),
AutoCompleteSinglelineMaxTokens: uint(o.AutoCompleteSinglelineMaxTokens),
AutoCompleteMultilineMaxTokens: uint(o.AutoCompleteMultilineMaxTokens),
EditTopK: float32(o.EditTopK),
EditTopP: float32(o.EditTopP),
EditTemperature: float32(o.EditTemperature),
EditMaxTokens: uint(o.EditMaxTokens),
}
}
return cfg
}
func intPtrToUintPtr(v *int) *uint {
if v == nil {
return nil
}
ptr := uint(*v)
return &ptr
}
func convertServerSideModelConfig(cfg *schema.ServerSideModelConfig) *types.ServerSideModelConfig {
@ -213,6 +262,12 @@ func convertServerSideModelConfig(cfg *schema.ServerSideModelConfig) *types.Serv
ARN: v.Arn,
},
}
} else if v := cfg.Openaicompatible; v != nil {
return &types.ServerSideModelConfig{
OpenAICompatible: &types.ServerSideModelConfigOpenAICompatible{
APIModel: v.ApiModel,
},
}
} else {
panic(fmt.Sprintf("illegal state: %+v", v))
}
@ -262,19 +317,14 @@ func convertModelCapabilities(capabilities []string) []types.ModelCapability {
//
// It would specify these equivalent options for them under `modelOverrides`:
var recommendedSettings = map[types.ModelRef]types.ModelOverride{
"bigcode::v1::starcoder2-3b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-3b", "Starcoder2 3B", "starcoder2-3b"),
"bigcode::v1::starcoder2-7b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-7b", "Starcoder2 7B", "starcoder2-7b"),
"bigcode::v1::starcoder2-15b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-15b", "Starcoder2 15B", "starcoder2-15b"),
"mistral::v1::mistral-7b": recommendedSettingsMistral("mistral::v1::mistral-7b", "Mistral 7B", "mistral-7b"),
"mistral::v1::mistral-7b-instruct": recommendedSettingsMistral("mistral::v1::mistral-7b-instruct", "Mistral 7B Instruct", "mistral-7b-instruct"),
"mistral::v1::mixtral-8x7b": recommendedSettingsMistral("mistral::v1::mixtral-8x7b", "Mixtral 8x7B", "mixtral-8x7b"),
"mistral::v1::mixtral-8x22b": recommendedSettingsMistral("mistral::v1::mixtral-8x22b", "Mixtral 8x22B", "mixtral-8x22b"),
"mistral::v1::mixtral-8x7b-instruct": recommendedSettingsMistral("mistral::v1::mixtral-8x7b-instruct", "Mixtral 8x7B Instruct", "mixtral-8x7b-instruct"),
"mistral::v1::mixtral-8x22b-instruct": recommendedSettingsMistral("mistral::v1::mixtral-8x22b", "Mixtral 8x22B", "mixtral-8x22b-instruct"),
}
func recommendedSettingsStarcoder2(modelRef, displayName, modelName string) types.ModelOverride {
// TODO(slimsag): self-hosted-models: tune these further based on testing
return types.ModelOverride{
ModelRef: types.ModelRef(modelRef),
DisplayName: displayName,
@ -285,15 +335,18 @@ func recommendedSettingsStarcoder2(modelRef, displayName, modelName string) type
Tier: types.ModelTierEnterprise,
ContextWindow: types.ContextWindow{
MaxInputTokens: 8192,
MaxOutputTokens: 4000,
MaxOutputTokens: 4096,
},
ClientSideConfig: &types.ClientSideModelConfig{
OpenAICompatible: &types.ClientSideModelConfigOpenAICompatible{
StopSequences: []string{"<|endoftext|>", "<file_sep>"},
EndOfText: "<|endoftext|>",
},
},
ClientSideConfig: nil,
ServerSideConfig: nil,
}
}
func recommendedSettingsMistral(modelRef, displayName, modelName string) types.ModelOverride {
// TODO(slimsag): self-hosted-models: tune these further based on testing
return types.ModelOverride{
ModelRef: types.ModelRef(modelRef),
DisplayName: displayName,
@ -304,9 +357,10 @@ func recommendedSettingsMistral(modelRef, displayName, modelName string) types.M
Tier: types.ModelTierEnterprise,
ContextWindow: types.ContextWindow{
MaxInputTokens: 8192,
MaxOutputTokens: 4000,
MaxOutputTokens: 4096,
},
ClientSideConfig: &types.ClientSideModelConfig{
OpenAICompatible: &types.ClientSideModelConfigOpenAICompatible{},
},
ClientSideConfig: nil,
ServerSideConfig: nil,
}
}

View File

@ -160,8 +160,8 @@ func getProviderConfiguration(siteConfig *conftypes.CompletionsConfig) *types.Se
Endpoint: siteConfig.Endpoint,
}
// For all the other types of providers you can define in the site configuration, we
// just use a generic config. Rather than creating one for Anthropic, Fireworks, Google, etc.
// For all the other types of providers you can define in the legacy "completions" site configuration,
// we just use a generic config. Rather than creating one for Anthropic, Fireworks, Google, etc.
// We'll add those when needed, when we expose the newer style configuration in the site-config.
default:
serverSideConfig.GenericProvider = &types.GenericProviderConfig{

View File

@ -6237,6 +6237,13 @@ def go_dependencies():
sum = "h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=",
version = "v0.6.1",
)
go_repository(
name = "com_github_tmaxmax_go_sse",
build_file_proto_mode = "disable_global",
importpath = "github.com/tmaxmax/go-sse",
sum = "h1:pPpTgyyi1r7vG2o6icebnpGEh3ebcnBXqDWkb7aTofs=",
version = "v0.8.0",
)
go_repository(
name = "com_github_tmc_dot",
build_file_proto_mode = "disable_global",

1
go.mod
View File

@ -318,6 +318,7 @@ require (
github.com/sourcegraph/sourcegraph/lib v0.0.0-20240524140455-2589fef13ea8
github.com/sourcegraph/sourcegraph/lib/managedservicesplatform v0.0.0-00010101000000-000000000000
github.com/sourcegraph/sourcegraph/monitoring v0.0.0-00010101000000-000000000000
github.com/tmaxmax/go-sse v0.8.0
github.com/vektah/gqlparser/v2 v2.4.5
github.com/vvakame/gcplogurl v0.2.0
go.opentelemetry.io/collector/config/confighttp v0.103.0

2
go.sum
View File

@ -2410,6 +2410,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tmaxmax/go-sse v0.8.0 h1:pPpTgyyi1r7vG2o6icebnpGEh3ebcnBXqDWkb7aTofs=
github.com/tmaxmax/go-sse v0.8.0/go.mod h1:HLoxqxdH+7oSUItjtnpxjzJedfr/+Rrm/dNWBcTxJFM=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=

View File

@ -17,6 +17,7 @@ go_library(
"//internal/completions/client/fireworks",
"//internal/completions/client/google",
"//internal/completions/client/openai",
"//internal/completions/client/openaicompatible",
"//internal/completions/tokenusage",
"//internal/completions/types",
"//internal/httpcli",

View File

@ -10,6 +10,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks"
"github.com/sourcegraph/sourcegraph/internal/completions/client/google"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openai"
"github.com/sourcegraph/sourcegraph/internal/completions/client/openaicompatible"
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
@ -64,6 +65,11 @@ func getAPIProvider(modelConfigInfo types.ModelConfigInfo) (types.CompletionsCli
return client, errors.Wrap(err, "getting api provider")
}
// OpenAI Compatible
if openAICompatibleCfg := ssConfig.OpenAICompatible; openAICompatibleCfg != nil {
return openaicompatible.NewClient(httpcli.UncachedExternalClient, *tokenManager), nil
}
// The "GenericProvider" is an escape hatch for a set of API Providers not needing any additional configuration.
if genProviderCfg := ssConfig.GenericProvider; genProviderCfg != nil {
token := genProviderCfg.AccessToken

View File

@ -0,0 +1,20 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "openaicompatible",
srcs = [
"openaicompatible.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/openaicompatible",
visibility = ["//:__subpackages__"],
deps = [
"//internal/completions/tokenizer",
"//internal/completions/tokenusage",
"//internal/completions/types",
"//internal/modelconfig/types",
"//lib/errors",
"@com_github_sourcegraph_log//:log",
"@com_github_tmaxmax_go_sse//:go-sse",
],
)

View File

@ -0,0 +1,501 @@
package openaicompatible
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"path"
"reflect"
"strings"
"time"
"github.com/sourcegraph/log"
sse "github.com/tmaxmax/go-sse"
"github.com/sourcegraph/sourcegraph/internal/completions/tokenizer"
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
modelconfigSDK "github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func NewClient(
cli *http.Client,
tokenManager tokenusage.Manager,
) types.CompletionsClient {
return &client{
cli: cli,
tokenManager: tokenManager,
rng: rand.New(rand.NewSource(time.Now().Unix())),
}
}
type client struct {
cli *http.Client
tokenManager tokenusage.Manager
rng *rand.Rand
}
func (c *client) Complete(
ctx context.Context,
logger log.Logger,
request types.CompletionRequest,
) (*types.CompletionResponse, error) {
logger = logger.Scoped("OpenAICompatible")
var resp *http.Response
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
var (
req *http.Request
reqBody string
err error
)
if request.Feature == types.CompletionsFeatureCode {
req, reqBody, err = c.makeCompletionRequest(ctx, request, false)
} else {
req, reqBody, err = c.makeChatRequest(ctx, request, false)
}
if err != nil {
return nil, errors.Wrap(err, "making request")
}
requestID := c.rng.Uint32()
providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible
if providerConfig.EnableVerboseLogs {
logger.Info("request",
log.Uint32("id", requestID),
log.String("kind", "non-streaming"),
log.String("method", req.Method),
log.String("url", req.URL.String()),
// Note: log package will automatically redact token
log.String("headers", fmt.Sprint(req.Header)),
log.String("body", reqBody),
)
}
start := time.Now()
resp, err = c.cli.Do(req)
if err != nil {
logger.Error("request error",
log.Uint32("id", requestID),
log.Error(err),
)
return nil, errors.Wrap(err, "performing request")
}
if resp.StatusCode != http.StatusOK {
err := types.NewErrStatusNotOK("OpenAI", resp)
logger.Error("request error",
log.Uint32("id", requestID),
log.Error(err),
)
return nil, err
}
defer resp.Body.Close()
var response openaiResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
logger.Error("request error, decoding response",
log.Uint32("id", requestID),
log.Error(err),
)
return nil, errors.Wrap(err, "decoding response")
}
if providerConfig.EnableVerboseLogs {
// When debugging connections, log more verbose information like the actual completion we got back.
completion := ""
if len(response.Choices) > 0 {
completion = response.Choices[0].Text
}
logger.Info("request success",
log.Uint32("id", requestID),
log.Duration("time", time.Since(start)),
log.String("response_model", response.Model),
log.String("url", req.URL.String()),
log.String("system_fingerprint", response.SystemFingerprint),
log.String("finish_reason", response.maybeGetFinishReason()),
log.String("completion", completion),
)
} else {
logger.Info("request success",
log.Uint32("id", requestID),
log.Duration("time", time.Since(start)),
log.String("response_model", response.Model),
log.String("url", req.URL.String()),
log.String("system_fingerprint", response.SystemFingerprint),
log.String("finish_reason", response.maybeGetFinishReason()),
)
}
if len(response.Choices) == 0 {
// Empty response.
return &types.CompletionResponse{}, nil
}
modelID := request.ModelConfigInfo.Model.ModelRef.ModelID()
err = c.tokenManager.UpdateTokenCountsFromModelUsage(
response.Usage.PromptTokens,
response.Usage.CompletionTokens,
tokenizer.OpenAIModel+"/"+string(modelID),
string(request.Feature),
tokenusage.OpenAICompatible)
if err != nil {
logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err))
}
return &types.CompletionResponse{
Completion: response.Choices[0].Text,
StopReason: response.Choices[0].FinishReason,
}, nil
}
func (c *client) Stream(
ctx context.Context,
logger log.Logger,
request types.CompletionRequest,
sendEvent types.SendCompletionEvent,
) error {
logger = logger.Scoped("OpenAICompatible")
var (
req *http.Request
reqBody string
err error
)
if request.Feature == types.CompletionsFeatureCode {
req, reqBody, err = c.makeCompletionRequest(ctx, request, true)
} else {
req, reqBody, err = c.makeChatRequest(ctx, request, true)
}
if err != nil {
return errors.Wrap(err, "making request")
}
sseClient := &sse.Client{
HTTPClient: c.cli,
ResponseValidator: sse.DefaultValidator,
Backoff: sse.Backoff{
// Note: go-sse has a bug with retry logic (https://github.com/tmaxmax/go-sse/pull/38)
// where it will get stuck in an infinite retry loop due to an io.EOF error
// depending on how the server behaves. For now, we just do not expose retry/backoff
// logic. It's not really useful for these types of requests anyway given their
// short-lived nature.
MaxRetries: -1,
},
}
ctx, cancel := context.WithCancel(ctx)
conn := sseClient.NewConnection(req.WithContext(ctx))
var (
content string
ev types.CompletionResponse
promptTokens, completionTokens int
streamErr error
finishReason string
)
unsubscribe := conn.SubscribeMessages(func(event sse.Event) {
// Ignore any data that is not JSON-like
if !strings.HasPrefix(event.Data, "{") {
return
}
var resp openaiResponse
if err := json.Unmarshal([]byte(event.Data), &resp); err != nil {
streamErr = errors.Errorf("failed to decode event payload: %w - body: %s", err, event.Data)
cancel()
return
}
if reflect.DeepEqual(resp, openaiResponse{}) {
// Empty response, it may be an error payload then
var errResp openaiErrorResponse
if err := json.Unmarshal([]byte(event.Data), &errResp); err != nil {
streamErr = errors.Errorf("failed to decode error event payload: %w - body: %s", err, event.Data)
cancel()
return
}
if errResp.Error != "" || errResp.ErrorType != "" {
streamErr = errors.Errorf("SSE error: %s: %s", errResp.ErrorType, errResp.Error)
cancel()
return
}
}
// These are only included in the last message, so we're not worried about overwriting
if resp.Usage.PromptTokens > 0 {
promptTokens = resp.Usage.PromptTokens
}
if resp.Usage.CompletionTokens > 0 {
completionTokens = resp.Usage.CompletionTokens
}
if len(resp.Choices) > 0 {
if request.Feature == types.CompletionsFeatureCode {
content += resp.Choices[0].Text
} else {
content += resp.Choices[0].Delta.Content
}
ev = types.CompletionResponse{
Completion: content,
StopReason: resp.Choices[0].FinishReason,
}
err = sendEvent(ev)
if err != nil {
streamErr = errors.Errorf("failed to send event: %w", err)
cancel()
return
}
for _, choice := range resp.Choices {
if choice.FinishReason != "" {
// End of stream
finishReason = choice.FinishReason
streamErr = nil
cancel()
return
}
}
}
})
defer unsubscribe()
requestID := c.rng.Uint32()
providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible
if providerConfig.EnableVerboseLogs {
logger.Info("request",
log.Uint32("id", requestID),
log.String("kind", "streaming"),
log.String("method", req.Method),
log.String("url", req.URL.String()),
// Note: log package will automatically redact token
log.String("headers", fmt.Sprint(req.Header)),
log.String("body", reqBody),
)
}
start := time.Now()
err = conn.Connect()
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
// go-sse will return io.EOF on successful close of the connection, since it expects the
// connection to be long-lived. In our case, we expect the connection to close on success
// and be short lived, so this is a non-error.
err = nil
}
if streamErr != nil {
err = errors.Append(err, streamErr)
}
if err == nil && finishReason == "" {
// At this point, we successfully streamed the response to the client. But we need to make
// sure the client gets a non-empty StopReason at the very end, otherwise it would think
// the streamed response it got is partial / incomplete and may not display the completion
// to the user as a result.
err = sendEvent(types.CompletionResponse{
Completion: content,
StopReason: "stop_sequence", // pretend we hit a stop sequence (we did!)
})
}
if err != nil {
logger.Error("request error",
log.Uint32("id", requestID),
log.Error(err),
)
return errors.Wrap(err, "NewConnection")
}
if providerConfig.EnableVerboseLogs {
// When debugging connections, log more verbose information like the actual completion we got back.
logger.Info("request success",
log.Uint32("id", requestID),
log.Duration("time", time.Since(start)),
log.String("url", req.URL.String()),
log.String("finish_reason", finishReason),
log.String("completion", content),
)
} else {
logger.Info("request success",
log.Uint32("id", requestID),
log.Duration("time", time.Since(start)),
log.String("url", req.URL.String()),
log.String("finish_reason", finishReason),
)
}
modelID := request.ModelConfigInfo.Model.ModelRef.ModelID()
err = c.tokenManager.UpdateTokenCountsFromModelUsage(
promptTokens,
completionTokens,
tokenizer.OpenAIModel+"/"+string(modelID),
string(request.Feature),
tokenusage.OpenAICompatible,
)
if err != nil {
logger.Warn("Failed to count tokens with the token manager %w", log.Error(err))
}
return nil
}
func (c *client) makeChatRequest(
ctx context.Context,
request types.CompletionRequest,
stream bool,
) (*http.Request, string, error) {
requestParams := request.Parameters
if requestParams.TopK < 0 {
requestParams.TopK = 0
}
if requestParams.TopP < 0 {
requestParams.TopP = 0
}
payload := openAIChatCompletionsRequestParameters{
Model: getAPIModel(request),
Temperature: requestParams.Temperature,
TopP: requestParams.TopP,
N: requestParams.TopK,
Stream: stream,
MaxTokens: requestParams.MaxTokensToSample,
Stop: requestParams.StopSequences,
}
for _, m := range requestParams.Messages {
var role string
switch m.Speaker {
case types.SYSTEM_MESSAGE_SPEAKER:
role = "system"
case types.HUMAN_MESSAGE_SPEAKER:
role = "user"
case types.ASSISTANT_MESSAGE_SPEAKER:
role = "assistant"
default:
role = strings.ToLower(role)
}
payload.Messages = append(payload.Messages, message{
Role: role,
Content: m.Text,
})
}
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, "", errors.Wrap(err, "Marshal")
}
endpoint, err := getEndpoint(request, c.rng)
if err != nil {
return nil, "", errors.Wrap(err, "getEndpoint")
}
url, err := getEndpointURL(endpoint, "chat/completions")
if err != nil {
return nil, "", errors.Wrap(err, "getEndpointURL")
}
req, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, "", errors.Wrap(err, "NewRequestWithContext")
}
req.Header.Set("Content-Type", "application/json")
if endpoint.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+endpoint.AccessToken)
}
return req, string(reqBody), nil
}
func (c *client) makeCompletionRequest(
ctx context.Context,
request types.CompletionRequest,
stream bool,
) (*http.Request, string, error) {
requestParams := request.Parameters
if requestParams.TopK < 0 {
requestParams.TopK = 0
}
if requestParams.TopP < 0 {
requestParams.TopP = 0
}
prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return nil, "", errors.Wrap(err, "getPrompt")
}
payload := openAICompletionsRequestParameters{
Model: getAPIModel(request),
Temperature: requestParams.Temperature,
TopP: requestParams.TopP,
N: requestParams.TopK,
Stream: stream,
MaxTokens: requestParams.MaxTokensToSample,
Stop: requestParams.StopSequences,
Prompt: prompt,
}
reqBody, err := json.Marshal(payload)
if err != nil {
return nil, "", errors.Wrap(err, "Marshal")
}
endpoint, err := getEndpoint(request, c.rng)
if err != nil {
return nil, "", errors.Wrap(err, "getEndpoint")
}
url, err := getEndpointURL(endpoint, "completions")
if err != nil {
return nil, "", errors.Wrap(err, "getEndpointURL")
}
req, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, "", errors.Wrap(err, "NewRequestWithContext")
}
req.Header.Set("Content-Type", "application/json")
if endpoint.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+endpoint.AccessToken)
}
return req, string(reqBody), nil
}
func getPrompt(messages []types.Message) (string, error) {
if l := len(messages); l == 0 {
return "", errors.New("found zero messages in prompt")
}
return messages[0].Text, nil
}
func getAPIModel(request types.CompletionRequest) string {
ssConfig := request.ModelConfigInfo.Model.ServerSideConfig
if ssConfig != nil && ssConfig.OpenAICompatible != nil && ssConfig.OpenAICompatible.APIModel != "" {
return ssConfig.OpenAICompatible.APIModel
}
// Default to model name if not specified
return request.ModelConfigInfo.Model.ModelName
}
func getEndpoint(request types.CompletionRequest, rng *rand.Rand) (modelconfigSDK.OpenAICompatibleEndpoint, error) {
providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible
if len(providerConfig.Endpoints) == 0 {
return modelconfigSDK.OpenAICompatibleEndpoint{}, errors.New("no openaicompatible endpoint configured")
}
if len(providerConfig.Endpoints) == 1 {
return providerConfig.Endpoints[0], nil
}
randPick := rng.Intn(len(providerConfig.Endpoints))
return providerConfig.Endpoints[randPick], nil
}
func getEndpointURL(endpoint modelconfigSDK.OpenAICompatibleEndpoint, relativePath string) (*url.URL, error) {
url, err := url.Parse(endpoint.URL)
if err != nil {
return nil, errors.Newf("failed to parse endpoint URL: %q", endpoint.URL)
}
if url.Scheme == "" || url.Host == "" {
return nil, errors.Newf("unable to build URL, bad endpoint: %q", endpoint.URL)
}
url.Path = path.Join(url.Path, relativePath)
return url, nil
}

View File

@ -0,0 +1,77 @@
package openaicompatible
// openAIChatCompletionsRequestParameters request object for openAI chat endpoint https://platform.openai.com/docs/api-reference/chat/create
type openAIChatCompletionsRequestParameters struct {
Model string `json:"model"` // request.Model
Messages []message `json:"messages"` // request.Messages
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
TopP float32 `json:"top_p,omitempty"` // request.TopP
N int `json:"n,omitempty"` // always 1
Stream bool `json:"stream,omitempty"` // request.Stream
Stop []string `json:"stop,omitempty"` // request.StopSequences
MaxTokens int `json:"max_tokens,omitempty"` // request.MaxTokensToSample
PresencePenalty float32 `json:"presence_penalty,omitempty"` // unused
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // unused
LogitBias map[string]float32 `json:"logit_bias,omitempty"` // unused
User string `json:"user,omitempty"` // unused
}
// openAICompletionsRequestParameters payload for openAI completions endpoint https://platform.openai.com/docs/api-reference/completions/create
type openAICompletionsRequestParameters struct {
Model string `json:"model"` // request.Model
Prompt string `json:"prompt"` // request.Messages[0] - formatted prompt expected to be the only message
Temperature float32 `json:"temperature,omitempty"` // request.Temperature
TopP float32 `json:"top_p,omitempty"` // request.TopP
N int `json:"n,omitempty"` // always 1
Stream bool `json:"stream,omitempty"` // request.Stream
Stop []string `json:"stop,omitempty"` // request.StopSequences
MaxTokens int `json:"max_tokens,omitempty"` // request.MaxTokensToSample
PresencePenalty float32 `json:"presence_penalty,omitempty"` // unused
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // unused
LogitBias map[string]float32 `json:"logit_bias,omitempty"` // unused
Suffix string `json:"suffix,omitempty"` // unused
User string `json:"user,omitempty"` // unused
}
type message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openaiUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type openaiChoiceDelta struct {
Content string `json:"content"`
}
type openaiChoice struct {
Delta openaiChoiceDelta `json:"delta"`
Role string `json:"role"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type openaiResponse struct {
// Usage is only available for non-streaming requests.
Usage openaiUsage `json:"usage"`
Model string `json:"model"`
Choices []openaiChoice `json:"choices"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
}
func (r *openaiResponse) maybeGetFinishReason() string {
if len(r.Choices) == 0 {
return ""
}
return r.Choices[len(r.Choices)-1].FinishReason
}
// e.g. {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 159 `inputs` tokens and 4000 `max_new_tokens`","error_type":"validation"}
type openaiErrorResponse struct {
Error string `json:"error"`
ErrorType string `json:"error_type"`
}

View File

@ -11,6 +11,7 @@ go_library(
"//internal/completions/client/anthropic:__pkg__",
"//internal/completions/client/azureopenai:__pkg__",
"//internal/completions/client/openai:__pkg__",
"//internal/completions/client/openaicompatible:__pkg__",
"//internal/completions/tokenusage:__pkg__",
],
deps = [

View File

@ -15,6 +15,7 @@ go_library(
"//internal/completions/client/azureopenai:__pkg__",
"//internal/completions/client/codygateway:__pkg__",
"//internal/completions/client/openai:__pkg__",
"//internal/completions/client/openaicompatible:__pkg__",
"//internal/updatecheck:__pkg__",
],
deps = [

View File

@ -27,10 +27,11 @@ func NewManager() *Manager {
type Provider string
const (
OpenAI Provider = "openai"
AzureOpenAI Provider = "azureopenai"
AwsBedrock Provider = "awsbedrock"
Anthropic Provider = "anthropic"
OpenAI Provider = "openai"
OpenAICompatible Provider = "openaicompatible"
AzureOpenAI Provider = "azureopenai"
AwsBedrock Provider = "awsbedrock"
Anthropic Provider = "anthropic"
)
func (m *Manager) UpdateTokenCountsFromModelUsage(inputTokens, outputTokens int, model, feature string, provider Provider) error {

View File

@ -45,13 +45,14 @@ type ConfigFeatures struct {
type CompletionsProviderName string
const (
CompletionsProviderNameAnthropic CompletionsProviderName = "anthropic"
CompletionsProviderNameOpenAI CompletionsProviderName = "openai"
CompletionsProviderNameGoogle CompletionsProviderName = "google"
CompletionsProviderNameAzureOpenAI CompletionsProviderName = "azure-openai"
CompletionsProviderNameSourcegraph CompletionsProviderName = "sourcegraph"
CompletionsProviderNameFireworks CompletionsProviderName = "fireworks"
CompletionsProviderNameAWSBedrock CompletionsProviderName = "aws-bedrock"
CompletionsProviderNameAnthropic CompletionsProviderName = "anthropic"
CompletionsProviderNameOpenAI CompletionsProviderName = "openai"
CompletionsProviderNameGoogle CompletionsProviderName = "google"
CompletionsProviderNameAzureOpenAI CompletionsProviderName = "azure-openai"
CompletionsProviderNameOpenAICompatible CompletionsProviderName = "openai-compatible"
CompletionsProviderNameSourcegraph CompletionsProviderName = "sourcegraph"
CompletionsProviderNameFireworks CompletionsProviderName = "fireworks"
CompletionsProviderNameAWSBedrock CompletionsProviderName = "aws-bedrock"
)
type EmbeddingsConfig struct {

View File

@ -72,6 +72,28 @@ type GenericProviderConfig struct {
Endpoint string `json:"endpoint"`
}
// OpenAICompatibleProvider is a provider for connecting to OpenAI-compatible API endpoints
// supplied by various third-party software.
//
// Because many of these third-party providers provide slightly different semantics for the OpenAI API
// protocol, the Sourcegraph instance exposes this provider configuration which allows for much more
// extensive configuration than would be needed for the official OpenAI API.
type OpenAICompatibleProviderConfig struct {
// Endpoints where this API can be reached. If multiple are present, Sourcegraph will distribute
// load between them as it sees fit.
Endpoints []OpenAICompatibleEndpoint `json:"endpoints,omitempty"`
// Whether to enable verbose logging of requests, allowing for grepping the logs for "OpenAICompatible"
// and seeing e.g. what requests Cody is actually sending to your API endpoint.
EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"`
}
// A single API endpoint for an OpenAI-compatible API.
type OpenAICompatibleEndpoint struct {
URL string `json:"url"`
AccessToken string `json:"accessToken"`
}
// SourcegraphProviderConfig is the configuration blog for configuring a provider
// to be use Sourcegraph's Cody Gateway for requests.
type SourcegraphProviderConfig struct {
@ -82,25 +104,99 @@ type SourcegraphProviderConfig struct {
// The "Provider" is conceptually a namespace for models. The server-side provider configuration
// is needed to describe the API endpoint needed to serve its models.
type ServerSideProviderConfig struct {
AWSBedrock *AWSBedrockProviderConfig `json:"awsBedrock,omitempty"`
AzureOpenAI *AzureOpenAIProviderConfig `json:"azureOpenAi,omitempty"`
GenericProvider *GenericProviderConfig `json:"genericProvider,omitempty"`
SourcegraphProvider *SourcegraphProviderConfig `json:"sourcegraphProvider,omitempty"`
AWSBedrock *AWSBedrockProviderConfig `json:"awsBedrock,omitempty"`
AzureOpenAI *AzureOpenAIProviderConfig `json:"azureOpenAi,omitempty"`
OpenAICompatible *OpenAICompatibleProviderConfig `json:"openAICompatible,omitempty"`
GenericProvider *GenericProviderConfig `json:"genericProvider,omitempty"`
SourcegraphProvider *SourcegraphProviderConfig `json:"sourcegraphProvider,omitempty"`
}
// ========================================================
// Client-side Model Configuration Data
// ========================================================
// Anything that needs to be provided to Cody clients at the model-level can go here.
//
// For example, allowing the server to customize/override the LLM
// prompt used. Or describe how clients should upload context to
// remote servers, etc. Or "hints", like "this model is great when
// working with 'C' code.".
type ClientSideModelConfig struct {
// We currently do not have any known client-side model configuration.
// But later, if anything needs to be provided to Cody clients at the
// model-level it will go here.
OpenAICompatible *ClientSideModelConfigOpenAICompatible `json:"openAICompatible,omitempty"`
}
// Client-side model configuration used when the model is backed by an OpenAI-compatible API
// provider.
type ClientSideModelConfigOpenAICompatible struct {
// (optional) List of stop sequences to use for this model.
StopSequences []string `json:"stopSequences,omitempty"`
// (optional) EndOfText identifier used by the model. e.g. "<|endoftext|>", "<EOT>"
EndOfText string `json:"endOfText,omitempty"`
// (optional) A hint the client should use when producing context to send to the LLM.
// The maximum length of all context (prefix + suffix + snippets), in characters.
ContextSizeHintTotalCharacters *uint `json:"contextSizeHintTotalCharacters,omitempty"`
// (optional) A hint the client should use when producing context to send to the LLM.
// The maximum length of the document prefix (text before the cursor) to include, in characters.
ContextSizeHintPrefixCharacters *uint `json:"contextSizeHintPrefixCharacters,omitempty"`
// (optional) A hint the client should use when producing context to send to the LLM.
// The maximum length of the document suffix (text after the cursor) to include, in characters.
ContextSizeHintSuffixCharacters *uint `json:"contextSizeHintSuffixCharacters,omitempty"`
// (optional) Custom instruction to be included at the start of all chat messages
// when using this model, e.g. "Answer all questions in Spanish."
//
// For example, allowing the server to customize/override the LLM
// prompt used. Or describe how clients should upload context to
// remote servers, etc. Or "hints", like "this model is great when
// working with 'C' code.".
// Note: similar to Cody client config option `cody.chat.preInstruction`; if user has
// configured that it will be used instead of this.
ChatPreInstruction string `json:"chatPreInstruction,omitempty"`
// (optional) Custom instruction to be included at the end of all edit commands
// when using this model, e.g. "Write all unit tests with Jest instead of detected framework."
//
// Note: similar to Cody client config option `cody.edit.preInstruction`; if user has
// configured that it will be respected instead of this.
EditPostInstruction string `json:"editPostInstruction,omitempty"`
// (optional) How long the client should wait for autocomplete results to come back (milliseconds),
// before giving up and not displaying an autocomplete result at all.
//
// This applies on single-line completions, e.g. `var i = <completion>`
//
// Note: similar to hidden Cody client config option `cody.autocomplete.advanced.timeout.singleline`
// If user has configured that, it will be respected instead of this.
AutocompleteSinglelineTimeout uint `json:"autocompleteSinglelineTimeout,omitempty"`
// (optional) How long the client should wait for autocomplete results to come back (milliseconds),
// before giving up and not displaying an autocomplete result at all.
//
// This applies on multi-line completions, which are based on intent-detection when e.g. a code block
// is being completed, e.g. `func parseURL(url string) {<completion>`
//
// Note: similar to hidden Cody client config option `cody.autocomplete.advanced.timeout.multiline`
// If user has configured that, it will be respected instead of this.
AutocompleteMultilineTimeout uint `json:"autocompleteMultilineTimeout,omitempty"`
// (optional) model parameters to use for the chat feature
ChatTopK float32 `json:"chatTopK,omitempty"`
ChatTopP float32 `json:"chatTopP,omitempty"`
ChatTemperature float32 `json:"chatTemperature,omitempty"`
ChatMaxTokens uint `json:"chatMaxTokens,omitempty"`
// (optional) model parameters to use for the autocomplete feature
AutoCompleteTopK float32 `json:"autoCompleteTopK,omitempty"`
AutoCompleteTopP float32 `json:"autoCompleteTopP,omitempty"`
AutoCompleteTemperature float32 `json:"autoCompleteTemperature,omitempty"`
AutoCompleteSinglelineMaxTokens uint `json:"autoCompleteSinglelineMaxTokens,omitempty"`
AutoCompleteMultilineMaxTokens uint `json:"autoCompleteMultilineMaxTokens,omitempty"`
// (optional) model parameters to use for the edit feature
EditTopK float32 `json:"editTopK,omitempty"`
EditTopP float32 `json:"editTopP,omitempty"`
EditTemperature float32 `json:"editTemperature,omitempty"`
EditMaxTokens uint `json:"editMaxTokens,omitempty"`
}
// ========================================================
@ -116,6 +212,34 @@ type AWSBedrockProvisionedThroughput struct {
ARN string `json:"arn"`
}
type ServerSideModelConfig struct {
AWSBedrockProvisionedThroughput *AWSBedrockProvisionedThroughput `json:"awsBedrockProvisionedThroughput"`
type ServerSideModelConfigOpenAICompatible struct {
// APIModel is value actually sent to the OpenAI-compatible API in the "model" field. This
// is less like a "model name" or "model identifier", and more like "an opaque, potentially
// secret string."
//
// Much software that claims to 'implement the OpenAI API' actually overrides this field with
// other information NOT related to the model name, either making it _ineffective_ as a
// model name/identifier (e.g. you must send "tgi" or "AUTODETECT" irrespective of which model
// you want to use) OR using it to smuggle other (potentially sensitive) information like the
// name of the deployment, which cannot be shared with clients.
//
// If this field is not an empty string, we treat it as an opaque string to be sent with API
// requests (similar to an access token) and use it for nothing else. If this field is not
// specified, we default to the Model.ModelName.
//
// Examples (these would be sent in the OpenAI /chat/completions `"model"` field):
//
// * Huggingface TGI: "tgi"
// * NVIDIA NIM: "meta/llama3-70b-instruct"
// * AWS LISA (v2): "AUTODETECT"
// * AWS LISA (v1): "mistralai/Mistral7b-v0.3-Instruct ecs.textgen.tgi"
// * Ollama: "llama2"
// * Others: "<SECRET DEPLOYMENT NAME>"
//
APIModel string `json:"apiModel,omitempty"`
}
type ServerSideModelConfig struct {
AWSBedrockProvisionedThroughput *AWSBedrockProvisionedThroughput `json:"awsBedrockProvisionedThroughput,omitempty"`
OpenAICompatible *ServerSideModelConfigOpenAICompatible `json:"openAICompatible,omitempty"`
}

View File

@ -1,7 +1,6 @@
package modelconfig
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@ -64,17 +63,6 @@ func TestApplyModelOverrides(t *testing.T) {
// The configuration data is applied too, but it isn't a copy rather we just update the pointers
// to point to the original data.
t.Run("ConfigPointers", func(t *testing.T) {
{
// This test skips validation for the `model.ClientSideConfig` value because there isn't a
// reliable way to actually confirm the pointer was changed. Since the size of the data type
// is 0, the Go compiler can do all sorts of optimization schenanigans.
//
// When this scenario fails when we finally add a field to the ClientSideConfig struct, just
// uncomment the relevant parts of the code below.
clientSideConfig := types.ClientSideModelConfig{}
assert.EqualValues(t, 0, reflect.TypeOf(clientSideConfig).Size(), "See comment in the code...")
}
mod := getValidModel()
origClientCfg := mod.ClientSideConfig
origServerCfg := mod.ServerSideConfig
@ -90,8 +78,7 @@ func TestApplyModelOverrides(t *testing.T) {
}
// Confirm the override has different pointers for the model config.
// require.True(t, origClientCfg != override.ClientSideConfig, "orig = %p, override = %p", origClientCfg, override.ClientSideConfig)
// ^-- 0-byte type schenanigans...
require.True(t, origClientCfg != override.ClientSideConfig, "orig = %p, override = %p", origClientCfg, override.ClientSideConfig)
require.True(t, origServerCfg != override.ServerSideConfig)
err := ApplyModelOverride(&mod, override)
@ -100,8 +87,7 @@ func TestApplyModelOverrides(t *testing.T) {
assert.NotNil(t, mod.ClientSideConfig)
assert.NotNil(t, mod.ServerSideConfig)
// assert.True(t, mod.ClientSideConfig != origClientCfg)
// ^-- 0-byte type schenanigans...
assert.True(t, mod.ClientSideConfig != origClientCfg)
assert.True(t, mod.ServerSideConfig != origServerCfg)
assert.True(t, mod.ClientSideConfig == override.ClientSideConfig)

View File

@ -605,6 +605,58 @@ type ChangesetTemplate struct {
// ClientSideModelConfig description: No client-side model configuration is currently available.
type ClientSideModelConfig struct {
Openaicompatible *ClientSideModelConfigOpenAICompatible `json:"openaicompatible,omitempty"`
}
// ClientSideModelConfigOpenAICompatible description: Advanced configuration options that are only respected if the model is provided by an openaicompatible provider.
type ClientSideModelConfigOpenAICompatible struct {
AutoCompleteMultilineMaxTokens int `json:"autoCompleteMultilineMaxTokens,omitempty"`
AutoCompleteSinglelineMaxTokens int `json:"autoCompleteSinglelineMaxTokens,omitempty"`
AutoCompleteTemperature float64 `json:"autoCompleteTemperature,omitempty"`
AutoCompleteTopK float64 `json:"autoCompleteTopK,omitempty"`
AutoCompleteTopP float64 `json:"autoCompleteTopP,omitempty"`
// AutocompleteMultilineTimeout description: How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.
//
// This applies on multi-line completions, which are based on intent-detection when e.g. a code block is being completed, e.g. 'func parseURL(url string) {<completion>'
//
// Note: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.multiline' If user has configured that, it will be respected instead of this.
AutocompleteMultilineTimeout int `json:"autocompleteMultilineTimeout,omitempty"`
// AutocompleteSinglelineTimeout description: How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.
//
// This applies on single-line completions, e.g. 'var i = <completion>'
//
// Note: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.singleline' If user has configured that, it will be respected instead of this.
AutocompleteSinglelineTimeout int `json:"autocompleteSinglelineTimeout,omitempty"`
ChatMaxTokens int `json:"chatMaxTokens,omitempty"`
// ChatPreInstruction description: Custom instruction to be included at the start of all chat messages
// when using this model, e.g. 'Answer all questions in Spanish.'
//
// Note: similar to Cody client config option 'cody.chat.preInstruction'; if user has configured that it will be used instead of this.
ChatPreInstruction string `json:"chatPreInstruction,omitempty"`
ChatTemperature float64 `json:"chatTemperature,omitempty"`
ChatTopK float64 `json:"chatTopK,omitempty"`
ChatTopP float64 `json:"chatTopP,omitempty"`
// ContextSizeHintPrefixCharacters description: A hint the client should use when producing context to send to the LLM.
// The maximum length of the document prefix (text before the cursor) to include, in characters.
ContextSizeHintPrefixCharacters *int `json:"contextSizeHintPrefixCharacters,omitempty"`
// ContextSizeHintSuffixCharacters description: A hint the client should use when producing context to send to the LLM.
// The maximum length of the document suffix (text after the cursor) to include, in characters.
ContextSizeHintSuffixCharacters *int `json:"contextSizeHintSuffixCharacters,omitempty"`
// ContextSizeHintTotalCharacters description: A hint the client should use when producing context to send to the LLM.
// The maximum length of all context (prefix + suffix + snippets), in characters.
ContextSizeHintTotalCharacters *int `json:"contextSizeHintTotalCharacters,omitempty"`
EditMaxTokens int `json:"editMaxTokens,omitempty"`
// EditPostInstruction description: Custom instruction to be included at the end of all edit commands
// when using this model, e.g. 'Write all unit tests with Jest instead of detected framework.'
//
// Note: similar to Cody client config option 'cody.edit.preInstruction'; if user has configured that it will be respected instead of this.
EditPostInstruction string `json:"editPostInstruction,omitempty"`
EditTemperature float64 `json:"editTemperature,omitempty"`
EditTopK float64 `json:"editTopK,omitempty"`
EditTopP float64 `json:"editTopP,omitempty"`
// EndOfText description: End of text identifier used by the model.
EndOfText string `json:"endOfText,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// ClientSideProviderConfig description: No client-side provider configuration is currently available.
@ -2008,6 +2060,10 @@ type OnboardingTourConfiguration struct {
DefaultSnippets map[string]any `json:"defaultSnippets,omitempty"`
Tasks []*OnboardingTask `json:"tasks"`
}
type OpenAICompatibleEndpoint struct {
AccessToken string `json:"accessToken,omitempty"`
Url string `json:"url"`
}
type OpenCodeGraphAnnotation struct {
Item OpenCodeGraphItemRef `json:"item"`
Range OpenCodeGraphRange `json:"range"`
@ -2534,6 +2590,7 @@ type Sentry struct {
}
type ServerSideModelConfig struct {
AwsBedrockProvisionedThroughput *ServerSideModelConfigAwsBedrockProvisionedThroughput
Openaicompatible *ServerSideModelConfigOpenAICompatible
Unused *DoNotUsePhonyDiscriminantType
}
@ -2541,6 +2598,9 @@ func (v ServerSideModelConfig) MarshalJSON() ([]byte, error) {
if v.AwsBedrockProvisionedThroughput != nil {
return json.Marshal(v.AwsBedrockProvisionedThroughput)
}
if v.Openaicompatible != nil {
return json.Marshal(v.Openaicompatible)
}
if v.Unused != nil {
return json.Marshal(v.Unused)
}
@ -2556,10 +2616,12 @@ func (v *ServerSideModelConfig) UnmarshalJSON(data []byte) error {
switch d.DiscriminantProperty {
case "awsBedrockProvisionedThroughput":
return json.Unmarshal(data, &v.AwsBedrockProvisionedThroughput)
case "openaicompatible":
return json.Unmarshal(data, &v.Openaicompatible)
case "unused":
return json.Unmarshal(data, &v.Unused)
}
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrockProvisionedThroughput", "unused"})
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrockProvisionedThroughput", "openaicompatible", "unused"})
}
type ServerSideModelConfigAwsBedrockProvisionedThroughput struct {
@ -2567,6 +2629,13 @@ type ServerSideModelConfigAwsBedrockProvisionedThroughput struct {
Arn string `json:"arn"`
Type string `json:"type"`
}
// ServerSideModelConfigOpenAICompatible description: Configuration that is only respected if the model is provided by an openaicompatible provider.
type ServerSideModelConfigOpenAICompatible struct {
// ApiModel description: The literal string value of the 'model' field that will be sent to the /chat/completions API, for example. If set, Sourcegraph treats this as an opaque string and sends it directly to the API, inferring no information from it. By default, the configured model name is sent.
ApiModel string `json:"apiModel,omitempty"`
Type string `json:"type"`
}
type ServerSideProviderConfig struct {
AwsBedrock *ServerSideProviderConfigAWSBedrock
AzureOpenAI *ServerSideProviderConfigAzureOpenAI
@ -2574,6 +2643,7 @@ type ServerSideProviderConfig struct {
Fireworks *ServerSideProviderConfigFireworksProvider
Google *ServerSideProviderConfigGoogleProvider
Openai *ServerSideProviderConfigOpenAIProvider
HuggingfaceTgi *ServerSideProviderConfigHuggingfaceTGIProvider
Openaicompatible *ServerSideProviderConfigOpenAICompatibleProvider
Sourcegraph *ServerSideProviderConfigSourcegraphProvider
Unused *DoNotUsePhonyDiscriminantType
@ -2598,6 +2668,9 @@ func (v ServerSideProviderConfig) MarshalJSON() ([]byte, error) {
if v.Openai != nil {
return json.Marshal(v.Openai)
}
if v.HuggingfaceTgi != nil {
return json.Marshal(v.HuggingfaceTgi)
}
if v.Openaicompatible != nil {
return json.Marshal(v.Openaicompatible)
}
@ -2627,6 +2700,8 @@ func (v *ServerSideProviderConfig) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &v.Fireworks)
case "google":
return json.Unmarshal(data, &v.Google)
case "huggingface-tgi":
return json.Unmarshal(data, &v.HuggingfaceTgi)
case "openai":
return json.Unmarshal(data, &v.Openai)
case "openaicompatible":
@ -2636,7 +2711,7 @@ func (v *ServerSideProviderConfig) UnmarshalJSON(data []byte) error {
case "unused":
return json.Unmarshal(data, &v.Unused)
}
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrock", "azureOpenAI", "anthropic", "fireworks", "google", "openai", "openaicompatible", "sourcegraph", "unused"})
return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrock", "azureOpenAI", "anthropic", "fireworks", "google", "openai", "huggingface-tgi", "openaicompatible", "sourcegraph", "unused"})
}
type ServerSideProviderConfigAWSBedrock struct {
@ -2674,10 +2749,17 @@ type ServerSideProviderConfigGoogleProvider struct {
Endpoint string `json:"endpoint"`
Type string `json:"type"`
}
type ServerSideProviderConfigHuggingfaceTGIProvider struct {
// EnableVerboseLogs description: Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.
EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"`
Endpoints []*OpenAICompatibleEndpoint `json:"endpoints"`
Type string `json:"type"`
}
type ServerSideProviderConfigOpenAICompatibleProvider struct {
AccessToken string `json:"accessToken"`
Endpoint string `json:"endpoint"`
Type string `json:"type"`
// EnableVerboseLogs description: Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.
EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"`
Endpoints []*OpenAICompatibleEndpoint `json:"endpoints"`
Type string `json:"type"`
}
type ServerSideProviderConfigOpenAIProvider struct {
AccessToken string `json:"accessToken"`

View File

@ -3191,13 +3191,9 @@
"items": {
"type": "string",
"enum": [
"bigcode::v1::starcoder2-3b",
"bigcode::v1::starcoder2-7b",
"bigcode::v1::starcoder2-15b",
"mistral::v1::mistral-7b",
"mistral::v1::mistral-7b-instruct",
"mistral::v1::mixtral-8x7b",
"mistral::v1::mixtral-8x22b",
"mistral::v1::mixtral-8x7b-instruct",
"mistral::v1::mixtral-8x22b-instruct"
]
@ -3466,6 +3462,7 @@
"fireworks",
"google",
"openai",
"huggingface-tgi",
"openaicompatible",
"sourcegraph"
]
@ -3490,6 +3487,9 @@
{
"$ref": "#/definitions/ServerSideProviderConfigOpenAIProvider"
},
{
"$ref": "#/definitions/ServerSideProviderConfigHuggingfaceTGIProvider"
},
{
"$ref": "#/definitions/ServerSideProviderConfigOpenAICompatibleProvider"
},
@ -3613,19 +3613,56 @@
}
}
},
"ServerSideProviderConfigHuggingfaceTGIProvider": {
"type": "object",
"required": ["type", "endpoints"],
"properties": {
"type": {
"type": "string",
"const": "huggingface-tgi"
},
"endpoints": {
"$ref": "#/definitions/OpenAICompatibleEndpoint"
},
"enableVerboseLogs": {
"description": "Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.",
"type": "boolean",
"default": false
}
}
},
"ServerSideProviderConfigOpenAICompatibleProvider": {
"type": "object",
"required": ["type", "accessToken", "endpoint"],
"required": ["type", "endpoints"],
"properties": {
"type": {
"type": "string",
"const": "openaicompatible"
},
"accessToken": {
"type": "string"
"endpoints": {
"$ref": "#/definitions/OpenAICompatibleEndpoint"
},
"endpoint": {
"type": "string"
"enableVerboseLogs": {
"description": "Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.",
"type": "boolean",
"default": false
}
}
},
"OpenAICompatibleEndpoint": {
"description": "Endpoints to connect to. If multiple are specified, Sourcegraph will randomly distribute requests between them.",
"type": "array",
"items": {
"minLength": 1,
"type": "object",
"required": ["url"],
"properties": {
"url": {
"type": "string"
},
"accessToken": {
"type": "string"
}
}
}
},
@ -3652,7 +3689,81 @@
},
"default": null,
"description": "No client-side model configuration is currently available.",
"properties": {}
"properties": {
"openaicompatible": {
"$ref": "#/definitions/ClientSideModelConfigOpenAICompatible"
}
}
},
"ClientSideModelConfigOpenAICompatible": {
"type": "object",
"!go": {
"pointer": true
},
"default": null,
"description": "Advanced configuration options that are only respected if the model is provided by an openaicompatible provider.",
"properties": {
"stopSequences": {
"type": "array",
"items": {
"type": "string",
"description": "List of stop sequences to use for this model.",
"examples": ["\n"]
}
},
"endOfText": {
"type": "string",
"description": "End of text identifier used by the model.",
"examples": ["<|endoftext|>", "<EOT>"]
},
"contextSizeHintTotalCharacters": {
"!go": { "pointer": true },
"default": null,
"type": "integer",
"description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of all context (prefix + suffix + snippets), in characters."
},
"contextSizeHintPrefixCharacters": {
"!go": { "pointer": true },
"default": null,
"type": "integer",
"description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of the document prefix (text before the cursor) to include, in characters."
},
"contextSizeHintSuffixCharacters": {
"!go": { "pointer": true },
"default": null,
"type": "integer",
"description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of the document suffix (text after the cursor) to include, in characters."
},
"chatPreInstruction": {
"type": "string",
"description": "Custom instruction to be included at the start of all chat messages\nwhen using this model, e.g. 'Answer all questions in Spanish.'\n\nNote: similar to Cody client config option 'cody.chat.preInstruction'; if user has configured that it will be used instead of this."
},
"editPostInstruction": {
"type": "string",
"description": "Custom instruction to be included at the end of all edit commands\nwhen using this model, e.g. 'Write all unit tests with Jest instead of detected framework.'\n\nNote: similar to Cody client config option 'cody.edit.preInstruction'; if user has configured that it will be respected instead of this."
},
"autocompleteSinglelineTimeout": {
"type": "integer",
"description": "How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.\n\nThis applies on single-line completions, e.g. 'var i = <completion>'\n\nNote: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.singleline' If user has configured that, it will be respected instead of this."
},
"autocompleteMultilineTimeout": {
"type": "integer",
"description": "How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.\n\nThis applies on multi-line completions, which are based on intent-detection when e.g. a code block is being completed, e.g. 'func parseURL(url string) {<completion>'\n\nNote: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.multiline' If user has configured that, it will be respected instead of this."
},
"chatTopK": { "type": "number" },
"chatTopP": { "type": "number" },
"chatTemperature": { "type": "number" },
"chatMaxTokens": { "type": "integer" },
"autoCompleteTopK": { "type": "number" },
"autoCompleteTopP": { "type": "number" },
"autoCompleteTemperature": { "type": "number" },
"autoCompleteSinglelineMaxTokens": { "type": "integer" },
"autoCompleteMultilineMaxTokens": { "type": "integer" },
"editTopK": { "type": "number" },
"editTopP": { "type": "number" },
"editTemperature": { "type": "number" },
"editMaxTokens": { "type": "integer" }
}
},
"ServerSideModelConfig": {
"type": "object",
@ -3665,13 +3776,16 @@
"properties": {
"type": {
"type": "string",
"enum": ["awsBedrockProvisionedThroughput"]
"enum": ["awsBedrockProvisionedThroughput", "openaicompatible"]
}
},
"oneOf": [
{
"$ref": "#/definitions/ServerSideModelConfigAwsBedrockProvisionedThroughput"
},
{
"$ref": "#/definitions/ServerSideModelConfigOpenAICompatible"
},
{
"$ref": "#/definitions/DoNotUsePhonyDiscriminantType"
}
@ -3691,6 +3805,21 @@
}
}
},
"ServerSideModelConfigOpenAICompatible": {
"description": "Configuration that is only respected if the model is provided by an openaicompatible provider.",
"type": "object",
"required": ["type"],
"properties": {
"type": {
"type": "string",
"const": "openaicompatible"
},
"apiModel": {
"description": "The literal string value of the 'model' field that will be sent to the /chat/completions API, for example. If set, Sourcegraph treats this as an opaque string and sends it directly to the API, inferring no information from it. By default, the configured model name is sent.",
"type": "string"
}
}
},
"DoNotUsePhonyDiscriminantType": {
"type": "object",
"required": ["type"],