#include "tcp.h" #include "error.h" #include #include #include #include #include #include #include #include #define MAX_TCP_SEND 4000 using namespace std; namespace Librescue { TcpConnection::TcpConnection(const char* host, int port) { Address address(host,port); init(&address); } TcpConnection::TcpConnection(const Address& target) { init(&target); } TcpConnection::TcpConnection(const Address& from, int socketDescriptor) : m_socketDescriptor(socketDescriptor) { m_from = from; init(NULL); } TcpConnection::~TcpConnection() { close(); } void TcpConnection::init(const Address* target) { if (target) { m_socketDescriptor = socket(PF_INET,SOCK_STREAM,0); struct linger l = {1, 0}; if (setsockopt(m_socketDescriptor, SOL_SOCKET, SO_LINGER, &l, sizeof(linger)) == -1) { cerr << "HEY! I cannot set SO_LINGER" << endl; exit(0); } if (m_socketDescriptor==-1) throw string("Could not open TCP server socket"); sockaddr_in sockStruct; memset(&sockStruct,0,sizeof(sockStruct)); target->fill(&sockStruct); if (connect(m_socketDescriptor, (sockaddr *) &sockStruct, sizeof(sockStruct))) { close(); LOG_ERROR("Could not connect to server %s",target->toString()); throw string("Error connecting to server"); } m_from = *target; } sockaddr_in sockStruct; int localLength = sizeof(sockStruct); getsockname(m_socketDescriptor,(sockaddr*)&sockStruct,(socklen_t*)&localLength); Address local(sockStruct); m_local = local; } void TcpConnection::receive(Input &result) { Bytes message; Byte buf[4]; if (!receiveData(buf,4)) { result.reset(message); return; } // int count = 0; // while (count < 4) { // int received = recv(m_socketDescriptor, &buf[count], 4-count, 0); // if (received < 0) { // throw "Error receiving through the socket"; // } // count += received; // } // Decode the size big-endian style unsigned int length = (buf[0] << 24) + (buf[1]<<16) + (buf[2]<<8) + (buf[3]); // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"%x %x %x %x",buf[0],buf[1],buf[2],buf[3]); // logDebug(errorBuffer); // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"About to receive %d bytes via TCP on %s from %s",length,m_local.toString(),m_from.toString()); // logDebug(errorBuffer); message.reserve(length); // unsigned int totalReceived = 0; Byte *buffer = new Byte[length]; if (receiveData(buffer,length)) { for (unsigned int i = 0; i < length; i++) message.push_back(buffer[i]); } else { message.clear(); } delete [] buffer; /* while (totalReceived < length) { int received = ::recv(m_socketDescriptor, bufferStart, length - totalReceived, 0); if (received < 0) throw "Error receiving through the socket"; for (int i = 0; i < received; i++) message.push_back(bufferStart[i]); totalReceived += received; // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"Received %d bytes of %d",totalReceived,length); // logDebug(errorBuffer); } */ // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"Received %d bytes via TCP",totalReceived); // logDebug(errorBuffer); if (message.size()>0 && message.size()!=length) logWarning("WARNING: Message size is too small"); result.reset(message); } bool TcpConnection::receiveData(Byte* buffer, int length) { int total = 0; pollfd pollRequest; pollRequest.fd = m_socketDescriptor; pollRequest.events = POLLIN | POLLERR | POLLHUP | POLLNVAL; pollRequest.revents = 0; while (total < length) { int result = poll(&pollRequest, 1, 100); // LOG_DEBUG("Poll returned %d, pollRequest.revents=0x%d",result,pollRequest.revents); if (result < 0) { // LOG_ERROR("Error reading TCP data: %s",strerror(errno)); return false; } if (pollRequest.revents & POLLERR) { LOG_ERROR("Poll error"); return false; } if (pollRequest.revents & POLLHUP) { LOG_ERROR("Poll hangup"); return false; } if (pollRequest.revents & POLLNVAL) { LOG_ERROR("Poll file descriptor not open"); return false; } if (result > 0) { // LOG_DEBUG("Trying to read up to %d bytes",length-total); int amount = recv(m_socketDescriptor,buffer+total,length-total,0); // LOG_DEBUG("Actually read %d bytes",amount); // We should always read at least one byte unless there is an error - poll told us that there was data to read. A return of zero must mean EOF if (amount==0) { LOG_ERROR("Unexpected EOF from %s",m_from.toString()); return false; } if (amount < 0) { // LOG_ERROR("Error receiving from %s: %s",m_from.toString(),strerror(errno)); return false; } total += amount; } } return true; } bool TcpConnection::send(const Output &output) { int size = (int) output.buffer().size(); int totalSent = 0; // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"Sending %d bytes",size); // logDebug(errorBuffer); // Send the size big-endian style Byte buf[4]; buf[0] = (size>>24) & 0xFF; buf[1] = (size>>16) & 0xFF; buf[2] = (size>>8) & 0xFF; buf[3] = size & 0xFF; if (::send(m_socketDescriptor, buf, 4, 0) < 0) return false; Bytes back = output.buffer(); while (totalSent < size) { int sent = ::send(m_socketDescriptor, &*back.begin() + totalSent, min(MAX_TCP_SEND, size - totalSent), 0); if (sent < 0) return false; totalSent += sent; // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"Sent %d of %d bytes",totalSent,size); // logDebug(errorBuffer); } return true; } bool TcpConnection::isDataAvailable(int timeout) { pollfd pollRequest; pollRequest.fd = m_socketDescriptor; pollRequest.events = POLLIN; pollRequest.revents = 0; if (poll(&pollRequest, 1, timeout) > 0) return true; return false; } void TcpConnection::close() { if (m_socketDescriptor) { if (shutdown(m_socketDescriptor,SHUT_RDWR)) { // LOG_ERROR("Error closing socket: %s",strerror(errno)); } m_socketDescriptor = 0; } } bool TcpConnection::isOpen() { return m_socketDescriptor!=0; } const Address& TcpConnection::addressReceivedFrom() const { return m_from; } TcpServer::TcpServer(int listenPort, TCPConnectionCallback& callback) : m_callback(callback) { // Initialise the listening socket m_socket = socket(PF_INET,SOCK_STREAM,0); if (m_socket==-1) { logError("Could not open TCP server socket"); throw string("Could not open TCP server socket"); } struct linger l = {1, 0}; if (setsockopt(m_socket, SOL_SOCKET, SO_LINGER, &l, sizeof(linger)) == -1) { cerr << "HEY! I cannot set SO_LINGER" << endl; exit(0); } sockaddr_in sockStruct; memset(&sockStruct,0,sizeof(sockStruct)); sockStruct.sin_family = AF_INET; sockStruct.sin_addr.s_addr = htonl(INADDR_ANY); sockStruct.sin_port = htons(listenPort); if (::bind(m_socket, (sockaddr *) &sockStruct, sizeof(sockStruct))) { logError("Couldn't bind on requested port"); throw string("Error binding on desired port"); } // Make the socket non-blocking fcntl(m_socket,F_SETFL,O_NONBLOCK); // Init the mutexes // pthread_mutex_init(&m_listenLock,NULL); pthread_mutex_init(&m_runLock,NULL); // Say what's going on int localLength = sizeof(sockStruct); getsockname(m_socket,(sockaddr*)&sockStruct,(socklen_t*)&localLength); Address local(sockStruct); LOG_INFO("Listening for TCP connections on %s",local.toString()); } TcpServer::~TcpServer() { stop(); // Close the socket if (shutdown(m_socket,SHUT_RDWR)) logError("Error closing socket"); // Destroy the mutexes // pthread_mutex_destroy(&m_listenLock); pthread_mutex_destroy(&m_runLock); // Destroy the thread if it still exists if (m_listenThread) pthread_cancel(m_listenThread); } void TcpServer::start() { pthread_mutex_lock(&m_runLock); running = true; pthread_mutex_unlock(&m_runLock); if (pthread_create(&m_listenThread,NULL,&TcpServer::threadFunc,(void*)this)) { logError("Couldn't create TCP server thread"); throw ("Could not start TCP server"); } } void TcpServer::stop() { pthread_mutex_lock(&m_runLock); running = false; pthread_mutex_unlock(&m_runLock); if (m_listenThread) { pthread_join(m_listenThread,NULL); m_listenThread = 0; } } void* TcpServer::threadFunc(void* args) { TcpServer* server = (TcpServer*)args; server->loop(); return server; } void TcpServer::loop() { // logDebug("TCP server running"); // Start listening if (listen(m_socket,100)) { // LOG_ERROR("TCP server couldn't listen: %s",strerror(errno)); throw string("Couldn't start listening"); } while (true) { pthread_mutex_lock(&m_runLock); bool finish = !running; pthread_mutex_unlock(&m_runLock); if (finish) return; // Are there any connections pending? pollfd pollRequest; pollRequest.fd = m_socket; pollRequest.events = POLLIN; pollRequest.revents = 0; // logDebug("TCP server checking for incoming connection"); if (poll(&pollRequest,1,100)>0) { // AcceptInfo info; struct sockaddr_in from; socklen_t size = sizeof(struct sockaddr_in); int socket = accept(m_socket,(struct sockaddr*)&from,&size); if (socket<0) logError("Error accepting connection"); else { Address fromAddress(from); TcpConnection* result = new TcpConnection(fromAddress,socket); // snprintf(errorBuffer,ERROR_BUFFER_LENGTH,"Incoming TCP connection from %s",fromAddress.toString()); // logDebug(errorBuffer); m_callback.incomingTCPConnection(result,fromAddress); // info.socket = socket; // pthread_mutex_lock(&m_listenLock); // accepted.push_back(info); // pthread_mutex_unlock(&m_listenLock); } } } // logDebug("TCP server finished"); } /* bool TcpServer::hasConnections() { pthread_mutex_lock(&m_listenLock); bool result = !accepted.empty(); pthread_mutex_unlock(&m_listenLock); return result; } TcpConnection* TcpServer::getConnection() { TcpConnection* result = NULL; pthread_mutex_lock(&m_listenLock); if (!accepted.empty()) { AcceptInfo socket = accepted.front(); accepted.pop_front(); Address from(socket.from); try { result = new TcpConnection(from,socket.socket); } catch (std::bad_alloc& ex) { logError("std::bad_alloc in TcpServer::getConnection()"); } } pthread_mutex_unlock(&m_listenLock); return result; } */ }