diff --git a/cmd/frontend/internal/modelconfig/resolver.go b/cmd/frontend/internal/modelconfig/resolver.go index c7abef8f275..49cc2438170 100644 --- a/cmd/frontend/internal/modelconfig/resolver.go +++ b/cmd/frontend/internal/modelconfig/resolver.go @@ -69,7 +69,7 @@ type completionsConfigResolver struct { } func (c *completionsConfigResolver) ChatModel() (string, error) { - return c.config.ChatModel, nil + return convertLegacyModelNameToModelID(c.config.ChatModel), nil } func (c *completionsConfigResolver) ChatModelMaxTokens() (*int32, error) { @@ -92,7 +92,7 @@ func (c *completionsConfigResolver) DisableClientConfigAPI() bool { } func (c *completionsConfigResolver) FastChatModel() (string, error) { - return c.config.FastChatModel, nil + return convertLegacyModelNameToModelID(c.config.FastChatModel), nil } func (c *completionsConfigResolver) FastChatModelMaxTokens() (*int32, error) { @@ -108,7 +108,7 @@ func (c *completionsConfigResolver) Provider() string { } func (c *completionsConfigResolver) CompletionModel() (string, error) { - return c.config.CompletionModel, nil + return convertLegacyModelNameToModelID(c.config.CompletionModel), nil } func (c *completionsConfigResolver) CompletionModelMaxTokens() (*int32, error) { @@ -254,12 +254,18 @@ func (r *modelconfigResolver) CompletionModelMaxTokens() (*int32, error) { // the provider as needed to match older behavior. (See unit tests and convertProviderID for // more information.) func (r *modelconfigResolver) toLegacyModelRef(model modelconfigSDK.Model) string { + modelID := model.ModelRef.ModelID() providerID := model.ModelRef.ProviderID() legacyProviderName := r.convertProviderID(providerID) - // For compatibility, we are returning the model _name_ instead of the model _id_. - // So the client will see "claude-3-xxxx" not the shortened model ID like "claude-3". - return fmt.Sprintf("%s/%s", legacyProviderName, model.ModelName) + // Potential issue: Older Cody clients calling the GraphQL may expect to see the model **name** + // such as "claude-3-sonnet-20240229". But it is important that we only return the model **ID** + // because that is what the HTTP completions API is expecting to see from the client. + // + // So when using older Cody clients, unaware of the newer modelconfig system, this could lead + // to some errors. (But newer clients won't be using this GraphQL endpoint at all and instead + // just use the newer modelconfig system, so hopefully this won't be a major concern.) + return fmt.Sprintf("%s/%s", legacyProviderName, modelID) } // convertProviderID returns the _API Provider_ for the referenced modelconfig provider. diff --git a/cmd/frontend/internal/modelconfig/resolver_test.go b/cmd/frontend/internal/modelconfig/resolver_test.go index 0c4c2f33d2e..3ab4644d41b 100644 --- a/cmd/frontend/internal/modelconfig/resolver_test.go +++ b/cmd/frontend/internal/modelconfig/resolver_test.go @@ -59,6 +59,45 @@ func TestCompletionsResolver(t *testing.T) { model, err = testResolver.CompletionModel() assert.EqualValues(t, siteConfigData.CompletionModel, model) assert.NoError(t, err) + + // In the GraphQL resolver we are returning the model name expressed in + // the site config, but the HTTP completions API only accepts model IDs. + // For the "completions" config, these are 99% identical, but in some cases + // may differ. + // + // In the completions API (see get_model.go) we lookup a model by its mref + // or model ID, and then use the unmodified model name when making the API + // request. + t.Run("Sanitization", func(t *testing.T) { + // Copy and introduce more challenging model names. + updatedSiteConfigData := *siteConfigData + updatedSiteConfigData.ChatModel = "anthropic.claude-3-opus-20240229-v1:0/so:many:colons" + updatedSiteConfigData.FastChatModel = "all/sorts@of;special_chars&but!no#sanitization" + updatedSiteConfigData.CompletionModel = "other invalid tokens 😭😭😭" + + updatedResolver := &completionsConfigResolver{ + config: &updatedSiteConfigData, + } + + var ( + model string + err error + ) + model, err = updatedResolver.ChatModel() + assert.NotEqualValues(t, updatedSiteConfigData.ChatModel, model) + assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0/so_many_colons", model) + assert.NoError(t, err) + + // Fast chat had wonky characters, but none required sanitizing. + model, err = updatedResolver.FastChatModel() + assert.EqualValues(t, updatedSiteConfigData.FastChatModel, model) + assert.NoError(t, err) + + model, err = updatedResolver.CompletionModel() + assert.NotEqualValues(t, updatedSiteConfigData.CompletionModel, model) + assert.EqualValues(t, "other_invalid_tokens_____________", model) + assert.NoError(t, err) + }) }) } @@ -75,8 +114,8 @@ func TestModelConfigResolver(t *testing.T) { }, } awsBedrockModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::test-model_aws-bedrock"), - ModelName: "aws-bedrock-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::aws-bedrock_model-id"), + ModelName: "aws-bedrock_model-name", } // Azure OpenAI provider and model. @@ -89,8 +128,8 @@ func TestModelConfigResolver(t *testing.T) { }, } azureOpenAIModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::test-model_azure-openai"), - ModelName: "azure-openai-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::azure-openai_model-id"), + ModelName: "azure-openai_model-name", } // Cody Gateway provider and model. @@ -103,8 +142,8 @@ func TestModelConfigResolver(t *testing.T) { }, } codyGatewayModel := modelconfigSDK.Model{ - ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::test-model_cody-gateway"), - ModelName: "cody-gateway-model-name", + ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::cody-gateway_model-id"), + ModelName: "cody-gateway_model-name", } modelconfigData := modelconfigSDK.ModelConfiguration{ @@ -146,24 +185,24 @@ func TestModelConfigResolver(t *testing.T) { }) t.Run("Models", func(t *testing.T) { - // Note that for all these cases the returned string doesn't match - // either the Provider ID nor the Model ID. Instead, it is the name - // of the API Provider (e.g. "sourcegraph" if using Cody Gateway), - // and we return the model name. + // The models returned here are kinda confusing: + // We replace the "provider" with whatever underlying API is used for serving responses. + // However, we return the model IDs (rather than model Names) since that's what the + // completions API expects. var ( model string err error ) model, err = testResolver.ChatModel() - assert.Equal(t, "aws-bedrock/aws-bedrock-model-name", model) + assert.Equal(t, "aws-bedrock/aws-bedrock_model-id", model) assert.NoError(t, err) model, err = testResolver.CompletionModel() - assert.Equal(t, "azure-openai/azure-openai-model-name", model) + assert.Equal(t, "azure-openai/azure-openai_model-id", model) assert.NoError(t, err) model, err = testResolver.FastChatModel() - assert.Equal(t, "sourcegraph/cody-gateway-model-name", model) + assert.Equal(t, "sourcegraph/cody-gateway_model-id", model) assert.NoError(t, err) }) } diff --git a/cmd/frontend/internal/modelconfig/siteconfig_completions.go b/cmd/frontend/internal/modelconfig/siteconfig_completions.go index 4d4fd4ad2eb..704d530557f 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig_completions.go +++ b/cmd/frontend/internal/modelconfig/siteconfig_completions.go @@ -49,6 +49,21 @@ type legacyModelRef struct { serverSideConfig *types.ServerSideModelConfig } +// convertLegacyModelNameToModelID returns the ID that should be used for a model name +// defined in the "completions" site config. +// +// When sending LLM models to the client, it expects to see the exact value specified in the site +// configuration. So the client sees the model **name**. However, internally, this Sourcegraph +// instance converts the site configuration into a modelconfigSDK.ModelConfigruation, which may +// have a slightly different model **ID** from model name. +// +// When converting older-style completions config, we just keep these identical for 99.9% of +// cases. (No need to differ.) But we need to have model IDs adhear to naming rules. So we +// need to sanitize the results. +func convertLegacyModelNameToModelID(model string) string { + return modelconfig.SanitizeResourceName(model) +} + // parseLegacyModelRef takes a reference to a model from the site configuration in the "legacy format", // and infers all the surrounding data. e.g. "claude-instant", "openai/gpt-4o". func parseLegacyModelRef( @@ -83,7 +98,7 @@ func parseLegacyModelRef( // The model ID may contain colons or other invalid characters. So we strip those out here, // so that the Model's mref is valid. // But the model NAME remains unchanged. As that's what is sent to AWS. - modelID = modelconfig.SanitizeResourceName(bedrockModelRef.Model) + modelID = convertLegacyModelNameToModelID(bedrockModelRef.Model) modelName = bedrockModelRef.Model if bedrockModelRef.ProvisionedCapacity != nil { @@ -122,7 +137,7 @@ func parseLegacyModelRef( modelID = modelNameFromConfig[kind] } // Finally, sanitize the user-supplied model ID to ensure it is valid. - modelID = modelconfig.SanitizeResourceName(modelID) + modelID = convertLegacyModelNameToModelID(modelID) default: // No other processing is needed. diff --git a/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go b/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go index c433ef0a178..d6c9ecafa83 100644 --- a/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go +++ b/cmd/frontend/internal/modelconfig/siteconfig_completions_test.go @@ -169,6 +169,8 @@ func TestConvertCompletionsConfig(t *testing.T) { } { m := siteModelConfig.ModelOverrides[0] + // Notice how the model ID has been sanitized. (No colon.) But the model name is the same + // from the site config. (Since that's how the model is identified in its backing API.) assert.EqualValues(t, "anthropic::unknown::anthropic.claude-3-opus-20240229-v1_0", m.ModelRef) assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0", m.ModelRef.ModelID()) // Unlike the Model's ID, the Name is unchanged, as this is what AWS expects in the API call.