From 162d3836dab3d3a6b6fd0a0da6a2f296c96d2c15 Mon Sep 17 00:00:00 2001 From: Ara Date: Wed, 31 Jul 2024 18:19:19 +0200 Subject: [PATCH] Backport 5ce2eea to 5.5.x (#64166) This is a backport PR to add changes from https://github.com/sourcegraph/sourcegraph/pull/64116 to v5.5.x to main to create a release of the frontend. ## Test plan ## Changelog --------- Co-authored-by: Vincent --- cmd/customer-2315/BUILD.bazel | 65 +++++ cmd/customer-2315/image_test.yaml | 15 ++ cmd/customer-2315/main.go | 224 ++++++++++++++++++ cmd/customer-4512/BUILD.bazel | 62 +++++ cmd/customer-4512/image_test.yaml | 15 ++ cmd/customer-4512/main.go | 204 ++++++++++++++++ .../completions/client/azureopenai/openai.go | 27 ++- 7 files changed, 605 insertions(+), 7 deletions(-) create mode 100644 cmd/customer-2315/BUILD.bazel create mode 100644 cmd/customer-2315/image_test.yaml create mode 100644 cmd/customer-2315/main.go create mode 100644 cmd/customer-4512/BUILD.bazel create mode 100644 cmd/customer-4512/image_test.yaml create mode 100644 cmd/customer-4512/main.go diff --git a/cmd/customer-2315/BUILD.bazel b/cmd/customer-2315/BUILD.bazel new file mode 100644 index 00000000000..00f83643caf --- /dev/null +++ b/cmd/customer-2315/BUILD.bazel @@ -0,0 +1,65 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("@container_structure_test//:defs.bzl", "container_structure_test") +load("//dev:oci_defs.bzl", "image_repository", "oci_image", "oci_push", "oci_tarball") + +go_library( + name = "customer-2315_lib", + srcs = ["main.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/customer-2315", + tags = [TAG_CODY_PRIME], + visibility = ["//visibility:private"], + deps = [ + "@com_github_google_uuid//:uuid", + "@com_github_sourcegraph_log//:log", + ], +) + +go_binary( + name = "customer-2315", + embed = [":customer-2315_lib"], + tags = [TAG_CODY_PRIME], + visibility = ["//visibility:public"], +) + +pkg_tar( + name = "tar_customer-2315", + srcs = [":customer-2315"], +) + +oci_image( + name = "image", + base = "//wolfi-images/sourcegraph-base:base_image", + entrypoint = [ + "/sbin/tini", + "--", + "/customer-2315", + ], + tars = [":tar_customer-2315"], + user = "sourcegraph", +) + +oci_tarball( + name = "image_tarball", + image = ":image", + repo_tags = ["customer-2315:candidate"], +) + +container_structure_test( + name = "image_test", + timeout = "short", + configs = ["image_test.yaml"], + driver = "docker", + image = ":image", + tags = [ + "exclusive", + "requires-network", + TAG_CODY_PRIME, + ], +) + +oci_push( + name = "candidate_push", + image = ":image", + repository = image_repository("customer-2315"), +) diff --git a/cmd/customer-2315/image_test.yaml b/cmd/customer-2315/image_test.yaml new file mode 100644 index 00000000000..34d5175ff8a --- /dev/null +++ b/cmd/customer-2315/image_test.yaml @@ -0,0 +1,15 @@ +schemaVersion: "2.0.0" + +commandTests: + - name: "not running as root" + command: "/usr/bin/id" + args: + - -u + excludedOutput: ["^0"] + exitCode: 0 + - name: "validate /customer-2315 file exists and is executable" + command: "test" + args: + - "-x" + - "/customer-2315" + exitCode: 0 diff --git a/cmd/customer-2315/main.go b/cmd/customer-2315/main.go new file mode 100644 index 00000000000..576f0b0ec64 --- /dev/null +++ b/cmd/customer-2315/main.go @@ -0,0 +1,224 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/sourcegraph/log" + + "github.com/google/uuid" +) + +type ProxyServer struct { + accessToken string + tokenMutex sync.RWMutex + client *http.Client + azureEndpoint *url.URL + logger log.Logger +} + +func (ps *ProxyServer) readSecretFile(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +func (ps *ProxyServer) generateHeaders(bearerToken string) map[string]string { + return map[string]string{ + "correlationId": uuid.New().String(), + "dataClassification": "sensitive", + "dataSource": "internet", + "Authorization": "Bearer " + bearerToken, + } +} + +func (ps *ProxyServer) updateAccessToken() { + for { + token, err := ps.getAccessToken() + if err != nil { + ps.logger.Fatal("Error getting access token: %v", log.Error(err)) + } else { + ps.tokenMutex.Lock() + ps.accessToken = token + ps.tokenMutex.Unlock() + ps.logger.Info("Access token updated") + } + time.Sleep(1 * time.Minute) + } +} + +func (ps *ProxyServer) initializeAzureEndpoint() { + var err error + azure_endpoint, err := ps.readSecretFile("/run/secrets/azure_endpoint") + if err != nil { + ps.logger.Fatal("error reading OAUTH_URL: %v", log.Error(err)) + } + ps.azureEndpoint, err = url.Parse(azure_endpoint) + if err != nil { + ps.logger.Fatal("Invalid AZURE_ENDPOINT: %v", log.Error(err)) + } +} + +func (ps *ProxyServer) initializeClient() { + ps.client = &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 400, + MaxIdleConnsPerHost: 400, + IdleConnTimeout: 90 * time.Second, + DisableKeepAlives: false, + }, + Timeout: 30 * time.Second, + } +} + +func (ps *ProxyServer) getAccessToken() (string, error) { + url, err := ps.readSecretFile("/run/secrets/oauth_url") + if err != nil { + return "", fmt.Errorf("error reading OAUTH_URL: %v", err) + } + clientID, err := ps.readSecretFile("/run/secrets/client_id") + if err != nil { + return "", fmt.Errorf("error reading CLIENT_ID: %v", err) + } + clientSecret, err := ps.readSecretFile("/run/secrets/client_secret") + if err != nil { + return "", fmt.Errorf("error reading CLIENT_SECRET: %v", err) + } + + data := map[string]string{ + "client_id": clientID, + "client_secret": clientSecret, + "scope": "azureopenai-readwrite", + "grant_type": "client_credentials", + } + + jsonData, err := json.Marshal(data) + if err != nil { + return "", fmt.Errorf("error marshalling JSON: %v", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := ps.client.Do(req) + if err != nil { + return "", fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("request failed with status: %v", resp.Status) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("error decoding response: %v", err) + } + + token, ok := result["access_token"].(string) + if !ok { + return "", fmt.Errorf("access token not found in response") + } + + return token, nil +} + +func (ps *ProxyServer) handleProxy(w http.ResponseWriter, req *http.Request) { + target := ps.azureEndpoint.ResolveReference(req.URL) + // Create a proxy request + proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Copy headers from the original request + for header, values := range req.Header { + for _, value := range values { + proxyReq.Header.Add(header, value) + } + } + + ps.tokenMutex.RLock() + bearerToken := ps.accessToken + ps.tokenMutex.RUnlock() + // Add generated headers + headers := ps.generateHeaders(bearerToken) + for key, value := range headers { + proxyReq.Header.Set(key, value) + } + proxyReq.Header.Set("Api-Key", bearerToken) + + resp, err := ps.client.Do(proxyReq) + if err != nil { + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Write the headers and status code from the response to the client + for header, values := range resp.Header { + for _, value := range values { + w.Header().Add(header, value) + } + } + w.WriteHeader(resp.StatusCode) + + // Stream the response body to the client + reader := bufio.NewReader(resp.Body) + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if err != nil && err != io.EOF { + ps.logger.Error("Error reading response body: %v", log.Error(err)) + http.Error(w, "Error reading response from upstream server", http.StatusBadGateway) + return + } + if n == 0 { + break + } + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + ps.logger.Fatal("Error writing response: %v", log.Error(writeErr)) + break + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } +} + +func main() { + liblog := log.Init(log.Resource{ + Name: "Special Oauth Server", + }) + defer liblog.Sync() + + logger := log.Scoped("server") + + ps := &ProxyServer{ + logger: logger, + } + ps.initializeClient() + ps.initializeAzureEndpoint() + go ps.updateAccessToken() + http.HandleFunc("/", ps.handleProxy) + logger.Info("HTTP Proxy server is running on port 8080") + if err := http.ListenAndServe(":8080", nil); err != nil { + logger.Fatal("Failed to start HTTP server: %v", log.Error(err)) + } +} diff --git a/cmd/customer-4512/BUILD.bazel b/cmd/customer-4512/BUILD.bazel new file mode 100644 index 00000000000..c8370965341 --- /dev/null +++ b/cmd/customer-4512/BUILD.bazel @@ -0,0 +1,62 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("@container_structure_test//:defs.bzl", "container_structure_test") +load("//dev:oci_defs.bzl", "image_repository", "oci_image", "oci_push", "oci_tarball") + +go_library( + name = "customer-4512_lib", + srcs = ["main.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/customer-4512", + tags = [TAG_CODY_PRIME], + visibility = ["//visibility:private"], + deps = ["@com_github_sourcegraph_log//:log"], +) + +go_binary( + name = "customer-4512", + embed = [":customer-4512_lib"], + tags = [TAG_CODY_PRIME], + visibility = ["//visibility:public"], +) + +pkg_tar( + name = "tar_customer-4512", + srcs = [":customer-4512"], +) + +oci_image( + name = "image", + base = "//wolfi-images/sourcegraph-base:base_image", + entrypoint = [ + "/sbin/tini", + "--", + "/customer-4512", + ], + tars = [":tar_customer-4512"], + user = "sourcegraph", +) + +oci_tarball( + name = "image_tarball", + image = ":image", + repo_tags = ["customer-4512:candidate"], +) + +container_structure_test( + name = "image_test", + timeout = "short", + configs = ["image_test.yaml"], + driver = "docker", + image = ":image", + tags = [ + "exclusive", + "requires-network", + TAG_CODY_PRIME, + ], +) + +oci_push( + name = "candidate_push", + image = ":image", + repository = image_repository("customer-4512"), +) diff --git a/cmd/customer-4512/image_test.yaml b/cmd/customer-4512/image_test.yaml new file mode 100644 index 00000000000..7fa532eeb00 --- /dev/null +++ b/cmd/customer-4512/image_test.yaml @@ -0,0 +1,15 @@ +schemaVersion: "2.0.0" + +commandTests: + - name: "not running as root" + command: "/usr/bin/id" + args: + - -u + excludedOutput: ["^0"] + exitCode: 0 + - name: "validate /customer-4512 file exists and is executable" + command: "test" + args: + - "-x" + - "/customer-4512" + exitCode: 0 diff --git a/cmd/customer-4512/main.go b/cmd/customer-4512/main.go new file mode 100644 index 00000000000..766ee554b00 --- /dev/null +++ b/cmd/customer-4512/main.go @@ -0,0 +1,204 @@ +package main + +import ( + "bufio" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/sourcegraph/log" +) + +type Proxy struct { + accessToken string + tokenMutex sync.RWMutex + client *http.Client + azureEndpoint *url.URL + logger log.Logger +} + +func (ps *Proxy) readSecretFile(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +func (ps *Proxy) updateAccessToken() { + for { + token, err := ps.getAccessToken() + if err != nil { + ps.logger.Fatal("Error getting access token: %v", log.Error(err)) + } else { + ps.tokenMutex.Lock() + ps.accessToken = token + ps.tokenMutex.Unlock() + ps.logger.Info("Access token updated") + } + time.Sleep(1 * time.Minute) + } +} + +func (ps *Proxy) initializeAzureEndpoint() { + var err error + azure_endpoint, err := ps.readSecretFile("/run/secrets/azure_endpoint") + if err != nil { + ps.logger.Fatal("error reading OAUTH_URL: %v", log.Error(err)) + } + ps.azureEndpoint, err = url.Parse(azure_endpoint) + if err != nil { + ps.logger.Fatal("Invalid AZURE_ENDPOINT: %v", log.Error(err)) + } +} + +func (ps *Proxy) initializeClient() { + ps.client = &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 400, + MaxIdleConnsPerHost: 400, + IdleConnTimeout: 90 * time.Second, + DisableKeepAlives: false, + }, + Timeout: 30 * time.Second, + } +} + +func (ps *Proxy) getAccessToken() (string, error) { + oauth_url, err := ps.readSecretFile("/run/secrets/oauth_url") + if err != nil { + return "", fmt.Errorf("error reading OAUTH_URL: %v", err) + } + clientID, err := ps.readSecretFile("/run/secrets/client_id") + if err != nil { + return "", fmt.Errorf("error reading CLIENT_ID: %v", err) + } + clientSecret, err := ps.readSecretFile("/run/secrets/client_secret") + if err != nil { + return "", fmt.Errorf("error reading CLIENT_SECRET: %v", err) + } + + authKey := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientSecret))) + + data := url.Values{} + data.Set("grant_type", "client_credentials") + + req, err := http.NewRequest("POST", oauth_url, io.NopCloser(strings.NewReader(data.Encode()))) + if err != nil { + return "", fmt.Errorf("Failed to create request: %v", err) + } + + req.Header.Add("Authorization", "Basic "+authKey) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err := ps.client.Do(req) + if err != nil { + return "", fmt.Errorf("Failed to retrieve token: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Failed to retrieve token: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("Failed to read response body: %v", err) + } + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + ps.logger.Fatal("Failed to unmarshal response body: %v", log.Error(err)) + } + + accessToken, ok := result["access_token"].(string) + if !ok { + ps.logger.Fatal("Failed to retrieve access token from response body") + } + return accessToken, nil +} + +func (ps *Proxy) handleProxy(w http.ResponseWriter, req *http.Request) { + target := ps.azureEndpoint.ResolveReference(req.URL) + // Create a proxy request + proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Copy headers from the original request + for header, values := range req.Header { + for _, value := range values { + proxyReq.Header.Add(header, value) + } + } + + ps.tokenMutex.RLock() + bearerToken := ps.accessToken + ps.tokenMutex.RUnlock() + + // Add accesstoken headers + proxyReq.Header.Set("Api-Key", bearerToken) + resp, err := ps.client.Do(proxyReq) + if err != nil { + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Write the headers and status code from the response to the client + for header, values := range resp.Header { + for _, value := range values { + w.Header().Add(header, value) + } + } + w.WriteHeader(resp.StatusCode) + + // Stream the response body to the client + reader := bufio.NewReader(resp.Body) + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if err != nil && err != io.EOF { + ps.logger.Error("Error reading response body: %v", log.Error(err)) + http.Error(w, "Error reading response from upstream server", http.StatusBadGateway) + return + } + if n == 0 { + break + } + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + ps.logger.Fatal("Error writing response: %v", log.Error(writeErr)) + break + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } +} + +func main() { + liblog := log.Init(log.Resource{ + Name: "Cody OAuth Proxy", + }) + defer liblog.Sync() + + logger := log.Scoped("server") + + ps := &Proxy{logger: logger} + ps.initializeClient() + ps.initializeAzureEndpoint() + go ps.updateAccessToken() + http.HandleFunc("/", ps.handleProxy) + logger.Info("HTTP Proxy server is running on port 8080") + if err := http.ListenAndServe(":8080", nil); err != nil { + logger.Fatal("Failed to start HTTP server: %v", log.Error(err)) + } +} diff --git a/internal/completions/client/azureopenai/openai.go b/internal/completions/client/azureopenai/openai.go index 5bd6c4908da..18baad28691 100644 --- a/internal/completions/client/azureopenai/openai.go +++ b/internal/completions/client/azureopenai/openai.go @@ -72,7 +72,17 @@ func GetAPIClient(endpoint, accessToken string) (CompletionsClient, error) { } var err error if accessToken != "" { - credential := azcore.NewKeyCredential(accessToken) + var credential *azcore.KeyCredential + // Note: HTTP connection can be useful if customers need to run e.g. an auth proxy + // between Sourcegraph and their Azure OpenAI endpoint. + // The Azure client will prohibit sending HTTP requests if the request would contain + // credentials, so we remove credentials if the admin's intention is to send HTTP + // and not HTTPS. + if strings.HasPrefix(endpoint, "http://") { + credential = nil + } else { + credential = azcore.NewKeyCredential(accessToken) + } apiClient.client, err = azopenai.NewClientWithKeyCredential(endpoint, credential, clientOpts) } else { var opts *azidentity.DefaultAzureCredentialOptions @@ -80,13 +90,16 @@ func GetAPIClient(endpoint, accessToken string) (CompletionsClient, error) { if err != nil { return nil, err } - credential, credErr := azidentity.NewDefaultAzureCredential(opts) - if credErr != nil { - return nil, credErr - } apiClient.endpoint = endpoint - - apiClient.client, err = azopenai.NewClient(endpoint, credential, clientOpts) + if strings.HasPrefix(endpoint, "http://") { + apiClient.client, err = azopenai.NewClient(endpoint, nil, clientOpts) + } else { + credential, credErr := azidentity.NewDefaultAzureCredential(opts) + if credErr != nil { + return nil, credErr + } + apiClient.client, err = azopenai.NewClient(endpoint, credential, clientOpts) + } } return apiClient.client, err