Implement session management based on session ids and announcements

This commit is contained in:
Alexander Block 2019-02-26 08:42:53 +01:00
parent 7372f6f10b
commit 34e3f8eb53
2 changed files with 114 additions and 11 deletions

View File

@ -93,21 +93,73 @@ CSigSharesInv CBatchedSigShares::ToInv(Consensus::LLMQType llmqType) const
return inv;
}
CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSession(Consensus::LLMQType llmqType, const uint256& signHash)
template<typename T>
static void InitSession(CSigSharesNodeState::Session& s, const uint256& signHash, T& from)
{
auto& s = sessions[signHash];
s.llmqType = (Consensus::LLMQType)from.llmqType;
s.quorumHash = from.quorumHash;
s.id = from.id;
s.msgHash = from.msgHash;
s.signHash = signHash;
s.announced.Init((Consensus::LLMQType)from.llmqType);
s.requested.Init((Consensus::LLMQType)from.llmqType);
s.knows.Init((Consensus::LLMQType)from.llmqType);
}
CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSessionFromShare(const llmq::CSigShare& sigShare)
{
auto& s = sessions[sigShare.GetSignHash()];
if (s.announced.inv.empty()) {
s.announced.Init(llmqType, signHash);
s.requested.Init(llmqType, signHash);
s.knows.Init(llmqType, signHash);
} else {
assert(s.announced.llmqType == llmqType);
assert(s.requested.llmqType == llmqType);
assert(s.knows.llmqType == llmqType);
InitSession(s, sigShare.GetSignHash(), sigShare);
}
return s;
}
CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSessionFromAnn(const llmq::CSigSesAnn& ann)
{
auto signHash = CLLMQUtils::BuildSignHash((Consensus::LLMQType)ann.llmqType, ann.quorumHash, ann.id, ann.msgHash);
auto& s = sessions[signHash];
if (s.announced.inv.empty()) {
InitSession(s, signHash, ann);
}
return s;
}
CSigSharesNodeState::Session* CSigSharesNodeState::GetSessionBySignHash(const uint256& signHash)
{
auto it = sessions.find(signHash);
if (it == sessions.end()) {
return nullptr;
}
return &it->second;
}
CSigSharesNodeState::Session* CSigSharesNodeState::GetSessionByRecvId(uint32_t sessionId)
{
auto it = sessionByRecvId.find(sessionId);
if (it == sessionByRecvId.end()) {
return nullptr;
}
return it->second;
}
bool CSigSharesNodeState::GetSessionInfoByRecvId(uint32_t sessionId, SessionInfo& retInfo)
{
auto s = GetSessionByRecvId(sessionId);
if (!s) {
return false;
}
retInfo.recvSessionId = sessionId;
retInfo.llmqType = s->llmqType;
retInfo.quorumHash = s->quorumHash;
retInfo.id = s->id;
retInfo.msgHash = s->msgHash;
retInfo.signHash = s->signHash;
retInfo.quorum = s->quorum;
return true;
}
void CSigSharesNodeState::MarkAnnounced(const uint256& signHash, const CSigSharesInv& inv)
{
GetOrCreateSession((Consensus::LLMQType)inv.llmqType, signHash).announced.Merge(inv);
@ -140,7 +192,11 @@ void CSigSharesNodeState::MarkKnows(Consensus::LLMQType llmqType, const uint256&
void CSigSharesNodeState::RemoveSession(const uint256& signHash)
{
sessions.erase(signHash);
auto it = sessions.find(signHash);
if (it != sessions.end()) {
sessionByRecvId.erase(it->second.recvSessionId);
sessions.erase(it);
}
requestedSigShares.EraseAllForSignHash(signHash);
pendingIncomingSigShares.EraseAllForSignHash(signHash);
}
@ -244,6 +300,15 @@ void CSigSharesManager::ProcessMessageSigSesAnn(CNode* pfrom, const CSigSesAnn&
}
auto signHash = CLLMQUtils::BuildSignHash(llmqType, ann.quorumHash, ann.id, ann.msgHash);
LOCK(cs);
auto& nodeState = nodeStates[pfrom->id];
auto& session = nodeState.GetOrCreateSessionFromAnn(ann);
nodeState.sessionByRecvId.erase(session.recvSessionId);
nodeState.sessionByRecvId.erase(ann.sessionId);
session.recvSessionId = ann.sessionId;
session.quorum = quorum;
nodeState.sessionByRecvId.emplace(ann.sessionId, &session);
}
bool CSigSharesManager::VerifySigSharesInv(NodeId from, const CSigSharesInv& inv)
@ -980,6 +1045,12 @@ bool CSigSharesManager::SendMessages()
return didSend;
}
bool CSigSharesManager::GetSessionInfoByRecvId(NodeId nodeId, uint32_t sessionId, CSigSharesNodeState::SessionInfo& retInfo)
{
LOCK(cs);
return nodeStates[nodeId].GetSessionInfoByRecvId(sessionId, retInfo);
}
CSigShare CSigSharesManager::RebuildSigShare(const CSigSharesNodeState::SessionInfo& session, const CBatchedSigShares& batchedSigShares, size_t idx)
{
assert(idx < batchedSigShares.sigShares.size());

View File

@ -278,7 +278,31 @@ public:
class CSigSharesNodeState
{
public:
// Used to avoid holding locks too long
struct SessionInfo
{
uint32_t recvSessionId;
Consensus::LLMQType llmqType;
uint256 quorumHash;
uint256 id;
uint256 msgHash;
uint256 signHash;
CQuorumCPtr quorum;
};
struct Session {
uint32_t recvSessionId{(uint32_t)-1};
uint32_t sendSessionId{(uint32_t)-1};
Consensus::LLMQType llmqType;
uint256 quorumHash;
uint256 id;
uint256 msgHash;
uint256 signHash;
CQuorumCPtr quorum;
CSigSharesInv announced;
CSigSharesInv requested;
CSigSharesInv knows;
@ -286,6 +310,9 @@ public:
// TODO limit number of sessions per node
std::unordered_map<uint256, Session, StaticSaltedHasher> sessions;
std::unordered_map<uint32_t, Session*> sessionByRecvId;
uint32_t nextSendSessionId{1};
SigShareMap<CSigShare> pendingIncomingSigShares;
SigShareMap<int64_t> requestedSigShares;
@ -295,7 +322,11 @@ public:
bool banned{false};
Session& GetOrCreateSession(Consensus::LLMQType llmqType, const uint256& signHash);
Session& GetOrCreateSessionFromShare(const CSigShare& sigShare);
Session& GetOrCreateSessionFromAnn(const CSigSesAnn& ann);
Session* GetSessionBySignHash(const uint256& signHash);
Session* GetSessionByRecvId(uint32_t sessionId);
bool GetSessionInfoByRecvId(uint32_t sessionId, SessionInfo& retInfo);
void MarkAnnounced(const uint256& signHash, const CSigSharesInv& inv);
void MarkRequested(const uint256& signHash, const CSigSharesInv& inv);
@ -377,6 +408,7 @@ private:
void TryRecoverSig(const CQuorumCPtr& quorum, const uint256& id, const uint256& msgHash, CConnman& connman);
private:
bool GetSessionInfoByRecvId(NodeId nodeId, uint32_t sessionId, CSigSharesNodeState::SessionInfo& retInfo);
CSigShare RebuildSigShare(const CSigSharesNodeState::SessionInfo& session, const CBatchedSigShares& batchedSigShares, size_t idx);
void Cleanup();