Several fixes around merging modelconfig, and the current Cody Gateway data (#63814)

While testing the modelconfig system working end-to-end with the data
coming from the site configuration, I ran into a handful of minor
issues.

They are all kinda subtle so I'll just leave comments to explain the
what and why.

## Test plan

Added new unit tests.

## Changelog

NA
This commit is contained in:
Chris Smith 2024-07-15 10:14:28 -07:00 committed by GitHub
parent c1efb92196
commit 8d4e5b52f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 388 additions and 92 deletions

View File

@ -169,7 +169,7 @@ func getGoogleModels() []types.Model {
newModel(
modelIdentity{
MRef: mRef(geminiV1, "gemini-1.5-flash-latest"),
Name: "google/gemini-1.5-flash-latest",
Name: "gemini-1.5-flash-latest",
DisplayName: "Gemini 1.5 Flash",
},
modelMetadata{
@ -183,8 +183,12 @@ func getGoogleModels() []types.Model {
}
func getMistralModels() []types.Model {
// Not sure if there is a canonical API reference, since Mixtral offers 3rd
// party LLMs as a service.
// NOTE: These are all kinda fubar, since we are offering Mixtral models
// via the Fireworks API provider.
//
// So the ModelNames do need to have this odd format. Because there isn't an
// actual "Mistral API Provider" in our backend, we route all of these to
// Fireworks.
// https://deepinfra.com/mistralai/Mixtral-8x22B-Instruct-v0.1/api
// https://readme.fireworks.ai
const mistralV1 = "mistral::v1"
@ -193,7 +197,7 @@ func getMistralModels() []types.Model {
newModel(
modelIdentity{
MRef: mRef(mistralV1, "mixtral-8x7b-instruct"),
Name: "mixtral-8x7b-instruct",
Name: "accounts/fireworks/models/mixtral-8x7b-instruct",
DisplayName: "Mixtral 8x7B",
},
modelMetadata{
@ -206,7 +210,7 @@ func getMistralModels() []types.Model {
newModel(
modelIdentity{
MRef: mRef(mistralV1, "mixtral-8x22b-instruct"),
Name: "mixtral-8x22b-instruct",
Name: "accounts/fireworks/models/mixtral-8x22b-instruct",
DisplayName: "Mixtral 8x22B",
},
modelMetadata{

View File

@ -186,7 +186,7 @@ func getChatModelFn(db database.DB) getModelFn {
return legacyMRef.ToModelRef(), nil
}
errModelNotAllowed := errors.Errorf(
"the requested model is not available (%q, onProTier=%v)",
"the requested chat model is not available (%q, onProTier=%v)",
requestParams.RequestedModel, subscription.ApplyProRateLimits)
return "", errModelNotAllowed
}
@ -228,7 +228,7 @@ func getChatModelFn(db database.DB) getModelFn {
}
err := errors.Errorf(
"unsupported code completion model %q (default %q)",
"unsupported chat model %q (default %q)",
initialRequestedModel, cfg.DefaultModels.Chat)
return "", err
}

View File

@ -355,7 +355,7 @@ func TestGetChatModelFn(t *testing.T) {
},
}
_, err := getModelFn(ctx, reqParams, &modelConfig)
require.ErrorContains(t, err, `unsupported code completion model "some-model-not-in-config"`)
require.ErrorContains(t, err, `unsupported chat model "some-model-not-in-config"`)
})
})

View File

@ -23,6 +23,7 @@ go_library(
"//internal/conf",
"//internal/conf/conftypes",
"//internal/database",
"//internal/license",
"//internal/modelconfig",
"//internal/modelconfig/embedded",
"//internal/modelconfig/types",

View File

@ -4,6 +4,8 @@ import (
"fmt"
"slices"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/modelconfig"
"github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -51,7 +53,8 @@ func (b *builder) build() (*types.ModelConfiguration, error) {
// If no site configuration data is supplied, then just use Sourcegraph
// supplied data.
if b.siteConfigData == nil {
return deepCopy(baseConfig)
routeUnconfiguredProvidersToCodyGateway(baseConfig)
return baseConfig, nil
}
// But if we are using site config data, ensure it is valid before appying.
if vErr := modelconfig.ValidateSiteConfig(b.siteConfigData); vErr != nil {
@ -106,28 +109,9 @@ func applySiteConfig(baseConfig *types.ModelConfiguration, siteConfig *types.Sit
if modelFilters := sgModelConfig.ModelFilters; modelFilters != nil {
var filteredModels []types.Model
for _, baseConfigModel := range mergedConfig.Models {
// Status filter.
if modelFilters.StatusFilter != nil {
if !slices.Contains(modelFilters.StatusFilter, string(baseConfigModel.Category)) {
continue
}
if isModelAllowed(baseConfigModel, *modelFilters) {
filteredModels = append(filteredModels, baseConfigModel)
}
// Allow list. If not specified, include all models. Otherwise, ONLY include the model
// IFF it matches one of the allow rules.
if len(modelFilters.Allow) > 0 {
if !filterListMatches(baseConfigModel.ModelRef, modelFilters.Allow) {
continue
}
}
// Deny list. Filter the model if it matches any deny rules.
if len(modelFilters.Deny) > 0 {
if filterListMatches(baseConfigModel.ModelRef, modelFilters.Deny) {
continue
}
}
filteredModels = append(filteredModels, baseConfigModel)
}
// Replace the base config models with the filtered set.
@ -211,6 +195,13 @@ func applySiteConfig(baseConfig *types.ModelConfiguration, siteConfig *types.Sit
}
}
// For any providers still missing server-side config, wire them to send
// requests to Cody Gateway. We do this BEFORE merging in model/provider
// overrides so that new providers defined by the admin don't get wired
// to Cody Gateway. (And instead will fail at runtime because no server-side
// config is available.)
routeUnconfiguredProvidersToCodyGateway(mergedConfig)
// If there are remaining keys in `modelOverrideLookup` means that the are for a ModelRef that
// was NOT found in the base configuration. So in that case we add those as "entirely new" models
// that were only defined in the site config, and wasn't referenced in the base config.
@ -245,69 +236,18 @@ func applySiteConfig(baseConfig *types.ModelConfiguration, siteConfig *types.Sit
mergedConfig.Models = append(mergedConfig.Models, *newModel)
}
// Use the DefaultModels from the site config. Otherwise, we need to pick something randomly
// to ensure they are at least defined.
// Have any admin-supplied default models overwrite what may have been set by
// the base config.
if siteConfig.DefaultModels != nil {
mergedConfig.DefaultModels.Chat = siteConfig.DefaultModels.Chat
mergedConfig.DefaultModels.CodeCompletion = siteConfig.DefaultModels.CodeCompletion
mergedConfig.DefaultModels.FastChat = siteConfig.DefaultModels.FastChat
} else {
// getModelWithRequirements returns the the first model available with the specific capability and a matching
// category. Returns nil if no such model is found.
getModelWithRequirements := func(
wantCapability types.ModelCapability, wantCategories ...types.ModelCategory) *types.ModelRef {
for _, model := range mergedConfig.Models {
// Check if the model can be used for that purpose.
var hasCapability bool
for _, gotCapability := range model.Capabilities {
if gotCapability == wantCapability {
hasCapability = true
break
}
}
if !hasCapability {
return nil
}
}
// Check if the model has a matching category.
for _, wantCategory := range wantCategories {
if model.Category == wantCategory {
return &model.ModelRef
}
}
}
return nil
}
const (
accuracy = types.ModelCategoryAccuracy
balanced = types.ModelCategoryBalanced
speed = types.ModelCategorySpeed
)
// Infer the default models to used based on category. This is probably not going to lead to great
// results. But :shrug: it's better than just crash looping because the config is under-specified.
if mergedConfig.DefaultModels.Chat == "" {
validModel := getModelWithRequirements(types.ModelCapabilityAutocomplete, accuracy, balanced)
if validModel == nil {
return nil, errors.New("no suitable model found for Chat")
}
mergedConfig.DefaultModels.Chat = *validModel
}
if mergedConfig.DefaultModels.FastChat == "" {
validModel := getModelWithRequirements(types.ModelCapabilityAutocomplete, speed, balanced)
if validModel == nil {
return nil, errors.New("no suitable model found for FastChat")
}
mergedConfig.DefaultModels.FastChat = *validModel
}
if mergedConfig.DefaultModels.CodeCompletion == "" {
validModel := getModelWithRequirements(types.ModelCapabilityAutocomplete, speed, balanced)
if validModel == nil {
return nil, errors.New("no suitable model found for Chat")
}
mergedConfig.DefaultModels.CodeCompletion = *validModel
}
// But we still need to confirm that all the default models are actually valid.
// e.g. not filtered out because of the model filter, or have an invalid ModelRef.
if err := maybeFixDefaultModels(&mergedConfig.DefaultModels, mergedConfig.Models); err != nil {
return nil, err
}
// Validate the resulting configuration.
@ -316,3 +256,136 @@ func applySiteConfig(baseConfig *types.ModelConfiguration, siteConfig *types.Sit
}
return mergedConfig, nil
}
// isModelAllowed returns whether or not the model should be supported as per the
// supplied model filter configuration.
func isModelAllowed(m types.Model, filter types.ModelFilters) bool {
// Status filter. Exclude any models whose status doesn't match what is required.
if filter.StatusFilter != nil {
if !slices.Contains(filter.StatusFilter, string(m.Status)) {
return false
}
}
// Allow list. If not specified, include all models. Otherwise, ONLY include the model
// IFF it matches one of the allow rules.
if len(filter.Allow) > 0 {
if !filterListMatches(m.ModelRef, filter.Allow) {
return false
}
}
// Deny list. Filter the model if it matches any deny rules.
if len(filter.Deny) > 0 {
if filterListMatches(m.ModelRef, filter.Deny) {
return false
}
}
return true
}
// maybeFixDefaultModels will verify that the supplied DefaultModels set is valid, possibly
// modifying a model in case the existing value was incorrect.
func maybeFixDefaultModels(defaultModels *types.DefaultModels, allModels []types.Model) error {
// getModelWithRequirements returns the the first model available with the specific capability and a matching
// category. Returns nil if no such model is found.
getModelWithRequirements := func(
wantCapability types.ModelCapability, wantCategories ...types.ModelCategory) *types.ModelRef {
for _, model := range allModels {
// Check if the model can be used for that purpose.
var hasCapability bool
for _, gotCapability := range model.Capabilities {
if gotCapability == wantCapability {
hasCapability = true
break
}
}
if !hasCapability {
continue
}
// Check if the model has a matching category.
for _, wantCategory := range wantCategories {
if model.Category == wantCategory {
return &model.ModelRef
}
}
}
return nil
}
modelFound := func(mref types.ModelRef) bool {
for _, model := range allModels {
if model.ModelRef == mref {
return true
}
}
return false
}
const (
accuracy = types.ModelCategoryAccuracy
balanced = types.ModelCategoryBalanced
speed = types.ModelCategorySpeed
)
// Infer the default models to used based on category. This is probably not going to lead to great
// results. But :shrug: it's better than just crash looping because the config is under-specified.
if defaultModels.Chat == "" || !modelFound(defaultModels.Chat) {
validModelRef := getModelWithRequirements(types.ModelCapabilityChat, accuracy, balanced)
if validModelRef == nil {
return errors.Errorf("no suitable model found for Chat (%d candidates)", len(allModels))
}
defaultModels.Chat = *validModelRef
}
if defaultModels.FastChat == "" || !modelFound(defaultModels.FastChat) {
validModelRef := getModelWithRequirements(types.ModelCapabilityChat, speed, balanced)
if validModelRef == nil {
return errors.Errorf("no suitable model found for FastChat (%d candidates)", len(allModels))
}
defaultModels.FastChat = *validModelRef
}
if defaultModels.CodeCompletion == "" || !modelFound(defaultModels.CodeCompletion) {
validModelRef := getModelWithRequirements(types.ModelCapabilityAutocomplete, speed, balanced)
if validModelRef == nil {
return errors.Errorf("no suitable model found for CodeCompletion (%d candidates)", len(allModels))
}
defaultModels.CodeCompletion = *validModelRef
}
return nil
}
// routeUnconfiguredProvidersToCodyGateway verifies that each provider has server-side configuration
// data present. (So that it can actually serve LLM traffic.) However, for any providers missing
// server-side configuration, they will be updated to send traffic to Cody Gateway.
func routeUnconfiguredProvidersToCodyGateway(config *types.ModelConfiguration) {
siteConfig := conf.Get()
endpoint := conf.CodyGatewayProdEndpoint
accessToken := license.GenerateLicenseKeyBasedAccessToken(siteConfig.LicenseKey)
// Apply any overrides if present in the site config.
if siteConfig.ModelConfiguration != nil && siteConfig.ModelConfiguration.Sourcegraph != nil {
sgConfig := siteConfig.ModelConfiguration.Sourcegraph
if sgConfig.AccessToken != nil {
accessToken = *sgConfig.AccessToken
}
if sgConfig.Endpoint != nil {
endpoint = *sgConfig.Endpoint
}
}
for i := range config.Providers {
provider := &config.Providers[i]
if provider.ServerSideConfig == nil {
provider.ServerSideConfig = &types.ServerSideProviderConfig{
SourcegraphProvider: &types.SourcegraphProviderConfig{
AccessToken: accessToken,
Endpoint: endpoint,
},
}
}
}
}

View File

@ -2,13 +2,16 @@ package modelconfig
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/sourcegraph/sourcegraph/schema"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/licensing"
"github.com/sourcegraph/sourcegraph/internal/modelconfig"
"github.com/sourcegraph/sourcegraph/internal/modelconfig/embedded"
"github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
"github.com/sourcegraph/sourcegraph/lib/pointers"
@ -221,3 +224,210 @@ func TestModelConfigBuilder(t *testing.T) {
})
})
}
func TestApplySiteConfig(t *testing.T) {
// validModelWith returns a new, unique Model and applies the given override.
rng := rand.New(rand.NewSource(time.Now().Unix()))
validModelWith := func(override types.ModelOverride) types.Model {
modelID := fmt.Sprintf("test-model-%x", rng.Uint64())
m := types.Model{
ModelRef: types.ModelRef(fmt.Sprintf("test-provider::v1::%s", modelID)),
ModelName: modelID,
Category: types.ModelCategoryBalanced,
Capabilities: []types.ModelCapability{types.ModelCapabilityChat, types.ModelCapabilityAutocomplete},
ContextWindow: types.ContextWindow{
MaxInputTokens: 1,
MaxOutputTokens: 1,
},
}
err := modelconfig.ApplyModelOverride(&m, override)
require.NoError(t, err)
return m
}
toModelOverride := func(m types.Model) types.ModelOverride {
return types.ModelOverride{
ModelRef: m.ModelRef,
ModelName: m.ModelName,
Capabilities: m.Capabilities,
ContextWindow: m.ContextWindow,
Category: m.Category,
}
}
t.Run("SourcegraphSuppliedModels", func(t *testing.T) {
t.Run("StatusFilter", func(t *testing.T) {
// The source config contains four models, one with each status.
sourcegraphSuppliedConfig := types.ModelConfiguration{
Providers: []types.Provider{
{
ID: types.ProviderID("test-provider"),
},
},
Models: []types.Model{
validModelWith(types.ModelOverride{
Status: types.ModelStatusExperimental,
}),
validModelWith(types.ModelOverride{
Status: types.ModelStatusBeta,
}),
validModelWith(types.ModelOverride{
Status: types.ModelStatusStable,
}),
validModelWith(types.ModelOverride{
Status: types.ModelStatusDeprecated,
}),
},
}
// The site configuration filters out all but "beta" and "stable".
siteConfig := types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{
ModelFilters: &types.ModelFilters{
StatusFilter: []string{"beta", "stable"},
},
},
}
gotConfig, err := applySiteConfig(&sourcegraphSuppliedConfig, &siteConfig)
require.NoError(t, err)
// Count the final models after the filter was applied.
statusCounts := map[types.ModelStatus]int{}
for _, model := range gotConfig.Models {
statusCounts[model.Status]++
}
assert.Equal(t, 2, len(gotConfig.Models))
assert.Equal(t, 1, statusCounts[types.ModelStatusBeta])
assert.Equal(t, 1, statusCounts[types.ModelStatusStable])
})
})
// This test covers the situation where the the default models from the base configuration
// are removed due to model filter, but the site configut doesn't provide valid values.
t.Run("ReplacedDefaultModels", func(t *testing.T) {
testModel := func(id string, capabilities []types.ModelCapability, category types.ModelCategory) types.Model {
m := validModelWith(types.ModelOverride{
Capabilities: capabilities,
Category: category,
})
m.ModelRef = types.ModelRef(fmt.Sprintf("test-provider::v1::%s", id))
return m
}
getValidBaseConfig := func() types.ModelConfiguration {
chatModel := testModel("chat", []types.ModelCapability{types.ModelCapabilityChat}, types.ModelCategoryAccuracy)
codeModel := testModel("code", []types.ModelCapability{types.ModelCapabilityAutocomplete}, types.ModelCategorySpeed)
return types.ModelConfiguration{
Providers: []types.Provider{
{
ID: types.ProviderID("test-provider"),
},
},
Models: []types.Model{
chatModel,
codeModel,
},
DefaultModels: types.DefaultModels{
Chat: chatModel.ModelRef,
CodeCompletion: codeModel.ModelRef,
FastChat: chatModel.ModelRef,
},
}
}
t.Run("Base", func(t *testing.T) {
baseConfig := getValidBaseConfig()
_, err := applySiteConfig(&baseConfig, &types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{}, // i.e. use the baseconfig.
})
require.NoError(t, err)
})
t.Run("ErrorNoChatModelAvail", func(t *testing.T) {
// Now have the site config reject the chat model that was used as the default model.
// This will now fail because there is nothing suitable.
baseConfig := getValidBaseConfig()
_, err := applySiteConfig(&baseConfig, &types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{
ModelFilters: &types.ModelFilters{
Deny: []string{"*chat"},
},
},
})
assert.ErrorContains(t, err, "no suitable model found for Chat (1 candidates)")
})
t.Run("AlternativeUsed", func(t *testing.T) {
t.Run("ErrorUnsuitableCandidate", func(t *testing.T) {
// We add a new model from the site config, but the capability and category
// make it unsuitable as the default chat model.
modelInSiteConfig := testModel(
"err-from-site-config", []types.ModelCapability{types.ModelCapabilityAutocomplete}, types.ModelCategorySpeed)
baseConfig := getValidBaseConfig()
_, err := applySiteConfig(&baseConfig, &types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{
ModelFilters: &types.ModelFilters{
Deny: []string{"*chat"},
},
},
ModelOverrides: []types.ModelOverride{
toModelOverride(modelInSiteConfig),
},
})
assert.ErrorContains(t, err, "no suitable model found for Chat (2 candidates)")
})
t.Run("ErrorStillNoSuitableCandidate", func(t *testing.T) {
// This time it works, because the model's capabilities and category.
//
// However, we still get an error because there is no valid model for
// the *fast chat*. Because "accuracy" isn't viable for fast chat,
// it needs to be "speed" or "balanced".
fromSiteConfig1 := testModel(
"from-site-config1", []types.ModelCapability{types.ModelCapabilityChat}, types.ModelCategoryAccuracy)
fromSiteConfig2 := testModel(
"from-site-config2", []types.ModelCapability{types.ModelCapabilityAutocomplete}, types.ModelCategoryBalanced)
baseConfig := getValidBaseConfig()
_, err := applySiteConfig(&baseConfig, &types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{
ModelFilters: &types.ModelFilters{
Deny: []string{"*chat"},
},
},
ModelOverrides: []types.ModelOverride{
toModelOverride(fromSiteConfig1),
toModelOverride(fromSiteConfig2),
},
})
assert.ErrorContains(t, err, "no suitable model found for FastChat (3 candidates)")
})
t.Run("Works", func(t *testing.T) {
// This time it all works, because the new model is "balanced".
modelInSiteConfig := testModel(
"from-site-config", []types.ModelCapability{types.ModelCapabilityChat}, types.ModelCategoryBalanced)
baseConfig := getValidBaseConfig()
gotConfig, err := applySiteConfig(&baseConfig, &types.SiteModelConfiguration{
SourcegraphModelConfig: &types.SourcegraphModelConfig{
ModelFilters: &types.ModelFilters{
Deny: []string{"*chat"},
},
},
ModelOverrides: []types.ModelOverride{
toModelOverride(modelInSiteConfig),
},
})
require.NoError(t, err)
assert.EqualValues(t, modelInSiteConfig.ModelRef, gotConfig.DefaultModels.Chat)
assert.EqualValues(t, modelInSiteConfig.ModelRef, gotConfig.DefaultModels.FastChat)
})
})
})
}

View File

@ -119,6 +119,12 @@ func (c *codyGatewayClient) clientForParams(logger log.Logger, feature types.Com
client := anthropic.NewClient(doer, "", "", true, c.tokenManager)
return client, nil
case "mistral":
// Annoying legacy hack: We expose Mistral model (e.g. "mixtral-8x22b-instruct") but have only
// effer offered them via the Fireworks API provider. So when switching to the newer modelconfig
// format, this is a situation where there wasn't a "mistral API Provider" for these models.
// Instead, we just send these to fireworks.
fallthrough
case conftypes.CompletionsProviderNameFireworks:
doer := gatewayDoer(
c.upstream, feature, c.gatewayURL, c.accessToken,

View File

@ -23,6 +23,8 @@ import (
"github.com/sourcegraph/sourcegraph/schema"
)
const CodyGatewayProdEndpoint = "https://cody-gateway.sourcegraph.com"
func init() {
deployType := deploy.Type()
if !deploy.IsValidDeployType(deployType) {
@ -726,7 +728,7 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
if completionsConfig.Provider == string(conftypes.CompletionsProviderNameSourcegraph) {
// If no endpoint is configured, use a default value.
if completionsConfig.Endpoint == "" {
completionsConfig.Endpoint = "https://cody-gateway.sourcegraph.com"
completionsConfig.Endpoint = CodyGatewayProdEndpoint
}
// Set the access token, either use the configured one, or generate one for the platform.

View File

@ -127,7 +127,7 @@
{
"modelRef": "google::v1::gemini-1.5-flash-latest",
"displayName": "Gemini 1.5 Flash",
"modelName": "google/gemini-1.5-flash-latest",
"modelName": "gemini-1.5-flash-latest",
"capabilities": ["autocomplete", "chat"],
"category": "speed",
"status": "stable",
@ -140,7 +140,7 @@
{
"modelRef": "mistral::v1::mixtral-8x7b-instruct",
"displayName": "Mixtral 8x7B",
"modelName": "mixtral-8x7b-instruct",
"modelName": "accounts/fireworks/models/mixtral-8x7b-instruct",
"capabilities": ["autocomplete", "chat"],
"category": "speed",
"status": "stable",
@ -153,7 +153,7 @@
{
"modelRef": "mistral::v1::mixtral-8x22b-instruct",
"displayName": "Mixtral 8x22B",
"modelName": "mixtral-8x22b-instruct",
"modelName": "accounts/fireworks/models/mixtral-8x22b-instruct",
"capabilities": ["autocomplete", "chat"],
"category": "accuracy",
"status": "stable",