Skip to main content

stackforge_core/flow/
table.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14/// Thread-safe conversation tracking table backed by `DashMap`.
15///
16/// Supports concurrent packet ingestion from multiple threads while
17/// maintaining per-conversation state including TCP state machines
18/// and stream reassembly. Optionally tracks memory usage and spills
19/// reassembly buffers to disk when a budget is exceeded.
20pub struct ConversationTable {
21    conversations: DashMap<CanonicalKey, ConversationState>,
22    config: FlowConfig,
23    memory_tracker: Arc<MemoryTracker>,
24}
25
26impl ConversationTable {
27    /// Create a new table with the given configuration.
28    #[must_use]
29    pub fn new(config: FlowConfig) -> Self {
30        let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
31        Self {
32            conversations: DashMap::new(),
33            config,
34            memory_tracker,
35        }
36    }
37
38    /// Create a new table with default configuration.
39    #[must_use]
40    pub fn with_default_config() -> Self {
41        Self::new(FlowConfig::default())
42    }
43
44    /// Number of tracked conversations.
45    #[must_use]
46    pub fn conversation_count(&self) -> usize {
47        self.conversations.len()
48    }
49
50    /// Ingest a single parsed packet, updating or creating conversation state.
51    ///
52    /// `timestamp` is the packet capture timestamp (from PCAP metadata).
53    /// `packet_index` is the index of this packet in the original capture
54    /// (used for cross-referencing).
55    pub fn ingest_packet(
56        &self,
57        packet: &Packet,
58        timestamp: Duration,
59        packet_index: usize,
60    ) -> Result<(), FlowError> {
61        let (key, direction) = match extract_key(packet) {
62            Ok(result) => result,
63            Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
64                // Skip non-IP or non-TCP/UDP packets silently
65                return Ok(());
66            },
67            Err(e) => return Err(e),
68        };
69
70        let byte_count = packet.as_bytes().len() as u64;
71
72        // Use DashMap entry API for atomic get-or-insert + update
73        let mut entry = self
74            .conversations
75            .entry(key.clone())
76            .or_insert_with(|| ConversationState::new(key, timestamp));
77
78        let conv = entry.value_mut();
79
80        // Record packet stats
81        conv.record_packet(
82            direction,
83            byte_count,
84            timestamp,
85            packet_index,
86            self.config.track_max_packet_len,
87            self.config.track_max_flow_len,
88        );
89
90        // Process protocol-specific state
91        let buf = packet.as_bytes();
92        match &mut conv.protocol_state {
93            ProtocolState::Tcp(tcp_state) => {
94                if let Some(tcp) = packet.tcp() {
95                    tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
96                }
97            },
98            ProtocolState::Udp(udp_state) => {
99                udp_state.process_packet();
100            },
101            ProtocolState::Icmp(icmp_state) => {
102                // Get ICMP type and code from buffer
103                if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
104                    if buf.len() >= icmp_layer.start + 2 {
105                        let icmp_type = buf[icmp_layer.start];
106                        let icmp_code = buf[icmp_layer.start + 1];
107                        icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
108                    }
109                }
110            },
111            ProtocolState::Icmpv6(icmpv6_state) => {
112                // Get ICMPv6 type and code from buffer
113                if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
114                    if buf.len() >= icmpv6_layer.start + 2 {
115                        let icmpv6_type = buf[icmpv6_layer.start];
116                        let icmpv6_code = buf[icmpv6_layer.start + 1];
117                        icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
118                    }
119                }
120            },
121            ProtocolState::ZWave(_) => {},
122            ProtocolState::Other => {},
123        }
124
125        // Update conversation status from protocol state
126        conv.update_status();
127
128        // Track memory for TCP reassembly buffers (only for in-memory data)
129        if self.memory_tracker.has_budget() {
130            if let ProtocolState::Tcp(ref tcp_state) = conv.protocol_state {
131                // Only track if at least one reassembler is still in memory
132                let fwd_spilled = tcp_state.reassembler_fwd.is_spilled();
133                let rev_spilled = tcp_state.reassembler_rev.is_spilled();
134                if !fwd_spilled || !rev_spilled {
135                    let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
136                        let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
137                        let payload_start = tcp.index.start + data_offset;
138                        buf.len().saturating_sub(payload_start)
139                    });
140                    if tcp_payload_len > 0 {
141                        self.memory_tracker.add(tcp_payload_len);
142                    }
143                }
144            }
145        }
146
147        // Drop the entry lock before spilling (which needs iter_mut)
148        drop(entry);
149
150        // Spill if over budget
151        if self.memory_tracker.is_over_budget() {
152            self.maybe_spill();
153        }
154
155        Ok(())
156    }
157
158    /// Spill the largest reassembly buffers to disk until under budget.
159    fn maybe_spill(&self) {
160        for mut entry in self.conversations.iter_mut() {
161            if !self.memory_tracker.is_over_budget() {
162                break;
163            }
164            if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
165                let freed_fwd = tcp_state
166                    .reassembler_fwd
167                    .spill(self.config.spill_dir.as_deref())
168                    .unwrap_or(0);
169                let freed_rev = tcp_state
170                    .reassembler_rev
171                    .spill(self.config.spill_dir.as_deref())
172                    .unwrap_or(0);
173                let total_freed = freed_fwd + freed_rev;
174                if total_freed > 0 {
175                    self.memory_tracker.subtract(total_freed);
176                }
177            }
178        }
179    }
180
181    /// Get a read reference to a specific conversation.
182    #[must_use]
183    pub fn get_conversation(
184        &self,
185        key: &CanonicalKey,
186    ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
187        self.conversations.get(key)
188    }
189
190    /// Evict conversations that have exceeded their idle timeout.
191    ///
192    /// Returns the number of evicted conversations.
193    #[must_use]
194    pub fn evict_idle(&self, now: Duration) -> usize {
195        let mut evicted = 0;
196        self.conversations.retain(|_, conv| {
197            if conv.is_timed_out(now, &self.config) {
198                evicted += 1;
199                false
200            } else {
201                true
202            }
203        });
204        evicted
205    }
206
207    /// Consume the table and return all conversations sorted by start time.
208    #[must_use]
209    pub fn into_conversations(self) -> Vec<ConversationState> {
210        let mut conversations: Vec<ConversationState> =
211            self.conversations.into_iter().map(|(_, v)| v).collect();
212        conversations.sort_by_key(|c| c.start_time);
213        conversations
214    }
215
216    /// Get a reference to the configuration.
217    #[must_use]
218    pub fn config(&self) -> &FlowConfig {
219        &self.config
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::layer::stack::{LayerStack, LayerStackEntry};
227    use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
228    use std::net::Ipv4Addr;
229
230    fn make_tcp_packet(
231        src_ip: Ipv4Addr,
232        dst_ip: Ipv4Addr,
233        sport: u16,
234        dport: u16,
235        flags: &str,
236    ) -> Packet {
237        let mut builder = TcpBuilder::new()
238            .src_port(sport)
239            .dst_port(dport)
240            .seq(1000)
241            .ack_num(0)
242            .window(65535);
243
244        for c in flags.chars() {
245            builder = match c {
246                'S' => builder.syn(),
247                'A' => builder.ack(),
248                'F' => builder.fin(),
249                'R' => builder.rst(),
250                _ => builder,
251            };
252        }
253
254        LayerStack::new()
255            .push(LayerStackEntry::Ethernet(
256                EthernetBuilder::new()
257                    .dst(MacAddress::BROADCAST)
258                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
259            ))
260            .push(LayerStackEntry::Ipv4(
261                Ipv4Builder::new().src(src_ip).dst(dst_ip),
262            ))
263            .push(LayerStackEntry::Tcp(builder))
264            .build_packet()
265    }
266
267    fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
268        LayerStack::new()
269            .push(LayerStackEntry::Ethernet(
270                EthernetBuilder::new()
271                    .dst(MacAddress::BROADCAST)
272                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
273            ))
274            .push(LayerStackEntry::Ipv4(
275                Ipv4Builder::new().src(src_ip).dst(dst_ip),
276            ))
277            .push(LayerStackEntry::Udp(
278                UdpBuilder::new().src_port(sport).dst_port(dport),
279            ))
280            .build_packet()
281    }
282
283    #[test]
284    fn test_ingest_creates_conversation() {
285        let table = ConversationTable::with_default_config();
286        let pkt = make_tcp_packet(
287            Ipv4Addr::new(10, 0, 0, 1),
288            Ipv4Addr::new(10, 0, 0, 2),
289            12345,
290            80,
291            "S",
292        );
293
294        table
295            .ingest_packet(&pkt, Duration::from_secs(1), 0)
296            .unwrap();
297        assert_eq!(table.conversation_count(), 1);
298    }
299
300    #[test]
301    fn test_bidirectional_same_conversation() {
302        let table = ConversationTable::with_default_config();
303
304        // Forward packet
305        let pkt_fwd = make_tcp_packet(
306            Ipv4Addr::new(10, 0, 0, 1),
307            Ipv4Addr::new(10, 0, 0, 2),
308            12345,
309            80,
310            "S",
311        );
312        table
313            .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
314            .unwrap();
315
316        // Reverse packet
317        let pkt_rev = make_tcp_packet(
318            Ipv4Addr::new(10, 0, 0, 2),
319            Ipv4Addr::new(10, 0, 0, 1),
320            80,
321            12345,
322            "SA",
323        );
324        table
325            .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
326            .unwrap();
327
328        // Should be one conversation, not two
329        assert_eq!(table.conversation_count(), 1);
330
331        let conversations = table.into_conversations();
332        assert_eq!(conversations[0].total_packets(), 2);
333        assert_eq!(conversations[0].forward.packets, 1);
334        assert_eq!(conversations[0].reverse.packets, 1);
335    }
336
337    #[test]
338    fn test_different_flows_different_conversations() {
339        let table = ConversationTable::with_default_config();
340
341        let pkt1 = make_tcp_packet(
342            Ipv4Addr::new(10, 0, 0, 1),
343            Ipv4Addr::new(10, 0, 0, 2),
344            12345,
345            80,
346            "S",
347        );
348        let pkt2 = make_tcp_packet(
349            Ipv4Addr::new(10, 0, 0, 1),
350            Ipv4Addr::new(10, 0, 0, 3),
351            12345,
352            443,
353            "S",
354        );
355
356        table
357            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
358            .unwrap();
359        table
360            .ingest_packet(&pkt2, Duration::from_secs(2), 1)
361            .unwrap();
362
363        assert_eq!(table.conversation_count(), 2);
364    }
365
366    #[test]
367    fn test_udp_conversation() {
368        let table = ConversationTable::with_default_config();
369
370        let pkt = make_udp_packet(
371            Ipv4Addr::new(10, 0, 0, 1),
372            Ipv4Addr::new(10, 0, 0, 2),
373            12345,
374            53,
375        );
376        table
377            .ingest_packet(&pkt, Duration::from_secs(1), 0)
378            .unwrap();
379
380        let conversations = table.into_conversations();
381        assert_eq!(conversations.len(), 1);
382        assert!(matches!(
383            conversations[0].protocol_state,
384            ProtocolState::Udp(_)
385        ));
386    }
387
388    #[test]
389    fn test_evict_idle() {
390        let mut config = FlowConfig::default();
391        config.udp_timeout = Duration::from_secs(10);
392        let table = ConversationTable::new(config);
393
394        let pkt = make_udp_packet(
395            Ipv4Addr::new(10, 0, 0, 1),
396            Ipv4Addr::new(10, 0, 0, 2),
397            12345,
398            53,
399        );
400        table
401            .ingest_packet(&pkt, Duration::from_secs(1), 0)
402            .unwrap();
403        assert_eq!(table.conversation_count(), 1);
404
405        // Not yet timed out
406        let evicted = table.evict_idle(Duration::from_secs(5));
407        assert_eq!(evicted, 0);
408        assert_eq!(table.conversation_count(), 1);
409
410        // Now timed out
411        let evicted = table.evict_idle(Duration::from_secs(20));
412        assert_eq!(evicted, 1);
413        assert_eq!(table.conversation_count(), 0);
414    }
415
416    #[test]
417    fn test_into_conversations_sorted() {
418        let table = ConversationTable::with_default_config();
419
420        let pkt1 = make_tcp_packet(
421            Ipv4Addr::new(10, 0, 0, 1),
422            Ipv4Addr::new(10, 0, 0, 2),
423            12345,
424            80,
425            "S",
426        );
427        let pkt2 = make_tcp_packet(
428            Ipv4Addr::new(10, 0, 0, 1),
429            Ipv4Addr::new(10, 0, 0, 3),
430            12345,
431            443,
432            "S",
433        );
434
435        // Insert second flow first (later timestamp)
436        table
437            .ingest_packet(&pkt2, Duration::from_secs(5), 1)
438            .unwrap();
439        table
440            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
441            .unwrap();
442
443        let conversations = table.into_conversations();
444        assert!(conversations[0].start_time <= conversations[1].start_time);
445    }
446}