Special oauth containers (#63880)

[See a description in Linear 

](https://linear.app/sourcegraph/issue/CODY-2845/supporting-special-azure-openai-configurations-for-enterprize#comment-6938baa0)

## Test plan
Tested with local builds and see linear for more. 

<!-- REQUIRED; info at
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

## Changelog

<!-- OPTIONAL; info at
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c
-->
This commit is contained in:
Ara 2024-07-22 17:31:03 +02:00 committed by GitHub
parent 874c9f6bd3
commit 1ca6385b1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 615 additions and 0 deletions

View File

@ -0,0 +1,65 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@container_structure_test//:defs.bzl", "container_structure_test")
load("//dev:oci_defs.bzl", "image_repository", "oci_image", "oci_push", "oci_tarball")
go_library(
name = "customer-2315_lib",
srcs = ["main.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/customer-2315",
tags = [TAG_CODY_PRIME],
visibility = ["//visibility:private"],
deps = [
"@com_github_google_uuid//:uuid",
"@com_github_sourcegraph_log//:log",
],
)
go_binary(
name = "customer-2315",
embed = [":customer-2315_lib"],
tags = [TAG_CODY_PRIME],
visibility = ["//visibility:public"],
)
pkg_tar(
name = "tar_customer-2315",
srcs = [":customer-2315"],
)
oci_image(
name = "image",
base = "//wolfi-images/sourcegraph-base:base_image",
entrypoint = [
"/sbin/tini",
"--",
"/customer-2315",
],
tars = [":tar_customer-2315"],
user = "sourcegraph",
)
oci_tarball(
name = "image_tarball",
image = ":image",
repo_tags = ["customer-2315:candidate"],
)
container_structure_test(
name = "image_test",
timeout = "short",
configs = ["image_test.yaml"],
driver = "docker",
image = ":image",
tags = [
"exclusive",
"requires-network",
TAG_CODY_PRIME,
],
)
oci_push(
name = "candidate_push",
image = ":image",
repository = image_repository("customer-2315"),
)

View File

@ -0,0 +1,15 @@
schemaVersion: "2.0.0"
commandTests:
- name: "not running as root"
command: "/usr/bin/id"
args:
- -u
excludedOutput: ["^0"]
exitCode: 0
- name: "validate /customer-2315 file exists and is executable"
command: "test"
args:
- "-x"
- "/customer-2315"
exitCode: 0

239
cmd/customer-2315/main.go Normal file
View File

@ -0,0 +1,239 @@
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/sourcegraph/log"
"github.com/google/uuid"
)
type ProxyServer struct {
accessToken string
tokenMutex sync.RWMutex
client *http.Client
azureEndpoint *url.URL
logger log.Logger
}
func (ps *ProxyServer) readSecretFile(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func (ps *ProxyServer) generateHeaders(bearerToken string) map[string]string {
return map[string]string{
"correlationId": uuid.New().String(),
"dataClassification": "sensitive",
"dataSource": "internet",
"Authorization": "Bearer " + bearerToken,
}
}
func (ps *ProxyServer) updateAccessToken() {
for {
token, err := ps.getAccessToken()
if err != nil {
ps.logger.Fatal("Error getting access token: %v", log.Error(err))
} else {
ps.tokenMutex.Lock()
ps.accessToken = token
ps.tokenMutex.Unlock()
ps.logger.Info("Access token updated")
}
time.Sleep(1 * time.Minute)
}
}
func (ps *ProxyServer) initializeAzureEndpoint() {
var err error
azure_endpoint, err := ps.readSecretFile("/run/secrets/azure_endpoint")
if err != nil {
ps.logger.Fatal("error reading OAUTH_URL: %v", log.Error(err))
}
ps.azureEndpoint, err = url.Parse(azure_endpoint)
if err != nil {
ps.logger.Fatal("Invalid AZURE_ENDPOINT: %v", log.Error(err))
}
}
func (ps *ProxyServer) initializeClient() {
ps.client = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 400,
MaxIdleConnsPerHost: 400,
IdleConnTimeout: 90 * time.Second,
DisableKeepAlives: false,
},
Timeout: 30 * time.Second,
}
}
func (ps *ProxyServer) getAccessToken() (string, error) {
url, err := ps.readSecretFile("/run/secrets/oauth_url")
if err != nil {
return "", fmt.Errorf("error reading OAUTH_URL: %v", err)
}
clientID, err := ps.readSecretFile("/run/secrets/client_id")
if err != nil {
return "", fmt.Errorf("error reading CLIENT_ID: %v", err)
}
clientSecret, err := ps.readSecretFile("/run/secrets/client_secret")
if err != nil {
return "", fmt.Errorf("error reading CLIENT_SECRET: %v", err)
}
data := map[string]string{
"client_id": clientID,
"client_secret": clientSecret,
"scope": "azureopenai-readwrite",
"grant_type": "client_credentials",
}
jsonData, err := json.Marshal(data)
if err != nil {
return "", fmt.Errorf("error marshalling JSON: %v", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := ps.client.Do(req)
if err != nil {
return "", fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("request failed with status: %v", resp.Status)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("error decoding response: %v", err)
}
token, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("access token not found in response")
}
return token, nil
}
func (ps *ProxyServer) validateApiKey(req *http.Request) bool {
proxyAccessToken, err := ps.readSecretFile("/run/secrets/proxy_access_token")
if err != nil {
return false
}
incomingAccessToken := req.Header.Get("Api-Key")
// Compare the incoming Api-Key with the environment variable
return incomingAccessToken == proxyAccessToken
}
func (ps *ProxyServer) handleProxy(w http.ResponseWriter, req *http.Request) {
target := ps.azureEndpoint.ResolveReference(req.URL)
if !ps.validateApiKey(req) {
http.Error(w, "Invalid Proxy Password", http.StatusUnauthorized)
return
}
// Create a proxy request
proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Copy headers from the original request
for header, values := range req.Header {
for _, value := range values {
proxyReq.Header.Add(header, value)
}
}
ps.tokenMutex.RLock()
bearerToken := ps.accessToken
ps.tokenMutex.RUnlock()
// Add generated headers
headers := ps.generateHeaders(bearerToken)
for key, value := range headers {
proxyReq.Header.Set(key, value)
}
proxyReq.Header.Set("Api-Key", bearerToken)
resp, err := ps.client.Do(proxyReq)
if err != nil {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Write the headers and status code from the response to the client
for header, values := range resp.Header {
for _, value := range values {
w.Header().Add(header, value)
}
}
w.WriteHeader(resp.StatusCode)
// Stream the response body to the client
reader := bufio.NewReader(resp.Body)
buf := make([]byte, 32*1024)
for {
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
ps.logger.Error("Error reading response body: %v", log.Error(err))
http.Error(w, "Error reading response from upstream server", http.StatusBadGateway)
return
}
if n == 0 {
break
}
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
ps.logger.Fatal("Error writing response: %v", log.Error(writeErr))
break
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
}
func main() {
liblog := log.Init(log.Resource{
Name: "Special Oauth Server",
})
defer liblog.Sync()
logger := log.Scoped("server")
ps := &ProxyServer{
logger: logger,
}
ps.initializeClient()
ps.initializeAzureEndpoint()
go ps.updateAccessToken()
http.HandleFunc("/", ps.handleProxy)
logger.Info("HTTPS Proxy server is running on port 8443")
if err := http.ListenAndServeTLS(":8443", "/run/secrets/cert.pem", "/run/secrets/key.pem", nil); err != nil {
logger.Fatal("Failed to start HTTPS server: %v", log.Error(err))
}
}

View File

@ -0,0 +1,62 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@container_structure_test//:defs.bzl", "container_structure_test")
load("//dev:oci_defs.bzl", "image_repository", "oci_image", "oci_push", "oci_tarball")
go_library(
name = "customer-4512_lib",
srcs = ["main.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/customer-4512",
tags = [TAG_CODY_PRIME],
visibility = ["//visibility:private"],
deps = ["@com_github_sourcegraph_log//:log"],
)
go_binary(
name = "customer-4512",
embed = [":customer-4512_lib"],
tags = [TAG_CODY_PRIME],
visibility = ["//visibility:public"],
)
pkg_tar(
name = "tar_customer-4512",
srcs = [":customer-4512"],
)
oci_image(
name = "image",
base = "//wolfi-images/sourcegraph-base:base_image",
entrypoint = [
"/sbin/tini",
"--",
"/customer-4512",
],
tars = [":tar_customer-4512"],
user = "sourcegraph",
)
oci_tarball(
name = "image_tarball",
image = ":image",
repo_tags = ["customer-4512:candidate"],
)
container_structure_test(
name = "image_test",
timeout = "short",
configs = ["image_test.yaml"],
driver = "docker",
image = ":image",
tags = [
"exclusive",
"requires-network",
TAG_CODY_PRIME,
],
)
oci_push(
name = "candidate_push",
image = ":image",
repository = image_repository("customer-4512"),
)

View File

@ -0,0 +1,15 @@
schemaVersion: "2.0.0"
commandTests:
- name: "not running as root"
command: "/usr/bin/id"
args:
- -u
excludedOutput: ["^0"]
exitCode: 0
- name: "validate /customer-4512 file exists and is executable"
command: "test"
args:
- "-x"
- "/customer-4512"
exitCode: 0

219
cmd/customer-4512/main.go Normal file
View File

@ -0,0 +1,219 @@
package main
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/sourcegraph/log"
)
type Proxy struct {
accessToken string
tokenMutex sync.RWMutex
client *http.Client
azureEndpoint *url.URL
logger log.Logger
}
func (ps *Proxy) readSecretFile(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}
func (ps *Proxy) updateAccessToken() {
for {
token, err := ps.getAccessToken()
if err != nil {
ps.logger.Fatal("Error getting access token: %v", log.Error(err))
} else {
ps.tokenMutex.Lock()
ps.accessToken = token
ps.tokenMutex.Unlock()
ps.logger.Info("Access token updated")
}
time.Sleep(1 * time.Minute)
}
}
func (ps *Proxy) initializeAzureEndpoint() {
var err error
azure_endpoint, err := ps.readSecretFile("/run/secrets/azure_endpoint")
if err != nil {
ps.logger.Fatal("error reading OAUTH_URL: %v", log.Error(err))
}
ps.azureEndpoint, err = url.Parse(azure_endpoint)
if err != nil {
ps.logger.Fatal("Invalid AZURE_ENDPOINT: %v", log.Error(err))
}
}
func (ps *Proxy) initializeClient() {
ps.client = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 400,
MaxIdleConnsPerHost: 400,
IdleConnTimeout: 90 * time.Second,
DisableKeepAlives: false,
},
Timeout: 30 * time.Second,
}
}
func (ps *Proxy) getAccessToken() (string, error) {
oauth_url, err := ps.readSecretFile("/run/secrets/oauth_url")
if err != nil {
return "", fmt.Errorf("error reading OAUTH_URL: %v", err)
}
clientID, err := ps.readSecretFile("/run/secrets/client_id")
if err != nil {
return "", fmt.Errorf("error reading CLIENT_ID: %v", err)
}
clientSecret, err := ps.readSecretFile("/run/secrets/client_secret")
if err != nil {
return "", fmt.Errorf("error reading CLIENT_SECRET: %v", err)
}
authKey := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientSecret)))
data := url.Values{}
data.Set("grant_type", "client_credentials")
req, err := http.NewRequest("POST", oauth_url, io.NopCloser(strings.NewReader(data.Encode())))
if err != nil {
return "", fmt.Errorf("Failed to create request: %v", err)
}
req.Header.Add("Authorization", "Basic "+authKey)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := ps.client.Do(req)
if err != nil {
return "", fmt.Errorf("Failed to retrieve token: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Failed to retrieve token: %s", resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("Failed to read response body: %v", err)
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
ps.logger.Fatal("Failed to unmarshal response body: %v", log.Error(err))
}
accessToken, ok := result["access_token"].(string)
if !ok {
ps.logger.Fatal("Failed to retrieve access token from response body")
}
return accessToken, nil
}
func (ps *Proxy) validateApiKey(req *http.Request) bool {
proxyAccessToken, err := ps.readSecretFile("/run/secrets/proxy_access_token")
if err != nil {
return false
}
incomingAccessToken := req.Header.Get("Api-Key")
// Compare the incoming Api-Key with the environment variable
return incomingAccessToken == proxyAccessToken
}
func (ps *Proxy) handleProxy(w http.ResponseWriter, req *http.Request) {
target := ps.azureEndpoint.ResolveReference(req.URL)
if !ps.validateApiKey(req) {
http.Error(w, "Invalid Proxy Password", http.StatusUnauthorized)
return
}
// Create a proxy request
proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Copy headers from the original request
for header, values := range req.Header {
for _, value := range values {
proxyReq.Header.Add(header, value)
}
}
ps.tokenMutex.RLock()
bearerToken := ps.accessToken
ps.tokenMutex.RUnlock()
// Add accesstoken headers
proxyReq.Header.Set("Api-Key", bearerToken)
resp, err := ps.client.Do(proxyReq)
if err != nil {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Write the headers and status code from the response to the client
for header, values := range resp.Header {
for _, value := range values {
w.Header().Add(header, value)
}
}
w.WriteHeader(resp.StatusCode)
// Stream the response body to the client
reader := bufio.NewReader(resp.Body)
buf := make([]byte, 32*1024)
for {
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
ps.logger.Error("Error reading response body: %v", log.Error(err))
http.Error(w, "Error reading response from upstream server", http.StatusBadGateway)
return
}
if n == 0 {
break
}
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
ps.logger.Fatal("Error writing response: %v", log.Error(writeErr))
break
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
}
func main() {
liblog := log.Init(log.Resource{
Name: "Cody OAuth Proxy",
})
defer liblog.Sync()
logger := log.Scoped("server")
ps := &Proxy{logger: logger}
ps.initializeClient()
ps.initializeAzureEndpoint()
go ps.updateAccessToken()
http.HandleFunc("/", ps.handleProxy)
logger.Info("HTTPS Proxy server is running on port 8443")
if err := http.ListenAndServeTLS(":8443", "/run/secrets/cert.pem", "/run/secrets/key.pem", nil); err != nil {
logger.Fatal("Failed to start HTTPS server: %v", log.Error(err))
}
}