Merge #19046: Replace CWallet::Set* functions that use memonly with Add/Load variants

3a9aba21a49a6d80bd187940d5e26893937b6832 Split SetWalletFlags into Add/LoadWalletFlags (Andrew Chow)
d9cd095b5965fc20c09f401370e7ba99446663e3 Split SetActiveScriptPubKeyMan into Add/LoadActiveScriptPubKeyMan (Andrew Chow)
0122fbab4c340b23ae56173de6c5ab866ba25ab8 Split SetHDChain into AddHDChain and LoadHDChain (Andrew Chow)

Pull request description:

  `SetHDChaiin`, `SetActiveScriptPubKeyMan`, and `SetWalletFlags` have a `memonly` argument which is kind of confusing, as noted in https://github.com/bitcoin/bitcoin/pull/17681#discussion_r427633081. This PR replaces those functions with `Add*` and `Load*` variants so that they follow the pattern used elsewhere in the wallet.

  `AddHDChain`, `AddActiveScriptPubKeyMan`, and `AddWalletFlags` both set their respective variables in `CWallet` and writes them to disk. These functions are used by the actions which modify the wallet such as `sethdseed`, `importdescriptors`, and creating a new wallet.

  `LoadHDChain`, `LoadActiveScriptPubKeyMan`, and `LoadWalletFlags` just set the `CWallet` variables. These functions are used by `LoadWallet` when loading the wallet from disk.

ACKs for top commit:
  jnewbery:
    Code review ACK 3a9aba21a49a6d80bd187940d5e26893937b6832
  ryanofsky:
    Code review ACK 3a9aba21a49a6d80bd187940d5e26893937b6832. Only changes since last review tweaks making m_wallet_flags updates more safe
  meshcollider:
    utACK 3a9aba21a49a6d80bd187940d5e26893937b6832

Tree-SHA512: 365aeaafc5ba42879c0eb797ec3beb29ab70e27f917dc880763f743420b3be6ddf797240996beed8a9ad70fb212c2590253c6b44c9dc244529c3939d9538983f
This commit is contained in:
Andrew Chow 2020-05-21 23:15:41 -04:00 committed by Konstantin Akimov
parent 2c0d5b7c71
commit 63895fde23
No known key found for this signature in database
GPG Key ID: 2176C4A5D01EA524
7 changed files with 75 additions and 57 deletions

View File

@ -1766,7 +1766,7 @@ static UniValue ProcessDescriptorImport(CWallet * const pwallet, const UniValue&
if (!w_desc.descriptor->GetOutputType()) { if (!w_desc.descriptor->GetOutputType()) {
warnings.push_back("Unknown output type, cannot set descriptor to active."); warnings.push_back("Unknown output type, cannot set descriptor to active.");
} else { } else {
pwallet->SetActiveScriptPubKeyMan(spk_manager->GetID(), internal); pwallet->AddActiveScriptPubKeyMan(spk_manager->GetID(), internal);
} }
} }

View File

@ -4335,8 +4335,8 @@ static RPCHelpMan sethdseed()
if (!newHdChain.SetSeed(SecureVector(key.begin(), key.end()), true)) { if (!newHdChain.SetSeed(SecureVector(key.begin(), key.end()), true)) {
throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: SetSeed failed"); throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: SetSeed failed");
} }
if (!spk_man.SetHDChainSingle(newHdChain, false)) { if (!spk_man.AddHDChainSingle(newHdChain)) {
throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: SetHDChainSingle failed"); throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: AddHDChainSingle failed");
} }
// add default account // add default account
newHdChain.AddAccount(); newHdChain.AddAccount();

View File

@ -268,7 +268,7 @@ bool LegacyScriptPubKeyMan::Encrypt(const CKeyingMaterial& master_key, WalletBat
if (!hdChainCurrent.IsNull()) { if (!hdChainCurrent.IsNull()) {
assert(EncryptHDChain(master_key, m_hd_chain)); assert(EncryptHDChain(master_key, m_hd_chain));
assert(SetHDChain(m_hd_chain)); assert(LoadHDChain(m_hd_chain));
CHDChain hdChainCrypted; CHDChain hdChainCrypted;
assert(GetHDChain(hdChainCrypted)); assert(GetHDChain(hdChainCrypted));
@ -277,7 +277,7 @@ bool LegacyScriptPubKeyMan::Encrypt(const CKeyingMaterial& master_key, WalletBat
assert(hdChainCurrent.GetID() == hdChainCrypted.GetID()); assert(hdChainCurrent.GetID() == hdChainCrypted.GetID());
assert(hdChainCurrent.GetSeedHash() != hdChainCrypted.GetSeedHash()); assert(hdChainCurrent.GetSeedHash() != hdChainCrypted.GetSeedHash());
assert(SetHDChain(*encrypted_batch, hdChainCrypted, false)); assert(AddHDChain(*encrypted_batch, hdChainCrypted));
} }
encrypted_batch = nullptr; encrypted_batch = nullptr;
@ -396,7 +396,7 @@ void LegacyScriptPubKeyMan::GenerateNewCryptedHDChain(const SecureString& secure
CHDChain hdChainPrev = hdChainTmp; CHDChain hdChainPrev = hdChainTmp;
bool res = EncryptHDChain(vMasterKey, hdChainTmp); bool res = EncryptHDChain(vMasterKey, hdChainTmp);
assert(res); assert(res);
res = SetHDChain(hdChainTmp); res = LoadHDChain(hdChainTmp);
assert(res); assert(res);
CHDChain hdChainCrypted; CHDChain hdChainCrypted;
@ -407,8 +407,8 @@ void LegacyScriptPubKeyMan::GenerateNewCryptedHDChain(const SecureString& secure
assert(hdChainPrev.GetID() == hdChainCrypted.GetID()); assert(hdChainPrev.GetID() == hdChainCrypted.GetID());
assert(hdChainPrev.GetSeedHash() != hdChainCrypted.GetSeedHash()); assert(hdChainPrev.GetSeedHash() != hdChainCrypted.GetSeedHash());
if (!SetHDChainSingle(hdChainCrypted, false)) { if (!AddHDChainSingle(hdChainCrypted)) {
throw std::runtime_error(std::string(__func__) + ": SetHDChainSingle failed"); throw std::runtime_error(std::string(__func__) + ": AddHDChainSingle failed");
} }
} }
@ -426,8 +426,8 @@ void LegacyScriptPubKeyMan::GenerateNewHDChain(const SecureString& secureMnemoni
// add default account // add default account
newHdChain.AddAccount(); newHdChain.AddAccount();
if (!SetHDChainSingle(newHdChain, false)) { if (!AddHDChainSingle(newHdChain)) {
throw std::runtime_error(std::string(__func__) + ": SetHDChainSingle failed"); throw std::runtime_error(std::string(__func__) + ": AddHDChainSingle failed");
} }
if (!NewKeyPool()) { if (!NewKeyPool()) {
@ -435,14 +435,24 @@ void LegacyScriptPubKeyMan::GenerateNewHDChain(const SecureString& secureMnemoni
} }
} }
bool LegacyScriptPubKeyMan::SetHDChain(WalletBatch &batch, const CHDChain& chain, bool memonly) bool LegacyScriptPubKeyMan::LoadHDChain(const CHDChain& chain)
{ {
LOCK(cs_KeyStore); LOCK(cs_KeyStore);
if (!SetHDChain(chain)) if (m_storage.HasEncryptionKeys() != chain.IsCrypted()) return false;
m_hd_chain = chain;
return true;
}
bool LegacyScriptPubKeyMan::AddHDChain(WalletBatch &batch, const CHDChain& chain)
{
LOCK(cs_KeyStore);
if (!LoadHDChain(chain))
return false; return false;
if (!memonly) { {
if (chain.IsCrypted() && encrypted_batch) { if (chain.IsCrypted() && encrypted_batch) {
if (!encrypted_batch->WriteHDChain(chain)) if (!encrypted_batch->WriteHDChain(chain))
throw std::runtime_error(std::string(__func__) + ": WriteHDChain failed for encrypted batch"); throw std::runtime_error(std::string(__func__) + ": WriteHDChain failed for encrypted batch");
@ -458,10 +468,10 @@ bool LegacyScriptPubKeyMan::SetHDChain(WalletBatch &batch, const CHDChain& chain
return true; return true;
} }
bool LegacyScriptPubKeyMan::SetHDChainSingle(const CHDChain& chain, bool memonly) bool LegacyScriptPubKeyMan::AddHDChainSingle(const CHDChain& chain)
{ {
WalletBatch batch(m_storage.GetDatabase()); WalletBatch batch(m_storage.GetDatabase());
return SetHDChain(batch, chain, memonly); return AddHDChain(batch, chain);
} }
bool LegacyScriptPubKeyMan::GetDecryptedHDChain(CHDChain& hdChainRet) bool LegacyScriptPubKeyMan::GetDecryptedHDChain(CHDChain& hdChainRet)
@ -1090,16 +1100,6 @@ bool LegacyScriptPubKeyMan::AddWatchOnly(const CScript& dest, int64_t nCreateTim
return AddWatchOnly(dest); return AddWatchOnly(dest);
} }
bool LegacyScriptPubKeyMan::SetHDChain(const CHDChain& chain)
{
LOCK(cs_KeyStore);
if (m_storage.HasEncryptionKeys() != chain.IsCrypted()) return false;
m_hd_chain = chain;
return true;
}
bool LegacyScriptPubKeyMan::HaveHDKey(const CKeyID &address, CHDChain& hdChainCurrent) const bool LegacyScriptPubKeyMan::HaveHDKey(const CKeyID &address, CHDChain& hdChainCurrent) const
{ {
LOCK(cs_KeyStore); LOCK(cs_KeyStore);
@ -1322,8 +1322,8 @@ void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata&
if (!hdChainCurrent.SetAccount(nAccountIndex, acc)) if (!hdChainCurrent.SetAccount(nAccountIndex, acc))
throw std::runtime_error(std::string(__func__) + ": SetAccount failed"); throw std::runtime_error(std::string(__func__) + ": SetAccount failed");
if (!SetHDChain(batch, hdChainCurrent, false)) { if (!AddHDChain(batch, hdChainCurrent)) {
throw std::runtime_error(std::string(__func__) + ": SetHDChain failed"); throw std::runtime_error(std::string(__func__) + ": AddHDChain failed");
} }
if (!AddHDPubKey(batch, childKey.Neuter(), fInternal)) if (!AddHDPubKey(batch, childKey.Neuter(), fInternal))

View File

@ -278,12 +278,8 @@ private:
/** Add a KeyOriginInfo to the wallet */ /** Add a KeyOriginInfo to the wallet */
bool AddKeyOriginWithDB(WalletBatch& batch, const CPubKey& pubkey, const KeyOriginInfo& info); bool AddKeyOriginWithDB(WalletBatch& batch, const CPubKey& pubkey, const KeyOriginInfo& info);
/* Set the HD chain model (chain child index counters) */
bool SetHDChain(WalletBatch &batch, const CHDChain& chain, bool memonly);
bool EncryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& chain); bool EncryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& chain);
bool DecryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& hdChainRet) const; bool DecryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& hdChainRet) const;
bool SetHDChain(const CHDChain& chain);
/* the HD chain data model (external chain counters) */ /* the HD chain data model (external chain counters) */
CHDChain m_hd_chain GUARDED_BY(cs_KeyStore); CHDChain m_hd_chain GUARDED_BY(cs_KeyStore);
@ -398,11 +394,15 @@ public:
//! Generate a new key //! Generate a new key
CPubKey GenerateNewKey(WalletBatch& batch, uint32_t nAccountIndex, bool fInternal /*= false*/) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); CPubKey GenerateNewKey(WalletBatch& batch, uint32_t nAccountIndex, bool fInternal /*= false*/) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore);
/* Set the HD chain model (chain child index counters) and writes it to the database */
bool AddHDChain(WalletBatch &batch, const CHDChain& chain);
//! Load a HD chain model (used by LoadWallet)
bool LoadHDChain(const CHDChain& chain);
/** /**
* Set the HD chain model (chain child index counters) using temporary wallet db object * Set the HD chain model (chain child index counters) using temporary wallet db object
* which causes db flush every time these methods are used * which causes db flush every time these methods are used
*/ */
bool SetHDChainSingle(const CHDChain& chain, bool memonly); bool AddHDChainSingle(const CHDChain& chain);
//! Adds a watch-only address to the store, without saving it to disk (used by LoadWallet) //! Adds a watch-only address to the store, without saving it to disk (used by LoadWallet)
bool LoadWatchOnly(const CScript &dest); bool LoadWatchOnly(const CScript &dest);

View File

@ -1684,19 +1684,28 @@ bool CWallet::IsWalletFlagSet(uint64_t flag) const
return (m_wallet_flags & flag); return (m_wallet_flags & flag);
} }
bool CWallet::SetWalletFlags(uint64_t overwriteFlags, bool memonly) bool CWallet::LoadWalletFlags(uint64_t flags)
{ {
LOCK(cs_wallet); LOCK(cs_wallet);
m_wallet_flags = overwriteFlags; if (((flags & KNOWN_WALLET_FLAGS) >> 32) ^ (flags >> 32)) {
if (((overwriteFlags & KNOWN_WALLET_FLAGS) >> 32) ^ (overwriteFlags >> 32)) {
// contains unknown non-tolerable wallet flags // contains unknown non-tolerable wallet flags
return false; return false;
} }
if (!memonly && !WalletBatch(GetDatabase()).WriteWalletFlags(m_wallet_flags)) { m_wallet_flags = flags;
return true;
}
bool CWallet::AddWalletFlags(uint64_t flags)
{
LOCK(cs_wallet);
// We should never be writing unknown non-tolerable wallet flags
assert(!(((flags & KNOWN_WALLET_FLAGS) >> 32) ^ (flags >> 32)));
if (!WalletBatch(GetDatabase()).WriteWalletFlags(flags)) {
throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed"); throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed");
} }
return true; return LoadWalletFlags(flags);
} }
int64_t CWalletTx::GetTxTime() const int64_t CWalletTx::GetTxTime() const
@ -4575,7 +4584,8 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain& chain, interfaces::C
if (fFirstRun) if (fFirstRun)
{ {
walletInstance->SetMaxVersion(FEATURE_LATEST); walletInstance->SetMaxVersion(FEATURE_LATEST);
walletInstance->SetWalletFlags(wallet_creation_flags, false);
walletInstance->AddWalletFlags(wallet_creation_flags);
// Only create LegacyScriptPubKeyMan when not descriptor wallet // Only create LegacyScriptPubKeyMan when not descriptor wallet
if (!walletInstance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) { if (!walletInstance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) {
@ -4600,8 +4610,8 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain& chain, interfaces::C
} }
LOCK(walletInstance->cs_wallet); LOCK(walletInstance->cs_wallet);
if (auto spk_man = walletInstance->GetLegacyScriptPubKeyMan()) { if (auto spk_man = walletInstance->GetLegacyScriptPubKeyMan()) {
if (!spk_man->SetHDChainSingle(newHdChain, false)) { if (!spk_man->AddHDChainSingle(newHdChain)) {
error = strprintf(_("%s failed"), "SetHDChainSingle"); error = strprintf(_("%s failed"), "AddHDChainSingle");
return nullptr; return nullptr;
} }
} }
@ -5649,12 +5659,21 @@ void CWallet::SetupDescriptorScriptPubKeyMans()
spk_manager->SetupDescriptorGeneration(master_key); spk_manager->SetupDescriptorGeneration(master_key);
uint256 id = spk_manager->GetID(); uint256 id = spk_manager->GetID();
m_spk_managers[id] = std::move(spk_manager); m_spk_managers[id] = std::move(spk_manager);
SetActiveScriptPubKeyMan(id, internal); AddActiveScriptPubKeyMan(id, internal);
} }
} }
} }
void CWallet::SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly) void CWallet::AddActiveScriptPubKeyMan(uint256 id, bool internal)
{
WalletBatch batch(GetDatabase());
if (!batch.WriteActiveScriptPubKeyMan(id, internal)) {
throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed");
}
LoadActiveScriptPubKeyMan(id, internal);
}
void CWallet::LoadActiveScriptPubKeyMan(uint256 id, bool internal)
{ {
WalletLogPrintf("Setting spkMan to active: id = %s, type = %d, internal = %d\n", id.ToString(), static_cast<int>(OutputType::LEGACY), static_cast<int>(internal)); WalletLogPrintf("Setting spkMan to active: id = %s, type = %d, internal = %d\n", id.ToString(), static_cast<int>(OutputType::LEGACY), static_cast<int>(internal));
auto& spk_mans = internal ? m_internal_spk_managers : m_external_spk_managers; auto& spk_mans = internal ? m_internal_spk_managers : m_external_spk_managers;
@ -5662,12 +5681,6 @@ void CWallet::SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly)
spk_man->SetInternal(internal); spk_man->SetInternal(internal);
spk_mans = spk_man; spk_mans = spk_man;
if (!memonly) {
WalletBatch batch(GetDatabase());
if (!batch.WriteActiveScriptPubKeyMan(id, internal)) {
throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed");
}
}
NotifyCanGetAddressesChanged(); NotifyCanGetAddressesChanged();
} }

View File

@ -1334,7 +1334,9 @@ public:
/** overwrite all flags by the given uint64_t /** overwrite all flags by the given uint64_t
returns false if unknown, non-tolerable flags are present */ returns false if unknown, non-tolerable flags are present */
bool SetWalletFlags(uint64_t overwriteFlags, bool memOnly); bool AddWalletFlags(uint64_t flags);
/** Loads the flags into the wallet. (used by LoadWallet) */
bool LoadWalletFlags(uint64_t flags);
/** Determine if we are a legacy wallet */ /** Determine if we are a legacy wallet */
bool IsLegacy() const; bool IsLegacy() const;
@ -1415,12 +1417,15 @@ public:
//! Instantiate a descriptor ScriptPubKeyMan from the WalletDescriptor and load it //! Instantiate a descriptor ScriptPubKeyMan from the WalletDescriptor and load it
void LoadDescriptorScriptPubKeyMan(uint256 id, WalletDescriptor& desc); void LoadDescriptorScriptPubKeyMan(uint256 id, WalletDescriptor& desc);
//! Sets the active ScriptPubKeyMan for the specified type and internal //! Adds the active ScriptPubKeyMan for the specified type and internal. Writes it to the wallet file
//! @param[in] id The unique id for the ScriptPubKeyMan //! @param[in] id The unique id for the ScriptPubKeyMan
//! @param[in] type The OutputType this ScriptPubKeyMan provides addresses for
//! @param[in] internal Whether this ScriptPubKeyMan provides change addresses //! @param[in] internal Whether this ScriptPubKeyMan provides change addresses
//! @param[in] memonly Whether to record this update to the database. Set to true for wallet loading, normally false when actually updating the wallet. void AddActiveScriptPubKeyMan(uint256 id, bool internal);
void SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly = false);
//! Loads an active ScriptPubKeyMan for the specified type and internal. (used by LoadWallet)
//! @param[in] id The unique id for the ScriptPubKeyMan
//! @param[in] internal Whether this ScriptPubKeyMan provides change addresses
void LoadActiveScriptPubKeyMan(uint256 id, bool internal);
//! Create new DescriptorScriptPubKeyMans and add them to the wallet //! Create new DescriptorScriptPubKeyMans and add them to the wallet
void SetupDescriptorScriptPubKeyMans(); void SetupDescriptorScriptPubKeyMans();

View File

@ -516,7 +516,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue,
CHDChain chain; CHDChain chain;
ssValue >> chain; ssValue >> chain;
assert ((strType == DBKeys::CRYPTED_HDCHAIN) == chain.IsCrypted()); assert ((strType == DBKeys::CRYPTED_HDCHAIN) == chain.IsCrypted());
if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->SetHDChainSingle(chain, true)) if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadHDChain(chain))
{ {
strErr = "Error reading wallet database: SetHDChain failed"; strErr = "Error reading wallet database: SetHDChain failed";
return false; return false;
@ -557,7 +557,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue,
} else if (strType == DBKeys::FLAGS) { } else if (strType == DBKeys::FLAGS) {
uint64_t flags; uint64_t flags;
ssValue >> flags; ssValue >> flags;
if (!pwallet->SetWalletFlags(flags, true)) { if (!pwallet->LoadWalletFlags(flags)) {
strErr = "Error reading wallet database: Unknown non-tolerable wallet flags found"; strErr = "Error reading wallet database: Unknown non-tolerable wallet flags found";
return false; return false;
} }
@ -768,10 +768,10 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet)
// Set the active ScriptPubKeyMans // Set the active ScriptPubKeyMans
for (auto spk_man : wss.m_active_external_spks) { for (auto spk_man : wss.m_active_external_spks) {
pwallet->SetActiveScriptPubKeyMan(spk_man.second, /* internal */ false, /* memonly */ true); pwallet->LoadActiveScriptPubKeyMan(spk_man.second, /* internal */ false);
} }
for (auto spk_man : wss.m_active_internal_spks) { for (auto spk_man : wss.m_active_internal_spks) {
pwallet->SetActiveScriptPubKeyMan(spk_man.second, /* internal */ true, /* memonly */ true); pwallet->LoadActiveScriptPubKeyMan(spk_man.second, /* internal */ true);
} }
// Set the descriptor caches // Set the descriptor caches