Skip to main content

stackforge_core/layer/
stack.rs

1//! Layer stacking support for packet composition.
2//!
3//! This module provides the `LayerStack` type for composing packets using
4//! the `/` operator syntax: `Ether() / IP() / TCP()`.
5//!
6//! # Example
7//!
8//! ```rust
9//! use stackforge_core::layer::{
10//!     EthernetBuilder, LayerStack, LayerStackEntry, LayerKind,
11//!     ipv4::Ipv4Builder,
12//!     tcp::TcpBuilder,
13//! };
14//! use stackforge_core::layer::field::MacAddress;
15//! use std::net::Ipv4Addr;
16//!
17//! // Build a TCP SYN packet
18//! let pkt = LayerStack::new()
19//!     .push(LayerStackEntry::Ethernet(
20//!         EthernetBuilder::new()
21//!             .dst(MacAddress::BROADCAST)
22//!             .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]))
23//!     ))
24//!     .push(LayerStackEntry::Ipv4(
25//!         Ipv4Builder::new()
26//!             .src(Ipv4Addr::new(192, 168, 1, 1))
27//!             .dst(Ipv4Addr::new(192, 168, 1, 2))
28//!             .ttl(64)
29//!     ))
30//!     .push(LayerStackEntry::Tcp(
31//!         TcpBuilder::new()
32//!             .src_port(12345)
33//!             .dst_port(80)
34//!             .syn()
35//!     ))
36//!     .build();
37//! ```
38
39use super::bindings::apply_binding;
40use super::dns::builder::DnsBuilder;
41use super::ethernet::{ETHERNET_HEADER_LEN, EthernetBuilder};
42use super::http2::builder::Http2FrameBuilder;
43use super::icmp::builder::IcmpBuilder;
44use super::icmpv6::builder::Icmpv6Builder;
45use super::ipv4::builder::Ipv4Builder;
46use super::ipv6::builder::Ipv6Builder;
47use super::l2tp::builder::L2tpBuilder;
48use super::ssh::builder::SshBuilder;
49use super::tcp::builder::TcpBuilder;
50use super::tls::builder::TlsRecordBuilder;
51use super::udp::builder::UdpBuilder;
52use super::{ArpBuilder, LayerKind};
53use crate::Packet;
54use crate::layer::arp::ARP_HEADER_LEN;
55use crate::layer::dns::DNS_HEADER_LEN;
56use crate::layer::icmp::ICMP_MIN_HEADER_LEN;
57use crate::layer::icmpv6::ICMPV6_MIN_HEADER_LEN;
58use crate::layer::ipv4::IPV4_MIN_HEADER_LEN;
59use crate::layer::ipv6::IPV6_HEADER_LEN;
60use crate::layer::l2tp::L2TP_MIN_HEADER_LEN;
61use crate::layer::tcp::TCP_MIN_HEADER_LEN;
62use crate::layer::udp::UDP_HEADER_LEN;
63
64/// An entry in a layer stack, representing a protocol layer builder.
65#[derive(Debug, Clone)]
66pub enum LayerStackEntry {
67    /// Ethernet II layer
68    Ethernet(EthernetBuilder),
69    /// ARP layer
70    Arp(ArpBuilder),
71    /// IPv4 layer
72    Ipv4(Ipv4Builder),
73    /// IPv6 layer
74    Ipv6(Ipv6Builder),
75    /// TCP layer
76    Tcp(TcpBuilder),
77    /// UDP layer
78    Udp(UdpBuilder),
79    /// ICMP layer
80    Icmp(IcmpBuilder),
81    /// ICMPv6 layer
82    Icmpv6(Icmpv6Builder),
83    /// SSH layer
84    Ssh(SshBuilder),
85    /// TLS record layer
86    Tls(TlsRecordBuilder),
87    /// DNS layer
88    Dns(DnsBuilder),
89    /// HTTP/2 frame layer
90    Http2(Http2FrameBuilder),
91    /// L2TP layer
92    L2tp(L2tpBuilder),
93    /// Raw bytes payload
94    Raw(Vec<u8>),
95}
96
97impl LayerStackEntry {
98    /// Get the LayerKind for this entry.
99    pub fn kind(&self) -> LayerKind {
100        match self {
101            Self::Ethernet(_) => LayerKind::Ethernet,
102            Self::Arp(_) => LayerKind::Arp,
103            Self::Ipv4(_) => LayerKind::Ipv4,
104            Self::Ipv6(_) => LayerKind::Ipv6,
105            Self::Tcp(_) => LayerKind::Tcp,
106            Self::Udp(_) => LayerKind::Udp,
107            Self::Icmp(_) => LayerKind::Icmp,
108            Self::Icmpv6(_) => LayerKind::Icmpv6,
109            Self::Ssh(_) => LayerKind::Ssh,
110            Self::Tls(_) => LayerKind::Tls,
111            Self::Dns(_) => LayerKind::Dns,
112            Self::Http2(_) => LayerKind::Http2,
113            Self::L2tp(_) => LayerKind::L2tp,
114            Self::Raw(_) => LayerKind::Raw,
115        }
116    }
117
118    /// Build this layer into bytes, without applying bindings.
119    pub fn build_bytes(&self) -> Vec<u8> {
120        match self {
121            Self::Ethernet(b) => b.build(),
122            Self::Arp(b) => b.build(),
123            Self::Ipv4(b) => b.build(),
124            Self::Ipv6(b) => b.build(),
125            Self::Tcp(b) => b.build(),
126            Self::Udp(b) => b.build(),
127            Self::Icmp(b) => b.build(),
128            Self::Icmpv6(b) => b.build(),
129            Self::Ssh(b) => b.build(),
130            Self::Tls(b) => b.build(),
131            Self::Dns(b) => b.build(),
132            Self::Http2(b) => b.build(),
133            Self::L2tp(b) => b.build(),
134            Self::Raw(data) => data.clone(),
135        }
136    }
137
138    /// Get the header size for this layer.
139    pub fn header_size(&self) -> usize {
140        match self {
141            Self::Ethernet(_) => ETHERNET_HEADER_LEN,
142            Self::Arp(_) => ARP_HEADER_LEN,
143            Self::Ipv4(b) => b.header_size(),
144            Self::Ipv6(b) => b.header_size(),
145            Self::Tcp(b) => b.header_size(),
146            Self::Udp(b) => b.header_size(),
147            Self::Icmp(b) => b.header_size(),
148            Self::Icmpv6(b) => b.header_size(),
149            Self::Ssh(b) => b.header_size(),
150            Self::Tls(b) => b.record_size(),
151            Self::Dns(b) => b.header_size(),
152            Self::Http2(b) => b.build().len(), // frame size is dynamic
153            Self::L2tp(b) => b.header_size(),
154            Self::Raw(data) => data.len(),
155        }
156    }
157
158    /// Get minimum header size for this layer type.
159    pub fn min_header_size(&self) -> usize {
160        match self {
161            Self::Ethernet(_) => ETHERNET_HEADER_LEN,
162            Self::Arp(_) => ARP_HEADER_LEN,
163            Self::Ipv4(_) => IPV4_MIN_HEADER_LEN,
164            Self::Ipv6(_) => IPV6_HEADER_LEN,
165            Self::Tcp(_) => TCP_MIN_HEADER_LEN,
166            Self::Udp(_) => UDP_HEADER_LEN,
167            Self::Icmp(_) => ICMP_MIN_HEADER_LEN,
168            Self::Icmpv6(_) => ICMPV6_MIN_HEADER_LEN,
169            Self::Ssh(b) => b.header_size(),
170            Self::Tls(_) => 5, // TLS record header is 5 bytes
171            Self::Dns(_) => DNS_HEADER_LEN,
172            Self::Http2(_) => 9, // HTTP/2 frame header is 9 bytes
173            Self::L2tp(_) => L2TP_MIN_HEADER_LEN,
174            Self::Raw(data) => data.len(),
175        }
176    }
177}
178
179/// A stack of protocol layers that can be combined into a packet.
180///
181/// The stack maintains the order of layers and automatically applies
182/// bindings when building the final packet.
183#[derive(Debug, Clone, Default)]
184pub struct LayerStack {
185    layers: Vec<LayerStackEntry>,
186}
187
188impl LayerStack {
189    /// Create a new empty layer stack.
190    pub fn new() -> Self {
191        Self { layers: Vec::new() }
192    }
193
194    /// Push a new layer onto the stack.
195    ///
196    /// Layers are stacked from bottom (first) to top (last).
197    pub fn push(mut self, layer: LayerStackEntry) -> Self {
198        self.layers.push(layer);
199        self
200    }
201
202    /// Add a layer to the stack (mutable version).
203    pub fn add(&mut self, layer: LayerStackEntry) {
204        self.layers.push(layer);
205    }
206
207    /// Stack another LayerStack on top of this one.
208    ///
209    /// This is the implementation of the `/` operator.
210    pub fn stack(mut self, other: LayerStack) -> Self {
211        self.layers.extend(other.layers);
212        self
213    }
214
215    /// Get the number of layers in the stack.
216    pub fn len(&self) -> usize {
217        self.layers.len()
218    }
219
220    /// Check if the stack is empty.
221    pub fn is_empty(&self) -> bool {
222        self.layers.is_empty()
223    }
224
225    /// Get the layers in the stack.
226    pub fn layers(&self) -> &[LayerStackEntry] {
227        &self.layers
228    }
229
230    /// Build the stacked layers into raw bytes.
231    ///
232    /// This method:
233    /// 1. Builds each layer
234    /// 2. Applies bindings (e.g., sets Ethernet type when IP is stacked on top)
235    /// 3. Concatenates all layer bytes
236    /// 4. Recalculates checksums and lengths where applicable
237    pub fn build(&self) -> Vec<u8> {
238        if self.layers.is_empty() {
239            return Vec::new();
240        }
241
242        // First pass: build all layers to determine total size
243        let mut layer_bytes: Vec<Vec<u8>> = self.layers.iter().map(|l| l.build_bytes()).collect();
244
245        // Calculate total payload sizes for each layer
246        let total_len: usize = layer_bytes.iter().map(|b| b.len()).sum();
247
248        // Second pass: apply bindings and fix length/checksum fields
249        for i in 0..self.layers.len().saturating_sub(1) {
250            let lower_kind = self.layers[i].kind();
251            let upper_kind = self.layers[i + 1].kind();
252
253            // Apply binding: set the appropriate field in the lower layer
254            if let Some((field_name, field_value)) = apply_binding(lower_kind, upper_kind) {
255                apply_field_to_bytes(&mut layer_bytes[i], lower_kind, field_name, field_value);
256            }
257        }
258
259        // Fix IP and TCP checksum/length if present
260        self.fix_ip_fields(&mut layer_bytes, total_len);
261        self.fix_ipv6_fields(&mut layer_bytes);
262        self.fix_tcp_fields(&mut layer_bytes);
263        self.fix_udp_fields(&mut layer_bytes);
264        self.fix_icmp_fields(&mut layer_bytes);
265        self.fix_icmpv6_fields(&mut layer_bytes);
266
267        // Concatenate all layers
268        layer_bytes.into_iter().flatten().collect()
269    }
270
271    /// Build the stack into a Packet.
272    pub fn build_packet(&self) -> Packet {
273        let bytes = self.build();
274        let mut pkt = Packet::from_bytes(bytes);
275        let _ = pkt.parse();
276        pkt
277    }
278
279    /// Fix IP total length and checksum fields.
280    fn fix_ip_fields(&self, layer_bytes: &mut [Vec<u8>], _total_len: usize) {
281        for (i, layer) in self.layers.iter().enumerate() {
282            if let LayerStackEntry::Ipv4(_) = layer {
283                // Calculate payload size (everything after IP header)
284                let payload_size: usize = layer_bytes[i + 1..].iter().map(|b| b.len()).sum();
285                let ip_header_len = layer_bytes[i].len();
286                let ip_total_len = ip_header_len + payload_size;
287
288                // Update total length field (bytes 2-3)
289                if layer_bytes[i].len() >= 4 {
290                    layer_bytes[i][2] = ((ip_total_len >> 8) & 0xFF) as u8;
291                    layer_bytes[i][3] = (ip_total_len & 0xFF) as u8;
292
293                    // Recalculate checksum
294                    // First zero out the checksum field
295                    layer_bytes[i][10] = 0;
296                    layer_bytes[i][11] = 0;
297
298                    // Calculate new checksum
299                    let checksum = crate::layer::ipv4::checksum::ipv4_checksum(
300                        &layer_bytes[i][..ip_header_len],
301                    );
302                    layer_bytes[i][10] = ((checksum >> 8) & 0xFF) as u8;
303                    layer_bytes[i][11] = (checksum & 0xFF) as u8;
304                }
305            }
306        }
307    }
308
309    /// Fix TCP checksum if present.
310    fn fix_tcp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
311        // Find IP layer for pseudo-header
312        let mut ip_layer_idx = None;
313        let mut tcp_layer_idx = None;
314
315        for (i, layer) in self.layers.iter().enumerate() {
316            if let LayerStackEntry::Ipv4(_) = layer {
317                ip_layer_idx = Some(i);
318            }
319            if let LayerStackEntry::Tcp(_) = layer {
320                tcp_layer_idx = Some(i);
321            }
322        }
323
324        if let (Some(ip_idx), Some(tcp_idx)) = (ip_layer_idx, tcp_layer_idx) {
325            let ip_bytes = &layer_bytes[ip_idx];
326            if ip_bytes.len() >= 20 {
327                // Extract source and destination IPs
328                let mut src_bytes = [0u8; 4];
329                let mut dst_bytes = [0u8; 4];
330                src_bytes.copy_from_slice(&ip_bytes[12..16]);
331                dst_bytes.copy_from_slice(&ip_bytes[16..20]);
332                let src_ip = std::net::Ipv4Addr::from(src_bytes);
333                let dst_ip = std::net::Ipv4Addr::from(dst_bytes);
334
335                // Zero out checksum field
336                if layer_bytes[tcp_idx].len() >= 18 {
337                    layer_bytes[tcp_idx][16] = 0;
338                    layer_bytes[tcp_idx][17] = 0;
339                }
340
341                // Recalculate checksum with zeroed field
342                let mut tcp_with_zero_checksum: Vec<u8> =
343                    layer_bytes[tcp_idx..].iter().flatten().copied().collect();
344                if tcp_with_zero_checksum.len() >= 18 {
345                    tcp_with_zero_checksum[16] = 0;
346                    tcp_with_zero_checksum[17] = 0;
347                }
348
349                let checksum = crate::layer::tcp::checksum::tcp_checksum_ipv4(
350                    src_ip,
351                    dst_ip,
352                    &tcp_with_zero_checksum,
353                );
354
355                if layer_bytes[tcp_idx].len() >= 18 {
356                    layer_bytes[tcp_idx][16] = ((checksum >> 8) & 0xFF) as u8;
357                    layer_bytes[tcp_idx][17] = (checksum & 0xFF) as u8;
358                }
359            }
360        }
361    }
362
363    /// Fix UDP checksum and length if present.
364    fn fix_udp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
365        // Find IP layer for pseudo-header
366        let mut ip_layer_idx = None;
367        let mut udp_layer_idx = None;
368
369        for (i, layer) in self.layers.iter().enumerate() {
370            if let LayerStackEntry::Ipv4(_) = layer {
371                ip_layer_idx = Some(i);
372            }
373            if let LayerStackEntry::Udp(_) = layer {
374                udp_layer_idx = Some(i);
375            }
376        }
377
378        // If UDP is present without IP, just fix the length field
379        if let Some(udp_idx) = udp_layer_idx {
380            if ip_layer_idx.is_none() {
381                // Standalone UDP - only fix length field, leave checksum as 0
382                let udp_len: usize = layer_bytes[udp_idx..].iter().map(|b| b.len()).sum();
383                if layer_bytes[udp_idx].len() >= 6 {
384                    layer_bytes[udp_idx][4] = ((udp_len >> 8) & 0xFF) as u8;
385                    layer_bytes[udp_idx][5] = (udp_len & 0xFF) as u8;
386                }
387                return;
388            }
389        }
390
391        if let (Some(ip_idx), Some(udp_idx)) = (ip_layer_idx, udp_layer_idx) {
392            let ip_bytes = &layer_bytes[ip_idx];
393            if ip_bytes.len() >= 20 {
394                // Extract source and destination IPs
395                let mut src_bytes = [0u8; 4];
396                let mut dst_bytes = [0u8; 4];
397                src_bytes.copy_from_slice(&ip_bytes[12..16]);
398                dst_bytes.copy_from_slice(&ip_bytes[16..20]);
399                let src_ip = std::net::Ipv4Addr::from(src_bytes);
400                let dst_ip = std::net::Ipv4Addr::from(dst_bytes);
401
402                // Calculate UDP length (header + payload)
403                let udp_len: usize = layer_bytes[udp_idx..].iter().map(|b| b.len()).sum();
404
405                // Update length field (bytes 4-5)
406                if layer_bytes[udp_idx].len() >= 6 {
407                    layer_bytes[udp_idx][4] = ((udp_len >> 8) & 0xFF) as u8;
408                    layer_bytes[udp_idx][5] = (udp_len & 0xFF) as u8;
409                }
410
411                // Zero out checksum field (bytes 6-7)
412                if layer_bytes[udp_idx].len() >= 8 {
413                    layer_bytes[udp_idx][6] = 0;
414                    layer_bytes[udp_idx][7] = 0;
415                }
416
417                // Collect UDP segment (header + payload)
418                let udp_segment: Vec<u8> =
419                    layer_bytes[udp_idx..].iter().flatten().copied().collect();
420
421                // Calculate checksum
422                let checksum =
423                    crate::layer::udp::checksum::udp_checksum_ipv4(src_ip, dst_ip, &udp_segment);
424
425                // Write checksum back
426                if layer_bytes[udp_idx].len() >= 8 {
427                    layer_bytes[udp_idx][6] = ((checksum >> 8) & 0xFF) as u8;
428                    layer_bytes[udp_idx][7] = (checksum & 0xFF) as u8;
429                }
430            }
431        }
432    }
433
434    /// Fix ICMP checksum if present.
435    fn fix_icmp_fields(&self, layer_bytes: &mut [Vec<u8>]) {
436        // Find ICMP layer
437        let mut icmp_layer_idx = None;
438
439        for (i, layer) in self.layers.iter().enumerate() {
440            if let LayerStackEntry::Icmp(_) = layer {
441                icmp_layer_idx = Some(i);
442                break;
443            }
444        }
445
446        if let Some(icmp_idx) = icmp_layer_idx {
447            // Zero out checksum field (bytes 2-3)
448            if layer_bytes[icmp_idx].len() >= 4 {
449                layer_bytes[icmp_idx][2] = 0;
450                layer_bytes[icmp_idx][3] = 0;
451            }
452
453            // Collect ICMP message (header + payload)
454            let icmp_message: Vec<u8> = layer_bytes[icmp_idx..].iter().flatten().copied().collect();
455
456            // Calculate checksum
457            let checksum = crate::layer::icmp::checksum::icmp_checksum(&icmp_message);
458
459            // Write checksum back
460            if layer_bytes[icmp_idx].len() >= 4 {
461                layer_bytes[icmp_idx][2] = ((checksum >> 8) & 0xFF) as u8;
462                layer_bytes[icmp_idx][3] = (checksum & 0xFF) as u8;
463            }
464        }
465    }
466
467    /// Fix IPv6 payload length field if present.
468    fn fix_ipv6_fields(&self, layer_bytes: &mut [Vec<u8>]) {
469        for (i, layer) in self.layers.iter().enumerate() {
470            if let LayerStackEntry::Ipv6(_) = layer {
471                // Payload = everything after this IPv6 header
472                let payload_size: usize = layer_bytes[i + 1..].iter().map(|b| b.len()).sum();
473
474                // Update payload length field (bytes 4-5 of IPv6 header)
475                if layer_bytes[i].len() >= 6 {
476                    layer_bytes[i][4] = ((payload_size >> 8) & 0xFF) as u8;
477                    layer_bytes[i][5] = (payload_size & 0xFF) as u8;
478                }
479            }
480        }
481    }
482
483    /// Fix ICMPv6 checksum if an IPv6 layer and ICMPv6 layer are both present.
484    fn fix_icmpv6_fields(&self, layer_bytes: &mut [Vec<u8>]) {
485        let mut ipv6_layer_idx = None;
486        let mut icmpv6_layer_idx = None;
487
488        for (i, layer) in self.layers.iter().enumerate() {
489            if let LayerStackEntry::Ipv6(_) = layer {
490                ipv6_layer_idx = Some(i);
491            }
492            if let LayerStackEntry::Icmpv6(_) = layer {
493                icmpv6_layer_idx = Some(i);
494                break;
495            }
496        }
497
498        if let (Some(ipv6_idx), Some(icmpv6_idx)) = (ipv6_layer_idx, icmpv6_layer_idx) {
499            let ipv6_bytes = &layer_bytes[ipv6_idx];
500            if ipv6_bytes.len() >= 40 {
501                // Extract source and destination IPv6 addresses
502                let mut src_bytes = [0u8; 16];
503                let mut dst_bytes = [0u8; 16];
504                src_bytes.copy_from_slice(&ipv6_bytes[8..24]);
505                dst_bytes.copy_from_slice(&ipv6_bytes[24..40]);
506                let src_ip = std::net::Ipv6Addr::from(src_bytes);
507                let dst_ip = std::net::Ipv6Addr::from(dst_bytes);
508
509                // Zero out the checksum field in ICMPv6 header (bytes 2-3)
510                if layer_bytes[icmpv6_idx].len() >= 4 {
511                    layer_bytes[icmpv6_idx][2] = 0;
512                    layer_bytes[icmpv6_idx][3] = 0;
513                }
514
515                // Collect ICMPv6 message (header + payload)
516                let icmpv6_message: Vec<u8> = layer_bytes[icmpv6_idx..]
517                    .iter()
518                    .flatten()
519                    .copied()
520                    .collect();
521
522                // Calculate ICMPv6 checksum using IPv6 pseudo-header
523                let checksum =
524                    crate::layer::icmpv6::icmpv6_checksum(src_ip, dst_ip, &icmpv6_message);
525
526                // Write checksum back
527                if layer_bytes[icmpv6_idx].len() >= 4 {
528                    layer_bytes[icmpv6_idx][2] = ((checksum >> 8) & 0xFF) as u8;
529                    layer_bytes[icmpv6_idx][3] = (checksum & 0xFF) as u8;
530                }
531            }
532        }
533    }
534}
535
536/// Apply a binding field value to layer bytes.
537fn apply_field_to_bytes(bytes: &mut Vec<u8>, layer_kind: LayerKind, field_name: &str, value: u16) {
538    match layer_kind {
539        LayerKind::Ethernet => {
540            // type field is at offset 12, 2 bytes
541            if field_name == "type" && bytes.len() >= 14 {
542                bytes[12] = ((value >> 8) & 0xFF) as u8;
543                bytes[13] = (value & 0xFF) as u8;
544            }
545        },
546        LayerKind::Ipv4 => {
547            // proto field is at offset 9, 1 byte
548            if field_name == "proto" && bytes.len() >= 10 {
549                bytes[9] = (value & 0xFF) as u8;
550            }
551        },
552        LayerKind::Ipv6 => {
553            // nh (next header) field is at offset 6, 1 byte
554            if field_name == "nh" && bytes.len() >= 7 {
555                bytes[6] = (value & 0xFF) as u8;
556            }
557        },
558        LayerKind::Dot1Q | LayerKind::Dot1AD => {
559            // type field is at offset 2, 2 bytes (after TCI)
560            if field_name == "type" && bytes.len() >= 4 {
561                bytes[2] = ((value >> 8) & 0xFF) as u8;
562                bytes[3] = (value & 0xFF) as u8;
563            }
564        },
565        LayerKind::Tcp => {
566            // dport field is at offset 2, 2 bytes
567            if field_name == "dport" && bytes.len() >= 4 {
568                bytes[2] = ((value >> 8) & 0xFF) as u8;
569                bytes[3] = (value & 0xFF) as u8;
570            }
571        },
572        LayerKind::Udp => {
573            // dport field is at offset 2, 2 bytes
574            if field_name == "dport" && bytes.len() >= 4 {
575                bytes[2] = ((value >> 8) & 0xFF) as u8;
576                bytes[3] = (value & 0xFF) as u8;
577            }
578        },
579        _ => {},
580    }
581}
582
583/// Trait for types that can be converted into a LayerStackEntry.
584pub trait IntoLayerStackEntry {
585    fn into_layer_stack_entry(self) -> LayerStackEntry;
586}
587
588impl IntoLayerStackEntry for EthernetBuilder {
589    fn into_layer_stack_entry(self) -> LayerStackEntry {
590        LayerStackEntry::Ethernet(self)
591    }
592}
593
594impl IntoLayerStackEntry for ArpBuilder {
595    fn into_layer_stack_entry(self) -> LayerStackEntry {
596        LayerStackEntry::Arp(self)
597    }
598}
599
600impl IntoLayerStackEntry for Ipv4Builder {
601    fn into_layer_stack_entry(self) -> LayerStackEntry {
602        LayerStackEntry::Ipv4(self)
603    }
604}
605
606impl IntoLayerStackEntry for Ipv6Builder {
607    fn into_layer_stack_entry(self) -> LayerStackEntry {
608        LayerStackEntry::Ipv6(self)
609    }
610}
611
612impl IntoLayerStackEntry for Icmpv6Builder {
613    fn into_layer_stack_entry(self) -> LayerStackEntry {
614        LayerStackEntry::Icmpv6(self)
615    }
616}
617
618impl IntoLayerStackEntry for TcpBuilder {
619    fn into_layer_stack_entry(self) -> LayerStackEntry {
620        LayerStackEntry::Tcp(self)
621    }
622}
623
624impl IntoLayerStackEntry for UdpBuilder {
625    fn into_layer_stack_entry(self) -> LayerStackEntry {
626        LayerStackEntry::Udp(self)
627    }
628}
629
630impl IntoLayerStackEntry for DnsBuilder {
631    fn into_layer_stack_entry(self) -> LayerStackEntry {
632        LayerStackEntry::Dns(self)
633    }
634}
635
636impl IntoLayerStackEntry for Http2FrameBuilder {
637    fn into_layer_stack_entry(self) -> LayerStackEntry {
638        LayerStackEntry::Http2(self)
639    }
640}
641
642impl IntoLayerStackEntry for L2tpBuilder {
643    fn into_layer_stack_entry(self) -> LayerStackEntry {
644        LayerStackEntry::L2tp(self)
645    }
646}
647
648impl IntoLayerStackEntry for Vec<u8> {
649    fn into_layer_stack_entry(self) -> LayerStackEntry {
650        LayerStackEntry::Raw(self)
651    }
652}
653
654impl IntoLayerStackEntry for &[u8] {
655    fn into_layer_stack_entry(self) -> LayerStackEntry {
656        LayerStackEntry::Raw(self.to_vec())
657    }
658}
659
660// Implement Div for stacking
661impl std::ops::Div<LayerStack> for LayerStack {
662    type Output = LayerStack;
663
664    fn div(self, rhs: LayerStack) -> Self::Output {
665        self.stack(rhs)
666    }
667}
668
669impl std::ops::Div<LayerStackEntry> for LayerStack {
670    type Output = LayerStack;
671
672    fn div(self, rhs: LayerStackEntry) -> Self::Output {
673        self.push(rhs)
674    }
675}
676
677impl std::ops::Div<LayerStackEntry> for LayerStackEntry {
678    type Output = LayerStack;
679
680    fn div(self, rhs: LayerStackEntry) -> Self::Output {
681        LayerStack::new().push(self).push(rhs)
682    }
683}
684
685// Implement From for common builder types to create single-layer stacks
686impl From<EthernetBuilder> for LayerStack {
687    fn from(builder: EthernetBuilder) -> Self {
688        LayerStack::new().push(LayerStackEntry::Ethernet(builder))
689    }
690}
691
692impl From<ArpBuilder> for LayerStack {
693    fn from(builder: ArpBuilder) -> Self {
694        LayerStack::new().push(LayerStackEntry::Arp(builder))
695    }
696}
697
698impl From<Ipv4Builder> for LayerStack {
699    fn from(builder: Ipv4Builder) -> Self {
700        LayerStack::new().push(LayerStackEntry::Ipv4(builder))
701    }
702}
703
704impl From<Ipv6Builder> for LayerStack {
705    fn from(builder: Ipv6Builder) -> Self {
706        LayerStack::new().push(LayerStackEntry::Ipv6(builder))
707    }
708}
709
710impl From<Icmpv6Builder> for LayerStack {
711    fn from(builder: Icmpv6Builder) -> Self {
712        LayerStack::new().push(LayerStackEntry::Icmpv6(builder))
713    }
714}
715
716impl From<TcpBuilder> for LayerStack {
717    fn from(builder: TcpBuilder) -> Self {
718        LayerStack::new().push(LayerStackEntry::Tcp(builder))
719    }
720}
721
722impl From<UdpBuilder> for LayerStack {
723    fn from(builder: UdpBuilder) -> Self {
724        LayerStack::new().push(LayerStackEntry::Udp(builder))
725    }
726}
727
728impl From<DnsBuilder> for LayerStack {
729    fn from(builder: DnsBuilder) -> Self {
730        LayerStack::new().push(LayerStackEntry::Dns(builder))
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737    use crate::layer::field::MacAddress;
738    use crate::layer::{ethertype, ip_protocol};
739    use std::net::Ipv4Addr;
740
741    #[test]
742    fn test_ethernet_ip_stack() {
743        let stack = LayerStack::new()
744            .push(LayerStackEntry::Ethernet(
745                EthernetBuilder::new()
746                    .dst(MacAddress::BROADCAST)
747                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
748            ))
749            .push(LayerStackEntry::Ipv4(
750                Ipv4Builder::new()
751                    .src(Ipv4Addr::new(192, 168, 1, 1))
752                    .dst(Ipv4Addr::new(192, 168, 1, 2))
753                    .ttl(64),
754            ));
755
756        let bytes = stack.build();
757
758        // Verify Ethernet type was set correctly
759        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
760        assert_eq!(etype, ethertype::IPV4);
761
762        // Parse and verify
763        let mut pkt = Packet::from_bytes(bytes);
764        pkt.parse().unwrap();
765        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
766        assert!(pkt.get_layer(LayerKind::Ipv4).is_some());
767    }
768
769    #[test]
770    fn test_ethernet_ip_tcp_stack() {
771        let stack = LayerStack::new()
772            .push(LayerStackEntry::Ethernet(
773                EthernetBuilder::new()
774                    .dst(MacAddress::BROADCAST)
775                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
776            ))
777            .push(LayerStackEntry::Ipv4(
778                Ipv4Builder::new()
779                    .src(Ipv4Addr::new(192, 168, 1, 1))
780                    .dst(Ipv4Addr::new(192, 168, 1, 2))
781                    .ttl(64),
782            ))
783            .push(LayerStackEntry::Tcp(
784                TcpBuilder::new().src_port(12345).dst_port(80).syn(),
785            ));
786
787        let bytes = stack.build();
788
789        // Verify Ethernet type was set correctly
790        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
791        assert_eq!(etype, ethertype::IPV4);
792
793        // Verify IP protocol was set correctly
794        let proto = bytes[14 + 9]; // Ethernet header + IP protocol offset
795        assert_eq!(proto, ip_protocol::TCP);
796
797        // Parse and verify
798        let mut pkt = Packet::from_bytes(bytes);
799        pkt.parse().unwrap();
800        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
801        assert!(pkt.get_layer(LayerKind::Ipv4).is_some());
802        assert!(pkt.get_layer(LayerKind::Tcp).is_some());
803    }
804
805    #[test]
806    fn test_ethernet_arp_stack() {
807        let stack = LayerStack::new()
808            .push(LayerStackEntry::Ethernet(
809                EthernetBuilder::new()
810                    .dst(MacAddress::BROADCAST)
811                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
812            ))
813            .push(LayerStackEntry::Arp(
814                ArpBuilder::new()
815                    .op(1) // REQUEST
816                    .hwsrc(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]))
817                    .psrc(Ipv4Addr::new(192, 168, 1, 1))
818                    .pdst(Ipv4Addr::new(192, 168, 1, 2)),
819            ));
820
821        let bytes = stack.build();
822
823        // Verify Ethernet type was set correctly
824        let etype = u16::from_be_bytes([bytes[12], bytes[13]]);
825        assert_eq!(etype, ethertype::ARP);
826
827        // Parse and verify
828        let mut pkt = Packet::from_bytes(bytes);
829        pkt.parse().unwrap();
830        assert!(pkt.get_layer(LayerKind::Ethernet).is_some());
831        assert!(pkt.get_layer(LayerKind::Arp).is_some());
832    }
833
834    #[test]
835    fn test_div_operator() {
836        let eth = LayerStackEntry::Ethernet(
837            EthernetBuilder::new()
838                .dst(MacAddress::BROADCAST)
839                .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
840        );
841        let ip = LayerStackEntry::Ipv4(
842            Ipv4Builder::new()
843                .src(Ipv4Addr::new(192, 168, 1, 1))
844                .dst(Ipv4Addr::new(192, 168, 1, 2))
845                .ttl(64),
846        );
847
848        let stack = eth / ip;
849
850        assert_eq!(stack.len(), 2);
851        assert_eq!(stack.layers()[0].kind(), LayerKind::Ethernet);
852        assert_eq!(stack.layers()[1].kind(), LayerKind::Ipv4);
853    }
854
855    #[test]
856    fn test_raw_payload() {
857        let payload = b"Hello, World!";
858        let stack = LayerStack::new()
859            .push(LayerStackEntry::Ethernet(
860                EthernetBuilder::new()
861                    .dst(MacAddress::BROADCAST)
862                    .src(MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])),
863            ))
864            .push(LayerStackEntry::Ipv4(
865                Ipv4Builder::new()
866                    .src(Ipv4Addr::new(192, 168, 1, 1))
867                    .dst(Ipv4Addr::new(192, 168, 1, 2))
868                    .ttl(64),
869            ))
870            .push(LayerStackEntry::Raw(payload.to_vec()));
871
872        let bytes = stack.build();
873
874        // Verify payload is at the end
875        let expected_offset = 14 + 20; // Ethernet + IP header
876        assert_eq!(
877            &bytes[expected_offset..expected_offset + payload.len()],
878            payload
879        );
880    }
881
882    #[test]
883    fn test_ip_total_length_calculation() {
884        let payload = vec![0u8; 100];
885        let stack = LayerStack::new()
886            .push(LayerStackEntry::Ethernet(EthernetBuilder::new()))
887            .push(LayerStackEntry::Ipv4(
888                Ipv4Builder::new()
889                    .src(Ipv4Addr::new(10, 0, 0, 1))
890                    .dst(Ipv4Addr::new(10, 0, 0, 2)),
891            ))
892            .push(LayerStackEntry::Raw(payload));
893
894        let bytes = stack.build();
895
896        // Check IP total length (should be 20 + 100 = 120)
897        let ip_total_len = u16::from_be_bytes([bytes[14 + 2], bytes[14 + 3]]);
898        assert_eq!(ip_total_len, 120);
899    }
900}