diff --git a/config.go b/config.go index d0613f3a..73c8e404 100644 --- a/config.go +++ b/config.go @@ -70,6 +70,7 @@ var ( defaultRPCKeyFile = filepath.Join(defaultHomeDir, "rpc.key") defaultRPCCertFile = filepath.Join(defaultHomeDir, "rpc.cert") defaultLogDir = filepath.Join(defaultHomeDir, defaultLogDirname) + defaultAltDNSNames = []string{} ) // runServiceCommand is only set to a real function on Windows. It is used @@ -166,6 +167,7 @@ type config struct { PipeRx uint `long:"piperx" description:"File descriptor of read end pipe to enable parent -> child process communication"` PipeTx uint `long:"pipetx" description:"File descriptor of write end pipe to enable parent <- child process communication"` LifetimeEvents bool `long:"lifetimeevents" description:"Send lifetime notifications over the TX pipe"` + AltDNSNames []string `long:"altdnsnames" description:"Specify additional dns names to use when generating the rpc server certificate" env:"DCRD_ALT_DNSNAMES" env-delim:","` onionlookup func(string) ([]net.IP, error) lookup func(string) ([]net.IP, error) oniondial func(string, string) (net.Conn, error) @@ -457,6 +459,7 @@ func loadConfig() (*config, []string, error) { AllowOldVotes: defaultAllowOldVotes, NoExistsAddrIndex: defaultNoExistsAddrIndex, NoCFilters: defaultNoCFilters, + AltDNSNames: defaultAltDNSNames, } // Service options which are only added on Windows. diff --git a/config_test.go b/config_test.go new file mode 100644 index 00000000..82c2482d --- /dev/null +++ b/config_test.go @@ -0,0 +1,68 @@ +// Copyright (c) 2018 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "io/ioutil" + "os" + "strings" + "testing" +) + +// in order to test command line arguments and environment variables +// you will need to append the flags to the os.Args variable like so +// os.Args = append(os.Args, "--altdnsnames=\"hostname1,hostname2\"") +// For environment variables you can use the +// os.Setenv("DCRD_ALT_DNSNAMES", "hostname1,hostname2") to set the variable +// before loadConfig() is called +// These args and env variables will then get parsed by loadConfig() + +func setup() { + // Temp config file is used to ensure there are no external influences + // from previously set env variables or default config files. + file, _ := ioutil.TempFile("", "dcrd_test_file.cfg") + defer os.Remove(file.Name()) + + // Parse the -test.* flags before removing them from the command line + // arguments list, which we do to allow go-flags to succeed. + flag.Parse() + os.Args = os.Args[:1] +} + +func TestLoadConfig(t *testing.T) { + _, _, err := loadConfig() + if err != nil { + t.Errorf("Failed to load dcrd config: %s\n", err.Error()) + } +} + +func TestDefaultAltDNSNames(t *testing.T) { + cfg, _, _ := loadConfig() + if len(cfg.AltDNSNames) != 0 { + t.Errorf("Invalid default value for altdnsnames: %s\n", cfg.AltDNSNames) + } +} + +func TestAltDNSNamesWithEnv(t *testing.T) { + os.Setenv("DCRD_ALT_DNSNAMES", "hostname1,hostname2") + cfg, _, _ := loadConfig() + hostnames := strings.Join(cfg.AltDNSNames, ",") + if hostnames != "hostname1,hostname2" { + t.Errorf("altDNSNames should be %s but was %s", "hostname1,hostname2", hostnames) + } +} + +func TestAltDNSNamesWithArg(t *testing.T) { + setup() + old := os.Args + os.Args = append(os.Args, "--altdnsnames=\"hostname1,hostname2\"") + cfg, _, _ := loadConfig() + hostnames := strings.Join(cfg.AltDNSNames, ",") + if hostnames != "hostname1,hostname2" { + t.Errorf("altDNSNames should be %s but was %s", "hostname1,hostname2", hostnames) + } + os.Args = old +} diff --git a/rpcserver.go b/rpcserver.go index 463edd06..152eb0c8 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6343,13 +6343,13 @@ func (s *rpcServer) Start() { } // genCertPair generates a key/cert pair to the paths provided. -func genCertPair(certFile, keyFile string) error { +func genCertPair(certFile, keyFile string, altDNSNames []string) error { rpcsLog.Infof("Generating TLS certificates...") org := "dcrd autogenerated cert" validUntil := time.Now().Add(10 * 365 * 24 * time.Hour) cert, key, err := certgen.NewTLSCertPair(elliptic.P521(), org, - validUntil, nil) + validUntil, altDNSNames) if err != nil { return err } @@ -6401,7 +6401,7 @@ func newRPCServer(listenAddrs []string, generator *BlkTmplGenerator, s *server) // Generate the TLS cert and key file if both don't already // exist. if !fileExists(cfg.RPCKey) && !fileExists(cfg.RPCCert) { - err := genCertPair(cfg.RPCCert, cfg.RPCKey) + err := genCertPair(cfg.RPCCert, cfg.RPCKey, cfg.AltDNSNames) if err != nil { return nil, err } diff --git a/rpcserver_test.go b/rpcserver_test.go index af5660f8..20e759bb 100644 --- a/rpcserver_test.go +++ b/rpcserver_test.go @@ -10,7 +10,11 @@ package main import ( "bytes" + "crypto/x509" + "encoding/pem" + "flag" "fmt" + "io/ioutil" "os" "runtime/debug" "testing" @@ -106,6 +110,12 @@ var primaryHarness *rpctest.Harness func TestMain(m *testing.M) { var err error + // Parse the -test.* flags before removing them from the command line + // arguments list, which we do to allow go-flags to succeed. + // See config_test.go for more info + flag.Parse() + os.Args = os.Args[:1] + // In order to properly test scenarios on as if we were on mainnet, // ensure that non-standard transactions aren't accepted into the // mempool or relayed. @@ -165,3 +175,54 @@ func TestRpcServer(t *testing.T) { currentTestNum++ } } + +func TestCertCreationWithHosts(t *testing.T) { + certfile, err := ioutil.TempFile("", "certfile") + if err != nil { + t.Fatalf("Unable to create temp certfile: %s", err) + } + keyfile, err := ioutil.TempFile("", "keyfile") + if err != nil { + t.Fatalf("Unable to create temp keyfile: %s", err) + } + hostnames := []string{"hostname1", "hostname2"} + defer os.Remove(keyfile.Name()) + defer os.Remove(certfile.Name()) + err = genCertPair(certfile.Name(), keyfile.Name(), hostnames) + if err != nil { + t.Fatalf("certifcate was not created correctly: %s", err) + } + certBytes, err := ioutil.ReadFile(certfile.Name()) + if err != nil { + t.Fatalf("Unable to read the certfile: %s", err) + } + pemCert, _ := pem.Decode(certBytes) + x509Cert, err := x509.ParseCertificate(pemCert.Bytes) + if err != nil { + t.Fatalf("Unable to parse the certificate: %s", err) + } + // Ensure the specified extra hosts are present. + for _, host := range hostnames { + err := x509Cert.VerifyHostname(host) + if err != nil { + t.Fatalf("failed to verify extra host '%s'", host) + } + } +} + +func TestCertCreationWithOutHosts(t *testing.T) { + certfile, err := ioutil.TempFile("", "certfile") + if err != nil { + t.Fatalf("Unable to create temp certfile: %s", err) + } + keyfile, err := ioutil.TempFile("", "keyfile") + if err != nil { + t.Fatalf("Unable to create temp keyfile: %s", err) + } + defer os.Remove(keyfile.Name()) + defer os.Remove(certfile.Name()) + err = genCertPair(certfile.Name(), keyfile.Name(), []string{}) + if err != nil { + t.Fatalf("certifcate was not created correctly: %s", err) + } +}