Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 43 additions & 30 deletions src/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,13 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
memcpy(&ack_crc, &pkt->payload[i], 4); i += 4;
if (i > pkt->payload_len) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete ACK packet", getLogDateTime());
} else if (!_tables->wasSeen(pkt)) {
break;
}
if (!_tables->wasSeen(pkt)) {
_tables->markSeen(pkt);
onAckRecv(pkt, ack_crc);
action = routeRecvPacket(pkt);
}
action = routeRecvPacket(pkt);
break;
}
case PAYLOAD_TYPE_PATH:
Expand All @@ -138,8 +140,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data
if (i + CIPHER_MAC_SIZE >= pkt->payload_len) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime());
} else if (!_tables->wasSeen(pkt)) {
_tables->markSeen(pkt);
} else {
bool already_dispatched = _tables->wasSeen(pkt);
if (!already_dispatched) _tables->markSeen(pkt);
// NOTE: this is a 'first packet wins' impl. When receiving from multiple paths, the first to arrive wins.
// For flood mode, the path may not be the 'best' in terms of hops.
// FUTURE: could send back multiple paths, using createPathReturn(), and let sender choose which to use(?)
Expand Down Expand Up @@ -170,14 +173,14 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
uint8_t extra_type = data[k++] & 0x0F; // upper 4 bits reserved for future use
uint8_t* extra = &data[k];
uint8_t extra_len = len - k; // remainder of packet (may be padded with zeroes!)
if (onPeerPathRecv(pkt, j, secret, path, path_len, extra_type, extra, extra_len)) {
if (!already_dispatched && onPeerPathRecv(pkt, j, secret, path, path_len, extra_type, extra, extra_len)) {
if (pkt->isRouteFlood()) {
// send a reciprocal return path to sender, but send DIRECTLY!
mesh::Packet* rpath = createPathReturn(&src_hash, secret, pkt->path, pkt->path_len, 0, NULL, 0);
if (rpath) sendDirect(rpath, path, path_len, 500);
}
}
} else {
} else if (!already_dispatched) {
onPeerDataRecv(pkt, pkt->getPayloadType(), j, secret, data, len);
}
found = true;
Expand All @@ -186,7 +189,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
}
if (found) {
pkt->markDoNotRetransmit(); // packet was for this node, so don't retransmit
} else {
} else if (!already_dispatched) {
MESH_DEBUG_PRINTLN("%s recv matches no peers, src_hash=%02X", getLogDateTime(), (uint32_t)src_hash);
}
}
Expand All @@ -202,8 +205,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data
if (i + 2 >= pkt->payload_len) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime());
} else if (!_tables->wasSeen(pkt)) {
_tables->markSeen(pkt);
} else {
bool already_dispatched = _tables->wasSeen(pkt);
if (!already_dispatched) _tables->markSeen(pkt);
if (self_id.isHashMatch(&dest_hash)) {
Identity sender(sender_pub_key);

Expand All @@ -214,7 +218,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
uint8_t data[MAX_PACKET_PAYLOAD];
int len = Utils::MACThenDecrypt(secret, data, macAndData, pkt->payload_len - i);
if (len > 0) { // success!
onAnonDataRecv(pkt, secret, sender, data, len);
if (!already_dispatched) onAnonDataRecv(pkt, secret, sender, data, len);
pkt->markDoNotRetransmit();
}
}
Expand All @@ -230,23 +234,27 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data
if (i + 2 >= pkt->payload_len) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime());
} else if (!_tables->wasSeen(pkt)) {
_tables->markSeen(pkt);
// scan channels DB, for all matching hashes of 'channel_hash' (max 4 matches supported ATM)
GroupChannel channels[4];
int num = searchChannelsByHash(&channel_hash, channels, 4);
// for each matching channel, try to decrypt data
for (int j = 0; j < num; j++) {
// decrypt, checking MAC is valid
uint8_t data[MAX_PACKET_PAYLOAD];
int len = Utils::MACThenDecrypt(channels[j].secret, data, macAndData, pkt->payload_len - i);
if (len > 0) { // success!
onGroupDataRecv(pkt, pkt->getPayloadType(), channels[j], data, len);
break;
}
}
break;
}
if (_tables->wasSeen(pkt)) {
action = routeRecvPacket(pkt);
break;
}
_tables->markSeen(pkt);
// scan channels DB, for all matching hashes of 'channel_hash' (max 4 matches supported ATM)
GroupChannel channels[4];
int num = searchChannelsByHash(&channel_hash, channels, 4);
// for each matching channel, try to decrypt data
for (int j = 0; j < num; j++) {
// decrypt, checking MAC is valid
uint8_t data[MAX_PACKET_PAYLOAD];
int len = Utils::MACThenDecrypt(channels[j].secret, data, macAndData, pkt->payload_len - i);
if (len > 0) { // success!
onGroupDataRecv(pkt, pkt->getPayloadType(), channels[j], data, len);
break;
}
}
action = routeRecvPacket(pkt);
break;
}
case PAYLOAD_TYPE_ADVERT: {
Expand All @@ -262,8 +270,9 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete advertisement packet", getLogDateTime());
} else if (self_id.matches(id.pub_key)) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): receiving SELF advert packet", getLogDateTime());
} else if (!_tables->wasSeen(pkt)) {
_tables->markSeen(pkt);
} else {
bool already_dispatched = _tables->wasSeen(pkt);
if (!already_dispatched) _tables->markSeen(pkt);
uint8_t* app_data = &pkt->payload[i];
int app_data_len = pkt->payload_len - i;
if (app_data_len > MAX_ADVERT_DATA_SIZE) { app_data_len = MAX_ADVERT_DATA_SIZE; }
Expand All @@ -281,7 +290,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) {
}
if (is_ok) {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): valid advertisement received!", getLogDateTime());
onAdvertRecv(pkt, id, timestamp, app_data, app_data_len);
if (!already_dispatched) onAdvertRecv(pkt, id, timestamp, app_data, app_data_len);
action = routeRecvPacket(pkt);
} else {
MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): received advertisement with forged signature! (app_data_len=%d)", getLogDateTime(), app_data_len);
Expand Down Expand Up @@ -344,7 +353,9 @@ void Mesh::removeSelfFromPath(Packet* pkt) {
DispatcherAction Mesh::routeRecvPacket(Packet* packet) {
uint8_t n = packet->getPathHashCount();
if (packet->isRouteFlood() && !packet->isMarkedDoNotRetransmit()
&& (n + 1)*packet->getPathHashSize() <= MAX_PATH_SIZE && allowPacketForward(packet)) {
&& (n + 1)*packet->getPathHashSize() <= MAX_PATH_SIZE && allowPacketForward(packet)
&& !_tables->wasForwarded(packet)) {
_tables->markForwarded(packet);
// append this node's hash to 'path'
self_id.copyHashTo(&packet->path[n * packet->getPathHashSize()], packet->getPathHashSize());
packet->setPathHashCount(n + 1);
Expand Down Expand Up @@ -649,6 +660,7 @@ void Mesh::sendFlood(Packet* packet, uint32_t delay_millis, uint8_t path_hash_si
packet->setPathHashSizeAndCount(path_hash_size, 0);

_tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us
_tables->markForwarded(packet); // and do not forward rebroadcast copies of our own flood

uint8_t pri;
if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) {
Expand Down Expand Up @@ -678,6 +690,7 @@ void Mesh::sendFlood(Packet* packet, uint16_t* transport_codes, uint32_t delay_m
packet->setPathHashSizeAndCount(path_hash_size, 0);

_tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us
_tables->markForwarded(packet); // and do not forward rebroadcast copies of our own flood

uint8_t pri;
if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) {
Expand Down Expand Up @@ -738,4 +751,4 @@ void Mesh::sendZeroHop(Packet* packet, uint16_t* transport_codes, uint32_t delay
sendPacket(packet, 0, delay_millis);
}

}
}
2 changes: 2 additions & 0 deletions src/Mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class MeshTables {
public:
virtual bool wasSeen(const Packet* packet) = 0;
virtual void markSeen(const Packet* packet) = 0;
virtual bool wasForwarded(const Packet* packet) = 0;
virtual void markForwarded(const Packet* packet) = 0;
virtual void clear(const Packet* packet) = 0; // remove this packet hash from table
};

Expand Down
77 changes: 57 additions & 20 deletions src/helpers/SimpleMeshTables.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@
#define MAX_PACKET_HASHES (128+32)

class SimpleMeshTables : public mesh::MeshTables {
enum SeenType : uint8_t {
SEEN_DISPATCH = 0x01,
SEEN_FORWARD = 0x02,
SEEN_BOTH = SEEN_DISPATCH | SEEN_FORWARD,
};

uint8_t _hashes[MAX_PACKET_HASHES*MAX_HASH_SIZE];
uint8_t _seen_types[MAX_PACKET_HASHES];
int _next_idx;
uint32_t _direct_dups, _flood_dups;

public:
SimpleMeshTables() {
memset(_hashes, 0, sizeof(_hashes));
memset(_seen_types, 0, sizeof(_seen_types));
_next_idx = 0;
_direct_dups = _flood_dups = 0;
}

#ifdef ESP32
void restoreFrom(File f) {
f.read(_hashes, sizeof(_hashes));
memset(_seen_types, SEEN_BOTH, sizeof(_seen_types));
f.read((uint8_t *) &_next_idx, sizeof(_next_idx));
}
void saveTo(File f) {
Expand All @@ -31,41 +40,69 @@ class SimpleMeshTables : public mesh::MeshTables {
}
#endif

bool wasSeen(const mesh::Packet* packet) override {
int findHash(const uint8_t hash[]) const {
const uint8_t* sp = _hashes;
for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) {
if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) return i;
}
return -1;
}

bool hasType(const mesh::Packet* packet, uint8_t seen_type, bool count_dup) {
uint8_t hash[MAX_HASH_SIZE];
packet->calculatePacketHash(hash);

const uint8_t* sp = _hashes;
for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) {
if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) {
if (packet->isRouteDirect()) {
_direct_dups++;
} else {
_flood_dups++;
}
return true;
int idx = findHash(hash);
if (idx < 0 || (_seen_types[idx] & seen_type) == 0) return false;

if (count_dup) {
if (packet->isRouteDirect()) {
_direct_dups++;
} else {
_flood_dups++;
}
}
return false;
return true;
}

void markSeen(const mesh::Packet* packet) override {
void markType(const mesh::Packet* packet, uint8_t seen_type) {
uint8_t hash[MAX_HASH_SIZE];
packet->calculatePacketHash(hash);
memcpy(&_hashes[_next_idx * MAX_HASH_SIZE], hash, MAX_HASH_SIZE);
_next_idx = (_next_idx + 1) % MAX_PACKET_HASHES;

int idx = findHash(hash);
if (idx < 0) {
idx = _next_idx;
memcpy(&_hashes[idx * MAX_HASH_SIZE], hash, MAX_HASH_SIZE);
_seen_types[idx] = 0;
_next_idx = (_next_idx + 1) % MAX_PACKET_HASHES;
}
_seen_types[idx] |= seen_type;
}

bool wasSeen(const mesh::Packet* packet) override {
return hasType(packet, SEEN_DISPATCH, true);
}

void markSeen(const mesh::Packet* packet) override {
markType(packet, SEEN_DISPATCH);
}

bool wasForwarded(const mesh::Packet* packet) override {
return hasType(packet, SEEN_FORWARD, false);
}

void markForwarded(const mesh::Packet* packet) override {
markType(packet, SEEN_FORWARD);
}

void clear(const mesh::Packet* packet) override {
uint8_t hash[MAX_HASH_SIZE];
packet->calculatePacketHash(hash);

uint8_t* sp = _hashes;
for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) {
if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) {
memset(sp, 0, MAX_HASH_SIZE);
break;
}
int idx = findHash(hash);
if (idx >= 0) {
memset(&_hashes[idx * MAX_HASH_SIZE], 0, MAX_HASH_SIZE);
_seen_types[idx] = 0;
}
}

Expand Down
20 changes: 20 additions & 0 deletions test/test_mesh_tables/test_simple_mesh_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ TEST(SimpleMeshTables, QueryThenMark_WorksCorrectly) {
EXPECT_TRUE(t.wasSeen(&p));
}

// ── forwarding state ─────────────────────────────────────────────────────────

TEST(SimpleMeshTables, MarkSeen_DoesNotMarkForwarded) {
SimpleMeshTables t;
Packet p = makeFloodPacket(0x01);
t.markSeen(&p);
EXPECT_FALSE(t.wasForwarded(&p));
}

TEST(SimpleMeshTables, MarkForwarded_DoesNotMarkSeen) {
SimpleMeshTables t;
Packet p = makeFloodPacket(0x01);
t.markForwarded(&p);
EXPECT_TRUE(t.wasForwarded(&p));
EXPECT_FALSE(t.wasSeen(&p));
}

// ── dup stats ────────────────────────────────────────────────────────────────

TEST(SimpleMeshTables, WasSeen_IncrementsFloodDupStat) {
Expand All @@ -92,9 +109,12 @@ TEST(SimpleMeshTables, Clear_RemovesSeenPacket) {
SimpleMeshTables t;
Packet p = makeFloodPacket(0x01);
t.markSeen(&p);
t.markForwarded(&p);
ASSERT_TRUE(t.wasSeen(&p));
ASSERT_TRUE(t.wasForwarded(&p));
t.clear(&p);
EXPECT_FALSE(t.wasSeen(&p));
EXPECT_FALSE(t.wasForwarded(&p));
}

int main(int argc, char** argv) {
Expand Down