This commit is contained in:
Ólafur Páll Geirsson 2024-08-14 22:09:58 +02:00 committed by GitHub
commit cd762476a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 64 additions and 24 deletions

View File

@ -74,7 +74,7 @@ func (c *openAIChatCompletionStreamClient) Complete(
logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err))
}
return &types.CompletionResponse{
Completion: response.Choices[0].Text,
Completion: response.Choices[0].Message.Content,
StopReason: response.Choices[0].FinishReason,
}, nil
}
@ -138,7 +138,7 @@ func (c *openAIChatCompletionStreamClient) Stream(
if len(event.Choices) > 0 {
if request.Feature == types.CompletionsFeatureCode {
content += event.Choices[0].Text
content += event.Choices[0].Message.Content
} else {
content += event.Choices[0].Delta.Content
}

View File

@ -25,32 +25,35 @@ func (c *mockDoer) Do(r *http.Request) (*http.Response, error) {
return c.do(r)
}
func TestErrStatusNotOK(t *testing.T) {
tokenManager := tokenusage.NewManager()
mockClient := NewClient(&mockDoer{
var compRequest = types.CompletionRequest{
Feature: types.CompletionsFeatureChat,
Version: types.CompletionsVersionLegacy,
ModelConfigInfo: types.ModelConfigInfo{
Provider: modelconfigSDK.Provider{
ID: modelconfigSDK.ProviderID("xxx-provider-id-xxx"),
},
Model: modelconfigSDK.Model{
ModelRef: modelconfigSDK.ModelRef("provider::apiversion::test-model"),
},
},
Parameters: types.CompletionRequestParameters{
RequestedModel: "xxx-requested-model-xxx",
},
}
func NewMockClient(statusCode int, response string) types.CompletionsClient {
return NewClient(&mockDoer{
func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: io.NopCloser(bytes.NewReader([]byte("oh no, please slow down!"))),
StatusCode: statusCode,
Body: io.NopCloser(bytes.NewReader([]byte(response))),
}, nil
},
}, "", "", *tokenManager)
}, "", "", *tokenusage.NewManager())
}
compRequest := types.CompletionRequest{
Feature: types.CompletionsFeatureChat,
Version: types.CompletionsVersionLegacy,
ModelConfigInfo: types.ModelConfigInfo{
Provider: modelconfigSDK.Provider{
ID: modelconfigSDK.ProviderID("xxx-provider-id-xxx"),
},
Model: modelconfigSDK.Model{
ModelRef: modelconfigSDK.ModelRef("provider::apiversion::test-model"),
},
},
Parameters: types.CompletionRequestParameters{
RequestedModel: "xxx-requested-model-xxx",
},
}
func TestErrStatusNotOK(t *testing.T) {
mockClient := NewMockClient(http.StatusTooManyRequests, "oh no, please slow down!")
t.Run("Complete", func(t *testing.T) {
logger := log.Scoped("completions")
@ -74,3 +77,36 @@ func TestErrStatusNotOK(t *testing.T) {
assert.True(t, ok)
})
}
func TestNonStreamingResponseParsing(t *testing.T) {
mockClient := NewMockClient(http.StatusOK, `{
"id": "chatcmpl-9wEJ9hnLdPcCLrfdZLrRPGOz48Pmo",
"object": "chat.completion",
"created": 1723665051,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "yes",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 1,
"total_tokens": 16
},
"system_fingerprint": "fp_48196bc67a"
}`)
logger := log.Scoped("completions")
resp, err := mockClient.Complete(context.Background(), logger, compRequest)
require.NoError(t, err)
assert.NotNil(t, resp)
autogold.Expect(&types.CompletionResponse{Completion: "yes", StopReason: "stop"}).Equal(t, resp)
}

View File

@ -50,10 +50,14 @@ type openaiChoiceDelta struct {
Content string `json:"content"`
}
type openaiMessage struct {
Content string `json:"content"`
}
type openaiChoice struct {
Delta openaiChoiceDelta `json:"delta"`
Message openaiMessage `json:"message"`
Role string `json:"role"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}