1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11 Listen,
12 SynSent,
13 SynRcvd,
14 Established,
15 FinWait1,
16 FinWait2,
17 CloseWait,
18 Closing,
19 LastAck,
20 TimeWait,
21 Closed,
22}
23
24impl TcpConnectionState {
25 pub fn name(&self) -> &'static str {
27 match self {
28 Self::Listen => "LISTEN",
29 Self::SynSent => "SYN_SENT",
30 Self::SynRcvd => "SYN_RCVD",
31 Self::Established => "ESTABLISHED",
32 Self::FinWait1 => "FIN_WAIT_1",
33 Self::FinWait2 => "FIN_WAIT_2",
34 Self::CloseWait => "CLOSE_WAIT",
35 Self::Closing => "CLOSING",
36 Self::LastAck => "LAST_ACK",
37 Self::TimeWait => "TIME_WAIT",
38 Self::Closed => "CLOSED",
39 }
40 }
41
42 pub fn is_closed(&self) -> bool {
44 matches!(self, Self::Closed | Self::TimeWait)
45 }
46
47 pub fn is_half_open(&self) -> bool {
49 matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
50 }
51}
52
53impl std::fmt::Display for TcpConnectionState {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str(self.name())
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TcpEndpointState {
62 pub next_expected_seq: u32,
64 pub last_ack: u32,
66 pub window_size: u16,
68 pub initial_seq: Option<u32>,
70}
71
72impl TcpEndpointState {
73 pub fn new() -> Self {
74 Self {
75 next_expected_seq: 0,
76 last_ack: 0,
77 window_size: 0,
78 initial_seq: None,
79 }
80 }
81}
82
83impl Default for TcpEndpointState {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[derive(Debug)]
92pub struct TcpConversationState {
93 pub conn_state: TcpConnectionState,
95 pub forward_endpoint: TcpEndpointState,
97 pub reverse_endpoint: TcpEndpointState,
99 pub reassembler_fwd: TcpReassembler,
101 pub reassembler_rev: TcpReassembler,
103}
104
105impl TcpConversationState {
106 pub fn new() -> Self {
107 Self {
108 conn_state: TcpConnectionState::Listen,
109 forward_endpoint: TcpEndpointState::new(),
110 reverse_endpoint: TcpEndpointState::new(),
111 reassembler_fwd: TcpReassembler::new(),
112 reassembler_rev: TcpReassembler::new(),
113 }
114 }
115
116 pub fn process_packet(
122 &mut self,
123 direction: FlowDirection,
124 tcp: &TcpLayer,
125 buf: &[u8],
126 config: &FlowConfig,
127 ) -> Result<(), FlowError> {
128 let flags = tcp
129 .flags(buf)
130 .map_err(|e| FlowError::PacketError(e.into()))?;
131 let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
132 let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
133 let window = tcp
134 .window(buf)
135 .map_err(|e| FlowError::PacketError(e.into()))?;
136
137 let data_offset = tcp
139 .data_offset(buf)
140 .map_err(|e| FlowError::PacketError(e.into()))?;
141 let header_bytes = (data_offset as usize) * 4;
142 let tcp_start = tcp.index.start;
143 let payload_start = tcp_start + header_bytes;
147 let payload = if payload_start < buf.len() {
148 &buf[payload_start..buf.len()]
149 } else {
150 &[]
151 };
152
153 let (sender, _receiver, reassembler) = match direction {
155 FlowDirection::Forward => (
156 &mut self.forward_endpoint,
157 &mut self.reverse_endpoint,
158 &mut self.reassembler_fwd,
159 ),
160 FlowDirection::Reverse => (
161 &mut self.reverse_endpoint,
162 &mut self.forward_endpoint,
163 &mut self.reassembler_rev,
164 ),
165 };
166
167 sender.window_size = window;
169
170 if flags.rst {
172 self.conn_state = TcpConnectionState::Closed;
173 return Ok(());
174 }
175
176 match self.conn_state {
177 TcpConnectionState::Listen => {
178 if flags.syn && !flags.ack {
179 sender.initial_seq = Some(seq);
181 sender.next_expected_seq = seq.wrapping_add(1); self.conn_state = TcpConnectionState::SynSent;
183 }
184 },
185 TcpConnectionState::SynSent => {
186 if flags.syn && flags.ack {
187 sender.initial_seq = Some(seq);
189 sender.next_expected_seq = seq.wrapping_add(1);
190 sender.last_ack = ack;
191 self.conn_state = TcpConnectionState::SynRcvd;
192 }
193 },
194 TcpConnectionState::SynRcvd => {
195 if flags.ack && !flags.syn {
196 sender.last_ack = ack;
198 self.conn_state = TcpConnectionState::Established;
199 if !self.reassembler_fwd.is_initialized() {
201 if let Some(isn) = self.forward_endpoint.initial_seq {
202 self.reassembler_fwd.initialize(isn.wrapping_add(1));
203 }
204 }
205 if !self.reassembler_rev.is_initialized() {
206 if let Some(isn) = self.reverse_endpoint.initial_seq {
207 self.reassembler_rev.initialize(isn.wrapping_add(1));
208 }
209 }
210 }
211 },
212 TcpConnectionState::Established => {
213 sender.last_ack = ack;
214
215 if !payload.is_empty() {
217 let _ = reassembler.process_segment(seq, payload, config);
220 }
221
222 if flags.fin {
223 sender.next_expected_seq =
224 seq.wrapping_add(payload.len() as u32).wrapping_add(1); match direction {
226 FlowDirection::Forward => {
227 self.conn_state = TcpConnectionState::FinWait1;
228 },
229 FlowDirection::Reverse => {
230 self.conn_state = TcpConnectionState::CloseWait;
231 },
232 }
233 } else {
234 sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
235 }
236 },
237 TcpConnectionState::FinWait1 => {
238 if flags.fin && flags.ack {
239 self.conn_state = TcpConnectionState::TimeWait;
241 } else if flags.ack {
242 self.conn_state = TcpConnectionState::FinWait2;
243 } else if flags.fin {
244 self.conn_state = TcpConnectionState::Closing;
245 }
246 },
247 TcpConnectionState::FinWait2 => {
248 if flags.fin {
249 self.conn_state = TcpConnectionState::TimeWait;
250 }
251 },
252 TcpConnectionState::CloseWait => {
253 if flags.fin {
254 self.conn_state = TcpConnectionState::LastAck;
255 }
256 },
257 TcpConnectionState::Closing => {
258 if flags.ack {
259 self.conn_state = TcpConnectionState::TimeWait;
260 }
261 },
262 TcpConnectionState::LastAck => {
263 if flags.ack {
264 self.conn_state = TcpConnectionState::Closed;
265 }
266 },
267 TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
268 },
270 }
271
272 Ok(())
273 }
274}
275
276impl Default for TcpConversationState {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::layer::stack::{LayerStack, LayerStackEntry};
286 use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
287
288 fn make_tcp_packet(
289 src_port: u16,
290 dst_port: u16,
291 seq: u32,
292 ack_num: u32,
293 flags: &str,
294 payload: &[u8],
295 ) -> crate::Packet {
296 let mut builder = TcpBuilder::new()
297 .src_port(src_port)
298 .dst_port(dst_port)
299 .seq(seq)
300 .ack_num(ack_num)
301 .window(65535);
302
303 for c in flags.chars() {
304 builder = match c {
305 'S' => builder.syn(),
306 'A' => builder.ack(),
307 'F' => builder.fin(),
308 'R' => builder.rst(),
309 'P' => builder.psh(),
310 _ => builder,
311 };
312 }
313
314 let stack = LayerStack::new()
315 .push(LayerStackEntry::Ethernet(
316 EthernetBuilder::new()
317 .dst(crate::MacAddress::BROADCAST)
318 .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
319 ))
320 .push(LayerStackEntry::Ipv4(
321 Ipv4Builder::new()
322 .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
323 .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
324 ))
325 .push(LayerStackEntry::Tcp(builder));
326
327 let stack = if !payload.is_empty() {
328 stack.push(LayerStackEntry::Raw(payload.to_vec()))
329 } else {
330 stack
331 };
332
333 stack.build_packet()
334 }
335
336 fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
337 let tcp = pkt.tcp().unwrap();
338 let buf = pkt.as_bytes();
339 (tcp, buf)
340 }
341
342 #[test]
343 fn test_three_way_handshake() {
344 let config = FlowConfig::default();
345 let mut state = TcpConversationState::new();
346
347 let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
349 let (tcp, buf) = get_tcp_and_buf(&pkt);
350 state
351 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
352 .unwrap();
353 assert_eq!(state.conn_state, TcpConnectionState::SynSent);
354
355 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
357 let (tcp, buf) = get_tcp_and_buf(&pkt);
358 state
359 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
360 .unwrap();
361 assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
362
363 let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
365 let (tcp, buf) = get_tcp_and_buf(&pkt);
366 state
367 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
368 .unwrap();
369 assert_eq!(state.conn_state, TcpConnectionState::Established);
370 }
371
372 #[test]
373 fn test_rst_closes_connection() {
374 let config = FlowConfig::default();
375 let mut state = TcpConversationState::new();
376 state.conn_state = TcpConnectionState::Established;
377
378 let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
379 let (tcp, buf) = get_tcp_and_buf(&pkt);
380 state
381 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
382 .unwrap();
383 assert_eq!(state.conn_state, TcpConnectionState::Closed);
384 }
385
386 #[test]
387 fn test_fin_handshake() {
388 let config = FlowConfig::default();
389 let mut state = TcpConversationState::new();
390 state.conn_state = TcpConnectionState::Established;
391
392 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
394 let (tcp, buf) = get_tcp_and_buf(&pkt);
395 state
396 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
397 .unwrap();
398 assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
399
400 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
402 let (tcp, buf) = get_tcp_and_buf(&pkt);
403 state
404 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
405 .unwrap();
406 assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
407
408 let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
410 let (tcp, buf) = get_tcp_and_buf(&pkt);
411 state
412 .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
413 .unwrap();
414 assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
415 }
416
417 #[test]
418 fn test_data_transfer_and_reassembly() {
419 let config = FlowConfig::default();
420 let mut state = TcpConversationState::new();
421 state.conn_state = TcpConnectionState::Established;
422
423 state.forward_endpoint.initial_seq = Some(999);
425 state.reassembler_fwd.initialize(1000);
426
427 let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
429 let (tcp, buf) = get_tcp_and_buf(&pkt);
430 state
431 .process_packet(FlowDirection::Forward, &tcp, buf, &config)
432 .unwrap();
433
434 assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
435 }
436
437 #[test]
438 fn test_state_display() {
439 assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
440 assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
441 assert!(TcpConnectionState::Closed.is_closed());
442 assert!(TcpConnectionState::TimeWait.is_closed());
443 assert!(TcpConnectionState::SynSent.is_half_open());
444 assert!(!TcpConnectionState::Established.is_half_open());
445 }
446}