1use crate::error::Error;
2use std::io::Write;
3
4pub const HEADER_SIZE: usize = 8;
5
6pub struct IcmpV4;
7pub struct IcmpV6;
8
9pub trait Proto {
10 const ECHO_REQUEST_TYPE: u8;
11 const ECHO_REQUEST_CODE: u8;
12 const ECHO_REPLY_TYPE: u8;
13 const ECHO_REPLY_CODE: u8;
14}
15
16impl Proto for IcmpV4 {
17 const ECHO_REQUEST_TYPE: u8 = 8;
18 const ECHO_REQUEST_CODE: u8 = 0;
19 const ECHO_REPLY_TYPE: u8 = 0;
20 const ECHO_REPLY_CODE: u8 = 0;
21}
22
23impl Proto for IcmpV6 {
24 const ECHO_REQUEST_TYPE: u8 = 128;
25 const ECHO_REQUEST_CODE: u8 = 0;
26 const ECHO_REPLY_TYPE: u8 = 129;
27 const ECHO_REPLY_CODE: u8 = 0;
28}
29
30pub struct EchoRequest {
31 pub ident: u16,
32 pub seq_cnt: u16,
33}
34
35impl EchoRequest {
36 pub fn encode<P: Proto>(&self, buffer: &mut [u8], payload: &[u8]) -> Result<Vec<u8>, Error> {
37 if buffer.len() < HEADER_SIZE + payload.len() {
38 return Err(Error::InvalidSize);
39 }
40
41 buffer[0] = P::ECHO_REQUEST_TYPE;
42 buffer[1] = P::ECHO_REQUEST_CODE;
43 buffer[2] = 0;
44 buffer[3] = 0;
45
46 buffer[4] = (self.ident >> 8) as u8;
47 buffer[5] = self.ident as u8;
48 buffer[6] = (self.seq_cnt >> 8) as u8;
49 buffer[7] = self.seq_cnt as u8;
50
51 if (&mut buffer[HEADER_SIZE..HEADER_SIZE + payload.len()])
52 .write_all(payload)
53 .is_err()
54 {
55 return Err(Error::InvalidSize);
56 }
57
58 write_checksum(&mut buffer[..HEADER_SIZE + payload.len()]);
59 Ok(buffer.to_vec())
60 }
61}
62
63pub struct EchoReply<'a> {
64 pub ident: u16,
65 pub seq_cnt: u16,
66 pub payload: &'a [u8],
67}
68
69impl<'a> EchoReply<'a> {
70 pub fn decode<P: Proto>(buffer: &'a [u8]) -> Result<Self, Error> {
71 if buffer.as_ref().len() < HEADER_SIZE {
72 return Err(Error::InvalidSize);
73 }
74
75 let type_ = buffer[0];
76 let code = buffer[1];
77 if type_ != P::ECHO_REPLY_TYPE || code != P::ECHO_REPLY_CODE {
78 return Err(Error::InvalidPacket);
79 }
80
81 let ident = (u16::from(buffer[4]) << 8) + u16::from(buffer[5]);
82 let seq_cnt = (u16::from(buffer[6]) << 8) + u16::from(buffer[7]);
83
84 let payload = &buffer[HEADER_SIZE..];
85
86 Ok(EchoReply {
87 ident,
88 seq_cnt,
89 payload,
90 })
91 }
92}
93
94fn write_checksum(buffer: &mut [u8]) {
95 let mut sum = 0u32;
96 for word in buffer.chunks(2) {
97 let mut part = u16::from(word[0]) << 8;
98 if word.len() > 1 {
99 part += u16::from(word[1]);
100 }
101 sum = sum.wrapping_add(u32::from(part));
102 }
103
104 while (sum >> 16) > 0 {
105 sum = (sum & 0xffff) + (sum >> 16);
106 }
107
108 let sum = !sum as u16;
109
110 buffer[2] = (sum >> 8) as u8;
111 buffer[3] = (sum & 0xff) as u8;
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn encodes_echo_request_with_valid_checksum() {
120 let request = EchoRequest {
121 ident: 0x1234,
122 seq_cnt: 7,
123 };
124 let payload = [1, 2, 3, 4];
125 let mut buffer = [0; HEADER_SIZE + 4];
126
127 let encoded = request.encode::<IcmpV4>(&mut buffer, &payload).unwrap();
128
129 assert_eq!(encoded[0], IcmpV4::ECHO_REQUEST_TYPE);
130 assert_eq!(encoded[1], IcmpV4::ECHO_REQUEST_CODE);
131 assert_eq!(&encoded[4..8], &[0x12, 0x34, 0, 7]);
132 assert_eq!(&encoded[8..], &payload);
133 assert_eq!(checksum_sum(&encoded), 0xffff);
134 }
135
136 #[test]
137 fn rejects_wrong_echo_reply_code() {
138 let packet = [IcmpV4::ECHO_REPLY_TYPE, 1, 0, 0, 0, 1, 0, 1];
139
140 assert!(matches!(
141 EchoReply::decode::<IcmpV4>(&packet),
142 Err(Error::InvalidPacket)
143 ));
144 }
145
146 #[test]
147 fn decodes_icmpv6_echo_reply() {
148 let packet = [IcmpV6::ECHO_REPLY_TYPE, 0, 0, 0, 0x12, 0x34, 0, 9, 1, 2];
149 let reply = EchoReply::decode::<IcmpV6>(&packet).unwrap();
150
151 assert_eq!(reply.ident, 0x1234);
152 assert_eq!(reply.seq_cnt, 9);
153 assert_eq!(reply.payload, &[1, 2]);
154 }
155
156 fn checksum_sum(buffer: &[u8]) -> u16 {
157 let mut sum = 0u32;
158 for word in buffer.chunks(2) {
159 let mut part = u16::from(word[0]) << 8;
160 if word.len() > 1 {
161 part += u16::from(word[1]);
162 }
163 sum = sum.wrapping_add(u32::from(part));
164 }
165 while (sum >> 16) > 0 {
166 sum = (sum & 0xffff) + (sum >> 16);
167 }
168 sum as u16
169 }
170}