From 5dfdb3ae2e6f5debcac0bba7b2ad6c4796c3533e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 2 Jul 2026 09:47:54 +0100 Subject: [PATCH] Fix incorrect deduplication of flood packets --- src/Mesh.cpp | 73 ++++++++++-------- src/Mesh.h | 2 + src/helpers/SimpleMeshTables.h | 77 ++++++++++++++----- .../test_simple_mesh_tables.cpp | 20 +++++ 4 files changed, 122 insertions(+), 50 deletions(-) diff --git a/src/Mesh.cpp b/src/Mesh.cpp index c11f37cacf..c865eac36e 100644 --- a/src/Mesh.cpp +++ b/src/Mesh.cpp @@ -120,11 +120,13 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { memcpy(&ack_crc, &pkt->payload[i], 4); i += 4; if (i > pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete ACK packet", getLogDateTime()); - } else if (!_tables->wasSeen(pkt)) { + break; + } + if (!_tables->wasSeen(pkt)) { _tables->markSeen(pkt); onAckRecv(pkt, ack_crc); - action = routeRecvPacket(pkt); } + action = routeRecvPacket(pkt); break; } case PAYLOAD_TYPE_PATH: @@ -138,8 +140,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + CIPHER_MAC_SIZE >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->wasSeen(pkt)) { - _tables->markSeen(pkt); + } else { + bool already_dispatched = _tables->wasSeen(pkt); + if (!already_dispatched) _tables->markSeen(pkt); // NOTE: this is a 'first packet wins' impl. When receiving from multiple paths, the first to arrive wins. // For flood mode, the path may not be the 'best' in terms of hops. // FUTURE: could send back multiple paths, using createPathReturn(), and let sender choose which to use(?) @@ -170,14 +173,14 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t extra_type = data[k++] & 0x0F; // upper 4 bits reserved for future use uint8_t* extra = &data[k]; uint8_t extra_len = len - k; // remainder of packet (may be padded with zeroes!) - if (onPeerPathRecv(pkt, j, secret, path, path_len, extra_type, extra, extra_len)) { + if (!already_dispatched && onPeerPathRecv(pkt, j, secret, path, path_len, extra_type, extra, extra_len)) { if (pkt->isRouteFlood()) { // send a reciprocal return path to sender, but send DIRECTLY! mesh::Packet* rpath = createPathReturn(&src_hash, secret, pkt->path, pkt->path_len, 0, NULL, 0); if (rpath) sendDirect(rpath, path, path_len, 500); } } - } else { + } else if (!already_dispatched) { onPeerDataRecv(pkt, pkt->getPayloadType(), j, secret, data, len); } found = true; @@ -186,7 +189,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { } if (found) { pkt->markDoNotRetransmit(); // packet was for this node, so don't retransmit - } else { + } else if (!already_dispatched) { MESH_DEBUG_PRINTLN("%s recv matches no peers, src_hash=%02X", getLogDateTime(), (uint32_t)src_hash); } } @@ -202,8 +205,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->wasSeen(pkt)) { - _tables->markSeen(pkt); + } else { + bool already_dispatched = _tables->wasSeen(pkt); + if (!already_dispatched) _tables->markSeen(pkt); if (self_id.isHashMatch(&dest_hash)) { Identity sender(sender_pub_key); @@ -214,7 +218,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t data[MAX_PACKET_PAYLOAD]; int len = Utils::MACThenDecrypt(secret, data, macAndData, pkt->payload_len - i); if (len > 0) { // success! - onAnonDataRecv(pkt, secret, sender, data, len); + if (!already_dispatched) onAnonDataRecv(pkt, secret, sender, data, len); pkt->markDoNotRetransmit(); } } @@ -230,23 +234,27 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->wasSeen(pkt)) { - _tables->markSeen(pkt); - // scan channels DB, for all matching hashes of 'channel_hash' (max 4 matches supported ATM) - GroupChannel channels[4]; - int num = searchChannelsByHash(&channel_hash, channels, 4); - // for each matching channel, try to decrypt data - for (int j = 0; j < num; j++) { - // decrypt, checking MAC is valid - uint8_t data[MAX_PACKET_PAYLOAD]; - int len = Utils::MACThenDecrypt(channels[j].secret, data, macAndData, pkt->payload_len - i); - if (len > 0) { // success! - onGroupDataRecv(pkt, pkt->getPayloadType(), channels[j], data, len); - break; - } - } + break; + } + if (_tables->wasSeen(pkt)) { action = routeRecvPacket(pkt); + break; } + _tables->markSeen(pkt); + // scan channels DB, for all matching hashes of 'channel_hash' (max 4 matches supported ATM) + GroupChannel channels[4]; + int num = searchChannelsByHash(&channel_hash, channels, 4); + // for each matching channel, try to decrypt data + for (int j = 0; j < num; j++) { + // decrypt, checking MAC is valid + uint8_t data[MAX_PACKET_PAYLOAD]; + int len = Utils::MACThenDecrypt(channels[j].secret, data, macAndData, pkt->payload_len - i); + if (len > 0) { // success! + onGroupDataRecv(pkt, pkt->getPayloadType(), channels[j], data, len); + break; + } + } + action = routeRecvPacket(pkt); break; } case PAYLOAD_TYPE_ADVERT: { @@ -262,8 +270,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete advertisement packet", getLogDateTime()); } else if (self_id.matches(id.pub_key)) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): receiving SELF advert packet", getLogDateTime()); - } else if (!_tables->wasSeen(pkt)) { - _tables->markSeen(pkt); + } else { + bool already_dispatched = _tables->wasSeen(pkt); + if (!already_dispatched) _tables->markSeen(pkt); uint8_t* app_data = &pkt->payload[i]; int app_data_len = pkt->payload_len - i; if (app_data_len > MAX_ADVERT_DATA_SIZE) { app_data_len = MAX_ADVERT_DATA_SIZE; } @@ -281,7 +290,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { } if (is_ok) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): valid advertisement received!", getLogDateTime()); - onAdvertRecv(pkt, id, timestamp, app_data, app_data_len); + if (!already_dispatched) onAdvertRecv(pkt, id, timestamp, app_data, app_data_len); action = routeRecvPacket(pkt); } else { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): received advertisement with forged signature! (app_data_len=%d)", getLogDateTime(), app_data_len); @@ -344,7 +353,9 @@ void Mesh::removeSelfFromPath(Packet* pkt) { DispatcherAction Mesh::routeRecvPacket(Packet* packet) { uint8_t n = packet->getPathHashCount(); if (packet->isRouteFlood() && !packet->isMarkedDoNotRetransmit() - && (n + 1)*packet->getPathHashSize() <= MAX_PATH_SIZE && allowPacketForward(packet)) { + && (n + 1)*packet->getPathHashSize() <= MAX_PATH_SIZE && allowPacketForward(packet) + && !_tables->wasForwarded(packet)) { + _tables->markForwarded(packet); // append this node's hash to 'path' self_id.copyHashTo(&packet->path[n * packet->getPathHashSize()], packet->getPathHashSize()); packet->setPathHashCount(n + 1); @@ -649,6 +660,7 @@ void Mesh::sendFlood(Packet* packet, uint32_t delay_millis, uint8_t path_hash_si packet->setPathHashSizeAndCount(path_hash_size, 0); _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markForwarded(packet); // and do not forward rebroadcast copies of our own flood uint8_t pri; if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) { @@ -678,6 +690,7 @@ void Mesh::sendFlood(Packet* packet, uint16_t* transport_codes, uint32_t delay_m packet->setPathHashSizeAndCount(path_hash_size, 0); _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markForwarded(packet); // and do not forward rebroadcast copies of our own flood uint8_t pri; if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) { @@ -738,4 +751,4 @@ void Mesh::sendZeroHop(Packet* packet, uint16_t* transport_codes, uint32_t delay sendPacket(packet, 0, delay_millis); } -} \ No newline at end of file +} diff --git a/src/Mesh.h b/src/Mesh.h index 49a299a6a4..fa3ba741a5 100644 --- a/src/Mesh.h +++ b/src/Mesh.h @@ -17,6 +17,8 @@ class MeshTables { public: virtual bool wasSeen(const Packet* packet) = 0; virtual void markSeen(const Packet* packet) = 0; + virtual bool wasForwarded(const Packet* packet) = 0; + virtual void markForwarded(const Packet* packet) = 0; virtual void clear(const Packet* packet) = 0; // remove this packet hash from table }; diff --git a/src/helpers/SimpleMeshTables.h b/src/helpers/SimpleMeshTables.h index 956f36faa6..45f1e50212 100644 --- a/src/helpers/SimpleMeshTables.h +++ b/src/helpers/SimpleMeshTables.h @@ -9,13 +9,21 @@ #define MAX_PACKET_HASHES (128+32) class SimpleMeshTables : public mesh::MeshTables { + enum SeenType : uint8_t { + SEEN_DISPATCH = 0x01, + SEEN_FORWARD = 0x02, + SEEN_BOTH = SEEN_DISPATCH | SEEN_FORWARD, + }; + uint8_t _hashes[MAX_PACKET_HASHES*MAX_HASH_SIZE]; + uint8_t _seen_types[MAX_PACKET_HASHES]; int _next_idx; uint32_t _direct_dups, _flood_dups; public: SimpleMeshTables() { memset(_hashes, 0, sizeof(_hashes)); + memset(_seen_types, 0, sizeof(_seen_types)); _next_idx = 0; _direct_dups = _flood_dups = 0; } @@ -23,6 +31,7 @@ class SimpleMeshTables : public mesh::MeshTables { #ifdef ESP32 void restoreFrom(File f) { f.read(_hashes, sizeof(_hashes)); + memset(_seen_types, SEEN_BOTH, sizeof(_seen_types)); f.read((uint8_t *) &_next_idx, sizeof(_next_idx)); } void saveTo(File f) { @@ -31,41 +40,69 @@ class SimpleMeshTables : public mesh::MeshTables { } #endif - bool wasSeen(const mesh::Packet* packet) override { + int findHash(const uint8_t hash[]) const { + const uint8_t* sp = _hashes; + for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { + if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) return i; + } + return -1; + } + + bool hasType(const mesh::Packet* packet, uint8_t seen_type, bool count_dup) { uint8_t hash[MAX_HASH_SIZE]; packet->calculatePacketHash(hash); - const uint8_t* sp = _hashes; - for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { - if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) { - if (packet->isRouteDirect()) { - _direct_dups++; - } else { - _flood_dups++; - } - return true; + int idx = findHash(hash); + if (idx < 0 || (_seen_types[idx] & seen_type) == 0) return false; + + if (count_dup) { + if (packet->isRouteDirect()) { + _direct_dups++; + } else { + _flood_dups++; } } - return false; + return true; } - void markSeen(const mesh::Packet* packet) override { + void markType(const mesh::Packet* packet, uint8_t seen_type) { uint8_t hash[MAX_HASH_SIZE]; packet->calculatePacketHash(hash); - memcpy(&_hashes[_next_idx * MAX_HASH_SIZE], hash, MAX_HASH_SIZE); - _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; + + int idx = findHash(hash); + if (idx < 0) { + idx = _next_idx; + memcpy(&_hashes[idx * MAX_HASH_SIZE], hash, MAX_HASH_SIZE); + _seen_types[idx] = 0; + _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; + } + _seen_types[idx] |= seen_type; + } + + bool wasSeen(const mesh::Packet* packet) override { + return hasType(packet, SEEN_DISPATCH, true); + } + + void markSeen(const mesh::Packet* packet) override { + markType(packet, SEEN_DISPATCH); + } + + bool wasForwarded(const mesh::Packet* packet) override { + return hasType(packet, SEEN_FORWARD, false); + } + + void markForwarded(const mesh::Packet* packet) override { + markType(packet, SEEN_FORWARD); } void clear(const mesh::Packet* packet) override { uint8_t hash[MAX_HASH_SIZE]; packet->calculatePacketHash(hash); - uint8_t* sp = _hashes; - for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { - if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) { - memset(sp, 0, MAX_HASH_SIZE); - break; - } + int idx = findHash(hash); + if (idx >= 0) { + memset(&_hashes[idx * MAX_HASH_SIZE], 0, MAX_HASH_SIZE); + _seen_types[idx] = 0; } } diff --git a/test/test_mesh_tables/test_simple_mesh_tables.cpp b/test/test_mesh_tables/test_simple_mesh_tables.cpp index 46b477d944..8c4131116f 100644 --- a/test/test_mesh_tables/test_simple_mesh_tables.cpp +++ b/test/test_mesh_tables/test_simple_mesh_tables.cpp @@ -66,6 +66,23 @@ TEST(SimpleMeshTables, QueryThenMark_WorksCorrectly) { EXPECT_TRUE(t.wasSeen(&p)); } +// ── forwarding state ───────────────────────────────────────────────────────── + +TEST(SimpleMeshTables, MarkSeen_DoesNotMarkForwarded) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + t.markSeen(&p); + EXPECT_FALSE(t.wasForwarded(&p)); +} + +TEST(SimpleMeshTables, MarkForwarded_DoesNotMarkSeen) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + t.markForwarded(&p); + EXPECT_TRUE(t.wasForwarded(&p)); + EXPECT_FALSE(t.wasSeen(&p)); +} + // ── dup stats ──────────────────────────────────────────────────────────────── TEST(SimpleMeshTables, WasSeen_IncrementsFloodDupStat) { @@ -92,9 +109,12 @@ TEST(SimpleMeshTables, Clear_RemovesSeenPacket) { SimpleMeshTables t; Packet p = makeFloodPacket(0x01); t.markSeen(&p); + t.markForwarded(&p); ASSERT_TRUE(t.wasSeen(&p)); + ASSERT_TRUE(t.wasForwarded(&p)); t.clear(&p); EXPECT_FALSE(t.wasSeen(&p)); + EXPECT_FALSE(t.wasForwarded(&p)); } int main(int argc, char** argv) {