Skip to main content

stackforge_core/flow/
table.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use dashmap::DashMap;
5
6use crate::Packet;
7
8use super::config::FlowConfig;
9use super::error::FlowError;
10use super::key::{CanonicalKey, extract_key};
11use super::spill::MemoryTracker;
12use super::state::{ConversationState, ProtocolState};
13
14/// Thread-safe conversation tracking table backed by `DashMap`.
15///
16/// Supports concurrent packet ingestion from multiple threads while
17/// maintaining per-conversation state including TCP state machines
18/// and stream reassembly. Optionally tracks memory usage and spills
19/// reassembly buffers to disk when a budget is exceeded.
20pub struct ConversationTable {
21    conversations: DashMap<CanonicalKey, ConversationState>,
22    config: FlowConfig,
23    memory_tracker: Arc<MemoryTracker>,
24}
25
26impl ConversationTable {
27    /// Create a new table with the given configuration.
28    #[must_use]
29    pub fn new(config: FlowConfig) -> Self {
30        let memory_tracker = Arc::new(MemoryTracker::new(config.memory_budget));
31        Self {
32            conversations: DashMap::new(),
33            config,
34            memory_tracker,
35        }
36    }
37
38    /// Create a new table with default configuration.
39    #[must_use]
40    pub fn with_default_config() -> Self {
41        Self::new(FlowConfig::default())
42    }
43
44    /// Number of tracked conversations.
45    #[must_use]
46    pub fn conversation_count(&self) -> usize {
47        self.conversations.len()
48    }
49
50    /// Ingest a single parsed packet, updating or creating conversation state.
51    ///
52    /// `timestamp` is the packet capture timestamp (from PCAP metadata).
53    /// `packet_index` is the index of this packet in the original capture
54    /// (used for cross-referencing).
55    pub fn ingest_packet(
56        &self,
57        packet: &Packet,
58        timestamp: Duration,
59        packet_index: usize,
60    ) -> Result<(), FlowError> {
61        let (key, direction) = match extract_key(packet) {
62            Ok(result) => result,
63            Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
64                // Skip non-IP or non-TCP/UDP packets silently
65                return Ok(());
66            },
67            Err(e) => return Err(e),
68        };
69
70        let byte_count = packet.as_bytes().len() as u64;
71
72        // Use DashMap entry API for atomic get-or-insert + update
73        let mut entry = self
74            .conversations
75            .entry(key.clone())
76            .or_insert_with(|| ConversationState::new(key, timestamp));
77
78        let conv = entry.value_mut();
79
80        // Record packet stats
81        conv.record_packet(
82            direction,
83            byte_count,
84            timestamp,
85            packet_index,
86            self.config.track_max_packet_len,
87            self.config.track_max_flow_len,
88        );
89
90        // Process protocol-specific state
91        let buf = packet.as_bytes();
92        match &mut conv.protocol_state {
93            ProtocolState::Tcp(tcp_state) => {
94                if let Some(tcp) = packet.tcp() {
95                    tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
96                }
97            },
98            ProtocolState::Udp(udp_state) => {
99                udp_state.process_packet();
100            },
101            ProtocolState::Icmp(icmp_state) => {
102                // Get ICMP type and code from buffer
103                if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
104                    if buf.len() >= icmp_layer.start + 2 {
105                        let icmp_type = buf[icmp_layer.start];
106                        let icmp_code = buf[icmp_layer.start + 1];
107                        icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
108                    }
109                }
110            },
111            ProtocolState::Icmpv6(icmpv6_state) => {
112                // Get ICMPv6 type and code from buffer
113                if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
114                    if buf.len() >= icmpv6_layer.start + 2 {
115                        let icmpv6_type = buf[icmpv6_layer.start];
116                        let icmpv6_code = buf[icmpv6_layer.start + 1];
117                        icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
118                    }
119                }
120            },
121            ProtocolState::ZWave(_) => {},
122            ProtocolState::Other => {},
123        }
124
125        // Update conversation status from protocol state
126        conv.update_status();
127
128        // Track memory for TCP reassembly buffers
129        if self.memory_tracker.has_budget() {
130            if matches!(conv.protocol_state, ProtocolState::Tcp(_)) {
131                // Track bytes that TCP payload may have added
132                let tcp_payload_len = packet.tcp().map_or(0, |tcp| {
133                    let data_offset = tcp.data_offset(buf).unwrap_or(5) as usize * 4;
134                    let payload_start = tcp.index.start + data_offset;
135                    buf.len().saturating_sub(payload_start)
136                });
137                if tcp_payload_len > 0 {
138                    self.memory_tracker.add(tcp_payload_len);
139                }
140            }
141        }
142
143        // Drop the entry lock before spilling (which needs iter_mut)
144        drop(entry);
145
146        // Spill if over budget
147        if self.memory_tracker.is_over_budget() {
148            self.maybe_spill();
149        }
150
151        Ok(())
152    }
153
154    /// Spill the largest reassembly buffers to disk until under budget.
155    fn maybe_spill(&self) {
156        for mut entry in self.conversations.iter_mut() {
157            if !self.memory_tracker.is_over_budget() {
158                break;
159            }
160            if let ProtocolState::Tcp(ref mut tcp_state) = entry.value_mut().protocol_state {
161                let freed_fwd = tcp_state
162                    .reassembler_fwd
163                    .spill(self.config.spill_dir.as_deref())
164                    .unwrap_or(0);
165                let freed_rev = tcp_state
166                    .reassembler_rev
167                    .spill(self.config.spill_dir.as_deref())
168                    .unwrap_or(0);
169                let total_freed = freed_fwd + freed_rev;
170                if total_freed > 0 {
171                    self.memory_tracker.subtract(total_freed);
172                }
173            }
174        }
175    }
176
177    /// Get a read reference to a specific conversation.
178    #[must_use]
179    pub fn get_conversation(
180        &self,
181        key: &CanonicalKey,
182    ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
183        self.conversations.get(key)
184    }
185
186    /// Evict conversations that have exceeded their idle timeout.
187    ///
188    /// Returns the number of evicted conversations.
189    #[must_use]
190    pub fn evict_idle(&self, now: Duration) -> usize {
191        let mut evicted = 0;
192        self.conversations.retain(|_, conv| {
193            if conv.is_timed_out(now, &self.config) {
194                evicted += 1;
195                false
196            } else {
197                true
198            }
199        });
200        evicted
201    }
202
203    /// Consume the table and return all conversations sorted by start time.
204    #[must_use]
205    pub fn into_conversations(self) -> Vec<ConversationState> {
206        let mut conversations: Vec<ConversationState> =
207            self.conversations.into_iter().map(|(_, v)| v).collect();
208        conversations.sort_by_key(|c| c.start_time);
209        conversations
210    }
211
212    /// Get a reference to the configuration.
213    #[must_use]
214    pub fn config(&self) -> &FlowConfig {
215        &self.config
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::layer::stack::{LayerStack, LayerStackEntry};
223    use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
224    use std::net::Ipv4Addr;
225
226    fn make_tcp_packet(
227        src_ip: Ipv4Addr,
228        dst_ip: Ipv4Addr,
229        sport: u16,
230        dport: u16,
231        flags: &str,
232    ) -> Packet {
233        let mut builder = TcpBuilder::new()
234            .src_port(sport)
235            .dst_port(dport)
236            .seq(1000)
237            .ack_num(0)
238            .window(65535);
239
240        for c in flags.chars() {
241            builder = match c {
242                'S' => builder.syn(),
243                'A' => builder.ack(),
244                'F' => builder.fin(),
245                'R' => builder.rst(),
246                _ => builder,
247            };
248        }
249
250        LayerStack::new()
251            .push(LayerStackEntry::Ethernet(
252                EthernetBuilder::new()
253                    .dst(MacAddress::BROADCAST)
254                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
255            ))
256            .push(LayerStackEntry::Ipv4(
257                Ipv4Builder::new().src(src_ip).dst(dst_ip),
258            ))
259            .push(LayerStackEntry::Tcp(builder))
260            .build_packet()
261    }
262
263    fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
264        LayerStack::new()
265            .push(LayerStackEntry::Ethernet(
266                EthernetBuilder::new()
267                    .dst(MacAddress::BROADCAST)
268                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
269            ))
270            .push(LayerStackEntry::Ipv4(
271                Ipv4Builder::new().src(src_ip).dst(dst_ip),
272            ))
273            .push(LayerStackEntry::Udp(
274                UdpBuilder::new().src_port(sport).dst_port(dport),
275            ))
276            .build_packet()
277    }
278
279    #[test]
280    fn test_ingest_creates_conversation() {
281        let table = ConversationTable::with_default_config();
282        let pkt = make_tcp_packet(
283            Ipv4Addr::new(10, 0, 0, 1),
284            Ipv4Addr::new(10, 0, 0, 2),
285            12345,
286            80,
287            "S",
288        );
289
290        table
291            .ingest_packet(&pkt, Duration::from_secs(1), 0)
292            .unwrap();
293        assert_eq!(table.conversation_count(), 1);
294    }
295
296    #[test]
297    fn test_bidirectional_same_conversation() {
298        let table = ConversationTable::with_default_config();
299
300        // Forward packet
301        let pkt_fwd = make_tcp_packet(
302            Ipv4Addr::new(10, 0, 0, 1),
303            Ipv4Addr::new(10, 0, 0, 2),
304            12345,
305            80,
306            "S",
307        );
308        table
309            .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
310            .unwrap();
311
312        // Reverse packet
313        let pkt_rev = make_tcp_packet(
314            Ipv4Addr::new(10, 0, 0, 2),
315            Ipv4Addr::new(10, 0, 0, 1),
316            80,
317            12345,
318            "SA",
319        );
320        table
321            .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
322            .unwrap();
323
324        // Should be one conversation, not two
325        assert_eq!(table.conversation_count(), 1);
326
327        let conversations = table.into_conversations();
328        assert_eq!(conversations[0].total_packets(), 2);
329        assert_eq!(conversations[0].forward.packets, 1);
330        assert_eq!(conversations[0].reverse.packets, 1);
331    }
332
333    #[test]
334    fn test_different_flows_different_conversations() {
335        let table = ConversationTable::with_default_config();
336
337        let pkt1 = make_tcp_packet(
338            Ipv4Addr::new(10, 0, 0, 1),
339            Ipv4Addr::new(10, 0, 0, 2),
340            12345,
341            80,
342            "S",
343        );
344        let pkt2 = make_tcp_packet(
345            Ipv4Addr::new(10, 0, 0, 1),
346            Ipv4Addr::new(10, 0, 0, 3),
347            12345,
348            443,
349            "S",
350        );
351
352        table
353            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
354            .unwrap();
355        table
356            .ingest_packet(&pkt2, Duration::from_secs(2), 1)
357            .unwrap();
358
359        assert_eq!(table.conversation_count(), 2);
360    }
361
362    #[test]
363    fn test_udp_conversation() {
364        let table = ConversationTable::with_default_config();
365
366        let pkt = make_udp_packet(
367            Ipv4Addr::new(10, 0, 0, 1),
368            Ipv4Addr::new(10, 0, 0, 2),
369            12345,
370            53,
371        );
372        table
373            .ingest_packet(&pkt, Duration::from_secs(1), 0)
374            .unwrap();
375
376        let conversations = table.into_conversations();
377        assert_eq!(conversations.len(), 1);
378        assert!(matches!(
379            conversations[0].protocol_state,
380            ProtocolState::Udp(_)
381        ));
382    }
383
384    #[test]
385    fn test_evict_idle() {
386        let mut config = FlowConfig::default();
387        config.udp_timeout = Duration::from_secs(10);
388        let table = ConversationTable::new(config);
389
390        let pkt = make_udp_packet(
391            Ipv4Addr::new(10, 0, 0, 1),
392            Ipv4Addr::new(10, 0, 0, 2),
393            12345,
394            53,
395        );
396        table
397            .ingest_packet(&pkt, Duration::from_secs(1), 0)
398            .unwrap();
399        assert_eq!(table.conversation_count(), 1);
400
401        // Not yet timed out
402        let evicted = table.evict_idle(Duration::from_secs(5));
403        assert_eq!(evicted, 0);
404        assert_eq!(table.conversation_count(), 1);
405
406        // Now timed out
407        let evicted = table.evict_idle(Duration::from_secs(20));
408        assert_eq!(evicted, 1);
409        assert_eq!(table.conversation_count(), 0);
410    }
411
412    #[test]
413    fn test_into_conversations_sorted() {
414        let table = ConversationTable::with_default_config();
415
416        let pkt1 = make_tcp_packet(
417            Ipv4Addr::new(10, 0, 0, 1),
418            Ipv4Addr::new(10, 0, 0, 2),
419            12345,
420            80,
421            "S",
422        );
423        let pkt2 = make_tcp_packet(
424            Ipv4Addr::new(10, 0, 0, 1),
425            Ipv4Addr::new(10, 0, 0, 3),
426            12345,
427            443,
428            "S",
429        );
430
431        // Insert second flow first (later timestamp)
432        table
433            .ingest_packet(&pkt2, Duration::from_secs(5), 1)
434            .unwrap();
435        table
436            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
437            .unwrap();
438
439        let conversations = table.into_conversations();
440        assert!(conversations[0].start_time <= conversations[1].start_time);
441    }
442}