diff --git a/tls.go b/tls.go index 951ac7e8708a0bf2dd51736bd13d55f209d43545..e8a541c0bba06e666f9b132f6713031cd8f9d784 100644 --- a/tls.go +++ b/tls.go @@ -35,33 +35,40 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) } if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { - if allowSkipClientCerts { - tlsConfig = defaultTLSConfig() - tlsConfig.ServerName = serverName - tlsConfig.InsecureSkipVerify = insecureSkipVerify - setMinVersionExplicit(settings, tlsConfig) + if !allowSkipClientCerts { + return } - return - } - - privateKeyFile, err := settings.Setting(config.SocketPrivateKeyFile) - if err != nil { - return - } - - certificateFile, err := settings.Setting(config.SocketCertificateFile) - if err != nil { - return } tlsConfig = defaultTLSConfig() - tlsConfig.Certificates = make([]tls.Certificate, 1) tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify setMinVersionExplicit(settings, tlsConfig) - if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { - return + if settings.HasSetting(config.SocketPrivateKeyFile) || settings.HasSetting(config.SocketCertificateFile) { + + var privateKeyFile string + var certificateFile string + + privateKeyFile, err = settings.Setting(config.SocketPrivateKeyFile) + if err != nil { + return + } + + certificateFile, err = settings.Setting(config.SocketCertificateFile) + if err != nil { + return + } + + tlsConfig.Certificates = make([]tls.Certificate, 1) + + if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { + return + } + } + + if !allowSkipClientCerts { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert } if !settings.HasSetting(config.SocketCAFile) { @@ -86,7 +93,6 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) tlsConfig.RootCAs = certPool tlsConfig.ClientCAs = certPool - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert return } diff --git a/tls_test.go b/tls_test.go index fe6745a19197efe2c3a16ab6f5db0d09fede79c7..f1f17dfbb76bcafe55f6b8ab5780b2eb19145ed0 100644 --- a/tls_test.go +++ b/tls_test.go @@ -60,7 +60,7 @@ func (s *TLSTestSuite) TestLoadTLSNoCA() { s.Len(tlsConfig.Certificates, 1) s.Nil(tlsConfig.RootCAs) s.Nil(tlsConfig.ClientCAs) - s.Equal(tls.NoClientCert, tlsConfig.ClientAuth) + s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) } func (s *TLSTestSuite) TestLoadTLSWithBadCA() { @@ -87,6 +87,36 @@ func (s *TLSTestSuite) TestLoadTLSWithCA() { s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) } +func (s *TLSTestSuite) TestLoadTLSWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().Set(config.SocketCAFile, s.CAFile) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.NotNil(tlsConfig.RootCAs) + s.NotNil(tlsConfig.ClientCAs) +} + +func (s *TLSTestSuite) TestLoadTLSWithoutSSLWithOnlyCA() { + s.settings.GlobalSettings().Set(config.SocketCAFile, s.CAFile) + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.Nil(tlsConfig) +} + +func (s *TLSTestSuite) TestLoadTLSAllowSkipClientCerts() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + + s.Equal(tls.NoClientCert, tlsConfig.ClientAuth) +} + func (s *TLSTestSuite) TestServerNameUseSSL() { s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameUseSSL")