mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:11:49 +00:00
Return model IDs from GraphQL, not model Names (#64307)
::sigh:: the problem here is super in the weeds, but ultimately this
fixes a problem introduced when using AWS Bedrock and Sourcegraph
instances using the older style "completions" config.
## The problem
AWS Bedrock has some LLM model names that contain a colon, e.g.
`anthropic.claude-3-opus-20240229-v1:0`. Cody clients connecting to
Sourcegraph instances using the older style "completions" config will
obtain the available LLM models by using GraphGL.
So the Cody client would see that the chat model is
`anthropic.claude-3-opus-20240229-v1:0`.
However, under the hood, the Sourcegraph instance will convert the site
config into the newer `modelconfig` format. And during that conversion,
we use a _different value_ for the **model ID** than what is in the site
config. (The **model name** is what is sent to the LLM API, and is
unmodified. The model ID is a stable, unique identifier but is sanitized
so that it adheres to naming rules.)
Because of this, we have a problem.
When the Cody client makes a request to the HTTP completions API with
the model name of `anthropic.claude-3-opus-20240229-v1:0` or
`anthropic/anthropic.claude-3-opus-20240229-v1:0` it fails. Because
there is no model with ID `...v1:0`. (We only have the sanitized
version, `...v1_0`.)
## The fix
There were a few ways we could fix this, but this goes with just having
the GraphQL component return the model ID instead of the model name. So
that when the Cody client passes that model ID to the completions API,
everything works as it should.
And, practically speaking, for 99.9% of cases, the model name and model
ID will be identical. We only strip out non-URL safe characters and
colons, which usually aren't used in model names.
## Potential bugs
With this fix however, there is a specific combination of { client,
server, and model name } where things could in theory break.
Specifically:
Client | Server | Modelname | Works |
--- | --- | --- | --- |
unaware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] |
aware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] |
unaware-of-modelconfig | using-modelconfig | standard | 🟢 [1] |
aware-of-modelconfig | using-modelconfig | standard | 🟢 [3] |
unaware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] |
aware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] |
unaware-of-modelconfig | using-modelconfig | non-standard | 🔴 [2] |
aware-of-modelconfig | using-modelconfig | non-standard | 🟢 [3] |
1. If the model name is something that doesn't require sanitization,
there is no problem. The model ID will be the same as the model name,
and things will work like they do today.
2. If the model name gets sanitized, then IFF the Cody client were to
make a decision based on that exact model name, it wouldn't work.
Because it would receive the sanitized name, and not the real one. As
long as the Cody client is only passing that model name onto the
Sourcegraph backend which will recognize the sanitized model name / ID,
all is well.
3. If the client and server are new, and using model config, then this
shouldn't be a problem because the client would use a different API to
fetch the Sourcegraph instance's supported models. And within the
client, natively refer to the model ID instead of the model name.
Fixes
[PRIME-464](https://linear.app/sourcegraph/issue/PRIME-464/aws-bedrock-x-completions-config-does-not-work-if-model-name-has-a).
## Test plan
Added some unit tests.
## Changelog
NA
This commit is contained in:
parent
c414477bee
commit
582022d301
@ -69,7 +69,7 @@ type completionsConfigResolver struct {
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) ChatModel() (string, error) {
|
||||
return c.config.ChatModel, nil
|
||||
return convertLegacyModelNameToModelID(c.config.ChatModel), nil
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) ChatModelMaxTokens() (*int32, error) {
|
||||
@ -92,7 +92,7 @@ func (c *completionsConfigResolver) DisableClientConfigAPI() bool {
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) FastChatModel() (string, error) {
|
||||
return c.config.FastChatModel, nil
|
||||
return convertLegacyModelNameToModelID(c.config.FastChatModel), nil
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) FastChatModelMaxTokens() (*int32, error) {
|
||||
@ -108,7 +108,7 @@ func (c *completionsConfigResolver) Provider() string {
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) CompletionModel() (string, error) {
|
||||
return c.config.CompletionModel, nil
|
||||
return convertLegacyModelNameToModelID(c.config.CompletionModel), nil
|
||||
}
|
||||
|
||||
func (c *completionsConfigResolver) CompletionModelMaxTokens() (*int32, error) {
|
||||
@ -254,12 +254,18 @@ func (r *modelconfigResolver) CompletionModelMaxTokens() (*int32, error) {
|
||||
// the provider as needed to match older behavior. (See unit tests and convertProviderID for
|
||||
// more information.)
|
||||
func (r *modelconfigResolver) toLegacyModelRef(model modelconfigSDK.Model) string {
|
||||
modelID := model.ModelRef.ModelID()
|
||||
providerID := model.ModelRef.ProviderID()
|
||||
legacyProviderName := r.convertProviderID(providerID)
|
||||
|
||||
// For compatibility, we are returning the model _name_ instead of the model _id_.
|
||||
// So the client will see "claude-3-xxxx" not the shortened model ID like "claude-3".
|
||||
return fmt.Sprintf("%s/%s", legacyProviderName, model.ModelName)
|
||||
// Potential issue: Older Cody clients calling the GraphQL may expect to see the model **name**
|
||||
// such as "claude-3-sonnet-20240229". But it is important that we only return the model **ID**
|
||||
// because that is what the HTTP completions API is expecting to see from the client.
|
||||
//
|
||||
// So when using older Cody clients, unaware of the newer modelconfig system, this could lead
|
||||
// to some errors. (But newer clients won't be using this GraphQL endpoint at all and instead
|
||||
// just use the newer modelconfig system, so hopefully this won't be a major concern.)
|
||||
return fmt.Sprintf("%s/%s", legacyProviderName, modelID)
|
||||
}
|
||||
|
||||
// convertProviderID returns the _API Provider_ for the referenced modelconfig provider.
|
||||
|
||||
@ -59,6 +59,45 @@ func TestCompletionsResolver(t *testing.T) {
|
||||
model, err = testResolver.CompletionModel()
|
||||
assert.EqualValues(t, siteConfigData.CompletionModel, model)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// In the GraphQL resolver we are returning the model name expressed in
|
||||
// the site config, but the HTTP completions API only accepts model IDs.
|
||||
// For the "completions" config, these are 99% identical, but in some cases
|
||||
// may differ.
|
||||
//
|
||||
// In the completions API (see get_model.go) we lookup a model by its mref
|
||||
// or model ID, and then use the unmodified model name when making the API
|
||||
// request.
|
||||
t.Run("Sanitization", func(t *testing.T) {
|
||||
// Copy and introduce more challenging model names.
|
||||
updatedSiteConfigData := *siteConfigData
|
||||
updatedSiteConfigData.ChatModel = "anthropic.claude-3-opus-20240229-v1:0/so:many:colons"
|
||||
updatedSiteConfigData.FastChatModel = "all/sorts@of;special_chars&but!no#sanitization"
|
||||
updatedSiteConfigData.CompletionModel = "other invalid tokens 😭😭😭"
|
||||
|
||||
updatedResolver := &completionsConfigResolver{
|
||||
config: &updatedSiteConfigData,
|
||||
}
|
||||
|
||||
var (
|
||||
model string
|
||||
err error
|
||||
)
|
||||
model, err = updatedResolver.ChatModel()
|
||||
assert.NotEqualValues(t, updatedSiteConfigData.ChatModel, model)
|
||||
assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0/so_many_colons", model)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Fast chat had wonky characters, but none required sanitizing.
|
||||
model, err = updatedResolver.FastChatModel()
|
||||
assert.EqualValues(t, updatedSiteConfigData.FastChatModel, model)
|
||||
assert.NoError(t, err)
|
||||
|
||||
model, err = updatedResolver.CompletionModel()
|
||||
assert.NotEqualValues(t, updatedSiteConfigData.CompletionModel, model)
|
||||
assert.EqualValues(t, "other_invalid_tokens_____________", model)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -75,8 +114,8 @@ func TestModelConfigResolver(t *testing.T) {
|
||||
},
|
||||
}
|
||||
awsBedrockModel := modelconfigSDK.Model{
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::test-model_aws-bedrock"),
|
||||
ModelName: "aws-bedrock-model-name",
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::aws-bedrock_model-id"),
|
||||
ModelName: "aws-bedrock_model-name",
|
||||
}
|
||||
|
||||
// Azure OpenAI provider and model.
|
||||
@ -89,8 +128,8 @@ func TestModelConfigResolver(t *testing.T) {
|
||||
},
|
||||
}
|
||||
azureOpenAIModel := modelconfigSDK.Model{
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::test-model_azure-openai"),
|
||||
ModelName: "azure-openai-model-name",
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::azure-openai_model-id"),
|
||||
ModelName: "azure-openai_model-name",
|
||||
}
|
||||
|
||||
// Cody Gateway provider and model.
|
||||
@ -103,8 +142,8 @@ func TestModelConfigResolver(t *testing.T) {
|
||||
},
|
||||
}
|
||||
codyGatewayModel := modelconfigSDK.Model{
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::test-model_cody-gateway"),
|
||||
ModelName: "cody-gateway-model-name",
|
||||
ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::cody-gateway_model-id"),
|
||||
ModelName: "cody-gateway_model-name",
|
||||
}
|
||||
|
||||
modelconfigData := modelconfigSDK.ModelConfiguration{
|
||||
@ -146,24 +185,24 @@ func TestModelConfigResolver(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Models", func(t *testing.T) {
|
||||
// Note that for all these cases the returned string doesn't match
|
||||
// either the Provider ID nor the Model ID. Instead, it is the name
|
||||
// of the API Provider (e.g. "sourcegraph" if using Cody Gateway),
|
||||
// and we return the model name.
|
||||
// The models returned here are kinda confusing:
|
||||
// We replace the "provider" with whatever underlying API is used for serving responses.
|
||||
// However, we return the model IDs (rather than model Names) since that's what the
|
||||
// completions API expects.
|
||||
var (
|
||||
model string
|
||||
err error
|
||||
)
|
||||
model, err = testResolver.ChatModel()
|
||||
assert.Equal(t, "aws-bedrock/aws-bedrock-model-name", model)
|
||||
assert.Equal(t, "aws-bedrock/aws-bedrock_model-id", model)
|
||||
assert.NoError(t, err)
|
||||
|
||||
model, err = testResolver.CompletionModel()
|
||||
assert.Equal(t, "azure-openai/azure-openai-model-name", model)
|
||||
assert.Equal(t, "azure-openai/azure-openai_model-id", model)
|
||||
assert.NoError(t, err)
|
||||
|
||||
model, err = testResolver.FastChatModel()
|
||||
assert.Equal(t, "sourcegraph/cody-gateway-model-name", model)
|
||||
assert.Equal(t, "sourcegraph/cody-gateway_model-id", model)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -49,6 +49,21 @@ type legacyModelRef struct {
|
||||
serverSideConfig *types.ServerSideModelConfig
|
||||
}
|
||||
|
||||
// convertLegacyModelNameToModelID returns the ID that should be used for a model name
|
||||
// defined in the "completions" site config.
|
||||
//
|
||||
// When sending LLM models to the client, it expects to see the exact value specified in the site
|
||||
// configuration. So the client sees the model **name**. However, internally, this Sourcegraph
|
||||
// instance converts the site configuration into a modelconfigSDK.ModelConfigruation, which may
|
||||
// have a slightly different model **ID** from model name.
|
||||
//
|
||||
// When converting older-style completions config, we just keep these identical for 99.9% of
|
||||
// cases. (No need to differ.) But we need to have model IDs adhear to naming rules. So we
|
||||
// need to sanitize the results.
|
||||
func convertLegacyModelNameToModelID(model string) string {
|
||||
return modelconfig.SanitizeResourceName(model)
|
||||
}
|
||||
|
||||
// parseLegacyModelRef takes a reference to a model from the site configuration in the "legacy format",
|
||||
// and infers all the surrounding data. e.g. "claude-instant", "openai/gpt-4o".
|
||||
func parseLegacyModelRef(
|
||||
@ -83,7 +98,7 @@ func parseLegacyModelRef(
|
||||
// The model ID may contain colons or other invalid characters. So we strip those out here,
|
||||
// so that the Model's mref is valid.
|
||||
// But the model NAME remains unchanged. As that's what is sent to AWS.
|
||||
modelID = modelconfig.SanitizeResourceName(bedrockModelRef.Model)
|
||||
modelID = convertLegacyModelNameToModelID(bedrockModelRef.Model)
|
||||
modelName = bedrockModelRef.Model
|
||||
|
||||
if bedrockModelRef.ProvisionedCapacity != nil {
|
||||
@ -122,7 +137,7 @@ func parseLegacyModelRef(
|
||||
modelID = modelNameFromConfig[kind]
|
||||
}
|
||||
// Finally, sanitize the user-supplied model ID to ensure it is valid.
|
||||
modelID = modelconfig.SanitizeResourceName(modelID)
|
||||
modelID = convertLegacyModelNameToModelID(modelID)
|
||||
|
||||
default:
|
||||
// No other processing is needed.
|
||||
|
||||
@ -169,6 +169,8 @@ func TestConvertCompletionsConfig(t *testing.T) {
|
||||
}
|
||||
{
|
||||
m := siteModelConfig.ModelOverrides[0]
|
||||
// Notice how the model ID has been sanitized. (No colon.) But the model name is the same
|
||||
// from the site config. (Since that's how the model is identified in its backing API.)
|
||||
assert.EqualValues(t, "anthropic::unknown::anthropic.claude-3-opus-20240229-v1_0", m.ModelRef)
|
||||
assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0", m.ModelRef.ModelID())
|
||||
// Unlike the Model's ID, the Name is unchanged, as this is what AWS expects in the API call.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user