1use std::time::Duration;
2
3use super::config::FlowConfig;
4use super::key::{CanonicalKey, FlowDirection, TransportProtocol};
5use super::tcp_state::TcpConversationState;
6use super::udp_state::UdpFlowState;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ConversationStatus {
11 Active,
13 HalfClosed,
15 Closed,
17 TimedOut,
19}
20
21impl ConversationStatus {
22 pub fn name(&self) -> &'static str {
23 match self {
24 Self::Active => "Active",
25 Self::HalfClosed => "HalfClosed",
26 Self::Closed => "Closed",
27 Self::TimedOut => "TimedOut",
28 }
29 }
30}
31
32impl std::fmt::Display for ConversationStatus {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.write_str(self.name())
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct DirectionStats {
41 pub packets: u64,
43 pub bytes: u64,
45 pub first_seen: Duration,
47 pub last_seen: Duration,
49}
50
51impl DirectionStats {
52 pub fn new(timestamp: Duration) -> Self {
53 Self {
54 packets: 0,
55 bytes: 0,
56 first_seen: timestamp,
57 last_seen: timestamp,
58 }
59 }
60
61 pub fn record_packet(&mut self, byte_count: u64, timestamp: Duration) {
63 self.packets += 1;
64 self.bytes += byte_count;
65 self.last_seen = timestamp;
66 }
67}
68
69#[derive(Debug)]
71pub enum ProtocolState {
72 Tcp(TcpConversationState),
74 Udp(UdpFlowState),
76 Other,
78}
79
80#[derive(Debug)]
86pub struct ConversationState {
87 pub key: CanonicalKey,
89 pub status: ConversationStatus,
91 pub start_time: Duration,
93 pub last_seen: Duration,
95 pub forward: DirectionStats,
97 pub reverse: DirectionStats,
99 pub packet_indices: Vec<usize>,
101 pub protocol_state: ProtocolState,
103}
104
105impl ConversationState {
106 pub fn new(key: CanonicalKey, timestamp: Duration) -> Self {
108 let protocol_state = match key.protocol {
109 TransportProtocol::Tcp => ProtocolState::Tcp(TcpConversationState::new()),
110 TransportProtocol::Udp => ProtocolState::Udp(UdpFlowState::new()),
111 _ => ProtocolState::Other,
112 };
113
114 Self {
115 key,
116 status: ConversationStatus::Active,
117 start_time: timestamp,
118 last_seen: timestamp,
119 forward: DirectionStats::new(timestamp),
120 reverse: DirectionStats::new(timestamp),
121 packet_indices: Vec::new(),
122 protocol_state,
123 }
124 }
125
126 pub fn total_packets(&self) -> u64 {
128 self.forward.packets + self.reverse.packets
129 }
130
131 pub fn total_bytes(&self) -> u64 {
133 self.forward.bytes + self.reverse.bytes
134 }
135
136 pub fn duration(&self) -> Duration {
138 self.last_seen.saturating_sub(self.start_time)
139 }
140
141 pub fn record_packet(
143 &mut self,
144 direction: FlowDirection,
145 byte_count: u64,
146 timestamp: Duration,
147 packet_index: usize,
148 ) {
149 self.last_seen = timestamp;
150 self.packet_indices.push(packet_index);
151
152 match direction {
153 FlowDirection::Forward => self.forward.record_packet(byte_count, timestamp),
154 FlowDirection::Reverse => self.reverse.record_packet(byte_count, timestamp),
155 }
156 }
157
158 pub fn update_status(&mut self) {
160 match &self.protocol_state {
161 ProtocolState::Tcp(tcp) => {
162 if tcp.conn_state.is_closed() {
163 self.status = ConversationStatus::Closed;
164 } else if matches!(
165 tcp.conn_state,
166 super::tcp_state::TcpConnectionState::FinWait1
167 | super::tcp_state::TcpConnectionState::FinWait2
168 | super::tcp_state::TcpConnectionState::CloseWait
169 | super::tcp_state::TcpConnectionState::Closing
170 | super::tcp_state::TcpConnectionState::LastAck
171 ) {
172 self.status = ConversationStatus::HalfClosed;
173 }
174 },
175 ProtocolState::Udp(udp) => {
176 self.status = udp.status;
177 },
178 ProtocolState::Other => {},
179 }
180 }
181
182 pub fn is_timed_out(&self, now: Duration, config: &FlowConfig) -> bool {
184 let elapsed = now.saturating_sub(self.last_seen);
185 match &self.protocol_state {
186 ProtocolState::Tcp(tcp) => {
187 if tcp.conn_state.is_closed() {
188 false } else if tcp.conn_state.is_half_open() {
190 elapsed > config.tcp_half_open_timeout
191 } else {
192 elapsed > config.tcp_established_timeout
193 }
194 },
195 ProtocolState::Udp(_) => elapsed > config.udp_timeout,
196 ProtocolState::Other => elapsed > config.udp_timeout,
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use std::net::{IpAddr, Ipv4Addr};
205
206 fn test_key() -> CanonicalKey {
207 let (key, _) = CanonicalKey::new(
208 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
209 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
210 12345,
211 80,
212 TransportProtocol::Tcp,
213 None,
214 );
215 key
216 }
217
218 #[test]
219 fn test_conversation_state_new() {
220 let state = ConversationState::new(test_key(), Duration::from_secs(1));
221 assert_eq!(state.status, ConversationStatus::Active);
222 assert_eq!(state.total_packets(), 0);
223 assert_eq!(state.total_bytes(), 0);
224 assert!(matches!(state.protocol_state, ProtocolState::Tcp(_)));
225 }
226
227 #[test]
228 fn test_record_packet() {
229 let mut state = ConversationState::new(test_key(), Duration::from_secs(1));
230
231 state.record_packet(FlowDirection::Forward, 100, Duration::from_secs(1), 0);
232 state.record_packet(FlowDirection::Reverse, 200, Duration::from_secs(2), 1);
233 state.record_packet(FlowDirection::Forward, 50, Duration::from_secs(3), 2);
234
235 assert_eq!(state.total_packets(), 3);
236 assert_eq!(state.total_bytes(), 350);
237 assert_eq!(state.forward.packets, 2);
238 assert_eq!(state.reverse.packets, 1);
239 assert_eq!(state.packet_indices, vec![0, 1, 2]);
240 assert_eq!(state.duration(), Duration::from_secs(2));
241 }
242
243 #[test]
244 fn test_udp_conversation() {
245 let (key, _) = CanonicalKey::new(
246 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
247 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
248 12345,
249 53,
250 TransportProtocol::Udp,
251 None,
252 );
253 let state = ConversationState::new(key, Duration::from_secs(0));
254 assert!(matches!(state.protocol_state, ProtocolState::Udp(_)));
255 }
256
257 #[test]
258 fn test_timeout_check() {
259 let mut state = ConversationState::new(test_key(), Duration::from_secs(0));
260 state.last_seen = Duration::from_secs(100);
261 let config = FlowConfig::default();
262
263 if let ProtocolState::Tcp(ref mut tcp) = state.protocol_state {
265 tcp.conn_state = super::super::tcp_state::TcpConnectionState::Established;
266 }
267
268 assert!(!state.is_timed_out(Duration::from_secs(86499), &config));
270
271 assert!(state.is_timed_out(Duration::from_secs(86501), &config));
273 }
274}