diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 36cdf3f171..53702f6c52 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -2757,21 +2757,34 @@ const CTxOut& CWallet::FindNonChangeParentOutput(const CTransaction& tx, int out return ptx->vout[n]; } -void CWallet::InitCoinJoinSalt() +const uint256& CWallet::GetCoinJoinSalt() +{ + if (nCoinJoinSalt.IsNull()) { + InitCJSaltFromDb(); + } + return nCoinJoinSalt; +} + +void CWallet::InitCJSaltFromDb() { - // Avoid fetching it multiple times assert(nCoinJoinSalt.IsNull()); WalletBatch batch(GetDatabase()); if (!batch.ReadCoinJoinSalt(nCoinJoinSalt) && batch.ReadCoinJoinSalt(nCoinJoinSalt, true)) { + // Migrate salt stored with legacy key batch.WriteCoinJoinSalt(nCoinJoinSalt); } +} - while (nCoinJoinSalt.IsNull()) { - // We never generated/saved it - nCoinJoinSalt = GetRandHash(); - batch.WriteCoinJoinSalt(nCoinJoinSalt); +bool CWallet::SetCoinJoinSalt(const uint256& cj_salt) +{ + WalletBatch batch(GetDatabase()); + // Only store new salt in CWallet if database write is successful + if (batch.WriteCoinJoinSalt(cj_salt)) { + nCoinJoinSalt = cj_salt; + return true; } + return false; } struct CompareByPriority @@ -3942,11 +3955,14 @@ DBErrors CWallet::LoadWallet(bool& fFirstRunRet) } } - InitCoinJoinSalt(); - if (nLoadWalletRet != DBErrors::LOAD_OK) return nLoadWalletRet; + /* If the CoinJoin salt is not set, try to set a new random hash as the salt */ + if (GetCoinJoinSalt().IsNull() && !SetCoinJoinSalt(GetRandHash())) { + return DBErrors::LOAD_FAIL; + } + return DBErrors::LOAD_OK; } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 0ac9666989..d73af772df 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -817,16 +817,16 @@ private: */ uint256 m_last_block_processed GUARDED_BY(cs_wallet); - /** Pulled from wallet DB ("ps_salt") and used when mixing a random number of rounds. + /** Pulled from wallet DB ("cj_salt") and used when mixing a random number of rounds. * This salt is needed to prevent an attacker from learning how many extra times * the input was mixed based only on information in the blockchain. */ uint256 nCoinJoinSalt; /** - * Fetches CoinJoin salt from database or generates and saves a new one if no salt was found in the db + * Populates nCoinJoinSalt with value from database (and migrates salt stored with legacy key). */ - void InitCoinJoinSalt(); + void InitCJSaltFromDb(); /** Height of last block processed is used by wallet to know depth of transactions * without relying on Chain interface beyond asynchronous updates. For safety, we @@ -872,6 +872,19 @@ public: */ const std::string& GetName() const { return m_name; } + /** + * Get an existing CoinJoin salt. Will attempt to read database (and migrate legacy salts) if + * nCoinJoinSalt is empty but will skip database read if nCoinJoinSalt is populated. + **/ + const uint256& GetCoinJoinSalt(); + + /** + * Write a new CoinJoin salt. This will directly write the new salt value into the wallet database. + * Ensuring that undesirable behaviour like overwriting the salt of a wallet that already uses CoinJoin + * is the responsibility of the caller. + **/ + bool SetCoinJoinSalt(const uint256& cj_salt); + // Map from governance object hash to governance object, they are added by gobject_prepare. std::map m_gobjects;