Adding HTTP Requests support for Azure OpenAI (#64116)

This PR adds special support for http requests to azure OpenAI and
changes special customer configs to use HTTP instead of HTTPS.


## Test plan
Tested this PR locally 
<!-- REQUIRED; info at
https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles
-->

## Changelog

<!-- OPTIONAL; info at
https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c
-->

---------

Co-authored-by: Vincent <evict@users.noreply.github.com>
This commit is contained in:
Ara 2024-07-30 17:45:17 +02:00 committed by GitHub
parent 7a3da57188
commit 5ce2eead9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 43 deletions

View File

@ -138,23 +138,8 @@ func (ps *ProxyServer) getAccessToken() (string, error) {
return token, nil
}
func (ps *ProxyServer) validateApiKey(req *http.Request) bool {
proxyAccessToken, err := ps.readSecretFile("/run/secrets/proxy_access_token")
if err != nil {
return false
}
incomingAccessToken := req.Header.Get("Api-Key")
// Compare the incoming Api-Key with the environment variable
return incomingAccessToken == proxyAccessToken
}
func (ps *ProxyServer) handleProxy(w http.ResponseWriter, req *http.Request) {
target := ps.azureEndpoint.ResolveReference(req.URL)
if !ps.validateApiKey(req) {
http.Error(w, "Invalid Proxy Password", http.StatusUnauthorized)
return
}
// Create a proxy request
proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body)
if err != nil {
@ -232,8 +217,8 @@ func main() {
ps.initializeAzureEndpoint()
go ps.updateAccessToken()
http.HandleFunc("/", ps.handleProxy)
logger.Info("HTTPS Proxy server is running on port 8443")
if err := http.ListenAndServeTLS(":8443", "/run/secrets/cert.pem", "/run/secrets/key.pem", nil); err != nil {
logger.Fatal("Failed to start HTTPS server: %v", log.Error(err))
logger.Info("HTTP Proxy server is running on port 8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
logger.Fatal("Failed to start HTTP server: %v", log.Error(err))
}
}

View File

@ -124,23 +124,8 @@ func (ps *Proxy) getAccessToken() (string, error) {
return accessToken, nil
}
func (ps *Proxy) validateApiKey(req *http.Request) bool {
proxyAccessToken, err := ps.readSecretFile("/run/secrets/proxy_access_token")
if err != nil {
return false
}
incomingAccessToken := req.Header.Get("Api-Key")
// Compare the incoming Api-Key with the environment variable
return incomingAccessToken == proxyAccessToken
}
func (ps *Proxy) handleProxy(w http.ResponseWriter, req *http.Request) {
target := ps.azureEndpoint.ResolveReference(req.URL)
if !ps.validateApiKey(req) {
http.Error(w, "Invalid Proxy Password", http.StatusUnauthorized)
return
}
// Create a proxy request
proxyReq, err := http.NewRequest(req.Method, target.String(), req.Body)
if err != nil {
@ -212,8 +197,8 @@ func main() {
ps.initializeAzureEndpoint()
go ps.updateAccessToken()
http.HandleFunc("/", ps.handleProxy)
logger.Info("HTTPS Proxy server is running on port 8443")
if err := http.ListenAndServeTLS(":8443", "/run/secrets/cert.pem", "/run/secrets/key.pem", nil); err != nil {
logger.Fatal("Failed to start HTTPS server: %v", log.Error(err))
logger.Info("HTTP Proxy server is running on port 8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
logger.Fatal("Failed to start HTTP server: %v", log.Error(err))
}
}

View File

@ -83,7 +83,17 @@ func GetAPIClient(endpoint, accessToken string) (CompletionsClient, error) {
var err error
if accessToken != "" {
credential := azcore.NewKeyCredential(accessToken)
var credential *azcore.KeyCredential
// Note: HTTP connection can be useful if customers need to run e.g. an auth proxy
// between Sourcegraph and their Azure OpenAI endpoint.
// The Azure client will prohibit sending HTTP requests if the request would contain
// credentials, so we remove credentials if the admin's intention is to send HTTP
// and not HTTPS.
if strings.HasPrefix(endpoint, "http://") {
credential = nil
} else {
credential = azcore.NewKeyCredential(accessToken)
}
apiClient.client, err = azopenai.NewClientWithKeyCredential(endpoint, credential, clientOpts)
} else {
var opts *azidentity.DefaultAzureCredentialOptions
@ -91,13 +101,16 @@ func GetAPIClient(endpoint, accessToken string) (CompletionsClient, error) {
if err != nil {
return nil, err
}
credential, credErr := azidentity.NewDefaultAzureCredential(opts)
if credErr != nil {
return nil, credErr
}
apiClient.endpoint = endpoint
apiClient.client, err = azopenai.NewClient(endpoint, credential, clientOpts)
if strings.HasPrefix(endpoint, "http://") {
apiClient.client, err = azopenai.NewClient(endpoint, nil, clientOpts)
} else {
credential, credErr := azidentity.NewDefaultAzureCredential(opts)
if credErr != nil {
return nil, credErr
}
apiClient.client, err = azopenai.NewClient(endpoint, credential, clientOpts)
}
}
return apiClient.client, err