merge bitcoin#24357: make setsockopt() and SetSocketNoDelay() mockable/testable

This commit is contained in:
Kittywhiskers Van Gogh 2024-04-29 11:48:02 +00:00
parent 9c751ef9d6
commit 6b159f1b87
No known key found for this signature in database
GPG Key ID: 30CD0C065E5C4AAD
8 changed files with 62 additions and 20 deletions

View File

@ -1310,7 +1310,11 @@ void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
// According to the internet TCP_NODELAY is not carried into accepted sockets // According to the internet TCP_NODELAY is not carried into accepted sockets
// on all platforms. Set it again here just to be sure. // on all platforms. Set it again here just to be sure.
SetSocketNoDelay(sock->Get()); const int on{1};
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
LogPrint(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n",
addr.ToString());
}
// Don't accept connections from banned peers. // Don't accept connections from banned peers.
bool banned = m_banman && m_banman->IsBanned(addr); bool banned = m_banman && m_banman->IsBanned(addr);
@ -3219,17 +3223,26 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
// Allow binding if the port is still in TIME_WAIT state after // Allow binding if the port is still in TIME_WAIT state after
// the program was closed and restarted. // the program was closed and restarted.
setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting SO_REUSEADDR on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}
// some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option
// and enable it by default or not. Try to enable it, if possible. // and enable it by default or not. Try to enable it, if possible.
if (addrBind.IsIPv6()) { if (addrBind.IsIPv6()) {
#ifdef IPV6_V6ONLY #ifdef IPV6_V6ONLY
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting IPV6_V6ONLY on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}
#endif #endif
#ifdef WIN32 #ifdef WIN32
int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED;
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}
#endif #endif
} }

View File

@ -498,10 +498,11 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
return nullptr; return nullptr;
} }
auto sock = std::make_unique<Sock>(hSocket);
// Ensure that waiting for I/O on this socket won't result in undefined // Ensure that waiting for I/O on this socket won't result in undefined
// behavior. // behavior.
if (!IsSelectableSocket(hSocket)) { if (!IsSelectableSocket(sock->Get())) {
CloseSocket(hSocket);
LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n");
return nullptr; return nullptr;
} }
@ -510,19 +511,24 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
int set = 1; int set = 1;
// Set the no-sigpipe option on the socket for BSD systems, other UNIXes // Set the no-sigpipe option on the socket for BSD systems, other UNIXes
// should use the MSG_NOSIGNAL flag for every send. // should use the MSG_NOSIGNAL flag for every send.
setsockopt(hSocket, SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); if (sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)) == SOCKET_ERROR) {
LogPrintf("Error setting SO_NOSIGPIPE on socket: %s, continuing anyway\n",
NetworkErrorString(WSAGetLastError()));
}
#endif #endif
// Set the no-delay option (disable Nagle's algorithm) on the TCP socket. // Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
SetSocketNoDelay(hSocket); const int on{1};
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
}
// Set the non-blocking option on the socket. // Set the non-blocking option on the socket.
if (!SetSocketNonBlocking(hSocket)) { if (!SetSocketNonBlocking(sock->Get())) {
CloseSocket(hSocket);
LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError()));
return nullptr; return nullptr;
} }
return std::make_unique<Sock>(hSocket); return sock;
} }
std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP; std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP;
@ -729,13 +735,6 @@ bool SetSocketNonBlocking(const SOCKET& hSocket)
return true; return true;
} }
bool SetSocketNoDelay(const SOCKET& hSocket)
{
int set = 1;
int rc = setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int));
return rc == 0;
}
void InterruptSocks5(bool interrupt) void InterruptSocks5(bool interrupt)
{ {
interruptSocks5Recv = interrupt; interruptSocks5Recv = interrupt;

View File

@ -227,8 +227,6 @@ bool ConnectThroughProxy(const Proxy& proxy, const std::string& strDest, uint16_
/** Enable non-blocking mode for a socket */ /** Enable non-blocking mode for a socket */
bool SetSocketNonBlocking(const SOCKET& hSocket); bool SetSocketNonBlocking(const SOCKET& hSocket);
/** Set the TCP_NODELAY flag on a socket */
bool SetSocketNoDelay(const SOCKET& hSocket);
void InterruptSocks5(bool interrupt); void InterruptSocks5(bool interrupt);
/** /**

View File

@ -190,6 +190,19 @@ int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* op
return 0; return 0;
} }
int FuzzedSock::SetSockOpt(int, int, const void*, socklen_t) const
{
constexpr std::array setsockopt_errnos{
ENOMEM,
ENOBUFS,
};
if (m_fuzzed_data_provider.ConsumeBool()) {
SetFuzzedErrNo(m_fuzzed_data_provider, setsockopt_errnos);
return -1;
}
return 0;
}
bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{ {
constexpr std::array wait_errnos{ constexpr std::array wait_errnos{

View File

@ -70,6 +70,8 @@ public:
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override; int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override;
int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const override;
bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;
bool IsConnected(std::string& errmsg) const override; bool IsConnected(std::string& errmsg) const override;

View File

@ -150,6 +150,8 @@ public:
return 0; return 0;
} }
int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; }
bool Wait(std::chrono::milliseconds timeout, bool Wait(std::chrono::milliseconds timeout,
Event requested, Event requested,
Event* occurred = nullptr) const override Event* occurred = nullptr) const override

View File

@ -105,6 +105,11 @@ int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len)
return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len); return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
} }
int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
{
return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
}
bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{ {
#ifdef USE_POLL #ifdef USE_POLL

View File

@ -163,6 +163,16 @@ public:
void* opt_val, void* opt_val,
socklen_t* opt_len) const; socklen_t* opt_len) const;
/**
* setsockopt(2) wrapper. Equivalent to
* `setsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int SetSockOpt(int level,
int opt_name,
const void* opt_val,
socklen_t opt_len) const;
using Event = uint8_t; using Event = uint8_t;
/** /**