diff --git a/src/lib/net/TCPSocket.cpp b/src/lib/net/TCPSocket.cpp index cf5a81d0..ce2cdd09 100644 --- a/src/lib/net/TCPSocket.cpp +++ b/src/lib/net/TCPSocket.cpp @@ -326,77 +326,39 @@ TCPSocket::init() TCPSocket::EJobResult TCPSocket::doRead() { - try { - static UInt8 buffer[4096]; - memset(buffer, 0, sizeof(buffer)); - int bytesRead = 0; - int status = 0; + UInt8 buffer[4096]; + memset(buffer, 0, sizeof(buffer)); + size_t bytesRead = 0; + + bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + + if (bytesRead > 0) { + bool wasEmpty = (m_inputBuffer.getSize() == 0); - if (isSecure()) { - if (isSecureReady()) { - status = secureRead(buffer, sizeof(buffer), bytesRead); - if (status < 0) { - return kBreak; - } - else if (status == 0) { - return kNew; - } - } - else { - return kRetry; - } - } - else { - bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer)); - } + // slurp up as much as possible + do { + m_inputBuffer.write(buffer, bytesRead); + + bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + } while (bytesRead > 0); - if (bytesRead > 0) { - bool wasEmpty = (m_inputBuffer.getSize() == 0); - - // slurp up as much as possible - do { - m_inputBuffer.write(buffer, bytesRead); - - if (isSecure() && isSecureReady()) { - status = secureRead(buffer, sizeof(buffer), bytesRead); - if (status < 0) { - return kBreak; - } - } - else { - bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer)); - } - - } while (bytesRead > 0 || status > 0); - - // send input ready if input buffer was empty - if (wasEmpty) { - sendEvent(m_events->forIStream().inputReady()); - } - } - else { - // remote write end of stream hungup. our input side - // has therefore shutdown but don't flush our buffer - // since there's still data to be read. - sendEvent(m_events->forIStream().inputShutdown()); - if (!m_writable && m_inputBuffer.getSize() == 0) { - sendEvent(m_events->forISocket().disconnected()); - m_connected = false; - } - m_readable = false; - return kNew; + // send input ready if input buffer was empty + if (wasEmpty) { + sendEvent(m_events->forIStream().inputReady()); } } - catch (XArchNetworkDisconnected&) { - // stream hungup - sendEvent(m_events->forISocket().disconnected()); - onDisconnected(); + else { + // remote write end of stream hungup. our input side + // has therefore shutdown but don't flush our buffer + // since there's still data to be read. + sendEvent(m_events->forIStream().inputShutdown()); + if (!m_writable && m_inputBuffer.getSize() == 0) { + sendEvent(m_events->forISocket().disconnected()); + m_connected = false; + } + m_readable = false; return kNew; } - catch (XArchNetwork& e) { - // ignore other read error - LOG((CLOG_WARN "error reading socket: %s", e.what())); - } return kRetry; } @@ -404,92 +366,16 @@ TCPSocket::doRead() TCPSocket::EJobResult TCPSocket::doWrite() { - static bool s_retry = false; - static int s_retrySize = 0; - static void* s_staticBuffer = NULL; - - try { - // write data - int bufferSize = 0; - int bytesWrote = 0; - int status = 0; - - if (s_retry) { - bufferSize = s_retrySize; - } - else { - bufferSize = m_outputBuffer.getSize(); - s_staticBuffer = malloc(bufferSize); - memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize); - } - - if (bufferSize == 0) { - return kRetry; - } - - if (isSecure()) { - if (isSecureReady()) { - status = secureWrite(s_staticBuffer, bufferSize, bytesWrote); - if (status > 0) { - s_retry = false; - bufferSize = 0; - free(s_staticBuffer); - s_staticBuffer = NULL; - } - else if (status < 0) { - return kBreak; - } - else if (status == 0) { - s_retry = true; - s_retrySize = bufferSize; - return kNew; - } - } - else { - return kRetry; - } - } - else { - bytesWrote = (UInt32)ARCH->writeSocket(m_socket, s_staticBuffer, bufferSize); - bufferSize = 0; - free(s_staticBuffer); - s_staticBuffer = NULL; - } - - // discard written data - if (bytesWrote > 0) { - m_outputBuffer.pop(bytesWrote); - if (m_outputBuffer.getSize() == 0) { - sendEvent(m_events->forIStream().outputFlushed()); - m_flushed = true; - m_flushed.broadcast(); - return kNew; - } - } - } - catch (XArchNetworkShutdown&) { - // remote read end of stream hungup. our output side - // has therefore shutdown. - onOutputShutdown(); - sendEvent(m_events->forIStream().outputShutdown()); - if (!m_readable && m_inputBuffer.getSize() == 0) { - sendEvent(m_events->forISocket().disconnected()); - m_connected = false; - } - return kNew; - } - catch (XArchNetworkDisconnected&) { - // stream hungup - onDisconnected(); - sendEvent(m_events->forISocket().disconnected()); - return kNew; - } - catch (XArchNetwork& e) { - // other write error - LOG((CLOG_WARN "error writing socket: %s", e.what())); - onDisconnected(); - sendEvent(m_events->forIStream().outputError()); - sendEvent(m_events->forISocket().disconnected()); + // write data + UInt32 bufferSize = 0; + int bytesWrote = 0; + + bufferSize = m_outputBuffer.getSize(); + const void* buffer = m_outputBuffer.peek(bufferSize); + bytesWrote = (UInt32)ARCH->writeSocket(m_socket, buffer, bufferSize); + + if (bytesWrote > 0) { + discardWrittenData(bytesWrote); return kNew; } @@ -550,6 +436,17 @@ TCPSocket::sendEvent(Event::Type type) m_events->addEvent(Event(type, getEventTarget(), NULL)); } +void +TCPSocket::discardWrittenData(int bytesWrote) +{ + m_outputBuffer.pop(bytesWrote); + if (m_outputBuffer.getSize() == 0) { + sendEvent(m_events->forIStream().outputFlushed()); + m_flushed = true; + m_flushed.broadcast(); + } +} + void TCPSocket::onConnected() { @@ -643,11 +540,50 @@ TCPSocket::serviceConnected(ISocketMultiplexerJob* job, EJobResult result = kRetry; if (write) { - result = doWrite(); + try { + result = doWrite(); + } + catch (XArchNetworkShutdown&) { + // remote read end of stream hungup. our output side + // has therefore shutdown. + onOutputShutdown(); + sendEvent(m_events->forIStream().outputShutdown()); + if (!m_readable && m_inputBuffer.getSize() == 0) { + sendEvent(m_events->forISocket().disconnected()); + m_connected = false; + } + result = kNew; + } + catch (XArchNetworkDisconnected&) { + // stream hungup + onDisconnected(); + sendEvent(m_events->forISocket().disconnected()); + result = kNew; + } + catch (XArchNetwork& e) { + // other write error + LOG((CLOG_WARN "error writing socket: %s", e.what())); + onDisconnected(); + sendEvent(m_events->forIStream().outputError()); + sendEvent(m_events->forISocket().disconnected()); + result = kNew; + } } if (read && m_readable) { - result = doRead(); + try { + result = doRead(); + } + catch (XArchNetworkDisconnected&) { + // stream hungup + sendEvent(m_events->forISocket().disconnected()); + onDisconnected(); + result = kNew; + } + catch (XArchNetwork& e) { + // ignore other read error + LOG((CLOG_WARN "error reading socket: %s", e.what())); + } } return result == kBreak ? NULL : result == kNew ? newJob() : job; diff --git a/src/lib/net/TCPSocket.h b/src/lib/net/TCPSocket.h index e4565385..181e7f30 100644 --- a/src/lib/net/TCPSocket.h +++ b/src/lib/net/TCPSocket.h @@ -89,6 +89,7 @@ protected: Mutex& getMutex() { return m_mutex; } void sendEvent(Event::Type); + void discardWrittenData(int bytesWrote); private: void init(); @@ -111,12 +112,12 @@ protected: bool m_writable; bool m_connected; IEventQueue* m_events; + StreamBuffer m_inputBuffer; + StreamBuffer m_outputBuffer; private: Mutex m_mutex; ArchSocket m_socket; - StreamBuffer m_inputBuffer; - StreamBuffer m_outputBuffer; CondVar m_flushed; SocketMultiplexer* m_socketMultiplexer; }; diff --git a/src/lib/plugin/ns/SecureSocket.cpp b/src/lib/plugin/ns/SecureSocket.cpp index 6e0e8f66..9f4df441 100644 --- a/src/lib/plugin/ns/SecureSocket.cpp +++ b/src/lib/plugin/ns/SecureSocket.cpp @@ -140,6 +140,115 @@ SecureSocket::secureAccept() getSocket(), isReadable(), isWritable())); } +TCPSocket::EJobResult +SecureSocket::doRead() +{ + static UInt8 buffer[4096]; + memset(buffer, 0, sizeof(buffer)); + int bytesRead = 0; + int status = 0; + + if (isSecureReady()) { + status = secureRead(buffer, sizeof(buffer), bytesRead); + if (status < 0) { + return kBreak; + } + else if (status == 0) { + return kNew; + } + } + else { + return kRetry; + } + + if (bytesRead > 0) { + bool wasEmpty = (m_inputBuffer.getSize() == 0); + + // slurp up as much as possible + do { + m_inputBuffer.write(buffer, bytesRead); + + status = secureRead(buffer, sizeof(buffer), bytesRead); + if (status < 0) { + return kBreak; + } + } while (bytesRead > 0 || status > 0); + + // send input ready if input buffer was empty + if (wasEmpty) { + sendEvent(m_events->forIStream().inputReady()); + } + } + else { + // remote write end of stream hungup. our input side + // has therefore shutdown but don't flush our buffer + // since there's still data to be read. + sendEvent(m_events->forIStream().inputShutdown()); + if (!m_writable && m_inputBuffer.getSize() == 0) { + sendEvent(m_events->forISocket().disconnected()); + m_connected = false; + } + m_readable = false; + return kNew; + } + + return kRetry; +} + +TCPSocket::EJobResult +SecureSocket::doWrite() +{ + static bool s_retry = false; + static int s_retrySize = 0; + static void* s_staticBuffer = NULL; + + // write data + int bufferSize = 0; + int bytesWrote = 0; + int status = 0; + + if (s_retry) { + bufferSize = s_retrySize; + } + else { + bufferSize = m_outputBuffer.getSize(); + s_staticBuffer = malloc(bufferSize); + memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize); + } + + if (bufferSize == 0) { + return kRetry; + } + + if (isSecureReady()) { + status = secureWrite(s_staticBuffer, bufferSize, bytesWrote); + if (status > 0) { + s_retry = false; + bufferSize = 0; + free(s_staticBuffer); + s_staticBuffer = NULL; + } + else if (status < 0) { + return kBreak; + } + else if (status == 0) { + s_retry = true; + s_retrySize = bufferSize; + return kNew; + } + } + else { + return kRetry; + } + + if (bytesWrote > 0) { + discardWrittenData(bytesWrote); + return kNew; + } + + return kRetry; +} + int SecureSocket::secureRead(void* buffer, int size, int& read) { diff --git a/src/lib/plugin/ns/SecureSocket.h b/src/lib/plugin/ns/SecureSocket.h index 5be68f7d..732ad399 100644 --- a/src/lib/plugin/ns/SecureSocket.h +++ b/src/lib/plugin/ns/SecureSocket.h @@ -48,6 +48,8 @@ public: newJob(); void secureConnect(); void secureAccept(); + EJobResult doRead(); + EJobResult doWrite(); bool isReady() const { return m_secureReady; } bool isFatal() const { return m_fatal; } void isFatal(bool b) { m_fatal = b; }