diff --git a/src/wallet/scriptpubkeyman.cpp b/src/wallet/scriptpubkeyman.cpp index 9d84a04b6a..52d07ff7fa 100644 --- a/src/wallet/scriptpubkeyman.cpp +++ b/src/wallet/scriptpubkeyman.cpp @@ -342,8 +342,11 @@ void LegacyScriptPubKeyMan::UpgradeKeyMetadata() CHDChain hdChainCurrent; if (!GetHDChain(hdChainCurrent)) throw std::runtime_error(std::string(__func__) + ": GetHDChain failed"); - if (!DecryptHDChain(m_storage.GetEncryptionKey(), hdChainCurrent)) + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptHDChain(encryption_key, hdChainCurrent); + })) { throw std::runtime_error(std::string(__func__) + ": DecryptHDChain failed"); + } CExtKey masterKey; SecureVector vchSeed = hdChainCurrent.GetSeed(); @@ -474,8 +477,11 @@ bool LegacyScriptPubKeyMan::GetDecryptedHDChain(CHDChain& hdChainRet) return false; } - if (!DecryptHDChain(m_storage.GetEncryptionKey(), hdChainTmp)) + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptHDChain(encryption_key, hdChainTmp); + })) { return false; + } // make sure seed matches this chain if (hdChainTmp.GetID() != hdChainTmp.GetSeedHash()) @@ -911,7 +917,9 @@ bool LegacyScriptPubKeyMan::AddKeyPubKeyInner(const CKey& key, const CPubKey &pu std::vector vchCryptedSecret; CKeyingMaterial vchSecret(key.begin(), key.end()); - if (!EncryptSecret(m_storage.GetEncryptionKey(), vchSecret, pubkey.GetHash(), vchCryptedSecret)) { + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return EncryptSecret(encryption_key, vchSecret, pubkey.GetHash(), vchCryptedSecret); + })) { return false; } @@ -933,7 +941,9 @@ bool LegacyScriptPubKeyMan::GetKeyInner(const CKeyID &address, CKey& keyOut) con { const CPubKey &vchPubKey = (*mi).second.first; const std::vector &vchCryptedSecret = (*mi).second.second; - return DecryptKey(m_storage.GetEncryptionKey(), vchCryptedSecret, vchPubKey, keyOut); + return m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptKey(encryption_key, vchCryptedSecret, vchPubKey, keyOut); + }); } return false; } @@ -1164,8 +1174,11 @@ bool LegacyScriptPubKeyMan::GetKey(const CKeyID &address, CKey& keyOut) const CHDChain hdChainCurrent; if (!GetHDChain(hdChainCurrent)) throw std::runtime_error(std::string(__func__) + ": GetHDChain failed"); - if (!DecryptHDChain(m_storage.GetEncryptionKey(), hdChainCurrent)) + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptHDChain(encryption_key, hdChainCurrent); + })) { throw std::runtime_error(std::string(__func__) + ": DecryptHDChain failed"); + } // make sure seed matches this chain if (hdChainCurrent.GetID() != hdChainCurrent.GetSeedHash()) throw std::runtime_error(std::string(__func__) + ": Wrong HD chain!"); @@ -1276,8 +1289,11 @@ void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& throw std::runtime_error(std::string(__func__) + ": GetHDChain failed"); } - if (!DecryptHDChain(m_storage.GetEncryptionKey(), hdChainTmp)) + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptHDChain(encryption_key, hdChainTmp); + })) { throw std::runtime_error(std::string(__func__) + ": DecryptHDChain failed"); + } // make sure seed matches this chain if (hdChainTmp.GetID() != hdChainTmp.GetSeedHash()) throw std::runtime_error(std::string(__func__) + ": Wrong HD chain!"); @@ -1897,7 +1913,9 @@ std::map DescriptorScriptPubKeyMan::GetKeys() const const CPubKey& pubkey = key_pair.second.first; const std::vector& crypted_secret = key_pair.second.second; CKey key; - DecryptKey(m_storage.GetEncryptionKey(), crypted_secret, pubkey, key); + m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return DecryptKey(encryption_key, crypted_secret, pubkey, key); + }); keys[pubkey.GetID()] = key; } return keys; @@ -2010,7 +2028,9 @@ bool DescriptorScriptPubKeyMan::AddDescriptorKeyWithDB(WalletBatch& batch, const std::vector crypted_secret; CKeyingMaterial secret(key.begin(), key.end()); - if (!EncryptSecret(m_storage.GetEncryptionKey(), secret, pubkey.GetHash(), crypted_secret)) { + if (!m_storage.WithEncryptionKey([&](const CKeyingMaterial& encryption_key) { + return EncryptSecret(encryption_key, secret, pubkey.GetHash(), crypted_secret); + })) { return false; } diff --git a/src/wallet/scriptpubkeyman.h b/src/wallet/scriptpubkeyman.h index 0d6450d671..05f2452d2f 100644 --- a/src/wallet/scriptpubkeyman.h +++ b/src/wallet/scriptpubkeyman.h @@ -21,6 +21,7 @@ #include +#include #include // Wallet storage things that ScriptPubKeyMans need in order to be able to store things to the wallet database. @@ -38,7 +39,8 @@ public: virtual void UnsetBlankWalletFlag(WalletBatch&) = 0; virtual bool CanSupportFeature(enum WalletFeature) const = 0; virtual void SetMinVersion(enum WalletFeature, WalletBatch* = nullptr) = 0; - virtual const CKeyingMaterial& GetEncryptionKey() const = 0; + //! Pass the encryption key to cb(). + virtual bool WithEncryptionKey(std::function cb) const = 0; virtual bool HasEncryptionKeys() const = 0; virtual bool IsLocked(bool fForMixing) const = 0; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index b6ac4595a8..00d872582a 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -5743,9 +5743,10 @@ void CWallet::SetupLegacyScriptPubKeyMan() m_spk_managers[spk_manager->GetID()] = std::move(spk_manager); } -const CKeyingMaterial& CWallet::GetEncryptionKey() const +bool CWallet::WithEncryptionKey(std::function cb) const { - return vMasterKey; + LOCK(cs_wallet); + return cb(vMasterKey); } bool CWallet::HasEncryptionKeys() const diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 411367ac2b..72c4c747ba 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -1468,7 +1468,8 @@ public: //! Make a LegacyScriptPubKeyMan and set it for all types, internal, and external. void SetupLegacyScriptPubKeyMan(); - const CKeyingMaterial& GetEncryptionKey() const override; + bool WithEncryptionKey(std::function cb) const override; + bool HasEncryptionKeys() const override; /** Get last block processed height */