Skip to main content

stackforge_core/flow/
table.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14/// Thread-safe conversation tracking table backed by `DashMap`.
15///
16/// Supports concurrent packet ingestion from multiple threads while
17/// maintaining per-conversation state including TCP state machines
18/// and stream reassembly. Optionally tracks memory usage and spills
19/// reassembly buffers to disk when a budget is exceeded.
20pub struct ConversationTable {
21    conversations: DashMap<CanonicalKey, ConversationState>,
22    config: FlowConfig,
23    memory_tracker: Arc<MemoryTracker>,
24    spill_count: std::sync::atomic::AtomicUsize,
25}
26
27impl ConversationTable {
28    /// Create a new table with the given configuration.
29    #[must_use]
30    pub fn new(config: FlowConfig) -> Self {
31        let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
32        Self {
33            conversations: DashMap::new(),
34            config,
35            memory_tracker,
36            spill_count: std::sync::atomic::AtomicUsize::new(0),
37        }
38    }
39
40    /// Create a new table with default configuration.
41    #[must_use]
42    pub fn with_default_config() -> Self {
43        Self::new(FlowConfig::default())
44    }
45
46    /// Number of tracked conversations.
47    #[must_use]
48    pub fn conversation_count(&self) -> usize {
49        self.conversations.len()
50    }
51
52    /// Ingest a single parsed packet, updating or creating conversation state.
53    ///
54    /// `timestamp` is the packet capture timestamp (from PCAP metadata).
55    /// `packet_index` is the index of this packet in the original capture
56    /// (used for cross-referencing).
57    pub fn ingest_packet(
58        &self,
59        packet: &Packet,
60        timestamp: Duration,
61        packet_index: usize,
62    ) -> Result<(), FlowError> {
63        let (key, direction) = match extract_key(packet) {
64            Ok(result) => result,
65            Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
66                // Skip non-IP or non-TCP/UDP packets silently
67                return Ok(());
68            },
69            Err(e) => return Err(e),
70        };
71
72        let byte_count = packet.as_bytes().len() as u64;
73
74        // Use DashMap entry API for atomic get-or-insert + update
75        let mut entry = self
76            .conversations
77            .entry(key.clone())
78            .or_insert_with(|| ConversationState::new(key, timestamp));
79
80        let conv = entry.value_mut();
81
82        // Record packet stats
83        conv.record_packet(
84            direction,
85            byte_count,
86            timestamp,
87            packet_index,
88            self.config.track_max_packet_len,
89            self.config.track_max_flow_len,
90            self.config.store_packet_indices,
91        );
92
93        // Process protocol-specific state
94        let buf = packet.as_bytes();
95        match &mut conv.protocol_state {
96            ProtocolState::Tcp(tcp_state) => {
97                if let Some(tcp) = packet.tcp() {
98                    tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
99                }
100            },
101            ProtocolState::Udp(udp_state) => {
102                udp_state.process_packet();
103            },
104            ProtocolState::Icmp(icmp_state) => {
105                // Get ICMP type and code from buffer
106                if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
107                    if buf.len() >= icmp_layer.start + 2 {
108                        let icmp_type = buf[icmp_layer.start];
109                        let icmp_code = buf[icmp_layer.start + 1];
110                        icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
111                    }
112                }
113            },
114            ProtocolState::Icmpv6(icmpv6_state) => {
115                // Get ICMPv6 type and code from buffer
116                if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
117                    if buf.len() >= icmpv6_layer.start + 2 {
118                        let icmpv6_type = buf[icmpv6_layer.start];
119                        let icmpv6_code = buf[icmpv6_layer.start + 1];
120                        icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
121                    }
122                }
123            },
124            ProtocolState::ZWave(_) => {},
125            ProtocolState::Other => {},
126        }
127
128        // Update conversation status from protocol state
129        conv.update_status();
130
131        // Track memory for TCP reassembly buffers (only for in-memory data)
132        if self.memory_tracker.has_budget() {
133            if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
134                // Only track if at least one reassembler is still in memory
135                let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
136                let rev_spilled = tcp_state.reassembler_rev.is_spilled();
137                if !fwd_spilled || !rev_spilled {
138                    let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
139                        let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
140                        let payload_start = tcp.index.start + data_offset;
141                        buf.len().saturating_sub(payload_start)
142                    });
143                    if tcp_payload_len > 0 {
144                        self.memory_tracker.add(tcp_payload_len);
145                    }
146                }
147            }
148        }
149
150        // Drop the entry lock before spilling (which needs iter_mut)
151        drop(entry);
152
153        // Spill if over budget
154        if self.memory_tracker.is_over_budget() {
155            self.maybe_spill();
156        }
157
158        Ok(())
159    }
160
161    /// Spill reassembly buffers to disk until under budget.
162    ///
163    /// Limits the number of entries scanned per call to avoid O(n) full-table
164    /// walks that become expensive as the flow count grows.
165    fn maybe_spill(&self) {
166        let mut scanned = 0;
167        let max_scan = 256; // Cap per-call work to avoid stalls
168        for mut entry in self.conversations.iter_mut() {
169            if !self.memory_tracker.is_over_budget() || scanned >= max_scan {
170                break;
171            }
172            scanned += 1;
173            if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
174                // Only try to spill buffers that are still in memory
175                if tcp_state.reassembler_fwd.is_spilled() && tcp_state.reassembler_rev.is_spilled()
176                {
177                    continue;
178                }
179                let freed_fwd = tcp_state
180                    .reassembler_fwd
181                    .spill(self.config.spill_dir.as_deref())
182                    .unwrap_or(0);
183                let freed_rev = tcp_state
184                    .reassembler_rev
185                    .spill(self.config.spill_dir.as_deref())
186                    .unwrap_or(0);
187                let total_freed = freed_fwd + freed_rev;
188                if total_freed > 0 {
189                    self.memory_tracker.subtract(total_freed);
190                    self.spill_count
191                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
192                }
193            }
194        }
195    }
196
197    /// Estimated memory usage of the flow table (tracked reassembly buffers).
198    #[must_use]
199    pub fn memory_usage(&self) -> usize {
200        self.memory_tracker.current_usage()
201    }
202
203    /// Number of spill operations performed.
204    #[must_use]
205    pub fn spill_count(&self) -> usize {
206        self.spill_count.load(std::sync::atomic::Ordering::Relaxed)
207    }
208
209    /// Get a read reference to a specific conversation.
210    #[must_use]
211    pub fn get_conversation(
212        &self,
213        key: &CanonicalKey,
214    ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
215        self.conversations.get(key)
216    }
217
218    /// Evict conversations that have exceeded their idle timeout.
219    ///
220    /// Returns the number of evicted conversations.
221    #[must_use]
222    pub fn evict_idle(&self, now: Duration) -> usize {
223        let mut evicted = 0;
224        self.conversations.retain(|_, conv| {
225            if conv.is_timed_out(now, &self.config) {
226                evicted += 1;
227                false
228            } else {
229                true
230            }
231        });
232        evicted
233    }
234
235    /// Consume the table and return all conversations sorted by start time.
236    #[must_use]
237    pub fn into_conversations(self) -> Vec<ConversationState> {
238        let mut conversations: Vec<ConversationState> =
239            self.conversations.into_iter().map(|(_, v)| v).collect();
240        conversations.sort_by_key(|c| c.start_time);
241        conversations
242    }
243
244    /// Get a reference to the configuration.
245    #[must_use]
246    pub fn config(&self) -> &FlowConfig {
247        &self.config
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::layer::stack::{LayerStack, LayerStackEntry};
255    use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
256    use std::net::Ipv4Addr;
257
258    fn make_tcp_packet(
259        src_ip: Ipv4Addr,
260        dst_ip: Ipv4Addr,
261        sport: u16,
262        dport: u16,
263        flags: &str,
264    ) -> Packet {
265        let mut builder = TcpBuilder::new()
266            .src_port(sport)
267            .dst_port(dport)
268            .seq(1000)
269            .ack_num(0)
270            .window(65535);
271
272        for c in flags.chars() {
273            builder = match c {
274                'S' => builder.syn(),
275                'A' => builder.ack(),
276                'F' => builder.fin(),
277                'R' => builder.rst(),
278                _ => builder,
279            };
280        }
281
282        LayerStack::new()
283            .push(LayerStackEntry::Ethernet(
284                EthernetBuilder::new()
285                    .dst(MacAddress::BROADCAST)
286                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
287            ))
288            .push(LayerStackEntry::Ipv4(
289                Ipv4Builder::new().src(src_ip).dst(dst_ip),
290            ))
291            .push(LayerStackEntry::Tcp(builder))
292            .build_packet()
293    }
294
295    fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
296        LayerStack::new()
297            .push(LayerStackEntry::Ethernet(
298                EthernetBuilder::new()
299                    .dst(MacAddress::BROADCAST)
300                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
301            ))
302            .push(LayerStackEntry::Ipv4(
303                Ipv4Builder::new().src(src_ip).dst(dst_ip),
304            ))
305            .push(LayerStackEntry::Udp(
306                UdpBuilder::new().src_port(sport).dst_port(dport),
307            ))
308            .build_packet()
309    }
310
311    #[test]
312    fn test_ingest_creates_conversation() {
313        let table = ConversationTable::with_default_config();
314        let pkt = make_tcp_packet(
315            Ipv4Addr::new(10, 0, 0, 1),
316            Ipv4Addr::new(10, 0, 0, 2),
317            12345,
318            80,
319            "S",
320        );
321
322        table
323            .ingest_packet(&pkt, Duration::from_secs(1), 0)
324            .unwrap();
325        assert_eq!(table.conversation_count(), 1);
326    }
327
328    #[test]
329    fn test_bidirectional_same_conversation() {
330        let table = ConversationTable::with_default_config();
331
332        // Forward packet
333        let pkt_fwd = make_tcp_packet(
334            Ipv4Addr::new(10, 0, 0, 1),
335            Ipv4Addr::new(10, 0, 0, 2),
336            12345,
337            80,
338            "S",
339        );
340        table
341            .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
342            .unwrap();
343
344        // Reverse packet
345        let pkt_rev = make_tcp_packet(
346            Ipv4Addr::new(10, 0, 0, 2),
347            Ipv4Addr::new(10, 0, 0, 1),
348            80,
349            12345,
350            "SA",
351        );
352        table
353            .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
354            .unwrap();
355
356        // Should be one conversation, not two
357        assert_eq!(table.conversation_count(), 1);
358
359        let conversations = table.into_conversations();
360        assert_eq!(conversations[0].total_packets(), 2);
361        assert_eq!(conversations[0].forward.packets, 1);
362        assert_eq!(conversations[0].reverse.packets, 1);
363    }
364
365    #[test]
366    fn test_different_flows_different_conversations() {
367        let table = ConversationTable::with_default_config();
368
369        let pkt1 = make_tcp_packet(
370            Ipv4Addr::new(10, 0, 0, 1),
371            Ipv4Addr::new(10, 0, 0, 2),
372            12345,
373            80,
374            "S",
375        );
376        let pkt2 = make_tcp_packet(
377            Ipv4Addr::new(10, 0, 0, 1),
378            Ipv4Addr::new(10, 0, 0, 3),
379            12345,
380            443,
381            "S",
382        );
383
384        table
385            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
386            .unwrap();
387        table
388            .ingest_packet(&pkt2, Duration::from_secs(2), 1)
389            .unwrap();
390
391        assert_eq!(table.conversation_count(), 2);
392    }
393
394    #[test]
395    fn test_udp_conversation() {
396        let table = ConversationTable::with_default_config();
397
398        let pkt = make_udp_packet(
399            Ipv4Addr::new(10, 0, 0, 1),
400            Ipv4Addr::new(10, 0, 0, 2),
401            12345,
402            53,
403        );
404        table
405            .ingest_packet(&pkt, Duration::from_secs(1), 0)
406            .unwrap();
407
408        let conversations = table.into_conversations();
409        assert_eq!(conversations.len(), 1);
410        assert!(matches!(
411            conversations[0].protocol_state,
412            ProtocolState::Udp(_)
413        ));
414    }
415
416    #[test]
417    fn test_evict_idle() {
418        let mut config = FlowConfig::default();
419        config.udp_timeout = Duration::from_secs(10);
420        let table = ConversationTable::new(config);
421
422        let pkt = make_udp_packet(
423            Ipv4Addr::new(10, 0, 0, 1),
424            Ipv4Addr::new(10, 0, 0, 2),
425            12345,
426            53,
427        );
428        table
429            .ingest_packet(&pkt, Duration::from_secs(1), 0)
430            .unwrap();
431        assert_eq!(table.conversation_count(), 1);
432
433        // Not yet timed out
434        let evicted = table.evict_idle(Duration::from_secs(5));
435        assert_eq!(evicted, 0);
436        assert_eq!(table.conversation_count(), 1);
437
438        // Now timed out
439        let evicted = table.evict_idle(Duration::from_secs(20));
440        assert_eq!(evicted, 1);
441        assert_eq!(table.conversation_count(), 0);
442    }
443
444    #[test]
445    fn test_into_conversations_sorted() {
446        let table = ConversationTable::with_default_config();
447
448        let pkt1 = make_tcp_packet(
449            Ipv4Addr::new(10, 0, 0, 1),
450            Ipv4Addr::new(10, 0, 0, 2),
451            12345,
452            80,
453            "S",
454        );
455        let pkt2 = make_tcp_packet(
456            Ipv4Addr::new(10, 0, 0, 1),
457            Ipv4Addr::new(10, 0, 0, 3),
458            12345,
459            443,
460            "S",
461        );
462
463        // Insert second flow first (later timestamp)
464        table
465            .ingest_packet(&pkt2, Duration::from_secs(5), 1)
466            .unwrap();
467        table
468            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
469            .unwrap();
470
471        let conversations = table.into_conversations();
472        assert!(conversations[0].start_time <= conversations[1].start_time);
473    }
474}