(1/3) Add Options for Auth via Form Data (#110)

Co-authored-by: Devin Buhl <onedr0p@users.noreply.github.com>
This commit is contained in:
Russell Troxel 2023-03-17 11:53:13 -07:00 committed by GitHub
parent 9237ee5d7c
commit 5834fa69fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 53 deletions

View File

@ -286,6 +286,9 @@ func validation(config *cli.Context) error {
if config.String("url") == "" && !apiKeyIsSet && config.String("config") == "" {
return cli.Exit("url and api-key or config must be set, not none of them", 1)
}
if config.Bool("form-auth") && (config.String("auth-username") == "" || config.String("auth-password") == "") {
return cli.Exit("username and password must be set if form-auth is set", 1)
}
return nil
}
@ -343,16 +346,25 @@ func flags(arr string) []cli.Flag {
EnvVars: []string{"DISABLE_SSL_VERIFY"},
},
&cli.StringFlag{
Name: "basic-auth-username",
Usage: "Provide the username for basic auth",
Name: "auth-username",
Aliases: []string{"basic-auth-username"},
Usage: "Provide the username for basic or form auth",
Required: false,
EnvVars: []string{"BASIC_AUTH_USERNAME"},
EnvVars: []string{"AUTH_USERNAME", "BASIC_AUTH_USERNAME"},
},
&cli.StringFlag{
Name: "basic-auth-password",
Usage: "Provide the password for basic auth",
Name: "auth-password",
Aliases: []string{"basic-auth-password"},
Usage: "Provide the password for basic or form auth",
Required: false,
EnvVars: []string{"BASIC_AUTH_PASSWORD"},
EnvVars: []string{"AUTH_PASSWORD", "BASIC_AUTH_PASSWORD"},
},
&cli.BoolFlag{
Name: "form-auth",
Usage: "Use form authentication rather than basic auth",
Value: false,
Required: false,
EnvVars: []string{"FORM_AUTH"},
},
&cli.BoolFlag{
Name: "enable-unknown-queue-items",

View File

@ -21,11 +21,8 @@ type Client struct {
// NewClient method initializes a new Radarr client.
func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
var apiKey string
var baseURL *url.URL
auth := AuthConfig{
Username: c.String("basic-auth-username"),
Password: c.String("basic-auth-password"),
}
apiVersion := cf.ApiVersion
@ -35,8 +32,8 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("Couldn't parse URL: %w", err)
}
baseURL = baseURL.JoinPath(cf.UrlBase, "api", apiVersion)
auth.ApiKey = cf.ApiKey
baseURL = baseURL.JoinPath(cf.UrlBase)
apiKey = cf.ApiKey
} else {
// Otherwise use the value provided in the api-key flag
@ -45,22 +42,45 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("Couldn't parse URL: %w", err)
}
baseURL = baseURL.JoinPath("api", apiVersion)
if c.String("api-key") != "" {
auth.ApiKey = c.String("api-key")
apiKey = c.String("api-key")
} else if c.String("api-key-file") != "" {
data, err := os.ReadFile(c.String("api-key-file"))
if err != nil {
return nil, fmt.Errorf("Couldn't Read API Key file %w", err)
}
auth.ApiKey = string(data)
apiKey = string(data)
}
}
baseTransport := http.DefaultTransport
if c.Bool("disable-ssl-verify") {
baseTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
var auth Authenticator
if c.Bool("form-auth") {
auth = &FormAuth{
Username: c.String("auth-username"),
Password: c.String("auth-password"),
ApiKey: apiKey,
AuthBaseURL: baseURL,
Transport: baseTransport,
}
} else if c.String("username") != "" && c.String("password") != "" {
auth = &BasicAuth{
Username: c.String("auth-username"),
Password: c.String("auth-password"),
ApiKey: apiKey,
}
} else {
auth = &ApiKeyAuth{
ApiKey: apiKey,
}
}
return &Client{
httpClient: http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
@ -68,7 +88,7 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
},
Transport: NewArrTransport(auth, baseTransport),
},
URL: *baseURL,
URL: *baseURL.JoinPath("api", apiVersion),
}, nil
}

View File

@ -3,21 +3,22 @@ package client
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
type Authenticator interface {
Auth(req *http.Request) error
}
// ArrTransport is a http.RoundTripper that adds authentication to requests
type ArrTransport struct {
inner http.RoundTripper
auth AuthConfig
auth Authenticator
}
type AuthConfig struct {
Username string
Password string
ApiKey string
}
func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport {
func NewArrTransport(auth Authenticator, inner http.RoundTripper) *ArrTransport {
return &ArrTransport{
inner: inner,
auth: auth,
@ -25,10 +26,10 @@ func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport {
}
func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.auth.Username != "" && t.auth.Password != "" {
req.SetBasicAuth(t.auth.Username, t.auth.Password)
err := t.auth.Auth(req)
if err != nil {
return nil, fmt.Errorf("Error authenticating request: %w", err)
}
req.Header.Add("X-Api-Key", t.auth.ApiKey)
resp, err := t.inner.RoundTrip(req)
if err != nil || resp.StatusCode >= 500 {
@ -57,3 +58,84 @@ func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
return resp, nil
}
type ApiKeyAuth struct {
ApiKey string
}
func (a *ApiKeyAuth) Auth(req *http.Request) error {
req.Header.Add("X-Api-Key", a.ApiKey)
return nil
}
type BasicAuth struct {
Username string
Password string
ApiKey string
}
func (a *BasicAuth) Auth(req *http.Request) error {
req.SetBasicAuth(a.Username, a.Password)
req.Header.Add("X-Api-Key", a.ApiKey)
return nil
}
type FormAuth struct {
Username string
Password string
ApiKey string
AuthBaseURL *url.URL
Transport http.RoundTripper
cookie *http.Cookie
}
func (a *FormAuth) Auth(req *http.Request) error {
if a.cookie == nil || a.cookie.Expires.Before(time.Now().Add(-5*time.Minute)) {
form := url.Values{
"username": {a.Username},
"password": {a.Password},
"rememberMe": {"on"},
}
u := a.AuthBaseURL.JoinPath("login")
u.Query().Add("ReturnUrl", "/general/settings")
authReq, err := http.NewRequest("POST", u.String(), strings.NewReader(form.Encode()))
if err != nil {
return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err)
}
authReq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
authReq.Header.Add("Content-Length", fmt.Sprintf("%d", len(form.Encode())))
client := &http.Client{Transport: a.Transport, CheckRedirect: func(req *http.Request, via []*http.Request) error {
if req.URL.Query().Get("loginFailed") == "true" {
return fmt.Errorf("Failed to renew FormAuth Cookie: Login Failed")
}
return http.ErrUseLastResponse
}}
authResp, err := client.Do(authReq)
if err != nil {
return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err)
}
if authResp.StatusCode != 302 {
return fmt.Errorf("Failed to renew FormAuth Cookie: Received Status Code %d", authResp.StatusCode)
}
for _, cookie := range authResp.Cookies() {
if strings.HasSuffix(cookie.Name, "arrAuth") {
copy := *cookie
a.cookie = &copy
break
}
return fmt.Errorf("Failed to renew FormAuth Cookie: No Cookie with suffix 'arrAuth' found")
}
}
req.AddCookie(a.cookie)
req.Header.Add("X-Api-Key", a.ApiKey)
return nil
}

View File

@ -4,7 +4,10 @@ import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@ -22,30 +25,30 @@ func (t testRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error)
}
func TestRoundTrip_Auth(t *testing.T) {
require := require.New(t)
parameters := []struct {
name string
auth AuthConfig
auth Authenticator
testFunc func(req *http.Request) (*http.Response, error)
}{
{
name: "BasicAuth",
auth: AuthConfig{
auth: &BasicAuth{
Username: TEST_USER,
Password: TEST_PASS,
ApiKey: TEST_KEY,
},
testFunc: func(req *http.Request) (*http.Response, error) {
require.NotNil(t, req, "Request should not be nil")
require.NotNil(t, req.Header, "Request header should not be nil")
require.NotEmpty(t, req.Header.Get("Authorization"), "Authorization header should be set")
require.NotNil(req, "Request should not be nil")
require.NotNil(req.Header, "Request header should not be nil")
require.NotEmpty(req.Header.Get("Authorization"), "Authorization header should be set")
require.Equal(
t,
"Basic "+base64.StdEncoding.EncodeToString([]byte(TEST_USER+":"+TEST_PASS)),
req.Header.Get("Authorization"),
"Authorization Header set to wrong value",
)
require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
return &http.Response{
StatusCode: 200,
Body: nil,
@ -55,17 +58,15 @@ func TestRoundTrip_Auth(t *testing.T) {
},
{
name: "ApiKey",
auth: AuthConfig{
Username: "",
Password: "",
ApiKey: TEST_KEY,
auth: &ApiKeyAuth{
ApiKey: TEST_KEY,
},
testFunc: func(req *http.Request) (*http.Response, error) {
require.NotNil(t, req, "Request should not be nil")
require.NotNil(t, req.Header, "Request header should not be nil")
require.Empty(t, req.Header.Get("Authorization"), "Authorization header should be empty")
require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
require.NotNil(req, "Request should not be nil")
require.NotNil(req.Header, "Request header should not be nil")
require.Empty(req.Header.Get("Authorization"), "Authorization header should be empty")
require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
return &http.Response{
StatusCode: 200,
Body: nil,
@ -79,13 +80,89 @@ func TestRoundTrip_Auth(t *testing.T) {
transport := NewArrTransport(param.auth, testRoundTripFunc(param.testFunc))
client := &http.Client{Transport: transport}
req, err := http.NewRequest("GET", "http://example.com", nil)
require.Nil(t, err, "Error creating request: %s", err)
require.NoError(err, "Error creating request: %s", err)
_, err = client.Do(req)
require.Nil(t, err, "Error sending request: %s", err)
require.NoError(err, "Error sending request: %s", err)
})
}
}
func TestRoundTrip_FormAuth(t *testing.T) {
require := require.New(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NotNil(r, "Request should not be nil")
require.NotNil(r.Header, "Request header should not be nil")
require.Empty(r.Header.Get("Authorization"), "Authorization header should be empty")
require.Equal("POST", r.Method, "Request method should be POST")
require.Equal("/login", r.URL.Path, "Request URL should be /login")
require.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type"), "Content-Type should be application/x-www-form-urlencoded")
require.Equal(TEST_USER, r.FormValue("username"), "Username should be %s", TEST_USER)
require.Equal(TEST_PASS, r.FormValue("password"), "Password should be %s", TEST_PASS)
http.SetCookie(w, &http.Cookie{
Name: "RadarrAuth",
Value: "abcdef1234567890abcdef1234567890",
Expires: time.Now().Add(24 * time.Hour),
})
w.WriteHeader(http.StatusFound)
w.Write([]byte("OK"))
}))
defer ts.Close()
tsUrl, _ := url.Parse(ts.URL)
auth := &FormAuth{
Username: TEST_USER,
Password: TEST_PASS,
ApiKey: TEST_KEY,
AuthBaseURL: tsUrl,
Transport: http.DefaultTransport,
}
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
require.NotNil(req, "Request should not be nil")
require.NotNil(req.Header, "Request header should not be nil")
cookie, err := req.Cookie("RadarrAuth")
require.NoError(err, "Cookie should be set")
require.Equal(cookie.Value, "abcdef1234567890abcdef1234567890", "Cookie should be set")
return &http.Response{
StatusCode: http.StatusOK,
Body: nil,
Header: make(http.Header),
}, nil
}))
client := &http.Client{Transport: transport}
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err, "Error creating request: %s", err)
_, err = client.Do(req)
require.NoError(err, "Error sending request: %s", err)
}
func TestRoundTrip_FormAuthFailure(t *testing.T) {
require := require.New(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/?loginFailed=true", http.StatusFound)
}))
u, _ := url.Parse(ts.URL)
auth := &FormAuth{
Username: TEST_USER,
Password: TEST_PASS,
ApiKey: TEST_KEY,
AuthBaseURL: u,
Transport: http.DefaultTransport,
}
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: nil,
Header: make(http.Header),
}, nil
}))
client := &http.Client{Transport: transport}
req, err := http.NewRequest("GET", "http://example.com", nil)
require.NoError(err, "Error creating request: %s", err)
require.NotPanics(func() {
_, err = client.Do(req)
}, "Form Auth should not panic on auth failure")
require.Error(err, "Form Auth Transport should throw an error when auth fails")
}
func TestRoundTrip_Retries(t *testing.T) {
parameters := []struct {
name string
@ -111,7 +188,7 @@ func TestRoundTrip_Retries(t *testing.T) {
for _, param := range parameters {
t.Run(param.name, func(t *testing.T) {
require := require.New(t)
auth := AuthConfig{
auth := &ApiKeyAuth{
ApiKey: TEST_KEY,
}
attempts := 0
@ -121,9 +198,9 @@ func TestRoundTrip_Retries(t *testing.T) {
}))
client := &http.Client{Transport: transport}
req, err := http.NewRequest("GET", "http://example.com", nil)
require.Nil(err, "Error creating request: %s", err)
require.NoError(err, "Error creating request: %s", err)
_, err = client.Do(req)
require.NotNil(err, "Error should be returned from Do()")
require.Error(err, "Error should be returned from Do()")
require.Equal(3, attempts, "Should retry 3 times")
})
}
@ -134,7 +211,7 @@ func TestRoundTrip_StatusCodes(t *testing.T) {
for _, param := range parameters {
t.Run(fmt.Sprintf("%d", param), func(t *testing.T) {
require := require.New(t)
auth := AuthConfig{
auth := &ApiKeyAuth{
ApiKey: TEST_KEY,
}
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
@ -149,9 +226,9 @@ func TestRoundTrip_StatusCodes(t *testing.T) {
require.Nil(err, "Error creating request: %s", err)
_, err = client.Do(req)
if param >= 200 && param < 300 {
require.Nil(err, "Should Not error on 2XX: %s", err)
require.NoError(err, "Should Not error on 2XX: %s", err)
} else {
require.NotNil(err, "Should error on non-2XX")
require.Error(err, "Should error on non-2XX")
}
})
}

View File

@ -9,7 +9,7 @@ type Config struct {
ApiKey string `xml:"ApiKey"`
Port string `xml:"Port"`
UrlBase string `xml:"UrlBase"`
ApiVersion string `ml:"ApiVersion"`
ApiVersion string `xml:"ApiVersion"`
}
func NewConfig() *Config {