1use std::time::Duration;
2
3use super::config::FlowConfig;
4use super::icmp_state::IcmpFlowState;
5use super::key::{CanonicalKey, FlowDirection, TransportProtocol};
6use super::tcp_state::TcpConversationState;
7use super::udp_state::UdpFlowState;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConversationStatus {
12 Active,
14 HalfClosed,
16 Closed,
18 TimedOut,
20}
21
22impl ConversationStatus {
23 #[must_use]
24 pub fn name(&self) -> &'static str {
25 match self {
26 Self::Active => "Active",
27 Self::HalfClosed => "HalfClosed",
28 Self::Closed => "Closed",
29 Self::TimedOut => "TimedOut",
30 }
31 }
32}
33
34impl std::fmt::Display for ConversationStatus {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.write_str(self.name())
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DirectionStats {
43 pub packets: u64,
45 pub bytes: u64,
47 pub first_seen: Duration,
49 pub last_seen: Duration,
51 pub max_packet_len: Option<u64>,
53}
54
55impl DirectionStats {
56 #[must_use]
57 pub fn new(timestamp: Duration) -> Self {
58 Self {
59 packets: 0,
60 bytes: 0,
61 first_seen: timestamp,
62 last_seen: timestamp,
63 max_packet_len: None,
64 }
65 }
66
67 pub fn record_packet(&mut self, byte_count: u64, timestamp: Duration, track_max_len: bool) {
69 self.packets += 1;
70 self.bytes += byte_count;
71 self.last_seen = timestamp;
72 if track_max_len {
73 self.max_packet_len = Some(self.max_packet_len.unwrap_or(0).max(byte_count));
74 }
75 }
76}
77
78#[derive(Debug)]
80pub enum ProtocolState {
81 Tcp(TcpConversationState),
83 Udp(UdpFlowState),
85 Icmp(IcmpFlowState),
87 Icmpv6(IcmpFlowState),
89 ZWave(ZWaveFlowState),
91 Other,
93}
94
95#[derive(Debug, Clone)]
97pub struct ZWaveFlowState {
98 pub home_id: u32,
100 pub command_count: u64,
102 pub ack_count: u64,
104}
105
106#[derive(Debug)]
112pub struct ConversationState {
113 pub key: CanonicalKey,
115 pub status: ConversationStatus,
117 pub start_time: Duration,
119 pub last_seen: Duration,
121 pub forward: DirectionStats,
123 pub reverse: DirectionStats,
125 pub packet_indices: Vec<usize>,
127 pub protocol_state: ProtocolState,
129 pub max_flow_len: Option<u64>,
131}
132
133impl ConversationState {
134 #[must_use]
136 pub fn new(key: CanonicalKey, timestamp: Duration) -> Self {
137 let protocol_state = match key.protocol {
138 TransportProtocol::Tcp => ProtocolState::Tcp(TcpConversationState::new()),
139 TransportProtocol::Udp => ProtocolState::Udp(UdpFlowState::new()),
140 TransportProtocol::Icmp => ProtocolState::Icmp(IcmpFlowState::new(0, 0)),
141 TransportProtocol::Icmpv6 => ProtocolState::Icmpv6(IcmpFlowState::new(0, 0)),
142 _ => ProtocolState::Other,
143 };
144
145 Self {
146 key,
147 status: ConversationStatus::Active,
148 start_time: timestamp,
149 last_seen: timestamp,
150 forward: DirectionStats::new(timestamp),
151 reverse: DirectionStats::new(timestamp),
152 packet_indices: Vec::new(),
153 protocol_state,
154 max_flow_len: None,
155 }
156 }
157
158 #[must_use]
163 pub fn new_zwave(zwave_key: super::key::ZWaveKey, timestamp: Duration) -> Self {
164 use std::net::{IpAddr, Ipv4Addr};
165
166 let (key, _) = CanonicalKey::new(
170 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
171 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
172 u16::from(zwave_key.node_a),
173 u16::from(zwave_key.node_b),
174 TransportProtocol::Other(0),
175 None,
176 );
177
178 Self {
179 key,
180 status: ConversationStatus::Active,
181 start_time: timestamp,
182 last_seen: timestamp,
183 forward: DirectionStats::new(timestamp),
184 reverse: DirectionStats::new(timestamp),
185 packet_indices: Vec::new(),
186 protocol_state: ProtocolState::ZWave(ZWaveFlowState {
187 home_id: zwave_key.home_id,
188 command_count: 0,
189 ack_count: 0,
190 }),
191 max_flow_len: None,
192 }
193 }
194
195 #[must_use]
197 pub fn total_packets(&self) -> u64 {
198 self.forward.packets + self.reverse.packets
199 }
200
201 #[must_use]
203 pub fn total_bytes(&self) -> u64 {
204 self.forward.bytes + self.reverse.bytes
205 }
206
207 #[must_use]
209 pub fn duration(&self) -> Duration {
210 self.last_seen.saturating_sub(self.start_time)
211 }
212
213 pub fn record_packet(
215 &mut self,
216 direction: FlowDirection,
217 byte_count: u64,
218 timestamp: Duration,
219 packet_index: usize,
220 track_max_packet_len: bool,
221 track_max_flow_len: bool,
222 ) {
223 self.last_seen = timestamp;
224 self.packet_indices.push(packet_index);
225
226 match direction {
227 FlowDirection::Forward => {
228 self.forward
229 .record_packet(byte_count, timestamp, track_max_packet_len);
230 },
231 FlowDirection::Reverse => {
232 self.reverse
233 .record_packet(byte_count, timestamp, track_max_packet_len);
234 },
235 }
236
237 if track_max_flow_len {
239 self.max_flow_len = Some(self.max_flow_len.unwrap_or(0).max(byte_count));
240 }
241 }
242
243 pub fn update_status(&mut self) {
245 match &self.protocol_state {
246 ProtocolState::Tcp(tcp) => {
247 if tcp.conn_state.is_closed() {
248 self.status = ConversationStatus::Closed;
249 } else if matches!(
250 tcp.conn_state,
251 super::tcp_state::TcpConnectionState::FinWait1
252 | super::tcp_state::TcpConnectionState::FinWait2
253 | super::tcp_state::TcpConnectionState::CloseWait
254 | super::tcp_state::TcpConnectionState::Closing
255 | super::tcp_state::TcpConnectionState::LastAck
256 ) {
257 self.status = ConversationStatus::HalfClosed;
258 }
259 },
260 ProtocolState::Udp(udp) => {
261 self.status = udp.status;
262 },
263 ProtocolState::Icmp(icmp) => {
264 self.status = icmp.status;
265 },
266 ProtocolState::Icmpv6(icmpv6) => {
267 self.status = icmpv6.status;
268 },
269 ProtocolState::ZWave(_) => {},
270 ProtocolState::Other => {},
271 }
272 }
273
274 #[must_use]
276 pub fn is_timed_out(&self, now: Duration, config: &FlowConfig) -> bool {
277 let elapsed = now.saturating_sub(self.last_seen);
278 match &self.protocol_state {
279 ProtocolState::Tcp(tcp) => {
280 if tcp.conn_state.is_closed() {
281 false } else if tcp.conn_state.is_half_open() {
283 elapsed > config.tcp_half_open_timeout
284 } else {
285 elapsed > config.tcp_established_timeout
286 }
287 },
288 ProtocolState::Udp(_) => elapsed > config.udp_timeout,
289 ProtocolState::Icmp(_) | ProtocolState::Icmpv6(_) => elapsed > config.udp_timeout,
290 ProtocolState::ZWave(_) => elapsed > config.udp_timeout,
291 ProtocolState::Other => elapsed > config.udp_timeout,
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use std::net::{IpAddr, Ipv4Addr};
300
301 fn test_key() -> CanonicalKey {
302 let (key, _) = CanonicalKey::new(
303 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
304 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
305 12345,
306 80,
307 TransportProtocol::Tcp,
308 None,
309 );
310 key
311 }
312
313 #[test]
314 fn test_conversation_state_new() {
315 let state = ConversationState::new(test_key(), Duration::from_secs(1));
316 assert_eq!(state.status, ConversationStatus::Active);
317 assert_eq!(state.total_packets(), 0);
318 assert_eq!(state.total_bytes(), 0);
319 assert!(matches!(state.protocol_state, ProtocolState::Tcp(_)));
320 }
321
322 #[test]
323 fn test_record_packet() {
324 let mut state = ConversationState::new(test_key(), Duration::from_secs(1));
325
326 state.record_packet(FlowDirection::Forward, 100, Duration::from_secs(1), 0);
327 state.record_packet(FlowDirection::Reverse, 200, Duration::from_secs(2), 1);
328 state.record_packet(FlowDirection::Forward, 50, Duration::from_secs(3), 2);
329
330 assert_eq!(state.total_packets(), 3);
331 assert_eq!(state.total_bytes(), 350);
332 assert_eq!(state.forward.packets, 2);
333 assert_eq!(state.reverse.packets, 1);
334 assert_eq!(state.packet_indices, vec![0, 1, 2]);
335 assert_eq!(state.duration(), Duration::from_secs(2));
336 }
337
338 #[test]
339 fn test_udp_conversation() {
340 let (key, _) = CanonicalKey::new(
341 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
342 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
343 12345,
344 53,
345 TransportProtocol::Udp,
346 None,
347 );
348 let state = ConversationState::new(key, Duration::from_secs(0));
349 assert!(matches!(state.protocol_state, ProtocolState::Udp(_)));
350 }
351
352 #[test]
353 fn test_timeout_check() {
354 let mut state = ConversationState::new(test_key(), Duration::from_secs(0));
355 state.last_seen = Duration::from_secs(100);
356 let config = FlowConfig::default();
357
358 if let ProtocolState::Tcp(ref mut tcp) = state.protocol_state {
360 tcp.conn_state = super::super::tcp_state::TcpConnectionState::Established;
361 }
362
363 assert!(!state.is_timed_out(Duration::from_secs(86499), &config));
365
366 assert!(state.is_timed_out(Duration::from_secs(86501), &config));
368 }
369}