diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h index 3b068f16be..32fb6b28e8 100644 --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -346,6 +346,9 @@ public: //! Return default wallet directory. virtual std::string getWalletDir() = 0; + //! Restore backup wallet + virtual std::unique_ptr restoreWallet(const std::string& backup_file, const std::string& wallet_name, bilingual_str& error, std::vector& warnings) = 0; + //! Return available wallets in wallet directory. virtual std::vector listWalletDir() = 0; diff --git a/src/rpc/protocol.h b/src/rpc/protocol.h index 6a567a9ae1..dc7846586e 100644 --- a/src/rpc/protocol.h +++ b/src/rpc/protocol.h @@ -81,6 +81,7 @@ enum RPCErrorCode RPC_WALLET_NOT_FOUND = -18, //!< Invalid wallet specified RPC_WALLET_NOT_SPECIFIED = -19, //!< No wallet specified (error when there are multiple wallets loaded) RPC_WALLET_ALREADY_LOADED = -35, //!< This same wallet is already loaded + RPC_WALLET_ALREADY_EXISTS = -36, //!< There is already a wallet with the same name //! Backwards compatible aliases diff --git a/src/wallet/db.h b/src/wallet/db.h index 5530c6c8cf..aed850241d 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -222,6 +222,7 @@ enum class DatabaseStatus { FAILED_LOAD, FAILED_VERIFY, FAILED_ENCRYPT, + FAILED_INVALID_BACKUP_FILE, }; /** Recursively list database paths in directory. */ diff --git a/src/wallet/interfaces.cpp b/src/wallet/interfaces.cpp index ac422a870d..50f446863e 100644 --- a/src/wallet/interfaces.cpp +++ b/src/wallet/interfaces.cpp @@ -607,6 +607,12 @@ public: assert(m_context.m_coinjoin_loader); return MakeWallet(LoadWallet(*m_context.chain, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings)); } + std::unique_ptr restoreWallet(const std::string& backup_file, const std::string& wallet_name, bilingual_str& error, std::vector& warnings) override + { + DatabaseStatus status; + assert(m_context.m_coinjoin_loader); + return MakeWallet(RestoreWallet(*m_context.chain, *m_context.m_coinjoin_loader, backup_file, wallet_name, /*load_on_start=*/true, status, error, warnings)); + } std::string getWalletDir() override { return GetWalletDir().string(); diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index ec3bb4901f..6efcbdc2d8 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -2716,16 +2716,8 @@ static RPCHelpMan listwallets() }; } -static std::tuple, std::vector> LoadWalletHelper(WalletContext& context, UniValue load_on_start_param, const std::string wallet_name) +void HandleWalletError(const std::shared_ptr wallet, DatabaseStatus& status, bilingual_str& error) { - DatabaseOptions options; - DatabaseStatus status; - options.require_existing = true; - bilingual_str error; - std::vector warnings; - std::optional load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional(load_on_start_param.get_bool()); - std::shared_ptr const wallet = LoadWallet(*context.chain, *context.m_coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings); - if (!wallet) { // Map bad format to not found, since bad format is returned when the // wallet directory exists, but doesn't contain a data file. @@ -2738,13 +2730,17 @@ static std::tuple, std::vector> LoadWall case DatabaseStatus::FAILED_ALREADY_LOADED: code = RPC_WALLET_ALREADY_LOADED; break; + case DatabaseStatus::FAILED_ALREADY_EXISTS: + code = RPC_WALLET_ALREADY_EXISTS; + break; + case DatabaseStatus::FAILED_INVALID_BACKUP_FILE: + code = RPC_INVALID_PARAMETER; + break; default: // RPC_WALLET_ERROR is returned for all other cases. break; } throw JSONRPCError(code, error.original); } - - return { wallet, warnings }; } static RPCHelpMan upgradetohd() @@ -2872,7 +2868,15 @@ static RPCHelpMan loadwallet() WalletContext& context = EnsureWalletContext(request.context); const std::string name(request.params[0].get_str()); - auto [wallet, warnings] = LoadWalletHelper(context, request.params[1], name); + DatabaseOptions options; + DatabaseStatus status; + options.require_existing = true; + bilingual_str error; + std::vector warnings; + std::optional load_on_start = request.params[1].isNull() ? std::nullopt : std::optional(request.params[1].get_bool()); + std::shared_ptr const wallet = LoadWallet(*context.chain, *context.m_coinjoin_loader, name, load_on_start, options, status, error, warnings); + + HandleWalletError(wallet, status, error); UniValue obj(UniValue::VOBJ); obj.pushKV("name", wallet->GetName()); @@ -3072,27 +3076,17 @@ static RPCHelpMan restorewallet() std::string backup_file = request.params[1].get_str(); - if (!fs::exists(backup_file)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Backup file does not exist"); - } - std::string wallet_name = request.params[0].get_str(); - const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), wallet_name); + std::optional load_on_start = request.params[2].isNull() ? std::nullopt : std::optional(request.params[2].get_bool()); - if (fs::exists(wallet_path)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Wallet name already exists."); - } + DatabaseStatus status; + bilingual_str error; + std::vector warnings; - if (!TryCreateDirectories(wallet_path)) { - throw JSONRPCError(RPC_WALLET_ERROR, strprintf("Failed to create database path '%s'. Database already exists.", wallet_path.string())); - } + const std::shared_ptr wallet = RestoreWallet(*context.chain, *context.m_coinjoin_loader, backup_file, wallet_name, load_on_start, status, error, warnings); - auto wallet_file = wallet_path / "wallet.dat"; - - fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists); - - auto [wallet, warnings] = LoadWalletHelper(context, request.params[2], wallet_name); + HandleWalletError(wallet, status, error); UniValue obj(UniValue::VOBJ); obj.pushKV("name", wallet->GetName()); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 0506193232..c3e9aa36bd 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -365,6 +365,38 @@ std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::Coin return wallet; } +std::shared_ptr RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +{ + DatabaseOptions options; + options.require_existing = true; + + if (!fs::exists(backup_file)) { + error = Untranslated("Backup file does not exist"); + status = DatabaseStatus::FAILED_INVALID_BACKUP_FILE; + return nullptr; + } + + const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), wallet_name); + + if (fs::exists(wallet_path) || !TryCreateDirectories(wallet_path)) { + error = Untranslated(strprintf("Failed to create database path '%s'. Database already exists.", wallet_path.string())); + status = DatabaseStatus::FAILED_ALREADY_EXISTS; + return nullptr; + } + + auto wallet_file = wallet_path / "wallet.dat"; + fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists); + + auto wallet = LoadWallet(chain, coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings); + + if (!wallet) { + fs::remove(wallet_file); + fs::remove(wallet_path); + } + + return wallet; +} + /** @defgroup mapWallet * * @{ diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 4afa5e1623..35d00e4bae 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -62,6 +62,7 @@ std::vector> GetWallets(); std::shared_ptr GetWallet(const std::string& name); std::shared_ptr LoadWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::shared_ptr RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); std::unique_ptr HandleLoadWallet(LoadWalletFn load_wallet); std::unique_ptr MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error); diff --git a/test/functional/wallet_backup.py b/test/functional/wallet_backup.py index 35b3a25921..cdb772b9ca 100755 --- a/test/functional/wallet_backup.py +++ b/test/functional/wallet_backup.py @@ -107,17 +107,32 @@ class WalletBackupTest(BitcoinTestFramework): os.remove(os.path.join(self.nodes[1].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename)) os.remove(os.path.join(self.nodes[2].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename)) + def restore_invalid_wallet(self): + node = self.nodes[3] + invalid_wallet_file = os.path.join(self.nodes[0].datadir, 'invalid_wallet_file.bak') + open(invalid_wallet_file, 'a', encoding="utf8").write('invald wallet') + wallet_name = "res0" + not_created_wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name) + error_message = "Wallet file verification failed. Failed to load database path '{}'. Data is not in recognized format.".format(not_created_wallet_file) + assert_raises_rpc_error(-18, error_message, node.restorewallet, wallet_name, invalid_wallet_file) + assert not os.path.exists(not_created_wallet_file) + def restore_nonexistent_wallet(self): node = self.nodes[3] nonexistent_wallet_file = os.path.join(self.nodes[0].datadir, 'nonexistent_wallet.bak') wallet_name = "res0" assert_raises_rpc_error(-8, "Backup file does not exist", node.restorewallet, wallet_name, nonexistent_wallet_file) + not_created_wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name) + assert not os.path.exists(not_created_wallet_file) def restore_wallet_existent_name(self): node = self.nodes[3] - wallet_file = os.path.join(self.nodes[0].datadir, 'wallet.bak') + backup_file = os.path.join(self.nodes[0].datadir, 'wallet.bak') wallet_name = "res0" - assert_raises_rpc_error(-8, "Wallet name already exists.", node.restorewallet, wallet_name, wallet_file) + wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name) + error_message = "Failed to create database path '{}'. Database already exists.".format(wallet_file) + assert_raises_rpc_error(-36, error_message, node.restorewallet, wallet_name, backup_file) + assert os.path.exists(wallet_file) def init_three(self): self.init_wallet(0) @@ -179,6 +194,7 @@ class WalletBackupTest(BitcoinTestFramework): ## self.log.info("Restoring wallets on node 3 using backup files") + self.restore_invalid_wallet() self.restore_nonexistent_wallet() backup_file_0 = os.path.join(self.nodes[0].datadir, 'wallet.bak') @@ -189,6 +205,10 @@ class WalletBackupTest(BitcoinTestFramework): self.nodes[3].restorewallet("res1", backup_file_1) self.nodes[3].restorewallet("res2", backup_file_2) + assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res0")) + assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res1")) + assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res2")) + res0_rpc = self.nodes[3].get_wallet_rpc("res0") res1_rpc = self.nodes[3].get_wallet_rpc("res1") res2_rpc = self.nodes[3].get_wallet_rpc("res2")