Skip to main content

stackforge_core/layer/ipv4/
header.rs

1//! IPv4 header layer implementation.
2//!
3//! Provides zero-copy access to IPv4 header fields.
4
5use std::net::Ipv4Addr;
6
7use crate::layer::field::{Field, FieldDesc, FieldError, FieldType, FieldValue};
8use crate::layer::{Layer, LayerIndex, LayerKind};
9
10use super::checksum::ipv4_checksum;
11use super::options::{Ipv4Options, parse_options};
12use super::protocol;
13use super::routing::Ipv4Route;
14
15/// Minimum IPv4 header length (no options).
16pub const IPV4_MIN_HEADER_LEN: usize = 20;
17
18/// Maximum IPv4 header length (with maximum options).
19pub const IPV4_MAX_HEADER_LEN: usize = 60;
20
21/// Maximum total length of an IPv4 packet.
22pub const IPV4_MAX_PACKET_LEN: usize = 65535;
23
24/// Field offsets within the IPv4 header.
25pub mod offsets {
26    /// Version (4 bits) + IHL (4 bits)
27    pub const VERSION_IHL: usize = 0;
28    /// DSCP (6 bits) + ECN (2 bits) - also known as TOS
29    pub const TOS: usize = 1;
30    /// Total length (16 bits)
31    pub const TOTAL_LEN: usize = 2;
32    /// Identification (16 bits)
33    pub const ID: usize = 4;
34    /// Flags (3 bits) + Fragment offset (13 bits)
35    pub const FLAGS_FRAG: usize = 6;
36    /// Time to live (8 bits)
37    pub const TTL: usize = 8;
38    /// Protocol (8 bits)
39    pub const PROTOCOL: usize = 9;
40    /// Header checksum (16 bits)
41    pub const CHECKSUM: usize = 10;
42    /// Source address (32 bits)
43    pub const SRC: usize = 12;
44    /// Destination address (32 bits)
45    pub const DST: usize = 16;
46    /// Options start (if IHL > 5)
47    pub const OPTIONS: usize = 20;
48}
49
50/// Field descriptors for dynamic access.
51pub static FIELDS: &[FieldDesc] = &[
52    FieldDesc::new("version", offsets::VERSION_IHL, 1, FieldType::U8),
53    FieldDesc::new("ihl", offsets::VERSION_IHL, 1, FieldType::U8),
54    FieldDesc::new("tos", offsets::TOS, 1, FieldType::U8),
55    FieldDesc::new("dscp", offsets::TOS, 1, FieldType::U8),
56    FieldDesc::new("ecn", offsets::TOS, 1, FieldType::U8),
57    FieldDesc::new("len", offsets::TOTAL_LEN, 2, FieldType::U16),
58    FieldDesc::new("id", offsets::ID, 2, FieldType::U16),
59    FieldDesc::new("flags", offsets::FLAGS_FRAG, 1, FieldType::U8),
60    FieldDesc::new("frag", offsets::FLAGS_FRAG, 2, FieldType::U16),
61    FieldDesc::new("ttl", offsets::TTL, 1, FieldType::U8),
62    FieldDesc::new("proto", offsets::PROTOCOL, 1, FieldType::U8),
63    FieldDesc::new("chksum", offsets::CHECKSUM, 2, FieldType::U16),
64    FieldDesc::new("src", offsets::SRC, 4, FieldType::Ipv4),
65    FieldDesc::new("dst", offsets::DST, 4, FieldType::Ipv4),
66];
67
68/// IPv4 flags in the flags/fragment offset field.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub struct Ipv4Flags {
71    /// Reserved/Evil bit (should be 0)
72    pub reserved: bool,
73    /// Don't Fragment
74    pub df: bool,
75    /// More Fragments
76    pub mf: bool,
77}
78
79impl Ipv4Flags {
80    pub const NONE: Self = Self {
81        reserved: false,
82        df: false,
83        mf: false,
84    };
85
86    pub const DF: Self = Self {
87        reserved: false,
88        df: true,
89        mf: false,
90    };
91
92    pub const MF: Self = Self {
93        reserved: false,
94        df: false,
95        mf: true,
96    };
97
98    /// Create flags from a raw byte value (upper 3 bits).
99    #[inline]
100    #[must_use]
101    pub fn from_byte(byte: u8) -> Self {
102        Self {
103            reserved: (byte & 0x80) != 0,
104            df: (byte & 0x40) != 0,
105            mf: (byte & 0x20) != 0,
106        }
107    }
108
109    /// Convert to a raw byte value (upper 3 bits).
110    #[inline]
111    #[must_use]
112    pub fn to_byte(self) -> u8 {
113        let mut b = 0u8;
114        if self.reserved {
115            b |= 0x80;
116        }
117        if self.df {
118            b |= 0x40;
119        }
120        if self.mf {
121            b |= 0x20;
122        }
123        b
124    }
125
126    /// Check if this is a fragment (MF set or offset > 0).
127    #[inline]
128    #[must_use]
129    pub fn is_fragment(self, offset: u16) -> bool {
130        self.mf || offset > 0
131    }
132}
133
134impl std::fmt::Display for Ipv4Flags {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        let mut parts = Vec::new();
137        if self.reserved {
138            parts.push("evil");
139        }
140        if self.df {
141            parts.push("DF");
142        }
143        if self.mf {
144            parts.push("MF");
145        }
146        if parts.is_empty() {
147            write!(f, "-")
148        } else {
149            write!(f, "{}", parts.join("+"))
150        }
151    }
152}
153
154/// A view into an IPv4 packet header.
155#[derive(Debug, Clone)]
156pub struct Ipv4Layer {
157    pub index: LayerIndex,
158}
159
160impl Ipv4Layer {
161    /// Create a new IPv4 layer view with specified bounds.
162    #[inline]
163    #[must_use]
164    pub const fn new(start: usize, end: usize) -> Self {
165        Self {
166            index: LayerIndex::new(LayerKind::Ipv4, start, end),
167        }
168    }
169
170    /// Create a layer at offset 0 with minimum header length.
171    #[inline]
172    #[must_use]
173    pub const fn at_start() -> Self {
174        Self::new(0, IPV4_MIN_HEADER_LEN)
175    }
176
177    /// Create a layer at the specified offset with minimum header length.
178    #[inline]
179    #[must_use]
180    pub const fn at_offset(offset: usize) -> Self {
181        Self::new(offset, offset + IPV4_MIN_HEADER_LEN)
182    }
183
184    /// Create a layer at offset, calculating actual header length from IHL.
185    pub fn at_offset_dynamic(buf: &[u8], offset: usize) -> Result<Self, FieldError> {
186        if buf.len() < offset + 1 {
187            return Err(FieldError::BufferTooShort {
188                offset,
189                need: 1,
190                have: buf.len().saturating_sub(offset),
191            });
192        }
193
194        let ihl = (buf[offset] & 0x0F) as usize;
195        let header_len = ihl * 4;
196
197        if header_len < IPV4_MIN_HEADER_LEN {
198            return Err(FieldError::InvalidValue(format!(
199                "IHL {ihl} is less than minimum (5)"
200            )));
201        }
202
203        if buf.len() < offset + header_len {
204            return Err(FieldError::BufferTooShort {
205                offset,
206                need: header_len,
207                have: buf.len().saturating_sub(offset),
208            });
209        }
210
211        Ok(Self::new(offset, offset + header_len))
212    }
213
214    /// Validate that the buffer contains a valid IPv4 header at the offset.
215    pub fn validate(buf: &[u8], offset: usize) -> Result<(), FieldError> {
216        if buf.len() < offset + IPV4_MIN_HEADER_LEN {
217            return Err(FieldError::BufferTooShort {
218                offset,
219                need: IPV4_MIN_HEADER_LEN,
220                have: buf.len().saturating_sub(offset),
221            });
222        }
223
224        let version = (buf[offset] >> 4) & 0x0F;
225        if version != 4 {
226            return Err(FieldError::InvalidValue(format!(
227                "not IPv4: version = {version}"
228            )));
229        }
230
231        let ihl = (buf[offset] & 0x0F) as usize;
232        if ihl < 5 {
233            return Err(FieldError::InvalidValue(format!(
234                "IHL {ihl} is less than minimum (5)"
235            )));
236        }
237
238        let header_len = ihl * 4;
239        if buf.len() < offset + header_len {
240            return Err(FieldError::BufferTooShort {
241                offset,
242                need: header_len,
243                have: buf.len().saturating_sub(offset),
244            });
245        }
246
247        Ok(())
248    }
249
250    /// Calculate the actual header length from the buffer.
251    #[must_use]
252    pub fn calculate_header_len(&self, buf: &[u8]) -> usize {
253        self.ihl(buf).map(|ihl| (ihl as usize) * 4).unwrap_or(20)
254    }
255
256    /// Get the options length (header length - 20).
257    #[must_use]
258    pub fn options_len(&self, buf: &[u8]) -> usize {
259        self.calculate_header_len(buf)
260            .saturating_sub(IPV4_MIN_HEADER_LEN)
261    }
262
263    // ========== Field Readers ==========
264
265    /// Read the version field (should be 4).
266    #[inline]
267    pub fn version(&self, buf: &[u8]) -> Result<u8, FieldError> {
268        let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?;
269        Ok((b >> 4) & 0x0F)
270    }
271
272    /// Read the Internet Header Length (in 32-bit words).
273    #[inline]
274    pub fn ihl(&self, buf: &[u8]) -> Result<u8, FieldError> {
275        let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?;
276        Ok(b & 0x0F)
277    }
278
279    /// Read the header length in bytes.
280    #[inline]
281    pub fn header_len_bytes(&self, buf: &[u8]) -> Result<usize, FieldError> {
282        Ok((self.ihl(buf)? as usize) * 4)
283    }
284
285    /// Read the Type of Service (TOS) field.
286    #[inline]
287    pub fn tos(&self, buf: &[u8]) -> Result<u8, FieldError> {
288        u8::read(buf, self.index.start + offsets::TOS)
289    }
290
291    /// Read the DSCP (Differentiated Services Code Point).
292    #[inline]
293    pub fn dscp(&self, buf: &[u8]) -> Result<u8, FieldError> {
294        Ok((self.tos(buf)? >> 2) & 0x3F)
295    }
296
297    /// Read the ECN (Explicit Congestion Notification).
298    #[inline]
299    pub fn ecn(&self, buf: &[u8]) -> Result<u8, FieldError> {
300        Ok(self.tos(buf)? & 0x03)
301    }
302
303    /// Read the total length field.
304    #[inline]
305    pub fn total_len(&self, buf: &[u8]) -> Result<u16, FieldError> {
306        u16::read(buf, self.index.start + offsets::TOTAL_LEN)
307    }
308
309    /// Read the identification field.
310    #[inline]
311    pub fn id(&self, buf: &[u8]) -> Result<u16, FieldError> {
312        u16::read(buf, self.index.start + offsets::ID)
313    }
314
315    /// Read the flags as a structured type.
316    #[inline]
317    pub fn flags(&self, buf: &[u8]) -> Result<Ipv4Flags, FieldError> {
318        let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
319        Ok(Ipv4Flags::from_byte(b))
320    }
321
322    /// Read the raw flags byte (upper 3 bits of flags/frag field).
323    #[inline]
324    pub fn flags_raw(&self, buf: &[u8]) -> Result<u8, FieldError> {
325        let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
326        Ok((b >> 5) & 0x07)
327    }
328
329    /// Read the fragment offset (in 8-byte units).
330    #[inline]
331    pub fn frag_offset(&self, buf: &[u8]) -> Result<u16, FieldError> {
332        let val = u16::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
333        Ok(val & 0x1FFF)
334    }
335
336    /// Read the fragment offset in bytes.
337    #[inline]
338    pub fn frag_offset_bytes(&self, buf: &[u8]) -> Result<u32, FieldError> {
339        Ok(u32::from(self.frag_offset(buf)?) * 8)
340    }
341
342    /// Read the Time to Live field.
343    #[inline]
344    pub fn ttl(&self, buf: &[u8]) -> Result<u8, FieldError> {
345        u8::read(buf, self.index.start + offsets::TTL)
346    }
347
348    /// Read the protocol field.
349    #[inline]
350    pub fn protocol(&self, buf: &[u8]) -> Result<u8, FieldError> {
351        u8::read(buf, self.index.start + offsets::PROTOCOL)
352    }
353
354    /// Read the header checksum.
355    #[inline]
356    pub fn checksum(&self, buf: &[u8]) -> Result<u16, FieldError> {
357        u16::read(buf, self.index.start + offsets::CHECKSUM)
358    }
359
360    /// Read the source IP address.
361    #[inline]
362    pub fn src(&self, buf: &[u8]) -> Result<Ipv4Addr, FieldError> {
363        Ipv4Addr::read(buf, self.index.start + offsets::SRC)
364    }
365
366    /// Read the destination IP address.
367    #[inline]
368    pub fn dst(&self, buf: &[u8]) -> Result<Ipv4Addr, FieldError> {
369        Ipv4Addr::read(buf, self.index.start + offsets::DST)
370    }
371
372    /// Read the options bytes (if any).
373    pub fn options_bytes<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> {
374        let header_len = self.calculate_header_len(buf);
375        let opts_start = self.index.start + IPV4_MIN_HEADER_LEN;
376        let opts_end = self.index.start + header_len;
377
378        if buf.len() < opts_end {
379            return Err(FieldError::BufferTooShort {
380                offset: opts_start,
381                need: header_len - IPV4_MIN_HEADER_LEN,
382                have: buf.len().saturating_sub(opts_start),
383            });
384        }
385
386        Ok(&buf[opts_start..opts_end])
387    }
388
389    /// Parse and return the options.
390    pub fn options(&self, buf: &[u8]) -> Result<Ipv4Options, FieldError> {
391        let opts_bytes = self.options_bytes(buf)?;
392        parse_options(opts_bytes)
393    }
394
395    // ========== Field Writers ==========
396
397    /// Set the version field.
398    #[inline]
399    pub fn set_version(&self, buf: &mut [u8], version: u8) -> Result<(), FieldError> {
400        let offset = self.index.start + offsets::VERSION_IHL;
401        let current = u8::read(buf, offset)?;
402        let new_val = (current & 0x0F) | ((version & 0x0F) << 4);
403        new_val.write(buf, offset)
404    }
405
406    /// Set the IHL field.
407    #[inline]
408    pub fn set_ihl(&self, buf: &mut [u8], ihl: u8) -> Result<(), FieldError> {
409        let offset = self.index.start + offsets::VERSION_IHL;
410        let current = u8::read(buf, offset)?;
411        let new_val = (current & 0xF0) | (ihl & 0x0F);
412        new_val.write(buf, offset)
413    }
414
415    /// Set the TOS field.
416    #[inline]
417    pub fn set_tos(&self, buf: &mut [u8], tos: u8) -> Result<(), FieldError> {
418        tos.write(buf, self.index.start + offsets::TOS)
419    }
420
421    /// Set the DSCP field.
422    #[inline]
423    pub fn set_dscp(&self, buf: &mut [u8], dscp: u8) -> Result<(), FieldError> {
424        let offset = self.index.start + offsets::TOS;
425        let current = u8::read(buf, offset)?;
426        let new_val = (current & 0x03) | ((dscp & 0x3F) << 2);
427        new_val.write(buf, offset)
428    }
429
430    /// Set the ECN field.
431    #[inline]
432    pub fn set_ecn(&self, buf: &mut [u8], ecn: u8) -> Result<(), FieldError> {
433        let offset = self.index.start + offsets::TOS;
434        let current = u8::read(buf, offset)?;
435        let new_val = (current & 0xFC) | (ecn & 0x03);
436        new_val.write(buf, offset)
437    }
438
439    /// Set the total length field.
440    #[inline]
441    pub fn set_total_len(&self, buf: &mut [u8], len: u16) -> Result<(), FieldError> {
442        len.write(buf, self.index.start + offsets::TOTAL_LEN)
443    }
444
445    /// Set the identification field.
446    #[inline]
447    pub fn set_id(&self, buf: &mut [u8], id: u16) -> Result<(), FieldError> {
448        id.write(buf, self.index.start + offsets::ID)
449    }
450
451    /// Set the flags field.
452    #[inline]
453    pub fn set_flags(&self, buf: &mut [u8], flags: Ipv4Flags) -> Result<(), FieldError> {
454        let offset = self.index.start + offsets::FLAGS_FRAG;
455        let current = u8::read(buf, offset)?;
456        let new_val = (current & 0x1F) | flags.to_byte();
457        new_val.write(buf, offset)
458    }
459
460    /// Set the fragment offset (in 8-byte units).
461    #[inline]
462    pub fn set_frag_offset(&self, buf: &mut [u8], offset_val: u16) -> Result<(), FieldError> {
463        let offset = self.index.start + offsets::FLAGS_FRAG;
464        let current = u16::read(buf, offset)?;
465        let new_val = (current & 0xE000) | (offset_val & 0x1FFF);
466        new_val.write(buf, offset)
467    }
468
469    /// Set the TTL field.
470    #[inline]
471    pub fn set_ttl(&self, buf: &mut [u8], ttl: u8) -> Result<(), FieldError> {
472        ttl.write(buf, self.index.start + offsets::TTL)
473    }
474
475    /// Set the protocol field.
476    #[inline]
477    pub fn set_protocol(&self, buf: &mut [u8], proto: u8) -> Result<(), FieldError> {
478        proto.write(buf, self.index.start + offsets::PROTOCOL)
479    }
480
481    /// Set the checksum field.
482    #[inline]
483    pub fn set_checksum(&self, buf: &mut [u8], checksum: u16) -> Result<(), FieldError> {
484        checksum.write(buf, self.index.start + offsets::CHECKSUM)
485    }
486
487    /// Set the source IP address.
488    #[inline]
489    pub fn set_src(&self, buf: &mut [u8], src: Ipv4Addr) -> Result<(), FieldError> {
490        src.write(buf, self.index.start + offsets::SRC)
491    }
492
493    /// Set the destination IP address.
494    #[inline]
495    pub fn set_dst(&self, buf: &mut [u8], dst: Ipv4Addr) -> Result<(), FieldError> {
496        dst.write(buf, self.index.start + offsets::DST)
497    }
498
499    /// Compute and set the header checksum.
500    pub fn compute_checksum(&self, buf: &mut [u8]) -> Result<u16, FieldError> {
501        // Zero out existing checksum first
502        self.set_checksum(buf, 0)?;
503
504        // Calculate header length
505        let header_len = self.calculate_header_len(buf);
506        let header_end = self.index.start + header_len;
507
508        if buf.len() < header_end {
509            return Err(FieldError::BufferTooShort {
510                offset: self.index.start,
511                need: header_len,
512                have: buf.len().saturating_sub(self.index.start),
513            });
514        }
515
516        // Compute checksum over header bytes
517        let checksum = ipv4_checksum(&buf[self.index.start..header_end]);
518
519        // Write checksum
520        self.set_checksum(buf, checksum)?;
521
522        Ok(checksum)
523    }
524
525    /// Verify the header checksum.
526    pub fn verify_checksum(&self, buf: &[u8]) -> Result<bool, FieldError> {
527        let header_len = self.calculate_header_len(buf);
528        let header_end = self.index.start + header_len;
529
530        if buf.len() < header_end {
531            return Err(FieldError::BufferTooShort {
532                offset: self.index.start,
533                need: header_len,
534                have: buf.len().saturating_sub(self.index.start),
535            });
536        }
537
538        let checksum = ipv4_checksum(&buf[self.index.start..header_end]);
539        Ok(checksum == 0 || checksum == 0xFFFF)
540    }
541
542    // ========== Dynamic Field Access ==========
543
544    /// Get a field value by name.
545    pub fn get_field(&self, buf: &[u8], name: &str) -> Option<Result<FieldValue, FieldError>> {
546        match name {
547            "version" => Some(self.version(buf).map(FieldValue::U8)),
548            "ihl" => Some(self.ihl(buf).map(FieldValue::U8)),
549            "tos" => Some(self.tos(buf).map(FieldValue::U8)),
550            "dscp" => Some(self.dscp(buf).map(FieldValue::U8)),
551            "ecn" => Some(self.ecn(buf).map(FieldValue::U8)),
552            "len" | "total_len" => Some(self.total_len(buf).map(FieldValue::U16)),
553            "id" => Some(self.id(buf).map(FieldValue::U16)),
554            "flags" => Some(self.flags_raw(buf).map(FieldValue::U8)),
555            "frag" | "frag_offset" => Some(self.frag_offset(buf).map(FieldValue::U16)),
556            "ttl" => Some(self.ttl(buf).map(FieldValue::U8)),
557            "proto" | "protocol" => Some(self.protocol(buf).map(FieldValue::U8)),
558            "chksum" | "checksum" => Some(self.checksum(buf).map(FieldValue::U16)),
559            "src" => Some(self.src(buf).map(FieldValue::Ipv4)),
560            "dst" => Some(self.dst(buf).map(FieldValue::Ipv4)),
561            _ => None,
562        }
563    }
564
565    /// Set a field value by name.
566    pub fn set_field(
567        &self,
568        buf: &mut [u8],
569        name: &str,
570        value: FieldValue,
571    ) -> Option<Result<(), FieldError>> {
572        match (name, value) {
573            ("version", FieldValue::U8(v)) => Some(self.set_version(buf, v)),
574            ("ihl", FieldValue::U8(v)) => Some(self.set_ihl(buf, v)),
575            ("tos", FieldValue::U8(v)) => Some(self.set_tos(buf, v)),
576            ("dscp", FieldValue::U8(v)) => Some(self.set_dscp(buf, v)),
577            ("ecn", FieldValue::U8(v)) => Some(self.set_ecn(buf, v)),
578            ("len" | "total_len", FieldValue::U16(v)) => Some(self.set_total_len(buf, v)),
579            ("id", FieldValue::U16(v)) => Some(self.set_id(buf, v)),
580            ("frag" | "frag_offset", FieldValue::U16(v)) => Some(self.set_frag_offset(buf, v)),
581            ("ttl", FieldValue::U8(v)) => Some(self.set_ttl(buf, v)),
582            ("proto" | "protocol", FieldValue::U8(v)) => Some(self.set_protocol(buf, v)),
583            ("chksum" | "checksum", FieldValue::U16(v)) => Some(self.set_checksum(buf, v)),
584            ("src", FieldValue::Ipv4(v)) => Some(self.set_src(buf, v)),
585            ("dst", FieldValue::Ipv4(v)) => Some(self.set_dst(buf, v)),
586            _ => None,
587        }
588    }
589
590    /// Get list of field names.
591    #[must_use]
592    pub fn field_names() -> &'static [&'static str] {
593        &[
594            "version", "ihl", "tos", "dscp", "ecn", "len", "id", "flags", "frag", "ttl", "proto",
595            "chksum", "src", "dst",
596        ]
597    }
598
599    // ========== Utility Methods ==========
600
601    /// Check if this is a fragment.
602    #[must_use]
603    pub fn is_fragment(&self, buf: &[u8]) -> bool {
604        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
605        let offset = self.frag_offset(buf).unwrap_or(0);
606        flags.is_fragment(offset)
607    }
608
609    /// Check if this is the first fragment.
610    #[must_use]
611    pub fn is_first_fragment(&self, buf: &[u8]) -> bool {
612        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
613        let offset = self.frag_offset(buf).unwrap_or(0);
614        flags.mf && offset == 0
615    }
616
617    /// Check if this is the last fragment.
618    #[must_use]
619    pub fn is_last_fragment(&self, buf: &[u8]) -> bool {
620        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
621        let offset = self.frag_offset(buf).unwrap_or(0);
622        !flags.mf && offset > 0
623    }
624
625    /// Check if the Don't Fragment flag is set.
626    #[must_use]
627    pub fn is_dont_fragment(&self, buf: &[u8]) -> bool {
628        self.flags(buf).map(|f| f.df).unwrap_or(false)
629    }
630
631    /// Get the payload length (`total_len` - `header_len`).
632    pub fn payload_len(&self, buf: &[u8]) -> Result<usize, FieldError> {
633        let total = self.total_len(buf)? as usize;
634        let header = self.calculate_header_len(buf);
635        Ok(total.saturating_sub(header))
636    }
637
638    /// Get a slice of the payload data.
639    pub fn payload<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> {
640        let header_len = self.calculate_header_len(buf);
641        let total_len = self.total_len(buf)? as usize;
642        let payload_start = self.index.start + header_len;
643        let payload_end = (self.index.start + total_len).min(buf.len());
644
645        if payload_start > buf.len() {
646            return Err(FieldError::BufferTooShort {
647                offset: payload_start,
648                need: 0,
649                have: buf.len().saturating_sub(payload_start),
650            });
651        }
652
653        Ok(&buf[payload_start..payload_end])
654    }
655
656    /// Get the header bytes.
657    #[inline]
658    #[must_use]
659    pub fn header_bytes<'a>(&self, buf: &'a [u8]) -> &'a [u8] {
660        let header_len = self.calculate_header_len(buf);
661        let end = (self.index.start + header_len).min(buf.len());
662        &buf[self.index.start..end]
663    }
664
665    /// Copy the header bytes.
666    #[inline]
667    #[must_use]
668    pub fn header_copy(&self, buf: &[u8]) -> Vec<u8> {
669        self.header_bytes(buf).to_vec()
670    }
671
672    /// Get the protocol name.
673    pub fn protocol_name(&self, buf: &[u8]) -> &'static str {
674        self.protocol(buf)
675            .map(protocol::to_name)
676            .unwrap_or("Unknown")
677    }
678
679    /// Determine the next layer kind based on protocol.
680    #[must_use]
681    pub fn next_layer(&self, buf: &[u8]) -> Option<LayerKind> {
682        self.protocol(buf).ok().and_then(|proto| match proto {
683            protocol::TCP => Some(LayerKind::Tcp),
684            protocol::UDP => Some(LayerKind::Udp),
685            protocol::ICMP => Some(LayerKind::Icmp),
686            protocol::IPV4 => Some(LayerKind::Ipv4),
687            protocol::IPV6 => Some(LayerKind::Ipv6),
688            _ => None,
689        })
690    }
691
692    /// Compute hash for packet matching (like Scapy's hashret).
693    #[must_use]
694    pub fn hashret(&self, buf: &[u8]) -> Vec<u8> {
695        let proto = self.protocol(buf).unwrap_or(0);
696
697        // For ICMP error messages, delegate to inner packet
698        if proto == protocol::ICMP {
699            // Check if it's an ICMP error (type 3, 4, 5, 11, 12)
700            let header_len = self.calculate_header_len(buf);
701            let icmp_start = self.index.start + header_len;
702            if buf.len() > icmp_start {
703                let icmp_type = buf[icmp_start];
704                if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) {
705                    // Return hash of the embedded packet
706                    // For now, just use src/dst XOR + proto
707                }
708            }
709        }
710
711        // For IP-in-IP tunnels, delegate to inner packet
712        if matches!(proto, protocol::IPV4 | protocol::IPV6) {
713            // Could recurse here
714        }
715
716        // Standard hash: XOR of src and dst + protocol
717        let src = self.src(buf).map(|ip| ip.octets()).unwrap_or([0; 4]);
718        let dst = self.dst(buf).map(|ip| ip.octets()).unwrap_or([0; 4]);
719
720        let mut result = Vec::with_capacity(5);
721        for i in 0..4 {
722            result.push(src[i] ^ dst[i]);
723        }
724        result.push(proto);
725        result
726    }
727
728    /// Check if this packet answers another (for `sr()` matching).
729    #[must_use]
730    pub fn answers(&self, buf: &[u8], other: &Ipv4Layer, other_buf: &[u8]) -> bool {
731        // Protocol must match
732        let self_proto = self.protocol(buf).unwrap_or(0);
733        let other_proto = other.protocol(other_buf).unwrap_or(0);
734
735        // Handle ICMP errors
736        if self_proto == protocol::ICMP {
737            let header_len = self.calculate_header_len(buf);
738            let icmp_start = self.index.start + header_len;
739            if buf.len() > icmp_start {
740                let icmp_type = buf[icmp_start];
741                if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) {
742                    // ICMP error - check embedded packet
743                    // The embedded packet should match the original
744                    return true; // Simplified - real impl would check embedded
745                }
746            }
747        }
748
749        // Handle IP-in-IP tunnels
750        if matches!(other_proto, protocol::IPV4 | protocol::IPV6) {
751            // Delegate to inner packet
752        }
753
754        if self_proto != other_proto {
755            return false;
756        }
757
758        // Check addresses
759        let self_src = self.src(buf).ok();
760        let self_dst = self.dst(buf).ok();
761        let other_src = other.src(other_buf).ok();
762        let other_dst = other.dst(other_buf).ok();
763
764        // Response src should match request dst
765        if self_src != other_dst {
766            return false;
767        }
768
769        // Response dst should match request src
770        if self_dst != other_src {
771            return false;
772        }
773
774        true
775    }
776
777    /// Extract padding from the packet.
778    /// Returns (payload, padding) tuple.
779    #[must_use]
780    pub fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) {
781        let header_len = self.calculate_header_len(buf);
782        let total_len = self.total_len(buf).unwrap_or(0) as usize;
783
784        let payload_start = self.index.start + header_len;
785        let payload_end = (self.index.start + total_len).min(buf.len());
786
787        if payload_start >= buf.len() {
788            return (&[], &buf[buf.len()..]);
789        }
790
791        let payload = &buf[payload_start..payload_end];
792        let padding = &buf[payload_end..];
793
794        (payload, padding)
795    }
796
797    /// Get routing information for this packet.
798    #[must_use]
799    pub fn route(&self, buf: &[u8]) -> Ipv4Route {
800        use crate::layer::ipv4::routing::get_route;
801        let dst = self.dst(buf).unwrap_or(Ipv4Addr::UNSPECIFIED);
802        get_route(dst)
803    }
804
805    /// Estimate the original TTL.
806    #[must_use]
807    pub fn original_ttl(&self, buf: &[u8]) -> u8 {
808        let current = self.ttl(buf).unwrap_or(0);
809        super::ttl::estimate_original(current)
810    }
811
812    /// Estimate the number of hops.
813    #[must_use]
814    pub fn hops(&self, buf: &[u8]) -> u8 {
815        let current = self.ttl(buf).unwrap_or(0);
816        super::ttl::estimate_hops(current)
817    }
818}
819
820impl Layer for Ipv4Layer {
821    fn kind(&self) -> LayerKind {
822        LayerKind::Ipv4
823    }
824
825    fn summary(&self, buf: &[u8]) -> String {
826        let src = self
827            .src(buf)
828            .map_or_else(|_| "?".into(), |ip| ip.to_string());
829        let dst = self
830            .dst(buf)
831            .map_or_else(|_| "?".into(), |ip| ip.to_string());
832        let proto = self.protocol_name(buf);
833        let ttl = self.ttl(buf).unwrap_or(0);
834
835        let mut s = format!("IP {src} > {dst} {proto} ttl={ttl}");
836
837        // Add fragment info if fragmented
838        if self.is_fragment(buf) {
839            let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
840            let offset = self.frag_offset(buf).unwrap_or(0);
841            s.push_str(&format!(
842                " frag:{}+{}",
843                offset,
844                if flags.mf { "MF" } else { "" }
845            ));
846        }
847
848        s
849    }
850
851    fn header_len(&self, buf: &[u8]) -> usize {
852        self.calculate_header_len(buf)
853    }
854
855    fn hashret(&self, buf: &[u8]) -> Vec<u8> {
856        self.hashret(buf)
857    }
858
859    fn answers(&self, buf: &[u8], other: &Self, other_buf: &[u8]) -> bool {
860        self.answers(buf, other, other_buf)
861    }
862
863    fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) {
864        self.extract_padding(buf)
865    }
866
867    fn field_names(&self) -> &'static [&'static str] {
868        Ipv4Layer::field_names()
869    }
870}
871
872#[cfg(test)]
873mod tests {
874    use super::*;
875
876    fn sample_ipv4_header() -> Vec<u8> {
877        vec![
878            0x45, // Version=4, IHL=5
879            0x00, // TOS=0
880            0x00, 0x3c, // Total length = 60
881            0x1c, 0x46, // ID = 0x1c46
882            0x40, 0x00, // Flags=DF, Frag offset=0
883            0x40, // TTL=64
884            0x06, // Protocol=TCP
885            0x00, 0x00, // Checksum (to be computed)
886            0xc0, 0xa8, 0x01, 0x01, // Src: 192.168.1.1
887            0xc0, 0xa8, 0x01, 0x02, // Dst: 192.168.1.2
888        ]
889    }
890
891    #[test]
892    fn test_field_readers() {
893        let buf = sample_ipv4_header();
894        let layer = Ipv4Layer::at_offset(0);
895
896        assert_eq!(layer.version(&buf).unwrap(), 4);
897        assert_eq!(layer.ihl(&buf).unwrap(), 5);
898        assert_eq!(layer.tos(&buf).unwrap(), 0);
899        assert_eq!(layer.total_len(&buf).unwrap(), 60);
900        assert_eq!(layer.id(&buf).unwrap(), 0x1c46);
901        assert!(layer.flags(&buf).unwrap().df);
902        assert_eq!(layer.frag_offset(&buf).unwrap(), 0);
903        assert_eq!(layer.ttl(&buf).unwrap(), 64);
904        assert_eq!(layer.protocol(&buf).unwrap(), protocol::TCP);
905        assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 1));
906        assert_eq!(layer.dst(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 2));
907    }
908
909    #[test]
910    fn test_field_writers() {
911        let mut buf = sample_ipv4_header();
912        let layer = Ipv4Layer::at_offset(0);
913
914        layer.set_ttl(&mut buf, 128).unwrap();
915        assert_eq!(layer.ttl(&buf).unwrap(), 128);
916
917        layer.set_src(&mut buf, Ipv4Addr::new(10, 0, 0, 1)).unwrap();
918        assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(10, 0, 0, 1));
919
920        layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap();
921        assert!(layer.flags(&buf).unwrap().mf);
922        assert!(!layer.flags(&buf).unwrap().df);
923    }
924
925    #[test]
926    fn test_checksum() {
927        let mut buf = sample_ipv4_header();
928        let layer = Ipv4Layer::at_offset(0);
929
930        // Compute checksum
931        let checksum = layer.compute_checksum(&mut buf).unwrap();
932        assert_ne!(checksum, 0);
933
934        // Verify it
935        assert!(layer.verify_checksum(&buf).unwrap());
936
937        // Corrupt and verify fails
938        buf[0] ^= 0x01;
939        assert!(!layer.verify_checksum(&buf).unwrap());
940    }
941
942    #[test]
943    fn test_flags() {
944        let flags = Ipv4Flags::from_byte(0x40); // DF
945        assert!(flags.df);
946        assert!(!flags.mf);
947        assert!(!flags.reserved);
948        assert_eq!(flags.to_byte(), 0x40);
949
950        let flags = Ipv4Flags::from_byte(0x20); // MF
951        assert!(!flags.df);
952        assert!(flags.mf);
953        assert!(!flags.reserved);
954
955        let flags = Ipv4Flags::from_byte(0xE0); // All set
956        assert!(flags.df);
957        assert!(flags.mf);
958        assert!(flags.reserved);
959    }
960
961    #[test]
962    fn test_is_fragment() {
963        let mut buf = sample_ipv4_header();
964        let layer = Ipv4Layer::at_offset(0);
965
966        // DF set, not a fragment
967        assert!(!layer.is_fragment(&buf));
968
969        // Set MF
970        layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap();
971        assert!(layer.is_fragment(&buf));
972
973        // Clear MF, set offset
974        layer.set_flags(&mut buf, Ipv4Flags::NONE).unwrap();
975        layer.set_frag_offset(&mut buf, 100).unwrap();
976        assert!(layer.is_fragment(&buf));
977    }
978
979    #[test]
980    fn test_dynamic_field_access() {
981        let buf = sample_ipv4_header();
982        let layer = Ipv4Layer::at_offset(0);
983
984        let ttl = layer.get_field(&buf, "ttl").unwrap().unwrap();
985        assert_eq!(ttl.as_u8(), Some(64));
986
987        let src = layer.get_field(&buf, "src").unwrap().unwrap();
988        assert_eq!(src.as_ipv4(), Some(Ipv4Addr::new(192, 168, 1, 1)));
989    }
990
991    #[test]
992    fn test_validate() {
993        let buf = sample_ipv4_header();
994        assert!(Ipv4Layer::validate(&buf, 0).is_ok());
995
996        // Too short
997        let short = vec![0x45, 0x00];
998        assert!(Ipv4Layer::validate(&short, 0).is_err());
999
1000        // Wrong version
1001        let mut wrong_version = sample_ipv4_header();
1002        wrong_version[0] = 0x65; // Version 6
1003        assert!(Ipv4Layer::validate(&wrong_version, 0).is_err());
1004
1005        // Invalid IHL
1006        let mut bad_ihl = sample_ipv4_header();
1007        bad_ihl[0] = 0x43; // IHL=3 (< minimum 5)
1008        assert!(Ipv4Layer::validate(&bad_ihl, 0).is_err());
1009    }
1010
1011    #[test]
1012    fn test_extract_padding() {
1013        let mut buf = sample_ipv4_header();
1014        // Set total_len to 30 (header=20 + payload=10)
1015        buf[2] = 0x00;
1016        buf[3] = 0x1e; // 30
1017
1018        // Add some payload and padding
1019        buf.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); // 10 bytes payload
1020        buf.extend_from_slice(&[0, 0, 0, 0]); // 4 bytes padding
1021
1022        let layer = Ipv4Layer::at_offset(0);
1023        let (payload, padding) = layer.extract_padding(&buf);
1024
1025        assert_eq!(payload.len(), 10);
1026        assert_eq!(padding.len(), 4);
1027    }
1028}