Skip to main content

stackforge_core/layer/
stack.rs

1//! Layer stacking support for packet composition.
2//!
3//! This module provides the `LayerStack` type for composing packets using
4//! the `/` operator syntax: `Ether() / IP() / TCP()`.
5//!
6//! # Example
7//!
8//! ```rust
9//! use stackforge_core::layer::{
10//!     EthernetBuilder, LayerStack, LayerStackEntry, LayerKind,
11//!     ipv4::Ipv4Builder,
12//!     tcp::TcpBuilder,
13//! };
14//! use stackforge_core::layer::field::MacAddress;
15//! use std::net::Ipv4Addr;
16//!
17//! // Build a TCP SYN packet
18//! let pkt = LayerStack::new()
19//!     .push(LayerStackEntry::Ethernet(
20//!         EthernetBuilder::new()
21//!             .dst(MacAddress::BROADCAST)
22//!             .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]))
23//!     ))
24//!     .push(LayerStackEntry::Ipv4(
25//!         Ipv4Builder::new()
26//!             .src(Ipv4Addr::new(192, 168, 1, 1))
27//!             .dst(Ipv4Addr::new(192, 168, 1, 2))
28//!             .ttl(64)
29//!     ))
30//!     .push(LayerStackEntry::Tcp(
31//!         TcpBuilder::new()
32//!             .src_port(12345)
33//!             .dst_port(80)
34//!             .syn()
35//!     ))
36//!     .build();
37//! ```
38
39use super::bindings::apply_binding;
40use super::dns::builder::DnsBuilder;
41use super::ethernet::{ETHERNET_HEADER_LEN, EthernetBuilder};
42use super::ftp::builder::FtpBuilder;
43use super::http2::builder::Http2FrameBuilder;
44use super::icmp::builder::IcmpBuilder;
45use super::icmpv6::builder::Icmpv6Builder;
46use super::imap::builder::ImapBuilder;
47use super::ipv4::builder::Ipv4Builder;
48use super::ipv6::builder::Ipv6Builder;
49use super::l2tp::builder::L2tpBuilder;
50use super::modbus::builder::ModbusBuilder;
51use super::mqtt::builder::MqttBuilder;
52use super::mqttsn::builder::MqttSnBuilder;
53use super::pop3::builder::Pop3Builder;
54use super::smtp::builder::SmtpBuilder;
55use super::ssh::builder::SshBuilder;
56use super::tcp::builder::TcpBuilder;
57use super::tftp::builder::TftpBuilder;
58use super::tls::builder::TlsRecordBuilder;
59use super::udp::builder::UdpBuilder;
60use super::zwave::builder::ZWaveBuilder;
61use super::{ArpBuilder, LayerKind};
62use crate::Packet;
63use crate::layer::arp::ARP_HEADER_LEN;
64use crate::layer::dns::DNS_HEADER_LEN;
65use crate::layer::ftp::FTP_MIN_HEADER_LEN;
66use crate::layer::icmp::ICMP_MIN_HEADER_LEN;
67use crate::layer::icmpv6::ICMPV6_MIN_HEADER_LEN;
68use crate::layer::imap::IMAP_MIN_HEADER_LEN;
69use crate::layer::ipv4::IPV4_MIN_HEADER_LEN;
70use crate::layer::ipv6::IPV6_HEADER_LEN;
71use crate::layer::l2tp::L2TP_MIN_HEADER_LEN;
72use crate::layer::modbus::MODBUS_MIN_HEADER_LEN;
73use crate::layer::mqtt::MQTT_MIN_HEADER_LEN;
74use crate::layer::mqttsn::MQTTSN_MIN_HEADER_LEN;
75use crate::layer::pop3::POP3_MIN_HEADER_LEN;
76use crate::layer::smtp::SMTP_MIN_HEADER_LEN;
77use crate::layer::tcp::TCP_MIN_HEADER_LEN;
78use crate::layer::tftp::TFTP_MIN_HEADER_LEN;
79use crate::layer::udp::UDP_HEADER_LEN;
80use crate::layer::zwave::ZWAVE_MIN_HEADER_LEN;
81
82/// An entry in a layer stack, representing a protocol layer builder.
83#[derive(Debug, Clone)]
84pub enum LayerStackEntry {
85    /// Ethernet II layer
86    Ethernet(EthernetBuilder),
87    /// ARP layer
88    Arp(ArpBuilder),
89    /// IPv4 layer
90    Ipv4(Ipv4Builder),
91    /// IPv6 layer
92    Ipv6(Ipv6Builder),
93    /// TCP layer
94    Tcp(TcpBuilder),
95    /// UDP layer
96    Udp(UdpBuilder),
97    /// ICMP layer
98    Icmp(IcmpBuilder),
99    /// `ICMPv6` layer
100    Icmpv6(Icmpv6Builder),
101    /// SSH layer
102    Ssh(SshBuilder),
103    /// TLS record layer
104    Tls(TlsRecordBuilder),
105    /// DNS layer
106    Dns(DnsBuilder),
107    /// HTTP/2 frame layer
108    Http2(Http2FrameBuilder),
109    /// L2TP layer
110    L2tp(L2tpBuilder),
111    /// MQTT layer
112    Mqtt(MqttBuilder),
113    /// MQTT-SN layer
114    MqttSn(MqttSnBuilder),
115    /// Modbus layer
116    Modbus(ModbusBuilder),
117    /// Z-Wave layer
118    ZWave(ZWaveBuilder),
119    /// FTP layer
120    Ftp(FtpBuilder),
121    /// TFTP layer
122    Tftp(TftpBuilder),
123    /// SMTP layer
124    Smtp(SmtpBuilder),
125    /// POP3 layer
126    Pop3(Pop3Builder),
127    /// IMAP layer
128    Imap(ImapBuilder),
129    /// Raw bytes payload
130    Raw(Vec<u8>),
131}
132
133impl LayerStackEntry {
134    /// Get the `LayerKind` for this entry.
135    #[must_use]
136    pub fn kind(&self) -> LayerKind {
137        match self {
138            Self::Ethernet(_) => LayerKind::Ethernet,
139            Self::Arp(_) => LayerKind::Arp,
140            Self::Ipv4(_) => LayerKind::Ipv4,
141            Self::Ipv6(_) => LayerKind::Ipv6,
142            Self::Tcp(_) => LayerKind::Tcp,
143            Self::Udp(_) => LayerKind::Udp,
144            Self::Icmp(_) => LayerKind::Icmp,
145            Self::Icmpv6(_) => LayerKind::Icmpv6,
146            Self::Ssh(_) => LayerKind::Ssh,
147            Self::Tls(_) => LayerKind::Tls,
148            Self::Dns(_) => LayerKind::Dns,
149            Self::Http2(_) => LayerKind::Http2,
150            Self::L2tp(_) => LayerKind::L2tp,
151            Self::Mqtt(_) => LayerKind::Mqtt,
152            Self::MqttSn(_) => LayerKind::MqttSn,
153            Self::Modbus(_) => LayerKind::Modbus,
154            Self::ZWave(_) => LayerKind::ZWave,
155            Self::Ftp(_) => LayerKind::Ftp,
156            Self::Tftp(_) => LayerKind::Tftp,
157            Self::Smtp(_) => LayerKind::Smtp,
158            Self::Pop3(_) => LayerKind::Pop3,
159            Self::Imap(_) => LayerKind::Imap,
160            Self::Raw(_) => LayerKind::Raw,
161        }
162    }
163
164    /// Build this layer into bytes, without applying bindings.
165    #[must_use]
166    pub fn build_bytes(&self) -> Vec<u8> {
167        match self {
168            Self::Ethernet(b) => b.build(),
169            Self::Arp(b) => b.build(),
170            Self::Ipv4(b) => b.build(),
171            Self::Ipv6(b) => b.build(),
172            Self::Tcp(b) => b.build(),
173            Self::Udp(b) => b.build(),
174            Self::Icmp(b) => b.build(),
175            Self::Icmpv6(b) => b.build(),
176            Self::Ssh(b) => b.build(),
177            Self::Tls(b) => b.build(),
178            Self::Dns(b) => b.build(),
179            Self::Http2(b) => b.build(),
180            Self::L2tp(b) => b.build(),
181            Self::Mqtt(b) => b.build(),
182            Self::MqttSn(b) => b.build(),
183            Self::Modbus(b) => b.build(),
184            Self::ZWave(b) => b.build(),
185            Self::Ftp(b) => b.build(),
186            Self::Tftp(b) => b.build(),
187            Self::Smtp(b) => b.build(),
188            Self::Pop3(b) => b.build(),
189            Self::Imap(b) => b.build(),
190            Self::Raw(data) => data.clone(),
191        }
192    }
193
194    /// Get the header size for this layer.
195    #[must_use]
196    pub fn header_size(&self) -> usize {
197        match self {
198            Self::Ethernet(_) => ETHERNET_HEADER_LEN,
199            Self::Arp(_) => ARP_HEADER_LEN,
200            Self::Ipv4(b) => b.header_size(),
201            Self::Ipv6(b) => b.header_size(),
202            Self::Tcp(b) => b.header_size(),
203            Self::Udp(b) => b.header_size(),
204            Self::Icmp(b) => b.header_size(),
205            Self::Icmpv6(b) => b.header_size(),
206            Self::Ssh(b) => b.header_size(),
207            Self::Tls(b) => b.record_size(),
208            Self::Dns(b) => b.header_size(),
209            Self::Http2(b) => b.build().len(), // frame size is dynamic
210            Self::L2tp(b) => b.header_size(),
211            Self::Mqtt(b) => b.build().len(),
212            Self::MqttSn(b) => b.build().len(),
213            Self::Modbus(b) => b.build().len(),
214            Self::ZWave(b) => b.build().len(),
215            Self::Ftp(b) => b.build().len(),
216            Self::Tftp(b) => b.build().len(),
217            Self::Smtp(b) => b.build().len(),
218            Self::Pop3(b) => b.build().len(),
219            Self::Imap(b) => b.build().len(),
220            Self::Raw(data) => data.len(),
221        }
222    }
223
224    /// Get minimum header size for this layer type.
225    #[must_use]
226    pub fn min_header_size(&self) -> usize {
227        match self {
228            Self::Ethernet(_) => ETHERNET_HEADER_LEN,
229            Self::Arp(_) => ARP_HEADER_LEN,
230            Self::Ipv4(_) => IPV4_MIN_HEADER_LEN,
231            Self::Ipv6(_) => IPV6_HEADER_LEN,
232            Self::Tcp(_) => TCP_MIN_HEADER_LEN,
233            Self::Udp(_) => UDP_HEADER_LEN,
234            Self::Icmp(_) => ICMP_MIN_HEADER_LEN,
235            Self::Icmpv6(_) => ICMPV6_MIN_HEADER_LEN,
236            Self::Ssh(b) => b.header_size(),
237            Self::Tls(_) => 5, // TLS record header is 5 bytes
238            Self::Dns(_) => DNS_HEADER_LEN,
239            Self::Http2(_) => 9, // HTTP/2 frame header is 9 bytes
240            Self::L2tp(_) => L2TP_MIN_HEADER_LEN,
241            Self::Mqtt(_) => MQTT_MIN_HEADER_LEN,
242            Self::MqttSn(_) => MQTTSN_MIN_HEADER_LEN,
243            Self::Modbus(_) => MODBUS_MIN_HEADER_LEN,
244            Self::ZWave(_) => ZWAVE_MIN_HEADER_LEN,
245            Self::Ftp(_) => FTP_MIN_HEADER_LEN,
246            Self::Tftp(_) => TFTP_MIN_HEADER_LEN,
247            Self::Smtp(_) => SMTP_MIN_HEADER_LEN,
248            Self::Pop3(_) => POP3_MIN_HEADER_LEN,
249            Self::Imap(_) => IMAP_MIN_HEADER_LEN,
250            Self::Raw(data) => data.len(),
251        }
252    }
253}
254
255/// A stack of protocol layers that can be combined into a packet.
256///
257/// The stack maintains the order of layers and automatically applies
258/// bindings when building the final packet.
259#[derive(Debug, Clone, Default)]
260pub struct LayerStack {
261    layers: Vec<LayerStackEntry>,
262}
263
264impl LayerStack {
265    /// Create a new empty layer stack.
266    #[must_use]
267    pub fn new() -> Self {
268        Self { layers: Vec::new() }
269    }
270
271    /// Push a new layer onto the stack.
272    ///
273    /// Layers are stacked from bottom (first) to top (last).
274    #[must_use]
275    pub fn push(mut self, layer: LayerStackEntry) -> Self {
276        self.layers.push(layer);
277        self
278    }
279
280    /// Add a layer to the stack (mutable version).
281    pub fn add(&mut self, layer: LayerStackEntry) {
282        self.layers.push(layer);
283    }
284
285    /// Stack another `LayerStack` on top of this one.
286    ///
287    /// This is the implementation of the `/` operator.
288    #[must_use]
289    pub fn stack(mut self, other: LayerStack) -> Self {
290        self.layers.extend(other.layers);
291        self
292    }
293
294    /// Get the number of layers in the stack.
295    #[must_use]
296    pub fn len(&self) -> usize {
297        self.layers.len()
298    }
299
300    /// Check if the stack is empty.
301    #[must_use]
302    pub fn is_empty(&self) -> bool {
303        self.layers.is_empty()
304    }
305
306    /// Get the layers in the stack.
307    #[must_use]
308    pub fn layers(&self) -> &[LayerStackEntry] {
309        &self.layers
310    }
311
312    /// Build the stacked layers into raw bytes.
313    ///
314    /// This method:
315    /// 1. Builds each layer
316    /// 2. Applies bindings (e.g., sets Ethernet type when IP is stacked on top)
317    /// 3. Concatenates all layer bytes
318    /// 4. Recalculates checksums and lengths where applicable
319    #[must_use]
320    pub fn build(&self) -> Vec<u8> {
321        if self.layers.is_empty() {
322            return Vec::new();
323        }
324
325        // First pass: build all layers to determine total size
326        let mut layer_bytes: Vec<Vec<u8>> = self
327            .layers
328            .iter()
329            .map(LayerStackEntry::build_bytes)
330            .collect();
331
332        // Calculate total payload sizes for each layer
333        let total_len: usize = layer_bytes.iter().map(std::vec::Vec::len).sum();
334
335        // Second pass: apply bindings and fix length/checksum fields
336        for i in 0..self.layers.len().saturating_sub(1) {
337            let lower_kind = self.layers[i].kind();
338            let upper_kind = self.layers[i + 1].kind();
339
340            // Apply binding: set the appropriate field in the lower layer
341            if let Some((field_name, field_value)) = apply_binding(lower_kind, upper_kind) {
342                apply_field_to_bytes(&mut layer_bytes[i], lower_kind, field_name, field_value);
343            }
344        }
345
346        // Fix IP and TCP checksum/length if present
347        self.fix_ip_fields(&mut layer_bytes, total_len);
348        self.fix_ipv6_fields(&mut layer_bytes);
349        self.fix_tcp_fields(&mut layer_bytes);
350        self.fix_udp_fields(&mut layer_bytes);
351        self.fix_icmp_fields(&mut layer_bytes);
352        self.fix_icmpv6_fields(&mut layer_bytes);
353
354        // Concatenate all layers
355        layer_bytes.into_iter().flatten().collect()
356    }
357
358    /// Build the stack into a Packet.
359    #[must_use]
360    pub fn build_packet(&self) -> Packet {
361        let bytes = self.build();
362        let mut pkt = Packet::from_bytes(bytes);
363        let _ = pkt.parse();
364        pkt
365    }
366
367    /// Fix IP total length and checksum fields.
368    fn fix_ip_fields(&self, layer_bytes: &mut [Vec<u8>], _total_len: usize) {
369        for (i, layer) in self.layers.iter().enumerate() {
370            if let LayerStackEntry::Ipv4(_) = layer {
371                // Calculate payload size (everything after IP header)
372                let payload_size: usize = layer_bytes[i + 1..].iter().map(std::vec::Vec::len).sum();
373                let ip_header_len = layer_bytes[i].len();
374                let ip_total_len = ip_header_len + payload_size;
375
376                // Update total length field (bytes 2-3)
377                if layer_bytes[i].len() >= 4 {
378                    layer_bytes[i][2] = ((ip_total_len >> 8) & 0xFF) as u8;
379                    layer_bytes[i][3] = (ip_total_len & 0xFF) as u8;
380
381                    // Recalculate checksum
382                    // First zero out the checksum field
383                    layer_bytes[i][10] = 0;
384                    layer_bytes[i][11] = 0;
385
386                    // Calculate new checksum
387                    let checksum = crate::layer::ipv4::checksum::ipv4_checksum(
388                        &layer_bytes[i][..ip_header_len],
389                    );
390                    layer_bytes[i][10] = ((checksum >> 8) & 0xFF) as u8;
391                    layer_bytes[i][11] = (checksum & 0xFF) as u8;
392                }
393            }
394        }
395    }
396
397    /// Fix TCP checksum if present.
398    fn fix_tcp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
399        // Find IP layer for pseudo-header
400        let mut ip_layer_idx = None;
401        let mut tcp_layer_idx = None;
402
403        for (i, layer) in self.layers.iter().enumerate() {
404            if let LayerStackEntry::Ipv4(_) = layer {
405                ip_layer_idx = Some(i);
406            }
407            if let LayerStackEntry::Tcp(_) = layer {
408                tcp_layer_idx = Some(i);
409            }
410        }
411
412        if let (Some(ip_idx), Some(tcp_idx)) = (ip_layer_idx, tcp_layer_idx) {
413            let ip_bytes = &layer_bytes[ip_idx];
414            if ip_bytes.len() >= 20 {
415                // Extract source and destination IPs
416                let mut src_bytes = [0u8; 4];
417                let mut dst_bytes = [0u8; 4];
418                src_bytes.copy_from_slice(&ip_bytes[12..16]);
419                dst_bytes.copy_from_slice(&ip_bytes[16..20]);
420                let src_ip = std::net::Ipv4Addr::from(src_bytes);
421                let dst_ip = std::net::Ipv4Addr::from(dst_bytes);
422
423                // Zero out checksum field
424                if layer_bytes[tcp_idx].len() >= 18 {
425                    layer_bytes[tcp_idx][16] = 0;
426                    layer_bytes[tcp_idx][17] = 0;
427                }
428
429                // Recalculate checksum with zeroed field
430                let mut tcp_with_zero_checksum: Vec<u8> =
431                    layer_bytes[tcp_idx..].iter().flatten().copied().collect();
432                if tcp_with_zero_checksum.len() >= 18 {
433                    tcp_with_zero_checksum[16] = 0;
434                    tcp_with_zero_checksum[17] = 0;
435                }
436
437                let checksum = crate::layer::tcp::checksum::tcp_checksum_ipv4(
438                    src_ip,
439                    dst_ip,
440                    &tcp_with_zero_checksum,
441                );
442
443                if layer_bytes[tcp_idx].len() >= 18 {
444                    layer_bytes[tcp_idx][16] = ((checksum >> 8) & 0xFF) as u8;
445                    layer_bytes[tcp_idx][17] = (checksum & 0xFF) as u8;
446                }
447            }
448        }
449    }
450
451    /// Fix UDP checksum and length if present.
452    fn fix_udp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
453        // Find IP layer for pseudo-header
454        let mut ip_layer_idx = None;
455        let mut udp_layer_idx = None;
456
457        for (i, layer) in self.layers.iter().enumerate() {
458            if let LayerStackEntry::Ipv4(_) = layer {
459                ip_layer_idx = Some(i);
460            }
461            if let LayerStackEntry::Udp(_) = layer {
462                udp_layer_idx = Some(i);
463            }
464        }
465
466        // If UDP is present without IP, just fix the length field
467        if let Some(udp_idx) = udp_layer_idx
468            && ip_layer_idx.is_none()
469        {
470            // Standalone UDP - only fix length field, leave checksum as 0
471            let udp_len: usize = layer_bytes[udp_idx..].iter().map(std::vec::Vec::len).sum();
472            if layer_bytes[udp_idx].len() >= 6 {
473                layer_bytes[udp_idx][4] = ((udp_len >> 8) & 0xFF) as u8;
474                layer_bytes[udp_idx][5] = (udp_len & 0xFF) as u8;
475            }
476            return;
477        }
478
479        if let (Some(ip_idx), Some(udp_idx)) = (ip_layer_idx, udp_layer_idx) {
480            let ip_bytes = &layer_bytes[ip_idx];
481            if ip_bytes.len() >= 20 {
482                // Extract source and destination IPs
483                let mut src_bytes = [0u8; 4];
484                let mut dst_bytes = [0u8; 4];
485                src_bytes.copy_from_slice(&ip_bytes[12..16]);
486                dst_bytes.copy_from_slice(&ip_bytes[16..20]);
487                let src_ip = std::net::Ipv4Addr::from(src_bytes);
488                let dst_ip = std::net::Ipv4Addr::from(dst_bytes);
489
490                // Calculate UDP length (header + payload)
491                let udp_len: usize = layer_bytes[udp_idx..].iter().map(std::vec::Vec::len).sum();
492
493                // Update length field (bytes 4-5)
494                if layer_bytes[udp_idx].len() >= 6 {
495                    layer_bytes[udp_idx][4] = ((udp_len >> 8) & 0xFF) as u8;
496                    layer_bytes[udp_idx][5] = (udp_len & 0xFF) as u8;
497                }
498
499                // Zero out checksum field (bytes 6-7)
500                if layer_bytes[udp_idx].len() >= 8 {
501                    layer_bytes[udp_idx][6] = 0;
502                    layer_bytes[udp_idx][7] = 0;
503                }
504
505                // Collect UDP segment (header + payload)
506                let udp_segment: Vec<u8> =
507                    layer_bytes[udp_idx..].iter().flatten().copied().collect();
508
509                // Calculate checksum
510                let checksum =
511                    crate::layer::udp::checksum::udp_checksum_ipv4(src_ip, dst_ip, &udp_segment);
512
513                // Write checksum back
514                if layer_bytes[udp_idx].len() >= 8 {
515                    layer_bytes[udp_idx][6] = ((checksum >> 8) & 0xFF) as u8;
516                    layer_bytes[udp_idx][7] = (checksum & 0xFF) as u8;
517                }
518            }
519        }
520    }
521
522    /// Fix ICMP checksum if present.
523    fn fix_icmp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
524        // Find ICMP layer
525        let mut icmp_layer_idx = None;
526
527        for (i, layer) in self.layers.iter().enumerate() {
528            if let LayerStackEntry::Icmp(_) = layer {
529                icmp_layer_idx = Some(i);
530                break;
531            }
532        }
533
534        if let Some(icmp_idx) = icmp_layer_idx {
535            // Zero out checksum field (bytes 2-3)
536            if layer_bytes[icmp_idx].len() >= 4 {
537                layer_bytes[icmp_idx][2] = 0;
538                layer_bytes[icmp_idx][3] = 0;
539            }
540
541            // Collect ICMP message (header + payload)
542            let icmp_message: Vec<u8> = layer_bytes[icmp_idx..].iter().flatten().copied().collect();
543
544            // Calculate checksum
545            let checksum = crate::layer::icmp::checksum::icmp_checksum(&icmp_message);
546
547            // Write checksum back
548            if layer_bytes[icmp_idx].len() >= 4 {
549                layer_bytes[icmp_idx][2] = ((checksum >> 8) & 0xFF) as u8;
550                layer_bytes[icmp_idx][3] = (checksum & 0xFF) as u8;
551            }
552        }
553    }
554
555    /// Fix IPv6 payload length field if present.
556    fn fix_ipv6_fields(&self, layer_bytes: &mut [Vec<u8>]) {
557        for (i, layer) in self.layers.iter().enumerate() {
558            if let LayerStackEntry::Ipv6(_) = layer {
559                // Payload = everything after this IPv6 header
560                let payload_size: usize = layer_bytes[i + 1..].iter().map(std::vec::Vec::len).sum();
561
562                // Update payload length field (bytes 4-5 of IPv6 header)
563                if layer_bytes[i].len() >= 6 {
564                    layer_bytes[i][4] = ((payload_size >> 8) & 0xFF) as u8;
565                    layer_bytes[i][5] = (payload_size & 0xFF) as u8;
566                }
567            }
568        }
569    }
570
571    /// Fix `ICMPv6` checksum if an IPv6 layer and `ICMPv6` layer are both present.
572    fn fix_icmpv6_fields(&self, layer_bytes: &mut [Vec<u8>]) {
573        let mut ipv6_layer_idx = None;
574        let mut icmpv6_layer_idx = None;
575
576        for (i, layer) in self.layers.iter().enumerate() {
577            if let LayerStackEntry::Ipv6(_) = layer {
578                ipv6_layer_idx = Some(i);
579            }
580            if let LayerStackEntry::Icmpv6(_) = layer {
581                icmpv6_layer_idx = Some(i);
582                break;
583            }
584        }
585
586        if let (Some(ipv6_idx), Some(icmpv6_idx)) = (ipv6_layer_idx, icmpv6_layer_idx) {
587            let ipv6_bytes = &layer_bytes[ipv6_idx];
588            if ipv6_bytes.len() >= 40 {
589                // Extract source and destination IPv6 addresses
590                let mut src_bytes = [0u8; 16];
591                let mut dst_bytes = [0u8; 16];
592                src_bytes.copy_from_slice(&ipv6_bytes[8..24]);
593                dst_bytes.copy_from_slice(&ipv6_bytes[24..40]);
594                let src_ip = std::net::Ipv6Addr::from(src_bytes);
595                let dst_ip = std::net::Ipv6Addr::from(dst_bytes);
596
597                // Zero out the checksum field in ICMPv6 header (bytes 2-3)
598                if layer_bytes[icmpv6_idx].len() >= 4 {
599                    layer_bytes[icmpv6_idx][2] = 0;
600                    layer_bytes[icmpv6_idx][3] = 0;
601                }
602
603                // Collect ICMPv6 message (header + payload)
604                let icmpv6_message: Vec<u8> = layer_bytes[icmpv6_idx..]
605                    .iter()
606                    .flatten()
607                    .copied()
608                    .collect();
609
610                // Calculate ICMPv6 checksum using IPv6 pseudo-header
611                let checksum =
612                    crate::layer::icmpv6::icmpv6_checksum(src_ip, dst_ip, &icmpv6_message);
613
614                // Write checksum back
615                if layer_bytes[icmpv6_idx].len() >= 4 {
616                    layer_bytes[icmpv6_idx][2] = ((checksum >> 8) & 0xFF) as u8;
617                    layer_bytes[icmpv6_idx][3] = (checksum & 0xFF) as u8;
618                }
619            }
620        }
621    }
622}
623
624/// Apply a binding field value to layer bytes.
625fn apply_field_to_bytes(bytes: &mut Vec<u8>, layer_kind: LayerKind, field_name: &str, value: u16) {
626    match layer_kind {
627        LayerKind::Ethernet => {
628            // type field is at offset 12, 2 bytes
629            if field_name == "type" && bytes.len() >= 14 {
630                bytes[12] = ((value >> 8) & 0xFF) as u8;
631                bytes[13] = (value & 0xFF) as u8;
632            }
633        },
634        LayerKind::Ipv4 => {
635            // proto field is at offset 9, 1 byte
636            if field_name == "proto" && bytes.len() >= 10 {
637                bytes[9] = (value & 0xFF) as u8;
638            }
639        },
640        LayerKind::Ipv6 => {
641            // nh (next header) field is at offset 6, 1 byte
642            if field_name == "nh" && bytes.len() >= 7 {
643                bytes[6] = (value & 0xFF) as u8;
644            }
645        },
646        LayerKind::Dot1Q | LayerKind::Dot1AD => {
647            // type field is at offset 2, 2 bytes (after TCI)
648            if field_name == "type" && bytes.len() >= 4 {
649                bytes[2] = ((value >> 8) & 0xFF) as u8;
650                bytes[3] = (value & 0xFF) as u8;
651            }
652        },
653        LayerKind::Tcp => {
654            // dport field is at offset 2, 2 bytes
655            if field_name == "dport" && bytes.len() >= 4 {
656                bytes[2] = ((value >> 8) & 0xFF) as u8;
657                bytes[3] = (value & 0xFF) as u8;
658            }
659        },
660        LayerKind::Udp => {
661            // dport field is at offset 2, 2 bytes
662            if field_name == "dport" && bytes.len() >= 4 {
663                bytes[2] = ((value >> 8) & 0xFF) as u8;
664                bytes[3] = (value & 0xFF) as u8;
665            }
666        },
667        _ => {},
668    }
669}
670
671/// Trait for types that can be converted into a `LayerStackEntry`.
672pub trait IntoLayerStackEntry {
673    fn into_layer_stack_entry(self) -> LayerStackEntry;
674}
675
676impl IntoLayerStackEntry for EthernetBuilder {
677    fn into_layer_stack_entry(self) -> LayerStackEntry {
678        LayerStackEntry::Ethernet(self)
679    }
680}
681
682impl IntoLayerStackEntry for ArpBuilder {
683    fn into_layer_stack_entry(self) -> LayerStackEntry {
684        LayerStackEntry::Arp(self)
685    }
686}
687
688impl IntoLayerStackEntry for Ipv4Builder {
689    fn into_layer_stack_entry(self) -> LayerStackEntry {
690        LayerStackEntry::Ipv4(self)
691    }
692}
693
694impl IntoLayerStackEntry for Ipv6Builder {
695    fn into_layer_stack_entry(self) -> LayerStackEntry {
696        LayerStackEntry::Ipv6(self)
697    }
698}
699
700impl IntoLayerStackEntry for Icmpv6Builder {
701    fn into_layer_stack_entry(self) -> LayerStackEntry {
702        LayerStackEntry::Icmpv6(self)
703    }
704}
705
706impl IntoLayerStackEntry for TcpBuilder {
707    fn into_layer_stack_entry(self) -> LayerStackEntry {
708        LayerStackEntry::Tcp(self)
709    }
710}
711
712impl IntoLayerStackEntry for UdpBuilder {
713    fn into_layer_stack_entry(self) -> LayerStackEntry {
714        LayerStackEntry::Udp(self)
715    }
716}
717
718impl IntoLayerStackEntry for DnsBuilder {
719    fn into_layer_stack_entry(self) -> LayerStackEntry {
720        LayerStackEntry::Dns(self)
721    }
722}
723
724impl IntoLayerStackEntry for Http2FrameBuilder {
725    fn into_layer_stack_entry(self) -> LayerStackEntry {
726        LayerStackEntry::Http2(self)
727    }
728}
729
730impl IntoLayerStackEntry for L2tpBuilder {
731    fn into_layer_stack_entry(self) -> LayerStackEntry {
732        LayerStackEntry::L2tp(self)
733    }
734}
735
736impl IntoLayerStackEntry for MqttBuilder {
737    fn into_layer_stack_entry(self) -> LayerStackEntry {
738        LayerStackEntry::Mqtt(self)
739    }
740}
741
742impl IntoLayerStackEntry for MqttSnBuilder {
743    fn into_layer_stack_entry(self) -> LayerStackEntry {
744        LayerStackEntry::MqttSn(self)
745    }
746}
747
748impl IntoLayerStackEntry for ModbusBuilder {
749    fn into_layer_stack_entry(self) -> LayerStackEntry {
750        LayerStackEntry::Modbus(self)
751    }
752}
753
754impl IntoLayerStackEntry for ZWaveBuilder {
755    fn into_layer_stack_entry(self) -> LayerStackEntry {
756        LayerStackEntry::ZWave(self)
757    }
758}
759
760impl IntoLayerStackEntry for Vec<u8> {
761    fn into_layer_stack_entry(self) -> LayerStackEntry {
762        LayerStackEntry::Raw(self)
763    }
764}
765
766impl IntoLayerStackEntry for &[u8] {
767    fn into_layer_stack_entry(self) -> LayerStackEntry {
768        LayerStackEntry::Raw(self.to_vec())
769    }
770}
771
772// Implement Div for stacking
773impl std::ops::Div<LayerStack> for LayerStack {
774    type Output = LayerStack;
775
776    fn div(self, rhs: LayerStack) -> Self::Output {
777        self.stack(rhs)
778    }
779}
780
781impl std::ops::Div<LayerStackEntry> for LayerStack {
782    type Output = LayerStack;
783
784    fn div(self, rhs: LayerStackEntry) -> Self::Output {
785        self.push(rhs)
786    }
787}
788
789impl std::ops::Div<LayerStackEntry> for LayerStackEntry {
790    type Output = LayerStack;
791
792    fn div(self, rhs: LayerStackEntry) -> Self::Output {
793        LayerStack::new().push(self).push(rhs)
794    }
795}
796
797// Implement From for common builder types to create single-layer stacks
798impl From<EthernetBuilder> for LayerStack {
799    fn from(builder: EthernetBuilder) -> Self {
800        LayerStack::new().push(LayerStackEntry::Ethernet(builder))
801    }
802}
803
804impl From<ArpBuilder> for LayerStack {
805    fn from(builder: ArpBuilder) -> Self {
806        LayerStack::new().push(LayerStackEntry::Arp(builder))
807    }
808}
809
810impl From<Ipv4Builder> for LayerStack {
811    fn from(builder: Ipv4Builder) -> Self {
812        LayerStack::new().push(LayerStackEntry::Ipv4(builder))
813    }
814}
815
816impl From<Ipv6Builder> for LayerStack {
817    fn from(builder: Ipv6Builder) -> Self {
818        LayerStack::new().push(LayerStackEntry::Ipv6(builder))
819    }
820}
821
822impl From<Icmpv6Builder> for LayerStack {
823    fn from(builder: Icmpv6Builder) -> Self {
824        LayerStack::new().push(LayerStackEntry::Icmpv6(builder))
825    }
826}
827
828impl From<TcpBuilder> for LayerStack {
829    fn from(builder: TcpBuilder) -> Self {
830        LayerStack::new().push(LayerStackEntry::Tcp(builder))
831    }
832}
833
834impl From<UdpBuilder> for LayerStack {
835    fn from(builder: UdpBuilder) -> Self {
836        LayerStack::new().push(LayerStackEntry::Udp(builder))
837    }
838}
839
840impl From<DnsBuilder> for LayerStack {
841    fn from(builder: DnsBuilder) -> Self {
842        LayerStack::new().push(LayerStackEntry::Dns(builder))
843    }
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849    use crate::layer::field::MacAddress;
850    use crate::layer::{ethertype, ip_protocol};
851    use std::net::Ipv4Addr;
852
853    #[test]
854    fn test_ethernet_ip_stack() {
855        let stack = LayerStack::new()
856            .push(LayerStackEntry::Ethernet(
857                EthernetBuilder::new()
858                    .dst(MacAddress::BROADCAST)
859                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
860            ))
861            .push(LayerStackEntry::Ipv4(
862                Ipv4Builder::new()
863                    .src(Ipv4Addr::new(192, 168, 1, 1))
864                    .dst(Ipv4Addr::new(192, 168, 1, 2))
865                    .ttl(64),
866            ));
867
868        let bytes = stack.build();
869
870        // Verify Ethernet type was set correctly
871        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
872        assert_eq!(etype, ethertype::IPV4);
873
874        // Parse and verify
875        let mut pkt = Packet::from_bytes(bytes);
876        pkt.parse().unwrap();
877        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
878        assert!(pkt.get_layer(LayerKind::Ipv4).is_some());
879    }
880
881    #[test]
882    fn test_ethernet_ip_tcp_stack() {
883        let stack = LayerStack::new()
884            .push(LayerStackEntry::Ethernet(
885                EthernetBuilder::new()
886                    .dst(MacAddress::BROADCAST)
887                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
888            ))
889            .push(LayerStackEntry::Ipv4(
890                Ipv4Builder::new()
891                    .src(Ipv4Addr::new(192, 168, 1, 1))
892                    .dst(Ipv4Addr::new(192, 168, 1, 2))
893                    .ttl(64),
894            ))
895            .push(LayerStackEntry::Tcp(
896                TcpBuilder::new().src_port(12345).dst_port(80).syn(),
897            ));
898
899        let bytes = stack.build();
900
901        // Verify Ethernet type was set correctly
902        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
903        assert_eq!(etype, ethertype::IPV4);
904
905        // Verify IP protocol was set correctly
906        let proto = bytes[14 + 9]; // Ethernet header + IP protocol offset
907        assert_eq!(proto, ip_protocol::TCP);
908
909        // Parse and verify
910        let mut pkt = Packet::from_bytes(bytes);
911        pkt.parse().unwrap();
912        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
913        assert!(pkt.get_layer(LayerKind::Ipv4).is_some());
914        assert!(pkt.get_layer(LayerKind::Tcp).is_some());
915    }
916
917    #[test]
918    fn test_ethernet_arp_stack() {
919        let stack = LayerStack::new()
920            .push(LayerStackEntry::Ethernet(
921                EthernetBuilder::new()
922                    .dst(MacAddress::BROADCAST)
923                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
924            ))
925            .push(LayerStackEntry::Arp(
926                ArpBuilder::new()
927                    .op(1) // REQUEST
928                    .hwsrc(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]))
929                    .psrc(Ipv4Addr::new(192, 168, 1, 1))
930                    .pdst(Ipv4Addr::new(192, 168, 1, 2)),
931            ));
932
933        let bytes = stack.build();
934
935        // Verify Ethernet type was set correctly
936        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
937        assert_eq!(etype, ethertype::ARP);
938
939        // Parse and verify
940        let mut pkt = Packet::from_bytes(bytes);
941        pkt.parse().unwrap();
942        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
943        assert!(pkt.get_layer(LayerKind::Arp).is_some());
944    }
945
946    #[test]
947    fn test_div_operator() {
948        let eth = LayerStackEntry::Ethernet(
949            EthernetBuilder::new()
950                .dst(MacAddress::BROADCAST)
951                .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
952        );
953        let ip = LayerStackEntry::Ipv4(
954            Ipv4Builder::new()
955                .src(Ipv4Addr::new(192, 168, 1, 1))
956                .dst(Ipv4Addr::new(192, 168, 1, 2))
957                .ttl(64),
958        );
959
960        let stack = eth / ip;
961
962        assert_eq!(stack.len(), 2);
963        assert_eq!(stack.layers()[0].kind(), LayerKind::Ethernet);
964        assert_eq!(stack.layers()[1].kind(), LayerKind::Ipv4);
965    }
966
967    #[test]
968    fn test_raw_payload() {
969        let payload = b"Hello, World!";
970        let stack = LayerStack::new()
971            .push(LayerStackEntry::Ethernet(
972                EthernetBuilder::new()
973                    .dst(MacAddress::BROADCAST)
974                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
975            ))
976            .push(LayerStackEntry::Ipv4(
977                Ipv4Builder::new()
978                    .src(Ipv4Addr::new(192, 168, 1, 1))
979                    .dst(Ipv4Addr::new(192, 168, 1, 2))
980                    .ttl(64),
981            ))
982            .push(LayerStackEntry::Raw(payload.to_vec()));
983
984        let bytes = stack.build();
985
986        // Verify payload is at the end
987        let expected_offset = 14 + 20; // Ethernet + IP header
988        assert_eq!(
989            &bytes[expected_offset..expected_offset + payload.len()],
990            payload
991        );
992    }
993
994    #[test]
995    fn test_ip_total_length_calculation() {
996        let payload = vec![0u8; 100];
997        let stack = LayerStack::new()
998            .push(LayerStackEntry::Ethernet(EthernetBuilder::new()))
999            .push(LayerStackEntry::Ipv4(
1000                Ipv4Builder::new()
1001                    .src(Ipv4Addr::new(10, 0, 0, 1))
1002                    .dst(Ipv4Addr::new(10, 0, 0, 2)),
1003            ))
1004            .push(LayerStackEntry::Raw(payload));
1005
1006        let bytes = stack.build();
1007
1008        // Check IP total length (should be 20 + 100 = 120)
1009        let ip_total_len = u16::from_be_bytes([bytes[14 + 2], bytes[14 + 3]]);
1010        assert_eq!(ip_total_len, 120);
1011    }
1012}