diff --git a/src/random.cpp b/src/random.cpp index 5674b56865..6f57e582d2 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -14,16 +14,14 @@ #include #endif #include // for LogPrintf() +#include +#include #include // for Mutex #include // for GetTimeMicros() #include #include -#include - -#include - #ifndef WIN32 #include #endif @@ -582,16 +580,6 @@ static void ProcRand(unsigned char* out, int num, RNGLevel level) noexcept } } -std::chrono::microseconds GetRandMicros(std::chrono::microseconds duration_max) noexcept -{ - return std::chrono::microseconds{GetRand(duration_max.count())}; -} - -std::chrono::milliseconds GetRandMillis(std::chrono::milliseconds duration_max) noexcept -{ - return std::chrono::milliseconds{GetRand(duration_max.count())}; -} - void GetRandBytes(unsigned char* buf, int num) noexcept { ProcRand(buf, num, RNGLevel::FAST); } void GetStrongRandBytes(unsigned char* buf, int num) noexcept { ProcRand(buf, num, RNGLevel::SLOW); } void RandAddPeriodic() noexcept { ProcRand(nullptr, 0, RNGLevel::PERIODIC); } diff --git a/src/random.h b/src/random.h index b957b52c12..f1b31aae66 100644 --- a/src/random.h +++ b/src/random.h @@ -70,9 +70,21 @@ * Thread-safe. */ void GetRandBytes(unsigned char* buf, int num) noexcept; +/** Generate a uniform random integer in the range [0..range). Precondition: range > 0 */ uint64_t GetRand(uint64_t nMax) noexcept; -std::chrono::microseconds GetRandMicros(std::chrono::microseconds duration_max) noexcept; -std::chrono::milliseconds GetRandMillis(std::chrono::milliseconds duration_max) noexcept; +/** Generate a uniform random duration in the range [0..max). Precondition: max.count() > 0 */ +template +D GetRandomDuration(typename std::common_type::type max) noexcept +// Having the compiler infer the template argument from the function argument +// is dangerous, because the desired return value generally has a different +// type than the function argument. So std::common_type is used to force the +// call site to specify the type of the return value. +{ + assert(max.count() > 0); + return D{GetRand(max.count())}; +}; +constexpr auto GetRandMicros = GetRandomDuration; +constexpr auto GetRandMillis = GetRandomDuration; int GetRandInt(int nMax) noexcept; uint256 GetRandHash() noexcept; diff --git a/src/rpc/server.cpp b/src/rpc/server.cpp index a87c7e2263..6ccd7c278f 100644 --- a/src/rpc/server.cpp +++ b/src/rpc/server.cpp @@ -28,7 +28,8 @@ static std::string rpcWarmupStatus GUARDED_BY(g_rpc_warmup_mutex) = "RPC server /* Timer-creating functions */ static RPCTimerInterface* timerInterface = nullptr; /* Map of name to timer. */ -static std::map > deadlineTimers; +static Mutex g_deadline_timers_mutex; +static std::map > deadlineTimers GUARDED_BY(g_deadline_timers_mutex); static bool ExecuteCommand(const CRPCCommand& command, const JSONRPCRequest& request, UniValue& result, bool last_handler, std::multimap> mapPlatformRestrictions); // Any commands submitted by this user will have their commands filtered based on the mapPlatformRestrictions @@ -330,7 +331,7 @@ void InterruptRPC() void StopRPC() { LogPrint(BCLog::RPC, "Stopping RPC\n"); - deadlineTimers.clear(); + WITH_LOCK(g_deadline_timers_mutex, deadlineTimers.clear()); DeleteAuthCookie(); g_rpcSignals.Stopped(); } @@ -609,6 +610,7 @@ void RPCRunLater(const std::string& name, std::function func, int64_t nS { if (!timerInterface) throw JSONRPCError(RPC_INTERNAL_ERROR, "No timer handler registered for RPC"); + LOCK(g_deadline_timers_mutex); deadlineTimers.erase(name); LogPrint(BCLog::RPC, "queue run of timer %s in %i seconds (using %s)\n", name, nSeconds, timerInterface->Name()); deadlineTimers.emplace(name, std::unique_ptr(timerInterface->NewTimer(func, nSeconds*1000))); diff --git a/src/test/random_tests.cpp b/src/test/random_tests.cpp index e65603cd76..40a2378fe9 100644 --- a/src/test/random_tests.cpp +++ b/src/test/random_tests.cpp @@ -28,6 +28,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) for (int i = 10; i > 0; --i) { BOOST_CHECK_EQUAL(GetRand(std::numeric_limits::max()), uint64_t{10393729187455219830U}); BOOST_CHECK_EQUAL(GetRandInt(std::numeric_limits::max()), int{769702006}); + BOOST_CHECK_EQUAL(GetRandMicros(std::chrono::hours{1}).count(), 2917185654); + BOOST_CHECK_EQUAL(GetRandMillis(std::chrono::hours{1}).count(), 2144374); } { constexpr SteadySeconds time_point{1s}; @@ -67,6 +69,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) for (int i = 10; i > 0; --i) { BOOST_CHECK(GetRand(std::numeric_limits::max()) != uint64_t{10393729187455219830U}); BOOST_CHECK(GetRandInt(std::numeric_limits::max()) != int{769702006}); + BOOST_CHECK(GetRandMicros(std::chrono::hours{1}) != std::chrono::microseconds{2917185654}); + BOOST_CHECK(GetRandMillis(std::chrono::hours{1}) != std::chrono::milliseconds{2144374}); } { @@ -108,7 +112,7 @@ BOOST_AUTO_TEST_CASE(stdrandom_test) BOOST_CHECK(x >= 3); BOOST_CHECK(x <= 9); - std::vector test{1,2,3,4,5,6,7,8,9,10}; + std::vector test{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; std::shuffle(test.begin(), test.end(), ctx); for (int j = 1; j <= 10; ++j) { BOOST_CHECK(std::find(test.begin(), test.end(), j) != test.end()); @@ -118,7 +122,6 @@ BOOST_AUTO_TEST_CASE(stdrandom_test) BOOST_CHECK(std::find(test.begin(), test.end(), j) != test.end()); } } - } /** Test that Shuffle reaches every permutation with equal probability. */ diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 703da04873..752098cbf5 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -1927,6 +1927,9 @@ static UniValue walletpassphrase(const JSONRPCRequest& request) CWallet* const pwallet = wallet.get(); int64_t nSleepTime; + int64_t relock_time; + // Prevent concurrent calls to walletpassphrase with the same wallet. + LOCK(pwallet->m_unlock_mutex); { LOCK(pwallet->cs_wallet); @@ -1975,7 +1978,7 @@ static UniValue walletpassphrase(const JSONRPCRequest& request) pwallet->TopUpKeyPool(); pwallet->nRelockTime = GetTime() + nSleepTime; - + relock_time = pwallet->nRelockTime; } // rpcRunLater must be called without cs_wallet held otherwise a deadlock // can occur. The deadlock would happen when RPCRunLater removes the @@ -1986,9 +1989,11 @@ static UniValue walletpassphrase(const JSONRPCRequest& request) // wallet before the following callback is called. If a valid shared pointer // is acquired in the callback then the wallet is still loaded. std::weak_ptr weak_wallet = wallet; - pwallet->chain().rpcRunLater(strprintf("lockwallet(%s)", pwallet->GetName()), [weak_wallet] { + pwallet->chain().rpcRunLater(strprintf("lockwallet(%s)", pwallet->GetName()), [weak_wallet, relock_time] { if (auto shared_wallet = weak_wallet.lock()) { LOCK(shared_wallet->cs_wallet); + // Skip if this is not the most recent rpcRunLater callback. + if (shared_wallet->nRelockTime != relock_time) return; shared_wallet->Lock(); shared_wallet->nRelockTime = 0; } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index b816a2d183..e70c4b6050 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -965,8 +965,10 @@ public: std::vector GetDestValues(const std::string& prefix) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); //! Holds a timestamp at which point the wallet is scheduled (externally) to be relocked. Caller must arrange for actual relocking to occur via Lock(). - int64_t nRelockTime = 0; + int64_t nRelockTime GUARDED_BY(cs_wallet){0}; + // Used to prevent concurrent calls to walletpassphrase RPC. + Mutex m_unlock_mutex; bool Unlock(const SecureString& strWalletPassphrase, bool fForMixingOnly = false, bool accept_no_keys = false); bool ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, const SecureString& strNewWalletPassphrase); bool EncryptWallet(const SecureString& strWalletPassphrase);