diff --git a/src/llmq/quorums_signing_shares.cpp b/src/llmq/quorums_signing_shares.cpp index 06170ce91..06ec554f6 100644 --- a/src/llmq/quorums_signing_shares.cpp +++ b/src/llmq/quorums_signing_shares.cpp @@ -93,21 +93,73 @@ CSigSharesInv CBatchedSigShares::ToInv(Consensus::LLMQType llmqType) const return inv; } -CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSession(Consensus::LLMQType llmqType, const uint256& signHash) +template +static void InitSession(CSigSharesNodeState::Session& s, const uint256& signHash, T& from) { - auto& s = sessions[signHash]; + s.llmqType = (Consensus::LLMQType)from.llmqType; + s.quorumHash = from.quorumHash; + s.id = from.id; + s.msgHash = from.msgHash; + s.signHash = signHash; + s.announced.Init((Consensus::LLMQType)from.llmqType); + s.requested.Init((Consensus::LLMQType)from.llmqType); + s.knows.Init((Consensus::LLMQType)from.llmqType); +} + +CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSessionFromShare(const llmq::CSigShare& sigShare) +{ + auto& s = sessions[sigShare.GetSignHash()]; if (s.announced.inv.empty()) { - s.announced.Init(llmqType, signHash); - s.requested.Init(llmqType, signHash); - s.knows.Init(llmqType, signHash); - } else { - assert(s.announced.llmqType == llmqType); - assert(s.requested.llmqType == llmqType); - assert(s.knows.llmqType == llmqType); + InitSession(s, sigShare.GetSignHash(), sigShare); } return s; } +CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSessionFromAnn(const llmq::CSigSesAnn& ann) +{ + auto signHash = CLLMQUtils::BuildSignHash((Consensus::LLMQType)ann.llmqType, ann.quorumHash, ann.id, ann.msgHash); + auto& s = sessions[signHash]; + if (s.announced.inv.empty()) { + InitSession(s, signHash, ann); + } + return s; +} + +CSigSharesNodeState::Session* CSigSharesNodeState::GetSessionBySignHash(const uint256& signHash) +{ + auto it = sessions.find(signHash); + if (it == sessions.end()) { + return nullptr; + } + return &it->second; +} + +CSigSharesNodeState::Session* CSigSharesNodeState::GetSessionByRecvId(uint32_t sessionId) +{ + auto it = sessionByRecvId.find(sessionId); + if (it == sessionByRecvId.end()) { + return nullptr; + } + return it->second; +} + +bool CSigSharesNodeState::GetSessionInfoByRecvId(uint32_t sessionId, SessionInfo& retInfo) +{ + auto s = GetSessionByRecvId(sessionId); + if (!s) { + return false; + } + retInfo.recvSessionId = sessionId; + retInfo.llmqType = s->llmqType; + retInfo.quorumHash = s->quorumHash; + retInfo.id = s->id; + retInfo.msgHash = s->msgHash; + retInfo.signHash = s->signHash; + retInfo.quorum = s->quorum; + + return true; +} + void CSigSharesNodeState::MarkAnnounced(const uint256& signHash, const CSigSharesInv& inv) { GetOrCreateSession((Consensus::LLMQType)inv.llmqType, signHash).announced.Merge(inv); @@ -140,7 +192,11 @@ void CSigSharesNodeState::MarkKnows(Consensus::LLMQType llmqType, const uint256& void CSigSharesNodeState::RemoveSession(const uint256& signHash) { - sessions.erase(signHash); + auto it = sessions.find(signHash); + if (it != sessions.end()) { + sessionByRecvId.erase(it->second.recvSessionId); + sessions.erase(it); + } requestedSigShares.EraseAllForSignHash(signHash); pendingIncomingSigShares.EraseAllForSignHash(signHash); } @@ -244,6 +300,15 @@ void CSigSharesManager::ProcessMessageSigSesAnn(CNode* pfrom, const CSigSesAnn& } auto signHash = CLLMQUtils::BuildSignHash(llmqType, ann.quorumHash, ann.id, ann.msgHash); + + LOCK(cs); + auto& nodeState = nodeStates[pfrom->id]; + auto& session = nodeState.GetOrCreateSessionFromAnn(ann); + nodeState.sessionByRecvId.erase(session.recvSessionId); + nodeState.sessionByRecvId.erase(ann.sessionId); + session.recvSessionId = ann.sessionId; + session.quorum = quorum; + nodeState.sessionByRecvId.emplace(ann.sessionId, &session); } bool CSigSharesManager::VerifySigSharesInv(NodeId from, const CSigSharesInv& inv) @@ -980,6 +1045,12 @@ bool CSigSharesManager::SendMessages() return didSend; } +bool CSigSharesManager::GetSessionInfoByRecvId(NodeId nodeId, uint32_t sessionId, CSigSharesNodeState::SessionInfo& retInfo) +{ + LOCK(cs); + return nodeStates[nodeId].GetSessionInfoByRecvId(sessionId, retInfo); +} + CSigShare CSigSharesManager::RebuildSigShare(const CSigSharesNodeState::SessionInfo& session, const CBatchedSigShares& batchedSigShares, size_t idx) { assert(idx < batchedSigShares.sigShares.size()); diff --git a/src/llmq/quorums_signing_shares.h b/src/llmq/quorums_signing_shares.h index 9b37fde5e..d2eb941ab 100644 --- a/src/llmq/quorums_signing_shares.h +++ b/src/llmq/quorums_signing_shares.h @@ -278,7 +278,31 @@ public: class CSigSharesNodeState { public: + // Used to avoid holding locks too long + struct SessionInfo + { + uint32_t recvSessionId; + Consensus::LLMQType llmqType; + uint256 quorumHash; + uint256 id; + uint256 msgHash; + uint256 signHash; + + CQuorumCPtr quorum; + }; + struct Session { + uint32_t recvSessionId{(uint32_t)-1}; + uint32_t sendSessionId{(uint32_t)-1}; + + Consensus::LLMQType llmqType; + uint256 quorumHash; + uint256 id; + uint256 msgHash; + uint256 signHash; + + CQuorumCPtr quorum; + CSigSharesInv announced; CSigSharesInv requested; CSigSharesInv knows; @@ -286,6 +310,9 @@ public: // TODO limit number of sessions per node std::unordered_map sessions; + std::unordered_map sessionByRecvId; + uint32_t nextSendSessionId{1}; + SigShareMap pendingIncomingSigShares; SigShareMap requestedSigShares; @@ -295,7 +322,11 @@ public: bool banned{false}; - Session& GetOrCreateSession(Consensus::LLMQType llmqType, const uint256& signHash); + Session& GetOrCreateSessionFromShare(const CSigShare& sigShare); + Session& GetOrCreateSessionFromAnn(const CSigSesAnn& ann); + Session* GetSessionBySignHash(const uint256& signHash); + Session* GetSessionByRecvId(uint32_t sessionId); + bool GetSessionInfoByRecvId(uint32_t sessionId, SessionInfo& retInfo); void MarkAnnounced(const uint256& signHash, const CSigSharesInv& inv); void MarkRequested(const uint256& signHash, const CSigSharesInv& inv); @@ -377,6 +408,7 @@ private: void TryRecoverSig(const CQuorumCPtr& quorum, const uint256& id, const uint256& msgHash, CConnman& connman); private: + bool GetSessionInfoByRecvId(NodeId nodeId, uint32_t sessionId, CSigSharesNodeState::SessionInfo& retInfo); CSigShare RebuildSigShare(const CSigSharesNodeState::SessionInfo& session, const CBatchedSigShares& batchedSigShares, size_t idx); void Cleanup();