1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14pub struct ConversationTable {
21 conversations: DashMap<CanonicalKey, ConversationState>,
22 config: FlowConfig,
23 memory_tracker: Arc<MemoryTracker>,
24}
25
26impl ConversationTable {
27 #[must_use]
29 pub fn new(config: FlowConfig) -> Self {
30 let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
31 Self {
32 conversations: DashMap::new(),
33 config,
34 memory_tracker,
35 }
36 }
37
38 #[must_use]
40 pub fn with_default_config() -> Self {
41 Self::new(FlowConfig::default())
42 }
43
44 #[must_use]
46 pub fn conversation_count(&self) -> usize {
47 self.conversations.len()
48 }
49
50 pub fn ingest_packet(
56 &self,
57 packet: &Packet,
58 timestamp: Duration,
59 packet_index: usize,
60 ) -> Result<(), FlowError> {
61 let (key, direction) = match extract_key(packet) {
62 Ok(result) => result,
63 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
64 return Ok(());
66 },
67 Err(e) => return Err(e),
68 };
69
70 let byte_count = packet.as_bytes().len() as u64;
71
72 let mut entry = self
74 .conversations
75 .entry(key.clone())
76 .or_insert_with(|| ConversationState::new(key, timestamp));
77
78 let conv = entry.value_mut();
79
80 conv.record_packet(
82 direction,
83 byte_count,
84 timestamp,
85 packet_index,
86 self.config.track_max_packet_len,
87 self.config.track_max_flow_len,
88 );
89
90 let buf = packet.as_bytes();
92 match &mut conv.protocol_state {
93 ProtocolState::Tcp(tcp_state) => {
94 if let Some(tcp) = packet.tcp() {
95 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
96 }
97 },
98 ProtocolState::Udp(udp_state) => {
99 udp_state.process_packet();
100 },
101 ProtocolState::Icmp(icmp_state) => {
102 if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
104 if buf.len() >= icmp_layer.start + 2 {
105 let icmp_type = buf[icmp_layer.start];
106 let icmp_code = buf[icmp_layer.start + 1];
107 icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
108 }
109 }
110 },
111 ProtocolState::Icmpv6(icmpv6_state) => {
112 if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
114 if buf.len() >= icmpv6_layer.start + 2 {
115 let icmpv6_type = buf[icmpv6_layer.start];
116 let icmpv6_code = buf[icmpv6_layer.start + 1];
117 icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
118 }
119 }
120 },
121 ProtocolState::ZWave(_) => {},
122 ProtocolState::Other => {},
123 }
124
125 conv.update_status();
127
128 if self.memory_tracker.has_budget() {
130 if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
131 let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
133 let rev_spilled = tcp_state.reassembler_rev.is_spilled();
134 if !fwd_spilled || !rev_spilled {
135 let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
136 let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
137 let payload_start = tcp.index.start + data_offset;
138 buf.len().saturating_sub(payload_start)
139 });
140 if tcp_payload_len > 0 {
141 self.memory_tracker.add(tcp_payload_len);
142 }
143 }
144 }
145 }
146
147 drop(entry);
149
150 if self.memory_tracker.is_over_budget() {
152 self.maybe_spill();
153 }
154
155 Ok(())
156 }
157
158 fn maybe_spill(&self) {
160 for mut entry in self.conversations.iter_mut() {
161 if !self.memory_tracker.is_over_budget() {
162 break;
163 }
164 if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
165 let freed_fwd = tcp_state
166 .reassembler_fwd
167 .spill(self.config.spill_dir.as_deref())
168 .unwrap_or(0);
169 let freed_rev = tcp_state
170 .reassembler_rev
171 .spill(self.config.spill_dir.as_deref())
172 .unwrap_or(0);
173 let total_freed = freed_fwd + freed_rev;
174 if total_freed > 0 {
175 self.memory_tracker.subtract(total_freed);
176 }
177 }
178 }
179 }
180
181 #[must_use]
183 pub fn get_conversation(
184 &self,
185 key: &CanonicalKey,
186 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
187 self.conversations.get(key)
188 }
189
190 #[must_use]
194 pub fn evict_idle(&self, now: Duration) -> usize {
195 let mut evicted = 0;
196 self.conversations.retain(|_, conv| {
197 if conv.is_timed_out(now, &self.config) {
198 evicted += 1;
199 false
200 } else {
201 true
202 }
203 });
204 evicted
205 }
206
207 #[must_use]
209 pub fn into_conversations(self) -> Vec<ConversationState> {
210 let mut conversations: Vec<ConversationState> =
211 self.conversations.into_iter().map(|(_, v)| v).collect();
212 conversations.sort_by_key(|c| c.start_time);
213 conversations
214 }
215
216 #[must_use]
218 pub fn config(&self) -> &FlowConfig {
219 &self.config
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::layer::stack::{LayerStack, LayerStackEntry};
227 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
228 use std::net::Ipv4Addr;
229
230 fn make_tcp_packet(
231 src_ip: Ipv4Addr,
232 dst_ip: Ipv4Addr,
233 sport: u16,
234 dport: u16,
235 flags: &str,
236 ) -> Packet {
237 let mut builder = TcpBuilder::new()
238 .src_port(sport)
239 .dst_port(dport)
240 .seq(1000)
241 .ack_num(0)
242 .window(65535);
243
244 for c in flags.chars() {
245 builder = match c {
246 'S' => builder.syn(),
247 'A' => builder.ack(),
248 'F' => builder.fin(),
249 'R' => builder.rst(),
250 _ => builder,
251 };
252 }
253
254 LayerStack::new()
255 .push(LayerStackEntry::Ethernet(
256 EthernetBuilder::new()
257 .dst(MacAddress::BROADCAST)
258 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
259 ))
260 .push(LayerStackEntry::Ipv4(
261 Ipv4Builder::new().src(src_ip).dst(dst_ip),
262 ))
263 .push(LayerStackEntry::Tcp(builder))
264 .build_packet()
265 }
266
267 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
268 LayerStack::new()
269 .push(LayerStackEntry::Ethernet(
270 EthernetBuilder::new()
271 .dst(MacAddress::BROADCAST)
272 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
273 ))
274 .push(LayerStackEntry::Ipv4(
275 Ipv4Builder::new().src(src_ip).dst(dst_ip),
276 ))
277 .push(LayerStackEntry::Udp(
278 UdpBuilder::new().src_port(sport).dst_port(dport),
279 ))
280 .build_packet()
281 }
282
283 #[test]
284 fn test_ingest_creates_conversation() {
285 let table = ConversationTable::with_default_config();
286 let pkt = make_tcp_packet(
287 Ipv4Addr::new(10, 0, 0, 1),
288 Ipv4Addr::new(10, 0, 0, 2),
289 12345,
290 80,
291 "S",
292 );
293
294 table
295 .ingest_packet(&pkt, Duration::from_secs(1), 0)
296 .unwrap();
297 assert_eq!(table.conversation_count(), 1);
298 }
299
300 #[test]
301 fn test_bidirectional_same_conversation() {
302 let table = ConversationTable::with_default_config();
303
304 let pkt_fwd = make_tcp_packet(
306 Ipv4Addr::new(10, 0, 0, 1),
307 Ipv4Addr::new(10, 0, 0, 2),
308 12345,
309 80,
310 "S",
311 );
312 table
313 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
314 .unwrap();
315
316 let pkt_rev = make_tcp_packet(
318 Ipv4Addr::new(10, 0, 0, 2),
319 Ipv4Addr::new(10, 0, 0, 1),
320 80,
321 12345,
322 "SA",
323 );
324 table
325 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
326 .unwrap();
327
328 assert_eq!(table.conversation_count(), 1);
330
331 let conversations = table.into_conversations();
332 assert_eq!(conversations[0].total_packets(), 2);
333 assert_eq!(conversations[0].forward.packets, 1);
334 assert_eq!(conversations[0].reverse.packets, 1);
335 }
336
337 #[test]
338 fn test_different_flows_different_conversations() {
339 let table = ConversationTable::with_default_config();
340
341 let pkt1 = make_tcp_packet(
342 Ipv4Addr::new(10, 0, 0, 1),
343 Ipv4Addr::new(10, 0, 0, 2),
344 12345,
345 80,
346 "S",
347 );
348 let pkt2 = make_tcp_packet(
349 Ipv4Addr::new(10, 0, 0, 1),
350 Ipv4Addr::new(10, 0, 0, 3),
351 12345,
352 443,
353 "S",
354 );
355
356 table
357 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
358 .unwrap();
359 table
360 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
361 .unwrap();
362
363 assert_eq!(table.conversation_count(), 2);
364 }
365
366 #[test]
367 fn test_udp_conversation() {
368 let table = ConversationTable::with_default_config();
369
370 let pkt = make_udp_packet(
371 Ipv4Addr::new(10, 0, 0, 1),
372 Ipv4Addr::new(10, 0, 0, 2),
373 12345,
374 53,
375 );
376 table
377 .ingest_packet(&pkt, Duration::from_secs(1), 0)
378 .unwrap();
379
380 let conversations = table.into_conversations();
381 assert_eq!(conversations.len(), 1);
382 assert!(matches!(
383 conversations[0].protocol_state,
384 ProtocolState::Udp(_)
385 ));
386 }
387
388 #[test]
389 fn test_evict_idle() {
390 let mut config = FlowConfig::default();
391 config.udp_timeout = Duration::from_secs(10);
392 let table = ConversationTable::new(config);
393
394 let pkt = make_udp_packet(
395 Ipv4Addr::new(10, 0, 0, 1),
396 Ipv4Addr::new(10, 0, 0, 2),
397 12345,
398 53,
399 );
400 table
401 .ingest_packet(&pkt, Duration::from_secs(1), 0)
402 .unwrap();
403 assert_eq!(table.conversation_count(), 1);
404
405 let evicted = table.evict_idle(Duration::from_secs(5));
407 assert_eq!(evicted, 0);
408 assert_eq!(table.conversation_count(), 1);
409
410 let evicted = table.evict_idle(Duration::from_secs(20));
412 assert_eq!(evicted, 1);
413 assert_eq!(table.conversation_count(), 0);
414 }
415
416 #[test]
417 fn test_into_conversations_sorted() {
418 let table = ConversationTable::with_default_config();
419
420 let pkt1 = make_tcp_packet(
421 Ipv4Addr::new(10, 0, 0, 1),
422 Ipv4Addr::new(10, 0, 0, 2),
423 12345,
424 80,
425 "S",
426 );
427 let pkt2 = make_tcp_packet(
428 Ipv4Addr::new(10, 0, 0, 1),
429 Ipv4Addr::new(10, 0, 0, 3),
430 12345,
431 443,
432 "S",
433 );
434
435 table
437 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
438 .unwrap();
439 table
440 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
441 .unwrap();
442
443 let conversations = table.into_conversations();
444 assert!(conversations[0].start_time <= conversations[1].start_time);
445 }
446}