From 582022d30111f2ff2eb03a9fb1a02c82bde4de64 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Tue, 6 Aug 2024 13:28:33 -0700 Subject: [PATCH] Return model IDs from GraphQL, not model Names (#64307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ::sigh:: the problem here is super in the weeds, but ultimately this fixes a problem introduced when using AWS Bedrock and Sourcegraph instances using the older style "completions" config. ## The problem AWS Bedrock has some LLM model names that contain a colon, e.g. `anthropic.claude-3-opus-20240229-v1:0`. Cody clients connecting to Sourcegraph instances using the older style "completions" config will obtain the available LLM models by using GraphGL. So the Cody client would see that the chat model is `anthropic.claude-3-opus-20240229-v1:0`. However, under the hood, the Sourcegraph instance will convert the site config into the newer `modelconfig` format. And during that conversion, we use a _different value_ for the **model ID** than what is in the site config. (The **model name** is what is sent to the LLM API, and is unmodified. The model ID is a stable, unique identifier but is sanitized so that it adheres to naming rules.) Because of this, we have a problem. When the Cody client makes a request to the HTTP completions API with the model name of `anthropic.claude-3-opus-20240229-v1:0` or `anthropic/anthropic.claude-3-opus-20240229-v1:0` it fails. Because there is no model with ID `...v1:0`. (We only have the sanitized version, `...v1_0`.) ## The fix There were a few ways we could fix this, but this goes with just having the GraphQL component return the model ID instead of the model name. So that when the Cody client passes that model ID to the completions API, everything works as it should. And, practically speaking, for 99.9% of cases, the model name and model ID will be identical. We only strip out non-URL safe characters and colons, which usually aren't used in model names. ## Potential bugs With this fix however, there is a specific combination of { client, server, and model name } where things could in theory break. Specifically: Client | Server | Modelname | Works | --- | --- | --- | --- | unaware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] | aware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] | unaware-of-modelconfig | using-modelconfig | standard | 🟢 [1] | aware-of-modelconfig | using-modelconfig | standard | 🟢 [3] | unaware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] | aware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] | unaware-of-modelconfig | using-modelconfig | non-standard | 🔴 [2] | aware-of-modelconfig | using-modelconfig | non-standard | 🟢 [3] | 1. If the model name is something that doesn't require sanitization, there is no problem. The model ID will be the same as the model name, and things will work like they do today. 2. If the model name gets sanitized, then IFF the Cody client were to make a decision based on that exact model name, it wouldn't work. Because it would receive the sanitized name, and not the real one. As long as the Cody client is only passing that model name onto the Sourcegraph backend which will recognize the sanitized model name / ID, all is well. 3. If the client and server are new, and using model config, then this shouldn't be a problem because the client would use a different API to fetch the Sourcegraph instance's supported models. And within the client, natively refer to the model ID instead of the model name. Fixes [PRIME-464](https://linear.app/sourcegraph/issue/PRIME-464/aws-bedrock-x-completions-config-does-not-work-if-model-name-has-a). ## Test plan Added some unit tests. ## Changelog NA --- cmd/frontend/internal/modelconfig/resolver.go | 18 +++-- .../internal/modelconfig/resolver_test.go | 65 +++++++++++++++---- .../modelconfig/siteconfig_completions.go | 19 +++++- .../siteconfig_completions_test.go | 2 + 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/cmd/frontend/internal/modelconfig/resolver.go b/cmd/frontend/internal/modelconfig/resolver.go index c7abef8f275..49cc2438170 100644 --- a/cmd/frontend/internal/modelconfig/resolver.go +++ b/cmd/frontend/internal/modelconfig/resolver.go @@ -69,7 +69,7 @@ type completionsConfigResolver struct { } func (c *completionsConfigResolver) ChatModel() (string, error) { - return c.config.ChatModel, nil + return convertLegacyModelNameToModelID(c.config.ChatModel), nil } func (c *completionsConfigResolver) ChatModelMaxTokens() (*int32, error) { @@ -92,7 +92,7 @@ func (c *completionsConfigResolver) DisableClientConfigAPI() bool { } func (c *completionsConfigResolver) FastChatModel() (string, error) { - return c.config.FastChatModel, nil + return convertLegacyModelNameToModelID(c.config.FastChatModel), nil } func (c *completionsConfigResolver) FastChatModelMaxTokens() (*int32, error) { @@ -108,7 +108,7 @@ func (c *completionsConfigResolver) Provider() string { } func (c *completionsConfigResolver) CompletionModel() (string, error) { - return c.config.CompletionModel, nil + return convertLegacyModelNameToModelID(c.config.CompletionModel), nil } func (c *completionsConfigResolver) CompletionModelMaxTokens() (*int32, error) { @@ -254,12 +254,18 @@ func (r *modelconfigResolver) CompletionModelMaxTokens() (*int32, error) { // the provider as needed to match older behavior. (See unit tests and convertProviderID for // more information.) func (r *modelconfigResolver) toLegacyModelRef(model modelconfigSDK.Model) string { + modelID := model.ModelRef.ModelID() providerID := model.ModelRef.ProviderID() legacyProviderName := r.convertProviderID(providerID) - // For compatibility, we are returning the model _name_ instead of the model _id_. - // So the client will see "claude-3-xxxx" not the shortened model ID like "claude-3". - return fmt.Sprintf("%s/%s", legacyProviderName, model.ModelName) + // Potential issue: Older Cody clients calling the GraphQL may expect to see the model **name** + // such as "claude-3-sonnet-20240229". But it is important that we only return the model **ID** + // because that is what the HTTP completions API is expecting to see from the client. + // + // So when using older Cody clients, unaware of the newer modelconfig system, this could lead + // to some errors. (But newer clients won't be using this GraphQL endpoint at all and instead + // just use the newer modelconfig system, so hopefully this won't be a major concern.) + return fmt.Sprintf("%s/%s", legacyProviderName, modelID) } // convertProviderID returns the _API Provider_ for the referenced modelconfig provider. diff --git a/cmd/frontend/internal/modelconfig/resolver_test.go b/cmd/frontend/internal/modelconfig/resolver_test.go index 0c4c2f33d2e..3ab4644d41b 100644 --- a/cmd/frontend/internal/modelconfig/resolver_test.go +++ b/cmd/frontend/internal/modelconfig/resolver_test.go @@ -59,6 +59,45 @@ func TestCompletionsResolver(t *testing.T) { model, err = testResolver.CompletionModel() assert.EqualValues(t, siteConfigData.CompletionModel, model) assert.NoError(t, err) + + // In the GraphQL resolver we are returning the model name expressed in + // the site config, but the HTTP completions API only accepts model IDs. + // For the "completions" config, these are 99% identical, but in some cases + // may differ. + // + // In the completions API (see get_model.go) we lookup a model by its mref + // or model ID, and then use the unmodified model name when making the API + // request. + t.Run("Sanitization", func(t *testing.T) { + // Copy and introduce more challenging model names. + updatedSiteConfigData := *siteConfigData + updatedSiteConfigData.ChatModel = "anthropic.claude-3-opus-20240229-v1:0/so:many:colons" + updatedSiteConfigData.FastChatModel = "all/sorts@of;special_chars&but!no#sanitization" + updatedSiteConfigData.CompletionModel = "other invalid tokens 😭😭😭" + + updatedResolver := &completionsConfigResolver{ + config: &updatedSiteConfigData, + } + + var ( + model string + err error + ) + model, err = updatedResolver.ChatModel() + assert.NotEqualValues(t, updatedSiteConfigData.ChatModel, model) + assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0/so_many_colons", model) + assert.NoError(t, err) + + // Fast chat had wonky characters, but none required sanitizing. + model, err = updatedResolver.FastChatModel() + assert.EqualValues(t, updatedSiteConfigData.FastChatModel, model) + assert.NoError(t, err) + + model, err = updatedResolver.CompletionModel() + assert.NotEqualValues(t, updatedSiteConfigData.CompletionModel, model) + assert.EqualValues(t, "other_invalid_tokens_____________", model) + assert.NoError(t, err) + }) }) } @@ -75,8 +114,8 @@ func TestModelConfigResolver(t *testing.T) { }, } awsBedrockModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::test-model_aws-bedrock"), - ModelName: "aws-bedrock-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::aws-bedrock_model-id"), + ModelName: "aws-bedrock_model-name", } // Azure OpenAI provider and model. @@ -89,8 +128,8 @@ func TestModelConfigResolver(t *testing.T) { }, } azureOpenAIModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::test-model_azure-openai"), - ModelName: "azure-openai-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::azure-openai_model-id"), + ModelName: "azure-openai_model-name", } // Cody Gateway provider and model. @@ -103,8 +142,8 @@ func TestModelConfigResolver(t *testing.T) { }, } codyGatewayModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::test-model_cody-gateway"), - ModelName: "cody-gateway-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::cody-gateway_model-id"), + ModelName: "cody-gateway_model-name", } modelconfigData := modelconfigSDK.ModelConfiguration{ @@ -146,24 +185,24 @@ func TestModelConfigResolver(t *testing.T) { }) t.Run("Models", func(t *testing.T) { - // Note that for all these cases the returned string doesn't match - // either the Provider ID nor the Model ID. Instead, it is the name - // of the API Provider (e.g. "sourcegraph" if using Cody Gateway), - // and we return the model name. + // The models returned here are kinda confusing: + // We replace the "provider" with whatever underlying API is used for serving responses. + // However, we return the model IDs (rather than model Names) since that's what the + // completions API expects. var ( model string err error ) model, err = testResolver.ChatModel() - assert.Equal(t, "aws-bedrock/aws-bedrock-model-name", model) + assert.Equal(t, "aws-bedrock/aws-bedrock_model-id", model) assert.NoError(t, err) model, err = testResolver.CompletionModel() - assert.Equal(t, "azure-openai/azure-openai-model-name", model) + assert.Equal(t, "azure-openai/azure-openai_model-id", model) assert.NoError(t, err) model, err = testResolver.FastChatModel() - assert.Equal(t, "sourcegraph/cody-gateway-model-name", model) + assert.Equal(t, "sourcegraph/cody-gateway_model-id", model) assert.NoError(t, err) }) } diff --git a/cmd/frontend/internal/modelconfig/siteconfig_completions.go b/cmd/frontend/internal/modelconfig/siteconfig_completions.go index 4d4fd4ad2eb..704d530557f 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig_completions.go +++ b/cmd/frontend/internal/modelconfig/siteconfig_completions.go @@ -49,6 +49,21 @@ type legacyModelRef struct { serverSideConfig *types.ServerSideModelConfig } +// convertLegacyModelNameToModelID returns the ID that should be used for a model name +// defined in the "completions" site config. +// +// When sending LLM models to the client, it expects to see the exact value specified in the site +// configuration. So the client sees the model **name**. However, internally, this Sourcegraph +// instance converts the site configuration into a modelconfigSDK.ModelConfigruation, which may +// have a slightly different model **ID** from model name. +// +// When converting older-style completions config, we just keep these identical for 99.9% of +// cases. (No need to differ.) But we need to have model IDs adhear to naming rules. So we +// need to sanitize the results. +func convertLegacyModelNameToModelID(model string) string { + return modelconfig.SanitizeResourceName(model) +} + // parseLegacyModelRef takes a reference to a model from the site configuration in the "legacy format", // and infers all the surrounding data. e.g. "claude-instant", "openai/gpt-4o". func parseLegacyModelRef( @@ -83,7 +98,7 @@ func parseLegacyModelRef( // The model ID may contain colons or other invalid characters. So we strip those out here, // so that the Model's mref is valid. // But the model NAME remains unchanged. As that's what is sent to AWS. - modelID = modelconfig.SanitizeResourceName(bedrockModelRef.Model) + modelID = convertLegacyModelNameToModelID(bedrockModelRef.Model) modelName = bedrockModelRef.Model if bedrockModelRef.ProvisionedCapacity != nil { @@ -122,7 +137,7 @@ func parseLegacyModelRef( modelID = modelNameFromConfig[kind] } // Finally, sanitize the user-supplied model ID to ensure it is valid. - modelID = modelconfig.SanitizeResourceName(modelID) + modelID = convertLegacyModelNameToModelID(modelID) default: // No other processing is needed. diff --git a/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go b/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go index c433ef0a178..d6c9ecafa83 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go +++ b/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go @@ -169,6 +169,8 @@ func TestConvertCompletionsConfig(t *testing.T) { } { m := siteModelConfig.ModelOverrides[0] + // Notice how the model ID has been sanitized. (No colon.) But the model name is the same + // from the site config. (Since that's how the model is identified in its backing API.) assert.EqualValues(t, "anthropic::unknown::anthropic.claude-3-opus-20240229-v1_0", m.ModelRef) assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0", m.ModelRef.ModelID()) // Unlike the Model's ID, the Name is unchanged, as this is what AWS expects in the API call.