From 7b1bc10a30689b9d40e6d9354177f2559d4379e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=93lafur=20P=C3=A1ll=20Geirsson?= Date: Tue, 13 Aug 2024 13:48:15 +0200 Subject: [PATCH] chore/API: speed up edit/test feedback loop for llmapi module (#64437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, it took ~6 seconds for a single edit/test/debug feedback loop in the `llmapi` module. After this change, it's now 1-2s. The reason the feedback loop was slow was that we depended on the `//cmd/frontend/internal/modelconfig` target, which transitively brings in `graphqlbackend` and all the migration code, which adds huge overhead to Go link times. It was relatively easy to untangle this dependency so I went ahead and removed it to boost my local feedback loop. ## Test plan Green CI. To measure the timing, I ran the tests, made a tiny change and ran the tests against to measure the total time to build+test. ``` # Before ❯ time go test -timeout 30s github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi ok github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi 2.394s go test -timeout 30s 4.26s user 4.73s system 166% cpu 5.393 total # After ❯ time go test -timeout 30s github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi ok github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi 0.862s go test -timeout 30s 1.20s user 1.21s system 135% cpu 1.774 total ``` ## Changelog --- cmd/frontend/internal/httpapi/BUILD.bazel | 1 + cmd/frontend/internal/httpapi/httpapi.go | 3 ++- cmd/frontend/internal/llmapi/BUILD.bazel | 5 ---- .../llmapi/chat_completions_handler.go | 8 ++++--- .../llmapi/chat_completions_handler_test.go | 24 +++++-------------- cmd/frontend/internal/llmapi/httpapi.go | 8 +++++-- cmd/frontend/internal/llmapi/utils_test.go | 4 ++-- 7 files changed, 22 insertions(+), 31 deletions(-) diff --git a/cmd/frontend/internal/httpapi/BUILD.bazel b/cmd/frontend/internal/httpapi/BUILD.bazel index 0027df2d539..4a983253e23 100644 --- a/cmd/frontend/internal/httpapi/BUILD.bazel +++ b/cmd/frontend/internal/httpapi/BUILD.bazel @@ -58,6 +58,7 @@ go_library( "//internal/gitserver/gitdomain", "//internal/httpcli", "//internal/licensing", + "//internal/modelconfig/types", "//internal/opencodegraph", "//internal/repoupdater", "//internal/sams", diff --git a/cmd/frontend/internal/httpapi/httpapi.go b/cmd/frontend/internal/httpapi/httpapi.go index 90ef7a69dad..8fe24eb789b 100644 --- a/cmd/frontend/internal/httpapi/httpapi.go +++ b/cmd/frontend/internal/httpapi/httpapi.go @@ -39,6 +39,7 @@ import ( "github.com/sourcegraph/sourcegraph/internal/encryption/keyring" "github.com/sourcegraph/sourcegraph/internal/env" "github.com/sourcegraph/sourcegraph/internal/gitserver" + "github.com/sourcegraph/sourcegraph/internal/modelconfig/types" "github.com/sourcegraph/sourcegraph/internal/sams" "github.com/sourcegraph/sourcegraph/internal/search" "github.com/sourcegraph/sourcegraph/internal/search/searchcontexts" @@ -322,7 +323,7 @@ func NewHandler( repo.Path("/refresh").Methods("POST").Handler(jsonHandler(serveRepoRefresh(db))) llm := m.PathPrefix("/llm/").Subrouter() - llmapi.RegisterHandlers(llm, m) + llmapi.RegisterHandlers(llm, m, func() (*types.ModelConfiguration, error) { return modelconfig.Get().Get() }) m.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("API no route: %s %s from %s", r.Method, r.URL, r.Referer()) diff --git a/cmd/frontend/internal/llmapi/BUILD.bazel b/cmd/frontend/internal/llmapi/BUILD.bazel index 32829cb0d7b..d1ef2947150 100644 --- a/cmd/frontend/internal/llmapi/BUILD.bazel +++ b/cmd/frontend/internal/llmapi/BUILD.bazel @@ -11,7 +11,6 @@ go_library( importpath = "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi", visibility = ["//cmd/frontend:__subpackages__"], deps = [ - "//cmd/frontend/internal/modelconfig", "//internal/completions/types", "//internal/modelconfig/types", "//lib/errors", @@ -31,15 +30,11 @@ go_test( data = glob(["golly-recordings/**"]), embed = [":llmapi"], deps = [ - "//cmd/frontend/internal/modelconfig", - "//internal/conf", "//internal/golly", "//internal/httpcli", "//internal/modelconfig/types", - "//schema", "@com_github_gorilla_mux//:mux", "@com_github_hexops_autogold_v2//:autogold", "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", ], ) diff --git a/cmd/frontend/internal/llmapi/chat_completions_handler.go b/cmd/frontend/internal/llmapi/chat_completions_handler.go index fe0e209eb81..1dd03e6fa88 100644 --- a/cmd/frontend/internal/llmapi/chat_completions_handler.go +++ b/cmd/frontend/internal/llmapi/chat_completions_handler.go @@ -14,7 +14,6 @@ import ( "github.com/sourcegraph/log" sglog "github.com/sourcegraph/log" - "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/modelconfig" "github.com/sourcegraph/sourcegraph/lib/errors" completions "github.com/sourcegraph/sourcegraph/internal/completions/types" @@ -31,8 +30,12 @@ type chatCompletionsHandler struct { // would have an in-house service we can use instead of going via HTTP but using HTTP // simplifies a lof of things (including testing). apiHandler http.Handler + + GetModelConfig GetModelConfigurationFunc } +type GetModelConfigurationFunc func() (*types.ModelConfiguration, error) + var _ http.Handler = (*chatCompletionsHandler)(nil) func (h *chatCompletionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -45,8 +48,7 @@ func (h *chatCompletionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques decoder := json.NewDecoder(io.NopCloser(bytes.NewBuffer(body))) - modelConfigSvc := modelconfig.Get() - currentModelConfig, err := modelConfigSvc.Get() + currentModelConfig, err := h.GetModelConfig() if err != nil { http.Error(w, fmt.Sprintf("modelConfigSvc.Get: %v", err), http.StatusInternalServerError) return diff --git a/cmd/frontend/internal/llmapi/chat_completions_handler_test.go b/cmd/frontend/internal/llmapi/chat_completions_handler_test.go index bb941a83826..d601e1c5aa8 100644 --- a/cmd/frontend/internal/llmapi/chat_completions_handler_test.go +++ b/cmd/frontend/internal/llmapi/chat_completions_handler_test.go @@ -8,28 +8,16 @@ import ( "github.com/hexops/autogold/v2" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/modelconfig" - "github.com/sourcegraph/sourcegraph/internal/conf" types "github.com/sourcegraph/sourcegraph/internal/modelconfig/types" - "github.com/sourcegraph/sourcegraph/schema" ) -func SetSiteConfig(t *testing.T, siteConfig schema.SiteConfiguration) { - conf.Mock(&conf.Unified{SiteConfiguration: siteConfig}) - if err := modelconfig.ResetMock(); err != nil { - require.NoError(t, err) - } -} - func TestChatCompletionsHandler(t *testing.T) { - c := newTest(t) - chatModels := c.getChatModels() - assert.NoError(t, modelconfig.InitMock()) - assert.NoError(t, modelconfig.ResetMockWithStaticData(&types.ModelConfiguration{ - Models: chatModels, - })) + var c *publicrestTest + c = newTest(t, func() (*types.ModelConfiguration, error) { + chatModels := c.getChatModels() + return &types.ModelConfiguration{Models: chatModels}, nil + }) t.Run("/.api/llm/chat/completions (400 stream=true)", func(t *testing.T) { rr := c.chatCompletions(t, `{ @@ -144,7 +132,7 @@ func TestChatCompletionsHandler(t *testing.T) { }`).Equal(t, body) }) - for _, model := range chatModels { + for _, model := range c.getChatModels() { if model.DisplayName == "starcoder" { // Skip starcoder because it's not a chat model even if it has the "chat" capability // per the /.api/modelconfig/supported-models.json endpoint. Context: diff --git a/cmd/frontend/internal/llmapi/httpapi.go b/cmd/frontend/internal/llmapi/httpapi.go index fbf8b2c545c..038182e540d 100644 --- a/cmd/frontend/internal/llmapi/httpapi.go +++ b/cmd/frontend/internal/llmapi/httpapi.go @@ -7,8 +7,12 @@ import ( sglog "github.com/sourcegraph/log" ) -func RegisterHandlers(m *mux.Router, apiHandler http.Handler) { +func RegisterHandlers(m *mux.Router, apiHandler http.Handler, getModelConfigFunc GetModelConfigurationFunc) { logger := sglog.Scoped("llmapi") - m.Path("/chat/completions").Methods("POST").Handler(&chatCompletionsHandler{logger: logger, apiHandler: apiHandler}) + m.Path("/chat/completions").Methods("POST").Handler(&chatCompletionsHandler{ + logger: logger, + apiHandler: apiHandler, + GetModelConfig: getModelConfigFunc, + }) } diff --git a/cmd/frontend/internal/llmapi/utils_test.go b/cmd/frontend/internal/llmapi/utils_test.go index d328ba7af25..a4fa3ecffee 100644 --- a/cmd/frontend/internal/llmapi/utils_test.go +++ b/cmd/frontend/internal/llmapi/utils_test.go @@ -25,13 +25,13 @@ type publicrestTest struct { HttpClient http.Handler } -func newTest(t *testing.T) *publicrestTest { +func newTest(t *testing.T, getModelConfigFunc GetModelConfigurationFunc) *publicrestTest { MockUUID = "mocked-llmapi-uuid" gollyDoer := golly.NewGollyDoer(t, httpcli.TestExternalClient) recordReplayHandler := newRecordReplayHandler(gollyDoer, gollyDoer.DotcomCredentials()) apiHandler := mux.NewRouter().PathPrefix("/.api/llm/").Subrouter() - RegisterHandlers(apiHandler, recordReplayHandler) + RegisterHandlers(apiHandler, recordReplayHandler, getModelConfigFunc) return &publicrestTest{ t: t,