1use std::time::Duration;
2
3use dashmap::DashMap;
4
5use crate::Packet;
6
7use super::config::FlowConfig;
8use super::error::FlowError;
9use super::key::{CanonicalKey, extract_key};
10use super::state::{ConversationState, ProtocolState};
11
12pub struct ConversationTable {
18 conversations: DashMap<CanonicalKey, ConversationState>,
19 config: FlowConfig,
20}
21
22impl ConversationTable {
23 pub fn new(config: FlowConfig) -> Self {
25 Self {
26 conversations: DashMap::new(),
27 config,
28 }
29 }
30
31 pub fn with_default_config() -> Self {
33 Self::new(FlowConfig::default())
34 }
35
36 pub fn conversation_count(&self) -> usize {
38 self.conversations.len()
39 }
40
41 pub fn ingest_packet(
47 &self,
48 packet: &Packet,
49 timestamp: Duration,
50 packet_index: usize,
51 ) -> Result<(), FlowError> {
52 let (key, direction) = match extract_key(packet) {
53 Ok(result) => result,
54 Err(FlowError::NoIpLayer | FlowError::NoTransportLayer) => {
55 return Ok(());
57 },
58 Err(e) => return Err(e),
59 };
60
61 let byte_count = packet.as_bytes().len() as u64;
62
63 let mut entry = self
65 .conversations
66 .entry(key.clone())
67 .or_insert_with(|| ConversationState::new(key, timestamp));
68
69 let conv = entry.value_mut();
70
71 conv.record_packet(direction, byte_count, timestamp, packet_index);
73
74 let buf = packet.as_bytes();
76 match &mut conv.protocol_state {
77 ProtocolState::Tcp(tcp_state) => {
78 if let Some(tcp) = packet.tcp() {
79 tcp_state.process_packet(direction, &tcp, buf, &self.config)?;
80 }
81 },
82 ProtocolState::Udp(udp_state) => {
83 udp_state.process_packet();
84 },
85 ProtocolState::Other => {},
86 }
87
88 conv.update_status();
90
91 Ok(())
92 }
93
94 pub fn get_conversation(
96 &self,
97 key: &CanonicalKey,
98 ) -> Option<dashmap::mapref::one::Ref<'_, CanonicalKey, ConversationState>> {
99 self.conversations.get(key)
100 }
101
102 pub fn evict_idle(&self, now: Duration) -> usize {
106 let mut evicted = 0;
107 self.conversations.retain(|_, conv| {
108 if conv.is_timed_out(now, &self.config) {
109 evicted += 1;
110 false
111 } else {
112 true
113 }
114 });
115 evicted
116 }
117
118 pub fn into_conversations(self) -> Vec<ConversationState> {
120 let mut conversations: Vec<ConversationState> =
121 self.conversations.into_iter().map(|(_, v)| v).collect();
122 conversations.sort_by_key(|c| c.start_time);
123 conversations
124 }
125
126 pub fn config(&self) -> &FlowConfig {
128 &self.config
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::layer::stack::{LayerStack, LayerStackEntry};
136 use crate::{EthernetBuilder, Ipv4Builder, MacAddress, TcpBuilder, UdpBuilder};
137 use std::net::Ipv4Addr;
138
139 fn make_tcp_packet(
140 src_ip: Ipv4Addr,
141 dst_ip: Ipv4Addr,
142 sport: u16,
143 dport: u16,
144 flags: &str,
145 ) -> Packet {
146 let mut builder = TcpBuilder::new()
147 .src_port(sport)
148 .dst_port(dport)
149 .seq(1000)
150 .ack_num(0)
151 .window(65535);
152
153 for c in flags.chars() {
154 builder = match c {
155 'S' => builder.syn(),
156 'A' => builder.ack(),
157 'F' => builder.fin(),
158 'R' => builder.rst(),
159 _ => builder,
160 };
161 }
162
163 LayerStack::new()
164 .push(LayerStackEntry::Ethernet(
165 EthernetBuilder::new()
166 .dst(MacAddress::BROADCAST)
167 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
168 ))
169 .push(LayerStackEntry::Ipv4(
170 Ipv4Builder::new().src(src_ip).dst(dst_ip),
171 ))
172 .push(LayerStackEntry::Tcp(builder))
173 .build_packet()
174 }
175
176 fn make_udp_packet(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, sport: u16, dport: u16) -> Packet {
177 LayerStack::new()
178 .push(LayerStackEntry::Ethernet(
179 EthernetBuilder::new()
180 .dst(MacAddress::BROADCAST)
181 .src(MacAddress::new([0, 1, 2, 3, 4, 5])),
182 ))
183 .push(LayerStackEntry::Ipv4(
184 Ipv4Builder::new().src(src_ip).dst(dst_ip),
185 ))
186 .push(LayerStackEntry::Udp(
187 UdpBuilder::new().src_port(sport).dst_port(dport),
188 ))
189 .build_packet()
190 }
191
192 #[test]
193 fn test_ingest_creates_conversation() {
194 let table = ConversationTable::with_default_config();
195 let pkt = make_tcp_packet(
196 Ipv4Addr::new(10, 0, 0, 1),
197 Ipv4Addr::new(10, 0, 0, 2),
198 12345,
199 80,
200 "S",
201 );
202
203 table
204 .ingest_packet(&pkt, Duration::from_secs(1), 0)
205 .unwrap();
206 assert_eq!(table.conversation_count(), 1);
207 }
208
209 #[test]
210 fn test_bidirectional_same_conversation() {
211 let table = ConversationTable::with_default_config();
212
213 let pkt_fwd = make_tcp_packet(
215 Ipv4Addr::new(10, 0, 0, 1),
216 Ipv4Addr::new(10, 0, 0, 2),
217 12345,
218 80,
219 "S",
220 );
221 table
222 .ingest_packet(&pkt_fwd, Duration::from_secs(1), 0)
223 .unwrap();
224
225 let pkt_rev = make_tcp_packet(
227 Ipv4Addr::new(10, 0, 0, 2),
228 Ipv4Addr::new(10, 0, 0, 1),
229 80,
230 12345,
231 "SA",
232 );
233 table
234 .ingest_packet(&pkt_rev, Duration::from_secs(2), 1)
235 .unwrap();
236
237 assert_eq!(table.conversation_count(), 1);
239
240 let conversations = table.into_conversations();
241 assert_eq!(conversations[0].total_packets(), 2);
242 assert_eq!(conversations[0].forward.packets, 1);
243 assert_eq!(conversations[0].reverse.packets, 1);
244 }
245
246 #[test]
247 fn test_different_flows_different_conversations() {
248 let table = ConversationTable::with_default_config();
249
250 let pkt1 = make_tcp_packet(
251 Ipv4Addr::new(10, 0, 0, 1),
252 Ipv4Addr::new(10, 0, 0, 2),
253 12345,
254 80,
255 "S",
256 );
257 let pkt2 = make_tcp_packet(
258 Ipv4Addr::new(10, 0, 0, 1),
259 Ipv4Addr::new(10, 0, 0, 3),
260 12345,
261 443,
262 "S",
263 );
264
265 table
266 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
267 .unwrap();
268 table
269 .ingest_packet(&pkt2, Duration::from_secs(2), 1)
270 .unwrap();
271
272 assert_eq!(table.conversation_count(), 2);
273 }
274
275 #[test]
276 fn test_udp_conversation() {
277 let table = ConversationTable::with_default_config();
278
279 let pkt = make_udp_packet(
280 Ipv4Addr::new(10, 0, 0, 1),
281 Ipv4Addr::new(10, 0, 0, 2),
282 12345,
283 53,
284 );
285 table
286 .ingest_packet(&pkt, Duration::from_secs(1), 0)
287 .unwrap();
288
289 let conversations = table.into_conversations();
290 assert_eq!(conversations.len(), 1);
291 assert!(matches!(
292 conversations[0].protocol_state,
293 ProtocolState::Udp(_)
294 ));
295 }
296
297 #[test]
298 fn test_evict_idle() {
299 let mut config = FlowConfig::default();
300 config.udp_timeout = Duration::from_secs(10);
301 let table = ConversationTable::new(config);
302
303 let pkt = make_udp_packet(
304 Ipv4Addr::new(10, 0, 0, 1),
305 Ipv4Addr::new(10, 0, 0, 2),
306 12345,
307 53,
308 );
309 table
310 .ingest_packet(&pkt, Duration::from_secs(1), 0)
311 .unwrap();
312 assert_eq!(table.conversation_count(), 1);
313
314 let evicted = table.evict_idle(Duration::from_secs(5));
316 assert_eq!(evicted, 0);
317 assert_eq!(table.conversation_count(), 1);
318
319 let evicted = table.evict_idle(Duration::from_secs(20));
321 assert_eq!(evicted, 1);
322 assert_eq!(table.conversation_count(), 0);
323 }
324
325 #[test]
326 fn test_into_conversations_sorted() {
327 let table = ConversationTable::with_default_config();
328
329 let pkt1 = make_tcp_packet(
330 Ipv4Addr::new(10, 0, 0, 1),
331 Ipv4Addr::new(10, 0, 0, 2),
332 12345,
333 80,
334 "S",
335 );
336 let pkt2 = make_tcp_packet(
337 Ipv4Addr::new(10, 0, 0, 1),
338 Ipv4Addr::new(10, 0, 0, 3),
339 12345,
340 443,
341 "S",
342 );
343
344 table
346 .ingest_packet(&pkt2, Duration::from_secs(5), 1)
347 .unwrap();
348 table
349 .ingest_packet(&pkt1, Duration::from_secs(1), 0)
350 .unwrap();
351
352 let conversations = table.into_conversations();
353 assert!(conversations[0].start_time <= conversations[1].start_time);
354 }
355}