diff --git a/src/bench/gcs_filter.cpp b/src/bench/gcs_filter.cpp index b26698ed5b..fa7ba0eb26 100644 --- a/src/bench/gcs_filter.cpp +++ b/src/bench/gcs_filter.cpp @@ -17,7 +17,7 @@ static void ConstructGCSFilter(benchmark::Bench& bench) uint64_t siphash_k0 = 0; bench.batch(elements.size()).unit("elem").run([&] { - GCSFilter filter(siphash_k0, 0, 20, 1 << 20, elements); + GCSFilter filter({siphash_k0, 0, 20, 1 << 20}, elements); siphash_k0++; }); } @@ -31,7 +31,7 @@ static void MatchGCSFilter(benchmark::Bench& bench) element[1] = static_cast(i >> 8); elements.insert(std::move(element)); } - GCSFilter filter(0, 0, 20, 1 << 20, elements); + GCSFilter filter({0, 0, 20, 1 << 20}, elements); bench.unit("elem").run([&] { filter.Match(GCSFilter::Element()); diff --git a/src/blockfilter.cpp b/src/blockfilter.cpp index 3669581023..5b2d4f07da 100644 --- a/src/blockfilter.cpp +++ b/src/blockfilter.cpp @@ -79,7 +79,7 @@ static uint64_t MapIntoRange(uint64_t x, uint64_t n) uint64_t GCSFilter::HashToRange(const Element& element) const { - uint64_t hash = CSipHasher(m_siphash_k0, m_siphash_k1) + uint64_t hash = CSipHasher(m_params.m_siphash_k0, m_params.m_siphash_k1) .Write(element.data(), element.size()) .Finalize(); return MapIntoRange(hash, m_F); @@ -96,16 +96,13 @@ std::vector GCSFilter::BuildHashedSet(const ElementSet& elements) cons return hashed_elements; } -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M) - : m_siphash_k0(siphash_k0), m_siphash_k1(siphash_k1), m_P(P), m_M(M), m_N(0), m_F(0) +GCSFilter::GCSFilter(const Params& params) + : m_params(params), m_N(0), m_F(0), m_encoded{0} {} -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - std::vector encoded_filter) - : GCSFilter(siphash_k0, siphash_k1, P, M) +GCSFilter::GCSFilter(const Params& params, std::vector encoded_filter) + : m_params(params), m_encoded(std::move(encoded_filter)) { - m_encoded = std::move(encoded_filter); - VectorReader stream(GCS_SER_TYPE, GCS_SER_VERSION, m_encoded, 0); uint64_t N = ReadCompactSize(stream); @@ -113,29 +110,28 @@ GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32 if (m_N != N) { throw std::ios_base::failure("N must be <2^32"); } - m_F = static_cast(m_N) * static_cast(m_M); + m_F = static_cast(m_N) * static_cast(m_params.m_M); // Verify that the encoded filter contains exactly N elements. If it has too much or too little // data, a std::ios_base::failure exception will be raised. BitStreamReader bitreader(stream); for (uint64_t i = 0; i < m_N; ++i) { - GolombRiceDecode(bitreader, m_P); + GolombRiceDecode(bitreader, m_params.m_P); } if (!stream.empty()) { throw std::ios_base::failure("encoded_filter contains excess data"); } } -GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - const ElementSet& elements) - : GCSFilter(siphash_k0, siphash_k1, P, M) +GCSFilter::GCSFilter(const Params& params, const ElementSet& elements) + : m_params(params) { size_t N = elements.size(); m_N = static_cast(N); if (m_N != N) { throw std::invalid_argument("N must be <2^32"); } - m_F = static_cast(m_N) * static_cast(m_M); + m_F = static_cast(m_N) * static_cast(m_params.m_M); CVectorWriter stream(GCS_SER_TYPE, GCS_SER_VERSION, m_encoded, 0); @@ -150,7 +146,7 @@ GCSFilter::GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32 uint64_t last_value = 0; for (uint64_t value : BuildHashedSet(elements)) { uint64_t delta = value - last_value; - GolombRiceEncode(bitwriter, m_P, delta); + GolombRiceEncode(bitwriter, m_params.m_P, delta); last_value = value; } @@ -170,7 +166,7 @@ bool GCSFilter::MatchInternal(const uint64_t* element_hashes, size_t size) const uint64_t value = 0; size_t hashes_index = 0; for (uint32_t i = 0; i < m_N; ++i) { - uint64_t delta = GolombRiceDecode(bitreader, m_P); + uint64_t delta = GolombRiceDecode(bitreader, m_params.m_P); value += delta; while (true) { @@ -225,19 +221,39 @@ static GCSFilter::ElementSet BasicFilterElements(const CBlock& block, return elements; } +BlockFilter::BlockFilter(BlockFilterType filter_type, const uint256& block_hash, + std::vector filter) + : m_filter_type(filter_type), m_block_hash(block_hash) +{ + GCSFilter::Params params; + if (!BuildParams(params)) { + throw std::invalid_argument("unknown filter_type"); + } + m_filter = GCSFilter(params, std::move(filter)); +} + BlockFilter::BlockFilter(BlockFilterType filter_type, const CBlock& block, const CBlockUndo& block_undo) : m_filter_type(filter_type), m_block_hash(block.GetHash()) { - switch (m_filter_type) { - case BlockFilterType::BASIC_FILTER: - m_filter = GCSFilter(m_block_hash.GetUint64(0), m_block_hash.GetUint64(1), - BASIC_FILTER_P, BASIC_FILTER_M, - BasicFilterElements(block, block_undo)); - break; - - default: + GCSFilter::Params params; + if (!BuildParams(params)) { throw std::invalid_argument("unknown filter_type"); } + m_filter = GCSFilter(params, BasicFilterElements(block, block_undo)); +} + +bool BlockFilter::BuildParams(GCSFilter::Params& params) const +{ + switch (m_filter_type) { + case BlockFilterType::BASIC_FILTER: + params.m_siphash_k0 = m_block_hash.GetUint64(0); + params.m_siphash_k1 = m_block_hash.GetUint64(1); + params.m_P = BASIC_FILTER_P; + params.m_M = BASIC_FILTER_M; + return true; + } + + return false; } uint256 BlockFilter::GetHash() const diff --git a/src/blockfilter.h b/src/blockfilter.h index 667530738a..948f30682a 100644 --- a/src/blockfilter.h +++ b/src/blockfilter.h @@ -25,11 +25,20 @@ public: typedef std::vector Element; typedef std::unordered_set ElementSet; + struct Params + { + uint64_t m_siphash_k0; + uint64_t m_siphash_k1; + uint8_t m_P; //!< Golomb-Rice coding parameter + uint32_t m_M; //!< Inverse false positive rate + + Params(uint64_t siphash_k0 = 0, uint64_t siphash_k1 = 0, uint8_t P = 0, uint32_t M = 1) + : m_siphash_k0(siphash_k0), m_siphash_k1(siphash_k1), m_P(P), m_M(M) + {} + }; + private: - uint64_t m_siphash_k0; - uint64_t m_siphash_k1; - uint8_t m_P; //!< Golomb-Rice coding parameter - uint32_t m_M; //!< Inverse false positive rate + Params m_params; uint32_t m_N; //!< Number of elements in the filter uint64_t m_F; //!< Range of element hashes, F = N * M std::vector m_encoded; @@ -45,19 +54,16 @@ private: public: /** Constructs an empty filter. */ - GCSFilter(uint64_t siphash_k0 = 0, uint64_t siphash_k1 = 0, uint8_t P = 0, uint32_t M = 0); + explicit GCSFilter(const Params& params = Params()); /** Reconstructs an already-created filter from an encoding. */ - GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - std::vector encoded_filter); + GCSFilter(const Params& params, std::vector encoded_filter); /** Builds a new filter from the params and set of elements. */ - GCSFilter(uint64_t siphash_k0, uint64_t siphash_k1, uint8_t P, uint32_t M, - const ElementSet& elements); + GCSFilter(const Params& params, const ElementSet& elements); - uint8_t GetP() const { return m_P; } uint32_t GetN() const { return m_N; } - uint32_t GetM() const { return m_M; } + const Params& GetParams() const { return m_params; } const std::vector& GetEncoded() const { return m_encoded; } /** @@ -93,13 +99,21 @@ private: uint256 m_block_hash; GCSFilter m_filter; + bool BuildParams(GCSFilter::Params& params) const; + public: - // Construct a new BlockFilter of the specified type from a block. + BlockFilter() = default; + + //! Reconstruct a BlockFilter from parts. + BlockFilter(BlockFilterType filter_type, const uint256& block_hash, + std::vector filter); + + //! Construct a new BlockFilter of the specified type from a block. BlockFilter(BlockFilterType filter_type, const CBlock& block, const CBlockUndo& block_undo); BlockFilterType GetFilterType() const { return m_filter_type; } - + const uint256& GetBlockHash() const { return m_block_hash; } const GCSFilter& GetFilter() const { return m_filter; } const std::vector& GetEncodedFilter() const @@ -107,10 +121,10 @@ public: return m_filter.GetEncoded(); } - // Compute the filter hash. + //! Compute the filter hash. uint256 GetHash() const; - // Compute the filter header given the previous one. + //! Compute the filter header given the previous one. uint256 ComputeHeader(const uint256& prev_header) const; template @@ -131,15 +145,11 @@ public: m_filter_type = static_cast(filter_type); - switch (m_filter_type) { - case BlockFilterType::BASIC_FILTER: - m_filter = GCSFilter(m_block_hash.GetUint64(0), m_block_hash.GetUint64(1), - BASIC_FILTER_P, BASIC_FILTER_M, std::move(encoded_filter)); - break; - - default: + GCSFilter::Params params; + if (!BuildParams(params)) { throw std::ios_base::failure("unknown filter_type"); } + m_filter = GCSFilter(params, std::move(encoded_filter)); } }; diff --git a/src/test/blockfilter_tests.cpp b/src/test/blockfilter_tests.cpp index f888593989..904f7b43b4 100644 --- a/src/test/blockfilter_tests.cpp +++ b/src/test/blockfilter_tests.cpp @@ -29,7 +29,7 @@ BOOST_AUTO_TEST_CASE(gcsfilter_test) excluded_elements.insert(std::move(element2)); } - GCSFilter filter(0, 0, 10, 1 << 10, included_elements); + GCSFilter filter({0, 0, 10, 1 << 10}, included_elements); for (const auto& element : included_elements) { BOOST_CHECK(filter.Match(element)); @@ -39,6 +39,19 @@ BOOST_AUTO_TEST_CASE(gcsfilter_test) } } +BOOST_AUTO_TEST_CASE(gcsfilter_default_constructor) +{ + GCSFilter filter; + BOOST_CHECK_EQUAL(filter.GetN(), 0); + BOOST_CHECK_EQUAL(filter.GetEncoded().size(), 1); + + const GCSFilter::Params& params = filter.GetParams(); + BOOST_CHECK_EQUAL(params.m_siphash_k0, 0); + BOOST_CHECK_EQUAL(params.m_siphash_k1, 0); + BOOST_CHECK_EQUAL(params.m_P, 0); + BOOST_CHECK_EQUAL(params.m_M, 1); +} + BOOST_AUTO_TEST_CASE(blockfilter_basic_test) { CScript included_scripts[5], excluded_scripts[3]; @@ -88,6 +101,17 @@ BOOST_AUTO_TEST_CASE(blockfilter_basic_test) for (const CScript& script : excluded_scripts) { BOOST_CHECK(!filter.Match(GCSFilter::Element(script.begin(), script.end()))); } + + // Test serialization/unserialization. + BlockFilter block_filter2; + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << block_filter; + stream >> block_filter2; + + BOOST_CHECK_EQUAL(block_filter.GetFilterType(), block_filter2.GetFilterType()); + BOOST_CHECK_EQUAL(block_filter.GetBlockHash(), block_filter2.GetBlockHash()); + BOOST_CHECK(block_filter.GetEncodedFilter() == block_filter2.GetEncodedFilter()); } BOOST_AUTO_TEST_CASE(blockfilters_json_test)