diff --git a/cmd/frontend/internal/modelconfig/siteconfig.go b/cmd/frontend/internal/modelconfig/siteconfig.go index 992a9c37427..06454131384 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig.go +++ b/cmd/frontend/internal/modelconfig/siteconfig.go @@ -173,13 +173,18 @@ func convertServerSideProviderConfig(cfg *schema.ServerSideProviderConfig) *type Endpoint: v.Endpoint, }, } - } else if v := cfg.Openaicompatible; v != nil { - // TODO(slimsag): self-hosted-models: map this to OpenAICompatibleProviderConfig in the future + } else if v := cfg.HuggingfaceTgi; v != nil { return &types.ServerSideProviderConfig{ - GenericProvider: &types.GenericProviderConfig{ - ServiceName: types.GenericServiceProviderOpenAI, - AccessToken: v.AccessToken, - Endpoint: v.Endpoint, + OpenAICompatible: &types.OpenAICompatibleProviderConfig{ + Endpoints: convertOpenAICompatibleEndpoints(v.Endpoints), + EnableVerboseLogs: v.EnableVerboseLogs, + }, + } + } else if v := cfg.Openaicompatible; v != nil { + return &types.ServerSideProviderConfig{ + OpenAICompatible: &types.OpenAICompatibleProviderConfig{ + Endpoints: convertOpenAICompatibleEndpoints(v.Endpoints), + EnableVerboseLogs: v.EnableVerboseLogs, }, } } else if v := cfg.Sourcegraph; v != nil { @@ -194,13 +199,57 @@ func convertServerSideProviderConfig(cfg *schema.ServerSideProviderConfig) *type } } +func convertOpenAICompatibleEndpoints(configEndpoints []*schema.OpenAICompatibleEndpoint) []types.OpenAICompatibleEndpoint { + var endpoints []types.OpenAICompatibleEndpoint + for _, e := range configEndpoints { + endpoints = append(endpoints, types.OpenAICompatibleEndpoint{ + URL: e.Url, + AccessToken: e.AccessToken, + }) + } + return endpoints +} + func convertClientSideModelConfig(v *schema.ClientSideModelConfig) *types.ClientSideModelConfig { if v == nil { return nil } - return &types.ClientSideModelConfig{ - // We currently do not have any known client-side model configuration. + cfg := &types.ClientSideModelConfig{} + if o := v.Openaicompatible; o != nil { + cfg.OpenAICompatible = &types.ClientSideModelConfigOpenAICompatible{ + StopSequences: o.StopSequences, + EndOfText: o.EndOfText, + ContextSizeHintTotalCharacters: intPtrToUintPtr(o.ContextSizeHintTotalCharacters), + ContextSizeHintPrefixCharacters: intPtrToUintPtr(o.ContextSizeHintPrefixCharacters), + ContextSizeHintSuffixCharacters: intPtrToUintPtr(o.ContextSizeHintSuffixCharacters), + ChatPreInstruction: o.ChatPreInstruction, + EditPostInstruction: o.EditPostInstruction, + AutocompleteSinglelineTimeout: uint(o.AutocompleteSinglelineTimeout), + AutocompleteMultilineTimeout: uint(o.AutocompleteMultilineTimeout), + ChatTopK: float32(o.ChatTopK), + ChatTopP: float32(o.ChatTopP), + ChatTemperature: float32(o.ChatTemperature), + ChatMaxTokens: uint(o.ChatMaxTokens), + AutoCompleteTopK: float32(o.AutoCompleteTopK), + AutoCompleteTopP: float32(o.AutoCompleteTopP), + AutoCompleteTemperature: float32(o.AutoCompleteTemperature), + AutoCompleteSinglelineMaxTokens: uint(o.AutoCompleteSinglelineMaxTokens), + AutoCompleteMultilineMaxTokens: uint(o.AutoCompleteMultilineMaxTokens), + EditTopK: float32(o.EditTopK), + EditTopP: float32(o.EditTopP), + EditTemperature: float32(o.EditTemperature), + EditMaxTokens: uint(o.EditMaxTokens), + } } + return cfg +} + +func intPtrToUintPtr(v *int) *uint { + if v == nil { + return nil + } + ptr := uint(*v) + return &ptr } func convertServerSideModelConfig(cfg *schema.ServerSideModelConfig) *types.ServerSideModelConfig { @@ -213,6 +262,12 @@ func convertServerSideModelConfig(cfg *schema.ServerSideModelConfig) *types.Serv ARN: v.Arn, }, } + } else if v := cfg.Openaicompatible; v != nil { + return &types.ServerSideModelConfig{ + OpenAICompatible: &types.ServerSideModelConfigOpenAICompatible{ + APIModel: v.ApiModel, + }, + } } else { panic(fmt.Sprintf("illegal state: %+v", v)) } @@ -262,19 +317,14 @@ func convertModelCapabilities(capabilities []string) []types.ModelCapability { // // It would specify these equivalent options for them under `modelOverrides`: var recommendedSettings = map[types.ModelRef]types.ModelOverride{ - "bigcode::v1::starcoder2-3b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-3b", "Starcoder2 3B", "starcoder2-3b"), "bigcode::v1::starcoder2-7b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-7b", "Starcoder2 7B", "starcoder2-7b"), "bigcode::v1::starcoder2-15b": recommendedSettingsStarcoder2("bigcode::v1::starcoder2-15b", "Starcoder2 15B", "starcoder2-15b"), - "mistral::v1::mistral-7b": recommendedSettingsMistral("mistral::v1::mistral-7b", "Mistral 7B", "mistral-7b"), "mistral::v1::mistral-7b-instruct": recommendedSettingsMistral("mistral::v1::mistral-7b-instruct", "Mistral 7B Instruct", "mistral-7b-instruct"), - "mistral::v1::mixtral-8x7b": recommendedSettingsMistral("mistral::v1::mixtral-8x7b", "Mixtral 8x7B", "mixtral-8x7b"), - "mistral::v1::mixtral-8x22b": recommendedSettingsMistral("mistral::v1::mixtral-8x22b", "Mixtral 8x22B", "mixtral-8x22b"), "mistral::v1::mixtral-8x7b-instruct": recommendedSettingsMistral("mistral::v1::mixtral-8x7b-instruct", "Mixtral 8x7B Instruct", "mixtral-8x7b-instruct"), "mistral::v1::mixtral-8x22b-instruct": recommendedSettingsMistral("mistral::v1::mixtral-8x22b", "Mixtral 8x22B", "mixtral-8x22b-instruct"), } func recommendedSettingsStarcoder2(modelRef, displayName, modelName string) types.ModelOverride { - // TODO(slimsag): self-hosted-models: tune these further based on testing return types.ModelOverride{ ModelRef: types.ModelRef(modelRef), DisplayName: displayName, @@ -285,15 +335,18 @@ func recommendedSettingsStarcoder2(modelRef, displayName, modelName string) type Tier: types.ModelTierEnterprise, ContextWindow: types.ContextWindow{ MaxInputTokens: 8192, - MaxOutputTokens: 4000, + MaxOutputTokens: 4096, + }, + ClientSideConfig: &types.ClientSideModelConfig{ + OpenAICompatible: &types.ClientSideModelConfigOpenAICompatible{ + StopSequences: []string{"<|endoftext|>", ""}, + EndOfText: "<|endoftext|>", + }, }, - ClientSideConfig: nil, - ServerSideConfig: nil, } } func recommendedSettingsMistral(modelRef, displayName, modelName string) types.ModelOverride { - // TODO(slimsag): self-hosted-models: tune these further based on testing return types.ModelOverride{ ModelRef: types.ModelRef(modelRef), DisplayName: displayName, @@ -304,9 +357,10 @@ func recommendedSettingsMistral(modelRef, displayName, modelName string) types.M Tier: types.ModelTierEnterprise, ContextWindow: types.ContextWindow{ MaxInputTokens: 8192, - MaxOutputTokens: 4000, + MaxOutputTokens: 4096, + }, + ClientSideConfig: &types.ClientSideModelConfig{ + OpenAICompatible: &types.ClientSideModelConfigOpenAICompatible{}, }, - ClientSideConfig: nil, - ServerSideConfig: nil, } } diff --git a/cmd/frontend/internal/modelconfig/siteconfig_completions.go b/cmd/frontend/internal/modelconfig/siteconfig_completions.go index a1de1f760c4..4d4fd4ad2eb 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig_completions.go +++ b/cmd/frontend/internal/modelconfig/siteconfig_completions.go @@ -160,8 +160,8 @@ func getProviderConfiguration(siteConfig *conftypes.CompletionsConfig) *types.Se Endpoint: siteConfig.Endpoint, } - // For all the other types of providers you can define in the site configuration, we - // just use a generic config. Rather than creating one for Anthropic, Fireworks, Google, etc. + // For all the other types of providers you can define in the legacy "completions" site configuration, + // we just use a generic config. Rather than creating one for Anthropic, Fireworks, Google, etc. // We'll add those when needed, when we expose the newer style configuration in the site-config. default: serverSideConfig.GenericProvider = &types.GenericProviderConfig{ diff --git a/deps.bzl b/deps.bzl index 2c81abb5d82..5e6609586e8 100644 --- a/deps.bzl +++ b/deps.bzl @@ -6237,6 +6237,13 @@ def go_dependencies(): sum = "h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=", version = "v0.6.1", ) + go_repository( + name = "com_github_tmaxmax_go_sse", + build_file_proto_mode = "disable_global", + importpath = "github.com/tmaxmax/go-sse", + sum = "h1:pPpTgyyi1r7vG2o6icebnpGEh3ebcnBXqDWkb7aTofs=", + version = "v0.8.0", + ) go_repository( name = "com_github_tmc_dot", build_file_proto_mode = "disable_global", diff --git a/go.mod b/go.mod index 8ca86cb6b69..03d8d4d6938 100644 --- a/go.mod +++ b/go.mod @@ -318,6 +318,7 @@ require ( github.com/sourcegraph/sourcegraph/lib v0.0.0-20240524140455-2589fef13ea8 github.com/sourcegraph/sourcegraph/lib/managedservicesplatform v0.0.0-00010101000000-000000000000 github.com/sourcegraph/sourcegraph/monitoring v0.0.0-00010101000000-000000000000 + github.com/tmaxmax/go-sse v0.8.0 github.com/vektah/gqlparser/v2 v2.4.5 github.com/vvakame/gcplogurl v0.2.0 go.opentelemetry.io/collector/config/confighttp v0.103.0 diff --git a/go.sum b/go.sum index ba16b955f05..7fed469c51d 100644 --- a/go.sum +++ b/go.sum @@ -2410,6 +2410,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/tmaxmax/go-sse v0.8.0 h1:pPpTgyyi1r7vG2o6icebnpGEh3ebcnBXqDWkb7aTofs= +github.com/tmaxmax/go-sse v0.8.0/go.mod h1:HLoxqxdH+7oSUItjtnpxjzJedfr/+Rrm/dNWBcTxJFM= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= diff --git a/internal/completions/client/BUILD.bazel b/internal/completions/client/BUILD.bazel index c1c405a217a..9e4be6e1291 100644 --- a/internal/completions/client/BUILD.bazel +++ b/internal/completions/client/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//internal/completions/client/fireworks", "//internal/completions/client/google", "//internal/completions/client/openai", + "//internal/completions/client/openaicompatible", "//internal/completions/tokenusage", "//internal/completions/types", "//internal/httpcli", diff --git a/internal/completions/client/client.go b/internal/completions/client/client.go index 83c1b8215d0..0f5364812e2 100644 --- a/internal/completions/client/client.go +++ b/internal/completions/client/client.go @@ -10,6 +10,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/completions/client/fireworks" "github.com/sourcegraph/sourcegraph/internal/completions/client/google" "github.com/sourcegraph/sourcegraph/internal/completions/client/openai" + "github.com/sourcegraph/sourcegraph/internal/completions/client/openaicompatible" "github.com/sourcegraph/sourcegraph/internal/completions/tokenusage" "github.com/sourcegraph/sourcegraph/internal/completions/types" "github.com/sourcegraph/sourcegraph/internal/httpcli" @@ -64,6 +65,11 @@ func getAPIProvider(modelConfigInfo types.ModelConfigInfo) (types.CompletionsCli return client, errors.Wrap(err, "getting api provider") } + // OpenAI Compatible + if openAICompatibleCfg := ssConfig.OpenAICompatible; openAICompatibleCfg != nil { + return openaicompatible.NewClient(httpcli.UncachedExternalClient, *tokenManager), nil + } + // The "GenericProvider" is an escape hatch for a set of API Providers not needing any additional configuration. if genProviderCfg := ssConfig.GenericProvider; genProviderCfg != nil { token := genProviderCfg.AccessToken diff --git a/internal/completions/client/openaicompatible/BUILD.bazel b/internal/completions/client/openaicompatible/BUILD.bazel new file mode 100644 index 00000000000..2c38c9aea39 --- /dev/null +++ b/internal/completions/client/openaicompatible/BUILD.bazel @@ -0,0 +1,20 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "openaicompatible", + srcs = [ + "openaicompatible.go", + "types.go", + ], + importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/openaicompatible", + visibility = ["//:__subpackages__"], + deps = [ + "//internal/completions/tokenizer", + "//internal/completions/tokenusage", + "//internal/completions/types", + "//internal/modelconfig/types", + "//lib/errors", + "@com_github_sourcegraph_log//:log", + "@com_github_tmaxmax_go_sse//:go-sse", + ], +) diff --git a/internal/completions/client/openaicompatible/openaicompatible.go b/internal/completions/client/openaicompatible/openaicompatible.go new file mode 100644 index 00000000000..3e40ca21afa --- /dev/null +++ b/internal/completions/client/openaicompatible/openaicompatible.go @@ -0,0 +1,501 @@ +package openaicompatible + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "net/url" + "path" + "reflect" + "strings" + "time" + + "github.com/sourcegraph/log" + + sse "github.com/tmaxmax/go-sse" + + "github.com/sourcegraph/sourcegraph/internal/completions/tokenizer" + "github.com/sourcegraph/sourcegraph/internal/completions/tokenusage" + "github.com/sourcegraph/sourcegraph/internal/completions/types" + modelconfigSDK "github.com/sourcegraph/sourcegraph/internal/modelconfig/types" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func NewClient( + cli *http.Client, + tokenManager tokenusage.Manager, +) types.CompletionsClient { + return &client{ + cli: cli, + tokenManager: tokenManager, + rng: rand.New(rand.NewSource(time.Now().Unix())), + } +} + +type client struct { + cli *http.Client + tokenManager tokenusage.Manager + rng *rand.Rand +} + +func (c *client) Complete( + ctx context.Context, + logger log.Logger, + request types.CompletionRequest, +) (*types.CompletionResponse, error) { + logger = logger.Scoped("OpenAICompatible") + + var resp *http.Response + defer (func() { + if resp != nil { + resp.Body.Close() + } + })() + + var ( + req *http.Request + reqBody string + err error + ) + if request.Feature == types.CompletionsFeatureCode { + req, reqBody, err = c.makeCompletionRequest(ctx, request, false) + } else { + req, reqBody, err = c.makeChatRequest(ctx, request, false) + } + if err != nil { + return nil, errors.Wrap(err, "making request") + } + + requestID := c.rng.Uint32() + providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible + if providerConfig.EnableVerboseLogs { + logger.Info("request", + log.Uint32("id", requestID), + log.String("kind", "non-streaming"), + log.String("method", req.Method), + log.String("url", req.URL.String()), + // Note: log package will automatically redact token + log.String("headers", fmt.Sprint(req.Header)), + log.String("body", reqBody), + ) + } + start := time.Now() + resp, err = c.cli.Do(req) + if err != nil { + logger.Error("request error", + log.Uint32("id", requestID), + log.Error(err), + ) + return nil, errors.Wrap(err, "performing request") + } + if resp.StatusCode != http.StatusOK { + err := types.NewErrStatusNotOK("OpenAI", resp) + logger.Error("request error", + log.Uint32("id", requestID), + log.Error(err), + ) + return nil, err + } + defer resp.Body.Close() + + var response openaiResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + logger.Error("request error, decoding response", + log.Uint32("id", requestID), + log.Error(err), + ) + return nil, errors.Wrap(err, "decoding response") + } + if providerConfig.EnableVerboseLogs { + // When debugging connections, log more verbose information like the actual completion we got back. + completion := "" + if len(response.Choices) > 0 { + completion = response.Choices[0].Text + } + logger.Info("request success", + log.Uint32("id", requestID), + log.Duration("time", time.Since(start)), + log.String("response_model", response.Model), + log.String("url", req.URL.String()), + log.String("system_fingerprint", response.SystemFingerprint), + log.String("finish_reason", response.maybeGetFinishReason()), + log.String("completion", completion), + ) + } else { + logger.Info("request success", + log.Uint32("id", requestID), + log.Duration("time", time.Since(start)), + log.String("response_model", response.Model), + log.String("url", req.URL.String()), + log.String("system_fingerprint", response.SystemFingerprint), + log.String("finish_reason", response.maybeGetFinishReason()), + ) + } + + if len(response.Choices) == 0 { + // Empty response. + return &types.CompletionResponse{}, nil + } + + modelID := request.ModelConfigInfo.Model.ModelRef.ModelID() + err = c.tokenManager.UpdateTokenCountsFromModelUsage( + response.Usage.PromptTokens, + response.Usage.CompletionTokens, + tokenizer.OpenAIModel+"/"+string(modelID), + string(request.Feature), + tokenusage.OpenAICompatible) + if err != nil { + logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err)) + } + return &types.CompletionResponse{ + Completion: response.Choices[0].Text, + StopReason: response.Choices[0].FinishReason, + }, nil +} + +func (c *client) Stream( + ctx context.Context, + logger log.Logger, + request types.CompletionRequest, + sendEvent types.SendCompletionEvent, +) error { + logger = logger.Scoped("OpenAICompatible") + + var ( + req *http.Request + reqBody string + err error + ) + if request.Feature == types.CompletionsFeatureCode { + req, reqBody, err = c.makeCompletionRequest(ctx, request, true) + } else { + req, reqBody, err = c.makeChatRequest(ctx, request, true) + } + if err != nil { + return errors.Wrap(err, "making request") + } + + sseClient := &sse.Client{ + HTTPClient: c.cli, + ResponseValidator: sse.DefaultValidator, + Backoff: sse.Backoff{ + // Note: go-sse has a bug with retry logic (https://github.com/tmaxmax/go-sse/pull/38) + // where it will get stuck in an infinite retry loop due to an io.EOF error + // depending on how the server behaves. For now, we just do not expose retry/backoff + // logic. It's not really useful for these types of requests anyway given their + // short-lived nature. + MaxRetries: -1, + }, + } + ctx, cancel := context.WithCancel(ctx) + conn := sseClient.NewConnection(req.WithContext(ctx)) + + var ( + content string + ev types.CompletionResponse + promptTokens, completionTokens int + streamErr error + finishReason string + ) + unsubscribe := conn.SubscribeMessages(func(event sse.Event) { + // Ignore any data that is not JSON-like + if !strings.HasPrefix(event.Data, "{") { + return + } + + var resp openaiResponse + if err := json.Unmarshal([]byte(event.Data), &resp); err != nil { + streamErr = errors.Errorf("failed to decode event payload: %w - body: %s", err, event.Data) + cancel() + return + } + + if reflect.DeepEqual(resp, openaiResponse{}) { + // Empty response, it may be an error payload then + var errResp openaiErrorResponse + if err := json.Unmarshal([]byte(event.Data), &errResp); err != nil { + streamErr = errors.Errorf("failed to decode error event payload: %w - body: %s", err, event.Data) + cancel() + return + } + if errResp.Error != "" || errResp.ErrorType != "" { + streamErr = errors.Errorf("SSE error: %s: %s", errResp.ErrorType, errResp.Error) + cancel() + return + } + } + + // These are only included in the last message, so we're not worried about overwriting + if resp.Usage.PromptTokens > 0 { + promptTokens = resp.Usage.PromptTokens + } + if resp.Usage.CompletionTokens > 0 { + completionTokens = resp.Usage.CompletionTokens + } + + if len(resp.Choices) > 0 { + if request.Feature == types.CompletionsFeatureCode { + content += resp.Choices[0].Text + } else { + content += resp.Choices[0].Delta.Content + } + ev = types.CompletionResponse{ + Completion: content, + StopReason: resp.Choices[0].FinishReason, + } + err = sendEvent(ev) + if err != nil { + streamErr = errors.Errorf("failed to send event: %w", err) + cancel() + return + } + for _, choice := range resp.Choices { + if choice.FinishReason != "" { + // End of stream + finishReason = choice.FinishReason + streamErr = nil + cancel() + return + } + } + } + }) + defer unsubscribe() + + requestID := c.rng.Uint32() + providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible + if providerConfig.EnableVerboseLogs { + logger.Info("request", + log.Uint32("id", requestID), + log.String("kind", "streaming"), + log.String("method", req.Method), + log.String("url", req.URL.String()), + // Note: log package will automatically redact token + log.String("headers", fmt.Sprint(req.Header)), + log.String("body", reqBody), + ) + } + start := time.Now() + err = conn.Connect() + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + // go-sse will return io.EOF on successful close of the connection, since it expects the + // connection to be long-lived. In our case, we expect the connection to close on success + // and be short lived, so this is a non-error. + err = nil + } + if streamErr != nil { + err = errors.Append(err, streamErr) + } + if err == nil && finishReason == "" { + // At this point, we successfully streamed the response to the client. But we need to make + // sure the client gets a non-empty StopReason at the very end, otherwise it would think + // the streamed response it got is partial / incomplete and may not display the completion + // to the user as a result. + err = sendEvent(types.CompletionResponse{ + Completion: content, + StopReason: "stop_sequence", // pretend we hit a stop sequence (we did!) + }) + } + if err != nil { + logger.Error("request error", + log.Uint32("id", requestID), + log.Error(err), + ) + return errors.Wrap(err, "NewConnection") + } + + if providerConfig.EnableVerboseLogs { + // When debugging connections, log more verbose information like the actual completion we got back. + logger.Info("request success", + log.Uint32("id", requestID), + log.Duration("time", time.Since(start)), + log.String("url", req.URL.String()), + log.String("finish_reason", finishReason), + log.String("completion", content), + ) + } else { + logger.Info("request success", + log.Uint32("id", requestID), + log.Duration("time", time.Since(start)), + log.String("url", req.URL.String()), + log.String("finish_reason", finishReason), + ) + } + + modelID := request.ModelConfigInfo.Model.ModelRef.ModelID() + err = c.tokenManager.UpdateTokenCountsFromModelUsage( + promptTokens, + completionTokens, + tokenizer.OpenAIModel+"/"+string(modelID), + string(request.Feature), + tokenusage.OpenAICompatible, + ) + if err != nil { + logger.Warn("Failed to count tokens with the token manager %w", log.Error(err)) + } + return nil +} + +func (c *client) makeChatRequest( + ctx context.Context, + request types.CompletionRequest, + stream bool, +) (*http.Request, string, error) { + requestParams := request.Parameters + if requestParams.TopK < 0 { + requestParams.TopK = 0 + } + if requestParams.TopP < 0 { + requestParams.TopP = 0 + } + + payload := openAIChatCompletionsRequestParameters{ + Model: getAPIModel(request), + Temperature: requestParams.Temperature, + TopP: requestParams.TopP, + N: requestParams.TopK, + Stream: stream, + MaxTokens: requestParams.MaxTokensToSample, + Stop: requestParams.StopSequences, + } + for _, m := range requestParams.Messages { + var role string + switch m.Speaker { + case types.SYSTEM_MESSAGE_SPEAKER: + role = "system" + case types.HUMAN_MESSAGE_SPEAKER: + role = "user" + case types.ASSISTANT_MESSAGE_SPEAKER: + role = "assistant" + default: + role = strings.ToLower(role) + } + payload.Messages = append(payload.Messages, message{ + Role: role, + Content: m.Text, + }) + } + + reqBody, err := json.Marshal(payload) + if err != nil { + return nil, "", errors.Wrap(err, "Marshal") + } + + endpoint, err := getEndpoint(request, c.rng) + if err != nil { + return nil, "", errors.Wrap(err, "getEndpoint") + } + url, err := getEndpointURL(endpoint, "chat/completions") + if err != nil { + return nil, "", errors.Wrap(err, "getEndpointURL") + } + req, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(reqBody)) + if err != nil { + return nil, "", errors.Wrap(err, "NewRequestWithContext") + } + + req.Header.Set("Content-Type", "application/json") + if endpoint.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+endpoint.AccessToken) + } + return req, string(reqBody), nil +} + +func (c *client) makeCompletionRequest( + ctx context.Context, + request types.CompletionRequest, + stream bool, +) (*http.Request, string, error) { + requestParams := request.Parameters + if requestParams.TopK < 0 { + requestParams.TopK = 0 + } + if requestParams.TopP < 0 { + requestParams.TopP = 0 + } + + prompt, err := getPrompt(requestParams.Messages) + if err != nil { + return nil, "", errors.Wrap(err, "getPrompt") + } + + payload := openAICompletionsRequestParameters{ + Model: getAPIModel(request), + Temperature: requestParams.Temperature, + TopP: requestParams.TopP, + N: requestParams.TopK, + Stream: stream, + MaxTokens: requestParams.MaxTokensToSample, + Stop: requestParams.StopSequences, + Prompt: prompt, + } + + reqBody, err := json.Marshal(payload) + if err != nil { + return nil, "", errors.Wrap(err, "Marshal") + } + + endpoint, err := getEndpoint(request, c.rng) + if err != nil { + return nil, "", errors.Wrap(err, "getEndpoint") + } + url, err := getEndpointURL(endpoint, "completions") + if err != nil { + return nil, "", errors.Wrap(err, "getEndpointURL") + } + + req, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(reqBody)) + if err != nil { + return nil, "", errors.Wrap(err, "NewRequestWithContext") + } + + req.Header.Set("Content-Type", "application/json") + if endpoint.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+endpoint.AccessToken) + } + return req, string(reqBody), nil +} + +func getPrompt(messages []types.Message) (string, error) { + if l := len(messages); l == 0 { + return "", errors.New("found zero messages in prompt") + } + return messages[0].Text, nil +} + +func getAPIModel(request types.CompletionRequest) string { + ssConfig := request.ModelConfigInfo.Model.ServerSideConfig + if ssConfig != nil && ssConfig.OpenAICompatible != nil && ssConfig.OpenAICompatible.APIModel != "" { + return ssConfig.OpenAICompatible.APIModel + } + // Default to model name if not specified + return request.ModelConfigInfo.Model.ModelName +} + +func getEndpoint(request types.CompletionRequest, rng *rand.Rand) (modelconfigSDK.OpenAICompatibleEndpoint, error) { + providerConfig := request.ModelConfigInfo.Provider.ServerSideConfig.OpenAICompatible + if len(providerConfig.Endpoints) == 0 { + return modelconfigSDK.OpenAICompatibleEndpoint{}, errors.New("no openaicompatible endpoint configured") + } + if len(providerConfig.Endpoints) == 1 { + return providerConfig.Endpoints[0], nil + } + randPick := rng.Intn(len(providerConfig.Endpoints)) + return providerConfig.Endpoints[randPick], nil +} + +func getEndpointURL(endpoint modelconfigSDK.OpenAICompatibleEndpoint, relativePath string) (*url.URL, error) { + url, err := url.Parse(endpoint.URL) + if err != nil { + return nil, errors.Newf("failed to parse endpoint URL: %q", endpoint.URL) + } + if url.Scheme == "" || url.Host == "" { + return nil, errors.Newf("unable to build URL, bad endpoint: %q", endpoint.URL) + } + url.Path = path.Join(url.Path, relativePath) + return url, nil +} diff --git a/internal/completions/client/openaicompatible/types.go b/internal/completions/client/openaicompatible/types.go new file mode 100644 index 00000000000..ac10d7f2d09 --- /dev/null +++ b/internal/completions/client/openaicompatible/types.go @@ -0,0 +1,77 @@ +package openaicompatible + +// openAIChatCompletionsRequestParameters request object for openAI chat endpoint https://platform.openai.com/docs/api-reference/chat/create +type openAIChatCompletionsRequestParameters struct { + Model string `json:"model"` // request.Model + Messages []message `json:"messages"` // request.Messages + Temperature float32 `json:"temperature,omitempty"` // request.Temperature + TopP float32 `json:"top_p,omitempty"` // request.TopP + N int `json:"n,omitempty"` // always 1 + Stream bool `json:"stream,omitempty"` // request.Stream + Stop []string `json:"stop,omitempty"` // request.StopSequences + MaxTokens int `json:"max_tokens,omitempty"` // request.MaxTokensToSample + PresencePenalty float32 `json:"presence_penalty,omitempty"` // unused + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // unused + LogitBias map[string]float32 `json:"logit_bias,omitempty"` // unused + User string `json:"user,omitempty"` // unused +} + +// openAICompletionsRequestParameters payload for openAI completions endpoint https://platform.openai.com/docs/api-reference/completions/create +type openAICompletionsRequestParameters struct { + Model string `json:"model"` // request.Model + Prompt string `json:"prompt"` // request.Messages[0] - formatted prompt expected to be the only message + Temperature float32 `json:"temperature,omitempty"` // request.Temperature + TopP float32 `json:"top_p,omitempty"` // request.TopP + N int `json:"n,omitempty"` // always 1 + Stream bool `json:"stream,omitempty"` // request.Stream + Stop []string `json:"stop,omitempty"` // request.StopSequences + MaxTokens int `json:"max_tokens,omitempty"` // request.MaxTokensToSample + PresencePenalty float32 `json:"presence_penalty,omitempty"` // unused + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // unused + LogitBias map[string]float32 `json:"logit_bias,omitempty"` // unused + Suffix string `json:"suffix,omitempty"` // unused + User string `json:"user,omitempty"` // unused +} + +type message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openaiUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type openaiChoiceDelta struct { + Content string `json:"content"` +} + +type openaiChoice struct { + Delta openaiChoiceDelta `json:"delta"` + Role string `json:"role"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type openaiResponse struct { + // Usage is only available for non-streaming requests. + Usage openaiUsage `json:"usage"` + Model string `json:"model"` + Choices []openaiChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` +} + +func (r *openaiResponse) maybeGetFinishReason() string { + if len(r.Choices) == 0 { + return "" + } + return r.Choices[len(r.Choices)-1].FinishReason +} + +// e.g. {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 159 `inputs` tokens and 4000 `max_new_tokens`","error_type":"validation"} +type openaiErrorResponse struct { + Error string `json:"error"` + ErrorType string `json:"error_type"` +} diff --git a/internal/completions/tokenizer/BUILD.bazel b/internal/completions/tokenizer/BUILD.bazel index f0b61a41ae0..249213ff087 100644 --- a/internal/completions/tokenizer/BUILD.bazel +++ b/internal/completions/tokenizer/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "//internal/completions/client/anthropic:__pkg__", "//internal/completions/client/azureopenai:__pkg__", "//internal/completions/client/openai:__pkg__", + "//internal/completions/client/openaicompatible:__pkg__", "//internal/completions/tokenusage:__pkg__", ], deps = [ diff --git a/internal/completions/tokenusage/BUILD.bazel b/internal/completions/tokenusage/BUILD.bazel index 941fac99642..dea5064dcb9 100644 --- a/internal/completions/tokenusage/BUILD.bazel +++ b/internal/completions/tokenusage/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//internal/completions/client/azureopenai:__pkg__", "//internal/completions/client/codygateway:__pkg__", "//internal/completions/client/openai:__pkg__", + "//internal/completions/client/openaicompatible:__pkg__", "//internal/updatecheck:__pkg__", ], deps = [ diff --git a/internal/completions/tokenusage/tokenusage.go b/internal/completions/tokenusage/tokenusage.go index 40c2bfed9ed..9750fec36ee 100644 --- a/internal/completions/tokenusage/tokenusage.go +++ b/internal/completions/tokenusage/tokenusage.go @@ -27,10 +27,11 @@ func NewManager() *Manager { type Provider string const ( - OpenAI Provider = "openai" - AzureOpenAI Provider = "azureopenai" - AwsBedrock Provider = "awsbedrock" - Anthropic Provider = "anthropic" + OpenAI Provider = "openai" + OpenAICompatible Provider = "openaicompatible" + AzureOpenAI Provider = "azureopenai" + AwsBedrock Provider = "awsbedrock" + Anthropic Provider = "anthropic" ) func (m *Manager) UpdateTokenCountsFromModelUsage(inputTokens, outputTokens int, model, feature string, provider Provider) error { diff --git a/internal/conf/conftypes/consts.go b/internal/conf/conftypes/consts.go index 0b999b85bb2..d44595e8db9 100644 --- a/internal/conf/conftypes/consts.go +++ b/internal/conf/conftypes/consts.go @@ -45,13 +45,14 @@ type ConfigFeatures struct { type CompletionsProviderName string const ( - CompletionsProviderNameAnthropic CompletionsProviderName = "anthropic" - CompletionsProviderNameOpenAI CompletionsProviderName = "openai" - CompletionsProviderNameGoogle CompletionsProviderName = "google" - CompletionsProviderNameAzureOpenAI CompletionsProviderName = "azure-openai" - CompletionsProviderNameSourcegraph CompletionsProviderName = "sourcegraph" - CompletionsProviderNameFireworks CompletionsProviderName = "fireworks" - CompletionsProviderNameAWSBedrock CompletionsProviderName = "aws-bedrock" + CompletionsProviderNameAnthropic CompletionsProviderName = "anthropic" + CompletionsProviderNameOpenAI CompletionsProviderName = "openai" + CompletionsProviderNameGoogle CompletionsProviderName = "google" + CompletionsProviderNameAzureOpenAI CompletionsProviderName = "azure-openai" + CompletionsProviderNameOpenAICompatible CompletionsProviderName = "openai-compatible" + CompletionsProviderNameSourcegraph CompletionsProviderName = "sourcegraph" + CompletionsProviderNameFireworks CompletionsProviderName = "fireworks" + CompletionsProviderNameAWSBedrock CompletionsProviderName = "aws-bedrock" ) type EmbeddingsConfig struct { diff --git a/internal/modelconfig/types/configuration.go b/internal/modelconfig/types/configuration.go index 14637ff6328..745c3f5fa87 100644 --- a/internal/modelconfig/types/configuration.go +++ b/internal/modelconfig/types/configuration.go @@ -72,6 +72,28 @@ type GenericProviderConfig struct { Endpoint string `json:"endpoint"` } +// OpenAICompatibleProvider is a provider for connecting to OpenAI-compatible API endpoints +// supplied by various third-party software. +// +// Because many of these third-party providers provide slightly different semantics for the OpenAI API +// protocol, the Sourcegraph instance exposes this provider configuration which allows for much more +// extensive configuration than would be needed for the official OpenAI API. +type OpenAICompatibleProviderConfig struct { + // Endpoints where this API can be reached. If multiple are present, Sourcegraph will distribute + // load between them as it sees fit. + Endpoints []OpenAICompatibleEndpoint `json:"endpoints,omitempty"` + + // Whether to enable verbose logging of requests, allowing for grepping the logs for "OpenAICompatible" + // and seeing e.g. what requests Cody is actually sending to your API endpoint. + EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"` +} + +// A single API endpoint for an OpenAI-compatible API. +type OpenAICompatibleEndpoint struct { + URL string `json:"url"` + AccessToken string `json:"accessToken"` +} + // SourcegraphProviderConfig is the configuration blog for configuring a provider // to be use Sourcegraph's Cody Gateway for requests. type SourcegraphProviderConfig struct { @@ -82,25 +104,99 @@ type SourcegraphProviderConfig struct { // The "Provider" is conceptually a namespace for models. The server-side provider configuration // is needed to describe the API endpoint needed to serve its models. type ServerSideProviderConfig struct { - AWSBedrock *AWSBedrockProviderConfig `json:"awsBedrock,omitempty"` - AzureOpenAI *AzureOpenAIProviderConfig `json:"azureOpenAi,omitempty"` - GenericProvider *GenericProviderConfig `json:"genericProvider,omitempty"` - SourcegraphProvider *SourcegraphProviderConfig `json:"sourcegraphProvider,omitempty"` + AWSBedrock *AWSBedrockProviderConfig `json:"awsBedrock,omitempty"` + AzureOpenAI *AzureOpenAIProviderConfig `json:"azureOpenAi,omitempty"` + OpenAICompatible *OpenAICompatibleProviderConfig `json:"openAICompatible,omitempty"` + GenericProvider *GenericProviderConfig `json:"genericProvider,omitempty"` + SourcegraphProvider *SourcegraphProviderConfig `json:"sourcegraphProvider,omitempty"` } // ======================================================== // Client-side Model Configuration Data // ======================================================== +// Anything that needs to be provided to Cody clients at the model-level can go here. +// +// For example, allowing the server to customize/override the LLM +// prompt used. Or describe how clients should upload context to +// remote servers, etc. Or "hints", like "this model is great when +// working with 'C' code.". type ClientSideModelConfig struct { - // We currently do not have any known client-side model configuration. - // But later, if anything needs to be provided to Cody clients at the - // model-level it will go here. + OpenAICompatible *ClientSideModelConfigOpenAICompatible `json:"openAICompatible,omitempty"` +} + +// Client-side model configuration used when the model is backed by an OpenAI-compatible API +// provider. +type ClientSideModelConfigOpenAICompatible struct { + // (optional) List of stop sequences to use for this model. + StopSequences []string `json:"stopSequences,omitempty"` + + // (optional) EndOfText identifier used by the model. e.g. "<|endoftext|>", "" + EndOfText string `json:"endOfText,omitempty"` + + // (optional) A hint the client should use when producing context to send to the LLM. + // The maximum length of all context (prefix + suffix + snippets), in characters. + ContextSizeHintTotalCharacters *uint `json:"contextSizeHintTotalCharacters,omitempty"` + + // (optional) A hint the client should use when producing context to send to the LLM. + // The maximum length of the document prefix (text before the cursor) to include, in characters. + ContextSizeHintPrefixCharacters *uint `json:"contextSizeHintPrefixCharacters,omitempty"` + + // (optional) A hint the client should use when producing context to send to the LLM. + // The maximum length of the document suffix (text after the cursor) to include, in characters. + ContextSizeHintSuffixCharacters *uint `json:"contextSizeHintSuffixCharacters,omitempty"` + + // (optional) Custom instruction to be included at the start of all chat messages + // when using this model, e.g. "Answer all questions in Spanish." // - // For example, allowing the server to customize/override the LLM - // prompt used. Or describe how clients should upload context to - // remote servers, etc. Or "hints", like "this model is great when - // working with 'C' code.". + // Note: similar to Cody client config option `cody.chat.preInstruction`; if user has + // configured that it will be used instead of this. + ChatPreInstruction string `json:"chatPreInstruction,omitempty"` + + // (optional) Custom instruction to be included at the end of all edit commands + // when using this model, e.g. "Write all unit tests with Jest instead of detected framework." + // + // Note: similar to Cody client config option `cody.edit.preInstruction`; if user has + // configured that it will be respected instead of this. + EditPostInstruction string `json:"editPostInstruction,omitempty"` + + // (optional) How long the client should wait for autocomplete results to come back (milliseconds), + // before giving up and not displaying an autocomplete result at all. + // + // This applies on single-line completions, e.g. `var i = ` + // + // Note: similar to hidden Cody client config option `cody.autocomplete.advanced.timeout.singleline` + // If user has configured that, it will be respected instead of this. + AutocompleteSinglelineTimeout uint `json:"autocompleteSinglelineTimeout,omitempty"` + + // (optional) How long the client should wait for autocomplete results to come back (milliseconds), + // before giving up and not displaying an autocomplete result at all. + // + // This applies on multi-line completions, which are based on intent-detection when e.g. a code block + // is being completed, e.g. `func parseURL(url string) {` + // + // Note: similar to hidden Cody client config option `cody.autocomplete.advanced.timeout.multiline` + // If user has configured that, it will be respected instead of this. + AutocompleteMultilineTimeout uint `json:"autocompleteMultilineTimeout,omitempty"` + + // (optional) model parameters to use for the chat feature + ChatTopK float32 `json:"chatTopK,omitempty"` + ChatTopP float32 `json:"chatTopP,omitempty"` + ChatTemperature float32 `json:"chatTemperature,omitempty"` + ChatMaxTokens uint `json:"chatMaxTokens,omitempty"` + + // (optional) model parameters to use for the autocomplete feature + AutoCompleteTopK float32 `json:"autoCompleteTopK,omitempty"` + AutoCompleteTopP float32 `json:"autoCompleteTopP,omitempty"` + AutoCompleteTemperature float32 `json:"autoCompleteTemperature,omitempty"` + AutoCompleteSinglelineMaxTokens uint `json:"autoCompleteSinglelineMaxTokens,omitempty"` + AutoCompleteMultilineMaxTokens uint `json:"autoCompleteMultilineMaxTokens,omitempty"` + + // (optional) model parameters to use for the edit feature + EditTopK float32 `json:"editTopK,omitempty"` + EditTopP float32 `json:"editTopP,omitempty"` + EditTemperature float32 `json:"editTemperature,omitempty"` + EditMaxTokens uint `json:"editMaxTokens,omitempty"` } // ======================================================== @@ -116,6 +212,34 @@ type AWSBedrockProvisionedThroughput struct { ARN string `json:"arn"` } -type ServerSideModelConfig struct { - AWSBedrockProvisionedThroughput *AWSBedrockProvisionedThroughput `json:"awsBedrockProvisionedThroughput"` +type ServerSideModelConfigOpenAICompatible struct { + // APIModel is value actually sent to the OpenAI-compatible API in the "model" field. This + // is less like a "model name" or "model identifier", and more like "an opaque, potentially + // secret string." + // + // Much software that claims to 'implement the OpenAI API' actually overrides this field with + // other information NOT related to the model name, either making it _ineffective_ as a + // model name/identifier (e.g. you must send "tgi" or "AUTODETECT" irrespective of which model + // you want to use) OR using it to smuggle other (potentially sensitive) information like the + // name of the deployment, which cannot be shared with clients. + // + // If this field is not an empty string, we treat it as an opaque string to be sent with API + // requests (similar to an access token) and use it for nothing else. If this field is not + // specified, we default to the Model.ModelName. + // + // Examples (these would be sent in the OpenAI /chat/completions `"model"` field): + // + // * Huggingface TGI: "tgi" + // * NVIDIA NIM: "meta/llama3-70b-instruct" + // * AWS LISA (v2): "AUTODETECT" + // * AWS LISA (v1): "mistralai/Mistral7b-v0.3-Instruct ecs.textgen.tgi" + // * Ollama: "llama2" + // * Others: "" + // + APIModel string `json:"apiModel,omitempty"` +} + +type ServerSideModelConfig struct { + AWSBedrockProvisionedThroughput *AWSBedrockProvisionedThroughput `json:"awsBedrockProvisionedThroughput,omitempty"` + OpenAICompatible *ServerSideModelConfigOpenAICompatible `json:"openAICompatible,omitempty"` } diff --git a/internal/modelconfig/updates_test.go b/internal/modelconfig/updates_test.go index 2bebe687ddf..31bcb2a3bc5 100644 --- a/internal/modelconfig/updates_test.go +++ b/internal/modelconfig/updates_test.go @@ -1,7 +1,6 @@ package modelconfig import ( - "reflect" "testing" "github.com/stretchr/testify/assert" @@ -64,17 +63,6 @@ func TestApplyModelOverrides(t *testing.T) { // The configuration data is applied too, but it isn't a copy rather we just update the pointers // to point to the original data. t.Run("ConfigPointers", func(t *testing.T) { - { - // This test skips validation for the `model.ClientSideConfig` value because there isn't a - // reliable way to actually confirm the pointer was changed. Since the size of the data type - // is 0, the Go compiler can do all sorts of optimization schenanigans. - // - // When this scenario fails when we finally add a field to the ClientSideConfig struct, just - // uncomment the relevant parts of the code below. - clientSideConfig := types.ClientSideModelConfig{} - assert.EqualValues(t, 0, reflect.TypeOf(clientSideConfig).Size(), "See comment in the code...") - } - mod := getValidModel() origClientCfg := mod.ClientSideConfig origServerCfg := mod.ServerSideConfig @@ -90,8 +78,7 @@ func TestApplyModelOverrides(t *testing.T) { } // Confirm the override has different pointers for the model config. - // require.True(t, origClientCfg != override.ClientSideConfig, "orig = %p, override = %p", origClientCfg, override.ClientSideConfig) - // ^-- 0-byte type schenanigans... + require.True(t, origClientCfg != override.ClientSideConfig, "orig = %p, override = %p", origClientCfg, override.ClientSideConfig) require.True(t, origServerCfg != override.ServerSideConfig) err := ApplyModelOverride(&mod, override) @@ -100,8 +87,7 @@ func TestApplyModelOverrides(t *testing.T) { assert.NotNil(t, mod.ClientSideConfig) assert.NotNil(t, mod.ServerSideConfig) - // assert.True(t, mod.ClientSideConfig != origClientCfg) - // ^-- 0-byte type schenanigans... + assert.True(t, mod.ClientSideConfig != origClientCfg) assert.True(t, mod.ServerSideConfig != origServerCfg) assert.True(t, mod.ClientSideConfig == override.ClientSideConfig) diff --git a/schema/schema.go b/schema/schema.go index 7310abaa194..52755863394 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -605,6 +605,58 @@ type ChangesetTemplate struct { // ClientSideModelConfig description: No client-side model configuration is currently available. type ClientSideModelConfig struct { + Openaicompatible *ClientSideModelConfigOpenAICompatible `json:"openaicompatible,omitempty"` +} + +// ClientSideModelConfigOpenAICompatible description: Advanced configuration options that are only respected if the model is provided by an openaicompatible provider. +type ClientSideModelConfigOpenAICompatible struct { + AutoCompleteMultilineMaxTokens int `json:"autoCompleteMultilineMaxTokens,omitempty"` + AutoCompleteSinglelineMaxTokens int `json:"autoCompleteSinglelineMaxTokens,omitempty"` + AutoCompleteTemperature float64 `json:"autoCompleteTemperature,omitempty"` + AutoCompleteTopK float64 `json:"autoCompleteTopK,omitempty"` + AutoCompleteTopP float64 `json:"autoCompleteTopP,omitempty"` + // AutocompleteMultilineTimeout description: How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all. + // + // This applies on multi-line completions, which are based on intent-detection when e.g. a code block is being completed, e.g. 'func parseURL(url string) {' + // + // Note: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.multiline' If user has configured that, it will be respected instead of this. + AutocompleteMultilineTimeout int `json:"autocompleteMultilineTimeout,omitempty"` + // AutocompleteSinglelineTimeout description: How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all. + // + // This applies on single-line completions, e.g. 'var i = ' + // + // Note: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.singleline' If user has configured that, it will be respected instead of this. + AutocompleteSinglelineTimeout int `json:"autocompleteSinglelineTimeout,omitempty"` + ChatMaxTokens int `json:"chatMaxTokens,omitempty"` + // ChatPreInstruction description: Custom instruction to be included at the start of all chat messages + // when using this model, e.g. 'Answer all questions in Spanish.' + // + // Note: similar to Cody client config option 'cody.chat.preInstruction'; if user has configured that it will be used instead of this. + ChatPreInstruction string `json:"chatPreInstruction,omitempty"` + ChatTemperature float64 `json:"chatTemperature,omitempty"` + ChatTopK float64 `json:"chatTopK,omitempty"` + ChatTopP float64 `json:"chatTopP,omitempty"` + // ContextSizeHintPrefixCharacters description: A hint the client should use when producing context to send to the LLM. + // The maximum length of the document prefix (text before the cursor) to include, in characters. + ContextSizeHintPrefixCharacters *int `json:"contextSizeHintPrefixCharacters,omitempty"` + // ContextSizeHintSuffixCharacters description: A hint the client should use when producing context to send to the LLM. + // The maximum length of the document suffix (text after the cursor) to include, in characters. + ContextSizeHintSuffixCharacters *int `json:"contextSizeHintSuffixCharacters,omitempty"` + // ContextSizeHintTotalCharacters description: A hint the client should use when producing context to send to the LLM. + // The maximum length of all context (prefix + suffix + snippets), in characters. + ContextSizeHintTotalCharacters *int `json:"contextSizeHintTotalCharacters,omitempty"` + EditMaxTokens int `json:"editMaxTokens,omitempty"` + // EditPostInstruction description: Custom instruction to be included at the end of all edit commands + // when using this model, e.g. 'Write all unit tests with Jest instead of detected framework.' + // + // Note: similar to Cody client config option 'cody.edit.preInstruction'; if user has configured that it will be respected instead of this. + EditPostInstruction string `json:"editPostInstruction,omitempty"` + EditTemperature float64 `json:"editTemperature,omitempty"` + EditTopK float64 `json:"editTopK,omitempty"` + EditTopP float64 `json:"editTopP,omitempty"` + // EndOfText description: End of text identifier used by the model. + EndOfText string `json:"endOfText,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` } // ClientSideProviderConfig description: No client-side provider configuration is currently available. @@ -2008,6 +2060,10 @@ type OnboardingTourConfiguration struct { DefaultSnippets map[string]any `json:"defaultSnippets,omitempty"` Tasks []*OnboardingTask `json:"tasks"` } +type OpenAICompatibleEndpoint struct { + AccessToken string `json:"accessToken,omitempty"` + Url string `json:"url"` +} type OpenCodeGraphAnnotation struct { Item OpenCodeGraphItemRef `json:"item"` Range OpenCodeGraphRange `json:"range"` @@ -2534,6 +2590,7 @@ type Sentry struct { } type ServerSideModelConfig struct { AwsBedrockProvisionedThroughput *ServerSideModelConfigAwsBedrockProvisionedThroughput + Openaicompatible *ServerSideModelConfigOpenAICompatible Unused *DoNotUsePhonyDiscriminantType } @@ -2541,6 +2598,9 @@ func (v ServerSideModelConfig) MarshalJSON() ([]byte, error) { if v.AwsBedrockProvisionedThroughput != nil { return json.Marshal(v.AwsBedrockProvisionedThroughput) } + if v.Openaicompatible != nil { + return json.Marshal(v.Openaicompatible) + } if v.Unused != nil { return json.Marshal(v.Unused) } @@ -2556,10 +2616,12 @@ func (v *ServerSideModelConfig) UnmarshalJSON(data []byte) error { switch d.DiscriminantProperty { case "awsBedrockProvisionedThroughput": return json.Unmarshal(data, &v.AwsBedrockProvisionedThroughput) + case "openaicompatible": + return json.Unmarshal(data, &v.Openaicompatible) case "unused": return json.Unmarshal(data, &v.Unused) } - return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrockProvisionedThroughput", "unused"}) + return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrockProvisionedThroughput", "openaicompatible", "unused"}) } type ServerSideModelConfigAwsBedrockProvisionedThroughput struct { @@ -2567,6 +2629,13 @@ type ServerSideModelConfigAwsBedrockProvisionedThroughput struct { Arn string `json:"arn"` Type string `json:"type"` } + +// ServerSideModelConfigOpenAICompatible description: Configuration that is only respected if the model is provided by an openaicompatible provider. +type ServerSideModelConfigOpenAICompatible struct { + // ApiModel description: The literal string value of the 'model' field that will be sent to the /chat/completions API, for example. If set, Sourcegraph treats this as an opaque string and sends it directly to the API, inferring no information from it. By default, the configured model name is sent. + ApiModel string `json:"apiModel,omitempty"` + Type string `json:"type"` +} type ServerSideProviderConfig struct { AwsBedrock *ServerSideProviderConfigAWSBedrock AzureOpenAI *ServerSideProviderConfigAzureOpenAI @@ -2574,6 +2643,7 @@ type ServerSideProviderConfig struct { Fireworks *ServerSideProviderConfigFireworksProvider Google *ServerSideProviderConfigGoogleProvider Openai *ServerSideProviderConfigOpenAIProvider + HuggingfaceTgi *ServerSideProviderConfigHuggingfaceTGIProvider Openaicompatible *ServerSideProviderConfigOpenAICompatibleProvider Sourcegraph *ServerSideProviderConfigSourcegraphProvider Unused *DoNotUsePhonyDiscriminantType @@ -2598,6 +2668,9 @@ func (v ServerSideProviderConfig) MarshalJSON() ([]byte, error) { if v.Openai != nil { return json.Marshal(v.Openai) } + if v.HuggingfaceTgi != nil { + return json.Marshal(v.HuggingfaceTgi) + } if v.Openaicompatible != nil { return json.Marshal(v.Openaicompatible) } @@ -2627,6 +2700,8 @@ func (v *ServerSideProviderConfig) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &v.Fireworks) case "google": return json.Unmarshal(data, &v.Google) + case "huggingface-tgi": + return json.Unmarshal(data, &v.HuggingfaceTgi) case "openai": return json.Unmarshal(data, &v.Openai) case "openaicompatible": @@ -2636,7 +2711,7 @@ func (v *ServerSideProviderConfig) UnmarshalJSON(data []byte) error { case "unused": return json.Unmarshal(data, &v.Unused) } - return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrock", "azureOpenAI", "anthropic", "fireworks", "google", "openai", "openaicompatible", "sourcegraph", "unused"}) + return fmt.Errorf("tagged union type must have a %q property whose value is one of %s", "type", []string{"awsBedrock", "azureOpenAI", "anthropic", "fireworks", "google", "openai", "huggingface-tgi", "openaicompatible", "sourcegraph", "unused"}) } type ServerSideProviderConfigAWSBedrock struct { @@ -2674,10 +2749,17 @@ type ServerSideProviderConfigGoogleProvider struct { Endpoint string `json:"endpoint"` Type string `json:"type"` } +type ServerSideProviderConfigHuggingfaceTGIProvider struct { + // EnableVerboseLogs description: Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint. + EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"` + Endpoints []*OpenAICompatibleEndpoint `json:"endpoints"` + Type string `json:"type"` +} type ServerSideProviderConfigOpenAICompatibleProvider struct { - AccessToken string `json:"accessToken"` - Endpoint string `json:"endpoint"` - Type string `json:"type"` + // EnableVerboseLogs description: Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint. + EnableVerboseLogs bool `json:"enableVerboseLogs,omitempty"` + Endpoints []*OpenAICompatibleEndpoint `json:"endpoints"` + Type string `json:"type"` } type ServerSideProviderConfigOpenAIProvider struct { AccessToken string `json:"accessToken"` diff --git a/schema/site.schema.json b/schema/site.schema.json index 49b87fe9d4a..af0802bc354 100644 --- a/schema/site.schema.json +++ b/schema/site.schema.json @@ -3191,13 +3191,9 @@ "items": { "type": "string", "enum": [ - "bigcode::v1::starcoder2-3b", "bigcode::v1::starcoder2-7b", "bigcode::v1::starcoder2-15b", - "mistral::v1::mistral-7b", "mistral::v1::mistral-7b-instruct", - "mistral::v1::mixtral-8x7b", - "mistral::v1::mixtral-8x22b", "mistral::v1::mixtral-8x7b-instruct", "mistral::v1::mixtral-8x22b-instruct" ] @@ -3466,6 +3462,7 @@ "fireworks", "google", "openai", + "huggingface-tgi", "openaicompatible", "sourcegraph" ] @@ -3490,6 +3487,9 @@ { "$ref": "#/definitions/ServerSideProviderConfigOpenAIProvider" }, + { + "$ref": "#/definitions/ServerSideProviderConfigHuggingfaceTGIProvider" + }, { "$ref": "#/definitions/ServerSideProviderConfigOpenAICompatibleProvider" }, @@ -3613,19 +3613,56 @@ } } }, + "ServerSideProviderConfigHuggingfaceTGIProvider": { + "type": "object", + "required": ["type", "endpoints"], + "properties": { + "type": { + "type": "string", + "const": "huggingface-tgi" + }, + "endpoints": { + "$ref": "#/definitions/OpenAICompatibleEndpoint" + }, + "enableVerboseLogs": { + "description": "Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.", + "type": "boolean", + "default": false + } + } + }, "ServerSideProviderConfigOpenAICompatibleProvider": { "type": "object", - "required": ["type", "accessToken", "endpoint"], + "required": ["type", "endpoints"], "properties": { "type": { "type": "string", "const": "openaicompatible" }, - "accessToken": { - "type": "string" + "endpoints": { + "$ref": "#/definitions/OpenAICompatibleEndpoint" }, - "endpoint": { - "type": "string" + "enableVerboseLogs": { + "description": "Whether to enable verbose logging of requests. When enabled, grep for 'OpenAICompatible' in the frontend container logs to see the requests Cody makes to the endpoint.", + "type": "boolean", + "default": false + } + } + }, + "OpenAICompatibleEndpoint": { + "description": "Endpoints to connect to. If multiple are specified, Sourcegraph will randomly distribute requests between them.", + "type": "array", + "items": { + "minLength": 1, + "type": "object", + "required": ["url"], + "properties": { + "url": { + "type": "string" + }, + "accessToken": { + "type": "string" + } } } }, @@ -3652,7 +3689,81 @@ }, "default": null, "description": "No client-side model configuration is currently available.", - "properties": {} + "properties": { + "openaicompatible": { + "$ref": "#/definitions/ClientSideModelConfigOpenAICompatible" + } + } + }, + "ClientSideModelConfigOpenAICompatible": { + "type": "object", + "!go": { + "pointer": true + }, + "default": null, + "description": "Advanced configuration options that are only respected if the model is provided by an openaicompatible provider.", + "properties": { + "stopSequences": { + "type": "array", + "items": { + "type": "string", + "description": "List of stop sequences to use for this model.", + "examples": ["\n"] + } + }, + "endOfText": { + "type": "string", + "description": "End of text identifier used by the model.", + "examples": ["<|endoftext|>", ""] + }, + "contextSizeHintTotalCharacters": { + "!go": { "pointer": true }, + "default": null, + "type": "integer", + "description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of all context (prefix + suffix + snippets), in characters." + }, + "contextSizeHintPrefixCharacters": { + "!go": { "pointer": true }, + "default": null, + "type": "integer", + "description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of the document prefix (text before the cursor) to include, in characters." + }, + "contextSizeHintSuffixCharacters": { + "!go": { "pointer": true }, + "default": null, + "type": "integer", + "description": "A hint the client should use when producing context to send to the LLM.\nThe maximum length of the document suffix (text after the cursor) to include, in characters." + }, + "chatPreInstruction": { + "type": "string", + "description": "Custom instruction to be included at the start of all chat messages\nwhen using this model, e.g. 'Answer all questions in Spanish.'\n\nNote: similar to Cody client config option 'cody.chat.preInstruction'; if user has configured that it will be used instead of this." + }, + "editPostInstruction": { + "type": "string", + "description": "Custom instruction to be included at the end of all edit commands\nwhen using this model, e.g. 'Write all unit tests with Jest instead of detected framework.'\n\nNote: similar to Cody client config option 'cody.edit.preInstruction'; if user has configured that it will be respected instead of this." + }, + "autocompleteSinglelineTimeout": { + "type": "integer", + "description": "How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.\n\nThis applies on single-line completions, e.g. 'var i = '\n\nNote: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.singleline' If user has configured that, it will be respected instead of this." + }, + "autocompleteMultilineTimeout": { + "type": "integer", + "description": "How long the client should wait for autocomplete results to come back (milliseconds), before giving up and not displaying an autocomplete result at all.\n\nThis applies on multi-line completions, which are based on intent-detection when e.g. a code block is being completed, e.g. 'func parseURL(url string) {'\n\nNote: similar to hidden Cody client config option 'cody.autocomplete.advanced.timeout.multiline' If user has configured that, it will be respected instead of this." + }, + "chatTopK": { "type": "number" }, + "chatTopP": { "type": "number" }, + "chatTemperature": { "type": "number" }, + "chatMaxTokens": { "type": "integer" }, + "autoCompleteTopK": { "type": "number" }, + "autoCompleteTopP": { "type": "number" }, + "autoCompleteTemperature": { "type": "number" }, + "autoCompleteSinglelineMaxTokens": { "type": "integer" }, + "autoCompleteMultilineMaxTokens": { "type": "integer" }, + "editTopK": { "type": "number" }, + "editTopP": { "type": "number" }, + "editTemperature": { "type": "number" }, + "editMaxTokens": { "type": "integer" } + } }, "ServerSideModelConfig": { "type": "object", @@ -3665,13 +3776,16 @@ "properties": { "type": { "type": "string", - "enum": ["awsBedrockProvisionedThroughput"] + "enum": ["awsBedrockProvisionedThroughput", "openaicompatible"] } }, "oneOf": [ { "$ref": "#/definitions/ServerSideModelConfigAwsBedrockProvisionedThroughput" }, + { + "$ref": "#/definitions/ServerSideModelConfigOpenAICompatible" + }, { "$ref": "#/definitions/DoNotUsePhonyDiscriminantType" } @@ -3691,6 +3805,21 @@ } } }, + "ServerSideModelConfigOpenAICompatible": { + "description": "Configuration that is only respected if the model is provided by an openaicompatible provider.", + "type": "object", + "required": ["type"], + "properties": { + "type": { + "type": "string", + "const": "openaicompatible" + }, + "apiModel": { + "description": "The literal string value of the 'model' field that will be sent to the /chat/completions API, for example. If set, Sourcegraph treats this as an opaque string and sends it directly to the API, inferring no information from it. By default, the configured model name is sent.", + "type": "string" + } + } + }, "DoNotUsePhonyDiscriminantType": { "type": "object", "required": ["type"],