Fix azure completions api (#63491)

[Linear Issue 

](https://linear.app/sourcegraph/issue/CODY-2586/fix-completions-models-api-for-azure-to-use-the-right-model-with-the)

The purpose of this PR is to make a backwords compatible solution such
that the completions logic in our codebase for azure supports both the
completions API(which is old) and also supports the chat/completions API
which is new. This way we can use models from both of them with
autocomplete.

NOTe: Since we can't figure out which model we are using because azure
has the deployment name instead of model name and because of that we
can't decide which API to use for which model we try with both of the
APIs and then the API that works is cached for that model and then we
used the cached API logic to choose the api to make subsequent
completion calls this way we can choose either of the APIs and not have
added latency with completions.


## Test plan
I used the azure keys to try out different deployment models that we
have both with the old and the new api.
Old API -> Completions (gpt-3.5-turbo-instruct, gpt-3.5-turbo(301),
gpt-3.5-turbo(613))
New API -> Chat Completions(gpt-3.5-turbo(301), gpt-4o,
gpt-3.5-turbo(613), gpt-3.5-turbo-16k)

NOTE both of the set of models work seamless with this PR.


<!-- REQUIRED; info at
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

## Changelog

<!-- OPTIONAL; info at
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c
-->
This commit is contained in:
Ara 2024-06-28 21:41:17 +02:00 committed by GitHub
parent b25cf26b05
commit f5d5deceb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 199 additions and 27 deletions

View File

@ -120,6 +120,7 @@ func newCompletionsHandler(
requestParams.User = completionsConfig.User
requestParams.AzureChatModel = completionsConfig.AzureChatModel
requestParams.AzureCompletionModel = completionsConfig.AzureCompletionModel
requestParams.AzureUseDeprecatedCompletionsAPIForOldModels = completionsConfig.AzureUseDeprecatedCompletionsAPIForOldModels
if err != nil {
// NOTE: We return the raw error to the user assuming that it contains relevant
// user-facing diagnostic information, and doesn't leak any internal details.

View File

@ -38,7 +38,6 @@ var authProxyURL = os.Getenv("CODY_AZURE_OPENAI_IDENTITY_HTTP_PROXY")
// it will acquire a short lived token and reusing the client
// prevents acquiring a new token on every request.
// The client will refresh the token as needed.
var apiClient completionsClient
type completionsClient struct {
@ -139,9 +138,9 @@ func (c *azureCompletionClient) Complete(
switch feature {
case types.CompletionsFeatureCode:
return completeAutocomplete(ctx, c.client, requestParams)
return completeAutocomplete(ctx, c.client, requestParams, log)
case types.CompletionsFeatureChat:
return completeChat(ctx, c.client, requestParams)
return completeChat(ctx, c.client, requestParams, log)
default:
return nil, errors.New("invalid completions feature")
}
@ -151,6 +150,54 @@ func completeAutocomplete(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
log log.Logger,
) (*types.CompletionResponse, error) {
if requestParams.AzureUseDeprecatedCompletionsAPIForOldModels {
return doCompletionsAPIAutocomplete(ctx, client, requestParams, log)
}
return doChatCompletionsAPIAutocomplete(ctx, client, requestParams, log)
}
func doChatCompletionsAPIAutocomplete(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
response, err := client.GetChatCompletions(ctx, getChatOptions(requestParams), nil)
if err != nil {
return nil, toStatusCodeError(err)
}
if !hasValidFirstChatChoice(response.Choices) {
return &types.CompletionResponse{}, nil
}
tokenManager := tokenusage.NewManager()
inputTokens, err := NumTokensFromAzureOpenAiMessages(requestParams.Messages, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
outputTokens, err := NumTokensFromAzureOpenAiResponseString(*response.Choices[0].Delta.Content, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
// Note: If we had an error calculating input/output tokens, that is unfortunate, the
// best thing we can do is record zero token usage which would be our hint to look at
// the logs for errors.
err = tokenManager.UpdateTokenCountsFromModelUsage(inputTokens, outputTokens, tokenizer.AzureModel+"/"+requestParams.Model, "code_completions", tokenusage.AzureOpenAI)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
return &types.CompletionResponse{
Completion: *response.Choices[0].Delta.Content,
StopReason: string(*response.Choices[0].FinishReason),
}, nil
}
func doCompletionsAPIAutocomplete(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
options, err := getCompletionsOptions(requestParams)
if err != nil {
@ -160,7 +207,22 @@ func completeAutocomplete(
if err != nil {
return nil, toStatusCodeError(err)
}
tokenManager := tokenusage.NewManager()
inputTokens, err := NumTokensFromAzureOpenAiMessages(requestParams.Messages, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
outputTokens, err := NumTokensFromAzureOpenAiResponseString(*response.Choices[0].Text, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
// Note: If we had an error calculating input/output tokens, that is unfortunate, the
// best thing we can do is record zero token usage which would be our hint to look at
// the logs for errors.
err = tokenManager.UpdateTokenCountsFromModelUsage(inputTokens, outputTokens, tokenizer.AzureModel+"/"+requestParams.Model, "code_completions", tokenusage.AzureOpenAI)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
// Text and FinishReason are documented as REQUIRED but checking just to be safe
if !hasValidFirstCompletionsChoice(response.Choices) {
return &types.CompletionResponse{}, nil
@ -175,6 +237,7 @@ func completeChat(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
logger log.Logger,
) (*types.CompletionResponse, error) {
response, err := client.GetChatCompletions(ctx, getChatOptions(requestParams), nil)
if err != nil {
@ -183,6 +246,22 @@ func completeChat(
if !hasValidFirstChatChoice(response.Choices) {
return &types.CompletionResponse{}, nil
}
tokenManager := tokenusage.NewManager()
inputTokens, err := NumTokensFromAzureOpenAiMessages(requestParams.Messages, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
outputTokens, err := NumTokensFromAzureOpenAiResponseString(*response.Choices[0].Delta.Content, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
// Note: If we had an error calculating input/output tokens, that is unfortunate, the
// best thing we can do is record zero token usage which would be our hint to look at
// the logs for errors.
err = tokenManager.UpdateTokenCountsFromModelUsage(inputTokens, outputTokens, tokenizer.AzureModel+"/"+requestParams.Model, "code_completions", tokenusage.AzureOpenAI)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
return &types.CompletionResponse{
Completion: *response.Choices[0].Delta.Content,
StopReason: string(*response.Choices[0].FinishReason),
@ -262,6 +341,78 @@ func streamAutocomplete(
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
if requestParams.AzureUseDeprecatedCompletionsAPIForOldModels {
return doStreamCompletionsAPI(ctx, client, requestParams, sendEvent, logger)
}
return doStreamChatCompletionsAPI(ctx, client, requestParams, sendEvent, logger)
}
// Streaming with ChatCompletions API
func doStreamChatCompletionsAPI(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
resp, err := client.GetChatCompletionsStream(ctx, getChatOptions(requestParams), nil)
if err != nil {
return err
}
defer resp.ChatCompletionsStream.Close()
var content string
for {
entry, err := resp.ChatCompletionsStream.Read()
if errors.Is(err, io.EOF) {
tokenManager := tokenusage.NewManager()
inputTokens, err := NumTokensFromAzureOpenAiMessages(requestParams.Messages, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count input tokens with the token manager %w ", log.Error(err))
}
outputTokens, err := NumTokensFromAzureOpenAiResponseString(content, requestParams.AzureChatModel)
if err != nil {
logger.Warn("Failed to count output tokens with the token manager %w ", log.Error(err))
}
// Note: If we had an error calculating input/output tokens, that is unfortunate, the
// best thing we can do is record zero token usage which would be our hint to look at
// the logs for errors.
err = tokenManager.UpdateTokenCountsFromModelUsage(inputTokens, outputTokens, tokenizer.AzureModel+"/"+requestParams.Model, "code_completions", tokenusage.AzureOpenAI)
if err != nil {
logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err))
}
return nil
}
if err != nil {
return err
}
if hasValidFirstChatChoice(entry.Choices) {
content += *entry.Choices[0].Delta.Content
finish := ""
if entry.Choices[0].FinishReason != nil {
finish = string(*entry.Choices[0].FinishReason)
}
ev := types.CompletionResponse{
Completion: content,
StopReason: finish,
}
err := sendEvent(ev)
if err != nil {
return err
}
}
}
}
// Streaming with Completions API
func doStreamCompletionsAPI(
ctx context.Context,
client CompletionsClient,
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
logger log.Logger,
) error {
options, err := getCompletionsOptions(requestParams)
if err != nil {
@ -302,7 +453,6 @@ func streamAutocomplete(
if err != nil {
return err
}
// hasValidFirstCompletionsChoice checks for a valid 1st choice which has text
if hasValidFirstCompletionsChoice(entry.Choices) {
content += *entry.Choices[0].Text
@ -322,6 +472,17 @@ func streamAutocomplete(
}
}
// isOperationNotSupportedError checks if the error is due to using the wrong API for a model.
// Detecting this error helps in choosing the correct API.
func isOperationNotSupportedError(err error) bool {
var responseError *azcore.ResponseError
if errors.As(err, &responseError) {
return responseError.StatusCode == http.StatusBadRequest &&
responseError.ErrorCode == "OperationNotSupported"
}
return false
}
func streamChat(
ctx context.Context,
client CompletionsClient,

View File

@ -55,19 +55,20 @@ type CompletionRequestParameters struct {
// Prompt exists only for backwards compatibility. Do not use it in new
// implementations. It will be removed once we are reasonably sure 99%
// of VSCode extension installations are upgraded to a new Cody version.
Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
MaxTokensToSample int `json:"maxTokensToSample,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float32 `json:"topP,omitempty"`
Model string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
Logprobs *uint8 `json:"logprobs"`
User string `json:"user,omitempty"`
AzureChatModel string `json:"azureChatModel,omitempty"`
AzureCompletionModel string `json:"azureCompletionModel,omitempty"`
Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
MaxTokensToSample int `json:"maxTokensToSample,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float32 `json:"topP,omitempty"`
Model string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
Logprobs *uint8 `json:"logprobs"`
User string `json:"user,omitempty"`
AzureChatModel string `json:"azureChatModel,omitempty"`
AzureCompletionModel string `json:"azureCompletionModel,omitempty"`
AzureUseDeprecatedCompletionsAPIForOldModels bool `json:"azureUseDeprecatedCompletionsAPIForOldModels,omitempty"`
}
// IsStream returns whether a streaming response is requested. For backwards

View File

@ -932,13 +932,14 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
}
computedConfig := &conftypes.CompletionsConfig{
Provider: conftypes.CompletionsProviderName(completionsConfig.Provider),
AccessToken: completionsConfig.AccessToken,
ChatModel: completionsConfig.ChatModel,
ChatModelMaxTokens: completionsConfig.ChatModelMaxTokens,
SmartContextWindow: completionsConfig.SmartContextWindow,
FastChatModel: completionsConfig.FastChatModel,
FastChatModelMaxTokens: completionsConfig.FastChatModelMaxTokens,
Provider: conftypes.CompletionsProviderName(completionsConfig.Provider),
AccessToken: completionsConfig.AccessToken,
ChatModel: completionsConfig.ChatModel,
ChatModelMaxTokens: completionsConfig.ChatModelMaxTokens,
SmartContextWindow: completionsConfig.SmartContextWindow,
FastChatModel: completionsConfig.FastChatModel,
FastChatModelMaxTokens: completionsConfig.FastChatModelMaxTokens,
AzureUseDeprecatedCompletionsAPIForOldModels: completionsConfig.AzureUseDeprecatedCompletionsAPIForOldModels,
CompletionModel: completionsConfig.CompletionModel,
CompletionModelMaxTokens: completionsConfig.CompletionModelMaxTokens,
Endpoint: completionsConfig.Endpoint,

View File

@ -14,8 +14,9 @@ type CompletionsConfig struct {
CompletionModel string
CompletionModelMaxTokens int
AzureCompletionModel string
AzureChatModel string
AzureCompletionModel string
AzureChatModel string
AzureUseDeprecatedCompletionsAPIForOldModels bool
AccessToken string
Provider CompletionsProviderName

View File

@ -668,6 +668,8 @@ type Completions struct {
AzureChatModel string `json:"azureChatModel,omitempty"`
// AzureCompletionModel description: Optional: Specify the Azure OpenAI model name for chat completions. This is only needed when you want to count tokens associated with your azure model
AzureCompletionModel string `json:"azureCompletionModel,omitempty"`
// AzureUseDeprecatedCompletionsAPIForOldModels description: Enables the use of the older completions API for select Azure OpenAI models.
AzureUseDeprecatedCompletionsAPIForOldModels bool `json:"azureUseDeprecatedCompletionsAPIForOldModels,omitempty"`
// ChatModel description: The model used for chat completions. If using the default provider 'sourcegraph', a reasonable default model will be set.
// NOTE: The Anthropic messages API does not support model names like claude-2 or claude-instant-1 where only the major version is specified as they are retired. We recommend using a specific model identifier as specified here https://docs.anthropic.com/claude/docs/models-overview#model-comparison
ChatModel string `json:"chatModel,omitempty"`

View File

@ -2890,6 +2890,11 @@
"enum": ["claude-2", "claude-instant-1"]
}
},
"azureUseDeprecatedCompletionsAPIForOldModels": {
"description": "Enables the use of the older completions API for select Azure OpenAI models.",
"type": "boolean",
"default": false
},
"fastChatModelMaxTokens": {
"description": "The maximum number of tokens to use as client when talking to fastChatModel. If not set, clients need to set their own limit.",
"type": "integer"