1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14pub struct ConversationTable {
21 conversations: DashMap<CanonicalKey, ConversationState>,
22 config: FlowConfig,
23 memory_tracker: Arc<MemoryTracker>,
24}
25
26impl ConversationTable {
27 #[must_use]
29 pub fn new(config: FlowConfig) -> Self {
30 let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
31 Self {
32 conversations: DashMap::new(),
33 config,
34 memory_tracker,
35 }
36 }
37
38 #[must_use]
40 pub fn with_default_config() -> Self {
41 Self::new(FlowConfig::default())
42 }
43
44 #[must_use]
46 pub fn conversation_count(&self) -> usize {
47 self.conversations.len()
48 }
49
50 pub fn ingest_packet(
56 &self,
57 packet: &Packet,
58 timestamp: Duration,
59 packet_index: usize,
60 ) -> Result<(), FlowError> {
61 let (key, direction) = match extract_key(packet) {
62 Ok(result) => result,
63 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
64 return Ok(());
66 },
67 Err(e) => return Err(e),
68 };
69
70 let byte_count = packet.as_bytes().len() as u64;
71
72 let mut entry = self
74 .conversations
75 .entry(key.clone())
76 .or_insert_with(|| ConversationState::new(key, timestamp));
77
78 let conv = entry.value_mut();
79
80 conv.record_packet(
82 direction,
83 byte_count,
84 timestamp,
85 packet_index,
86 self.config.track_max_packet_len,
87 self.config.track_max_flow_len,
88 );
89
90 let buf = packet.as_bytes();
92 match &mut conv.protocol_state {
93 ProtocolState::Tcp(tcp_state) => {
94 if let Some(tcp) = packet.tcp() {
95 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
96 }
97 },
98 ProtocolState::Udp(udp_state) => {
99 udp_state.process_packet();
100 },
101 ProtocolState::Icmp(icmp_state) => {
102 if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
104 if buf.len() >= icmp_layer.start + 2 {
105 let icmp_type = buf[icmp_layer.start];
106 let icmp_code = buf[icmp_layer.start + 1];
107 icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
108 }
109 }
110 },
111 ProtocolState::Icmpv6(icmpv6_state) => {
112 if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
114 if buf.len() >= icmpv6_layer.start + 2 {
115 let icmpv6_type = buf[icmpv6_layer.start];
116 let icmpv6_code = buf[icmpv6_layer.start + 1];
117 icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
118 }
119 }
120 },
121 ProtocolState::ZWave(_) => {},
122 ProtocolState::Other => {},
123 }
124
125 conv.update_status();
127
128 if self.memory_tracker.has_budget() {
130 if matches!(conv.protocol_state, ProtocolState::Tcp(_)) {
131 let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
133 let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
134 let payload_start = tcp.index.start + data_offset;
135 buf.len().saturating_sub(payload_start)
136 });
137 if tcp_payload_len > 0 {
138 self.memory_tracker.add(tcp_payload_len);
139 }
140 }
141 }
142
143 drop(entry);
145
146 if self.memory_tracker.is_over_budget() {
148 self.maybe_spill();
149 }
150
151 Ok(())
152 }
153
154 fn maybe_spill(&self) {
156 for mut entry in self.conversations.iter_mut() {
157 if !self.memory_tracker.is_over_budget() {
158 break;
159 }
160 if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
161 let freed_fwd = tcp_state
162 .reassembler_fwd
163 .spill(self.config.spill_dir.as_deref())
164 .unwrap_or(0);
165 let freed_rev = tcp_state
166 .reassembler_rev
167 .spill(self.config.spill_dir.as_deref())
168 .unwrap_or(0);
169 let total_freed = freed_fwd + freed_rev;
170 if total_freed > 0 {
171 self.memory_tracker.subtract(total_freed);
172 }
173 }
174 }
175 }
176
177 #[must_use]
179 pub fn get_conversation(
180 &self,
181 key: &CanonicalKey,
182 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
183 self.conversations.get(key)
184 }
185
186 #[must_use]
190 pub fn evict_idle(&self, now: Duration) -> usize {
191 let mut evicted = 0;
192 self.conversations.retain(|_, conv| {
193 if conv.is_timed_out(now, &self.config) {
194 evicted += 1;
195 false
196 } else {
197 true
198 }
199 });
200 evicted
201 }
202
203 #[must_use]
205 pub fn into_conversations(self) -> Vec<ConversationState> {
206 let mut conversations: Vec<ConversationState> =
207 self.conversations.into_iter().map(|(_, v)| v).collect();
208 conversations.sort_by_key(|c| c.start_time);
209 conversations
210 }
211
212 #[must_use]
214 pub fn config(&self) -> &FlowConfig {
215 &self.config
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::layer::stack::{LayerStack, LayerStackEntry};
223 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
224 use std::net::Ipv4Addr;
225
226 fn make_tcp_packet(
227 src_ip: Ipv4Addr,
228 dst_ip: Ipv4Addr,
229 sport: u16,
230 dport: u16,
231 flags: &str,
232 ) -> Packet {
233 let mut builder = TcpBuilder::new()
234 .src_port(sport)
235 .dst_port(dport)
236 .seq(1000)
237 .ack_num(0)
238 .window(65535);
239
240 for c in flags.chars() {
241 builder = match c {
242 'S' => builder.syn(),
243 'A' => builder.ack(),
244 'F' => builder.fin(),
245 'R' => builder.rst(),
246 _ => builder,
247 };
248 }
249
250 LayerStack::new()
251 .push(LayerStackEntry::Ethernet(
252 EthernetBuilder::new()
253 .dst(MacAddress::BROADCAST)
254 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
255 ))
256 .push(LayerStackEntry::Ipv4(
257 Ipv4Builder::new().src(src_ip).dst(dst_ip),
258 ))
259 .push(LayerStackEntry::Tcp(builder))
260 .build_packet()
261 }
262
263 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
264 LayerStack::new()
265 .push(LayerStackEntry::Ethernet(
266 EthernetBuilder::new()
267 .dst(MacAddress::BROADCAST)
268 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
269 ))
270 .push(LayerStackEntry::Ipv4(
271 Ipv4Builder::new().src(src_ip).dst(dst_ip),
272 ))
273 .push(LayerStackEntry::Udp(
274 UdpBuilder::new().src_port(sport).dst_port(dport),
275 ))
276 .build_packet()
277 }
278
279 #[test]
280 fn test_ingest_creates_conversation() {
281 let table = ConversationTable::with_default_config();
282 let pkt = make_tcp_packet(
283 Ipv4Addr::new(10, 0, 0, 1),
284 Ipv4Addr::new(10, 0, 0, 2),
285 12345,
286 80,
287 "S",
288 );
289
290 table
291 .ingest_packet(&pkt, Duration::from_secs(1), 0)
292 .unwrap();
293 assert_eq!(table.conversation_count(), 1);
294 }
295
296 #[test]
297 fn test_bidirectional_same_conversation() {
298 let table = ConversationTable::with_default_config();
299
300 let pkt_fwd = make_tcp_packet(
302 Ipv4Addr::new(10, 0, 0, 1),
303 Ipv4Addr::new(10, 0, 0, 2),
304 12345,
305 80,
306 "S",
307 );
308 table
309 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
310 .unwrap();
311
312 let pkt_rev = make_tcp_packet(
314 Ipv4Addr::new(10, 0, 0, 2),
315 Ipv4Addr::new(10, 0, 0, 1),
316 80,
317 12345,
318 "SA",
319 );
320 table
321 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
322 .unwrap();
323
324 assert_eq!(table.conversation_count(), 1);
326
327 let conversations = table.into_conversations();
328 assert_eq!(conversations[0].total_packets(), 2);
329 assert_eq!(conversations[0].forward.packets, 1);
330 assert_eq!(conversations[0].reverse.packets, 1);
331 }
332
333 #[test]
334 fn test_different_flows_different_conversations() {
335 let table = ConversationTable::with_default_config();
336
337 let pkt1 = make_tcp_packet(
338 Ipv4Addr::new(10, 0, 0, 1),
339 Ipv4Addr::new(10, 0, 0, 2),
340 12345,
341 80,
342 "S",
343 );
344 let pkt2 = make_tcp_packet(
345 Ipv4Addr::new(10, 0, 0, 1),
346 Ipv4Addr::new(10, 0, 0, 3),
347 12345,
348 443,
349 "S",
350 );
351
352 table
353 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
354 .unwrap();
355 table
356 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
357 .unwrap();
358
359 assert_eq!(table.conversation_count(), 2);
360 }
361
362 #[test]
363 fn test_udp_conversation() {
364 let table = ConversationTable::with_default_config();
365
366 let pkt = make_udp_packet(
367 Ipv4Addr::new(10, 0, 0, 1),
368 Ipv4Addr::new(10, 0, 0, 2),
369 12345,
370 53,
371 );
372 table
373 .ingest_packet(&pkt, Duration::from_secs(1), 0)
374 .unwrap();
375
376 let conversations = table.into_conversations();
377 assert_eq!(conversations.len(), 1);
378 assert!(matches!(
379 conversations[0].protocol_state,
380 ProtocolState::Udp(_)
381 ));
382 }
383
384 #[test]
385 fn test_evict_idle() {
386 let mut config = FlowConfig::default();
387 config.udp_timeout = Duration::from_secs(10);
388 let table = ConversationTable::new(config);
389
390 let pkt = make_udp_packet(
391 Ipv4Addr::new(10, 0, 0, 1),
392 Ipv4Addr::new(10, 0, 0, 2),
393 12345,
394 53,
395 );
396 table
397 .ingest_packet(&pkt, Duration::from_secs(1), 0)
398 .unwrap();
399 assert_eq!(table.conversation_count(), 1);
400
401 let evicted = table.evict_idle(Duration::from_secs(5));
403 assert_eq!(evicted, 0);
404 assert_eq!(table.conversation_count(), 1);
405
406 let evicted = table.evict_idle(Duration::from_secs(20));
408 assert_eq!(evicted, 1);
409 assert_eq!(table.conversation_count(), 0);
410 }
411
412 #[test]
413 fn test_into_conversations_sorted() {
414 let table = ConversationTable::with_default_config();
415
416 let pkt1 = make_tcp_packet(
417 Ipv4Addr::new(10, 0, 0, 1),
418 Ipv4Addr::new(10, 0, 0, 2),
419 12345,
420 80,
421 "S",
422 );
423 let pkt2 = make_tcp_packet(
424 Ipv4Addr::new(10, 0, 0, 1),
425 Ipv4Addr::new(10, 0, 0, 3),
426 12345,
427 443,
428 "S",
429 );
430
431 table
433 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
434 .unwrap();
435 table
436 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
437 .unwrap();
438
439 let conversations = table.into_conversations();
440 assert!(conversations[0].start_time <= conversations[1].start_time);
441 }
442}