1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14pub struct ConversationTable {
21 conversations: DashMap<CanonicalKey, ConversationState>,
22 config: FlowConfig,
23 memory_tracker: Arc<MemoryTracker>,
24 spill_count: std::sync::atomic::AtomicUsize,
25}
26
27impl ConversationTable {
28 #[must_use]
30 pub fn new(config: FlowConfig) -> Self {
31 let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
32 Self {
33 conversations: DashMap::new(),
34 config,
35 memory_tracker,
36 spill_count: std::sync::atomic::AtomicUsize::new(0),
37 }
38 }
39
40 #[must_use]
42 pub fn with_default_config() -> Self {
43 Self::new(FlowConfig::default())
44 }
45
46 #[must_use]
48 pub fn conversation_count(&self) -> usize {
49 self.conversations.len()
50 }
51
52 pub fn ingest_packet(
58 &self,
59 packet: &Packet,
60 timestamp: Duration,
61 packet_index: usize,
62 ) -> Result<(), FlowError> {
63 let (key, direction) = match extract_key(packet) {
64 Ok(result) => result,
65 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
66 return Ok(());
68 },
69 Err(e) => return Err(e),
70 };
71
72 let byte_count = packet.as_bytes().len() as u64;
73
74 let mut entry = self
76 .conversations
77 .entry(key.clone())
78 .or_insert_with(|| ConversationState::new(key, timestamp));
79
80 let conv = entry.value_mut();
81
82 conv.record_packet(
84 direction,
85 byte_count,
86 timestamp,
87 packet_index,
88 self.config.track_max_packet_len,
89 self.config.track_max_flow_len,
90 self.config.store_packet_indices,
91 );
92
93 let buf = packet.as_bytes();
95 match &mut conv.protocol_state {
96 ProtocolState::Tcp(tcp_state) => {
97 if let Some(tcp) = packet.tcp() {
98 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
99 }
100 },
101 ProtocolState::Udp(udp_state) => {
102 udp_state.process_packet();
103 },
104 ProtocolState::Icmp(icmp_state) => {
105 if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
107 if buf.len() >= icmp_layer.start + 2 {
108 let icmp_type = buf[icmp_layer.start];
109 let icmp_code = buf[icmp_layer.start + 1];
110 icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
111 }
112 }
113 },
114 ProtocolState::Icmpv6(icmpv6_state) => {
115 if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
117 if buf.len() >= icmpv6_layer.start + 2 {
118 let icmpv6_type = buf[icmpv6_layer.start];
119 let icmpv6_code = buf[icmpv6_layer.start + 1];
120 icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
121 }
122 }
123 },
124 ProtocolState::ZWave(_) => {},
125 ProtocolState::Other => {},
126 }
127
128 conv.update_status();
130
131 if self.memory_tracker.has_budget() {
133 if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
134 let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
136 let rev_spilled = tcp_state.reassembler_rev.is_spilled();
137 if !fwd_spilled || !rev_spilled {
138 let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
139 let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
140 let payload_start = tcp.index.start + data_offset;
141 buf.len().saturating_sub(payload_start)
142 });
143 if tcp_payload_len > 0 {
144 self.memory_tracker.add(tcp_payload_len);
145 }
146 }
147 }
148 }
149
150 drop(entry);
152
153 if self.memory_tracker.is_over_budget() {
155 self.maybe_spill();
156 }
157
158 Ok(())
159 }
160
161 fn maybe_spill(&self) {
166 let mut scanned = 0;
167 let max_scan = 256; for mut entry in self.conversations.iter_mut() {
169 if !self.memory_tracker.is_over_budget() || scanned >= max_scan {
170 break;
171 }
172 scanned += 1;
173 if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
174 if tcp_state.reassembler_fwd.is_spilled() && tcp_state.reassembler_rev.is_spilled()
176 {
177 continue;
178 }
179 let freed_fwd = tcp_state
180 .reassembler_fwd
181 .spill(self.config.spill_dir.as_deref())
182 .unwrap_or(0);
183 let freed_rev = tcp_state
184 .reassembler_rev
185 .spill(self.config.spill_dir.as_deref())
186 .unwrap_or(0);
187 let total_freed = freed_fwd + freed_rev;
188 if total_freed > 0 {
189 self.memory_tracker.subtract(total_freed);
190 self.spill_count
191 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
192 }
193 }
194 }
195 }
196
197 #[must_use]
199 pub fn memory_usage(&self) -> usize {
200 self.memory_tracker.current_usage()
201 }
202
203 #[must_use]
205 pub fn spill_count(&self) -> usize {
206 self.spill_count.load(std::sync::atomic::Ordering::Relaxed)
207 }
208
209 #[must_use]
211 pub fn get_conversation(
212 &self,
213 key: &CanonicalKey,
214 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
215 self.conversations.get(key)
216 }
217
218 #[must_use]
222 pub fn evict_idle(&self, now: Duration) -> usize {
223 let mut evicted = 0;
224 self.conversations.retain(|_, conv| {
225 if conv.is_timed_out(now, &self.config) {
226 evicted += 1;
227 false
228 } else {
229 true
230 }
231 });
232 evicted
233 }
234
235 #[must_use]
237 pub fn into_conversations(self) -> Vec<ConversationState> {
238 let mut conversations: Vec<ConversationState> =
239 self.conversations.into_iter().map(|(_, v)| v).collect();
240 conversations.sort_by_key(|c| c.start_time);
241 conversations
242 }
243
244 #[must_use]
246 pub fn config(&self) -> &FlowConfig {
247 &self.config
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::layer::stack::{LayerStack, LayerStackEntry};
255 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
256 use std::net::Ipv4Addr;
257
258 fn make_tcp_packet(
259 src_ip: Ipv4Addr,
260 dst_ip: Ipv4Addr,
261 sport: u16,
262 dport: u16,
263 flags: &str,
264 ) -> Packet {
265 let mut builder = TcpBuilder::new()
266 .src_port(sport)
267 .dst_port(dport)
268 .seq(1000)
269 .ack_num(0)
270 .window(65535);
271
272 for c in flags.chars() {
273 builder = match c {
274 'S' => builder.syn(),
275 'A' => builder.ack(),
276 'F' => builder.fin(),
277 'R' => builder.rst(),
278 _ => builder,
279 };
280 }
281
282 LayerStack::new()
283 .push(LayerStackEntry::Ethernet(
284 EthernetBuilder::new()
285 .dst(MacAddress::BROADCAST)
286 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
287 ))
288 .push(LayerStackEntry::Ipv4(
289 Ipv4Builder::new().src(src_ip).dst(dst_ip),
290 ))
291 .push(LayerStackEntry::Tcp(builder))
292 .build_packet()
293 }
294
295 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
296 LayerStack::new()
297 .push(LayerStackEntry::Ethernet(
298 EthernetBuilder::new()
299 .dst(MacAddress::BROADCAST)
300 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
301 ))
302 .push(LayerStackEntry::Ipv4(
303 Ipv4Builder::new().src(src_ip).dst(dst_ip),
304 ))
305 .push(LayerStackEntry::Udp(
306 UdpBuilder::new().src_port(sport).dst_port(dport),
307 ))
308 .build_packet()
309 }
310
311 #[test]
312 fn test_ingest_creates_conversation() {
313 let table = ConversationTable::with_default_config();
314 let pkt = make_tcp_packet(
315 Ipv4Addr::new(10, 0, 0, 1),
316 Ipv4Addr::new(10, 0, 0, 2),
317 12345,
318 80,
319 "S",
320 );
321
322 table
323 .ingest_packet(&pkt, Duration::from_secs(1), 0)
324 .unwrap();
325 assert_eq!(table.conversation_count(), 1);
326 }
327
328 #[test]
329 fn test_bidirectional_same_conversation() {
330 let table = ConversationTable::with_default_config();
331
332 let pkt_fwd = make_tcp_packet(
334 Ipv4Addr::new(10, 0, 0, 1),
335 Ipv4Addr::new(10, 0, 0, 2),
336 12345,
337 80,
338 "S",
339 );
340 table
341 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
342 .unwrap();
343
344 let pkt_rev = make_tcp_packet(
346 Ipv4Addr::new(10, 0, 0, 2),
347 Ipv4Addr::new(10, 0, 0, 1),
348 80,
349 12345,
350 "SA",
351 );
352 table
353 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
354 .unwrap();
355
356 assert_eq!(table.conversation_count(), 1);
358
359 let conversations = table.into_conversations();
360 assert_eq!(conversations[0].total_packets(), 2);
361 assert_eq!(conversations[0].forward.packets, 1);
362 assert_eq!(conversations[0].reverse.packets, 1);
363 }
364
365 #[test]
366 fn test_different_flows_different_conversations() {
367 let table = ConversationTable::with_default_config();
368
369 let pkt1 = make_tcp_packet(
370 Ipv4Addr::new(10, 0, 0, 1),
371 Ipv4Addr::new(10, 0, 0, 2),
372 12345,
373 80,
374 "S",
375 );
376 let pkt2 = make_tcp_packet(
377 Ipv4Addr::new(10, 0, 0, 1),
378 Ipv4Addr::new(10, 0, 0, 3),
379 12345,
380 443,
381 "S",
382 );
383
384 table
385 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
386 .unwrap();
387 table
388 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
389 .unwrap();
390
391 assert_eq!(table.conversation_count(), 2);
392 }
393
394 #[test]
395 fn test_udp_conversation() {
396 let table = ConversationTable::with_default_config();
397
398 let pkt = make_udp_packet(
399 Ipv4Addr::new(10, 0, 0, 1),
400 Ipv4Addr::new(10, 0, 0, 2),
401 12345,
402 53,
403 );
404 table
405 .ingest_packet(&pkt, Duration::from_secs(1), 0)
406 .unwrap();
407
408 let conversations = table.into_conversations();
409 assert_eq!(conversations.len(), 1);
410 assert!(matches!(
411 conversations[0].protocol_state,
412 ProtocolState::Udp(_)
413 ));
414 }
415
416 #[test]
417 fn test_evict_idle() {
418 let mut config = FlowConfig::default();
419 config.udp_timeout = Duration::from_secs(10);
420 let table = ConversationTable::new(config);
421
422 let pkt = make_udp_packet(
423 Ipv4Addr::new(10, 0, 0, 1),
424 Ipv4Addr::new(10, 0, 0, 2),
425 12345,
426 53,
427 );
428 table
429 .ingest_packet(&pkt, Duration::from_secs(1), 0)
430 .unwrap();
431 assert_eq!(table.conversation_count(), 1);
432
433 let evicted = table.evict_idle(Duration::from_secs(5));
435 assert_eq!(evicted, 0);
436 assert_eq!(table.conversation_count(), 1);
437
438 let evicted = table.evict_idle(Duration::from_secs(20));
440 assert_eq!(evicted, 1);
441 assert_eq!(table.conversation_count(), 0);
442 }
443
444 #[test]
445 fn test_into_conversations_sorted() {
446 let table = ConversationTable::with_default_config();
447
448 let pkt1 = make_tcp_packet(
449 Ipv4Addr::new(10, 0, 0, 1),
450 Ipv4Addr::new(10, 0, 0, 2),
451 12345,
452 80,
453 "S",
454 );
455 let pkt2 = make_tcp_packet(
456 Ipv4Addr::new(10, 0, 0, 1),
457 Ipv4Addr::new(10, 0, 0, 3),
458 12345,
459 443,
460 "S",
461 );
462
463 table
465 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
466 .unwrap();
467 table
468 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
469 .unwrap();
470
471 let conversations = table.into_conversations();
472 assert!(conversations[0].start_time <= conversations[1].start_time);
473 }
474}