1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11 Listen,
12 SynSent,
13 SynRcvd,
14 Established,
15 FinWait1,
16 FinWait2,
17 CloseWait,
18 Closing,
19 LastAck,
20 TimeWait,
21 Closed,
22}
23
24impl TcpConnectionState {
25 #[must_use]
27 pub fn name(&self) -> &'static str {
28 match self {
29 Self::Listen => "LISTEN",
30 Self::SynSent => "SYN_SENT",
31 Self::SynRcvd => "SYN_RCVD",
32 Self::Established => "ESTABLISHED",
33 Self::FinWait1 => "FIN_WAIT_1",
34 Self::FinWait2 => "FIN_WAIT_2",
35 Self::CloseWait => "CLOSE_WAIT",
36 Self::Closing => "CLOSING",
37 Self::LastAck => "LAST_ACK",
38 Self::TimeWait => "TIME_WAIT",
39 Self::Closed => "CLOSED",
40 }
41 }
42
43 #[must_use]
45 pub fn is_closed(&self) -> bool {
46 matches!(self, Self::Closed | Self::TimeWait)
47 }
48
49 #[must_use]
51 pub fn is_half_open(&self) -> bool {
52 matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53 }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.write_str(self.name())
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65 pub next_expected_seq: u32,
67 pub last_ack: u32,
69 pub window_size: u16,
71 pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 next_expected_seq: 0,
80 last_ack: 0,
81 window_size: 0,
82 initial_seq: None,
83 }
84 }
85}
86
87impl Default for TcpEndpointState {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[derive(Debug)]
96pub struct TcpConversationState {
97 pub conn_state: TcpConnectionState,
99 pub forward_endpoint: TcpEndpointState,
101 pub reverse_endpoint: TcpEndpointState,
103 pub reassembler_fwd: TcpReassembler,
105 pub reassembler_rev: TcpReassembler,
107}
108
109impl TcpConversationState {
110 #[must_use]
111 pub fn new() -> Self {
112 Self {
113 conn_state: TcpConnectionState::Listen,
114 forward_endpoint: TcpEndpointState::new(),
115 reverse_endpoint: TcpEndpointState::new(),
116 reassembler_fwd: TcpReassembler::new(),
117 reassembler_rev: TcpReassembler::new(),
118 }
119 }
120
121 pub fn process_packet(
127 &mut self,
128 direction: FlowDirection,
129 tcp: &TcpLayer,
130 buf: &[u8],
131 config: &FlowConfig,
132 ) -> Result<(), FlowError> {
133 let flags = tcp
134 .flags(buf)
135 .map_err(|e| FlowError::PacketError(e.into()))?;
136 let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
137 let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
138 let window = tcp
139 .window(buf)
140 .map_err(|e| FlowError::PacketError(e.into()))?;
141
142 let data_offset = tcp
144 .data_offset(buf)
145 .map_err(|e| FlowError::PacketError(e.into()))?;
146 let header_bytes = (data_offset as usize) * 4;
147 let tcp_start = tcp.index.start;
148 let payload_start = tcp_start + header_bytes;
152 let payload = if payload_start < buf.len() {
153 &buf[payload_start..buf.len()]
154 } else {
155 &[]
156 };
157
158 let (sender, _receiver, reassembler) = match direction {
160 FlowDirection::Forward => (
161 &mut self.forward_endpoint,
162 &mut self.reverse_endpoint,
163 &mut self.reassembler_fwd,
164 ),
165 FlowDirection::Reverse => (
166 &mut self.reverse_endpoint,
167 &mut self.forward_endpoint,
168 &mut self.reassembler_rev,
169 ),
170 };
171
172 sender.window_size = window;
174
175 if flags.rst {
177 self.conn_state = TcpConnectionState::Closed;
178 return Ok(());
179 }
180
181 match self.conn_state {
182 TcpConnectionState::Listen => {
183 if flags.syn && !flags.ack {
184 sender.initial_seq = Some(seq);
186 sender.next_expected_seq = seq.wrapping_add(1); self.conn_state = TcpConnectionState::SynSent;
188 }
189 },
190 TcpConnectionState::SynSent => {
191 if flags.syn && flags.ack {
192 sender.initial_seq = Some(seq);
194 sender.next_expected_seq = seq.wrapping_add(1);
195 sender.last_ack = ack;
196 self.conn_state = TcpConnectionState::SynRcvd;
197 }
198 },
199 TcpConnectionState::SynRcvd => {
200 if flags.ack && !flags.syn {
201 sender.last_ack = ack;
203 self.conn_state = TcpConnectionState::Established;
204 if !self.reassembler_fwd.is_initialized()
206 && let Some(isn) = self.forward_endpoint.initial_seq
207 {
208 self.reassembler_fwd.initialize(isn.wrapping_add(1));
209 }
210 if !self.reassembler_rev.is_initialized()
211 && let Some(isn) = self.reverse_endpoint.initial_seq
212 {
213 self.reassembler_rev.initialize(isn.wrapping_add(1));
214 }
215 }
216 },
217 TcpConnectionState::Established => {
218 sender.last_ack = ack;
219
220 if !payload.is_empty() {
222 let _ = reassembler.process_segment(seq, payload, config);
225 }
226
227 if flags.fin {
228 sender.next_expected_seq =
229 seq.wrapping_add(payload.len() as u32).wrapping_add(1); match direction {
231 FlowDirection::Forward => {
232 self.conn_state = TcpConnectionState::FinWait1;
233 },
234 FlowDirection::Reverse => {
235 self.conn_state = TcpConnectionState::CloseWait;
236 },
237 }
238 } else {
239 sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
240 }
241 },
242 TcpConnectionState::FinWait1 => {
243 if flags.fin && flags.ack {
244 self.conn_state = TcpConnectionState::TimeWait;
246 } else if flags.ack {
247 self.conn_state = TcpConnectionState::FinWait2;
248 } else if flags.fin {
249 self.conn_state = TcpConnectionState::Closing;
250 }
251 },
252 TcpConnectionState::FinWait2 => {
253 if flags.fin {
254 self.conn_state = TcpConnectionState::TimeWait;
255 }
256 },
257 TcpConnectionState::CloseWait => {
258 if flags.fin {
259 self.conn_state = TcpConnectionState::LastAck;
260 }
261 },
262 TcpConnectionState::Closing => {
263 if flags.ack {
264 self.conn_state = TcpConnectionState::TimeWait;
265 }
266 },
267 TcpConnectionState::LastAck => {
268 if flags.ack {
269 self.conn_state = TcpConnectionState::Closed;
270 }
271 },
272 TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
273 },
275 }
276
277 Ok(())
278 }
279}
280
281impl Default for TcpConversationState {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::layer::stack::{LayerStack, LayerStackEntry};
291 use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
292
293 fn make_tcp_packet(
294 src_port: u16,
295 dst_port: u16,
296 seq: u32,
297 ack_num: u32,
298 flags: &str,
299 payload: &[u8],
300 ) -> crate::Packet {
301 let mut builder = TcpBuilder::new()
302 .src_port(src_port)
303 .dst_port(dst_port)
304 .seq(seq)
305 .ack_num(ack_num)
306 .window(65535);
307
308 for c in flags.chars() {
309 builder = match c {
310 'S' => builder.syn(),
311 'A' => builder.ack(),
312 'F' => builder.fin(),
313 'R' => builder.rst(),
314 'P' => builder.psh(),
315 _ => builder,
316 };
317 }
318
319 let stack = LayerStack::new()
320 .push(LayerStackEntry::Ethernet(
321 EthernetBuilder::new()
322 .dst(crate::MacAddress::BROADCAST)
323 .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
324 ))
325 .push(LayerStackEntry::Ipv4(
326 Ipv4Builder::new()
327 .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
328 .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
329 ))
330 .push(LayerStackEntry::Tcp(builder));
331
332 let stack = if !payload.is_empty() {
333 stack.push(LayerStackEntry::Raw(payload.to_vec()))
334 } else {
335 stack
336 };
337
338 stack.build_packet()
339 }
340
341 fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
342 let tcp = pkt.tcp().unwrap();
343 let buf = pkt.as_bytes();
344 (tcp, buf)
345 }
346
347 #[test]
348 fn test_three_way_handshake() {
349 let config = FlowConfig::default();
350 let mut state = TcpConversationState::new();
351
352 let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
354 let (tcp, buf) = get_tcp_and_buf(&pkt);
355 state
356 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
357 .unwrap();
358 assert_eq!(state.conn_state, TcpConnectionState::SynSent);
359
360 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
362 let (tcp, buf) = get_tcp_and_buf(&pkt);
363 state
364 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
365 .unwrap();
366 assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
367
368 let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
370 let (tcp, buf) = get_tcp_and_buf(&pkt);
371 state
372 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
373 .unwrap();
374 assert_eq!(state.conn_state, TcpConnectionState::Established);
375 }
376
377 #[test]
378 fn test_rst_closes_connection() {
379 let config = FlowConfig::default();
380 let mut state = TcpConversationState::new();
381 state.conn_state = TcpConnectionState::Established;
382
383 let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
384 let (tcp, buf) = get_tcp_and_buf(&pkt);
385 state
386 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
387 .unwrap();
388 assert_eq!(state.conn_state, TcpConnectionState::Closed);
389 }
390
391 #[test]
392 fn test_fin_handshake() {
393 let config = FlowConfig::default();
394 let mut state = TcpConversationState::new();
395 state.conn_state = TcpConnectionState::Established;
396
397 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
399 let (tcp, buf) = get_tcp_and_buf(&pkt);
400 state
401 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
402 .unwrap();
403 assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
404
405 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
407 let (tcp, buf) = get_tcp_and_buf(&pkt);
408 state
409 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
410 .unwrap();
411 assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
412
413 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
415 let (tcp, buf) = get_tcp_and_buf(&pkt);
416 state
417 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
418 .unwrap();
419 assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
420 }
421
422 #[test]
423 fn test_data_transfer_and_reassembly() {
424 let config = FlowConfig::default();
425 let mut state = TcpConversationState::new();
426 state.conn_state = TcpConnectionState::Established;
427
428 state.forward_endpoint.initial_seq = Some(999);
430 state.reassembler_fwd.initialize(1000);
431
432 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
434 let (tcp, buf) = get_tcp_and_buf(&pkt);
435 state
436 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
437 .unwrap();
438
439 assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
440 }
441
442 #[test]
443 fn test_state_display() {
444 assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
445 assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
446 assert!(TcpConnectionState::Closed.is_closed());
447 assert!(TcpConnectionState::TimeWait.is_closed());
448 assert!(TcpConnectionState::SynSent.is_half_open());
449 assert!(!TcpConnectionState::Established.is_half_open());
450 }
451}