diff --git a/src/Makefile.test.include b/src/Makefile.test.include index c48eacf1c0..db14dade78 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -106,6 +106,7 @@ BITCOIN_TESTS =\ test/getarg_tests.cpp \ test/governance_validators_tests.cpp \ test/hash_tests.cpp \ + test/i2p_tests.cpp \ test/interfaces_tests.cpp \ test/key_io_tests.cpp \ test/key_tests.cpp \ @@ -248,6 +249,7 @@ test_fuzz_fuzz_SOURCES = \ test/fuzz/golomb_rice.cpp \ test/fuzz/hex.cpp \ test/fuzz/http_request.cpp \ + test/fuzz/i2p.cpp \ test/fuzz/integer.cpp \ test/fuzz/key.cpp \ test/fuzz/key_io.cpp \ diff --git a/src/i2p.cpp b/src/i2p.cpp index d16c620d88..a44f09f043 100644 --- a/src/i2p.cpp +++ b/src/i2p.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -115,7 +116,8 @@ namespace sam { Session::Session(const fs::path& private_key_file, const CService& control_host, CThreadInterrupt* interrupt) - : m_private_key_file(private_key_file), m_control_host(control_host), m_interrupt(interrupt) + : m_private_key_file(private_key_file), m_control_host(control_host), m_interrupt(interrupt), + m_control_sock(std::make_unique(INVALID_SOCKET)) { } @@ -145,7 +147,7 @@ bool Session::Accept(Connection& conn) try { while (!*m_interrupt) { Sock::Event occurred; - conn.sock.Wait(MAX_WAIT_FOR_IO, Sock::RECV, &occurred); + conn.sock->Wait(MAX_WAIT_FOR_IO, Sock::RECV, &occurred); if ((occurred & Sock::RECV) == 0) { // Timeout, no incoming connections within MAX_WAIT_FOR_IO. @@ -153,7 +155,7 @@ bool Session::Accept(Connection& conn) } const std::string& peer_dest = - conn.sock.RecvUntilTerminator('\n', MAX_WAIT_FOR_IO, *m_interrupt, MAX_MSG_SIZE); + conn.sock->RecvUntilTerminator('\n', MAX_WAIT_FOR_IO, *m_interrupt, MAX_MSG_SIZE); conn.peer = CService(DestB64ToAddr(peer_dest), Params().GetDefaultPort()); @@ -171,7 +173,7 @@ bool Session::Connect(const CService& to, Connection& conn, bool& proxy_error) proxy_error = true; std::string session_id; - Sock sock; + std::unique_ptr sock; conn.peer = to; try { @@ -184,12 +186,12 @@ bool Session::Connect(const CService& to, Connection& conn, bool& proxy_error) } const Reply& lookup_reply = - SendRequestAndGetReply(sock, strprintf("NAMING LOOKUP NAME=%s", to.ToStringIP())); + SendRequestAndGetReply(*sock, strprintf("NAMING LOOKUP NAME=%s", to.ToStringIP())); const std::string& dest = lookup_reply.Get("VALUE"); const Reply& connect_reply = SendRequestAndGetReply( - sock, strprintf("STREAM CONNECT ID=%s DESTINATION=%s SILENT=false", session_id, dest), + *sock, strprintf("STREAM CONNECT ID=%s DESTINATION=%s SILENT=false", session_id, dest), false); const std::string& result = connect_reply.Get("RESULT"); @@ -271,7 +273,7 @@ Session::Reply Session::SendRequestAndGetReply(const Sock& sock, return reply; } -Sock Session::Hello() const +std::unique_ptr Session::Hello() const { auto sock = CreateSock(m_control_host); @@ -279,13 +281,13 @@ Sock Session::Hello() const throw std::runtime_error("Cannot create socket"); } - if (!ConnectSocketDirectly(m_control_host, sock->Get(), nConnectTimeout, true)) { + if (!ConnectSocketDirectly(m_control_host, *sock, nConnectTimeout, true)) { throw std::runtime_error(strprintf("Cannot connect to %s", m_control_host.ToString())); } SendRequestAndGetReply(*sock, "HELLO VERSION MIN=3.1 MAX=3.1"); - return std::move(*sock); + return sock; } void Session::CheckControlSock() @@ -293,7 +295,7 @@ void Session::CheckControlSock() LOCK(m_mutex); std::string errmsg; - if (!m_control_sock.IsConnected(errmsg)) { + if (!m_control_sock->IsConnected(errmsg)) { Log("Control socket error: %s", errmsg); Disconnect(); } @@ -341,26 +343,26 @@ Binary Session::MyDestination() const void Session::CreateIfNotCreatedAlready() { std::string errmsg; - if (m_control_sock.IsConnected(errmsg)) { + if (m_control_sock->IsConnected(errmsg)) { return; } Log("Creating SAM session with %s", m_control_host.ToString()); - Sock sock = Hello(); + auto sock = Hello(); const auto& [read_ok, data] = ReadBinaryFile(m_private_key_file); if (read_ok) { m_private_key.assign(data.begin(), data.end()); } else { - GenerateAndSavePrivateKey(sock); + GenerateAndSavePrivateKey(*sock); } const std::string& session_id = GetRandHash().GetHex().substr(0, 10); // full is an overkill, too verbose in the logs const std::string& private_key_b64 = SwapBase64(EncodeBase64(m_private_key)); - SendRequestAndGetReply(sock, strprintf("SESSION CREATE STYLE=STREAM ID=%s DESTINATION=%s", - session_id, private_key_b64)); + SendRequestAndGetReply(*sock, strprintf("SESSION CREATE STYLE=STREAM ID=%s DESTINATION=%s", + session_id, private_key_b64)); m_my_addr = CService(DestBinToAddr(MyDestination()), Params().GetDefaultPort()); m_session_id = session_id; @@ -370,12 +372,12 @@ void Session::CreateIfNotCreatedAlready() m_my_addr.ToString()); } -Sock Session::StreamAccept() +std::unique_ptr Session::StreamAccept() { - Sock sock = Hello(); + auto sock = Hello(); const Reply& reply = SendRequestAndGetReply( - sock, strprintf("STREAM ACCEPT ID=%s SILENT=false", m_session_id), false); + *sock, strprintf("STREAM ACCEPT ID=%s SILENT=false", m_session_id), false); const std::string& result = reply.Get("RESULT"); @@ -393,14 +395,14 @@ Sock Session::StreamAccept() void Session::Disconnect() { - if (m_control_sock.Get() != INVALID_SOCKET) { + if (m_control_sock->Get() != INVALID_SOCKET) { if (m_session_id.empty()) { Log("Destroying incomplete session"); } else { Log("Destroying session %s", m_session_id); } } - m_control_sock.Reset(); + m_control_sock->Reset(); m_session_id.clear(); } } // namespace sam diff --git a/src/i2p.h b/src/i2p.h index 1ebe7d0329..cb2efedba8 100644 --- a/src/i2p.h +++ b/src/i2p.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -29,7 +30,7 @@ using Binary = std::vector; */ struct Connection { /** Connected socket. */ - Sock sock; + std::unique_ptr sock; /** Our I2P address. */ CService me; @@ -166,7 +167,7 @@ private: * @return a connected socket * @throws std::runtime_error if an error occurs */ - Sock Hello() const EXCLUSIVE_LOCKS_REQUIRED(m_mutex); + std::unique_ptr Hello() const EXCLUSIVE_LOCKS_REQUIRED(m_mutex); /** * Check the control socket for errors and possibly disconnect. @@ -204,10 +205,11 @@ private: /** * Open a new connection to the SAM proxy and issue "STREAM ACCEPT" request using the existing - * session id. Return the idle socket that is waiting for a peer to connect to us. + * session id. + * @return the idle socket that is waiting for a peer to connect to us * @throws std::runtime_error if an error occurs */ - Sock StreamAccept() EXCLUSIVE_LOCKS_REQUIRED(m_mutex); + std::unique_ptr StreamAccept() EXCLUSIVE_LOCKS_REQUIRED(m_mutex); /** * Destroy the session, closing the internally used sockets. @@ -248,7 +250,7 @@ private: * connections and make outgoing ones. * See https://geti2p.net/en/docs/api/samv3 */ - Sock m_control_sock GUARDED_BY(m_mutex); + std::unique_ptr m_control_sock GUARDED_BY(m_mutex); /** * Our .b32.i2p address. diff --git a/src/masternode/node.cpp b/src/masternode/node.cpp index b05cf1e5de..84a2005293 100644 --- a/src/masternode/node.cpp +++ b/src/masternode/node.cpp @@ -119,7 +119,7 @@ void CActiveMasternodeManager::Init(const CBlockIndex* pindex) LogPrintf("CActiveMasternodeManager::Init -- ERROR: %s\n", strError); return; } - bool fConnected = ConnectSocketDirectly(activeMasternodeInfo.service, sock->Get(), nConnectTimeout, true) && IsSelectableSocket(sock->Get()); + bool fConnected = ConnectSocketDirectly(activeMasternodeInfo.service, *sock, nConnectTimeout, true) && IsSelectableSocket(sock->Get()); sock->Reset(); if (!fConnected && Params().RequireRoutableExternalIP()) { diff --git a/src/net.cpp b/src/net.cpp index ad3ea5c549..3f1d5f0206 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -468,7 +468,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo i2p::Connection conn; if (m_i2p_sam_session->Connect(addrConnect, conn, proxyConnectionFailed)) { connected = true; - sock = std::make_unique(std::move(conn.sock)); + sock = std::move(conn.sock); addr_bind = CAddress{conn.me, NODE_NONE}; } } else if (GetProxy(addrConnect.GetNetwork(), proxy)) { @@ -484,7 +484,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo if (!sock) { return nullptr; } - connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout, manual_connection); + connected = ConnectSocketDirectly(addrConnect, *sock, nConnectTimeout, manual_connection); } if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -2854,7 +2854,7 @@ void CConnman::ThreadI2PAcceptIncoming() continue; } - CreateNodeFromAcceptedSocket(conn.sock.Release(), NetPermissionFlags::PF_NONE, + CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::PF_NONE, CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE}); } } diff --git a/src/netbase.cpp b/src/netbase.cpp index c7da4defa0..fc940b0501 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -536,12 +536,12 @@ static void LogConnectFailure(bool manual_connection, const char* fmt, const Arg } } -bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, int nTimeout, bool manual_connection) +bool ConnectSocketDirectly(const CService &addrConnect, const Sock& sock, int nTimeout, bool manual_connection) { // Create a sockaddr from the specified service. struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - if (hSocket == INVALID_SOCKET) { + if (sock.Get() == INVALID_SOCKET) { LogPrintf("Cannot connect to %s: invalid socket\n", addrConnect.ToString()); return false; } @@ -551,8 +551,7 @@ bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, i } // Connect to the addrConnect service on the hSocket socket. - if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) - { + if (sock.Connect(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); // WSAEINVAL is here because some legacy version of winsock uses it if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) @@ -560,46 +559,34 @@ bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, i // Connection didn't actually fail, but is being established // asynchronously. Thus, use async I/O api (select/poll) // synchronously to check for successful connection with a timeout. -#ifdef USE_POLL - struct pollfd pollfd = {}; - pollfd.fd = hSocket; - pollfd.events = POLLIN | POLLOUT; - int nRet = poll(&pollfd, 1, nTimeout); -#else - struct timeval timeout = MillisToTimeval(nTimeout); - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(hSocket, &fdset); - int nRet = select(hSocket + 1, nullptr, &fdset, nullptr, &timeout); -#endif - // Upon successful completion, both select and poll return the total - // number of file descriptors that have been selected. A value of 0 - // indicates that the call timed out and no file descriptors have - // been selected. - if (nRet == 0) - { - LogPrint(BCLog::NET, "connection to %s timeout\n", addrConnect.ToString()); + const Sock::Event requested = Sock::RECV | Sock::SEND; + Sock::Event occurred; + if (!sock.Wait(std::chrono::milliseconds{nTimeout}, requested, &occurred)) { + LogPrintf("wait for connect to %s failed: %s\n", + addrConnect.ToString(), + NetworkErrorString(WSAGetLastError())); return false; - } - if (nRet == SOCKET_ERROR) - { - LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); + } else if (occurred == 0) { + LogPrint(BCLog::NET, "connection attempt to %s timed out\n", addrConnect.ToString()); return false; } - // Even if the select/poll was successful, the connect might not + // Even if the wait was successful, the connect might not // have been successful. The reason for this failure is hidden away // in the SO_ERROR for the socket in modern systems. We read it into - // nRet here. - socklen_t nRetSize = sizeof(nRet); - if (getsockopt(hSocket, SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&nRet, &nRetSize) == SOCKET_ERROR) - { + // sockerr here. + int sockerr; + socklen_t sockerr_len = sizeof(sockerr); + if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&sockerr, &sockerr_len) == + SOCKET_ERROR) { LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); return false; } - if (nRet != 0) - { - LogConnectFailure(manual_connection, "connect() to %s failed after select(): %s", addrConnect.ToString(), NetworkErrorString(nRet)); + if (sockerr != 0) { + LogConnectFailure(manual_connection, + "connect() to %s failed after wait: %s", + addrConnect.ToString(), + NetworkErrorString(sockerr)); return false; } } @@ -667,7 +654,7 @@ bool IsProxy(const CNetAddr &addr) { bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, uint16_t port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed) { // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, sock.Get(), nTimeout, true)) { + if (!ConnectSocketDirectly(proxy.proxy, sock, nTimeout, true)) { outProxyConnectionFailed = true; return false; } diff --git a/src/netbase.h b/src/netbase.h index 90f7f604e9..9f93a46149 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -178,7 +178,7 @@ extern std::function(const CService&)> CreateSock; * Try to connect to the specified service on the specified socket. * * @param addrConnect The service to which to connect. - * @param hSocket The socket on which to connect. + * @param sock The socket on which to connect. * @param nTimeout Wait this many milliseconds for the connection to be * established. * @param manual_connection Whether or not the connection was manually requested @@ -186,7 +186,7 @@ extern std::function(const CService&)> CreateSock; * * @returns Whether or not a connection was successfully made. */ -bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, int nTimeout, bool manual_connection); +bool ConnectSocketDirectly(const CService &addrConnect, const Sock& sock, int nTimeout, bool manual_connection); /** * Connect to a specified destination service through a SOCKS5 proxy by first diff --git a/src/test/fuzz/i2p.cpp b/src/test/fuzz/i2p.cpp new file mode 100644 index 0000000000..9f4e2bbf22 --- /dev/null +++ b/src/test/fuzz/i2p.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2020-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void initialize_i2p() +{ + InitializeFuzzingContext(); +} + +FUZZ_TARGET_INIT(i2p, initialize_i2p) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + // Mock CreateSock() to create FuzzedSock. + auto CreateSockOrig = CreateSock; + CreateSock = [&fuzzed_data_provider](const CService&) { + return std::make_unique(fuzzed_data_provider); + }; + + const CService sam_proxy; + CThreadInterrupt interrupt; + + i2p::sam::Session sess{GetDataDir() / "fuzzed_i2p_private_key", sam_proxy, &interrupt}; + + i2p::Connection conn; + + if (sess.Listen(conn)) { + if (sess.Accept(conn)) { + try { + conn.sock->RecvUntilTerminator('\n', 10ms, interrupt, i2p::sam::MAX_MSG_SIZE); + } catch (const std::runtime_error&) { + } + } + } + + const CService to; + bool proxy_error; + + if (sess.Connect(to, conn, proxy_error)) { + try { + conn.sock->SendComplete("verack\n", 10ms, interrupt); + } catch (const std::runtime_error&) { + } + } + + CreateSock = CreateSockOrig; +} diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index fd3b06179d..bd5216e5e7 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -526,36 +526,37 @@ class FuzzedSock : public Sock { FuzzedDataProvider& m_fuzzed_data_provider; + /** + * Data to return when `MSG_PEEK` is used as a `Recv()` flag. + * If `MSG_PEEK` is used, then our `Recv()` returns some random data as usual, but on the next + * `Recv()` call we must return the same data, thus we remember it here. + */ + mutable std::optional m_peek_data; + public: explicit FuzzedSock(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider} { + m_socket = fuzzed_data_provider.ConsumeIntegral(); } ~FuzzedSock() override { + // Sock::~Sock() will be called after FuzzedSock::~FuzzedSock() and it will call + // Sock::Reset() (not FuzzedSock::Reset()!) which will call CloseSocket(m_socket). + // Avoid closing an arbitrary file descriptor (m_socket is just a random number which + // may concide with a real opened file descriptor). + Reset(); } FuzzedSock& operator=(Sock&& other) override { - assert(false && "Not implemented yet."); + assert(false && "Move of Sock into FuzzedSock not allowed."); return *this; } - SOCKET Get() const override - { - assert(false && "Not implemented yet."); - return INVALID_SOCKET; - } - - SOCKET Release() override - { - assert(false && "Not implemented yet."); - return INVALID_SOCKET; - } - void Reset() override { - assert(false && "Not implemented yet."); + m_socket = INVALID_SOCKET; } ssize_t Send(const void* data, size_t len, int flags) const override @@ -592,10 +593,13 @@ public: ssize_t Recv(void* buf, size_t len, int flags) const override { + // Have a permanent error at recv_errnos[0] because when the fuzzed data is exhausted + // SetFuzzedErrNo() will always return the first element and we want to avoid Recv() + // returning -1 and setting errno to EAGAIN repeatedly. constexpr std::array recv_errnos{ + ECONNREFUSED, EAGAIN, EBADF, - ECONNREFUSED, EFAULT, EINTR, EINVAL, @@ -612,8 +616,26 @@ public: } return r; } - const std::vector random_bytes = m_fuzzed_data_provider.ConsumeBytes( - m_fuzzed_data_provider.ConsumeIntegralInRange(0, len)); + std::vector random_bytes; + bool pad_to_len_bytes{m_fuzzed_data_provider.ConsumeBool()}; + if (m_peek_data.has_value()) { + // `MSG_PEEK` was used in the preceding `Recv()` call, return `m_peek_data`. + random_bytes.assign({m_peek_data.value()}); + if ((flags & MSG_PEEK) == 0) { + m_peek_data.reset(); + } + pad_to_len_bytes = false; + } else if ((flags & MSG_PEEK) != 0) { + // New call with `MSG_PEEK`. + random_bytes = m_fuzzed_data_provider.ConsumeBytes(1); + if (!random_bytes.empty()) { + m_peek_data = random_bytes[0]; + pad_to_len_bytes = false; + } + } else { + random_bytes = m_fuzzed_data_provider.ConsumeBytes( + m_fuzzed_data_provider.ConsumeIntegralInRange(0, len)); + } if (random_bytes.empty()) { const ssize_t r = m_fuzzed_data_provider.ConsumeBool() ? 0 : -1; if (r == -1) { @@ -622,7 +644,7 @@ public: return r; } std::memcpy(buf, random_bytes.data(), random_bytes.size()); - if (m_fuzzed_data_provider.ConsumeBool()) { + if (pad_to_len_bytes) { if (len > random_bytes.size()) { std::memset((char*)buf + random_bytes.size(), 0, len - random_bytes.size()); } @@ -634,10 +656,59 @@ public: return random_bytes.size(); } + int Connect(const sockaddr*, socklen_t) const override + { + // Have a permanent error at connect_errnos[0] because when the fuzzed data is exhausted + // SetFuzzedErrNo() will always return the first element and we want to avoid Connect() + // returning -1 and setting errno to EAGAIN repeatedly. + constexpr std::array connect_errnos{ + ECONNREFUSED, + EAGAIN, + ECONNRESET, + EHOSTUNREACH, + EINPROGRESS, + EINTR, + ENETUNREACH, + ETIMEDOUT, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, connect_errnos); + return -1; + } + return 0; + } + + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override + { + constexpr std::array getsockopt_errnos{ + ENOMEM, + ENOBUFS, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, getsockopt_errnos); + return -1; + } + if (opt_val == nullptr) { + return 0; + } + std::memcpy(opt_val, + ConsumeFixedLengthByteVector(m_fuzzed_data_provider, *opt_len).data(), + *opt_len); + return 0; + } + bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override { return m_fuzzed_data_provider.ConsumeBool(); } + + bool IsConnected(std::string& errmsg) const override { + if (m_fuzzed_data_provider.ConsumeBool()) { + return true; + } + errmsg = "disconnected at random by the fuzzer"; + return false; + } }; [[nodiscard]] inline FuzzedSock ConsumeSock(FuzzedDataProvider& fuzzed_data_provider) diff --git a/src/test/i2p_tests.cpp b/src/test/i2p_tests.cpp new file mode 100644 index 0000000000..1c2a6a433d --- /dev/null +++ b/src/test/i2p_tests.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2021-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +BOOST_FIXTURE_TEST_SUITE(i2p_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(unlimited_recv) +{ + auto CreateSockOrig = CreateSock; + + // Mock CreateSock() to create MockSock. + CreateSock = [](const CService&) { + return std::make_unique(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a')); + }; + + CThreadInterrupt interrupt; + i2p::sam::Session session(GetDataDir() / "test_i2p_private_key", CService{}, &interrupt); + + { + ASSERT_DEBUG_LOG("Creating SAM session"); + ASSERT_DEBUG_LOG("too many bytes without a terminator"); + + i2p::Connection conn; + bool proxy_error; + BOOST_REQUIRE(!session.Connect(CService{}, conn, proxy_error)); + } + + CreateSock = CreateSockOrig; +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/util/net.h b/src/test/util/net.h index d82e9594f3..03382c6df1 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -5,7 +5,13 @@ #ifndef BITCOIN_TEST_UTIL_NET_H #define BITCOIN_TEST_UTIL_NET_H +#include #include +#include + +#include +#include +#include struct ConnmanTestMsg : public CConnman { using CConnman::CConnman; @@ -53,4 +59,67 @@ constexpr NetPermissionFlags ALL_NET_PERMISSION_FLAGS[]{ NetPermissionFlags::PF_ALL, }; +/** + * A mocked Sock alternative that returns a statically contained data upon read and succeeds + * and ignores all writes. The data to be returned is given to the constructor and when it is + * exhausted an EOF is returned by further reads. + */ +class StaticContentsSock : public Sock +{ +public: + explicit StaticContentsSock(const std::string& contents) : m_contents{contents}, m_consumed{0} + { + // Just a dummy number that is not INVALID_SOCKET. + static_assert(INVALID_SOCKET != 1000); + m_socket = 1000; + } + + ~StaticContentsSock() override { Reset(); } + + StaticContentsSock& operator=(Sock&& other) override + { + assert(false && "Move of Sock into MockSock not allowed."); + return *this; + } + + void Reset() override + { + m_socket = INVALID_SOCKET; + } + + ssize_t Send(const void*, size_t len, int) const override { return len; } + + ssize_t Recv(void* buf, size_t len, int flags) const override + { + const size_t consume_bytes{std::min(len, m_contents.size() - m_consumed)}; + std::memcpy(buf, m_contents.data() + m_consumed, consume_bytes); + if ((flags & MSG_PEEK) == 0) { + m_consumed += consume_bytes; + } + return consume_bytes; + } + + int Connect(const sockaddr*, socklen_t) const override { return 0; } + + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override + { + std::memset(opt_val, 0x0, *opt_len); + return 0; + } + + bool Wait(std::chrono::milliseconds timeout, + Event requested, + Event* occurred = nullptr) const override + { + if (occurred != nullptr) { + *occurred = requested; + } + return true; + } + +private: + const std::string m_contents; + mutable size_t m_consumed; +}; + #endif // BITCOIN_TEST_UTIL_NET_H diff --git a/src/util/sock.cpp b/src/util/sock.cpp index f9ecfef5d4..0bc9795db3 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -66,6 +66,16 @@ ssize_t Sock::Recv(void* buf, size_t len, int flags) const return recv(m_socket, static_cast(buf), len, flags); } +int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const +{ + return connect(m_socket, addr, addr_len); +} + +int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const +{ + return getsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); +} + bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { #ifdef USE_POLL diff --git a/src/util/sock.h b/src/util/sock.h index 4b0618dcff..c4ad0cbc43 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -80,16 +80,29 @@ public: /** * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this - * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. */ virtual ssize_t Send(const void* data, size_t len, int flags) const; /** * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this - * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. */ virtual ssize_t Recv(void* buf, size_t len, int flags) const; + /** + * connect(2) wrapper. Equivalent to `connect(this->Get(), addr, addrlen)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + virtual int Connect(const sockaddr* addr, socklen_t addr_len) const; + + /** + * getsockopt(2) wrapper. Equivalent to + * `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + virtual int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const; + using Event = uint8_t; /** @@ -153,7 +166,7 @@ public: */ virtual bool IsConnected(std::string& errmsg) const; -private: +protected: /** * Contained socket. `INVALID_SOCKET` designates the object is empty. */