1use socket2::Type as SockType;
2use std::convert::TryInto;
3use std::net::Ipv4Addr;
4
5use pnet_packet::icmp::{self, IcmpCode, IcmpType};
6use pnet_packet::Packet;
7use pnet_packet::{ipv4, PacketSize};
8
9use crate::{
10 error::{MalformedPacketError, Result, SurgeError},
11 is_linux_icmp_socket,
12};
13
14use super::{PingIdentifier, PingSequence};
15
16pub fn make_icmpv4_echo_packet(
17 ident_hint: PingIdentifier,
18 seq_cnt: PingSequence,
19 sock_type: SockType,
20 payload: &[u8],
21) -> Result<Vec<u8>> {
22 let mut buf = vec![0; 8 + payload.len()];
24 let mut packet = icmp::echo_request::MutableEchoRequestPacket::new(&mut buf[..])
25 .ok_or(SurgeError::IncorrectBufferSize)?;
26
27 packet.set_icmp_type(icmp::IcmpTypes::EchoRequest);
28 packet.set_payload(payload);
29 packet.set_sequence_number(seq_cnt.into_u16());
30
31 if !(is_linux_icmp_socket!(sock_type)) {
32 packet.set_identifier(ident_hint.into_u16());
33
34 let icmp_packet =
36 icmp::IcmpPacket::new(packet.packet()).ok_or(SurgeError::IncorrectBufferSize)?;
37
38 let checksum = icmp::checksum(&icmp_packet);
39 packet.set_checksum(checksum);
40 }
41
42 Ok(packet.packet().to_vec())
43}
44
45#[derive(Debug)]
47pub struct Icmpv4Packet {
48 source: Ipv4Addr,
49 destination: Ipv4Addr,
50 ttl: Option<u8>,
51 icmp_type: IcmpType,
52 icmp_code: IcmpCode,
53 size: usize,
54 real_dest: Ipv4Addr,
55 identifier: PingIdentifier,
56 sequence: PingSequence,
57}
58
59impl Default for Icmpv4Packet {
60 fn default() -> Self {
61 Icmpv4Packet {
62 source: Ipv4Addr::new(127, 0, 0, 1),
63 destination: Ipv4Addr::new(127, 0, 0, 1),
64 ttl: None,
65 icmp_type: IcmpType::new(0),
66 icmp_code: IcmpCode::new(0),
67 size: 0,
68 real_dest: Ipv4Addr::new(127, 0, 0, 1),
69 identifier: PingIdentifier(0),
70 sequence: PingSequence(0),
71 }
72 }
73}
74
75impl Icmpv4Packet {
76 fn source(&mut self, source: Ipv4Addr) -> &mut Self {
77 self.source = source;
78 self
79 }
80
81 pub fn get_source(&self) -> Ipv4Addr {
83 self.source
84 }
85
86 fn destination(&mut self, destination: Ipv4Addr) -> &mut Self {
87 self.destination = destination;
88 self
89 }
90
91 pub fn get_destination(&self) -> Ipv4Addr {
93 self.destination
94 }
95
96 fn ttl(&mut self, ttl: u8) -> &mut Self {
97 self.ttl = Some(ttl);
98 self
99 }
100
101 pub fn get_ttl(&self) -> Option<u8> {
103 self.ttl
104 }
105
106 fn icmp_type(&mut self, icmp_type: IcmpType) -> &mut Self {
107 self.icmp_type = icmp_type;
108 self
109 }
110
111 pub fn get_icmp_type(&self) -> IcmpType {
113 self.icmp_type
114 }
115
116 fn icmp_code(&mut self, icmp_code: IcmpCode) -> &mut Self {
117 self.icmp_code = icmp_code;
118 self
119 }
120
121 pub fn get_icmp_code(&self) -> IcmpCode {
123 self.icmp_code
124 }
125
126 fn size(&mut self, size: usize) -> &mut Self {
127 self.size = size;
128 self
129 }
130
131 pub fn get_size(&self) -> usize {
133 self.size
134 }
135
136 fn real_dest(&mut self, addr: Ipv4Addr) -> &mut Self {
137 self.real_dest = addr;
138 self
139 }
140
141 pub fn get_real_dest(&self) -> Ipv4Addr {
144 self.real_dest
145 }
146
147 fn identifier(&mut self, identifier: PingIdentifier) -> &mut Self {
148 self.identifier = identifier;
149 self
150 }
151
152 pub fn get_identifier(&self) -> PingIdentifier {
154 self.identifier
155 }
156
157 fn sequence(&mut self, sequence: PingSequence) -> &mut Self {
158 self.sequence = sequence;
159 self
160 }
161
162 pub fn get_sequence(&self) -> PingSequence {
164 self.sequence
165 }
166
167 pub fn decode(
169 buf: &[u8],
170 sock_type: SockType,
171 src_addr: Ipv4Addr,
172 dst_addr: Ipv4Addr,
173 ) -> Result<Self> {
174 if is_linux_icmp_socket!(sock_type) {
175 Self::decode_from_icmp(buf, src_addr, dst_addr)
176 } else {
177 Self::decode_from_ipv4(buf)
178 }
179 }
180
181 fn decode_from_ipv4(buf: &[u8]) -> Result<Self> {
182 let ipv4_packet = ipv4::Ipv4Packet::new(buf)
183 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
184 let icmp_packet = icmp::IcmpPacket::new(ipv4_packet.payload())
185 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
186 let mut packet = Icmpv4Packet::default();
187
188 match icmp_packet.get_icmp_type() {
189 icmp::IcmpTypes::EchoReply => {
190 let icmp_packet = icmp::echo_reply::EchoReplyPacket::new(icmp_packet.packet())
191 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
192
193 packet
194 .source(ipv4_packet.get_source())
195 .destination(ipv4_packet.get_destination())
196 .ttl(ipv4_packet.get_ttl())
197 .icmp_type(icmp_packet.get_icmp_type())
198 .icmp_code(icmp_packet.get_icmp_code())
199 .size(icmp_packet.packet().len())
200 .real_dest(ipv4_packet.get_source())
201 .identifier(icmp_packet.get_identifier().into())
202 .sequence(icmp_packet.get_sequence_number().into());
203 }
204 icmp::IcmpTypes::EchoRequest => return Err(SurgeError::EchoRequestPacket),
205 _ => {
206 let icmp_payload = icmp_packet.payload();
207
208 if icmp_payload.len() < 32 {
209 return Err(SurgeError::from(MalformedPacketError::PayloadTooShort {
210 got: icmp_payload.len(),
211 want: 32,
212 }));
213 }
214 let real_ip_packet = ipv4::Ipv4Packet::new(&icmp_payload[4..])
216 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
217 let identifier = u16::from_be_bytes(icmp_payload[28..30].try_into().unwrap());
218 let sequence = u16::from_be_bytes(icmp_payload[30..32].try_into().unwrap());
219
220 packet
221 .source(ipv4_packet.get_source())
222 .destination(ipv4_packet.get_destination())
223 .ttl(ipv4_packet.get_ttl())
224 .icmp_type(icmp_packet.get_icmp_type())
225 .icmp_code(icmp_packet.get_icmp_code())
226 .size(icmp_packet.packet_size())
227 .real_dest(real_ip_packet.get_destination())
228 .identifier(identifier.into())
229 .sequence(sequence.into());
230 }
231 }
232
233 Ok(packet)
234 }
235
236 fn decode_from_icmp(buf: &[u8], src_addr: Ipv4Addr, dst_addr: Ipv4Addr) -> Result<Self> {
237 let icmp_packet = icmp::IcmpPacket::new(buf)
238 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
239 let mut packet = Icmpv4Packet::default();
240
241 match icmp_packet.get_icmp_type() {
242 icmp::IcmpTypes::EchoReply => {
243 let icmp_packet = icmp::echo_reply::EchoReplyPacket::new(icmp_packet.packet())
244 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIcmpv4Packet))?;
245
246 packet
247 .source(src_addr)
248 .destination(dst_addr)
249 .icmp_type(icmp_packet.get_icmp_type())
250 .icmp_code(icmp_packet.get_icmp_code())
251 .size(icmp_packet.packet().len())
252 .real_dest(src_addr)
253 .identifier(icmp_packet.get_identifier().into())
254 .sequence(icmp_packet.get_sequence_number().into());
255 }
256 icmp::IcmpTypes::EchoRequest => return Err(SurgeError::EchoRequestPacket),
257 _ => {
258 let icmp_payload = icmp_packet.payload();
259
260 if icmp_payload.len() < 32 {
261 return Err(SurgeError::from(MalformedPacketError::PayloadTooShort {
262 got: icmp_payload.len(),
263 want: 32,
264 }));
265 }
266
267 let real_ip_packet = ipv4::Ipv4Packet::new(&icmp_payload[4..])
269 .ok_or_else(|| SurgeError::from(MalformedPacketError::NotIpv4Packet))?;
270 let identifier = u16::from_be_bytes(icmp_payload[28..30].try_into().unwrap());
271 let sequence = u16::from_be_bytes(icmp_payload[30..32].try_into().unwrap());
272
273 packet
274 .source(src_addr)
275 .destination(dst_addr)
276 .icmp_type(icmp_packet.get_icmp_type())
277 .icmp_code(icmp_packet.get_icmp_code())
278 .size(icmp_packet.packet_size())
279 .real_dest(real_ip_packet.get_destination())
280 .identifier(identifier.into())
281 .sequence(sequence.into());
282 }
283 }
284
285 Ok(packet)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::Icmpv4Packet;
293
294 #[test]
295 fn malformed_packet() {
296 let decoded_ipv4 =
297 hex::decode("4500001d0000000079018a76acd90e6e0a00f22203006c3293cc").unwrap();
298 assert!(Icmpv4Packet::decode(
299 &decoded_ipv4,
300 SockType::RAW,
301 ("172.217.14.110").parse().unwrap(),
302 ("10.0.242.34").parse().unwrap(),
303 )
304 .is_err());
305
306 let decoded_icmp = hex::decode("03006c3293cc").unwrap();
307 assert!(Icmpv4Packet::decode(
308 &decoded_icmp,
309 SockType::DGRAM,
310 ("172.217.14.110").parse().unwrap(),
311 ("10.0.242.34").parse().unwrap(),
312 )
313 .is_err());
314 }
315
316 #[test]
317 fn short_packet() {
318 let decoded_ipv4 =
319 hex::decode("4500001d0000000079018a76acd90e6e0a00f22203006c3293cc000100").unwrap();
320 assert!(Icmpv4Packet::decode(
321 &decoded_ipv4,
322 SockType::RAW,
323 ("172.217.14.110").parse().unwrap(),
324 ("10.0.242.34").parse().unwrap(),
325 )
326 .is_err());
327
328 let decoded_icmp = hex::decode("03006c3293cc000100").unwrap();
329 assert!(Icmpv4Packet::decode(
330 &decoded_icmp,
331 SockType::DGRAM,
332 ("172.217.14.110").parse().unwrap(),
333 ("10.0.242.34").parse().unwrap(),
334 )
335 .is_err());
336 }
337
338 #[test]
339 fn standard_packet() {
340 let decoded_ipv4 = hex::decode("45000054000000007901067e8efab00e0a00f22203004176a1ee0001613dd762000000002127040000000000101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031323334353637").unwrap();
341 Icmpv4Packet::decode(
342 &decoded_ipv4,
343 SockType::RAW,
344 ("172.217.14.110").parse().unwrap(),
345 ("10.0.242.34").parse().unwrap(),
346 )
347 .unwrap();
348
349 let decoded_icmp = hex::decode("03004176a1ee0001613dd762000000002127040000000000101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f3031323334353637").unwrap();
350 Icmpv4Packet::decode(
351 &decoded_icmp,
352 SockType::DGRAM,
353 ("172.217.14.110").parse().unwrap(),
354 ("10.0.242.34").parse().unwrap(),
355 )
356 .unwrap();
357 }
358}