Skip to main content

stackforge_core/layer/tcp/
flags.rs

1//! TCP flags implementation.
2//!
3//! TCP flags are a 9-bit field in the TCP header (including NS, CWR, ECE).
4//! This module provides a structured representation matching Scapy's `FlagsField`.
5
6use std::fmt;
7
8/// TCP flags structure (9 bits).
9///
10/// Bit layout (from MSB to LSB in the 16-bit flags/reserved field):
11/// - NS (Nonce Sum) - ECN nonce concealment
12/// - CWR (Congestion Window Reduced)
13/// - ECE (ECN-Echo)
14/// - URG (Urgent)
15/// - ACK (Acknowledgment)
16/// - PSH (Push)
17/// - RST (Reset)
18/// - SYN (Synchronize)
19/// - FIN (Finish)
20///
21/// Scapy uses the string "FSRPAUECN" for flags (reversed order).
22#[derive(Clone, Copy, PartialEq, Eq, Default, Hash)]
23pub struct TcpFlags {
24    /// FIN - No more data from sender
25    pub fin: bool,
26    /// SYN - Synchronize sequence numbers
27    pub syn: bool,
28    /// RST - Reset the connection
29    pub rst: bool,
30    /// PSH - Push function
31    pub psh: bool,
32    /// ACK - Acknowledgment field significant
33    pub ack: bool,
34    /// URG - Urgent pointer field significant
35    pub urg: bool,
36    /// ECE - ECN-Echo (RFC 3168)
37    pub ece: bool,
38    /// CWR - Congestion Window Reduced (RFC 3168)
39    pub cwr: bool,
40    /// NS - ECN-nonce concealment protection (RFC 3540)
41    pub ns: bool,
42}
43
44impl TcpFlags {
45    /// No flags set
46    pub const NONE: Self = Self {
47        fin: false,
48        syn: false,
49        rst: false,
50        psh: false,
51        ack: false,
52        urg: false,
53        ece: false,
54        cwr: false,
55        ns: false,
56    };
57
58    /// SYN flag only (connection initiation)
59    pub const S: Self = Self {
60        fin: false,
61        syn: true,
62        rst: false,
63        psh: false,
64        ack: false,
65        urg: false,
66        ece: false,
67        cwr: false,
68        ns: false,
69    };
70
71    /// SYN+ACK flags (connection acknowledgment)
72    pub const SA: Self = Self {
73        fin: false,
74        syn: true,
75        rst: false,
76        psh: false,
77        ack: true,
78        urg: false,
79        ece: false,
80        cwr: false,
81        ns: false,
82    };
83
84    /// ACK flag only
85    pub const A: Self = Self {
86        fin: false,
87        syn: false,
88        rst: false,
89        psh: false,
90        ack: true,
91        urg: false,
92        ece: false,
93        cwr: false,
94        ns: false,
95    };
96
97    /// FIN+ACK flags (connection termination)
98    pub const FA: Self = Self {
99        fin: true,
100        syn: false,
101        rst: false,
102        psh: false,
103        ack: true,
104        urg: false,
105        ece: false,
106        cwr: false,
107        ns: false,
108    };
109
110    /// RST flag only (connection reset)
111    pub const R: Self = Self {
112        fin: false,
113        syn: false,
114        rst: true,
115        psh: false,
116        ack: false,
117        urg: false,
118        ece: false,
119        cwr: false,
120        ns: false,
121    };
122
123    /// RST+ACK flags
124    pub const RA: Self = Self {
125        fin: false,
126        syn: false,
127        rst: true,
128        psh: false,
129        ack: true,
130        urg: false,
131        ece: false,
132        cwr: false,
133        ns: false,
134    };
135
136    /// PSH+ACK flags (data push)
137    pub const PA: Self = Self {
138        fin: false,
139        syn: false,
140        rst: false,
141        psh: true,
142        ack: true,
143        urg: false,
144        ece: false,
145        cwr: false,
146        ns: false,
147    };
148
149    /// Flag bit positions (in the 9-bit flags field)
150    pub const FIN_BIT: u16 = 0x001;
151    pub const SYN_BIT: u16 = 0x002;
152    pub const RST_BIT: u16 = 0x004;
153    pub const PSH_BIT: u16 = 0x008;
154    pub const ACK_BIT: u16 = 0x010;
155    pub const URG_BIT: u16 = 0x020;
156    pub const ECE_BIT: u16 = 0x040;
157    pub const CWR_BIT: u16 = 0x080;
158    pub const NS_BIT: u16 = 0x100;
159
160    /// Create flags from a raw 16-bit value (data offset + reserved + flags).
161    ///
162    /// The flags are in the lower 9 bits (with NS in bit 8 of the high byte).
163    #[inline]
164    #[must_use]
165    pub fn from_u16(value: u16) -> Self {
166        Self {
167            fin: (value & Self::FIN_BIT) != 0,
168            syn: (value & Self::SYN_BIT) != 0,
169            rst: (value & Self::RST_BIT) != 0,
170            psh: (value & Self::PSH_BIT) != 0,
171            ack: (value & Self::ACK_BIT) != 0,
172            urg: (value & Self::URG_BIT) != 0,
173            ece: (value & Self::ECE_BIT) != 0,
174            cwr: (value & Self::CWR_BIT) != 0,
175            ns: (value & Self::NS_BIT) != 0,
176        }
177    }
178
179    /// Create flags from just the flags byte (lower 8 bits).
180    #[inline]
181    #[must_use]
182    pub fn from_byte(byte: u8) -> Self {
183        Self::from_u16(u16::from(byte))
184    }
185
186    /// Create flags from two bytes (`data_offset_reserved` + flags).
187    #[inline]
188    #[must_use]
189    pub fn from_bytes(hi: u8, lo: u8) -> Self {
190        let ns = (hi & 0x01) != 0;
191        let mut flags = Self::from_byte(lo);
192        flags.ns = ns;
193        flags
194    }
195
196    /// Convert to a raw 9-bit value.
197    #[inline]
198    #[must_use]
199    pub fn to_u16(self) -> u16 {
200        let mut value = 0u16;
201        if self.fin {
202            value |= Self::FIN_BIT;
203        }
204        if self.syn {
205            value |= Self::SYN_BIT;
206        }
207        if self.rst {
208            value |= Self::RST_BIT;
209        }
210        if self.psh {
211            value |= Self::PSH_BIT;
212        }
213        if self.ack {
214            value |= Self::ACK_BIT;
215        }
216        if self.urg {
217            value |= Self::URG_BIT;
218        }
219        if self.ece {
220            value |= Self::ECE_BIT;
221        }
222        if self.cwr {
223            value |= Self::CWR_BIT;
224        }
225        if self.ns {
226            value |= Self::NS_BIT;
227        }
228        value
229    }
230
231    /// Convert to the lower flags byte (without NS).
232    #[inline]
233    #[must_use]
234    pub fn to_byte(self) -> u8 {
235        (self.to_u16() & 0xFF) as u8
236    }
237
238    /// Get the NS bit for the high byte.
239    #[inline]
240    #[must_use]
241    pub fn ns_bit(self) -> u8 {
242        u8::from(self.ns)
243    }
244
245    /// Create flags from a string like "S", "SA", "FA", "PA", "R", etc.
246    /// Uses Scapy's "FSRPAUECN" convention.
247    #[must_use]
248    pub fn from_str(s: &str) -> Self {
249        let mut flags = Self::NONE;
250        for c in s.chars() {
251            match c {
252                'F' | 'f' => flags.fin = true,
253                'S' | 's' => flags.syn = true,
254                'R' | 'r' => flags.rst = true,
255                'P' | 'p' => flags.psh = true,
256                'A' | 'a' => flags.ack = true,
257                'U' | 'u' => flags.urg = true,
258                'E' | 'e' => flags.ece = true,
259                'C' | 'c' => flags.cwr = true,
260                'N' | 'n' => flags.ns = true,
261                _ => {}, // Ignore unknown characters
262            }
263        }
264        flags
265    }
266
267    /// Check if this is a SYN packet (SYN set, ACK not set).
268    #[inline]
269    #[must_use]
270    pub fn is_syn(&self) -> bool {
271        self.syn && !self.ack
272    }
273
274    /// Check if this is a SYN-ACK packet.
275    #[inline]
276    #[must_use]
277    pub fn is_syn_ack(&self) -> bool {
278        self.syn && self.ack
279    }
280
281    /// Check if this is a pure ACK packet.
282    #[inline]
283    #[must_use]
284    pub fn is_ack(&self) -> bool {
285        self.ack && !self.syn && !self.fin && !self.rst
286    }
287
288    /// Check if this is a FIN packet.
289    #[inline]
290    #[must_use]
291    pub fn is_fin(&self) -> bool {
292        self.fin
293    }
294
295    /// Check if this is a RST packet.
296    #[inline]
297    #[must_use]
298    pub fn is_rst(&self) -> bool {
299        self.rst
300    }
301
302    /// Check if ECN is enabled (ECE or CWR set).
303    #[inline]
304    #[must_use]
305    pub fn has_ecn(&self) -> bool {
306        self.ece || self.cwr
307    }
308
309    /// Check if any flag is set.
310    #[inline]
311    #[must_use]
312    pub fn is_empty(&self) -> bool {
313        !self.fin
314            && !self.syn
315            && !self.rst
316            && !self.psh
317            && !self.ack
318            && !self.urg
319            && !self.ece
320            && !self.cwr
321            && !self.ns
322    }
323
324    /// Count how many flags are set.
325    #[inline]
326    #[must_use]
327    pub fn count(&self) -> u8 {
328        let mut count = 0;
329        if self.fin {
330            count += 1;
331        }
332        if self.syn {
333            count += 1;
334        }
335        if self.rst {
336            count += 1;
337        }
338        if self.psh {
339            count += 1;
340        }
341        if self.ack {
342            count += 1;
343        }
344        if self.urg {
345            count += 1;
346        }
347        if self.ece {
348            count += 1;
349        }
350        if self.cwr {
351            count += 1;
352        }
353        if self.ns {
354            count += 1;
355        }
356        count
357    }
358}
359
360impl fmt::Display for TcpFlags {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        // Scapy order: FSRPAUECN
363        let mut s = String::with_capacity(9);
364        if self.fin {
365            s.push('F');
366        }
367        if self.syn {
368            s.push('S');
369        }
370        if self.rst {
371            s.push('R');
372        }
373        if self.psh {
374            s.push('P');
375        }
376        if self.ack {
377            s.push('A');
378        }
379        if self.urg {
380            s.push('U');
381        }
382        if self.ece {
383            s.push('E');
384        }
385        if self.cwr {
386            s.push('C');
387        }
388        if self.ns {
389            s.push('N');
390        }
391
392        if s.is_empty() {
393            write!(f, "-")
394        } else {
395            write!(f, "{s}")
396        }
397    }
398}
399
400impl fmt::Debug for TcpFlags {
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        write!(f, "TcpFlags({self})")
403    }
404}
405
406impl From<u16> for TcpFlags {
407    fn from(value: u16) -> Self {
408        Self::from_u16(value)
409    }
410}
411
412impl From<u8> for TcpFlags {
413    fn from(value: u8) -> Self {
414        Self::from_byte(value)
415    }
416}
417
418impl From<TcpFlags> for u16 {
419    fn from(flags: TcpFlags) -> Self {
420        flags.to_u16()
421    }
422}
423
424impl From<TcpFlags> for u8 {
425    fn from(flags: TcpFlags) -> Self {
426        flags.to_byte()
427    }
428}
429
430impl From<&str> for TcpFlags {
431    fn from(s: &str) -> Self {
432        Self::from_str(s)
433    }
434}
435
436impl std::ops::BitOr for TcpFlags {
437    type Output = Self;
438
439    fn bitor(self, rhs: Self) -> Self::Output {
440        Self {
441            fin: self.fin || rhs.fin,
442            syn: self.syn || rhs.syn,
443            rst: self.rst || rhs.rst,
444            psh: self.psh || rhs.psh,
445            ack: self.ack || rhs.ack,
446            urg: self.urg || rhs.urg,
447            ece: self.ece || rhs.ece,
448            cwr: self.cwr || rhs.cwr,
449            ns: self.ns || rhs.ns,
450        }
451    }
452}
453
454impl std::ops::BitAnd for TcpFlags {
455    type Output = Self;
456
457    fn bitand(self, rhs: Self) -> Self::Output {
458        Self {
459            fin: self.fin && rhs.fin,
460            syn: self.syn && rhs.syn,
461            rst: self.rst && rhs.rst,
462            psh: self.psh && rhs.psh,
463            ack: self.ack && rhs.ack,
464            urg: self.urg && rhs.urg,
465            ece: self.ece && rhs.ece,
466            cwr: self.cwr && rhs.cwr,
467            ns: self.ns && rhs.ns,
468        }
469    }
470}
471
472impl std::ops::BitOrAssign for TcpFlags {
473    fn bitor_assign(&mut self, rhs: Self) {
474        *self = *self | rhs;
475    }
476}
477
478impl std::ops::BitAndAssign for TcpFlags {
479    fn bitand_assign(&mut self, rhs: Self) {
480        *self = *self & rhs;
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_from_u16() {
490        let flags = TcpFlags::from_u16(0x02); // SYN
491        assert!(flags.syn);
492        assert!(!flags.ack);
493        assert!(!flags.fin);
494
495        let flags = TcpFlags::from_u16(0x12); // SYN+ACK
496        assert!(flags.syn);
497        assert!(flags.ack);
498
499        let flags = TcpFlags::from_u16(0x10); // ACK
500        assert!(!flags.syn);
501        assert!(flags.ack);
502
503        let flags = TcpFlags::from_u16(0x11); // FIN+ACK
504        assert!(flags.fin);
505        assert!(flags.ack);
506
507        let flags = TcpFlags::from_u16(0x04); // RST
508        assert!(flags.rst);
509
510        let flags = TcpFlags::from_u16(0x100); // NS
511        assert!(flags.ns);
512    }
513
514    #[test]
515    fn test_to_u16() {
516        assert_eq!(TcpFlags::S.to_u16(), 0x02);
517        assert_eq!(TcpFlags::SA.to_u16(), 0x12);
518        assert_eq!(TcpFlags::A.to_u16(), 0x10);
519        assert_eq!(TcpFlags::FA.to_u16(), 0x11);
520        assert_eq!(TcpFlags::R.to_u16(), 0x04);
521    }
522
523    #[test]
524    fn test_from_str() {
525        let flags = TcpFlags::from_str("S");
526        assert!(flags.syn);
527        assert!(!flags.ack);
528
529        let flags = TcpFlags::from_str("SA");
530        assert!(flags.syn);
531        assert!(flags.ack);
532
533        let flags = TcpFlags::from_str("FSRPAUECN");
534        assert!(flags.fin);
535        assert!(flags.syn);
536        assert!(flags.rst);
537        assert!(flags.psh);
538        assert!(flags.ack);
539        assert!(flags.urg);
540        assert!(flags.ece);
541        assert!(flags.cwr);
542        assert!(flags.ns);
543    }
544
545    #[test]
546    fn test_display() {
547        assert_eq!(TcpFlags::S.to_string(), "S");
548        assert_eq!(TcpFlags::SA.to_string(), "SA");
549        assert_eq!(TcpFlags::FA.to_string(), "FA");
550        assert_eq!(TcpFlags::PA.to_string(), "PA");
551        assert_eq!(TcpFlags::NONE.to_string(), "-");
552    }
553
554    #[test]
555    fn test_is_methods() {
556        assert!(TcpFlags::S.is_syn());
557        assert!(!TcpFlags::SA.is_syn()); // SYN-ACK is not "just SYN"
558        assert!(TcpFlags::SA.is_syn_ack());
559        assert!(TcpFlags::A.is_ack());
560        assert!(TcpFlags::FA.is_fin());
561        assert!(TcpFlags::R.is_rst());
562    }
563
564    #[test]
565    fn test_bit_ops() {
566        let flags = TcpFlags::S | TcpFlags::A;
567        assert!(flags.syn);
568        assert!(flags.ack);
569
570        let flags = TcpFlags::SA & TcpFlags::S;
571        assert!(flags.syn);
572        assert!(!flags.ack);
573    }
574
575    #[test]
576    fn test_from_bytes() {
577        // Data offset (5) + reserved (000) + NS (0) + flags (0x12 = SYN+ACK)
578        // Byte 12: 0101_0000 = 0x50 (data offset 5, NS=0)
579        // Byte 13: 0001_0010 = 0x12 (SYN+ACK)
580        let flags = TcpFlags::from_bytes(0x50, 0x12);
581        assert!(flags.syn);
582        assert!(flags.ack);
583        assert!(!flags.ns);
584
585        // With NS bit set (bit 0 of byte 12)
586        let flags = TcpFlags::from_bytes(0x51, 0x12);
587        assert!(flags.syn);
588        assert!(flags.ack);
589        assert!(flags.ns);
590    }
591
592    #[test]
593    fn test_constants() {
594        assert_eq!(TcpFlags::NONE.to_u16(), 0);
595        assert_eq!(TcpFlags::S.to_u16(), 0x002);
596        assert_eq!(TcpFlags::SA.to_u16(), 0x012);
597        assert_eq!(TcpFlags::A.to_u16(), 0x010);
598        assert_eq!(TcpFlags::FA.to_u16(), 0x011);
599        assert_eq!(TcpFlags::R.to_u16(), 0x004);
600        assert_eq!(TcpFlags::RA.to_u16(), 0x014);
601        assert_eq!(TcpFlags::PA.to_u16(), 0x018);
602    }
603}