Skip to main content

stackforge_core/flow/
tcp_state.rs

1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8/// TCP connection states per RFC 793.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11    Listen,
12    SynSent,
13    SynRcvd,
14    Established,
15    FinWait1,
16    FinWait2,
17    CloseWait,
18    Closing,
19    LastAck,
20    TimeWait,
21    Closed,
22}
23
24impl TcpConnectionState {
25    /// Human-readable state name.
26    #[must_use]
27    pub fn name(&self) -> &'static str {
28        match self {
29            Self::Listen => "LISTEN",
30            Self::SynSent => "SYN_SENT",
31            Self::SynRcvd => "SYN_RCVD",
32            Self::Established => "ESTABLISHED",
33            Self::FinWait1 => "FIN_WAIT_1",
34            Self::FinWait2 => "FIN_WAIT_2",
35            Self::CloseWait => "CLOSE_WAIT",
36            Self::Closing => "CLOSING",
37            Self::LastAck => "LAST_ACK",
38            Self::TimeWait => "TIME_WAIT",
39            Self::Closed => "CLOSED",
40        }
41    }
42
43    /// Whether this is a terminal/closed state.
44    #[must_use]
45    pub fn is_closed(&self) -> bool {
46        matches!(self, Self::Closed | Self::TimeWait)
47    }
48
49    /// Whether this is a half-open state (not yet established).
50    #[must_use]
51    pub fn is_half_open(&self) -> bool {
52        matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53    }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str(self.name())
59    }
60}
61
62/// Per-endpoint sequence tracking state.
63#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65    /// Next expected sequence number from this endpoint.
66    pub next_expected_seq: u32,
67    /// Last acknowledged sequence number from this endpoint.
68    pub last_ack: u32,
69    /// Advertised receive window size.
70    pub window_size: u16,
71    /// Initial sequence number (set on SYN).
72    pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            next_expected_seq: 0,
80            last_ack: 0,
81            window_size: 0,
82            initial_seq: None,
83        }
84    }
85}
86
87impl Default for TcpEndpointState {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93/// Complete TCP conversation state including connection tracking,
94/// per-endpoint sequence state, and stream reassembly.
95#[derive(Debug)]
96pub struct TcpConversationState {
97    /// Current connection state (RFC 793 state machine).
98    pub conn_state: TcpConnectionState,
99    /// Sequence tracking for the forward direction (`addr_a` → `addr_b`).
100    pub forward_endpoint: TcpEndpointState,
101    /// Sequence tracking for the reverse direction (`addr_b` → `addr_a`).
102    pub reverse_endpoint: TcpEndpointState,
103    /// Stream reassembly for forward direction.
104    pub reassembler_fwd: TcpReassembler,
105    /// Stream reassembly for reverse direction.
106    pub reassembler_rev: TcpReassembler,
107    /// Number of forward segments dropped due to buffer/fragment limits.
108    pub dropped_segments_fwd: u64,
109    /// Number of reverse segments dropped due to buffer/fragment limits.
110    pub dropped_segments_rev: u64,
111}
112
113impl TcpConversationState {
114    #[must_use]
115    pub fn new() -> Self {
116        Self {
117            conn_state: TcpConnectionState::Listen,
118            forward_endpoint: TcpEndpointState::new(),
119            reverse_endpoint: TcpEndpointState::new(),
120            reassembler_fwd: TcpReassembler::new(),
121            reassembler_rev: TcpReassembler::new(),
122            dropped_segments_fwd: 0,
123            dropped_segments_rev: 0,
124        }
125    }
126
127    /// Total dropped segments across both directions.
128    #[must_use]
129    pub fn total_dropped_segments(&self) -> u64 {
130        self.dropped_segments_fwd + self.dropped_segments_rev
131    }
132
133    /// Process a TCP packet, updating connection state and reassembly buffers.
134    ///
135    /// `direction` indicates whether this packet is Forward (`addr_a` → `addr_b`)
136    /// or Reverse (`addr_b` → `addr_a`) relative to the canonical key.
137    /// `tcp` is the TCP layer view, `buf` is the full packet buffer.
138    pub fn process_packet(
139        &mut self,
140        direction: FlowDirection,
141        tcp: &TcpLayer,
142        buf: &[u8],
143        config: &FlowConfig,
144    ) -> Result<(), FlowError> {
145        let flags = tcp
146            .flags(buf)
147            .map_err(|e| FlowError::PacketError(e.into()))?;
148        let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
149        let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
150        let window = tcp
151            .window(buf)
152            .map_err(|e| FlowError::PacketError(e.into()))?;
153
154        // Determine payload boundaries
155        let data_offset = tcp
156            .data_offset(buf)
157            .map_err(|e| FlowError::PacketError(e.into()))?;
158        let header_bytes = (data_offset as usize) * 4;
159        let tcp_start = tcp.index.start;
160        // TCP payload starts after the TCP header. Since the TCP layer's
161        // index.end marks the header boundary, the payload is everything
162        // from header end to the end of the packet buffer.
163        let payload_start = tcp_start + header_bytes;
164        let payload = if payload_start < buf.len() {
165            &buf[payload_start..buf.len()]
166        } else {
167            &[]
168        };
169
170        // Get mutable refs to endpoint and reassembler for this direction
171        let (sender, _receiver, reassembler) = match direction {
172            FlowDirection::Forward => (
173                &mut self.forward_endpoint,
174                &mut self.reverse_endpoint,
175                &mut self.reassembler_fwd,
176            ),
177            FlowDirection::Reverse => (
178                &mut self.reverse_endpoint,
179                &mut self.forward_endpoint,
180                &mut self.reassembler_rev,
181            ),
182        };
183
184        // Update endpoint state
185        sender.window_size = window;
186
187        // State machine transitions
188        if flags.rst {
189            self.conn_state = TcpConnectionState::Closed;
190            return Ok(());
191        }
192
193        match self.conn_state {
194            TcpConnectionState::Listen => {
195                if flags.syn && !flags.ack {
196                    // SYN from initiator
197                    sender.initial_seq = Some(seq);
198                    sender.next_expected_seq = seq.wrapping_add(1); // SYN consumes 1 seq
199                    self.conn_state = TcpConnectionState::SynSent;
200                }
201            },
202            TcpConnectionState::SynSent => {
203                if flags.syn && flags.ack {
204                    // SYN-ACK from responder
205                    sender.initial_seq = Some(seq);
206                    sender.next_expected_seq = seq.wrapping_add(1);
207                    sender.last_ack = ack;
208                    self.conn_state = TcpConnectionState::SynRcvd;
209                }
210            },
211            TcpConnectionState::SynRcvd => {
212                if flags.ack && !flags.syn {
213                    // Final ACK of 3-way handshake
214                    sender.last_ack = ack;
215                    self.conn_state = TcpConnectionState::Established;
216                    // Initialize reassemblers with ISN+1 (after SYN)
217                    if !self.reassembler_fwd.is_initialized()
218                        && let Some(isn) = self.forward_endpoint.initial_seq
219                    {
220                        self.reassembler_fwd.initialize(isn.wrapping_add(1));
221                    }
222                    if !self.reassembler_rev.is_initialized()
223                        && let Some(isn) = self.reverse_endpoint.initial_seq
224                    {
225                        self.reassembler_rev.initialize(isn.wrapping_add(1));
226                    }
227                }
228            },
229            TcpConnectionState::Established => {
230                sender.last_ack = ack;
231
232                // Process payload through reassembler
233                if !payload.is_empty() {
234                    if reassembler.process_segment(seq, payload, config).is_err() {
235                        match direction {
236                            FlowDirection::Forward => self.dropped_segments_fwd += 1,
237                            FlowDirection::Reverse => self.dropped_segments_rev += 1,
238                        }
239                    }
240                }
241
242                if flags.fin {
243                    sender.next_expected_seq =
244                        seq.wrapping_add(payload.len() as u32).wrapping_add(1); // FIN consumes 1 seq
245                    match direction {
246                        FlowDirection::Forward => {
247                            self.conn_state = TcpConnectionState::FinWait1;
248                        },
249                        FlowDirection::Reverse => {
250                            self.conn_state = TcpConnectionState::CloseWait;
251                        },
252                    }
253                } else {
254                    sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
255                }
256            },
257            TcpConnectionState::FinWait1 => {
258                if flags.fin && flags.ack {
259                    // Simultaneous close
260                    self.conn_state = TcpConnectionState::TimeWait;
261                } else if flags.ack {
262                    self.conn_state = TcpConnectionState::FinWait2;
263                } else if flags.fin {
264                    self.conn_state = TcpConnectionState::Closing;
265                }
266            },
267            TcpConnectionState::FinWait2 => {
268                if flags.fin {
269                    self.conn_state = TcpConnectionState::TimeWait;
270                }
271            },
272            TcpConnectionState::CloseWait => {
273                if flags.fin {
274                    self.conn_state = TcpConnectionState::LastAck;
275                }
276            },
277            TcpConnectionState::Closing => {
278                if flags.ack {
279                    self.conn_state = TcpConnectionState::TimeWait;
280                }
281            },
282            TcpConnectionState::LastAck => {
283                if flags.ack {
284                    self.conn_state = TcpConnectionState::Closed;
285                }
286            },
287            TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
288                // Terminal states — no further transitions
289            },
290        }
291
292        Ok(())
293    }
294}
295
296impl Default for TcpConversationState {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::layer::stack::{LayerStack, LayerStackEntry};
306    use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
307
308    fn make_tcp_packet(
309        src_port: u16,
310        dst_port: u16,
311        seq: u32,
312        ack_num: u32,
313        flags: &str,
314        payload: &[u8],
315    ) -> crate::Packet {
316        let mut builder = TcpBuilder::new()
317            .src_port(src_port)
318            .dst_port(dst_port)
319            .seq(seq)
320            .ack_num(ack_num)
321            .window(65535);
322
323        for c in flags.chars() {
324            builder = match c {
325                'S' => builder.syn(),
326                'A' => builder.ack(),
327                'F' => builder.fin(),
328                'R' => builder.rst(),
329                'P' => builder.psh(),
330                _ => builder,
331            };
332        }
333
334        let stack = LayerStack::new()
335            .push(LayerStackEntry::Ethernet(
336                EthernetBuilder::new()
337                    .dst(crate::MacAddress::BROADCAST)
338                    .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
339            ))
340            .push(LayerStackEntry::Ipv4(
341                Ipv4Builder::new()
342                    .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
343                    .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
344            ))
345            .push(LayerStackEntry::Tcp(builder));
346
347        let stack = if !payload.is_empty() {
348            stack.push(LayerStackEntry::Raw(payload.to_vec()))
349        } else {
350            stack
351        };
352
353        stack.build_packet()
354    }
355
356    fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
357        let tcp = pkt.tcp().unwrap();
358        let buf = pkt.as_bytes();
359        (tcp, buf)
360    }
361
362    #[test]
363    fn test_three_way_handshake() {
364        let config = FlowConfig::default();
365        let mut state = TcpConversationState::new();
366
367        // SYN (client → server, forward)
368        let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
369        let (tcp, buf) = get_tcp_and_buf(&pkt);
370        state
371            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
372            .unwrap();
373        assert_eq!(state.conn_state, TcpConnectionState::SynSent);
374
375        // SYN-ACK (server → client, reverse)
376        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
377        let (tcp, buf) = get_tcp_and_buf(&pkt);
378        state
379            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
380            .unwrap();
381        assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
382
383        // ACK (client → server, forward)
384        let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
385        let (tcp, buf) = get_tcp_and_buf(&pkt);
386        state
387            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
388            .unwrap();
389        assert_eq!(state.conn_state, TcpConnectionState::Established);
390    }
391
392    #[test]
393    fn test_rst_closes_connection() {
394        let config = FlowConfig::default();
395        let mut state = TcpConversationState::new();
396        state.conn_state = TcpConnectionState::Established;
397
398        let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
399        let (tcp, buf) = get_tcp_and_buf(&pkt);
400        state
401            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
402            .unwrap();
403        assert_eq!(state.conn_state, TcpConnectionState::Closed);
404    }
405
406    #[test]
407    fn test_fin_handshake() {
408        let config = FlowConfig::default();
409        let mut state = TcpConversationState::new();
410        state.conn_state = TcpConnectionState::Established;
411
412        // FIN from forward direction
413        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
414        let (tcp, buf) = get_tcp_and_buf(&pkt);
415        state
416            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
417            .unwrap();
418        assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
419
420        // ACK of FIN from reverse
421        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
422        let (tcp, buf) = get_tcp_and_buf(&pkt);
423        state
424            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
425            .unwrap();
426        assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
427
428        // FIN from reverse
429        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
430        let (tcp, buf) = get_tcp_and_buf(&pkt);
431        state
432            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
433            .unwrap();
434        assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
435    }
436
437    #[test]
438    fn test_data_transfer_and_reassembly() {
439        let config = FlowConfig::default();
440        let mut state = TcpConversationState::new();
441        state.conn_state = TcpConnectionState::Established;
442
443        // Initialize forward reassembler
444        state.forward_endpoint.initial_seq = Some(999);
445        state.reassembler_fwd.initialize(1000);
446
447        // Data from forward direction
448        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
449        let (tcp, buf) = get_tcp_and_buf(&pkt);
450        state
451            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
452            .unwrap();
453
454        assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
455    }
456
457    #[test]
458    fn test_state_display() {
459        assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
460        assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
461        assert!(TcpConnectionState::Closed.is_closed());
462        assert!(TcpConnectionState::TimeWait.is_closed());
463        assert!(TcpConnectionState::SynSent.is_half_open());
464        assert!(!TcpConnectionState::Established.is_half_open());
465    }
466}