1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11 Listen,
12 SynSent,
13 SynRcvd,
14 Established,
15 FinWait1,
16 FinWait2,
17 CloseWait,
18 Closing,
19 LastAck,
20 TimeWait,
21 Closed,
22}
23
24impl TcpConnectionState {
25 #[must_use]
27 pub fn name(&self) -> &'static str {
28 match self {
29 Self::Listen => "LISTEN",
30 Self::SynSent => "SYN_SENT",
31 Self::SynRcvd => "SYN_RCVD",
32 Self::Established => "ESTABLISHED",
33 Self::FinWait1 => "FIN_WAIT_1",
34 Self::FinWait2 => "FIN_WAIT_2",
35 Self::CloseWait => "CLOSE_WAIT",
36 Self::Closing => "CLOSING",
37 Self::LastAck => "LAST_ACK",
38 Self::TimeWait => "TIME_WAIT",
39 Self::Closed => "CLOSED",
40 }
41 }
42
43 #[must_use]
45 pub fn is_closed(&self) -> bool {
46 matches!(self, Self::Closed | Self::TimeWait)
47 }
48
49 #[must_use]
51 pub fn is_half_open(&self) -> bool {
52 matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53 }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.write_str(self.name())
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65 pub next_expected_seq: u32,
67 pub last_ack: u32,
69 pub window_size: u16,
71 pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 next_expected_seq: 0,
80 last_ack: 0,
81 window_size: 0,
82 initial_seq: None,
83 }
84 }
85}
86
87impl Default for TcpEndpointState {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[derive(Debug)]
96pub struct TcpConversationState {
97 pub conn_state: TcpConnectionState,
99 pub forward_endpoint: TcpEndpointState,
101 pub reverse_endpoint: TcpEndpointState,
103 pub reassembler_fwd: TcpReassembler,
105 pub reassembler_rev: TcpReassembler,
107 pub dropped_segments_fwd: u64,
109 pub dropped_segments_rev: u64,
111}
112
113impl TcpConversationState {
114 #[must_use]
115 pub fn new() -> Self {
116 Self {
117 conn_state: TcpConnectionState::Listen,
118 forward_endpoint: TcpEndpointState::new(),
119 reverse_endpoint: TcpEndpointState::new(),
120 reassembler_fwd: TcpReassembler::new(),
121 reassembler_rev: TcpReassembler::new(),
122 dropped_segments_fwd: 0,
123 dropped_segments_rev: 0,
124 }
125 }
126
127 #[must_use]
129 pub fn total_dropped_segments(&self) -> u64 {
130 self.dropped_segments_fwd + self.dropped_segments_rev
131 }
132
133 pub fn process_packet(
139 &mut self,
140 direction: FlowDirection,
141 tcp: &TcpLayer,
142 buf: &[u8],
143 config: &FlowConfig,
144 ) -> Result<(), FlowError> {
145 let flags = tcp
146 .flags(buf)
147 .map_err(|e| FlowError::PacketError(e.into()))?;
148 let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
149 let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
150 let window = tcp
151 .window(buf)
152 .map_err(|e| FlowError::PacketError(e.into()))?;
153
154 let data_offset = tcp
156 .data_offset(buf)
157 .map_err(|e| FlowError::PacketError(e.into()))?;
158 let header_bytes = (data_offset as usize) * 4;
159 let tcp_start = tcp.index.start;
160 let payload_start = tcp_start + header_bytes;
164 let payload = if payload_start < buf.len() {
165 &buf[payload_start..buf.len()]
166 } else {
167 &[]
168 };
169
170 let (sender, _receiver, reassembler) = match direction {
172 FlowDirection::Forward => (
173 &mut self.forward_endpoint,
174 &mut self.reverse_endpoint,
175 &mut self.reassembler_fwd,
176 ),
177 FlowDirection::Reverse => (
178 &mut self.reverse_endpoint,
179 &mut self.forward_endpoint,
180 &mut self.reassembler_rev,
181 ),
182 };
183
184 sender.window_size = window;
186
187 if flags.rst {
189 self.conn_state = TcpConnectionState::Closed;
190 return Ok(());
191 }
192
193 match self.conn_state {
194 TcpConnectionState::Listen => {
195 if flags.syn && !flags.ack {
196 sender.initial_seq = Some(seq);
198 sender.next_expected_seq = seq.wrapping_add(1); self.conn_state = TcpConnectionState::SynSent;
200 }
201 },
202 TcpConnectionState::SynSent => {
203 if flags.syn && flags.ack {
204 sender.initial_seq = Some(seq);
206 sender.next_expected_seq = seq.wrapping_add(1);
207 sender.last_ack = ack;
208 self.conn_state = TcpConnectionState::SynRcvd;
209 }
210 },
211 TcpConnectionState::SynRcvd => {
212 if flags.ack && !flags.syn {
213 sender.last_ack = ack;
215 self.conn_state = TcpConnectionState::Established;
216 if !self.reassembler_fwd.is_initialized()
218 && let Some(isn) = self.forward_endpoint.initial_seq
219 {
220 self.reassembler_fwd.initialize(isn.wrapping_add(1));
221 }
222 if !self.reassembler_rev.is_initialized()
223 && let Some(isn) = self.reverse_endpoint.initial_seq
224 {
225 self.reassembler_rev.initialize(isn.wrapping_add(1));
226 }
227 }
228 },
229 TcpConnectionState::Established => {
230 sender.last_ack = ack;
231
232 if !payload.is_empty() {
234 if reassembler.process_segment(seq, payload, config).is_err() {
235 match direction {
236 FlowDirection::Forward => self.dropped_segments_fwd += 1,
237 FlowDirection::Reverse => self.dropped_segments_rev += 1,
238 }
239 }
240 }
241
242 if flags.fin {
243 sender.next_expected_seq =
244 seq.wrapping_add(payload.len() as u32).wrapping_add(1); match direction {
246 FlowDirection::Forward => {
247 self.conn_state = TcpConnectionState::FinWait1;
248 },
249 FlowDirection::Reverse => {
250 self.conn_state = TcpConnectionState::CloseWait;
251 },
252 }
253 } else {
254 sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
255 }
256 },
257 TcpConnectionState::FinWait1 => {
258 if flags.fin && flags.ack {
259 self.conn_state = TcpConnectionState::TimeWait;
261 } else if flags.ack {
262 self.conn_state = TcpConnectionState::FinWait2;
263 } else if flags.fin {
264 self.conn_state = TcpConnectionState::Closing;
265 }
266 },
267 TcpConnectionState::FinWait2 => {
268 if flags.fin {
269 self.conn_state = TcpConnectionState::TimeWait;
270 }
271 },
272 TcpConnectionState::CloseWait => {
273 if flags.fin {
274 self.conn_state = TcpConnectionState::LastAck;
275 }
276 },
277 TcpConnectionState::Closing => {
278 if flags.ack {
279 self.conn_state = TcpConnectionState::TimeWait;
280 }
281 },
282 TcpConnectionState::LastAck => {
283 if flags.ack {
284 self.conn_state = TcpConnectionState::Closed;
285 }
286 },
287 TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
288 },
290 }
291
292 Ok(())
293 }
294}
295
296impl Default for TcpConversationState {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::layer::stack::{LayerStack, LayerStackEntry};
306 use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
307
308 fn make_tcp_packet(
309 src_port: u16,
310 dst_port: u16,
311 seq: u32,
312 ack_num: u32,
313 flags: &str,
314 payload: &[u8],
315 ) -> crate::Packet {
316 let mut builder = TcpBuilder::new()
317 .src_port(src_port)
318 .dst_port(dst_port)
319 .seq(seq)
320 .ack_num(ack_num)
321 .window(65535);
322
323 for c in flags.chars() {
324 builder = match c {
325 'S' => builder.syn(),
326 'A' => builder.ack(),
327 'F' => builder.fin(),
328 'R' => builder.rst(),
329 'P' => builder.psh(),
330 _ => builder,
331 };
332 }
333
334 let stack = LayerStack::new()
335 .push(LayerStackEntry::Ethernet(
336 EthernetBuilder::new()
337 .dst(crate::MacAddress::BROADCAST)
338 .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
339 ))
340 .push(LayerStackEntry::Ipv4(
341 Ipv4Builder::new()
342 .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
343 .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
344 ))
345 .push(LayerStackEntry::Tcp(builder));
346
347 let stack = if !payload.is_empty() {
348 stack.push(LayerStackEntry::Raw(payload.to_vec()))
349 } else {
350 stack
351 };
352
353 stack.build_packet()
354 }
355
356 fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
357 let tcp = pkt.tcp().unwrap();
358 let buf = pkt.as_bytes();
359 (tcp, buf)
360 }
361
362 #[test]
363 fn test_three_way_handshake() {
364 let config = FlowConfig::default();
365 let mut state = TcpConversationState::new();
366
367 let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
369 let (tcp, buf) = get_tcp_and_buf(&pkt);
370 state
371 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
372 .unwrap();
373 assert_eq!(state.conn_state, TcpConnectionState::SynSent);
374
375 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
377 let (tcp, buf) = get_tcp_and_buf(&pkt);
378 state
379 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
380 .unwrap();
381 assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
382
383 let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
385 let (tcp, buf) = get_tcp_and_buf(&pkt);
386 state
387 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
388 .unwrap();
389 assert_eq!(state.conn_state, TcpConnectionState::Established);
390 }
391
392 #[test]
393 fn test_rst_closes_connection() {
394 let config = FlowConfig::default();
395 let mut state = TcpConversationState::new();
396 state.conn_state = TcpConnectionState::Established;
397
398 let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
399 let (tcp, buf) = get_tcp_and_buf(&pkt);
400 state
401 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
402 .unwrap();
403 assert_eq!(state.conn_state, TcpConnectionState::Closed);
404 }
405
406 #[test]
407 fn test_fin_handshake() {
408 let config = FlowConfig::default();
409 let mut state = TcpConversationState::new();
410 state.conn_state = TcpConnectionState::Established;
411
412 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
414 let (tcp, buf) = get_tcp_and_buf(&pkt);
415 state
416 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
417 .unwrap();
418 assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
419
420 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
422 let (tcp, buf) = get_tcp_and_buf(&pkt);
423 state
424 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
425 .unwrap();
426 assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
427
428 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
430 let (tcp, buf) = get_tcp_and_buf(&pkt);
431 state
432 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
433 .unwrap();
434 assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
435 }
436
437 #[test]
438 fn test_data_transfer_and_reassembly() {
439 let config = FlowConfig::default();
440 let mut state = TcpConversationState::new();
441 state.conn_state = TcpConnectionState::Established;
442
443 state.forward_endpoint.initial_seq = Some(999);
445 state.reassembler_fwd.initialize(1000);
446
447 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
449 let (tcp, buf) = get_tcp_and_buf(&pkt);
450 state
451 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
452 .unwrap();
453
454 assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
455 }
456
457 #[test]
458 fn test_state_display() {
459 assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
460 assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
461 assert!(TcpConnectionState::Closed.is_closed());
462 assert!(TcpConnectionState::TimeWait.is_closed());
463 assert!(TcpConnectionState::SynSent.is_half_open());
464 assert!(!TcpConnectionState::Established.is_half_open());
465 }
466}