Skip to main content

stackforge_core/layer/icmp/
builder.rs

1//! ICMP packet builder.
2//!
3//! Provides a fluent API for constructing ICMP packets with type-specific fields
4//! and automatic checksum calculation.
5//!
6//! # Example
7//!
8//! ```rust
9//! use stackforge_core::layer::icmp::IcmpBuilder;
10//!
11//! // Build an echo request
12//! let packet = IcmpBuilder::echo_request(0x1234, 1)
13//!     .payload(b"ping data")
14//!     .build();
15//!
16//! // Build a destination unreachable message
17//! let packet = IcmpBuilder::dest_unreach(3) // Port unreachable
18//!     .build();
19//! ```
20
21use std::net::Ipv4Addr;
22
23use super::checksum::icmp_checksum;
24use super::types::types;
25use super::{ICMP_MIN_HEADER_LEN, offsets};
26use crate::layer::field::FieldError;
27
28/// Builder for ICMP packets.
29///
30/// Due to ICMP's type-specific fields, this builder provides factory methods
31/// for common ICMP message types rather than a generic constructor.
32#[derive(Debug, Clone)]
33pub struct IcmpBuilder {
34    // Base fields (all ICMP messages)
35    icmp_type: u8,
36    code: u8,
37    checksum: Option<u16>,
38
39    // Type-specific data (4 bytes after checksum)
40    // Layout depends on ICMP type:
41    // - Echo: [id: u16, seq: u16]
42    // - Redirect: [gateway: 4 bytes IP]
43    // - Dest Unreach (code 4): [unused: u16, mtu: u16]
44    // - Param Problem: [ptr: u8, unused: 3 bytes]
45    // - Timestamp: handled separately (has additional fields)
46    type_specific: [u8; 4],
47
48    // Additional data for timestamp messages (12 bytes)
49    // ts_ori, ts_rx, ts_tx (3 x u32)
50    timestamp_data: Option<[u8; 12]>,
51
52    // Payload data
53    payload: Vec<u8>,
54
55    // Build options
56    auto_checksum: bool,
57}
58
59impl Default for IcmpBuilder {
60    fn default() -> Self {
61        Self {
62            icmp_type: types::ECHO_REQUEST,
63            code: 0,
64            checksum: None,
65            type_specific: [0; 4],
66            timestamp_data: None,
67            payload: Vec::new(),
68            auto_checksum: true,
69        }
70    }
71}
72
73impl IcmpBuilder {
74    /// Create a new ICMP builder with default values (echo request).
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    // ========== Factory Methods for Common Types ==========
80
81    /// Create an echo request (ping) packet.
82    ///
83    /// # Arguments
84    /// * `id` - Identifier
85    /// * `seq` - Sequence number
86    pub fn echo_request(id: u16, seq: u16) -> Self {
87        let mut builder = Self::new();
88        builder.icmp_type = types::ECHO_REQUEST;
89        builder.code = 0;
90        builder.set_id_seq(id, seq);
91        builder
92    }
93
94    /// Create an echo reply (pong) packet.
95    ///
96    /// # Arguments
97    /// * `id` - Identifier (should match request)
98    /// * `seq` - Sequence number (should match request)
99    pub fn echo_reply(id: u16, seq: u16) -> Self {
100        let mut builder = Self::new();
101        builder.icmp_type = types::ECHO_REPLY;
102        builder.code = 0;
103        builder.set_id_seq(id, seq);
104        builder
105    }
106
107    /// Create a destination unreachable message.
108    ///
109    /// # Arguments
110    /// * `code` - Specific unreachable code (0-15)
111    ///   - 0: Network unreachable
112    ///   - 1: Host unreachable
113    ///   - 2: Protocol unreachable
114    ///   - 3: Port unreachable
115    ///   - 4: Fragmentation needed (use `dest_unreach_need_frag` for this)
116    pub fn dest_unreach(code: u8) -> Self {
117        let mut builder = Self::new();
118        builder.icmp_type = types::DEST_UNREACH;
119        builder.code = code;
120        builder.type_specific = [0; 4]; // Unused for most codes
121        builder
122    }
123
124    /// Create a destination unreachable - fragmentation needed message.
125    ///
126    /// # Arguments
127    /// * `mtu` - Next-hop MTU value
128    pub fn dest_unreach_need_frag(mtu: u16) -> Self {
129        let mut builder = Self::new();
130        builder.icmp_type = types::DEST_UNREACH;
131        builder.code = 4; // Fragmentation needed
132        builder.type_specific[0] = 0; // Unused byte
133        builder.type_specific[1] = 0; // Unused byte
134        builder.type_specific[2..4].copy_from_slice(&mtu.to_be_bytes());
135        builder
136    }
137
138    /// Create a redirect message.
139    ///
140    /// # Arguments
141    /// * `code` - Redirect code (0-3)
142    ///   - 0: Redirect for network
143    ///   - 1: Redirect for host
144    ///   - 2: Redirect for TOS and network
145    ///   - 3: Redirect for TOS and host
146    /// * `gateway` - Gateway IP address to redirect to
147    pub fn redirect(code: u8, gateway: Ipv4Addr) -> Self {
148        let mut builder = Self::new();
149        builder.icmp_type = types::REDIRECT;
150        builder.code = code;
151        builder.type_specific.copy_from_slice(&gateway.octets());
152        builder
153    }
154
155    /// Create a time exceeded message.
156    ///
157    /// # Arguments
158    /// * `code` - Time exceeded code
159    ///   - 0: TTL exceeded in transit
160    ///   - 1: Fragment reassembly time exceeded
161    pub fn time_exceeded(code: u8) -> Self {
162        let mut builder = Self::new();
163        builder.icmp_type = types::TIME_EXCEEDED;
164        builder.code = code;
165        builder.type_specific = [0; 4]; // Unused
166        builder
167    }
168
169    /// Create a parameter problem message.
170    ///
171    /// # Arguments
172    /// * `ptr` - Pointer to the problematic byte in the original packet
173    pub fn param_problem(ptr: u8) -> Self {
174        let mut builder = Self::new();
175        builder.icmp_type = types::PARAM_PROBLEM;
176        builder.code = 0;
177        builder.type_specific[0] = ptr;
178        builder.type_specific[1] = 0; // Unused
179        builder.type_specific[2] = 0; // Length
180        builder.type_specific[3] = 0; // Unused
181        builder
182    }
183
184    /// Create a source quench message (deprecated).
185    pub fn source_quench() -> Self {
186        let mut builder = Self::new();
187        builder.icmp_type = types::SOURCE_QUENCH;
188        builder.code = 0;
189        builder.type_specific = [0; 4]; // Unused
190        builder
191    }
192
193    /// Create a timestamp request message.
194    ///
195    /// # Arguments
196    /// * `id` - Identifier
197    /// * `seq` - Sequence number
198    /// * `ts_ori` - Originate timestamp (milliseconds since midnight UT)
199    /// * `ts_rx` - Receive timestamp (0 for request)
200    /// * `ts_tx` - Transmit timestamp (0 for request)
201    pub fn timestamp_request(id: u16, seq: u16, ts_ori: u32, ts_rx: u32, ts_tx: u32) -> Self {
202        let mut builder = Self::new();
203        builder.icmp_type = types::TIMESTAMP;
204        builder.code = 0;
205        builder.set_id_seq(id, seq);
206
207        let mut ts_data = [0u8; 12];
208        ts_data[0..4].copy_from_slice(&ts_ori.to_be_bytes());
209        ts_data[4..8].copy_from_slice(&ts_rx.to_be_bytes());
210        ts_data[8..12].copy_from_slice(&ts_tx.to_be_bytes());
211        builder.timestamp_data = Some(ts_data);
212
213        builder
214    }
215
216    /// Create a timestamp reply message.
217    ///
218    /// # Arguments
219    /// * `id` - Identifier (should match request)
220    /// * `seq` - Sequence number (should match request)
221    /// * `ts_ori` - Originate timestamp from request
222    /// * `ts_rx` - Receive timestamp (when request was received)
223    /// * `ts_tx` - Transmit timestamp (when reply is sent)
224    pub fn timestamp_reply(id: u16, seq: u16, ts_ori: u32, ts_rx: u32, ts_tx: u32) -> Self {
225        let mut builder = Self::new();
226        builder.icmp_type = types::TIMESTAMP_REPLY;
227        builder.code = 0;
228        builder.set_id_seq(id, seq);
229
230        let mut ts_data = [0u8; 12];
231        ts_data[0..4].copy_from_slice(&ts_ori.to_be_bytes());
232        ts_data[4..8].copy_from_slice(&ts_rx.to_be_bytes());
233        ts_data[8..12].copy_from_slice(&ts_tx.to_be_bytes());
234        builder.timestamp_data = Some(ts_data);
235
236        builder
237    }
238
239    /// Create an address mask request message.
240    ///
241    /// # Arguments
242    /// * `id` - Identifier
243    /// * `seq` - Sequence number
244    pub fn address_mask_request(id: u16, seq: u16) -> Self {
245        let mut builder = Self::new();
246        builder.icmp_type = types::ADDRESS_MASK_REQUEST;
247        builder.code = 0;
248        builder.set_id_seq(id, seq);
249        builder
250    }
251
252    /// Create an address mask reply message.
253    ///
254    /// # Arguments
255    /// * `id` - Identifier (should match request)
256    /// * `seq` - Sequence number (should match request)
257    /// * `mask` - Address mask
258    pub fn address_mask_reply(id: u16, seq: u16, mask: Ipv4Addr) -> Self {
259        let mut builder = Self::new();
260        builder.icmp_type = types::ADDRESS_MASK_REPLY;
261        builder.code = 0;
262        builder.set_id_seq(id, seq);
263        // For address mask, the mask goes in the payload area
264        builder.payload = mask.octets().to_vec();
265        builder
266    }
267
268    // ========== Helper Methods ==========
269
270    /// Set ID and sequence number (for echo, timestamp, etc.)
271    fn set_id_seq(&mut self, id: u16, seq: u16) {
272        self.type_specific[0..2].copy_from_slice(&id.to_be_bytes());
273        self.type_specific[2..4].copy_from_slice(&seq.to_be_bytes());
274    }
275
276    // ========== Field Setters ==========
277
278    /// Set the ICMP type manually (use factory methods instead when possible).
279    pub fn icmp_type(mut self, t: u8) -> Self {
280        self.icmp_type = t;
281        self
282    }
283
284    /// Set the ICMP code manually.
285    pub fn code(mut self, c: u8) -> Self {
286        self.code = c;
287        self
288    }
289
290    /// Set the checksum manually.
291    ///
292    /// If not set, the checksum will be calculated automatically.
293    pub fn checksum(mut self, csum: u16) -> Self {
294        self.checksum = Some(csum);
295        self.auto_checksum = false;
296        self
297    }
298
299    /// Alias for checksum (Scapy compatibility).
300    pub fn chksum(self, csum: u16) -> Self {
301        self.checksum(csum)
302    }
303
304    /// Enable automatic checksum calculation (default).
305    pub fn enable_auto_checksum(mut self) -> Self {
306        self.auto_checksum = true;
307        self.checksum = None;
308        self
309    }
310
311    /// Disable automatic checksum calculation.
312    pub fn disable_auto_checksum(mut self) -> Self {
313        self.auto_checksum = false;
314        self
315    }
316
317    /// Set the payload data.
318    pub fn payload<T: Into<Vec<u8>>>(mut self, data: T) -> Self {
319        self.payload = data.into();
320        self
321    }
322
323    /// Append to the payload data.
324    pub fn append_payload<T: AsRef<[u8]>>(mut self, data: T) -> Self {
325        self.payload.extend_from_slice(data.as_ref());
326        self
327    }
328
329    // ========== Size Calculation ==========
330
331    /// Get the total packet size (header + optional timestamp + payload).
332    pub fn packet_size(&self) -> usize {
333        let mut size = ICMP_MIN_HEADER_LEN; // Base 8 bytes
334
335        // Timestamp messages have 12 additional bytes
336        if self.timestamp_data.is_some() {
337            size += 12;
338        }
339
340        size + self.payload.len()
341    }
342
343    /// Get the header size (8 bytes for most, 20 for timestamp).
344    pub fn header_size(&self) -> usize {
345        if self.timestamp_data.is_some() {
346            20 // 8 base + 12 timestamp data
347        } else {
348            ICMP_MIN_HEADER_LEN
349        }
350    }
351
352    // ========== Build Methods ==========
353
354    /// Build the ICMP packet into a new buffer.
355    pub fn build(&self) -> Vec<u8> {
356        let total_size = self.packet_size();
357        let mut buf = vec![0u8; total_size];
358        self.build_into(&mut buf)
359            .expect("buffer is correctly sized");
360        buf
361    }
362
363    /// Build the ICMP packet into an existing buffer.
364    pub fn build_into(&self, buf: &mut [u8]) -> Result<usize, FieldError> {
365        let total_size = self.packet_size();
366
367        if buf.len() < total_size {
368            return Err(FieldError::BufferTooShort {
369                offset: 0,
370                need: total_size,
371                have: buf.len(),
372            });
373        }
374
375        // Type
376        buf[offsets::TYPE] = self.icmp_type;
377
378        // Code
379        buf[offsets::CODE] = self.code;
380
381        // Checksum (initially 0, calculated later if auto_checksum is enabled)
382        buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&[0, 0]);
383
384        // Type-specific data (4 bytes)
385        buf[4..8].copy_from_slice(&self.type_specific);
386
387        let mut offset = 8;
388
389        // Timestamp data if present (12 bytes)
390        if let Some(ts_data) = &self.timestamp_data {
391            buf[offset..offset + 12].copy_from_slice(ts_data);
392            offset += 12;
393        }
394
395        // Payload
396        if !self.payload.is_empty() {
397            buf[offset..offset + self.payload.len()].copy_from_slice(&self.payload);
398        }
399
400        // Calculate checksum if enabled
401        if self.auto_checksum {
402            let csum = icmp_checksum(&buf[..total_size]);
403            buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
404        } else if let Some(csum) = self.checksum {
405            buf[offsets::CHECKSUM..offsets::CHECKSUM + 2].copy_from_slice(&csum.to_be_bytes());
406        }
407
408        Ok(total_size)
409    }
410
411    /// Build just the ICMP header (without payload).
412    pub fn build_header(&self) -> Vec<u8> {
413        let header_size = self.header_size();
414        let mut buf = vec![0u8; header_size];
415
416        // Create a copy without payload for header-only build
417        let builder = Self {
418            payload: Vec::new(),
419            ..self.clone()
420        };
421        builder
422            .build_into(&mut buf)
423            .expect("buffer is correctly sized");
424
425        buf
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_echo_request() {
435        let packet = IcmpBuilder::echo_request(0x1234, 5)
436            .payload(b"Hello")
437            .build();
438
439        assert_eq!(packet[0], types::ECHO_REQUEST); // type
440        assert_eq!(packet[1], 0); // code
441        // bytes 2-3 are checksum
442        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 0x1234); // id
443        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 5); // seq
444        assert_eq!(&packet[8..], b"Hello"); // payload
445    }
446
447    #[test]
448    fn test_echo_reply() {
449        let packet = IcmpBuilder::echo_reply(0x5678, 10).build();
450
451        assert_eq!(packet[0], types::ECHO_REPLY);
452        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 0x5678);
453        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 10);
454    }
455
456    #[test]
457    fn test_dest_unreach() {
458        let packet = IcmpBuilder::dest_unreach(3).build(); // Port unreachable
459
460        assert_eq!(packet[0], types::DEST_UNREACH);
461        assert_eq!(packet[1], 3); // code
462    }
463
464    #[test]
465    fn test_dest_unreach_need_frag() {
466        let packet = IcmpBuilder::dest_unreach_need_frag(1500).build();
467
468        assert_eq!(packet[0], types::DEST_UNREACH);
469        assert_eq!(packet[1], 4); // fragmentation needed
470        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 1500); // MTU
471    }
472
473    #[test]
474    fn test_redirect() {
475        let gateway = Ipv4Addr::new(192, 168, 1, 1);
476        let packet = IcmpBuilder::redirect(1, gateway).build();
477
478        assert_eq!(packet[0], types::REDIRECT);
479        assert_eq!(packet[1], 1); // code: redirect host
480        assert_eq!(&packet[4..8], &[192, 168, 1, 1]); // gateway IP
481    }
482
483    #[test]
484    fn test_time_exceeded() {
485        let packet = IcmpBuilder::time_exceeded(0).build(); // TTL exceeded
486
487        assert_eq!(packet[0], types::TIME_EXCEEDED);
488        assert_eq!(packet[1], 0);
489    }
490
491    #[test]
492    fn test_param_problem() {
493        let packet = IcmpBuilder::param_problem(20).build();
494
495        assert_eq!(packet[0], types::PARAM_PROBLEM);
496        assert_eq!(packet[4], 20); // pointer
497    }
498
499    #[test]
500    fn test_timestamp_request() {
501        let packet = IcmpBuilder::timestamp_request(0x1234, 1, 1000, 0, 0).build();
502
503        assert_eq!(packet[0], types::TIMESTAMP);
504        assert_eq!(packet.len(), 20); // 8 base + 12 timestamp data
505        assert_eq!(u16::from_be_bytes([packet[4], packet[5]]), 0x1234); // id
506        assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 1); // seq
507        assert_eq!(
508            u32::from_be_bytes([packet[8], packet[9], packet[10], packet[11]]),
509            1000
510        ); // ts_ori
511    }
512
513    #[test]
514    fn test_checksum_calculation() {
515        let packet = IcmpBuilder::echo_request(1, 1).payload(b"test").build();
516
517        // Checksum should be non-zero
518        let checksum = u16::from_be_bytes([packet[2], packet[3]]);
519        assert_ne!(checksum, 0);
520    }
521
522    #[test]
523    fn test_manual_checksum() {
524        let packet = IcmpBuilder::echo_request(1, 1).checksum(0xABCD).build();
525
526        assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 0xABCD);
527    }
528
529    #[test]
530    fn test_build_header_only() {
531        let header = IcmpBuilder::echo_request(1, 1)
532            .payload(b"this should not be included")
533            .build_header();
534
535        assert_eq!(header.len(), 8); // Header only
536    }
537
538    #[test]
539    fn test_timestamp_header_size() {
540        let header = IcmpBuilder::timestamp_request(1, 1, 1000, 2000, 3000).build_header();
541
542        assert_eq!(header.len(), 20); // 8 + 12 for timestamps
543    }
544}