mirror of
https://github.com/onedr0p/exportarr.git
synced 2026-02-06 10:57:32 +00:00
(1/3) Add Options for Auth via Form Data (#110)
Co-authored-by: Devin Buhl <onedr0p@users.noreply.github.com>
This commit is contained in:
parent
9237ee5d7c
commit
5834fa69fb
@ -286,6 +286,9 @@ func validation(config *cli.Context) error {
|
||||
if config.String("url") == "" && !apiKeyIsSet && config.String("config") == "" {
|
||||
return cli.Exit("url and api-key or config must be set, not none of them", 1)
|
||||
}
|
||||
if config.Bool("form-auth") && (config.String("auth-username") == "" || config.String("auth-password") == "") {
|
||||
return cli.Exit("username and password must be set if form-auth is set", 1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -343,16 +346,25 @@ func flags(arr string) []cli.Flag {
|
||||
EnvVars: []string{"DISABLE_SSL_VERIFY"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "basic-auth-username",
|
||||
Usage: "Provide the username for basic auth",
|
||||
Name: "auth-username",
|
||||
Aliases: []string{"basic-auth-username"},
|
||||
Usage: "Provide the username for basic or form auth",
|
||||
Required: false,
|
||||
EnvVars: []string{"BASIC_AUTH_USERNAME"},
|
||||
EnvVars: []string{"AUTH_USERNAME", "BASIC_AUTH_USERNAME"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "basic-auth-password",
|
||||
Usage: "Provide the password for basic auth",
|
||||
Name: "auth-password",
|
||||
Aliases: []string{"basic-auth-password"},
|
||||
Usage: "Provide the password for basic or form auth",
|
||||
Required: false,
|
||||
EnvVars: []string{"BASIC_AUTH_PASSWORD"},
|
||||
EnvVars: []string{"AUTH_PASSWORD", "BASIC_AUTH_PASSWORD"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "form-auth",
|
||||
Usage: "Use form authentication rather than basic auth",
|
||||
Value: false,
|
||||
Required: false,
|
||||
EnvVars: []string{"FORM_AUTH"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "enable-unknown-queue-items",
|
||||
|
||||
@ -21,11 +21,8 @@ type Client struct {
|
||||
|
||||
// NewClient method initializes a new Radarr client.
|
||||
func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
|
||||
var apiKey string
|
||||
var baseURL *url.URL
|
||||
auth := AuthConfig{
|
||||
Username: c.String("basic-auth-username"),
|
||||
Password: c.String("basic-auth-password"),
|
||||
}
|
||||
|
||||
apiVersion := cf.ApiVersion
|
||||
|
||||
@ -35,8 +32,8 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Couldn't parse URL: %w", err)
|
||||
}
|
||||
baseURL = baseURL.JoinPath(cf.UrlBase, "api", apiVersion)
|
||||
auth.ApiKey = cf.ApiKey
|
||||
baseURL = baseURL.JoinPath(cf.UrlBase)
|
||||
apiKey = cf.ApiKey
|
||||
|
||||
} else {
|
||||
// Otherwise use the value provided in the api-key flag
|
||||
@ -45,22 +42,45 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Couldn't parse URL: %w", err)
|
||||
}
|
||||
baseURL = baseURL.JoinPath("api", apiVersion)
|
||||
|
||||
if c.String("api-key") != "" {
|
||||
auth.ApiKey = c.String("api-key")
|
||||
apiKey = c.String("api-key")
|
||||
} else if c.String("api-key-file") != "" {
|
||||
data, err := os.ReadFile(c.String("api-key-file"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Couldn't Read API Key file %w", err)
|
||||
}
|
||||
auth.ApiKey = string(data)
|
||||
|
||||
apiKey = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
baseTransport := http.DefaultTransport
|
||||
if c.Bool("disable-ssl-verify") {
|
||||
baseTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
var auth Authenticator
|
||||
if c.Bool("form-auth") {
|
||||
auth = &FormAuth{
|
||||
Username: c.String("auth-username"),
|
||||
Password: c.String("auth-password"),
|
||||
ApiKey: apiKey,
|
||||
AuthBaseURL: baseURL,
|
||||
Transport: baseTransport,
|
||||
}
|
||||
} else if c.String("username") != "" && c.String("password") != "" {
|
||||
auth = &BasicAuth{
|
||||
Username: c.String("auth-username"),
|
||||
Password: c.String("auth-password"),
|
||||
ApiKey: apiKey,
|
||||
}
|
||||
} else {
|
||||
auth = &ApiKeyAuth{
|
||||
ApiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
@ -68,7 +88,7 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) {
|
||||
},
|
||||
Transport: NewArrTransport(auth, baseTransport),
|
||||
},
|
||||
URL: *baseURL,
|
||||
URL: *baseURL.JoinPath("api", apiVersion),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -3,21 +3,22 @@ package client
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Authenticator interface {
|
||||
Auth(req *http.Request) error
|
||||
}
|
||||
|
||||
// ArrTransport is a http.RoundTripper that adds authentication to requests
|
||||
type ArrTransport struct {
|
||||
inner http.RoundTripper
|
||||
auth AuthConfig
|
||||
auth Authenticator
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport {
|
||||
func NewArrTransport(auth Authenticator, inner http.RoundTripper) *ArrTransport {
|
||||
return &ArrTransport{
|
||||
inner: inner,
|
||||
auth: auth,
|
||||
@ -25,10 +26,10 @@ func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport {
|
||||
}
|
||||
|
||||
func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if t.auth.Username != "" && t.auth.Password != "" {
|
||||
req.SetBasicAuth(t.auth.Username, t.auth.Password)
|
||||
err := t.auth.Auth(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error authenticating request: %w", err)
|
||||
}
|
||||
req.Header.Add("X-Api-Key", t.auth.ApiKey)
|
||||
|
||||
resp, err := t.inner.RoundTrip(req)
|
||||
if err != nil || resp.StatusCode >= 500 {
|
||||
@ -57,3 +58,84 @@ func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type ApiKeyAuth struct {
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
func (a *ApiKeyAuth) Auth(req *http.Request) error {
|
||||
req.Header.Add("X-Api-Key", a.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
type BasicAuth struct {
|
||||
Username string
|
||||
Password string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
func (a *BasicAuth) Auth(req *http.Request) error {
|
||||
req.SetBasicAuth(a.Username, a.Password)
|
||||
req.Header.Add("X-Api-Key", a.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
type FormAuth struct {
|
||||
Username string
|
||||
Password string
|
||||
ApiKey string
|
||||
AuthBaseURL *url.URL
|
||||
Transport http.RoundTripper
|
||||
cookie *http.Cookie
|
||||
}
|
||||
|
||||
func (a *FormAuth) Auth(req *http.Request) error {
|
||||
if a.cookie == nil || a.cookie.Expires.Before(time.Now().Add(-5*time.Minute)) {
|
||||
form := url.Values{
|
||||
"username": {a.Username},
|
||||
"password": {a.Password},
|
||||
"rememberMe": {"on"},
|
||||
}
|
||||
|
||||
u := a.AuthBaseURL.JoinPath("login")
|
||||
u.Query().Add("ReturnUrl", "/general/settings")
|
||||
|
||||
authReq, err := http.NewRequest("POST", u.String(), strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err)
|
||||
}
|
||||
|
||||
authReq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
authReq.Header.Add("Content-Length", fmt.Sprintf("%d", len(form.Encode())))
|
||||
|
||||
client := &http.Client{Transport: a.Transport, CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if req.URL.Query().Get("loginFailed") == "true" {
|
||||
return fmt.Errorf("Failed to renew FormAuth Cookie: Login Failed")
|
||||
}
|
||||
return http.ErrUseLastResponse
|
||||
}}
|
||||
|
||||
authResp, err := client.Do(authReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err)
|
||||
}
|
||||
|
||||
if authResp.StatusCode != 302 {
|
||||
return fmt.Errorf("Failed to renew FormAuth Cookie: Received Status Code %d", authResp.StatusCode)
|
||||
}
|
||||
|
||||
for _, cookie := range authResp.Cookies() {
|
||||
if strings.HasSuffix(cookie.Name, "arrAuth") {
|
||||
copy := *cookie
|
||||
a.cookie = ©
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("Failed to renew FormAuth Cookie: No Cookie with suffix 'arrAuth' found")
|
||||
}
|
||||
}
|
||||
|
||||
req.AddCookie(a.cookie)
|
||||
req.Header.Add("X-Api-Key", a.ApiKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -4,7 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -22,30 +25,30 @@ func (t testRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func TestRoundTrip_Auth(t *testing.T) {
|
||||
require := require.New(t)
|
||||
parameters := []struct {
|
||||
name string
|
||||
auth AuthConfig
|
||||
auth Authenticator
|
||||
testFunc func(req *http.Request) (*http.Response, error)
|
||||
}{
|
||||
{
|
||||
name: "BasicAuth",
|
||||
auth: AuthConfig{
|
||||
auth: &BasicAuth{
|
||||
Username: TEST_USER,
|
||||
Password: TEST_PASS,
|
||||
ApiKey: TEST_KEY,
|
||||
},
|
||||
testFunc: func(req *http.Request) (*http.Response, error) {
|
||||
require.NotNil(t, req, "Request should not be nil")
|
||||
require.NotNil(t, req.Header, "Request header should not be nil")
|
||||
require.NotEmpty(t, req.Header.Get("Authorization"), "Authorization header should be set")
|
||||
require.NotNil(req, "Request should not be nil")
|
||||
require.NotNil(req.Header, "Request header should not be nil")
|
||||
require.NotEmpty(req.Header.Get("Authorization"), "Authorization header should be set")
|
||||
require.Equal(
|
||||
t,
|
||||
"Basic "+base64.StdEncoding.EncodeToString([]byte(TEST_USER+":"+TEST_PASS)),
|
||||
req.Header.Get("Authorization"),
|
||||
"Authorization Header set to wrong value",
|
||||
)
|
||||
require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
|
||||
require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
|
||||
require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
|
||||
require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: nil,
|
||||
@ -55,17 +58,15 @@ func TestRoundTrip_Auth(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "ApiKey",
|
||||
auth: AuthConfig{
|
||||
Username: "",
|
||||
Password: "",
|
||||
ApiKey: TEST_KEY,
|
||||
auth: &ApiKeyAuth{
|
||||
ApiKey: TEST_KEY,
|
||||
},
|
||||
testFunc: func(req *http.Request) (*http.Response, error) {
|
||||
require.NotNil(t, req, "Request should not be nil")
|
||||
require.NotNil(t, req.Header, "Request header should not be nil")
|
||||
require.Empty(t, req.Header.Get("Authorization"), "Authorization header should be empty")
|
||||
require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
|
||||
require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
|
||||
require.NotNil(req, "Request should not be nil")
|
||||
require.NotNil(req.Header, "Request header should not be nil")
|
||||
require.Empty(req.Header.Get("Authorization"), "Authorization header should be empty")
|
||||
require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set")
|
||||
require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value")
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: nil,
|
||||
@ -79,13 +80,89 @@ func TestRoundTrip_Auth(t *testing.T) {
|
||||
transport := NewArrTransport(param.auth, testRoundTripFunc(param.testFunc))
|
||||
client := &http.Client{Transport: transport}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
require.Nil(t, err, "Error creating request: %s", err)
|
||||
require.NoError(err, "Error creating request: %s", err)
|
||||
_, err = client.Do(req)
|
||||
require.Nil(t, err, "Error sending request: %s", err)
|
||||
require.NoError(err, "Error sending request: %s", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTrip_FormAuth(t *testing.T) {
|
||||
require := require.New(t)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NotNil(r, "Request should not be nil")
|
||||
require.NotNil(r.Header, "Request header should not be nil")
|
||||
require.Empty(r.Header.Get("Authorization"), "Authorization header should be empty")
|
||||
require.Equal("POST", r.Method, "Request method should be POST")
|
||||
require.Equal("/login", r.URL.Path, "Request URL should be /login")
|
||||
require.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type"), "Content-Type should be application/x-www-form-urlencoded")
|
||||
require.Equal(TEST_USER, r.FormValue("username"), "Username should be %s", TEST_USER)
|
||||
require.Equal(TEST_PASS, r.FormValue("password"), "Password should be %s", TEST_PASS)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "RadarrAuth",
|
||||
Value: "abcdef1234567890abcdef1234567890",
|
||||
Expires: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
w.WriteHeader(http.StatusFound)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
tsUrl, _ := url.Parse(ts.URL)
|
||||
auth := &FormAuth{
|
||||
Username: TEST_USER,
|
||||
Password: TEST_PASS,
|
||||
ApiKey: TEST_KEY,
|
||||
AuthBaseURL: tsUrl,
|
||||
Transport: http.DefaultTransport,
|
||||
}
|
||||
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
require.NotNil(req, "Request should not be nil")
|
||||
require.NotNil(req.Header, "Request header should not be nil")
|
||||
cookie, err := req.Cookie("RadarrAuth")
|
||||
require.NoError(err, "Cookie should be set")
|
||||
require.Equal(cookie.Value, "abcdef1234567890abcdef1234567890", "Cookie should be set")
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}))
|
||||
client := &http.Client{Transport: transport}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
require.NoError(err, "Error creating request: %s", err)
|
||||
_, err = client.Do(req)
|
||||
require.NoError(err, "Error sending request: %s", err)
|
||||
}
|
||||
|
||||
func TestRoundTrip_FormAuthFailure(t *testing.T) {
|
||||
require := require.New(t)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/?loginFailed=true", http.StatusFound)
|
||||
}))
|
||||
u, _ := url.Parse(ts.URL)
|
||||
auth := &FormAuth{
|
||||
Username: TEST_USER,
|
||||
Password: TEST_PASS,
|
||||
ApiKey: TEST_KEY,
|
||||
AuthBaseURL: u,
|
||||
Transport: http.DefaultTransport,
|
||||
}
|
||||
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}))
|
||||
client := &http.Client{Transport: transport}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
require.NoError(err, "Error creating request: %s", err)
|
||||
require.NotPanics(func() {
|
||||
_, err = client.Do(req)
|
||||
}, "Form Auth should not panic on auth failure")
|
||||
require.Error(err, "Form Auth Transport should throw an error when auth fails")
|
||||
}
|
||||
|
||||
func TestRoundTrip_Retries(t *testing.T) {
|
||||
parameters := []struct {
|
||||
name string
|
||||
@ -111,7 +188,7 @@ func TestRoundTrip_Retries(t *testing.T) {
|
||||
for _, param := range parameters {
|
||||
t.Run(param.name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
auth := AuthConfig{
|
||||
auth := &ApiKeyAuth{
|
||||
ApiKey: TEST_KEY,
|
||||
}
|
||||
attempts := 0
|
||||
@ -121,9 +198,9 @@ func TestRoundTrip_Retries(t *testing.T) {
|
||||
}))
|
||||
client := &http.Client{Transport: transport}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
require.Nil(err, "Error creating request: %s", err)
|
||||
require.NoError(err, "Error creating request: %s", err)
|
||||
_, err = client.Do(req)
|
||||
require.NotNil(err, "Error should be returned from Do()")
|
||||
require.Error(err, "Error should be returned from Do()")
|
||||
require.Equal(3, attempts, "Should retry 3 times")
|
||||
})
|
||||
}
|
||||
@ -134,7 +211,7 @@ func TestRoundTrip_StatusCodes(t *testing.T) {
|
||||
for _, param := range parameters {
|
||||
t.Run(fmt.Sprintf("%d", param), func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
auth := AuthConfig{
|
||||
auth := &ApiKeyAuth{
|
||||
ApiKey: TEST_KEY,
|
||||
}
|
||||
transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
@ -149,9 +226,9 @@ func TestRoundTrip_StatusCodes(t *testing.T) {
|
||||
require.Nil(err, "Error creating request: %s", err)
|
||||
_, err = client.Do(req)
|
||||
if param >= 200 && param < 300 {
|
||||
require.Nil(err, "Should Not error on 2XX: %s", err)
|
||||
require.NoError(err, "Should Not error on 2XX: %s", err)
|
||||
} else {
|
||||
require.NotNil(err, "Should error on non-2XX")
|
||||
require.Error(err, "Should error on non-2XX")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ type Config struct {
|
||||
ApiKey string `xml:"ApiKey"`
|
||||
Port string `xml:"Port"`
|
||||
UrlBase string `xml:"UrlBase"`
|
||||
ApiVersion string `ml:"ApiVersion"`
|
||||
ApiVersion string `xml:"ApiVersion"`
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user