Skip to main content

stackforge_core/flow/
table.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14/// Thread-safe conversation tracking table backed by `DashMap`.
15///
16/// Supports concurrent packet ingestion from multiple threads while
17/// maintaining per-conversation state including TCP state machines
18/// and stream reassembly. Optionally tracks memory usage and spills
19/// reassembly buffers to disk when a budget is exceeded.
20pub struct ConversationTable {
21    conversations: DashMap<CanonicalKey, ConversationState>,
22    config: FlowConfig,
23    memory_tracker: Arc<MemoryTracker>,
24    spill_count: std::sync::atomic::AtomicUsize,
25}
26
27impl ConversationTable {
28    /// Create a new table with the given configuration.
29    #[must_use]
30    pub fn new(config: FlowConfig) -> Self {
31        let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
32        Self {
33            conversations: DashMap::new(),
34            config,
35            memory_tracker,
36            spill_count: std::sync::atomic::AtomicUsize::new(0),
37        }
38    }
39
40    /// Create a new table with default configuration.
41    #[must_use]
42    pub fn with_default_config() -> Self {
43        Self::new(FlowConfig::default())
44    }
45
46    /// Number of tracked conversations.
47    #[must_use]
48    pub fn conversation_count(&self) -> usize {
49        self.conversations.len()
50    }
51
52    /// Ingest a single parsed packet, updating or creating conversation state.
53    ///
54    /// `timestamp` is the packet capture timestamp (from PCAP metadata).
55    /// `packet_index` is the index of this packet in the original capture
56    /// (used for cross-referencing).
57    pub fn ingest_packet(
58        &self,
59        packet: &Packet,
60        timestamp: Duration,
61        packet_index: usize,
62    ) -> Result<(), FlowError> {
63        let (key, direction) = match extract_key(packet) {
64            Ok(result) => result,
65            Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
66                // Skip non-IP or non-TCP/UDP packets silently
67                return Ok(());
68            },
69            Err(e) => return Err(e),
70        };
71
72        let byte_count = packet.as_bytes().len() as u64;
73
74        // Use DashMap entry API for atomic get-or-insert + update
75        let mut entry = self
76            .conversations
77            .entry(key.clone())
78            .or_insert_with(|| ConversationState::new(key, timestamp));
79
80        let conv = entry.value_mut();
81
82        // Record packet stats
83        conv.record_packet(
84            direction,
85            byte_count,
86            timestamp,
87            packet_index,
88            self.config.track_max_packet_len,
89            self.config.track_max_flow_len,
90            self.config.store_packet_indices,
91        );
92
93        // Process protocol-specific state
94        let buf = packet.as_bytes();
95        match &mut conv.protocol_state {
96            ProtocolState::Tcp(tcp_state) => {
97                if let Some(tcp) = packet.tcp() {
98                    tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
99                }
100            },
101            ProtocolState::Udp(udp_state) => {
102                udp_state.process_packet();
103            },
104            ProtocolState::Icmp(icmp_state) => {
105                // Get ICMP type and code from buffer
106                if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
107                    if buf.len() >= icmp_layer.start + 2 {
108                        let icmp_type = buf[icmp_layer.start];
109                        let icmp_code = buf[icmp_layer.start + 1];
110                        icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
111                    }
112                }
113            },
114            ProtocolState::Icmpv6(icmpv6_state) => {
115                // Get ICMPv6 type and code from buffer
116                if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
117                    if buf.len() >= icmpv6_layer.start + 2 {
118                        let icmpv6_type = buf[icmpv6_layer.start];
119                        let icmpv6_code = buf[icmpv6_layer.start + 1];
120                        icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
121                    }
122                }
123            },
124            ProtocolState::ZWave(_) => {},
125            ProtocolState::Other => {},
126        }
127
128        // Update conversation status from protocol state
129        conv.update_status();
130
131        // Track memory for TCP reassembly buffers (only for in-memory data)
132        if self.memory_tracker.has_budget() {
133            if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
134                // Only track if at least one reassembler is still in memory
135                let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
136                let rev_spilled = tcp_state.reassembler_rev.is_spilled();
137                if !fwd_spilled || !rev_spilled {
138                    let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
139                        let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
140                        let payload_start = tcp.index.start + data_offset;
141                        buf.len().saturating_sub(payload_start)
142                    });
143                    if tcp_payload_len > 0 {
144                        self.memory_tracker.add(tcp_payload_len);
145                    }
146                }
147            }
148        }
149
150        // Drop the entry lock before spilling (which needs iter_mut)
151        drop(entry);
152
153        // Spill if over budget
154        if self.memory_tracker.is_over_budget() {
155            self.maybe_spill();
156        }
157
158        Ok(())
159    }
160
161    /// Spill reassembly buffers to disk until under budget.
162    ///
163    /// Two limits prevent runaway iteration:
164    /// - `max_spills`: stop after spilling this many buffers (actual work done)
165    /// - `max_skip`: stop after skipping this many already-spilled/non-TCP entries
166    ///   without finding anything to free (avoids scanning the entire table when
167    ///   most flows are already on disk)
168    fn maybe_spill(&self) {
169        let mut spills = 0;
170        let max_spills = 64;
171        let mut consecutive_skips = 0;
172        let max_skip = 512;
173
174        for mut entry in self.conversations.iter_mut() {
175            if !self.memory_tracker.is_over_budget() || spills >= max_spills {
176                break;
177            }
178            if consecutive_skips >= max_skip {
179                // Most nearby entries are already spilled — stop scanning
180                break;
181            }
182
183            if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
184                // Skip buffers already on disk
185                if tcp_state.reassembler_fwd.is_spilled() && tcp_state.reassembler_rev.is_spilled()
186                {
187                    consecutive_skips += 1;
188                    continue;
189                }
190                let freed_fwd = tcp_state
191                    .reassembler_fwd
192                    .spill(self.config.spill_dir.as_deref())
193                    .unwrap_or(0);
194                let freed_rev = tcp_state
195                    .reassembler_rev
196                    .spill(self.config.spill_dir.as_deref())
197                    .unwrap_or(0);
198                let total_freed = freed_fwd + freed_rev;
199                if total_freed > 0 {
200                    self.memory_tracker.subtract(total_freed);
201                    self.spill_count
202                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
203                    spills += 1;
204                    consecutive_skips = 0; // Reset — we found something useful
205                } else {
206                    consecutive_skips += 1;
207                }
208            } else {
209                consecutive_skips += 1;
210            }
211        }
212    }
213
214    /// Estimated memory usage of the flow table (tracked reassembly buffers).
215    #[must_use]
216    pub fn memory_usage(&self) -> usize {
217        self.memory_tracker.current_usage()
218    }
219
220    /// Number of spill operations performed.
221    #[must_use]
222    pub fn spill_count(&self) -> usize {
223        self.spill_count.load(std::sync::atomic::Ordering::Relaxed)
224    }
225
226    /// Get a read reference to a specific conversation.
227    #[must_use]
228    pub fn get_conversation(
229        &self,
230        key: &CanonicalKey,
231    ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
232        self.conversations.get(key)
233    }
234
235    /// Evict conversations that have exceeded their idle timeout.
236    ///
237    /// Returns the number of evicted conversations.
238    #[must_use]
239    pub fn evict_idle(&self, now: Duration) -> usize {
240        let mut evicted = 0;
241        self.conversations.retain(|_, conv| {
242            if conv.is_timed_out(now, &self.config) {
243                evicted += 1;
244                false
245            } else {
246                true
247            }
248        });
249        evicted
250    }
251
252    /// Consume the table and return all conversations sorted by start time.
253    #[must_use]
254    pub fn into_conversations(self) -> Vec<ConversationState> {
255        let mut conversations: Vec<ConversationState> =
256            self.conversations.into_iter().map(|(_, v)| v).collect();
257        conversations.sort_by_key(|c| c.start_time);
258        conversations
259    }
260
261    /// Get a reference to the configuration.
262    #[must_use]
263    pub fn config(&self) -> &FlowConfig {
264        &self.config
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::layer::stack::{LayerStack, LayerStackEntry};
272    use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
273    use std::net::Ipv4Addr;
274
275    fn make_tcp_packet(
276        src_ip: Ipv4Addr,
277        dst_ip: Ipv4Addr,
278        sport: u16,
279        dport: u16,
280        flags: &str,
281    ) -> Packet {
282        let mut builder = TcpBuilder::new()
283            .src_port(sport)
284            .dst_port(dport)
285            .seq(1000)
286            .ack_num(0)
287            .window(65535);
288
289        for c in flags.chars() {
290            builder = match c {
291                'S' => builder.syn(),
292                'A' => builder.ack(),
293                'F' => builder.fin(),
294                'R' => builder.rst(),
295                _ => builder,
296            };
297        }
298
299        LayerStack::new()
300            .push(LayerStackEntry::Ethernet(
301                EthernetBuilder::new()
302                    .dst(MacAddress::BROADCAST)
303                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
304            ))
305            .push(LayerStackEntry::Ipv4(
306                Ipv4Builder::new().src(src_ip).dst(dst_ip),
307            ))
308            .push(LayerStackEntry::Tcp(builder))
309            .build_packet()
310    }
311
312    fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
313        LayerStack::new()
314            .push(LayerStackEntry::Ethernet(
315                EthernetBuilder::new()
316                    .dst(MacAddress::BROADCAST)
317                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
318            ))
319            .push(LayerStackEntry::Ipv4(
320                Ipv4Builder::new().src(src_ip).dst(dst_ip),
321            ))
322            .push(LayerStackEntry::Udp(
323                UdpBuilder::new().src_port(sport).dst_port(dport),
324            ))
325            .build_packet()
326    }
327
328    #[test]
329    fn test_ingest_creates_conversation() {
330        let table = ConversationTable::with_default_config();
331        let pkt = make_tcp_packet(
332            Ipv4Addr::new(10, 0, 0, 1),
333            Ipv4Addr::new(10, 0, 0, 2),
334            12345,
335            80,
336            "S",
337        );
338
339        table
340            .ingest_packet(&pkt, Duration::from_secs(1), 0)
341            .unwrap();
342        assert_eq!(table.conversation_count(), 1);
343    }
344
345    #[test]
346    fn test_bidirectional_same_conversation() {
347        let table = ConversationTable::with_default_config();
348
349        // Forward packet
350        let pkt_fwd = make_tcp_packet(
351            Ipv4Addr::new(10, 0, 0, 1),
352            Ipv4Addr::new(10, 0, 0, 2),
353            12345,
354            80,
355            "S",
356        );
357        table
358            .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
359            .unwrap();
360
361        // Reverse packet
362        let pkt_rev = make_tcp_packet(
363            Ipv4Addr::new(10, 0, 0, 2),
364            Ipv4Addr::new(10, 0, 0, 1),
365            80,
366            12345,
367            "SA",
368        );
369        table
370            .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
371            .unwrap();
372
373        // Should be one conversation, not two
374        assert_eq!(table.conversation_count(), 1);
375
376        let conversations = table.into_conversations();
377        assert_eq!(conversations[0].total_packets(), 2);
378        assert_eq!(conversations[0].forward.packets, 1);
379        assert_eq!(conversations[0].reverse.packets, 1);
380    }
381
382    #[test]
383    fn test_different_flows_different_conversations() {
384        let table = ConversationTable::with_default_config();
385
386        let pkt1 = make_tcp_packet(
387            Ipv4Addr::new(10, 0, 0, 1),
388            Ipv4Addr::new(10, 0, 0, 2),
389            12345,
390            80,
391            "S",
392        );
393        let pkt2 = make_tcp_packet(
394            Ipv4Addr::new(10, 0, 0, 1),
395            Ipv4Addr::new(10, 0, 0, 3),
396            12345,
397            443,
398            "S",
399        );
400
401        table
402            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
403            .unwrap();
404        table
405            .ingest_packet(&pkt2, Duration::from_secs(2), 1)
406            .unwrap();
407
408        assert_eq!(table.conversation_count(), 2);
409    }
410
411    #[test]
412    fn test_udp_conversation() {
413        let table = ConversationTable::with_default_config();
414
415        let pkt = make_udp_packet(
416            Ipv4Addr::new(10, 0, 0, 1),
417            Ipv4Addr::new(10, 0, 0, 2),
418            12345,
419            53,
420        );
421        table
422            .ingest_packet(&pkt, Duration::from_secs(1), 0)
423            .unwrap();
424
425        let conversations = table.into_conversations();
426        assert_eq!(conversations.len(), 1);
427        assert!(matches!(
428            conversations[0].protocol_state,
429            ProtocolState::Udp(_)
430        ));
431    }
432
433    #[test]
434    fn test_evict_idle() {
435        let mut config = FlowConfig::default();
436        config.udp_timeout = Duration::from_secs(10);
437        let table = ConversationTable::new(config);
438
439        let pkt = make_udp_packet(
440            Ipv4Addr::new(10, 0, 0, 1),
441            Ipv4Addr::new(10, 0, 0, 2),
442            12345,
443            53,
444        );
445        table
446            .ingest_packet(&pkt, Duration::from_secs(1), 0)
447            .unwrap();
448        assert_eq!(table.conversation_count(), 1);
449
450        // Not yet timed out
451        let evicted = table.evict_idle(Duration::from_secs(5));
452        assert_eq!(evicted, 0);
453        assert_eq!(table.conversation_count(), 1);
454
455        // Now timed out
456        let evicted = table.evict_idle(Duration::from_secs(20));
457        assert_eq!(evicted, 1);
458        assert_eq!(table.conversation_count(), 0);
459    }
460
461    #[test]
462    fn test_into_conversations_sorted() {
463        let table = ConversationTable::with_default_config();
464
465        let pkt1 = make_tcp_packet(
466            Ipv4Addr::new(10, 0, 0, 1),
467            Ipv4Addr::new(10, 0, 0, 2),
468            12345,
469            80,
470            "S",
471        );
472        let pkt2 = make_tcp_packet(
473            Ipv4Addr::new(10, 0, 0, 1),
474            Ipv4Addr::new(10, 0, 0, 3),
475            12345,
476            443,
477            "S",
478        );
479
480        // Insert second flow first (later timestamp)
481        table
482            .ingest_packet(&pkt2, Duration::from_secs(5), 1)
483            .unwrap();
484        table
485            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
486            .unwrap();
487
488        let conversations = table.into_conversations();
489        assert!(conversations[0].start_time <= conversations[1].start_time);
490    }
491}