1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14pub struct ConversationTable {
21 conversations: DashMap<CanonicalKey, ConversationState>,
22 config: FlowConfig,
23 memory_tracker: Arc<MemoryTracker>,
24 spill_count: std::sync::atomic::AtomicUsize,
25}
26
27impl ConversationTable {
28 #[must_use]
30 pub fn new(config: FlowConfig) -> Self {
31 let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
32 Self {
33 conversations: DashMap::new(),
34 config,
35 memory_tracker,
36 spill_count: std::sync::atomic::AtomicUsize::new(0),
37 }
38 }
39
40 #[must_use]
42 pub fn with_default_config() -> Self {
43 Self::new(FlowConfig::default())
44 }
45
46 #[must_use]
48 pub fn conversation_count(&self) -> usize {
49 self.conversations.len()
50 }
51
52 pub fn ingest_packet(
58 &self,
59 packet: &Packet,
60 timestamp: Duration,
61 packet_index: usize,
62 ) -> Result<(), FlowError> {
63 let (key, direction) = match extract_key(packet) {
64 Ok(result) => result,
65 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
66 return Ok(());
68 },
69 Err(e) => return Err(e),
70 };
71
72 let byte_count = packet.as_bytes().len() as u64;
73
74 let mut entry = self
76 .conversations
77 .entry(key.clone())
78 .or_insert_with(|| ConversationState::new(key, timestamp));
79
80 let conv = entry.value_mut();
81
82 conv.record_packet(
84 direction,
85 byte_count,
86 timestamp,
87 packet_index,
88 self.config.track_max_packet_len,
89 self.config.track_max_flow_len,
90 self.config.store_packet_indices,
91 );
92
93 let buf = packet.as_bytes();
95 match &mut conv.protocol_state {
96 ProtocolState::Tcp(tcp_state) => {
97 if let Some(tcp) = packet.tcp() {
98 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
99 }
100 },
101 ProtocolState::Udp(udp_state) => {
102 udp_state.process_packet();
103 },
104 ProtocolState::Icmp(icmp_state) => {
105 if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
107 if buf.len() >= icmp_layer.start + 2 {
108 let icmp_type = buf[icmp_layer.start];
109 let icmp_code = buf[icmp_layer.start + 1];
110 icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
111 }
112 }
113 },
114 ProtocolState::Icmpv6(icmpv6_state) => {
115 if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
117 if buf.len() >= icmpv6_layer.start + 2 {
118 let icmpv6_type = buf[icmpv6_layer.start];
119 let icmpv6_code = buf[icmpv6_layer.start + 1];
120 icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
121 }
122 }
123 },
124 ProtocolState::ZWave(_) => {},
125 ProtocolState::Other => {},
126 }
127
128 conv.update_status();
130
131 if self.memory_tracker.has_budget() {
133 if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
134 let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
136 let rev_spilled = tcp_state.reassembler_rev.is_spilled();
137 if !fwd_spilled || !rev_spilled {
138 let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
139 let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
140 let payload_start = tcp.index.start + data_offset;
141 buf.len().saturating_sub(payload_start)
142 });
143 if tcp_payload_len > 0 {
144 self.memory_tracker.add(tcp_payload_len);
145 }
146 }
147 }
148 }
149
150 drop(entry);
152
153 if self.memory_tracker.is_over_budget() {
155 self.maybe_spill();
156 }
157
158 Ok(())
159 }
160
161 fn maybe_spill(&self) {
169 let mut spills = 0;
170 let max_spills = 64;
171 let mut consecutive_skips = 0;
172 let max_skip = 512;
173
174 for mut entry in self.conversations.iter_mut() {
175 if !self.memory_tracker.is_over_budget() || spills >= max_spills {
176 break;
177 }
178 if consecutive_skips >= max_skip {
179 break;
181 }
182
183 if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
184 if tcp_state.reassembler_fwd.is_spilled() && tcp_state.reassembler_rev.is_spilled()
186 {
187 consecutive_skips += 1;
188 continue;
189 }
190 let freed_fwd = tcp_state
191 .reassembler_fwd
192 .spill(self.config.spill_dir.as_deref())
193 .unwrap_or(0);
194 let freed_rev = tcp_state
195 .reassembler_rev
196 .spill(self.config.spill_dir.as_deref())
197 .unwrap_or(0);
198 let total_freed = freed_fwd + freed_rev;
199 if total_freed > 0 {
200 self.memory_tracker.subtract(total_freed);
201 self.spill_count
202 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
203 spills += 1;
204 consecutive_skips = 0; } else {
206 consecutive_skips += 1;
207 }
208 } else {
209 consecutive_skips += 1;
210 }
211 }
212 }
213
214 #[must_use]
216 pub fn memory_usage(&self) -> usize {
217 self.memory_tracker.current_usage()
218 }
219
220 #[must_use]
222 pub fn spill_count(&self) -> usize {
223 self.spill_count.load(std::sync::atomic::Ordering::Relaxed)
224 }
225
226 #[must_use]
228 pub fn get_conversation(
229 &self,
230 key: &CanonicalKey,
231 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
232 self.conversations.get(key)
233 }
234
235 #[must_use]
239 pub fn evict_idle(&self, now: Duration) -> usize {
240 let mut evicted = 0;
241 self.conversations.retain(|_, conv| {
242 if conv.is_timed_out(now, &self.config) {
243 evicted += 1;
244 false
245 } else {
246 true
247 }
248 });
249 evicted
250 }
251
252 #[must_use]
254 pub fn into_conversations(self) -> Vec<ConversationState> {
255 let mut conversations: Vec<ConversationState> =
256 self.conversations.into_iter().map(|(_, v)| v).collect();
257 conversations.sort_by_key(|c| c.start_time);
258 conversations
259 }
260
261 #[must_use]
263 pub fn config(&self) -> &FlowConfig {
264 &self.config
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::layer::stack::{LayerStack, LayerStackEntry};
272 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
273 use std::net::Ipv4Addr;
274
275 fn make_tcp_packet(
276 src_ip: Ipv4Addr,
277 dst_ip: Ipv4Addr,
278 sport: u16,
279 dport: u16,
280 flags: &str,
281 ) -> Packet {
282 let mut builder = TcpBuilder::new()
283 .src_port(sport)
284 .dst_port(dport)
285 .seq(1000)
286 .ack_num(0)
287 .window(65535);
288
289 for c in flags.chars() {
290 builder = match c {
291 'S' => builder.syn(),
292 'A' => builder.ack(),
293 'F' => builder.fin(),
294 'R' => builder.rst(),
295 _ => builder,
296 };
297 }
298
299 LayerStack::new()
300 .push(LayerStackEntry::Ethernet(
301 EthernetBuilder::new()
302 .dst(MacAddress::BROADCAST)
303 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
304 ))
305 .push(LayerStackEntry::Ipv4(
306 Ipv4Builder::new().src(src_ip).dst(dst_ip),
307 ))
308 .push(LayerStackEntry::Tcp(builder))
309 .build_packet()
310 }
311
312 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
313 LayerStack::new()
314 .push(LayerStackEntry::Ethernet(
315 EthernetBuilder::new()
316 .dst(MacAddress::BROADCAST)
317 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
318 ))
319 .push(LayerStackEntry::Ipv4(
320 Ipv4Builder::new().src(src_ip).dst(dst_ip),
321 ))
322 .push(LayerStackEntry::Udp(
323 UdpBuilder::new().src_port(sport).dst_port(dport),
324 ))
325 .build_packet()
326 }
327
328 #[test]
329 fn test_ingest_creates_conversation() {
330 let table = ConversationTable::with_default_config();
331 let pkt = make_tcp_packet(
332 Ipv4Addr::new(10, 0, 0, 1),
333 Ipv4Addr::new(10, 0, 0, 2),
334 12345,
335 80,
336 "S",
337 );
338
339 table
340 .ingest_packet(&pkt, Duration::from_secs(1), 0)
341 .unwrap();
342 assert_eq!(table.conversation_count(), 1);
343 }
344
345 #[test]
346 fn test_bidirectional_same_conversation() {
347 let table = ConversationTable::with_default_config();
348
349 let pkt_fwd = make_tcp_packet(
351 Ipv4Addr::new(10, 0, 0, 1),
352 Ipv4Addr::new(10, 0, 0, 2),
353 12345,
354 80,
355 "S",
356 );
357 table
358 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
359 .unwrap();
360
361 let pkt_rev = make_tcp_packet(
363 Ipv4Addr::new(10, 0, 0, 2),
364 Ipv4Addr::new(10, 0, 0, 1),
365 80,
366 12345,
367 "SA",
368 );
369 table
370 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
371 .unwrap();
372
373 assert_eq!(table.conversation_count(), 1);
375
376 let conversations = table.into_conversations();
377 assert_eq!(conversations[0].total_packets(), 2);
378 assert_eq!(conversations[0].forward.packets, 1);
379 assert_eq!(conversations[0].reverse.packets, 1);
380 }
381
382 #[test]
383 fn test_different_flows_different_conversations() {
384 let table = ConversationTable::with_default_config();
385
386 let pkt1 = make_tcp_packet(
387 Ipv4Addr::new(10, 0, 0, 1),
388 Ipv4Addr::new(10, 0, 0, 2),
389 12345,
390 80,
391 "S",
392 );
393 let pkt2 = make_tcp_packet(
394 Ipv4Addr::new(10, 0, 0, 1),
395 Ipv4Addr::new(10, 0, 0, 3),
396 12345,
397 443,
398 "S",
399 );
400
401 table
402 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
403 .unwrap();
404 table
405 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
406 .unwrap();
407
408 assert_eq!(table.conversation_count(), 2);
409 }
410
411 #[test]
412 fn test_udp_conversation() {
413 let table = ConversationTable::with_default_config();
414
415 let pkt = make_udp_packet(
416 Ipv4Addr::new(10, 0, 0, 1),
417 Ipv4Addr::new(10, 0, 0, 2),
418 12345,
419 53,
420 );
421 table
422 .ingest_packet(&pkt, Duration::from_secs(1), 0)
423 .unwrap();
424
425 let conversations = table.into_conversations();
426 assert_eq!(conversations.len(), 1);
427 assert!(matches!(
428 conversations[0].protocol_state,
429 ProtocolState::Udp(_)
430 ));
431 }
432
433 #[test]
434 fn test_evict_idle() {
435 let mut config = FlowConfig::default();
436 config.udp_timeout = Duration::from_secs(10);
437 let table = ConversationTable::new(config);
438
439 let pkt = make_udp_packet(
440 Ipv4Addr::new(10, 0, 0, 1),
441 Ipv4Addr::new(10, 0, 0, 2),
442 12345,
443 53,
444 );
445 table
446 .ingest_packet(&pkt, Duration::from_secs(1), 0)
447 .unwrap();
448 assert_eq!(table.conversation_count(), 1);
449
450 let evicted = table.evict_idle(Duration::from_secs(5));
452 assert_eq!(evicted, 0);
453 assert_eq!(table.conversation_count(), 1);
454
455 let evicted = table.evict_idle(Duration::from_secs(20));
457 assert_eq!(evicted, 1);
458 assert_eq!(table.conversation_count(), 0);
459 }
460
461 #[test]
462 fn test_into_conversations_sorted() {
463 let table = ConversationTable::with_default_config();
464
465 let pkt1 = make_tcp_packet(
466 Ipv4Addr::new(10, 0, 0, 1),
467 Ipv4Addr::new(10, 0, 0, 2),
468 12345,
469 80,
470 "S",
471 );
472 let pkt2 = make_tcp_packet(
473 Ipv4Addr::new(10, 0, 0, 1),
474 Ipv4Addr::new(10, 0, 0, 3),
475 12345,
476 443,
477 "S",
478 );
479
480 table
482 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
483 .unwrap();
484 table
485 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
486 .unwrap();
487
488 let conversations = table.into_conversations();
489 assert!(conversations[0].start_time <= conversations[1].start_time);
490 }
491}