Skip to main content

stackforge_core/layer/ipv4/
header.rs

1//! IPv4 header layer implementation.
2//!
3//! Provides zero-copy access to IPv4 header fields.
4
5use std::net::Ipv4Addr;
6
7use crate::layer::field::{Field, FieldDesc, FieldError, FieldType, FieldValue};
8use crate::layer::{Layer, LayerIndex, LayerKind};
9
10use super::checksum::ipv4_checksum;
11use super::options::{Ipv4Options, parse_options};
12use super::protocol;
13use super::routing::Ipv4Route;
14
15/// Minimum IPv4 header length (no options).
16pub const IPV4_MIN_HEADER_LEN: usize = 20;
17
18/// Maximum IPv4 header length (with maximum options).
19pub const IPV4_MAX_HEADER_LEN: usize = 60;
20
21/// Maximum total length of an IPv4 packet.
22pub const IPV4_MAX_PACKET_LEN: usize = 65535;
23
24/// Field offsets within the IPv4 header.
25pub mod offsets {
26    /// Version (4 bits) + IHL (4 bits)
27    pub const VERSION_IHL: usize = 0;
28    /// DSCP (6 bits) + ECN (2 bits) - also known as TOS
29    pub const TOS: usize = 1;
30    /// Total length (16 bits)
31    pub const TOTAL_LEN: usize = 2;
32    /// Identification (16 bits)
33    pub const ID: usize = 4;
34    /// Flags (3 bits) + Fragment offset (13 bits)
35    pub const FLAGS_FRAG: usize = 6;
36    /// Time to live (8 bits)
37    pub const TTL: usize = 8;
38    /// Protocol (8 bits)
39    pub const PROTOCOL: usize = 9;
40    /// Header checksum (16 bits)
41    pub const CHECKSUM: usize = 10;
42    /// Source address (32 bits)
43    pub const SRC: usize = 12;
44    /// Destination address (32 bits)
45    pub const DST: usize = 16;
46    /// Options start (if IHL > 5)
47    pub const OPTIONS: usize = 20;
48}
49
50/// Field descriptors for dynamic access.
51pub static FIELDS: &[FieldDesc] = &[
52    FieldDesc::new("version", offsets::VERSION_IHL, 1, FieldType::U8),
53    FieldDesc::new("ihl", offsets::VERSION_IHL, 1, FieldType::U8),
54    FieldDesc::new("tos", offsets::TOS, 1, FieldType::U8),
55    FieldDesc::new("dscp", offsets::TOS, 1, FieldType::U8),
56    FieldDesc::new("ecn", offsets::TOS, 1, FieldType::U8),
57    FieldDesc::new("len", offsets::TOTAL_LEN, 2, FieldType::U16),
58    FieldDesc::new("id", offsets::ID, 2, FieldType::U16),
59    FieldDesc::new("flags", offsets::FLAGS_FRAG, 1, FieldType::U8),
60    FieldDesc::new("frag", offsets::FLAGS_FRAG, 2, FieldType::U16),
61    FieldDesc::new("ttl", offsets::TTL, 1, FieldType::U8),
62    FieldDesc::new("proto", offsets::PROTOCOL, 1, FieldType::U8),
63    FieldDesc::new("chksum", offsets::CHECKSUM, 2, FieldType::U16),
64    FieldDesc::new("src", offsets::SRC, 4, FieldType::Ipv4),
65    FieldDesc::new("dst", offsets::DST, 4, FieldType::Ipv4),
66];
67
68/// IPv4 flags in the flags/fragment offset field.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub struct Ipv4Flags {
71    /// Reserved/Evil bit (should be 0)
72    pub reserved: bool,
73    /// Don't Fragment
74    pub df: bool,
75    /// More Fragments
76    pub mf: bool,
77}
78
79impl Ipv4Flags {
80    pub const NONE: Self = Self {
81        reserved: false,
82        df: false,
83        mf: false,
84    };
85
86    pub const DF: Self = Self {
87        reserved: false,
88        df: true,
89        mf: false,
90    };
91
92    pub const MF: Self = Self {
93        reserved: false,
94        df: false,
95        mf: true,
96    };
97
98    /// Create flags from a raw byte value (upper 3 bits).
99    #[inline]
100    pub fn from_byte(byte: u8) -> Self {
101        Self {
102            reserved: (byte & 0x80) != 0,
103            df: (byte & 0x40) != 0,
104            mf: (byte & 0x20) != 0,
105        }
106    }
107
108    /// Convert to a raw byte value (upper 3 bits).
109    #[inline]
110    pub fn to_byte(self) -> u8 {
111        let mut b = 0u8;
112        if self.reserved {
113            b |= 0x80;
114        }
115        if self.df {
116            b |= 0x40;
117        }
118        if self.mf {
119            b |= 0x20;
120        }
121        b
122    }
123
124    /// Check if this is a fragment (MF set or offset > 0).
125    #[inline]
126    pub fn is_fragment(self, offset: u16) -> bool {
127        self.mf || offset > 0
128    }
129}
130
131impl std::fmt::Display for Ipv4Flags {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        let mut parts = Vec::new();
134        if self.reserved {
135            parts.push("evil");
136        }
137        if self.df {
138            parts.push("DF");
139        }
140        if self.mf {
141            parts.push("MF");
142        }
143        if parts.is_empty() {
144            write!(f, "-")
145        } else {
146            write!(f, "{}", parts.join("+"))
147        }
148    }
149}
150
151/// A view into an IPv4 packet header.
152#[derive(Debug, Clone)]
153pub struct Ipv4Layer {
154    pub index: LayerIndex,
155}
156
157impl Ipv4Layer {
158    /// Create a new IPv4 layer view with specified bounds.
159    #[inline]
160    pub const fn new(start: usize, end: usize) -> Self {
161        Self {
162            index: LayerIndex::new(LayerKind::Ipv4, start, end),
163        }
164    }
165
166    /// Create a layer at offset 0 with minimum header length.
167    #[inline]
168    pub const fn at_start() -> Self {
169        Self::new(0, IPV4_MIN_HEADER_LEN)
170    }
171
172    /// Create a layer at the specified offset with minimum header length.
173    #[inline]
174    pub const fn at_offset(offset: usize) -> Self {
175        Self::new(offset, offset + IPV4_MIN_HEADER_LEN)
176    }
177
178    /// Create a layer at offset, calculating actual header length from IHL.
179    pub fn at_offset_dynamic(buf: &[u8], offset: usize) -> Result<Self, FieldError> {
180        if buf.len() < offset + 1 {
181            return Err(FieldError::BufferTooShort {
182                offset,
183                need: 1,
184                have: buf.len().saturating_sub(offset),
185            });
186        }
187
188        let ihl = (buf[offset] & 0x0F) as usize;
189        let header_len = ihl * 4;
190
191        if header_len < IPV4_MIN_HEADER_LEN {
192            return Err(FieldError::InvalidValue(format!(
193                "IHL {} is less than minimum (5)",
194                ihl
195            )));
196        }
197
198        if buf.len() < offset + header_len {
199            return Err(FieldError::BufferTooShort {
200                offset,
201                need: header_len,
202                have: buf.len().saturating_sub(offset),
203            });
204        }
205
206        Ok(Self::new(offset, offset + header_len))
207    }
208
209    /// Validate that the buffer contains a valid IPv4 header at the offset.
210    pub fn validate(buf: &[u8], offset: usize) -> Result<(), FieldError> {
211        if buf.len() < offset + IPV4_MIN_HEADER_LEN {
212            return Err(FieldError::BufferTooShort {
213                offset,
214                need: IPV4_MIN_HEADER_LEN,
215                have: buf.len().saturating_sub(offset),
216            });
217        }
218
219        let version = (buf[offset] >> 4) & 0x0F;
220        if version != 4 {
221            return Err(FieldError::InvalidValue(format!(
222                "not IPv4: version = {}",
223                version
224            )));
225        }
226
227        let ihl = (buf[offset] & 0x0F) as usize;
228        if ihl < 5 {
229            return Err(FieldError::InvalidValue(format!(
230                "IHL {} is less than minimum (5)",
231                ihl
232            )));
233        }
234
235        let header_len = ihl * 4;
236        if buf.len() < offset + header_len {
237            return Err(FieldError::BufferTooShort {
238                offset,
239                need: header_len,
240                have: buf.len().saturating_sub(offset),
241            });
242        }
243
244        Ok(())
245    }
246
247    /// Calculate the actual header length from the buffer.
248    pub fn calculate_header_len(&self, buf: &[u8]) -> usize {
249        self.ihl(buf).map(|ihl| (ihl as usize) * 4).unwrap_or(20)
250    }
251
252    /// Get the options length (header length - 20).
253    pub fn options_len(&self, buf: &[u8]) -> usize {
254        self.calculate_header_len(buf)
255            .saturating_sub(IPV4_MIN_HEADER_LEN)
256    }
257
258    // ========== Field Readers ==========
259
260    /// Read the version field (should be 4).
261    #[inline]
262    pub fn version(&self, buf: &[u8]) -> Result<u8, FieldError> {
263        let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?;
264        Ok((b >> 4) & 0x0F)
265    }
266
267    /// Read the Internet Header Length (in 32-bit words).
268    #[inline]
269    pub fn ihl(&self, buf: &[u8]) -> Result<u8, FieldError> {
270        let b = u8::read(buf, self.index.start + offsets::VERSION_IHL)?;
271        Ok(b & 0x0F)
272    }
273
274    /// Read the header length in bytes.
275    #[inline]
276    pub fn header_len_bytes(&self, buf: &[u8]) -> Result<usize, FieldError> {
277        Ok((self.ihl(buf)? as usize) * 4)
278    }
279
280    /// Read the Type of Service (TOS) field.
281    #[inline]
282    pub fn tos(&self, buf: &[u8]) -> Result<u8, FieldError> {
283        u8::read(buf, self.index.start + offsets::TOS)
284    }
285
286    /// Read the DSCP (Differentiated Services Code Point).
287    #[inline]
288    pub fn dscp(&self, buf: &[u8]) -> Result<u8, FieldError> {
289        Ok((self.tos(buf)? >> 2) & 0x3F)
290    }
291
292    /// Read the ECN (Explicit Congestion Notification).
293    #[inline]
294    pub fn ecn(&self, buf: &[u8]) -> Result<u8, FieldError> {
295        Ok(self.tos(buf)? & 0x03)
296    }
297
298    /// Read the total length field.
299    #[inline]
300    pub fn total_len(&self, buf: &[u8]) -> Result<u16, FieldError> {
301        u16::read(buf, self.index.start + offsets::TOTAL_LEN)
302    }
303
304    /// Read the identification field.
305    #[inline]
306    pub fn id(&self, buf: &[u8]) -> Result<u16, FieldError> {
307        u16::read(buf, self.index.start + offsets::ID)
308    }
309
310    /// Read the flags as a structured type.
311    #[inline]
312    pub fn flags(&self, buf: &[u8]) -> Result<Ipv4Flags, FieldError> {
313        let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
314        Ok(Ipv4Flags::from_byte(b))
315    }
316
317    /// Read the raw flags byte (upper 3 bits of flags/frag field).
318    #[inline]
319    pub fn flags_raw(&self, buf: &[u8]) -> Result<u8, FieldError> {
320        let b = u8::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
321        Ok((b >> 5) & 0x07)
322    }
323
324    /// Read the fragment offset (in 8-byte units).
325    #[inline]
326    pub fn frag_offset(&self, buf: &[u8]) -> Result<u16, FieldError> {
327        let val = u16::read(buf, self.index.start + offsets::FLAGS_FRAG)?;
328        Ok(val & 0x1FFF)
329    }
330
331    /// Read the fragment offset in bytes.
332    #[inline]
333    pub fn frag_offset_bytes(&self, buf: &[u8]) -> Result<u32, FieldError> {
334        Ok((self.frag_offset(buf)? as u32) * 8)
335    }
336
337    /// Read the Time to Live field.
338    #[inline]
339    pub fn ttl(&self, buf: &[u8]) -> Result<u8, FieldError> {
340        u8::read(buf, self.index.start + offsets::TTL)
341    }
342
343    /// Read the protocol field.
344    #[inline]
345    pub fn protocol(&self, buf: &[u8]) -> Result<u8, FieldError> {
346        u8::read(buf, self.index.start + offsets::PROTOCOL)
347    }
348
349    /// Read the header checksum.
350    #[inline]
351    pub fn checksum(&self, buf: &[u8]) -> Result<u16, FieldError> {
352        u16::read(buf, self.index.start + offsets::CHECKSUM)
353    }
354
355    /// Read the source IP address.
356    #[inline]
357    pub fn src(&self, buf: &[u8]) -> Result<Ipv4Addr, FieldError> {
358        Ipv4Addr::read(buf, self.index.start + offsets::SRC)
359    }
360
361    /// Read the destination IP address.
362    #[inline]
363    pub fn dst(&self, buf: &[u8]) -> Result<Ipv4Addr, FieldError> {
364        Ipv4Addr::read(buf, self.index.start + offsets::DST)
365    }
366
367    /// Read the options bytes (if any).
368    pub fn options_bytes<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> {
369        let header_len = self.calculate_header_len(buf);
370        let opts_start = self.index.start + IPV4_MIN_HEADER_LEN;
371        let opts_end = self.index.start + header_len;
372
373        if buf.len() < opts_end {
374            return Err(FieldError::BufferTooShort {
375                offset: opts_start,
376                need: header_len - IPV4_MIN_HEADER_LEN,
377                have: buf.len().saturating_sub(opts_start),
378            });
379        }
380
381        Ok(&buf[opts_start..opts_end])
382    }
383
384    /// Parse and return the options.
385    pub fn options(&self, buf: &[u8]) -> Result<Ipv4Options, FieldError> {
386        let opts_bytes = self.options_bytes(buf)?;
387        parse_options(opts_bytes)
388    }
389
390    // ========== Field Writers ==========
391
392    /// Set the version field.
393    #[inline]
394    pub fn set_version(&self, buf: &mut [u8], version: u8) -> Result<(), FieldError> {
395        let offset = self.index.start + offsets::VERSION_IHL;
396        let current = u8::read(buf, offset)?;
397        let new_val = (current & 0x0F) | ((version & 0x0F) << 4);
398        new_val.write(buf, offset)
399    }
400
401    /// Set the IHL field.
402    #[inline]
403    pub fn set_ihl(&self, buf: &mut [u8], ihl: u8) -> Result<(), FieldError> {
404        let offset = self.index.start + offsets::VERSION_IHL;
405        let current = u8::read(buf, offset)?;
406        let new_val = (current & 0xF0) | (ihl & 0x0F);
407        new_val.write(buf, offset)
408    }
409
410    /// Set the TOS field.
411    #[inline]
412    pub fn set_tos(&self, buf: &mut [u8], tos: u8) -> Result<(), FieldError> {
413        tos.write(buf, self.index.start + offsets::TOS)
414    }
415
416    /// Set the DSCP field.
417    #[inline]
418    pub fn set_dscp(&self, buf: &mut [u8], dscp: u8) -> Result<(), FieldError> {
419        let offset = self.index.start + offsets::TOS;
420        let current = u8::read(buf, offset)?;
421        let new_val = (current & 0x03) | ((dscp & 0x3F) << 2);
422        new_val.write(buf, offset)
423    }
424
425    /// Set the ECN field.
426    #[inline]
427    pub fn set_ecn(&self, buf: &mut [u8], ecn: u8) -> Result<(), FieldError> {
428        let offset = self.index.start + offsets::TOS;
429        let current = u8::read(buf, offset)?;
430        let new_val = (current & 0xFC) | (ecn & 0x03);
431        new_val.write(buf, offset)
432    }
433
434    /// Set the total length field.
435    #[inline]
436    pub fn set_total_len(&self, buf: &mut [u8], len: u16) -> Result<(), FieldError> {
437        len.write(buf, self.index.start + offsets::TOTAL_LEN)
438    }
439
440    /// Set the identification field.
441    #[inline]
442    pub fn set_id(&self, buf: &mut [u8], id: u16) -> Result<(), FieldError> {
443        id.write(buf, self.index.start + offsets::ID)
444    }
445
446    /// Set the flags field.
447    #[inline]
448    pub fn set_flags(&self, buf: &mut [u8], flags: Ipv4Flags) -> Result<(), FieldError> {
449        let offset = self.index.start + offsets::FLAGS_FRAG;
450        let current = u8::read(buf, offset)?;
451        let new_val = (current & 0x1F) | flags.to_byte();
452        new_val.write(buf, offset)
453    }
454
455    /// Set the fragment offset (in 8-byte units).
456    #[inline]
457    pub fn set_frag_offset(&self, buf: &mut [u8], offset_val: u16) -> Result<(), FieldError> {
458        let offset = self.index.start + offsets::FLAGS_FRAG;
459        let current = u16::read(buf, offset)?;
460        let new_val = (current & 0xE000) | (offset_val & 0x1FFF);
461        new_val.write(buf, offset)
462    }
463
464    /// Set the TTL field.
465    #[inline]
466    pub fn set_ttl(&self, buf: &mut [u8], ttl: u8) -> Result<(), FieldError> {
467        ttl.write(buf, self.index.start + offsets::TTL)
468    }
469
470    /// Set the protocol field.
471    #[inline]
472    pub fn set_protocol(&self, buf: &mut [u8], proto: u8) -> Result<(), FieldError> {
473        proto.write(buf, self.index.start + offsets::PROTOCOL)
474    }
475
476    /// Set the checksum field.
477    #[inline]
478    pub fn set_checksum(&self, buf: &mut [u8], checksum: u16) -> Result<(), FieldError> {
479        checksum.write(buf, self.index.start + offsets::CHECKSUM)
480    }
481
482    /// Set the source IP address.
483    #[inline]
484    pub fn set_src(&self, buf: &mut [u8], src: Ipv4Addr) -> Result<(), FieldError> {
485        src.write(buf, self.index.start + offsets::SRC)
486    }
487
488    /// Set the destination IP address.
489    #[inline]
490    pub fn set_dst(&self, buf: &mut [u8], dst: Ipv4Addr) -> Result<(), FieldError> {
491        dst.write(buf, self.index.start + offsets::DST)
492    }
493
494    /// Compute and set the header checksum.
495    pub fn compute_checksum(&self, buf: &mut [u8]) -> Result<u16, FieldError> {
496        // Zero out existing checksum first
497        self.set_checksum(buf, 0)?;
498
499        // Calculate header length
500        let header_len = self.calculate_header_len(buf);
501        let header_end = self.index.start + header_len;
502
503        if buf.len() < header_end {
504            return Err(FieldError::BufferTooShort {
505                offset: self.index.start,
506                need: header_len,
507                have: buf.len().saturating_sub(self.index.start),
508            });
509        }
510
511        // Compute checksum over header bytes
512        let checksum = ipv4_checksum(&buf[self.index.start..header_end]);
513
514        // Write checksum
515        self.set_checksum(buf, checksum)?;
516
517        Ok(checksum)
518    }
519
520    /// Verify the header checksum.
521    pub fn verify_checksum(&self, buf: &[u8]) -> Result<bool, FieldError> {
522        let header_len = self.calculate_header_len(buf);
523        let header_end = self.index.start + header_len;
524
525        if buf.len() < header_end {
526            return Err(FieldError::BufferTooShort {
527                offset: self.index.start,
528                need: header_len,
529                have: buf.len().saturating_sub(self.index.start),
530            });
531        }
532
533        let checksum = ipv4_checksum(&buf[self.index.start..header_end]);
534        Ok(checksum == 0 || checksum == 0xFFFF)
535    }
536
537    // ========== Dynamic Field Access ==========
538
539    /// Get a field value by name.
540    pub fn get_field(&self, buf: &[u8], name: &str) -> Option<Result<FieldValue, FieldError>> {
541        match name {
542            "version" => Some(self.version(buf).map(FieldValue::U8)),
543            "ihl" => Some(self.ihl(buf).map(FieldValue::U8)),
544            "tos" => Some(self.tos(buf).map(FieldValue::U8)),
545            "dscp" => Some(self.dscp(buf).map(FieldValue::U8)),
546            "ecn" => Some(self.ecn(buf).map(FieldValue::U8)),
547            "len" | "total_len" => Some(self.total_len(buf).map(FieldValue::U16)),
548            "id" => Some(self.id(buf).map(FieldValue::U16)),
549            "flags" => Some(self.flags_raw(buf).map(FieldValue::U8)),
550            "frag" | "frag_offset" => Some(self.frag_offset(buf).map(FieldValue::U16)),
551            "ttl" => Some(self.ttl(buf).map(FieldValue::U8)),
552            "proto" | "protocol" => Some(self.protocol(buf).map(FieldValue::U8)),
553            "chksum" | "checksum" => Some(self.checksum(buf).map(FieldValue::U16)),
554            "src" => Some(self.src(buf).map(FieldValue::Ipv4)),
555            "dst" => Some(self.dst(buf).map(FieldValue::Ipv4)),
556            _ => None,
557        }
558    }
559
560    /// Set a field value by name.
561    pub fn set_field(
562        &self,
563        buf: &mut [u8],
564        name: &str,
565        value: FieldValue,
566    ) -> Option<Result<(), FieldError>> {
567        match (name, value) {
568            ("version", FieldValue::U8(v)) => Some(self.set_version(buf, v)),
569            ("ihl", FieldValue::U8(v)) => Some(self.set_ihl(buf, v)),
570            ("tos", FieldValue::U8(v)) => Some(self.set_tos(buf, v)),
571            ("dscp", FieldValue::U8(v)) => Some(self.set_dscp(buf, v)),
572            ("ecn", FieldValue::U8(v)) => Some(self.set_ecn(buf, v)),
573            ("len" | "total_len", FieldValue::U16(v)) => Some(self.set_total_len(buf, v)),
574            ("id", FieldValue::U16(v)) => Some(self.set_id(buf, v)),
575            ("frag" | "frag_offset", FieldValue::U16(v)) => Some(self.set_frag_offset(buf, v)),
576            ("ttl", FieldValue::U8(v)) => Some(self.set_ttl(buf, v)),
577            ("proto" | "protocol", FieldValue::U8(v)) => Some(self.set_protocol(buf, v)),
578            ("chksum" | "checksum", FieldValue::U16(v)) => Some(self.set_checksum(buf, v)),
579            ("src", FieldValue::Ipv4(v)) => Some(self.set_src(buf, v)),
580            ("dst", FieldValue::Ipv4(v)) => Some(self.set_dst(buf, v)),
581            _ => None,
582        }
583    }
584
585    /// Get list of field names.
586    pub fn field_names() -> &'static [&'static str] {
587        &[
588            "version", "ihl", "tos", "dscp", "ecn", "len", "id", "flags", "frag", "ttl", "proto",
589            "chksum", "src", "dst",
590        ]
591    }
592
593    // ========== Utility Methods ==========
594
595    /// Check if this is a fragment.
596    pub fn is_fragment(&self, buf: &[u8]) -> bool {
597        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
598        let offset = self.frag_offset(buf).unwrap_or(0);
599        flags.is_fragment(offset)
600    }
601
602    /// Check if this is the first fragment.
603    pub fn is_first_fragment(&self, buf: &[u8]) -> bool {
604        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
605        let offset = self.frag_offset(buf).unwrap_or(0);
606        flags.mf && offset == 0
607    }
608
609    /// Check if this is the last fragment.
610    pub fn is_last_fragment(&self, buf: &[u8]) -> bool {
611        let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
612        let offset = self.frag_offset(buf).unwrap_or(0);
613        !flags.mf && offset > 0
614    }
615
616    /// Check if the Don't Fragment flag is set.
617    pub fn is_dont_fragment(&self, buf: &[u8]) -> bool {
618        self.flags(buf).map(|f| f.df).unwrap_or(false)
619    }
620
621    /// Get the payload length (total_len - header_len).
622    pub fn payload_len(&self, buf: &[u8]) -> Result<usize, FieldError> {
623        let total = self.total_len(buf)? as usize;
624        let header = self.calculate_header_len(buf);
625        Ok(total.saturating_sub(header))
626    }
627
628    /// Get a slice of the payload data.
629    pub fn payload<'a>(&self, buf: &'a [u8]) -> Result<&'a [u8], FieldError> {
630        let header_len = self.calculate_header_len(buf);
631        let total_len = self.total_len(buf)? as usize;
632        let payload_start = self.index.start + header_len;
633        let payload_end = (self.index.start + total_len).min(buf.len());
634
635        if payload_start > buf.len() {
636            return Err(FieldError::BufferTooShort {
637                offset: payload_start,
638                need: 0,
639                have: buf.len().saturating_sub(payload_start),
640            });
641        }
642
643        Ok(&buf[payload_start..payload_end])
644    }
645
646    /// Get the header bytes.
647    #[inline]
648    pub fn header_bytes<'a>(&self, buf: &'a [u8]) -> &'a [u8] {
649        let header_len = self.calculate_header_len(buf);
650        let end = (self.index.start + header_len).min(buf.len());
651        &buf[self.index.start..end]
652    }
653
654    /// Copy the header bytes.
655    #[inline]
656    pub fn header_copy(&self, buf: &[u8]) -> Vec<u8> {
657        self.header_bytes(buf).to_vec()
658    }
659
660    /// Get the protocol name.
661    pub fn protocol_name(&self, buf: &[u8]) -> &'static str {
662        self.protocol(buf)
663            .map(protocol::to_name)
664            .unwrap_or("Unknown")
665    }
666
667    /// Determine the next layer kind based on protocol.
668    pub fn next_layer(&self, buf: &[u8]) -> Option<LayerKind> {
669        self.protocol(buf).ok().and_then(|proto| match proto {
670            protocol::TCP => Some(LayerKind::Tcp),
671            protocol::UDP => Some(LayerKind::Udp),
672            protocol::ICMP => Some(LayerKind::Icmp),
673            protocol::IPV4 => Some(LayerKind::Ipv4),
674            protocol::IPV6 => Some(LayerKind::Ipv6),
675            _ => None,
676        })
677    }
678
679    /// Compute hash for packet matching (like Scapy's hashret).
680    pub fn hashret(&self, buf: &[u8]) -> Vec<u8> {
681        let proto = self.protocol(buf).unwrap_or(0);
682
683        // For ICMP error messages, delegate to inner packet
684        if proto == protocol::ICMP {
685            // Check if it's an ICMP error (type 3, 4, 5, 11, 12)
686            let header_len = self.calculate_header_len(buf);
687            let icmp_start = self.index.start + header_len;
688            if buf.len() > icmp_start {
689                let icmp_type = buf[icmp_start];
690                if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) {
691                    // Return hash of the embedded packet
692                    // For now, just use src/dst XOR + proto
693                }
694            }
695        }
696
697        // For IP-in-IP tunnels, delegate to inner packet
698        if matches!(proto, protocol::IPV4 | protocol::IPV6) {
699            // Could recurse here
700        }
701
702        // Standard hash: XOR of src and dst + protocol
703        let src = self.src(buf).map(|ip| ip.octets()).unwrap_or([0; 4]);
704        let dst = self.dst(buf).map(|ip| ip.octets()).unwrap_or([0; 4]);
705
706        let mut result = Vec::with_capacity(5);
707        for i in 0..4 {
708            result.push(src[i] ^ dst[i]);
709        }
710        result.push(proto);
711        result
712    }
713
714    /// Check if this packet answers another (for sr() matching).
715    pub fn answers(&self, buf: &[u8], other: &Ipv4Layer, other_buf: &[u8]) -> bool {
716        // Protocol must match
717        let self_proto = self.protocol(buf).unwrap_or(0);
718        let other_proto = other.protocol(other_buf).unwrap_or(0);
719
720        // Handle ICMP errors
721        if self_proto == protocol::ICMP {
722            let header_len = self.calculate_header_len(buf);
723            let icmp_start = self.index.start + header_len;
724            if buf.len() > icmp_start {
725                let icmp_type = buf[icmp_start];
726                if matches!(icmp_type, 3 | 4 | 5 | 11 | 12) {
727                    // ICMP error - check embedded packet
728                    // The embedded packet should match the original
729                    return true; // Simplified - real impl would check embedded
730                }
731            }
732        }
733
734        // Handle IP-in-IP tunnels
735        if matches!(other_proto, protocol::IPV4 | protocol::IPV6) {
736            // Delegate to inner packet
737        }
738
739        if self_proto != other_proto {
740            return false;
741        }
742
743        // Check addresses
744        let self_src = self.src(buf).ok();
745        let self_dst = self.dst(buf).ok();
746        let other_src = other.src(other_buf).ok();
747        let other_dst = other.dst(other_buf).ok();
748
749        // Response src should match request dst
750        if self_src != other_dst {
751            return false;
752        }
753
754        // Response dst should match request src
755        if self_dst != other_src {
756            return false;
757        }
758
759        true
760    }
761
762    /// Extract padding from the packet.
763    /// Returns (payload, padding) tuple.
764    pub fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) {
765        let header_len = self.calculate_header_len(buf);
766        let total_len = self.total_len(buf).unwrap_or(0) as usize;
767
768        let payload_start = self.index.start + header_len;
769        let payload_end = (self.index.start + total_len).min(buf.len());
770
771        if payload_start >= buf.len() {
772            return (&[], &buf[buf.len()..]);
773        }
774
775        let payload = &buf[payload_start..payload_end];
776        let padding = &buf[payload_end..];
777
778        (payload, padding)
779    }
780
781    /// Get routing information for this packet.
782    pub fn route(&self, buf: &[u8]) -> Ipv4Route {
783        use crate::layer::ipv4::routing::get_route;
784        let dst = self.dst(buf).unwrap_or(Ipv4Addr::UNSPECIFIED);
785        get_route(dst)
786    }
787
788    /// Estimate the original TTL.
789    pub fn original_ttl(&self, buf: &[u8]) -> u8 {
790        let current = self.ttl(buf).unwrap_or(0);
791        super::ttl::estimate_original(current)
792    }
793
794    /// Estimate the number of hops.
795    pub fn hops(&self, buf: &[u8]) -> u8 {
796        let current = self.ttl(buf).unwrap_or(0);
797        super::ttl::estimate_hops(current)
798    }
799}
800
801impl Layer for Ipv4Layer {
802    fn kind(&self) -> LayerKind {
803        LayerKind::Ipv4
804    }
805
806    fn summary(&self, buf: &[u8]) -> String {
807        let src = self
808            .src(buf)
809            .map(|ip| ip.to_string())
810            .unwrap_or_else(|_| "?".into());
811        let dst = self
812            .dst(buf)
813            .map(|ip| ip.to_string())
814            .unwrap_or_else(|_| "?".into());
815        let proto = self.protocol_name(buf);
816        let ttl = self.ttl(buf).unwrap_or(0);
817
818        let mut s = format!("IP {} > {} {} ttl={}", src, dst, proto, ttl);
819
820        // Add fragment info if fragmented
821        if self.is_fragment(buf) {
822            let flags = self.flags(buf).unwrap_or(Ipv4Flags::NONE);
823            let offset = self.frag_offset(buf).unwrap_or(0);
824            s.push_str(&format!(
825                " frag:{}+{}",
826                offset,
827                if flags.mf { "MF" } else { "" }
828            ));
829        }
830
831        s
832    }
833
834    fn header_len(&self, buf: &[u8]) -> usize {
835        self.calculate_header_len(buf)
836    }
837
838    fn hashret(&self, buf: &[u8]) -> Vec<u8> {
839        self.hashret(buf)
840    }
841
842    fn answers(&self, buf: &[u8], other: &Self, other_buf: &[u8]) -> bool {
843        self.answers(buf, other, other_buf)
844    }
845
846    fn extract_padding<'a>(&self, buf: &'a [u8]) -> (&'a [u8], &'a [u8]) {
847        self.extract_padding(buf)
848    }
849}
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854
855    fn sample_ipv4_header() -> Vec<u8> {
856        vec![
857            0x45, // Version=4, IHL=5
858            0x00, // TOS=0
859            0x00, 0x3c, // Total length = 60
860            0x1c, 0x46, // ID = 0x1c46
861            0x40, 0x00, // Flags=DF, Frag offset=0
862            0x40, // TTL=64
863            0x06, // Protocol=TCP
864            0x00, 0x00, // Checksum (to be computed)
865            0xc0, 0xa8, 0x01, 0x01, // Src: 192.168.1.1
866            0xc0, 0xa8, 0x01, 0x02, // Dst: 192.168.1.2
867        ]
868    }
869
870    #[test]
871    fn test_field_readers() {
872        let buf = sample_ipv4_header();
873        let layer = Ipv4Layer::at_offset(0);
874
875        assert_eq!(layer.version(&buf).unwrap(), 4);
876        assert_eq!(layer.ihl(&buf).unwrap(), 5);
877        assert_eq!(layer.tos(&buf).unwrap(), 0);
878        assert_eq!(layer.total_len(&buf).unwrap(), 60);
879        assert_eq!(layer.id(&buf).unwrap(), 0x1c46);
880        assert!(layer.flags(&buf).unwrap().df);
881        assert_eq!(layer.frag_offset(&buf).unwrap(), 0);
882        assert_eq!(layer.ttl(&buf).unwrap(), 64);
883        assert_eq!(layer.protocol(&buf).unwrap(), protocol::TCP);
884        assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 1));
885        assert_eq!(layer.dst(&buf).unwrap(), Ipv4Addr::new(192, 168, 1, 2));
886    }
887
888    #[test]
889    fn test_field_writers() {
890        let mut buf = sample_ipv4_header();
891        let layer = Ipv4Layer::at_offset(0);
892
893        layer.set_ttl(&mut buf, 128).unwrap();
894        assert_eq!(layer.ttl(&buf).unwrap(), 128);
895
896        layer.set_src(&mut buf, Ipv4Addr::new(10, 0, 0, 1)).unwrap();
897        assert_eq!(layer.src(&buf).unwrap(), Ipv4Addr::new(10, 0, 0, 1));
898
899        layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap();
900        assert!(layer.flags(&buf).unwrap().mf);
901        assert!(!layer.flags(&buf).unwrap().df);
902    }
903
904    #[test]
905    fn test_checksum() {
906        let mut buf = sample_ipv4_header();
907        let layer = Ipv4Layer::at_offset(0);
908
909        // Compute checksum
910        let checksum = layer.compute_checksum(&mut buf).unwrap();
911        assert_ne!(checksum, 0);
912
913        // Verify it
914        assert!(layer.verify_checksum(&buf).unwrap());
915
916        // Corrupt and verify fails
917        buf[0] ^= 0x01;
918        assert!(!layer.verify_checksum(&buf).unwrap());
919    }
920
921    #[test]
922    fn test_flags() {
923        let flags = Ipv4Flags::from_byte(0x40); // DF
924        assert!(flags.df);
925        assert!(!flags.mf);
926        assert!(!flags.reserved);
927        assert_eq!(flags.to_byte(), 0x40);
928
929        let flags = Ipv4Flags::from_byte(0x20); // MF
930        assert!(!flags.df);
931        assert!(flags.mf);
932        assert!(!flags.reserved);
933
934        let flags = Ipv4Flags::from_byte(0xE0); // All set
935        assert!(flags.df);
936        assert!(flags.mf);
937        assert!(flags.reserved);
938    }
939
940    #[test]
941    fn test_is_fragment() {
942        let mut buf = sample_ipv4_header();
943        let layer = Ipv4Layer::at_offset(0);
944
945        // DF set, not a fragment
946        assert!(!layer.is_fragment(&buf));
947
948        // Set MF
949        layer.set_flags(&mut buf, Ipv4Flags::MF).unwrap();
950        assert!(layer.is_fragment(&buf));
951
952        // Clear MF, set offset
953        layer.set_flags(&mut buf, Ipv4Flags::NONE).unwrap();
954        layer.set_frag_offset(&mut buf, 100).unwrap();
955        assert!(layer.is_fragment(&buf));
956    }
957
958    #[test]
959    fn test_dynamic_field_access() {
960        let buf = sample_ipv4_header();
961        let layer = Ipv4Layer::at_offset(0);
962
963        let ttl = layer.get_field(&buf, "ttl").unwrap().unwrap();
964        assert_eq!(ttl.as_u8(), Some(64));
965
966        let src = layer.get_field(&buf, "src").unwrap().unwrap();
967        assert_eq!(src.as_ipv4(), Some(Ipv4Addr::new(192, 168, 1, 1)));
968    }
969
970    #[test]
971    fn test_validate() {
972        let buf = sample_ipv4_header();
973        assert!(Ipv4Layer::validate(&buf, 0).is_ok());
974
975        // Too short
976        let short = vec![0x45, 0x00];
977        assert!(Ipv4Layer::validate(&short, 0).is_err());
978
979        // Wrong version
980        let mut wrong_version = sample_ipv4_header();
981        wrong_version[0] = 0x65; // Version 6
982        assert!(Ipv4Layer::validate(&wrong_version, 0).is_err());
983
984        // Invalid IHL
985        let mut bad_ihl = sample_ipv4_header();
986        bad_ihl[0] = 0x43; // IHL=3 (< minimum 5)
987        assert!(Ipv4Layer::validate(&bad_ihl, 0).is_err());
988    }
989
990    #[test]
991    fn test_extract_padding() {
992        let mut buf = sample_ipv4_header();
993        // Set total_len to 30 (header=20 + payload=10)
994        buf[2] = 0x00;
995        buf[3] = 0x1e; // 30
996
997        // Add some payload and padding
998        buf.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); // 10 bytes payload
999        buf.extend_from_slice(&[0, 0, 0, 0]); // 4 bytes padding
1000
1001        let layer = Ipv4Layer::at_offset(0);
1002        let (payload, padding) = layer.extract_padding(&buf);
1003
1004        assert_eq!(payload.len(), 10);
1005        assert_eq!(padding.len(), 4);
1006    }
1007}