diff --git a/src/lib/base/String.cpp b/src/lib/base/String.cpp index 4ce38899..4674deff 100644 --- a/src/lib/base/String.cpp +++ b/src/lib/base/String.cpp @@ -27,6 +27,9 @@ #include #include #include +#include +#include +#include namespace synergy { namespace string { @@ -180,6 +183,30 @@ removeFileExt(String filename) return filename.substr(0, dot); } +void +toHex(String& subject, int width, const char fill) +{ + std::stringstream ss; + ss << std::hex; + for (unsigned int i = 0; i < subject.length(); i++) { + ss << std::setw(width) << std::setfill(fill) << (int)(unsigned char)subject[i]; + } + + subject = ss.str(); +} + +void +uppercase(String& subject) +{ + std::transform(subject.begin(), subject.end(), subject.begin(), ::toupper); +} + +void +removeChar(String& subject, const char c) +{ + subject.erase(std::remove(subject.begin(), subject.end(), c), subject.end()); +} + // // CaselessCmp // diff --git a/src/lib/base/String.h b/src/lib/base/String.h index 9d137143..dbe5c21b 100644 --- a/src/lib/base/String.h +++ b/src/lib/base/String.h @@ -70,6 +70,25 @@ Finds the last dot and remove all characters from the dot to the end */ String removeFileExt(String filename); +//! Convert into hexdecimal +/*! +Convert each character in \c subject into hexdecimal form with \c width +*/ +void toHex(String& subject, int width, const char fill = '0'); + +//! Convert to all uppercase +/*! +Convert each character in \c subject to uppercase +*/ +void uppercase(String& subject); + +//! Remove all specific char in suject +/*! +Remove all specific \c char in \c suject +*/ +void removeChar(String& subject, const char c); + + //! Case-insensitive comparisons /*! This class provides case-insensitve comparison functions. diff --git a/src/lib/client/Client.cpp b/src/lib/client/Client.cpp index 456260f0..329733d4 100644 --- a/src/lib/client/Client.cpp +++ b/src/lib/client/Client.cpp @@ -60,8 +60,7 @@ Client::Client( const String& name, const NetworkAddress& address, ISocketFactory* socketFactory, synergy::Screen* screen, - bool enableDragDrop, - bool enableCrypto) : + ClientArgs args) : m_mock(false), m_name(name), m_serverAddress(address), @@ -77,9 +76,9 @@ Client::Client( m_events(events), m_sendFileThread(NULL), m_writeToDropDirThread(NULL), - m_enableDragDrop(enableDragDrop), m_socket(NULL), - m_useSecureNetwork(false) + m_useSecureNetwork(false), + m_args(args) { assert(m_socketFactory != NULL); assert(m_screen != NULL); @@ -94,7 +93,7 @@ Client::Client( new TMethodEventJob(this, &Client::handleResume)); - if (m_enableDragDrop) { + if (m_args.m_enableDragDrop) { m_events->adoptHandler(m_events->forIScreen().fileChunkSending(), this, new TMethodEventJob(this, @@ -105,7 +104,7 @@ Client::Client( &Client::handleFileRecieveCompleted)); } - if (enableCrypto) { + if (m_args.m_enableCrypto) { m_useSecureNetwork = ARCH->plugin().exists(s_networkSecurity); if (m_useSecureNetwork == false) { LOG((CLOG_NOTE "crypto disabled because of ns plugin not available")); @@ -162,6 +161,7 @@ Client::connect() // create the socket IDataSocket* socket = m_socketFactory->create(m_useSecureNetwork); m_socket = dynamic_cast(socket); + m_socket->setFingerprintFilename(m_args.m_certFingerprintFilename); // filter socket messages, including a packetizing filter m_stream = socket; @@ -780,7 +780,8 @@ Client::fileChunkReceived(String data) void Client::dragInfoReceived(UInt32 fileNum, String data) { - if (!m_enableDragDrop) { + // TODO: fix duplicate function from CServer + if (!m_args.m_enableDragDrop) { LOG((CLOG_DEBUG "drag drop not enabled, ignoring drag info.")); return; } diff --git a/src/lib/client/Client.h b/src/lib/client/Client.h index f869ef64..e50b567b 100644 --- a/src/lib/client/Client.h +++ b/src/lib/client/Client.h @@ -23,6 +23,7 @@ #include "synergy/IClipboard.h" #include "synergy/DragInformation.h" #include "synergy/INode.h" +#include "synergy/ClientArgs.h" #include "net/NetworkAddress.h" #include "base/EventTypes.h" @@ -59,8 +60,7 @@ public: const String& name, const NetworkAddress& address, ISocketFactory* socketFactory, synergy::Screen* screen, - bool enableDragDrop, - bool enableCrypto); + ClientArgs args); ~Client(); #ifdef TEST_ENV @@ -224,7 +224,7 @@ private: String m_dragFileExt; Thread* m_sendFileThread; Thread* m_writeToDropDirThread; - bool m_enableDragDrop; TCPSocket* m_socket; bool m_useSecureNetwork; + ClientArgs m_args; }; diff --git a/src/lib/mt/Thread.cpp b/src/lib/mt/Thread.cpp index dd45d872..f28bb6b9 100644 --- a/src/lib/mt/Thread.cpp +++ b/src/lib/mt/Thread.cpp @@ -18,7 +18,6 @@ #include "mt/Thread.h" -#include "net/XSocket.h" #include "mt/XMT.h" #include "mt/XThread.h" #include "arch/Arch.h" @@ -158,11 +157,6 @@ Thread::threadFunc(void* vjob) job->run(); LOG((CLOG_DEBUG1 "thread 0x%08x exit", id)); } - - catch (XSocket& e) { - // client called cancel() - LOG((CLOG_DEBUG "%s", e.what())); - } catch (XThreadCancel&) { // client called cancel() LOG((CLOG_DEBUG1 "caught cancel on thread 0x%08x", id)); diff --git a/src/lib/net/TCPListenSocket.cpp b/src/lib/net/TCPListenSocket.cpp index cdbd73c4..f0f06ee3 100644 --- a/src/lib/net/TCPListenSocket.cpp +++ b/src/lib/net/TCPListenSocket.cpp @@ -112,27 +112,35 @@ TCPListenSocket::accept() try { socket = new TCPSocket(m_events, m_socketMultiplexer, ARCH->acceptSocket(m_socket, NULL)); if (socket != NULL) { - m_socketMultiplexer->addSocket(this, - new TSocketMultiplexerMethodJob( - this, &TCPListenSocket::serviceListening, - m_socket, true, false)); + setListeningJob(); } return socket; } catch (XArchNetwork&) { if (socket != NULL) { delete socket; + setListeningJob(); } return NULL; } catch (std::exception &ex) { if (socket != NULL) { delete socket; + setListeningJob(); } throw ex; } } +void +TCPListenSocket::setListeningJob() +{ + m_socketMultiplexer->addSocket(this, + new TSocketMultiplexerMethodJob( + this, &TCPListenSocket::serviceListening, + m_socket, true, false)); +} + ISocketMultiplexerJob* TCPListenSocket::serviceListening(ISocketMultiplexerJob* job, bool read, bool, bool error) diff --git a/src/lib/net/TCPListenSocket.h b/src/lib/net/TCPListenSocket.h index ef2687db..c769e170 100644 --- a/src/lib/net/TCPListenSocket.h +++ b/src/lib/net/TCPListenSocket.h @@ -45,6 +45,9 @@ public: accept(); virtual void deleteSocket(void*) { } +protected: + void setListeningJob(); + public: ISocketMultiplexerJob* serviceListening(ISocketMultiplexerJob*, diff --git a/src/lib/net/TCPSocket.h b/src/lib/net/TCPSocket.h index 832cd128..2e75f357 100644 --- a/src/lib/net/TCPSocket.h +++ b/src/lib/net/TCPSocket.h @@ -59,6 +59,7 @@ public: virtual void secureConnect() {} virtual void secureAccept() {} + virtual void setFingerprintFilename(String& f) {} protected: ArchSocket getSocket() { return m_socket; } diff --git a/src/lib/plugin/ns/SecureListenSocket.cpp b/src/lib/plugin/ns/SecureListenSocket.cpp index a0df2519..562a6899 100644 --- a/src/lib/plugin/ns/SecureListenSocket.cpp +++ b/src/lib/plugin/ns/SecureListenSocket.cpp @@ -54,10 +54,13 @@ SecureListenSocket::accept() m_events, m_socketMultiplexer, ARCH->acceptSocket(m_socket, NULL)); - + socket->initSsl(true); m_secureSocketSet.insert(socket); - socket->initSsl(true); + if (socket != NULL) { + setListeningJob(); + } + // TODO: customized certificate path String certificateFilename = ARCH->getProfileDirectory(); #if SYSAPI_WIN32 @@ -67,26 +70,27 @@ SecureListenSocket::accept() #endif certificateFilename.append(s_certificateFilename); - socket->loadCertificates(certificateFilename.c_str()); + bool loaded = socket->loadCertificates(certificateFilename); + if (!loaded) { + delete socket; + return NULL; + } + socket->secureAccept(); - if (socket != NULL) { - m_socketMultiplexer->addSocket(this, - new TSocketMultiplexerMethodJob( - this, &TCPListenSocket::serviceListening, - m_socket, true, false)); - } return dynamic_cast(socket); } catch (XArchNetwork&) { if (socket != NULL) { delete socket; + setListeningJob(); } return NULL; } catch (std::exception &ex) { if (socket != NULL) { delete socket; + setListeningJob(); } throw ex; } diff --git a/src/lib/plugin/ns/SecureSocket.cpp b/src/lib/plugin/ns/SecureSocket.cpp index cb6db76f..7f78554c 100644 --- a/src/lib/plugin/ns/SecureSocket.cpp +++ b/src/lib/plugin/ns/SecureSocket.cpp @@ -28,6 +28,7 @@ #include #include #include +#include // // SecureSocket @@ -44,7 +45,8 @@ SecureSocket::SecureSocket( IEventQueue* events, SocketMultiplexer* socketMultiplexer) : TCPSocket(events, socketMultiplexer), - m_secureReady(false) + m_secureReady(false), + m_certFingerprintFilename() { } @@ -149,24 +151,46 @@ SecureSocket::initSsl(bool server) initContext(server); } -void -SecureSocket::loadCertificates(const char* filename) +bool +SecureSocket::loadCertificates(String& filename) { - int r = 0; - r = SSL_CTX_use_certificate_file(m_ssl->m_context, filename, SSL_FILETYPE_PEM); - if (r <= 0) { - throwError("could not use ssl certificate"); + if (filename.empty()) { + showError("ssl certificate is not specified"); + return false; + } + else { + std::ifstream file(filename.c_str()); + bool exist = file.good(); + file.close(); + + if (!exist) { + String errorMsg("ssl certificate doesn't exist: "); + errorMsg.append(filename); + showError(errorMsg.c_str()); + return false; + } } - r = SSL_CTX_use_PrivateKey_file(m_ssl->m_context, filename, SSL_FILETYPE_PEM); + int r = 0; + r = SSL_CTX_use_certificate_file(m_ssl->m_context, filename.c_str(), SSL_FILETYPE_PEM); if (r <= 0) { - throwError("could not use ssl private key"); + showError("could not use ssl certificate"); + return false; + } + + r = SSL_CTX_use_PrivateKey_file(m_ssl->m_context, filename.c_str(), SSL_FILETYPE_PEM); + if (r <= 0) { + showError("could not use ssl private key"); + return false; } r = SSL_CTX_check_private_key(m_ssl->m_context); if (!r) { - throwError("could not verify ssl private key"); + showError("could not verify ssl private key"); + return false; } + + return true; } void @@ -256,20 +280,29 @@ SecureSocket::secureConnect(int socket) // tell user and sleep so the socket isn't hammered. LOG((CLOG_ERR "failed to connect secure socket")); LOG((CLOG_INFO "server connection may not be secure")); - ARCH->sleep(1); + disconnect(); + return false; } m_secureReady = !retry; if (m_secureReady) { - LOG((CLOG_INFO "connected to secure socket")); - showCertificate(); + if (verifyCertFingerprint()) { + LOG((CLOG_INFO "connected to secure socket")); + if (!showCertificate()) { + disconnect(); + } + } + else { + LOG((CLOG_ERR "failed to verity server certificate fingerprint")); + disconnect(); + } } return retry; } -void +bool SecureSocket::showCertificate() { X509* cert; @@ -284,8 +317,11 @@ SecureSocket::showCertificate() X509_free(cert); } else { - throwError("server has no ssl certificate"); + showError("server has no ssl certificate"); + return false; } + + return true; } void @@ -346,30 +382,20 @@ SecureSocket::checkResult(int n, bool& fatal, bool& retry) if (fatal) { showError(); - sendEvent(getEvents()->forISocket().disconnected()); - sendEvent(getEvents()->forIStream().inputShutdown()); + disconnect(); } } void -SecureSocket::showError() +SecureSocket::showError(const char* reason) { - String error = getError(); - if (!error.empty()) { - LOG((CLOG_ERR "secure socket error: %s", error.c_str())); + if (reason != NULL) { + LOG((CLOG_ERR "%s", reason)); } -} -void -SecureSocket::throwError(const char* reason) -{ String error = getError(); if (!error.empty()) { - throw XSocket(synergy::string::sprintf( - "%s: %s", reason, error.c_str())); - } - else { - throw XSocket(reason); + LOG((CLOG_ERR "%s", error.c_str())); } } @@ -388,6 +414,76 @@ SecureSocket::getError() } } +void +SecureSocket::disconnect() +{ + sendEvent(getEvents()->forISocket().disconnected()); + sendEvent(getEvents()->forIStream().inputShutdown()); +} + +void +SecureSocket::formatFingerprint(String& fingerprint, bool hex, bool separator) +{ + if (hex) { + // to hexidecimal + synergy::string::toHex(fingerprint, 2); + } + + // all uppercase + synergy::string::uppercase(fingerprint); + + if (separator) { + // add colon to separate each 2 charactors + size_t separators = fingerprint.size() / 2; + for (size_t i = 1; i < separators; i++) { + fingerprint.insert(i * 3 - 1, ":"); + } + } +} + +bool +SecureSocket::verifyCertFingerprint() +{ + if (m_certFingerprintFilename.empty()) { + return false; + } + + // calculate received certificate fingerprint + X509 *cert = cert = SSL_get_peer_certificate(m_ssl->m_ssl); + EVP_MD* tempDigest; + unsigned char tempFingerprint[EVP_MAX_MD_SIZE]; + unsigned int tempFingerprintLen; + tempDigest = (EVP_MD*)EVP_sha1(); + if (X509_digest(cert, tempDigest, tempFingerprint, &tempFingerprintLen) <= 0) { + return false; + } + + // format fingerprint into hexdecimal format with colon separator + String fingerprint(reinterpret_cast(tempFingerprint), tempFingerprintLen); + formatFingerprint(fingerprint); + LOG((CLOG_NOTE "server fingerprint: %s", fingerprint.c_str())); + + // check if this fingerprint exist + String fileLine; + std::ifstream file; + file.open(m_certFingerprintFilename.c_str()); + + bool isValid = false; + while (!file.eof()) { + getline(file,fileLine); + // example of a fingerprint:A1:B2:C3 + if (!fileLine.empty()) { + if (fileLine.compare(fingerprint) == 0) { + isValid = true; + break; + } + } + } + + file.close(); + return isValid; +} + ISocketMultiplexerJob* SecureSocket::serviceConnect(ISocketMultiplexerJob* job, bool, bool write, bool error) diff --git a/src/lib/plugin/ns/SecureSocket.h b/src/lib/plugin/ns/SecureSocket.h index d703d2ca..1bc64628 100644 --- a/src/lib/plugin/ns/SecureSocket.h +++ b/src/lib/plugin/ns/SecureSocket.h @@ -43,13 +43,14 @@ public: void secureConnect(); void secureAccept(); + void setFingerprintFilename(String& f) { m_certFingerprintFilename = f; } bool isReady() const { return m_secureReady; } bool isSecureReady(); bool isSecure() { return true; } UInt32 secureRead(void* buffer, UInt32 n); UInt32 secureWrite(const void* buffer, UInt32 n); void initSsl(bool server); - void loadCertificates(const char* CertFile); + bool loadCertificates(String& CertFile); private: // SSL @@ -57,11 +58,15 @@ private: void createSSL(); bool secureAccept(int s); bool secureConnect(int s); - void showCertificate(); + bool showCertificate(); void checkResult(int n, bool& fatal, bool& retry); - void showError(); - void throwError(const char* reason); + void showError(const char* reason = NULL); String getError(); + void disconnect(); + void formatFingerprint(String& fingerprint, + bool hex = true, + bool separator = true); + bool verifyCertFingerprint(); ISocketMultiplexerJob* serviceConnect(ISocketMultiplexerJob*, @@ -74,4 +79,5 @@ private: private: Ssl* m_ssl; bool m_secureReady; + String m_certFingerprintFilename; }; diff --git a/src/lib/synergy/ArgParser.cpp b/src/lib/synergy/ArgParser.cpp index bec60632..82186b3a 100644 --- a/src/lib/synergy/ArgParser.cpp +++ b/src/lib/synergy/ArgParser.cpp @@ -89,6 +89,10 @@ ArgParser::parseClientArgs(ClientArgs& args, int argc, const char* const* argv) // define scroll args.m_yscroll = atoi(argv[++i]); } + else if (isArg(i, argc, argv, NULL, "--certificate-fingerprint", 1)) { + // define scroll + args.m_certFingerprintFilename = argv[++i]; + } else { if (i + 1 == argc) { args.m_synergyAddress = argv[i]; diff --git a/src/lib/synergy/ArgsBase.h b/src/lib/synergy/ArgsBase.h index 125f9b51..743202e8 100644 --- a/src/lib/synergy/ArgsBase.h +++ b/src/lib/synergy/ArgsBase.h @@ -24,25 +24,27 @@ class ArgsBase { public: ArgsBase(); virtual ~ArgsBase(); - bool m_daemon; - bool m_backend; - bool m_restartable; - bool m_noHooks; - const char* m_pname; - const char* m_logFilter; - const char* m_logFile; - const char* m_display; - String m_name; - bool m_disableTray; - bool m_enableIpc; - bool m_enableDragDrop; + +public: + bool m_daemon; + bool m_backend; + bool m_restartable; + bool m_noHooks; + const char* m_pname; + const char* m_logFilter; + const char* m_logFile; + const char* m_display; + String m_name; + bool m_disableTray; + bool m_enableIpc; + bool m_enableDragDrop; #if SYSAPI_WIN32 - bool m_debugServiceWait; - bool m_pauseOnExit; - bool m_stopOnDeskSwitch; + bool m_debugServiceWait; + bool m_pauseOnExit; + bool m_stopOnDeskSwitch; #endif #if WINAPI_XWINDOWS - bool m_disableXInitThreads; + bool m_disableXInitThreads; #endif bool m_shouldExit; String m_synergyAddress; diff --git a/src/lib/synergy/ClientApp.cpp b/src/lib/synergy/ClientApp.cpp index b9c00414..f27a8331 100644 --- a/src/lib/synergy/ClientApp.cpp +++ b/src/lib/synergy/ClientApp.cpp @@ -342,8 +342,7 @@ ClientApp::openClient(const String& name, const NetworkAddress& address, address, new TCPSocketFactory(m_events, getSocketMultiplexer()), screen, - args().m_enableDragDrop, - args().m_enableCrypto); + args()); try { m_events->adoptHandler( diff --git a/src/lib/synergy/ClientArgs.cpp b/src/lib/synergy/ClientArgs.cpp index c997f3cc..fff34817 100644 --- a/src/lib/synergy/ClientArgs.cpp +++ b/src/lib/synergy/ClientArgs.cpp @@ -18,6 +18,7 @@ #include "synergy/ClientArgs.h" ClientArgs::ClientArgs() : - m_yscroll(0) + m_yscroll(0), + m_certFingerprintFilename() { } diff --git a/src/lib/synergy/ClientArgs.h b/src/lib/synergy/ClientArgs.h index db749b3e..093b0ccf 100644 --- a/src/lib/synergy/ClientArgs.h +++ b/src/lib/synergy/ClientArgs.h @@ -26,5 +26,6 @@ public: ClientArgs(); public: - int m_yscroll; + int m_yscroll; + String m_certFingerprintFilename; }; diff --git a/src/test/integtests/net/NetworkTests.cpp b/src/test/integtests/net/NetworkTests.cpp index 7e3d3bd5..b307ff51 100644 --- a/src/test/integtests/net/NetworkTests.cpp +++ b/src/test/integtests/net/NetworkTests.cpp @@ -140,7 +140,11 @@ TEST_F(NetworkTests, sendToClient_mockData) ON_CALL(clientScreen, getShape(_, _, _, _)).WillByDefault(Invoke(getScreenShape)); ON_CALL(clientScreen, getCursorPos(_, _)).WillByDefault(Invoke(getCursorPos)); - Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, true, false); + + ClientArgs args; + args.m_enableDragDrop = true; + args.m_enableCrypto = false; + Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, args); m_events.adoptHandler( m_events.forIScreen().fileRecieveCompleted(), &client, @@ -192,7 +196,11 @@ TEST_F(NetworkTests, sendToClient_mockFile) ON_CALL(clientScreen, getShape(_, _, _, _)).WillByDefault(Invoke(getScreenShape)); ON_CALL(clientScreen, getCursorPos(_, _)).WillByDefault(Invoke(getCursorPos)); - Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, true, false); + + ClientArgs args; + args.m_enableDragDrop = true; + args.m_enableCrypto = false; + Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, args); m_events.adoptHandler( m_events.forIScreen().fileRecieveCompleted(), &client, @@ -238,7 +246,10 @@ TEST_F(NetworkTests, sendToServer_mockData) ON_CALL(clientScreen, getShape(_, _, _, _)).WillByDefault(Invoke(getScreenShape)); ON_CALL(clientScreen, getCursorPos(_, _)).WillByDefault(Invoke(getCursorPos)); - Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, true, false); + ClientArgs args; + args.m_enableDragDrop = true; + args.m_enableCrypto = false; + Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, args); m_events.adoptHandler( m_events.forClientListener().connected(), &listener, @@ -290,8 +301,11 @@ TEST_F(NetworkTests, sendToServer_mockFile) ON_CALL(clientScreen, getShape(_, _, _, _)).WillByDefault(Invoke(getScreenShape)); ON_CALL(clientScreen, getCursorPos(_, _)).WillByDefault(Invoke(getCursorPos)); - Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, true, false); - + ClientArgs args; + args.m_enableDragDrop = true; + args.m_enableCrypto = false; + Client client(&m_events, "stub", serverAddress, clientSocketFactory, &clientScreen, args); + m_events.adoptHandler( m_events.forClientListener().connected(), &listener, new TMethodEventJob( diff --git a/src/test/unittests/base/StringTests.cpp b/src/test/unittests/base/StringTests.cpp index e0108413..2a461f05 100644 --- a/src/test/unittests/base/StringTests.cpp +++ b/src/test/unittests/base/StringTests.cpp @@ -53,3 +53,32 @@ TEST(StringTests, sprintf) EXPECT_EQ("answer=42", result); } + +TEST(StringTests, toHex) +{ + String subject = "foobar"; + int width = 2; + + string::toHex(subject, width); + + EXPECT_EQ("666f6f626172", subject); +} + +TEST(StringTests, uppercase) +{ + String subject = "12foo3BaR"; + + string::uppercase(subject); + + EXPECT_EQ("12FOO3BAR", subject); +} + +TEST(StringTests, removeChar) +{ + String subject = "foobar"; + const char c = 'o'; + + string::removeChar(subject, c); + + EXPECT_EQ("fbar", subject); +}