1use std::{
2 mem::MaybeUninit,
3 net::{IpAddr, SocketAddr},
4 sync::atomic::{AtomicU16, Ordering},
5 time::{Duration, Instant},
6};
7
8use tokio::time::timeout;
9
10use crate::error::{Error, Result};
11use crate::icmp::{EchoReply, EchoRequest};
12use crate::socket::AsyncSocket;
13
14pub use crate::socket::SocketType;
15
16const DEFAULT_PAYLOAD_SIZE: usize = 56;
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
18const TOKEN_SIZE: usize = 8;
19
20static NEXT_IDENT: AtomicU16 = AtomicU16::new(1);
21
22#[derive(Clone, Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub struct PingResult {
25 pub reply: EchoReply,
26 pub rtt: Duration,
27 pub socket_type: SocketType,
28}
29
30#[derive(Clone, Debug, Eq, PartialEq)]
32pub struct PingRequest {
33 sequence: u16,
34 payload: Option<Vec<u8>>,
35}
36
37impl PingRequest {
38 pub fn new(sequence: u16) -> Self {
40 Self {
41 sequence,
42 payload: None,
43 }
44 }
45
46 pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
48 self.payload = Some(payload.into());
49 self
50 }
51
52 pub fn sequence(&self) -> u16 {
54 self.sequence
55 }
56
57 pub fn payload_bytes(&self) -> Option<&[u8]> {
59 self.payload.as_deref()
60 }
61}
62
63impl From<u16> for PingRequest {
64 fn from(sequence: u16) -> Self {
65 Self::new(sequence)
66 }
67}
68
69#[derive(Clone, Debug, Eq, PartialEq)]
71pub struct PingSeries {
72 start_sequence: u16,
73 count: usize,
74 interval: Duration,
75 payload: Option<Vec<u8>>,
76}
77
78impl PingSeries {
79 pub fn new(start_sequence: u16, count: usize) -> Self {
81 Self {
82 start_sequence,
83 count,
84 interval: Duration::ZERO,
85 payload: None,
86 }
87 }
88
89 pub fn interval(mut self, interval: Duration) -> Self {
91 self.interval = interval;
92 self
93 }
94
95 pub fn payload(mut self, payload: impl Into<Vec<u8>>) -> Self {
97 self.payload = Some(payload.into());
98 self
99 }
100}
101
102#[derive(Debug)]
104#[non_exhaustive]
105pub struct PingAttempt {
106 pub sequence: u16,
107 pub result: std::result::Result<PingResult, Error>,
108}
109
110#[derive(Clone, Debug, PartialEq)]
112#[non_exhaustive]
113pub struct PingSummary {
114 pub transmitted: usize,
115 pub received: usize,
116 pub loss: f64,
117 pub min_rtt: Option<Duration>,
118 pub avg_rtt: Option<Duration>,
119 pub max_rtt: Option<Duration>,
120}
121
122#[derive(Debug)]
124#[non_exhaustive]
125pub struct PingSeriesResult {
126 pub attempts: Vec<PingAttempt>,
127 pub summary: PingSummary,
128}
129
130#[derive(Debug, Clone)]
132pub struct Pinger {
133 target: SocketAddr,
134 source: Option<SocketAddr>,
135 ident: u16,
136 size: usize,
137 timeout: Duration,
138 ttl: Option<u32>,
139 socket: AsyncSocket,
140}
141
142impl Pinger {
143 pub fn new(host: IpAddr) -> Result<Pinger> {
145 Self::with_socket_type(host, SocketType::Raw)
146 }
147
148 pub fn with_socket_type(host: IpAddr, socket_type: SocketType) -> Result<Pinger> {
150 Self::with_socket_addr(SocketAddr::new(host, 0), socket_type)
151 }
152
153 pub fn with_socket_addr(target: SocketAddr, socket_type: SocketType) -> Result<Pinger> {
158 Ok(Pinger {
159 target,
160 source: None,
161 ident: default_ident(),
162 size: DEFAULT_PAYLOAD_SIZE,
163 timeout: DEFAULT_TIMEOUT,
164 ttl: None,
165 socket: AsyncSocket::new(target.ip(), socket_type)?,
166 })
167 }
168
169 pub fn socket_type(&mut self, socket_type: SocketType) -> Result<&mut Pinger> {
171 let socket = AsyncSocket::new(self.target.ip(), socket_type)?;
172 if let Some(source) = self.source {
173 socket.bind(&source.into())?;
174 }
175 if let Some(ttl) = self.ttl {
176 socket.set_ttl(self.target.ip(), ttl)?;
177 }
178 self.socket = socket;
179 Ok(self)
180 }
181
182 pub fn active_socket_type(&self) -> SocketType {
184 self.socket.socket_type()
185 }
186
187 pub fn target(&self) -> SocketAddr {
189 self.target
190 }
191
192 pub fn source(&self) -> Option<SocketAddr> {
194 self.source
195 }
196
197 pub fn bind_source(&mut self, source: SocketAddr) -> Result<&mut Pinger> {
201 let source = socket_addr_without_port(source);
202 self.socket.bind(&source.into())?;
203 self.source = Some(source);
204 Ok(self)
205 }
206
207 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
217 pub fn bind_device(&mut self, interface: Option<&[u8]>) -> Result<&mut Pinger> {
218 self.socket.bind_device(interface)?;
219 Ok(self)
220 }
221
222 pub fn ident(&mut self, val: u16) -> &mut Pinger {
224 self.ident = val;
225 self
226 }
227
228 pub fn identifier(&self) -> u16 {
230 self.ident
231 }
232
233 pub fn size(&mut self, size: usize) -> &mut Pinger {
235 self.size = size;
236 self
237 }
238
239 pub fn payload_size(&self) -> usize {
241 self.size
242 }
243
244 pub fn timeout(&mut self, timeout: Duration) -> &mut Pinger {
246 self.timeout = timeout;
247 self
248 }
249
250 pub fn timeout_duration(&self) -> Duration {
252 self.timeout
253 }
254
255 pub fn ttl(&mut self, ttl: u32) -> Result<&mut Pinger> {
257 self.socket.set_ttl(self.target.ip(), ttl)?;
258 self.ttl = Some(ttl);
259 Ok(self)
260 }
261
262 pub fn ttl_value(&self) -> Option<u32> {
264 self.ttl
265 }
266
267 async fn recv_reply(&self, request: &ResolvedPingRequest) -> Result<EchoReply> {
268 let mut buffer = [MaybeUninit::new(0); 2048];
269 loop {
270 let (size, source) = self.socket.recv_from(&mut buffer).await?;
271 let buf = unsafe { assume_init(&buffer[..size]) };
272 let source = source.map(|addr| addr.ip()).unwrap_or(self.target.ip());
273 let decoded = match self.socket.socket_type() {
274 SocketType::Raw if self.target.ip().is_ipv6() => EchoReply::decode_raw(source, buf),
275 SocketType::Raw => EchoReply::decode_raw(self.target.ip(), buf),
276 SocketType::Dgram => EchoReply::decode_dgram(source, buf),
277 };
278
279 match decoded {
280 Ok(reply) if self.reply_matches(&reply, request) => return Ok(reply),
281 Ok(_) => continue,
282 Err(Error::InvalidPacket)
283 | Err(Error::NotEchoReply)
284 | Err(Error::NotV6EchoReply)
285 | Err(Error::OtherICMP)
286 | Err(Error::UnknownProtocol) => continue,
287 Err(e) => return Err(e),
288 }
289 }
290 }
291
292 fn reply_matches(&self, reply: &EchoReply, request: &ResolvedPingRequest) -> bool {
293 if reply.sequence != request.sequence {
294 return false;
295 }
296
297 if self.socket.socket_type() == SocketType::Raw && reply.identifier != self.ident {
298 return false;
299 }
300
301 !request.match_payload || reply.payload == request.payload
302 }
303
304 async fn send_request(&self, request: &ResolvedPingRequest) -> Result<Instant> {
305 let packet = EchoRequest::new(self.target.ip(), self.ident, request.sequence)
306 .encode_with_payload(&request.payload)?;
307
308 let sent = Instant::now();
309 let size = self.socket.send_to(&packet, &self.target.into()).await?;
310 if size != packet.len() {
311 return Err(Error::InvalidSize);
312 }
313
314 Ok(sent)
315 }
316
317 pub async fn ping(&self, request: impl Into<PingRequest>) -> Result<PingResult> {
319 let request = self.resolve_request(request);
320 let sent = self.send_request(&request).await?;
321
322 let reply = timeout(self.timeout, self.recv_reply(&request))
323 .await
324 .map_err(|_| Error::Timeout)??;
325
326 Ok(PingResult {
327 reply,
328 rtt: sent.elapsed(),
329 socket_type: self.socket.socket_type(),
330 })
331 }
332
333 pub async fn ping_replies(&self, request: impl Into<PingRequest>) -> Result<Vec<PingResult>> {
340 let request = self.resolve_request(request);
341 let sent = self.send_request(&request).await?;
342 let deadline = sent + self.timeout;
343 let mut replies = Vec::new();
344
345 while let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
346 let reply = match timeout(remaining, self.recv_reply(&request)).await {
347 Ok(reply) => reply?,
348 Err(_) => break,
349 };
350
351 replies.push(PingResult {
352 reply,
353 rtt: sent.elapsed(),
354 socket_type: self.socket.socket_type(),
355 });
356 }
357
358 Ok(replies)
359 }
360
361 pub async fn ping_many(&self, series: PingSeries) -> PingSeriesResult {
363 let mut attempts = Vec::with_capacity(series.count);
364
365 for index in 0..series.count {
366 let sequence = series.start_sequence.wrapping_add(index as u16);
367 let request = match &series.payload {
368 Some(payload) => PingRequest::new(sequence).payload(payload.clone()),
369 None => PingRequest::new(sequence),
370 };
371 let result = self.ping(request).await;
372 attempts.push(PingAttempt { sequence, result });
373
374 if index + 1 < series.count && !series.interval.is_zero() {
375 tokio::time::sleep(series.interval).await;
376 }
377 }
378
379 let summary = PingSummary::from_attempts(&attempts);
380 PingSeriesResult { attempts, summary }
381 }
382
383 fn resolve_request(&self, request: impl Into<PingRequest>) -> ResolvedPingRequest {
384 resolve_ping_request(self.ident, self.size, request.into())
385 }
386}
387
388impl PingSummary {
389 fn from_attempts(attempts: &[PingAttempt]) -> Self {
390 let transmitted = attempts.len();
391 let rtts: Vec<Duration> = attempts
392 .iter()
393 .filter_map(|attempt| attempt.result.as_ref().ok().map(|result| result.rtt))
394 .collect();
395 let received = rtts.len();
396 let loss = if transmitted == 0 {
397 0.0
398 } else {
399 ((transmitted - received) as f64 / transmitted as f64) * 100.0
400 };
401 let min_rtt = rtts.iter().copied().min();
402 let max_rtt = rtts.iter().copied().max();
403 let avg_rtt = average_duration(&rtts);
404
405 Self {
406 transmitted,
407 received,
408 loss,
409 min_rtt,
410 avg_rtt,
411 max_rtt,
412 }
413 }
414}
415
416struct ResolvedPingRequest {
417 sequence: u16,
418 payload: Vec<u8>,
419 match_payload: bool,
420}
421
422fn resolve_ping_request(
423 ident: u16,
424 default_payload_size: usize,
425 request: PingRequest,
426) -> ResolvedPingRequest {
427 match request.payload {
428 Some(payload) => ResolvedPingRequest {
429 sequence: request.sequence,
430 payload,
431 match_payload: true,
432 },
433 None => {
434 let payload = request_payload(ident, request.sequence, default_payload_size);
435 let match_payload = !payload.is_empty();
436 ResolvedPingRequest {
437 sequence: request.sequence,
438 payload,
439 match_payload,
440 }
441 }
442 }
443}
444
445fn average_duration(durations: &[Duration]) -> Option<Duration> {
446 let total: u128 = durations.iter().map(Duration::as_nanos).sum();
447 let average = total.checked_div(durations.len() as u128)?;
448 Some(Duration::from_nanos(average.min(u64::MAX as u128) as u64))
449}
450
451fn default_ident() -> u16 {
452 let pid = std::process::id() as u16;
453 let next = NEXT_IDENT.fetch_add(1, Ordering::Relaxed);
454 pid.wrapping_add(next)
455}
456
457fn socket_addr_without_port(addr: SocketAddr) -> SocketAddr {
458 match addr {
459 SocketAddr::V4(mut addr) => {
460 addr.set_port(0);
461 SocketAddr::V4(addr)
462 }
463 SocketAddr::V6(mut addr) => {
464 addr.set_port(0);
465 SocketAddr::V6(addr)
466 }
467 }
468}
469
470fn request_payload(ident: u16, seq_cnt: u16, size: usize) -> Vec<u8> {
471 let mut payload = vec![0; size];
472 let token = [
473 b't',
474 b'p',
475 (ident >> 8) as u8,
476 ident as u8,
477 (seq_cnt >> 8) as u8,
478 seq_cnt as u8,
479 (size >> 8) as u8,
480 size as u8,
481 ];
482 let len = payload.len().min(TOKEN_SIZE);
483 payload[..len].copy_from_slice(&token[..len]);
484 payload
485}
486
487unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
493 unsafe { &*(buf as *const [MaybeUninit<u8>] as *const [u8]) }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use std::net::Ipv4Addr;
500
501 #[test]
502 fn request_payload_respects_size() {
503 assert_eq!(request_payload(1, 2, 0), Vec::<u8>::new());
504 assert_eq!(request_payload(1, 2, 4), vec![b't', b'p', 0, 1]);
505 assert_eq!(request_payload(1, 2, 8), vec![b't', b'p', 0, 1, 0, 2, 0, 8]);
506 }
507
508 #[test]
509 fn ping_request_from_sequence_uses_default_payload() {
510 let request = PingRequest::from(7);
511
512 assert_eq!(request.sequence(), 7);
513 assert_eq!(request.payload_bytes(), None);
514 }
515
516 #[test]
517 fn ping_request_keeps_custom_payload() {
518 let request = PingRequest::new(9).payload(b"hello");
519
520 assert_eq!(request.sequence(), 9);
521 assert_eq!(request.payload_bytes(), Some(b"hello".as_slice()));
522 }
523
524 #[test]
525 fn default_request_with_empty_generated_payload_matches_any_payload() {
526 let request = resolve_ping_request(1, 0, PingRequest::new(2));
527
528 assert_eq!(request.sequence, 2);
529 assert!(request.payload.is_empty());
530 assert!(!request.match_payload);
531 }
532
533 #[test]
534 fn custom_empty_payload_matches_exactly() {
535 let request = resolve_ping_request(1, 56, PingRequest::new(2).payload(Vec::new()));
536
537 assert_eq!(request.sequence, 2);
538 assert!(request.payload.is_empty());
539 assert!(request.match_payload);
540 }
541
542 #[test]
543 fn ping_summary_counts_successes_and_rtts() {
544 let attempts = vec![
545 successful_attempt(1, Duration::from_millis(10)),
546 PingAttempt {
547 sequence: 2,
548 result: Err(Error::Timeout),
549 },
550 successful_attempt(3, Duration::from_millis(30)),
551 ];
552
553 let summary = PingSummary::from_attempts(&attempts);
554
555 assert_eq!(summary.transmitted, 3);
556 assert_eq!(summary.received, 2);
557 assert!((summary.loss - (100.0 / 3.0)).abs() < 1e-12);
558 assert_eq!(summary.min_rtt, Some(Duration::from_millis(10)));
559 assert_eq!(summary.avg_rtt, Some(Duration::from_millis(20)));
560 assert_eq!(summary.max_rtt, Some(Duration::from_millis(30)));
561 }
562
563 #[test]
564 fn empty_ping_summary_has_no_rtts() {
565 let summary = PingSummary::from_attempts(&[]);
566
567 assert_eq!(summary.transmitted, 0);
568 assert_eq!(summary.received, 0);
569 assert_eq!(summary.loss, 0.0);
570 assert_eq!(summary.min_rtt, None);
571 assert_eq!(summary.avg_rtt, None);
572 assert_eq!(summary.max_rtt, None);
573 }
574
575 #[test]
576 fn socket_addr_without_port_preserves_ipv6_scope() {
577 let addr = "[fe80::1%4]:1234".parse().unwrap();
578
579 assert_eq!(socket_addr_without_port(addr).to_string(), "[fe80::1%4]:0");
580 }
581
582 fn successful_attempt(sequence: u16, rtt: Duration) -> PingAttempt {
583 PingAttempt {
584 sequence,
585 result: Ok(PingResult {
586 reply: EchoReply {
587 ttl: Some(64),
588 source: IpAddr::V4(Ipv4Addr::LOCALHOST),
589 sequence,
590 identifier: 1,
591 payload_len: 0,
592 payload: Vec::new(),
593 #[allow(deprecated)]
594 size: 0,
595 },
596 rtt,
597 socket_type: SocketType::Dgram,
598 }),
599 }
600 }
601}