diff --git a/luprex/core/cpp/driver-linux.cpp b/luprex/core/cpp/driver-linux.cpp index b6701174..e65239f9 100644 --- a/luprex/core/cpp/driver-linux.cpp +++ b/luprex/core/cpp/driver-linux.cpp @@ -325,25 +325,47 @@ public: } void handle_socket_input_output(int mstimeout) { - fd_set rfds, wfds, efds; - int nbytes; const char *bytes; - int nfds = calc_select_sets(rfds, wfds, efds); - struct timeval timeout; - timeout.tv_sec = mstimeout / 1000; - timeout.tv_usec = (mstimeout - (timeout.tv_sec*1000)) * 1000; - int status = select(nfds, &rfds, &wfds, &efds, &timeout); + // Construct the pollfd vector. + std::vector pollvec; + pollvec.resize(listen_sockets_.size() + chans_.size() + 1); + int index = 0; + for (const auto &p : listen_sockets_) { + struct pollfd &pfd = pollvec[index++]; + pfd.fd = p.second; + pfd.events = POLLIN; + } + for (const ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec[index++]; + pfd.fd = chan.socket; + pfd.events = POLLIN; + if (!driven_->drv_outgoing_empty(chan.chid)) { + pfd.events |= POLLOUT; + } + } + struct pollfd &stdiopoll = pollvec[index++]; + stdiopoll.fd = 0; + stdiopoll.events = POLLIN; + + // Do the poll. + int status = poll(&pollvec[0], pollvec.size(), mstimeout); assert(status >= 0); + // Check listening sockets. + index = 0; for (auto &p : listen_sockets_) { - if (FD_ISSET(p.second, &rfds) || FD_ISSET(p.second, &efds)) { + struct pollfd &pfd = pollvec[index++]; + if (pfd.revents & (POLLIN | POLLERR)) { accept_connections(p.first, p.second); } } + // Transfer bytes wherever possible. for (ChanInfo &chan : chans_) { + struct pollfd &pfd = pollvec[index++]; SOCKET sock = chan.socket; if (sock == INVALID_SOCKET) continue; - if (FD_ISSET(sock, &wfds)) { + if (pfd.revents & POLLOUT) { + int nbytes; const char *bytes; chan.state = CHAN_OPEN; driven_->drv_peek_outgoing(chan.chid, &nbytes, &bytes); if (nbytes > 0) { @@ -356,7 +378,7 @@ public: } } } - if (FD_ISSET(sock, &rfds) || FD_ISSET(sock, &efds)) { + if (pfd.revents & (POLLIN | POLLERR)) { // Someday, find a way to avoid this copy. int nrecv = recv(sock, chbuf.get(), 65536, 0); if (nrecv <= 0) { @@ -368,6 +390,8 @@ public: } } } + + // Delete any newly-inactive channels cleanup_channels(); }