Skip to main content

stackforge_core/flow/
tcp_state.rs

1use crate::TcpLayer;
2
3use super::config::FlowConfig;
4use super::error::FlowError;
5use super::key::FlowDirection;
6use super::tcp_reassembly::TcpReassembler;
7
8/// TCP connection states per RFC 793.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TcpConnectionState {
11    Listen,
12    SynSent,
13    SynRcvd,
14    Established,
15    FinWait1,
16    FinWait2,
17    CloseWait,
18    Closing,
19    LastAck,
20    TimeWait,
21    Closed,
22}
23
24impl TcpConnectionState {
25    /// Human-readable state name.
26    #[must_use]
27    pub fn name(&self) -> &'static str {
28        match self {
29            Self::Listen => "LISTEN",
30            Self::SynSent => "SYN_SENT",
31            Self::SynRcvd => "SYN_RCVD",
32            Self::Established => "ESTABLISHED",
33            Self::FinWait1 => "FIN_WAIT_1",
34            Self::FinWait2 => "FIN_WAIT_2",
35            Self::CloseWait => "CLOSE_WAIT",
36            Self::Closing => "CLOSING",
37            Self::LastAck => "LAST_ACK",
38            Self::TimeWait => "TIME_WAIT",
39            Self::Closed => "CLOSED",
40        }
41    }
42
43    /// Whether this is a terminal/closed state.
44    #[must_use]
45    pub fn is_closed(&self) -> bool {
46        matches!(self, Self::Closed | Self::TimeWait)
47    }
48
49    /// Whether this is a half-open state (not yet established).
50    #[must_use]
51    pub fn is_half_open(&self) -> bool {
52        matches!(self, Self::Listen | Self::SynSent | Self::SynRcvd)
53    }
54}
55
56impl std::fmt::Display for TcpConnectionState {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str(self.name())
59    }
60}
61
62/// Per-endpoint sequence tracking state.
63#[derive(Debug, Clone)]
64pub struct TcpEndpointState {
65    /// Next expected sequence number from this endpoint.
66    pub next_expected_seq: u32,
67    /// Last acknowledged sequence number from this endpoint.
68    pub last_ack: u32,
69    /// Advertised receive window size.
70    pub window_size: u16,
71    /// Initial sequence number (set on SYN).
72    pub initial_seq: Option<u32>,
73}
74
75impl TcpEndpointState {
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            next_expected_seq: 0,
80            last_ack: 0,
81            window_size: 0,
82            initial_seq: None,
83        }
84    }
85}
86
87impl Default for TcpEndpointState {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93/// Complete TCP conversation state including connection tracking,
94/// per-endpoint sequence state, and stream reassembly.
95#[derive(Debug)]
96pub struct TcpConversationState {
97    /// Current connection state (RFC 793 state machine).
98    pub conn_state: TcpConnectionState,
99    /// Sequence tracking for the forward direction (`addr_a` → `addr_b`).
100    pub forward_endpoint: TcpEndpointState,
101    /// Sequence tracking for the reverse direction (`addr_b` → `addr_a`).
102    pub reverse_endpoint: TcpEndpointState,
103    /// Stream reassembly for forward direction.
104    pub reassembler_fwd: TcpReassembler,
105    /// Stream reassembly for reverse direction.
106    pub reassembler_rev: TcpReassembler,
107}
108
109impl TcpConversationState {
110    #[must_use]
111    pub fn new() -> Self {
112        Self {
113            conn_state: TcpConnectionState::Listen,
114            forward_endpoint: TcpEndpointState::new(),
115            reverse_endpoint: TcpEndpointState::new(),
116            reassembler_fwd: TcpReassembler::new(),
117            reassembler_rev: TcpReassembler::new(),
118        }
119    }
120
121    /// Process a TCP packet, updating connection state and reassembly buffers.
122    ///
123    /// `direction` indicates whether this packet is Forward (`addr_a` → `addr_b`)
124    /// or Reverse (`addr_b` → `addr_a`) relative to the canonical key.
125    /// `tcp` is the TCP layer view, `buf` is the full packet buffer.
126    pub fn process_packet(
127        &mut self,
128        direction: FlowDirection,
129        tcp: &TcpLayer,
130        buf: &[u8],
131        config: &FlowConfig,
132    ) -> Result<(), FlowError> {
133        let flags = tcp
134            .flags(buf)
135            .map_err(|e| FlowError::PacketError(e.into()))?;
136        let seq = tcp.seq(buf).map_err(|e| FlowError::PacketError(e.into()))?;
137        let ack = tcp.ack(buf).map_err(|e| FlowError::PacketError(e.into()))?;
138        let window = tcp
139            .window(buf)
140            .map_err(|e| FlowError::PacketError(e.into()))?;
141
142        // Determine payload boundaries
143        let data_offset = tcp
144            .data_offset(buf)
145            .map_err(|e| FlowError::PacketError(e.into()))?;
146        let header_bytes = (data_offset as usize) * 4;
147        let tcp_start = tcp.index.start;
148        // TCP payload starts after the TCP header. Since the TCP layer's
149        // index.end marks the header boundary, the payload is everything
150        // from header end to the end of the packet buffer.
151        let payload_start = tcp_start + header_bytes;
152        let payload = if payload_start < buf.len() {
153            &buf[payload_start..buf.len()]
154        } else {
155            &[]
156        };
157
158        // Get mutable refs to endpoint and reassembler for this direction
159        let (sender, _receiver, reassembler) = match direction {
160            FlowDirection::Forward => (
161                &mut self.forward_endpoint,
162                &mut self.reverse_endpoint,
163                &mut self.reassembler_fwd,
164            ),
165            FlowDirection::Reverse => (
166                &mut self.reverse_endpoint,
167                &mut self.forward_endpoint,
168                &mut self.reassembler_rev,
169            ),
170        };
171
172        // Update endpoint state
173        sender.window_size = window;
174
175        // State machine transitions
176        if flags.rst {
177            self.conn_state = TcpConnectionState::Closed;
178            return Ok(());
179        }
180
181        match self.conn_state {
182            TcpConnectionState::Listen => {
183                if flags.syn && !flags.ack {
184                    // SYN from initiator
185                    sender.initial_seq = Some(seq);
186                    sender.next_expected_seq = seq.wrapping_add(1); // SYN consumes 1 seq
187                    self.conn_state = TcpConnectionState::SynSent;
188                }
189            },
190            TcpConnectionState::SynSent => {
191                if flags.syn && flags.ack {
192                    // SYN-ACK from responder
193                    sender.initial_seq = Some(seq);
194                    sender.next_expected_seq = seq.wrapping_add(1);
195                    sender.last_ack = ack;
196                    self.conn_state = TcpConnectionState::SynRcvd;
197                }
198            },
199            TcpConnectionState::SynRcvd => {
200                if flags.ack && !flags.syn {
201                    // Final ACK of 3-way handshake
202                    sender.last_ack = ack;
203                    self.conn_state = TcpConnectionState::Established;
204                    // Initialize reassemblers with ISN+1 (after SYN)
205                    if !self.reassembler_fwd.is_initialized()
206                        && let Some(isn) = self.forward_endpoint.initial_seq
207                    {
208                        self.reassembler_fwd.initialize(isn.wrapping_add(1));
209                    }
210                    if !self.reassembler_rev.is_initialized()
211                        && let Some(isn) = self.reverse_endpoint.initial_seq
212                    {
213                        self.reassembler_rev.initialize(isn.wrapping_add(1));
214                    }
215                }
216            },
217            TcpConnectionState::Established => {
218                sender.last_ack = ack;
219
220                // Process payload through reassembler
221                if !payload.is_empty() {
222                    // Ignore reassembly errors (buffer full, etc.) — they don't
223                    // affect connection state tracking
224                    let _ = reassembler.process_segment(seq, payload, config);
225                }
226
227                if flags.fin {
228                    sender.next_expected_seq =
229                        seq.wrapping_add(payload.len() as u32).wrapping_add(1); // FIN consumes 1 seq
230                    match direction {
231                        FlowDirection::Forward => {
232                            self.conn_state = TcpConnectionState::FinWait1;
233                        },
234                        FlowDirection::Reverse => {
235                            self.conn_state = TcpConnectionState::CloseWait;
236                        },
237                    }
238                } else {
239                    sender.next_expected_seq = seq.wrapping_add(payload.len() as u32);
240                }
241            },
242            TcpConnectionState::FinWait1 => {
243                if flags.fin && flags.ack {
244                    // Simultaneous close
245                    self.conn_state = TcpConnectionState::TimeWait;
246                } else if flags.ack {
247                    self.conn_state = TcpConnectionState::FinWait2;
248                } else if flags.fin {
249                    self.conn_state = TcpConnectionState::Closing;
250                }
251            },
252            TcpConnectionState::FinWait2 => {
253                if flags.fin {
254                    self.conn_state = TcpConnectionState::TimeWait;
255                }
256            },
257            TcpConnectionState::CloseWait => {
258                if flags.fin {
259                    self.conn_state = TcpConnectionState::LastAck;
260                }
261            },
262            TcpConnectionState::Closing => {
263                if flags.ack {
264                    self.conn_state = TcpConnectionState::TimeWait;
265                }
266            },
267            TcpConnectionState::LastAck => {
268                if flags.ack {
269                    self.conn_state = TcpConnectionState::Closed;
270                }
271            },
272            TcpConnectionState::TimeWait | TcpConnectionState::Closed => {
273                // Terminal states — no further transitions
274            },
275        }
276
277        Ok(())
278    }
279}
280
281impl Default for TcpConversationState {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::layer::stack::{LayerStack, LayerStackEntry};
291    use crate::{EthernetBuilder, Ipv4Builder, TcpBuilder};
292
293    fn make_tcp_packet(
294        src_port: u16,
295        dst_port: u16,
296        seq: u32,
297        ack_num: u32,
298        flags: &str,
299        payload: &[u8],
300    ) -> crate::Packet {
301        let mut builder = TcpBuilder::new()
302            .src_port(src_port)
303            .dst_port(dst_port)
304            .seq(seq)
305            .ack_num(ack_num)
306            .window(65535);
307
308        for c in flags.chars() {
309            builder = match c {
310                'S' => builder.syn(),
311                'A' => builder.ack(),
312                'F' => builder.fin(),
313                'R' => builder.rst(),
314                'P' => builder.psh(),
315                _ => builder,
316            };
317        }
318
319        let stack = LayerStack::new()
320            .push(LayerStackEntry::Ethernet(
321                EthernetBuilder::new()
322                    .dst(crate::MacAddress::BROADCAST)
323                    .src(crate::MacAddress::new([0, 1, 2, 3, 4, 5])),
324            ))
325            .push(LayerStackEntry::Ipv4(
326                Ipv4Builder::new()
327                    .src(std::net::Ipv4Addr::new(10, 0, 0, 1))
328                    .dst(std::net::Ipv4Addr::new(10, 0, 0, 2)),
329            ))
330            .push(LayerStackEntry::Tcp(builder));
331
332        let stack = if !payload.is_empty() {
333            stack.push(LayerStackEntry::Raw(payload.to_vec()))
334        } else {
335            stack
336        };
337
338        stack.build_packet()
339    }
340
341    fn get_tcp_and_buf(pkt: &crate::Packet) -> (TcpLayer, &[u8]) {
342        let tcp = pkt.tcp().unwrap();
343        let buf = pkt.as_bytes();
344        (tcp, buf)
345    }
346
347    #[test]
348    fn test_three_way_handshake() {
349        let config = FlowConfig::default();
350        let mut state = TcpConversationState::new();
351
352        // SYN (client → server, forward)
353        let pkt = make_tcp_packet(12345, 80, 1000, 0, "S", &[]);
354        let (tcp, buf) = get_tcp_and_buf(&pkt);
355        state
356            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
357            .unwrap();
358        assert_eq!(state.conn_state, TcpConnectionState::SynSent);
359
360        // SYN-ACK (server → client, reverse)
361        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "SA", &[]);
362        let (tcp, buf) = get_tcp_and_buf(&pkt);
363        state
364            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
365            .unwrap();
366        assert_eq!(state.conn_state, TcpConnectionState::SynRcvd);
367
368        // ACK (client → server, forward)
369        let pkt = make_tcp_packet(12345, 80, 1001, 2001, "A", &[]);
370        let (tcp, buf) = get_tcp_and_buf(&pkt);
371        state
372            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
373            .unwrap();
374        assert_eq!(state.conn_state, TcpConnectionState::Established);
375    }
376
377    #[test]
378    fn test_rst_closes_connection() {
379        let config = FlowConfig::default();
380        let mut state = TcpConversationState::new();
381        state.conn_state = TcpConnectionState::Established;
382
383        let pkt = make_tcp_packet(12345, 80, 1000, 0, "R", &[]);
384        let (tcp, buf) = get_tcp_and_buf(&pkt);
385        state
386            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
387            .unwrap();
388        assert_eq!(state.conn_state, TcpConnectionState::Closed);
389    }
390
391    #[test]
392    fn test_fin_handshake() {
393        let config = FlowConfig::default();
394        let mut state = TcpConversationState::new();
395        state.conn_state = TcpConnectionState::Established;
396
397        // FIN from forward direction
398        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "FA", &[]);
399        let (tcp, buf) = get_tcp_and_buf(&pkt);
400        state
401            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
402            .unwrap();
403        assert_eq!(state.conn_state, TcpConnectionState::FinWait1);
404
405        // ACK of FIN from reverse
406        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "A", &[]);
407        let (tcp, buf) = get_tcp_and_buf(&pkt);
408        state
409            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
410            .unwrap();
411        assert_eq!(state.conn_state, TcpConnectionState::FinWait2);
412
413        // FIN from reverse
414        let pkt = make_tcp_packet(80, 12345, 2000, 1001, "FA", &[]);
415        let (tcp, buf) = get_tcp_and_buf(&pkt);
416        state
417            .process_packet(FlowDirection::Reverse, &tcp, buf, &config)
418            .unwrap();
419        assert_eq!(state.conn_state, TcpConnectionState::TimeWait);
420    }
421
422    #[test]
423    fn test_data_transfer_and_reassembly() {
424        let config = FlowConfig::default();
425        let mut state = TcpConversationState::new();
426        state.conn_state = TcpConnectionState::Established;
427
428        // Initialize forward reassembler
429        state.forward_endpoint.initial_seq = Some(999);
430        state.reassembler_fwd.initialize(1000);
431
432        // Data from forward direction
433        let pkt = make_tcp_packet(12345, 80, 1000, 2000, "A", b"GET /");
434        let (tcp, buf) = get_tcp_and_buf(&pkt);
435        state
436            .process_packet(FlowDirection::Forward, &tcp, buf, &config)
437            .unwrap();
438
439        assert_eq!(state.reassembler_fwd.reassembled_data(), b"GET /");
440    }
441
442    #[test]
443    fn test_state_display() {
444        assert_eq!(TcpConnectionState::Established.name(), "ESTABLISHED");
445        assert_eq!(TcpConnectionState::SynSent.name(), "SYN_SENT");
446        assert!(TcpConnectionState::Closed.is_closed());
447        assert!(TcpConnectionState::TimeWait.is_closed());
448        assert!(TcpConnectionState::SynSent.is_half_open());
449        assert!(!TcpConnectionState::Established.is_half_open());
450    }
451}