1use std::time::Duration;
2
3use dashmap::DashMap;
4
5use crate::Packet;
6
7use super::config::FlowConfig;
8use super::error::FlowError;
9use super::key::{CanonicalKey, extract_key};
10use super::state::{ConversationState, ProtocolState};
11
12pub struct ConversationTable {
18 conversations: DashMap<CanonicalKey, ConversationState>,
19 config: FlowConfig,
20}
21
22impl ConversationTable {
23 #[must_use]
25 pub fn new(config: FlowConfig) -> Self {
26 Self {
27 conversations: DashMap::new(),
28 config,
29 }
30 }
31
32 #[must_use]
34 pub fn with_default_config() -> Self {
35 Self::new(FlowConfig::default())
36 }
37
38 #[must_use]
40 pub fn conversation_count(&self) -> usize {
41 self.conversations.len()
42 }
43
44 pub fn ingest_packet(
50 &self,
51 packet: &Packet,
52 timestamp: Duration,
53 packet_index: usize,
54 ) -> Result<(), FlowError> {
55 let (key, direction) = match extract_key(packet) {
56 Ok(result) => result,
57 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
58 return Ok(());
60 },
61 Err(e) => return Err(e),
62 };
63
64 let byte_count = packet.as_bytes().len() as u64;
65
66 let mut entry = self
68 .conversations
69 .entry(key.clone())
70 .or_insert_with(|| ConversationState::new(key, timestamp));
71
72 let conv = entry.value_mut();
73
74 conv.record_packet(
76 direction,
77 byte_count,
78 timestamp,
79 packet_index,
80 self.config.track_max_packet_len,
81 self.config.track_max_flow_len,
82 );
83
84 let buf = packet.as_bytes();
86 match &mut conv.protocol_state {
87 ProtocolState::Tcp(tcp_state) => {
88 if let Some(tcp) = packet.tcp() {
89 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
90 }
91 },
92 ProtocolState::Udp(udp_state) => {
93 udp_state.process_packet();
94 },
95 ProtocolState::Icmp(icmp_state) => {
96 if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) {
98 if buf.len() >= icmp_layer.start + 2 {
99 let icmp_type = buf[icmp_layer.start];
100 let icmp_code = buf[icmp_layer.start + 1];
101 icmp_state.process_packet(packet, buf, icmp_type, icmp_code);
102 }
103 }
104 },
105 ProtocolState::Icmpv6(icmpv6_state) => {
106 if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) {
108 if buf.len() >= icmpv6_layer.start + 2 {
109 let icmpv6_type = buf[icmpv6_layer.start];
110 let icmpv6_code = buf[icmpv6_layer.start + 1];
111 icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code);
112 }
113 }
114 },
115 ProtocolState::ZWave(_) => {},
116 ProtocolState::Other => {},
117 }
118
119 conv.update_status();
121
122 Ok(())
123 }
124
125 #[must_use]
127 pub fn get_conversation(
128 &self,
129 key: &CanonicalKey,
130 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
131 self.conversations.get(key)
132 }
133
134 #[must_use]
138 pub fn evict_idle(&self, now: Duration) -> usize {
139 let mut evicted = 0;
140 self.conversations.retain(|_, conv| {
141 if conv.is_timed_out(now, &self.config) {
142 evicted += 1;
143 false
144 } else {
145 true
146 }
147 });
148 evicted
149 }
150
151 #[must_use]
153 pub fn into_conversations(self) -> Vec<ConversationState> {
154 let mut conversations: Vec<ConversationState> =
155 self.conversations.into_iter().map(|(_, v)| v).collect();
156 conversations.sort_by_key(|c| c.start_time);
157 conversations
158 }
159
160 #[must_use]
162 pub fn config(&self) -> &FlowConfig {
163 &self.config
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::layer::stack::{LayerStack, LayerStackEntry};
171 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
172 use std::net::Ipv4Addr;
173
174 fn make_tcp_packet(
175 src_ip: Ipv4Addr,
176 dst_ip: Ipv4Addr,
177 sport: u16,
178 dport: u16,
179 flags: &str,
180 ) -> Packet {
181 let mut builder = TcpBuilder::new()
182 .src_port(sport)
183 .dst_port(dport)
184 .seq(1000)
185 .ack_num(0)
186 .window(65535);
187
188 for c in flags.chars() {
189 builder = match c {
190 'S' => builder.syn(),
191 'A' => builder.ack(),
192 'F' => builder.fin(),
193 'R' => builder.rst(),
194 _ => builder,
195 };
196 }
197
198 LayerStack::new()
199 .push(LayerStackEntry::Ethernet(
200 EthernetBuilder::new()
201 .dst(MacAddress::BROADCAST)
202 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
203 ))
204 .push(LayerStackEntry::Ipv4(
205 Ipv4Builder::new().src(src_ip).dst(dst_ip),
206 ))
207 .push(LayerStackEntry::Tcp(builder))
208 .build_packet()
209 }
210
211 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
212 LayerStack::new()
213 .push(LayerStackEntry::Ethernet(
214 EthernetBuilder::new()
215 .dst(MacAddress::BROADCAST)
216 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
217 ))
218 .push(LayerStackEntry::Ipv4(
219 Ipv4Builder::new().src(src_ip).dst(dst_ip),
220 ))
221 .push(LayerStackEntry::Udp(
222 UdpBuilder::new().src_port(sport).dst_port(dport),
223 ))
224 .build_packet()
225 }
226
227 #[test]
228 fn test_ingest_creates_conversation() {
229 let table = ConversationTable::with_default_config();
230 let pkt = make_tcp_packet(
231 Ipv4Addr::new(10, 0, 0, 1),
232 Ipv4Addr::new(10, 0, 0, 2),
233 12345,
234 80,
235 "S",
236 );
237
238 table
239 .ingest_packet(&pkt, Duration::from_secs(1), 0)
240 .unwrap();
241 assert_eq!(table.conversation_count(), 1);
242 }
243
244 #[test]
245 fn test_bidirectional_same_conversation() {
246 let table = ConversationTable::with_default_config();
247
248 let pkt_fwd = make_tcp_packet(
250 Ipv4Addr::new(10, 0, 0, 1),
251 Ipv4Addr::new(10, 0, 0, 2),
252 12345,
253 80,
254 "S",
255 );
256 table
257 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
258 .unwrap();
259
260 let pkt_rev = make_tcp_packet(
262 Ipv4Addr::new(10, 0, 0, 2),
263 Ipv4Addr::new(10, 0, 0, 1),
264 80,
265 12345,
266 "SA",
267 );
268 table
269 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
270 .unwrap();
271
272 assert_eq!(table.conversation_count(), 1);
274
275 let conversations = table.into_conversations();
276 assert_eq!(conversations[0].total_packets(), 2);
277 assert_eq!(conversations[0].forward.packets, 1);
278 assert_eq!(conversations[0].reverse.packets, 1);
279 }
280
281 #[test]
282 fn test_different_flows_different_conversations() {
283 let table = ConversationTable::with_default_config();
284
285 let pkt1 = make_tcp_packet(
286 Ipv4Addr::new(10, 0, 0, 1),
287 Ipv4Addr::new(10, 0, 0, 2),
288 12345,
289 80,
290 "S",
291 );
292 let pkt2 = make_tcp_packet(
293 Ipv4Addr::new(10, 0, 0, 1),
294 Ipv4Addr::new(10, 0, 0, 3),
295 12345,
296 443,
297 "S",
298 );
299
300 table
301 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
302 .unwrap();
303 table
304 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
305 .unwrap();
306
307 assert_eq!(table.conversation_count(), 2);
308 }
309
310 #[test]
311 fn test_udp_conversation() {
312 let table = ConversationTable::with_default_config();
313
314 let pkt = make_udp_packet(
315 Ipv4Addr::new(10, 0, 0, 1),
316 Ipv4Addr::new(10, 0, 0, 2),
317 12345,
318 53,
319 );
320 table
321 .ingest_packet(&pkt, Duration::from_secs(1), 0)
322 .unwrap();
323
324 let conversations = table.into_conversations();
325 assert_eq!(conversations.len(), 1);
326 assert!(matches!(
327 conversations[0].protocol_state,
328 ProtocolState::Udp(_)
329 ));
330 }
331
332 #[test]
333 fn test_evict_idle() {
334 let mut config = FlowConfig::default();
335 config.udp_timeout = Duration::from_secs(10);
336 let table = ConversationTable::new(config);
337
338 let pkt = make_udp_packet(
339 Ipv4Addr::new(10, 0, 0, 1),
340 Ipv4Addr::new(10, 0, 0, 2),
341 12345,
342 53,
343 );
344 table
345 .ingest_packet(&pkt, Duration::from_secs(1), 0)
346 .unwrap();
347 assert_eq!(table.conversation_count(), 1);
348
349 let evicted = table.evict_idle(Duration::from_secs(5));
351 assert_eq!(evicted, 0);
352 assert_eq!(table.conversation_count(), 1);
353
354 let evicted = table.evict_idle(Duration::from_secs(20));
356 assert_eq!(evicted, 1);
357 assert_eq!(table.conversation_count(), 0);
358 }
359
360 #[test]
361 fn test_into_conversations_sorted() {
362 let table = ConversationTable::with_default_config();
363
364 let pkt1 = make_tcp_packet(
365 Ipv4Addr::new(10, 0, 0, 1),
366 Ipv4Addr::new(10, 0, 0, 2),
367 12345,
368 80,
369 "S",
370 );
371 let pkt2 = make_tcp_packet(
372 Ipv4Addr::new(10, 0, 0, 1),
373 Ipv4Addr::new(10, 0, 0, 3),
374 12345,
375 443,
376 "S",
377 );
378
379 table
381 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
382 .unwrap();
383 table
384 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
385 .unwrap();
386
387 let conversations = table.into_conversations();
388 assert!(conversations[0].start_time <= conversations[1].start_time);
389 }
390}