Skip to main content

stackforge_core/flow/
table.rs

1use std::time::Duration;
2
3use dashmap::DashMap;
4
5use crate::Packet;
6
7use super::config::FlowConfig;
8use super::error::FlowError;
9use super::key::{CanonicalKey, extract_key};
10use super::state::{ConversationState, ProtocolState};
11
12/// Thread-safe conversation tracking table backed by `DashMap`.
13///
14/// Supports concurrent packet ingestion from multiple threads while
15/// maintaining per-conversation state including TCP state machines
16/// and stream reassembly.
17pub struct ConversationTable {
18    conversations: DashMap<CanonicalKey, ConversationState>,
19    config: FlowConfig,
20}
21
22impl ConversationTable {
23    /// Create a new table with the given configuration.
24    #[must_use]
25    pub fn new(config: FlowConfig) -> Self {
26        Self {
27            conversations: DashMap::new(),
28            config,
29        }
30    }
31
32    /// Create a new table with default configuration.
33    #[must_use]
34    pub fn with_default_config() -> Self {
35        Self::new(FlowConfig::default())
36    }
37
38    /// Number of tracked conversations.
39    #[must_use]
40    pub fn conversation_count(&self) -> usize {
41        self.conversations.len()
42    }
43
44    /// Ingest a single parsed packet, updating or creating conversation state.
45    ///
46    /// `timestamp` is the packet capture timestamp (from PCAP metadata).
47    /// `packet_index` is the index of this packet in the original capture
48    /// (used for cross-referencing).
49    pub fn ingest_packet(
50        &self,
51        packet: &Packet,
52        timestamp: Duration,
53        packet_index: usize,
54    ) -> Result<(), FlowError> {
55        let (key, direction) = match extract_key(packet) {
56            Ok(result) => result,
57            Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
58                // Skip non-IP or non-TCP/UDP packets silently
59                return Ok(());
60            },
61            Err(e) => return Err(e),
62        };
63
64        let byte_count = packet.as_bytes().len() as u64;
65
66        // Use DashMap entry API for atomic get-or-insert + update
67        let mut entry = self
68            .conversations
69            .entry(key.clone())
70            .or_insert_with(|| ConversationState::new(key, timestamp));
71
72        let conv = entry.value_mut();
73
74        // Record packet stats
75        conv.record_packet(
76            direction,
77            byte_count,
78            timestamp,
79            packet_index,
80            self.config.track_max_packet_len,
81            self.config.track_max_flow_len,
82        );
83
84        // Process protocol-specific state
85        let buf = packet.as_bytes();
86        match &mut conv.protocol_state {
87            ProtocolState::Tcp(tcp_state) => {
88                if let Some(tcp) = packet.tcp() {
89                    tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
90                }
91            },
92            ProtocolState::Udp(udp_state) => {
93                udp_state.process_packet();
94            },
95            ProtocolState::Icmp(icmp_state) => {
96                // Get ICMP type and code from buffer
97                if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
98                    if buf.len() >= icmp_layer.start + 2 {
99                        let icmp_type = buf[icmp_layer.start];
100                        let icmp_code = buf[icmp_layer.start + 1];
101                        icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
102                    }
103                }
104            },
105            ProtocolState::Icmpv6(icmpv6_state) => {
106                // Get ICMPv6 type and code from buffer
107                if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
108                    if buf.len() >= icmpv6_layer.start + 2 {
109                        let icmpv6_type = buf[icmpv6_layer.start];
110                        let icmpv6_code = buf[icmpv6_layer.start + 1];
111                        icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
112                    }
113                }
114            },
115            ProtocolState::ZWave(_) => {},
116            ProtocolState::Other => {},
117        }
118
119        // Update conversation status from protocol state
120        conv.update_status();
121
122        Ok(())
123    }
124
125    /// Get a read reference to a specific conversation.
126    #[must_use]
127    pub fn get_conversation(
128        &self,
129        key: &CanonicalKey,
130    ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
131        self.conversations.get(key)
132    }
133
134    /// Evict conversations that have exceeded their idle timeout.
135    ///
136    /// Returns the number of evicted conversations.
137    #[must_use]
138    pub fn evict_idle(&self, now: Duration) -> usize {
139        let mut evicted = 0;
140        self.conversations.retain(|_, conv| {
141            if conv.is_timed_out(now, &self.config) {
142                evicted += 1;
143                false
144            } else {
145                true
146            }
147        });
148        evicted
149    }
150
151    /// Consume the table and return all conversations sorted by start time.
152    #[must_use]
153    pub fn into_conversations(self) -> Vec<ConversationState> {
154        let mut conversations: Vec<ConversationState> =
155            self.conversations.into_iter().map(|(_, v)| v).collect();
156        conversations.sort_by_key(|c| c.start_time);
157        conversations
158    }
159
160    /// Get a reference to the configuration.
161    #[must_use]
162    pub fn config(&self) -> &FlowConfig {
163        &self.config
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::layer::stack::{LayerStack, LayerStackEntry};
171    use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
172    use std::net::Ipv4Addr;
173
174    fn make_tcp_packet(
175        src_ip: Ipv4Addr,
176        dst_ip: Ipv4Addr,
177        sport: u16,
178        dport: u16,
179        flags: &str,
180    ) -> Packet {
181        let mut builder = TcpBuilder::new()
182            .src_port(sport)
183            .dst_port(dport)
184            .seq(1000)
185            .ack_num(0)
186            .window(65535);
187
188        for c in flags.chars() {
189            builder = match c {
190                'S' => builder.syn(),
191                'A' => builder.ack(),
192                'F' => builder.fin(),
193                'R' => builder.rst(),
194                _ => builder,
195            };
196        }
197
198        LayerStack::new()
199            .push(LayerStackEntry::Ethernet(
200                EthernetBuilder::new()
201                    .dst(MacAddress::BROADCAST)
202                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
203            ))
204            .push(LayerStackEntry::Ipv4(
205                Ipv4Builder::new().src(src_ip).dst(dst_ip),
206            ))
207            .push(LayerStackEntry::Tcp(builder))
208            .build_packet()
209    }
210
211    fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
212        LayerStack::new()
213            .push(LayerStackEntry::Ethernet(
214                EthernetBuilder::new()
215                    .dst(MacAddress::BROADCAST)
216                    .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
217            ))
218            .push(LayerStackEntry::Ipv4(
219                Ipv4Builder::new().src(src_ip).dst(dst_ip),
220            ))
221            .push(LayerStackEntry::Udp(
222                UdpBuilder::new().src_port(sport).dst_port(dport),
223            ))
224            .build_packet()
225    }
226
227    #[test]
228    fn test_ingest_creates_conversation() {
229        let table = ConversationTable::with_default_config();
230        let pkt = make_tcp_packet(
231            Ipv4Addr::new(10, 0, 0, 1),
232            Ipv4Addr::new(10, 0, 0, 2),
233            12345,
234            80,
235            "S",
236        );
237
238        table
239            .ingest_packet(&pkt, Duration::from_secs(1), 0)
240            .unwrap();
241        assert_eq!(table.conversation_count(), 1);
242    }
243
244    #[test]
245    fn test_bidirectional_same_conversation() {
246        let table = ConversationTable::with_default_config();
247
248        // Forward packet
249        let pkt_fwd = make_tcp_packet(
250            Ipv4Addr::new(10, 0, 0, 1),
251            Ipv4Addr::new(10, 0, 0, 2),
252            12345,
253            80,
254            "S",
255        );
256        table
257            .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
258            .unwrap();
259
260        // Reverse packet
261        let pkt_rev = make_tcp_packet(
262            Ipv4Addr::new(10, 0, 0, 2),
263            Ipv4Addr::new(10, 0, 0, 1),
264            80,
265            12345,
266            "SA",
267        );
268        table
269            .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
270            .unwrap();
271
272        // Should be one conversation, not two
273        assert_eq!(table.conversation_count(), 1);
274
275        let conversations = table.into_conversations();
276        assert_eq!(conversations[0].total_packets(), 2);
277        assert_eq!(conversations[0].forward.packets, 1);
278        assert_eq!(conversations[0].reverse.packets, 1);
279    }
280
281    #[test]
282    fn test_different_flows_different_conversations() {
283        let table = ConversationTable::with_default_config();
284
285        let pkt1 = make_tcp_packet(
286            Ipv4Addr::new(10, 0, 0, 1),
287            Ipv4Addr::new(10, 0, 0, 2),
288            12345,
289            80,
290            "S",
291        );
292        let pkt2 = make_tcp_packet(
293            Ipv4Addr::new(10, 0, 0, 1),
294            Ipv4Addr::new(10, 0, 0, 3),
295            12345,
296            443,
297            "S",
298        );
299
300        table
301            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
302            .unwrap();
303        table
304            .ingest_packet(&pkt2, Duration::from_secs(2), 1)
305            .unwrap();
306
307        assert_eq!(table.conversation_count(), 2);
308    }
309
310    #[test]
311    fn test_udp_conversation() {
312        let table = ConversationTable::with_default_config();
313
314        let pkt = make_udp_packet(
315            Ipv4Addr::new(10, 0, 0, 1),
316            Ipv4Addr::new(10, 0, 0, 2),
317            12345,
318            53,
319        );
320        table
321            .ingest_packet(&pkt, Duration::from_secs(1), 0)
322            .unwrap();
323
324        let conversations = table.into_conversations();
325        assert_eq!(conversations.len(), 1);
326        assert!(matches!(
327            conversations[0].protocol_state,
328            ProtocolState::Udp(_)
329        ));
330    }
331
332    #[test]
333    fn test_evict_idle() {
334        let mut config = FlowConfig::default();
335        config.udp_timeout = Duration::from_secs(10);
336        let table = ConversationTable::new(config);
337
338        let pkt = make_udp_packet(
339            Ipv4Addr::new(10, 0, 0, 1),
340            Ipv4Addr::new(10, 0, 0, 2),
341            12345,
342            53,
343        );
344        table
345            .ingest_packet(&pkt, Duration::from_secs(1), 0)
346            .unwrap();
347        assert_eq!(table.conversation_count(), 1);
348
349        // Not yet timed out
350        let evicted = table.evict_idle(Duration::from_secs(5));
351        assert_eq!(evicted, 0);
352        assert_eq!(table.conversation_count(), 1);
353
354        // Now timed out
355        let evicted = table.evict_idle(Duration::from_secs(20));
356        assert_eq!(evicted, 1);
357        assert_eq!(table.conversation_count(), 0);
358    }
359
360    #[test]
361    fn test_into_conversations_sorted() {
362        let table = ConversationTable::with_default_config();
363
364        let pkt1 = make_tcp_packet(
365            Ipv4Addr::new(10, 0, 0, 1),
366            Ipv4Addr::new(10, 0, 0, 2),
367            12345,
368            80,
369            "S",
370        );
371        let pkt2 = make_tcp_packet(
372            Ipv4Addr::new(10, 0, 0, 1),
373            Ipv4Addr::new(10, 0, 0, 3),
374            12345,
375            443,
376            "S",
377        );
378
379        // Insert second flow first (later timestamp)
380        table
381            .ingest_packet(&pkt2, Duration::from_secs(5), 1)
382            .unwrap();
383        table
384            .ingest_packet(&pkt1, Duration::from_secs(1), 0)
385            .unwrap();
386
387        let conversations = table.into_conversations();
388        assert!(conversations[0].start_time <= conversations[1].start_time);
389    }
390}