diff --git a/cmd/exportarr/main.go b/cmd/exportarr/main.go index cb89414..5b51bdf 100644 --- a/cmd/exportarr/main.go +++ b/cmd/exportarr/main.go @@ -286,6 +286,9 @@ func validation(config *cli.Context) error { if config.String("url") == "" && !apiKeyIsSet && config.String("config") == "" { return cli.Exit("url and api-key or config must be set, not none of them", 1) } + if config.Bool("form-auth") && (config.String("auth-username") == "" || config.String("auth-password") == "") { + return cli.Exit("username and password must be set if form-auth is set", 1) + } return nil } @@ -343,16 +346,25 @@ func flags(arr string) []cli.Flag { EnvVars: []string{"DISABLE_SSL_VERIFY"}, }, &cli.StringFlag{ - Name: "basic-auth-username", - Usage: "Provide the username for basic auth", + Name: "auth-username", + Aliases: []string{"basic-auth-username"}, + Usage: "Provide the username for basic or form auth", Required: false, - EnvVars: []string{"BASIC_AUTH_USERNAME"}, + EnvVars: []string{"AUTH_USERNAME", "BASIC_AUTH_USERNAME"}, }, &cli.StringFlag{ - Name: "basic-auth-password", - Usage: "Provide the password for basic auth", + Name: "auth-password", + Aliases: []string{"basic-auth-password"}, + Usage: "Provide the password for basic or form auth", Required: false, - EnvVars: []string{"BASIC_AUTH_PASSWORD"}, + EnvVars: []string{"AUTH_PASSWORD", "BASIC_AUTH_PASSWORD"}, + }, + &cli.BoolFlag{ + Name: "form-auth", + Usage: "Use form authentication rather than basic auth", + Value: false, + Required: false, + EnvVars: []string{"FORM_AUTH"}, }, &cli.BoolFlag{ Name: "enable-unknown-queue-items", diff --git a/internal/client/client.go b/internal/client/client.go index c51fc32..bc98682 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -21,11 +21,8 @@ type Client struct { // NewClient method initializes a new Radarr client. func NewClient(c *cli.Context, cf *model.Config) (*Client, error) { + var apiKey string var baseURL *url.URL - auth := AuthConfig{ - Username: c.String("basic-auth-username"), - Password: c.String("basic-auth-password"), - } apiVersion := cf.ApiVersion @@ -35,8 +32,8 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) { if err != nil { return nil, fmt.Errorf("Couldn't parse URL: %w", err) } - baseURL = baseURL.JoinPath(cf.UrlBase, "api", apiVersion) - auth.ApiKey = cf.ApiKey + baseURL = baseURL.JoinPath(cf.UrlBase) + apiKey = cf.ApiKey } else { // Otherwise use the value provided in the api-key flag @@ -45,22 +42,45 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) { if err != nil { return nil, fmt.Errorf("Couldn't parse URL: %w", err) } - baseURL = baseURL.JoinPath("api", apiVersion) if c.String("api-key") != "" { - auth.ApiKey = c.String("api-key") + apiKey = c.String("api-key") } else if c.String("api-key-file") != "" { data, err := os.ReadFile(c.String("api-key-file")) if err != nil { return nil, fmt.Errorf("Couldn't Read API Key file %w", err) } - auth.ApiKey = string(data) + + apiKey = string(data) } } + baseTransport := http.DefaultTransport if c.Bool("disable-ssl-verify") { baseTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } + + var auth Authenticator + if c.Bool("form-auth") { + auth = &FormAuth{ + Username: c.String("auth-username"), + Password: c.String("auth-password"), + ApiKey: apiKey, + AuthBaseURL: baseURL, + Transport: baseTransport, + } + } else if c.String("username") != "" && c.String("password") != "" { + auth = &BasicAuth{ + Username: c.String("auth-username"), + Password: c.String("auth-password"), + ApiKey: apiKey, + } + } else { + auth = &ApiKeyAuth{ + ApiKey: apiKey, + } + } + return &Client{ httpClient: http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -68,7 +88,7 @@ func NewClient(c *cli.Context, cf *model.Config) (*Client, error) { }, Transport: NewArrTransport(auth, baseTransport), }, - URL: *baseURL, + URL: *baseURL.JoinPath("api", apiVersion), }, nil } diff --git a/internal/client/transport.go b/internal/client/transport.go index 1382a1b..0223761 100644 --- a/internal/client/transport.go +++ b/internal/client/transport.go @@ -3,21 +3,22 @@ package client import ( "fmt" "net/http" + "net/url" + "strings" + "time" ) +type Authenticator interface { + Auth(req *http.Request) error +} + // ArrTransport is a http.RoundTripper that adds authentication to requests type ArrTransport struct { inner http.RoundTripper - auth AuthConfig + auth Authenticator } -type AuthConfig struct { - Username string - Password string - ApiKey string -} - -func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport { +func NewArrTransport(auth Authenticator, inner http.RoundTripper) *ArrTransport { return &ArrTransport{ inner: inner, auth: auth, @@ -25,10 +26,10 @@ func NewArrTransport(auth AuthConfig, inner http.RoundTripper) *ArrTransport { } func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.auth.Username != "" && t.auth.Password != "" { - req.SetBasicAuth(t.auth.Username, t.auth.Password) + err := t.auth.Auth(req) + if err != nil { + return nil, fmt.Errorf("Error authenticating request: %w", err) } - req.Header.Add("X-Api-Key", t.auth.ApiKey) resp, err := t.inner.RoundTrip(req) if err != nil || resp.StatusCode >= 500 { @@ -57,3 +58,84 @@ func (t *ArrTransport) RoundTrip(req *http.Request) (*http.Response, error) { } return resp, nil } + +type ApiKeyAuth struct { + ApiKey string +} + +func (a *ApiKeyAuth) Auth(req *http.Request) error { + req.Header.Add("X-Api-Key", a.ApiKey) + return nil +} + +type BasicAuth struct { + Username string + Password string + ApiKey string +} + +func (a *BasicAuth) Auth(req *http.Request) error { + req.SetBasicAuth(a.Username, a.Password) + req.Header.Add("X-Api-Key", a.ApiKey) + return nil +} + +type FormAuth struct { + Username string + Password string + ApiKey string + AuthBaseURL *url.URL + Transport http.RoundTripper + cookie *http.Cookie +} + +func (a *FormAuth) Auth(req *http.Request) error { + if a.cookie == nil || a.cookie.Expires.Before(time.Now().Add(-5*time.Minute)) { + form := url.Values{ + "username": {a.Username}, + "password": {a.Password}, + "rememberMe": {"on"}, + } + + u := a.AuthBaseURL.JoinPath("login") + u.Query().Add("ReturnUrl", "/general/settings") + + authReq, err := http.NewRequest("POST", u.String(), strings.NewReader(form.Encode())) + if err != nil { + return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err) + } + + authReq.Header.Add("Content-Type", "application/x-www-form-urlencoded") + authReq.Header.Add("Content-Length", fmt.Sprintf("%d", len(form.Encode()))) + + client := &http.Client{Transport: a.Transport, CheckRedirect: func(req *http.Request, via []*http.Request) error { + if req.URL.Query().Get("loginFailed") == "true" { + return fmt.Errorf("Failed to renew FormAuth Cookie: Login Failed") + } + return http.ErrUseLastResponse + }} + + authResp, err := client.Do(authReq) + if err != nil { + return fmt.Errorf("Failed to renew FormAuth Cookie: %w", err) + } + + if authResp.StatusCode != 302 { + return fmt.Errorf("Failed to renew FormAuth Cookie: Received Status Code %d", authResp.StatusCode) + } + + for _, cookie := range authResp.Cookies() { + if strings.HasSuffix(cookie.Name, "arrAuth") { + copy := *cookie + a.cookie = © + break + } + return fmt.Errorf("Failed to renew FormAuth Cookie: No Cookie with suffix 'arrAuth' found") + } + } + + req.AddCookie(a.cookie) + req.Header.Add("X-Api-Key", a.ApiKey) + + return nil +} diff --git a/internal/client/transport_test.go b/internal/client/transport_test.go index 63b2183..5ba964c 100644 --- a/internal/client/transport_test.go +++ b/internal/client/transport_test.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "fmt" "net/http" + "net/http/httptest" + "net/url" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -22,30 +25,30 @@ func (t testRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) } func TestRoundTrip_Auth(t *testing.T) { + require := require.New(t) parameters := []struct { name string - auth AuthConfig + auth Authenticator testFunc func(req *http.Request) (*http.Response, error) }{ { name: "BasicAuth", - auth: AuthConfig{ + auth: &BasicAuth{ Username: TEST_USER, Password: TEST_PASS, ApiKey: TEST_KEY, }, testFunc: func(req *http.Request) (*http.Response, error) { - require.NotNil(t, req, "Request should not be nil") - require.NotNil(t, req.Header, "Request header should not be nil") - require.NotEmpty(t, req.Header.Get("Authorization"), "Authorization header should be set") + require.NotNil(req, "Request should not be nil") + require.NotNil(req.Header, "Request header should not be nil") + require.NotEmpty(req.Header.Get("Authorization"), "Authorization header should be set") require.Equal( - t, "Basic "+base64.StdEncoding.EncodeToString([]byte(TEST_USER+":"+TEST_PASS)), req.Header.Get("Authorization"), "Authorization Header set to wrong value", ) - require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set") - require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value") + require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set") + require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value") return &http.Response{ StatusCode: 200, Body: nil, @@ -55,17 +58,15 @@ func TestRoundTrip_Auth(t *testing.T) { }, { name: "ApiKey", - auth: AuthConfig{ - Username: "", - Password: "", - ApiKey: TEST_KEY, + auth: &ApiKeyAuth{ + ApiKey: TEST_KEY, }, testFunc: func(req *http.Request) (*http.Response, error) { - require.NotNil(t, req, "Request should not be nil") - require.NotNil(t, req.Header, "Request header should not be nil") - require.Empty(t, req.Header.Get("Authorization"), "Authorization header should be empty") - require.NotEmpty(t, req.Header.Get("X-Api-Key"), "X-Api-Key header should be set") - require.Equal(t, TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value") + require.NotNil(req, "Request should not be nil") + require.NotNil(req.Header, "Request header should not be nil") + require.Empty(req.Header.Get("Authorization"), "Authorization header should be empty") + require.NotEmpty(req.Header.Get("X-Api-Key"), "X-Api-Key header should be set") + require.Equal(TEST_KEY, req.Header.Get("X-Api-Key"), "X-Api-Key Header set to wrong value") return &http.Response{ StatusCode: 200, Body: nil, @@ -79,13 +80,89 @@ func TestRoundTrip_Auth(t *testing.T) { transport := NewArrTransport(param.auth, testRoundTripFunc(param.testFunc)) client := &http.Client{Transport: transport} req, err := http.NewRequest("GET", "http://example.com", nil) - require.Nil(t, err, "Error creating request: %s", err) + require.NoError(err, "Error creating request: %s", err) _, err = client.Do(req) - require.Nil(t, err, "Error sending request: %s", err) + require.NoError(err, "Error sending request: %s", err) }) } } +func TestRoundTrip_FormAuth(t *testing.T) { + require := require.New(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NotNil(r, "Request should not be nil") + require.NotNil(r.Header, "Request header should not be nil") + require.Empty(r.Header.Get("Authorization"), "Authorization header should be empty") + require.Equal("POST", r.Method, "Request method should be POST") + require.Equal("/login", r.URL.Path, "Request URL should be /login") + require.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type"), "Content-Type should be application/x-www-form-urlencoded") + require.Equal(TEST_USER, r.FormValue("username"), "Username should be %s", TEST_USER) + require.Equal(TEST_PASS, r.FormValue("password"), "Password should be %s", TEST_PASS) + http.SetCookie(w, &http.Cookie{ + Name: "RadarrAuth", + Value: "abcdef1234567890abcdef1234567890", + Expires: time.Now().Add(24 * time.Hour), + }) + w.WriteHeader(http.StatusFound) + w.Write([]byte("OK")) + })) + defer ts.Close() + tsUrl, _ := url.Parse(ts.URL) + auth := &FormAuth{ + Username: TEST_USER, + Password: TEST_PASS, + ApiKey: TEST_KEY, + AuthBaseURL: tsUrl, + Transport: http.DefaultTransport, + } + transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) { + require.NotNil(req, "Request should not be nil") + require.NotNil(req.Header, "Request header should not be nil") + cookie, err := req.Cookie("RadarrAuth") + require.NoError(err, "Cookie should be set") + require.Equal(cookie.Value, "abcdef1234567890abcdef1234567890", "Cookie should be set") + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + Header: make(http.Header), + }, nil + })) + client := &http.Client{Transport: transport} + req, err := http.NewRequest("GET", "http://example.com", nil) + require.NoError(err, "Error creating request: %s", err) + _, err = client.Do(req) + require.NoError(err, "Error sending request: %s", err) +} + +func TestRoundTrip_FormAuthFailure(t *testing.T) { + require := require.New(t) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/?loginFailed=true", http.StatusFound) + })) + u, _ := url.Parse(ts.URL) + auth := &FormAuth{ + Username: TEST_USER, + Password: TEST_PASS, + ApiKey: TEST_KEY, + AuthBaseURL: u, + Transport: http.DefaultTransport, + } + transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + Header: make(http.Header), + }, nil + })) + client := &http.Client{Transport: transport} + req, err := http.NewRequest("GET", "http://example.com", nil) + require.NoError(err, "Error creating request: %s", err) + require.NotPanics(func() { + _, err = client.Do(req) + }, "Form Auth should not panic on auth failure") + require.Error(err, "Form Auth Transport should throw an error when auth fails") +} + func TestRoundTrip_Retries(t *testing.T) { parameters := []struct { name string @@ -111,7 +188,7 @@ func TestRoundTrip_Retries(t *testing.T) { for _, param := range parameters { t.Run(param.name, func(t *testing.T) { require := require.New(t) - auth := AuthConfig{ + auth := &ApiKeyAuth{ ApiKey: TEST_KEY, } attempts := 0 @@ -121,9 +198,9 @@ func TestRoundTrip_Retries(t *testing.T) { })) client := &http.Client{Transport: transport} req, err := http.NewRequest("GET", "http://example.com", nil) - require.Nil(err, "Error creating request: %s", err) + require.NoError(err, "Error creating request: %s", err) _, err = client.Do(req) - require.NotNil(err, "Error should be returned from Do()") + require.Error(err, "Error should be returned from Do()") require.Equal(3, attempts, "Should retry 3 times") }) } @@ -134,7 +211,7 @@ func TestRoundTrip_StatusCodes(t *testing.T) { for _, param := range parameters { t.Run(fmt.Sprintf("%d", param), func(t *testing.T) { require := require.New(t) - auth := AuthConfig{ + auth := &ApiKeyAuth{ ApiKey: TEST_KEY, } transport := NewArrTransport(auth, testRoundTripFunc(func(req *http.Request) (*http.Response, error) { @@ -149,9 +226,9 @@ func TestRoundTrip_StatusCodes(t *testing.T) { require.Nil(err, "Error creating request: %s", err) _, err = client.Do(req) if param >= 200 && param < 300 { - require.Nil(err, "Should Not error on 2XX: %s", err) + require.NoError(err, "Should Not error on 2XX: %s", err) } else { - require.NotNil(err, "Should error on non-2XX") + require.Error(err, "Should error on non-2XX") } }) } diff --git a/internal/model/config.go b/internal/model/config.go index e275629..59faabc 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -9,7 +9,7 @@ type Config struct { ApiKey string `xml:"ApiKey"` Port string `xml:"Port"` UrlBase string `xml:"UrlBase"` - ApiVersion string `ml:"ApiVersion"` + ApiVersion string `xml:"ApiVersion"` } func NewConfig() *Config {