surge_ping/icmp/
icmpv4.rs

1use socket2::Type as SockType;
2use std::convert::TryInto;
3use std::net::Ipv4Addr;
4
5use pnet_packet::icmp::{self, IcmpCode, IcmpType};
6use pnet_packet::Packet;
7use pnet_packet::{ipv4, PacketSize};
8
9use crate::{
10    error::{MalformedPacketError, Result, SurgeError},
11    is_linux_icmp_socket,
12};
13
14use super::{PingIdentifier, PingSequence};
15
16pub fn make_icmpv4_echo_packet(
17    ident_hint: PingIdentifier,
18    seq_cnt: PingSequence,
19    sock_type: SockType,
20    payload: &[u8],
21) -> Result<Vec<u8>> {
22    // 8 bytes of header, then payload.
23    let mut buf = vec![0; 8 + payload.len()];
24    let mut packet = icmp::echo_request::MutableEchoRequestPacket::new(&mut buf[..])
25        .ok_or(SurgeError::IncorrectBufferSize)?;
26
27    packet.set_icmp_type(icmp::IcmpTypes::EchoRequest);
28    packet.set_payload(payload);
29    packet.set_sequence_number(seq_cnt.into_u16());
30
31    if !(is_linux_icmp_socket!(sock_type)) {
32        packet.set_identifier(ident_hint.into_u16());
33
34        // Calculate and set the checksum
35        let icmp_packet =
36            icmp::IcmpPacket::new(packet.packet()).ok_or(SurgeError::IncorrectBufferSize)?;
37
38        let checksum = icmp::checksum(&icmp_packet);
39        packet.set_checksum(checksum);
40    }
41
42    Ok(packet.packet().to_vec())
43}
44
45/// Packet structure returned by ICMPv4.
46#[derive(Debug)]
47pub struct Icmpv4Packet {
48    source: Ipv4Addr,
49    destination: Ipv4Addr,
50    ttl: Option<u8>,
51    icmp_type: IcmpType,
52    icmp_code: IcmpCode,
53    size: usize,
54    real_dest: Ipv4Addr,
55    identifier: PingIdentifier,
56    sequence: PingSequence,
57}
58
59impl Default for Icmpv4Packet {
60    fn default() -> Self {
61        Icmpv4Packet {
62            source: Ipv4Addr::new(127, 0, 0, 1),
63            destination: Ipv4Addr::new(127, 0, 0, 1),
64            ttl: None,
65            icmp_type: IcmpType::new(0),
66            icmp_code: IcmpCode::new(0),
67            size: 0,
68            real_dest: Ipv4Addr::new(127, 0, 0, 1),
69            identifier: PingIdentifier(0),
70            sequence: PingSequence(0),
71        }
72    }
73}
74
75impl Icmpv4Packet {
76    fn source(&mut self, source: Ipv4Addr) -> &mut Self {
77        self.source = source;
78        self
79    }
80
81    /// Get the source field.
82    pub fn get_source(&self) -> Ipv4Addr {
83        self.source
84    }
85
86    fn destination(&mut self, destination: Ipv4Addr) -> &mut Self {
87        self.destination = destination;
88        self
89    }
90
91    /// Get the destination field.
92    pub fn get_destination(&self) -> Ipv4Addr {
93        self.destination
94    }
95
96    fn ttl(&mut self, ttl: u8) -> &mut Self {
97        self.ttl = Some(ttl);
98        self
99    }
100
101    /// Get the ttl field.
102    pub fn get_ttl(&self) -> Option<u8> {
103        self.ttl
104    }
105
106    fn icmp_type(&mut self, icmp_type: IcmpType) -> &mut Self {
107        self.icmp_type = icmp_type;
108        self
109    }
110
111    /// Get the icmp_type of the icmpv4 packet.
112    pub fn get_icmp_type(&self) -> IcmpType {
113        self.icmp_type
114    }
115
116    fn icmp_code(&mut self, icmp_code: IcmpCode) -> &mut Self {
117        self.icmp_code = icmp_code;
118        self
119    }
120
121    /// Get the icmp_code of the icmpv4 packet.
122    pub fn get_icmp_code(&self) -> IcmpCode {
123        self.icmp_code
124    }
125
126    fn size(&mut self, size: usize) -> &mut Self {
127        self.size = size;
128        self
129    }
130
131    /// Get the size of the icmp_v4 packet.
132    pub fn get_size(&self) -> usize {
133        self.size
134    }
135
136    fn real_dest(&mut self, addr: Ipv4Addr) -> &mut Self {
137        self.real_dest = addr;
138        self
139    }
140
141    /// If it is an `echo_reply` packet, it is the source address in the IPv4 packet.
142    /// If it is other packets, it is the destination address in the IPv4 packet in ICMP's payload.
143    pub fn get_real_dest(&self) -> Ipv4Addr {
144        self.real_dest
145    }
146
147    fn identifier(&mut self, identifier: PingIdentifier) -> &mut Self {
148        self.identifier = identifier;
149        self
150    }
151
152    /// Get the identifier of the icmp_v4 packet.
153    pub fn get_identifier(&self) -> PingIdentifier {
154        self.identifier
155    }
156
157    fn sequence(&mut self, sequence: PingSequence) -> &mut Self {
158        self.sequence = sequence;
159        self
160    }
161
162    /// Get the sequence of the icmp_v4 packet.
163    pub fn get_sequence(&self) -> PingSequence {
164        self.sequence
165    }
166
167    /// Decode into icmp packet from the socket message.
168    pub fn decode(
169        buf: &[u8],
170        sock_type: SockType,
171        src_addr: Ipv4Addr,
172        dst_addr: Ipv4Addr,
173    ) -> Result<Self> {
174        if is_linux_icmp_socket!(sock_type) {
175            Self::decode_from_icmp(buf, src_addr, dst_addr)
176        } else {
177            Self::decode_from_ipv4(buf)
178        }
179    }
180
181    fn decode_from_ipv4(buf: &[u8]) -> Result<Self> {
182        let ipv4_packet = ipv4::Ipv4Packet::new(buf)
183            .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
184        let icmp_packet = icmp::IcmpPacket::new(ipv4_packet.payload())
185            .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
186        let mut packet = Icmpv4Packet::default();
187
188        match icmp_packet.get_icmp_type() {
189            icmp::IcmpTypes::EchoReply => {
190                let icmp_packet = icmp::echo_reply::EchoReplyPacket::new(icmp_packet.packet())
191                    .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
192
193                packet
194                    .source(ipv4_packet.get_source())
195                    .destination(ipv4_packet.get_destination())
196                    .ttl(ipv4_packet.get_ttl())
197                    .icmp_type(icmp_packet.get_icmp_type())
198                    .icmp_code(icmp_packet.get_icmp_code())
199                    .size(icmp_packet.packet().len())
200                    .real_dest(ipv4_packet.get_source())
201                    .identifier(icmp_packet.get_identifier().into())
202                    .sequence(icmp_packet.get_sequence_number().into());
203            }
204            icmp::IcmpTypes::EchoRequest => return Err(SurgeError::EchoRequestPacket),
205            _ => {
206                let icmp_payload = icmp_packet.payload();
207
208                if icmp_payload.len() < 32 {
209                    return Err(SurgeError::from(MalformedPacketError::PayloadTooShort {
210                        got: icmp_payload.len(),
211                        want: 32,
212                    }));
213                }
214                // icmp unused(4) + ip header(20) + echo icmp(4)
215                let real_ip_packet = ipv4::Ipv4Packet::new(&icmp_payload[4..])
216                    .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
217                let identifier = u16::from_be_bytes(icmp_payload[28..30].try_into().unwrap());
218                let sequence = u16::from_be_bytes(icmp_payload[30..32].try_into().unwrap());
219
220                packet
221                    .source(ipv4_packet.get_source())
222                    .destination(ipv4_packet.get_destination())
223                    .ttl(ipv4_packet.get_ttl())
224                    .icmp_type(icmp_packet.get_icmp_type())
225                    .icmp_code(icmp_packet.get_icmp_code())
226                    .size(icmp_packet.packet_size())
227                    .real_dest(real_ip_packet.get_destination())
228                    .identifier(identifier.into())
229                    .sequence(sequence.into());
230            }
231        }
232
233        Ok(packet)
234    }
235
236    fn decode_from_icmp(buf: &[u8], src_addr: Ipv4Addr, dst_addr: Ipv4Addr) -> Result<Self> {
237        let icmp_packet = icmp::IcmpPacket::new(buf)
238            .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
239        let mut packet = Icmpv4Packet::default();
240
241        match icmp_packet.get_icmp_type() {
242            icmp::IcmpTypes::EchoReply => {
243                let icmp_packet = icmp::echo_reply::EchoReplyPacket::new(icmp_packet.packet())
244                    .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
245
246                packet
247                    .source(src_addr)
248                    .destination(dst_addr)
249                    .icmp_type(icmp_packet.get_icmp_type())
250                    .icmp_code(icmp_packet.get_icmp_code())
251                    .size(icmp_packet.packet().len())
252                    .real_dest(src_addr)
253                    .identifier(icmp_packet.get_identifier().into())
254                    .sequence(icmp_packet.get_sequence_number().into());
255            }
256            icmp::IcmpTypes::EchoRequest => return Err(SurgeError::EchoRequestPacket),
257            _ => {
258                let icmp_payload = icmp_packet.payload();
259
260                if icmp_payload.len() < 32 {
261                    return Err(SurgeError::from(MalformedPacketError::PayloadTooShort {
262                        got: icmp_payload.len(),
263                        want: 32,
264                    }));
265                }
266
267                // icmp unused(4) + ip header(20) + echo icmp(4)
268                let real_ip_packet = ipv4::Ipv4Packet::new(&icmp_payload[4..])
269                    .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
270                let identifier = u16::from_be_bytes(icmp_payload[28..30].try_into().unwrap());
271                let sequence = u16::from_be_bytes(icmp_payload[30..32].try_into().unwrap());
272
273                packet
274                    .source(src_addr)
275                    .destination(dst_addr)
276                    .icmp_type(icmp_packet.get_icmp_type())
277                    .icmp_code(icmp_packet.get_icmp_code())
278                    .size(icmp_packet.packet_size())
279                    .real_dest(real_ip_packet.get_destination())
280                    .identifier(identifier.into())
281                    .sequence(sequence.into());
282            }
283        }
284
285        Ok(packet)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::Icmpv4Packet;
293
294    #[test]
295    fn malformed_packet() {
296        let decoded_ipv4 =
297            hex::decode("4500001d0000000079018a76acd90e6e0a00f22203006c3293cc").unwrap();
298        assert!(Icmpv4Packet::decode(
299            &decoded_ipv4,
300            SockType::RAW,
301            ("172.217.14.110").parse().unwrap(),
302            ("10.0.242.34").parse().unwrap(),
303        )
304        .is_err());
305
306        let decoded_icmp = hex::decode("03006c3293cc").unwrap();
307        assert!(Icmpv4Packet::decode(
308            &decoded_icmp,
309            SockType::DGRAM,
310            ("172.217.14.110").parse().unwrap(),
311            ("10.0.242.34").parse().unwrap(),
312        )
313        .is_err());
314    }
315
316    #[test]
317    fn short_packet() {
318        let decoded_ipv4 =
319            hex::decode("4500001d0000000079018a76acd90e6e0a00f22203006c3293cc000100").unwrap();
320        assert!(Icmpv4Packet::decode(
321            &decoded_ipv4,
322            SockType::RAW,
323            ("172.217.14.110").parse().unwrap(),
324            ("10.0.242.34").parse().unwrap(),
325        )
326        .is_err());
327
328        let decoded_icmp = hex::decode("03006c3293cc000100").unwrap();
329        assert!(Icmpv4Packet::decode(
330            &decoded_icmp,
331            SockType::DGRAM,
332            ("172.217.14.110").parse().unwrap(),
333            ("10.0.242.34").parse().unwrap(),
334        )
335        .is_err());
336    }
337
338    #[test]
339    fn standard_packet() {
340        let decoded_ipv4 = hex::decode("45000054000000007901067e8efab00e0a00f22203004176a1ee0001613dd762000000002127040000000000101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031323334353637").unwrap();
341        Icmpv4Packet::decode(
342            &decoded_ipv4,
343            SockType::RAW,
344            ("172.217.14.110").parse().unwrap(),
345            ("10.0.242.34").parse().unwrap(),
346        )
347        .unwrap();
348
349        let decoded_icmp = hex::decode("03004176a1ee0001613dd762000000002127040000000000101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031323334353637").unwrap();
350        Icmpv4Packet::decode(
351            &decoded_icmp,
352            SockType::DGRAM,
353            ("172.217.14.110").parse().unwrap(),
354            ("10.0.242.34").parse().unwrap(),
355        )
356        .unwrap();
357    }
358}