1use std::{
2 mem::MaybeUninit,
3 net::{IpAddr, SocketAddr},
4 sync::atomic::{AtomicU16, Ordering},
5 time::{Duration, Instant},
6};
7
8use tokio::time::timeout;
9
10use crate::error::{Error, Result};
11use crate::icmp::{EchoReply, EchoRequest};
12use crate::socket::AsyncSocket;
13
14pub use crate::socket::SocketType;
15
16const DEFAULT_PAYLOAD_SIZE: usize = 56;
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
18const TOKEN_SIZE: usize = 8;
19
20static NEXT_IDENT: AtomicU16 = AtomicU16::new(1);
21
22#[derive(Clone, Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub struct PingResult {
25 pub reply: EchoReply,
26 pub rtt: Duration,
27 pub socket_type: SocketType,
28}
29
30#[derive(Debug, Clone)]
32pub struct Pinger {
33 target: SocketAddr,
34 ident: u16,
35 size: usize,
36 timeout: Duration,
37 ttl: Option<u32>,
38 socket: AsyncSocket,
39}
40
41impl Pinger {
42 pub fn new(host: IpAddr) -> Result<Pinger> {
44 Self::with_socket_type(host, SocketType::Raw)
45 }
46
47 pub fn with_socket_type(host: IpAddr, socket_type: SocketType) -> Result<Pinger> {
49 Self::with_socket_addr(SocketAddr::new(host, 0), socket_type)
50 }
51
52 pub fn with_socket_addr(target: SocketAddr, socket_type: SocketType) -> Result<Pinger> {
57 Ok(Pinger {
58 target,
59 ident: default_ident(),
60 size: DEFAULT_PAYLOAD_SIZE,
61 timeout: DEFAULT_TIMEOUT,
62 ttl: None,
63 socket: AsyncSocket::new(target.ip(), socket_type)?,
64 })
65 }
66
67 pub fn socket_type(&mut self, socket_type: SocketType) -> Result<&mut Pinger> {
69 let socket = AsyncSocket::new(self.target.ip(), socket_type)?;
70 if let Some(ttl) = self.ttl {
71 socket.set_ttl(self.target.ip(), ttl)?;
72 }
73 self.socket = socket;
74 Ok(self)
75 }
76
77 pub fn active_socket_type(&self) -> SocketType {
79 self.socket.socket_type()
80 }
81
82 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
92 pub fn bind_device(&mut self, interface: Option<&[u8]>) -> Result<&mut Pinger> {
93 self.socket.bind_device(interface)?;
94 Ok(self)
95 }
96
97 pub fn ident(&mut self, val: u16) -> &mut Pinger {
99 self.ident = val;
100 self
101 }
102
103 pub fn size(&mut self, size: usize) -> &mut Pinger {
105 self.size = size;
106 self
107 }
108
109 pub fn timeout(&mut self, timeout: Duration) -> &mut Pinger {
111 self.timeout = timeout;
112 self
113 }
114
115 pub fn ttl(&mut self, ttl: u32) -> Result<&mut Pinger> {
117 self.socket.set_ttl(self.target.ip(), ttl)?;
118 self.ttl = Some(ttl);
119 Ok(self)
120 }
121
122 async fn recv_reply(&self, seq_cnt: u16, payload: &[u8]) -> Result<EchoReply> {
123 let mut buffer = [MaybeUninit::new(0); 2048];
124 loop {
125 let (size, source) = self.socket.recv_from(&mut buffer).await?;
126 let buf = unsafe { assume_init(&buffer[..size]) };
127 let source = source.map(|addr| addr.ip()).unwrap_or(self.target.ip());
128 let decoded = match self.socket.socket_type() {
129 SocketType::Raw if self.target.ip().is_ipv6() => EchoReply::decode_raw(source, buf),
130 SocketType::Raw => EchoReply::decode_raw(self.target.ip(), buf),
131 SocketType::Dgram => EchoReply::decode_dgram(source, buf),
132 };
133
134 match decoded {
135 Ok(reply) if self.reply_matches(&reply, seq_cnt, payload) => return Ok(reply),
136 Ok(_) => continue,
137 Err(Error::InvalidPacket)
138 | Err(Error::NotEchoReply)
139 | Err(Error::NotV6EchoReply)
140 | Err(Error::OtherICMP)
141 | Err(Error::UnknownProtocol) => continue,
142 Err(e) => return Err(e),
143 }
144 }
145 }
146
147 fn reply_matches(&self, reply: &EchoReply, seq_cnt: u16, payload: &[u8]) -> bool {
148 if reply.sequence != seq_cnt {
149 return false;
150 }
151
152 if self.socket.socket_type() == SocketType::Raw && reply.identifier != self.ident {
153 return false;
154 }
155
156 payload.is_empty() || reply.payload == payload
157 }
158
159 async fn send_request(&self, seq_cnt: u16, payload: &[u8]) -> Result<Instant> {
160 let packet =
161 EchoRequest::new(self.target.ip(), self.ident, seq_cnt).encode_with_payload(payload)?;
162
163 let sent = Instant::now();
164 let size = self.socket.send_to(&packet, &self.target.into()).await?;
165 if size != packet.len() {
166 return Err(Error::InvalidSize);
167 }
168
169 Ok(sent)
170 }
171
172 pub async fn ping(&self, seq_cnt: u16) -> Result<PingResult> {
174 let payload = request_payload(self.ident, seq_cnt, self.size);
175 let sent = self.send_request(seq_cnt, &payload).await?;
176
177 let reply = timeout(self.timeout, self.recv_reply(seq_cnt, &payload))
178 .await
179 .map_err(|_| Error::Timeout)??;
180
181 Ok(PingResult {
182 reply,
183 rtt: sent.elapsed(),
184 socket_type: self.socket.socket_type(),
185 })
186 }
187
188 pub async fn ping_replies(&self, seq_cnt: u16) -> Result<Vec<PingResult>> {
195 let payload = request_payload(self.ident, seq_cnt, self.size);
196 let sent = self.send_request(seq_cnt, &payload).await?;
197 let deadline = sent + self.timeout;
198 let mut replies = Vec::new();
199
200 while let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
201 let reply = match timeout(remaining, self.recv_reply(seq_cnt, &payload)).await {
202 Ok(reply) => reply?,
203 Err(_) => break,
204 };
205
206 replies.push(PingResult {
207 reply,
208 rtt: sent.elapsed(),
209 socket_type: self.socket.socket_type(),
210 });
211 }
212
213 Ok(replies)
214 }
215}
216
217fn default_ident() -> u16 {
218 let pid = std::process::id() as u16;
219 let next = NEXT_IDENT.fetch_add(1, Ordering::Relaxed);
220 pid.wrapping_add(next)
221}
222
223fn request_payload(ident: u16, seq_cnt: u16, size: usize) -> Vec<u8> {
224 let mut payload = vec![0; size];
225 let token = [
226 b't',
227 b'p',
228 (ident >> 8) as u8,
229 ident as u8,
230 (seq_cnt >> 8) as u8,
231 seq_cnt as u8,
232 (size >> 8) as u8,
233 size as u8,
234 ];
235 let len = payload.len().min(TOKEN_SIZE);
236 payload[..len].copy_from_slice(&token[..len]);
237 payload
238}
239
240unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
246 unsafe { &*(buf as *const [MaybeUninit<u8>] as *const [u8]) }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn request_payload_respects_size() {
255 assert_eq!(request_payload(1, 2, 0), Vec::<u8>::new());
256 assert_eq!(request_payload(1, 2, 4), vec![b't', b'p', 0, 1]);
257 assert_eq!(request_payload(1, 2, 8), vec![b't', b'p', 0, 1, 0, 2, 0, 8]);
258 }
259}