1use std::time::Duration;
2
3use super::config::FlowConfig;
4use super::icmp_state::IcmpFlowState;
5use super::key::{CanonicalKey, FlowDirection, TransportProtocol};
6use super::tcp_state::TcpConversationState;
7use super::udp_state::UdpFlowState;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConversationStatus {
12 Active,
14 HalfClosed,
16 Closed,
18 TimedOut,
20}
21
22impl ConversationStatus {
23 #[must_use]
24 pub fn name(&self) -> &'static str {
25 match self {
26 Self::Active => "Active",
27 Self::HalfClosed => "HalfClosed",
28 Self::Closed => "Closed",
29 Self::TimedOut => "TimedOut",
30 }
31 }
32}
33
34impl std::fmt::Display for ConversationStatus {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.write_str(self.name())
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DirectionStats {
43 pub packets: u64,
45 pub bytes: u64,
47 pub first_seen: Duration,
49 pub last_seen: Duration,
51 pub max_packet_len: Option<u64>,
53}
54
55impl DirectionStats {
56 #[must_use]
57 pub fn new(timestamp: Duration) -> Self {
58 Self {
59 packets: 0,
60 bytes: 0,
61 first_seen: timestamp,
62 last_seen: timestamp,
63 max_packet_len: None,
64 }
65 }
66
67 pub fn record_packet(&mut self, byte_count: u64, timestamp: Duration, track_max_len: bool) {
69 self.packets += 1;
70 self.bytes += byte_count;
71 self.last_seen = timestamp;
72 if track_max_len {
73 self.max_packet_len = Some(self.max_packet_len.unwrap_or(0).max(byte_count));
74 }
75 }
76}
77
78#[derive(Debug)]
80pub enum ProtocolState {
81 Tcp(TcpConversationState),
83 Udp(UdpFlowState),
85 Icmp(IcmpFlowState),
87 Icmpv6(IcmpFlowState),
89 ZWave(ZWaveFlowState),
91 Other,
93}
94
95#[derive(Debug, Clone)]
97pub struct ZWaveFlowState {
98 pub home_id: u32,
100 pub command_count: u64,
102 pub ack_count: u64,
104}
105
106#[derive(Debug)]
112pub struct ConversationState {
113 pub key: CanonicalKey,
115 pub status: ConversationStatus,
117 pub start_time: Duration,
119 pub last_seen: Duration,
121 pub forward: DirectionStats,
123 pub reverse: DirectionStats,
125 pub packet_indices: Vec<usize>,
127 pub protocol_state: ProtocolState,
129 pub max_flow_len: Option<u64>,
131}
132
133impl ConversationState {
134 #[must_use]
136 pub fn new(key: CanonicalKey, timestamp: Duration) -> Self {
137 let protocol_state = match key.protocol {
138 TransportProtocol::Tcp => ProtocolState::Tcp(TcpConversationState::new()),
139 TransportProtocol::Udp => ProtocolState::Udp(UdpFlowState::new()),
140 TransportProtocol::Icmp => ProtocolState::Icmp(IcmpFlowState::new(0, 0)),
141 TransportProtocol::Icmpv6 => ProtocolState::Icmpv6(IcmpFlowState::new(0, 0)),
142 _ => ProtocolState::Other,
143 };
144
145 Self {
146 key,
147 status: ConversationStatus::Active,
148 start_time: timestamp,
149 last_seen: timestamp,
150 forward: DirectionStats::new(timestamp),
151 reverse: DirectionStats::new(timestamp),
152 packet_indices: Vec::new(),
153 protocol_state,
154 max_flow_len: None,
155 }
156 }
157
158 #[must_use]
163 pub fn new_zwave(zwave_key: super::key::ZWaveKey, timestamp: Duration) -> Self {
164 use std::net::{IpAddr, Ipv4Addr};
165
166 let (key, _) = CanonicalKey::new(
170 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
171 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
172 u16::from(zwave_key.node_a),
173 u16::from(zwave_key.node_b),
174 TransportProtocol::Other(0),
175 None,
176 );
177
178 Self {
179 key,
180 status: ConversationStatus::Active,
181 start_time: timestamp,
182 last_seen: timestamp,
183 forward: DirectionStats::new(timestamp),
184 reverse: DirectionStats::new(timestamp),
185 packet_indices: Vec::new(),
186 protocol_state: ProtocolState::ZWave(ZWaveFlowState {
187 home_id: zwave_key.home_id,
188 command_count: 0,
189 ack_count: 0,
190 }),
191 max_flow_len: None,
192 }
193 }
194
195 #[must_use]
197 pub fn total_packets(&self) -> u64 {
198 self.forward.packets + self.reverse.packets
199 }
200
201 #[must_use]
203 pub fn total_bytes(&self) -> u64 {
204 self.forward.bytes + self.reverse.bytes
205 }
206
207 #[must_use]
209 pub fn duration(&self) -> Duration {
210 self.last_seen.saturating_sub(self.start_time)
211 }
212
213 pub fn record_packet(
215 &mut self,
216 direction: FlowDirection,
217 byte_count: u64,
218 timestamp: Duration,
219 packet_index: usize,
220 track_max_packet_len: bool,
221 track_max_flow_len: bool,
222 store_packet_indices: bool,
223 ) {
224 self.last_seen = timestamp;
225 if store_packet_indices {
226 self.packet_indices.push(packet_index);
227 }
228
229 match direction {
230 FlowDirection::Forward => {
231 self.forward
232 .record_packet(byte_count, timestamp, track_max_packet_len);
233 },
234 FlowDirection::Reverse => {
235 self.reverse
236 .record_packet(byte_count, timestamp, track_max_packet_len);
237 },
238 }
239
240 if track_max_flow_len {
242 self.max_flow_len = Some(self.max_flow_len.unwrap_or(0).max(byte_count));
243 }
244 }
245
246 pub fn update_status(&mut self) {
248 match &self.protocol_state {
249 ProtocolState::Tcp(tcp) => {
250 if tcp.conn_state.is_closed() {
251 self.status = ConversationStatus::Closed;
252 } else if matches!(
253 tcp.conn_state,
254 super::tcp_state::TcpConnectionState::FinWait1
255 | super::tcp_state::TcpConnectionState::FinWait2
256 | super::tcp_state::TcpConnectionState::CloseWait
257 | super::tcp_state::TcpConnectionState::Closing
258 | super::tcp_state::TcpConnectionState::LastAck
259 ) {
260 self.status = ConversationStatus::HalfClosed;
261 }
262 },
263 ProtocolState::Udp(udp) => {
264 self.status = udp.status;
265 },
266 ProtocolState::Icmp(icmp) => {
267 self.status = icmp.status;
268 },
269 ProtocolState::Icmpv6(icmpv6) => {
270 self.status = icmpv6.status;
271 },
272 ProtocolState::ZWave(_) => {},
273 ProtocolState::Other => {},
274 }
275 }
276
277 #[must_use]
279 pub fn is_timed_out(&self, now: Duration, config: &FlowConfig) -> bool {
280 let elapsed = now.saturating_sub(self.last_seen);
281 match &self.protocol_state {
282 ProtocolState::Tcp(tcp) => {
283 if tcp.conn_state.is_closed() {
284 false } else if tcp.conn_state.is_half_open() {
286 elapsed > config.tcp_half_open_timeout
287 } else {
288 elapsed > config.tcp_established_timeout
289 }
290 },
291 ProtocolState::Udp(_) => elapsed > config.udp_timeout,
292 ProtocolState::Icmp(_) | ProtocolState::Icmpv6(_) => elapsed > config.udp_timeout,
293 ProtocolState::ZWave(_) => elapsed > config.udp_timeout,
294 ProtocolState::Other => elapsed > config.udp_timeout,
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use std::net::{IpAddr, Ipv4Addr};
303
304 fn test_key() -> CanonicalKey {
305 let (key, _) = CanonicalKey::new(
306 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
307 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
308 12345,
309 80,
310 TransportProtocol::Tcp,
311 None,
312 );
313 key
314 }
315
316 #[test]
317 fn test_conversation_state_new() {
318 let state = ConversationState::new(test_key(), Duration::from_secs(1));
319 assert_eq!(state.status, ConversationStatus::Active);
320 assert_eq!(state.total_packets(), 0);
321 assert_eq!(state.total_bytes(), 0);
322 assert!(matches!(state.protocol_state, ProtocolState::Tcp(_)));
323 }
324
325 #[test]
326 fn test_record_packet() {
327 let mut state = ConversationState::new(test_key(), Duration::from_secs(1));
328
329 state.record_packet(
330 FlowDirection::Forward,
331 100,
332 Duration::from_secs(1),
333 0,
334 false,
335 false,
336 true,
337 );
338 state.record_packet(
339 FlowDirection::Reverse,
340 200,
341 Duration::from_secs(2),
342 1,
343 false,
344 false,
345 true,
346 );
347 state.record_packet(
348 FlowDirection::Forward,
349 50,
350 Duration::from_secs(3),
351 2,
352 false,
353 false,
354 true,
355 );
356
357 assert_eq!(state.total_packets(), 3);
358 assert_eq!(state.total_bytes(), 350);
359 assert_eq!(state.forward.packets, 2);
360 assert_eq!(state.reverse.packets, 1);
361 assert_eq!(state.packet_indices, vec![0, 1, 2]);
362 assert_eq!(state.duration(), Duration::from_secs(2));
363 }
364
365 #[test]
366 fn test_udp_conversation() {
367 let (key, _) = CanonicalKey::new(
368 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
369 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
370 12345,
371 53,
372 TransportProtocol::Udp,
373 None,
374 );
375 let state = ConversationState::new(key, Duration::from_secs(0));
376 assert!(matches!(state.protocol_state, ProtocolState::Udp(_)));
377 }
378
379 #[test]
380 fn test_timeout_check() {
381 let mut state = ConversationState::new(test_key(), Duration::from_secs(0));
382 state.last_seen = Duration::from_secs(100);
383 let config = FlowConfig::default();
384
385 if let ProtocolState::Tcp(ref mut tcp) = state.protocol_state {
387 tcp.conn_state = super::super::tcp_state::TcpConnectionState::Established;
388 }
389
390 assert!(!state.is_timed_out(Duration::from_secs(86499), &config));
392
393 assert!(state.is_timed_out(Duration::from_secs(86501), &config));
395 }
396}