surge_ping/
client.rs

1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
5
6use std::{
7    collections::HashMap,
8    io,
9    net::{IpAddr, SocketAddr},
10    sync::atomic::{AtomicBool, Ordering},
11    sync::Arc,
12    time::Instant,
13};
14
15use parking_lot::Mutex;
16use socket2::{Domain, Protocol, Socket, Type as SockType};
17use tokio::{
18    net::UdpSocket,
19    sync::oneshot,
20    task::{self, JoinHandle},
21};
22use tracing::debug;
23
24use crate::{
25    config::Config,
26    icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
27    IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
28};
29
30// Check, if the platform's socket operates with ICMP packets in a casual way
31#[macro_export]
32macro_rules! is_linux_icmp_socket {
33    ($sock_type:expr) => {
34        if ($sock_type == socket2::Type::DGRAM
35            && cfg!(not(any(target_os = "linux", target_os = "android"))))
36            || $sock_type == socket2::Type::RAW
37        {
38            false
39        } else {
40            true
41        }
42    };
43}
44
45#[derive(Clone)]
46pub struct AsyncSocket {
47    inner: Arc<UdpSocket>,
48    sock_type: SockType,
49}
50
51impl AsyncSocket {
52    pub fn new(config: &Config) -> io::Result<Self> {
53        let (sock_type, socket) = Self::create_socket(config)?;
54
55        socket.set_nonblocking(true)?;
56        if let Some(sock_addr) = &config.bind {
57            socket.bind(sock_addr)?;
58        }
59        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
60        if let Some(interface) = &config.interface {
61            socket.bind_device(Some(interface.as_bytes()))?;
62        }
63        #[cfg(any(
64            target_os = "ios",
65            target_os = "visionos",
66            target_os = "macos",
67            target_os = "tvos",
68            target_os = "watchos",
69            target_os = "illumos",
70            target_os = "solaris",
71            target_os = "linux",
72            target_os = "android",
73        ))]
74        {
75            if config.interface_index.is_some() {
76                match config.kind {
77                    ICMP::V4 => socket.bind_device_by_index_v4(config.interface_index)?,
78                    ICMP::V6 => socket.bind_device_by_index_v6(config.interface_index)?,
79                }
80            }
81        }
82        if let Some(ttl) = config.ttl {
83            match config.kind {
84                ICMP::V4 => socket.set_ttl_v4(ttl)?,
85                ICMP::V6 => socket.set_unicast_hops_v6(ttl)?,
86            }
87        }
88        #[cfg(target_os = "freebsd")]
89        if let Some(fib) = config.fib {
90            socket.set_fib(fib)?;
91        }
92        #[cfg(windows)]
93        let socket = UdpSocket::from_std(unsafe {
94            std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
95        })?;
96        #[cfg(unix)]
97        let socket =
98            UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
99        Ok(Self {
100            inner: Arc::new(socket),
101            sock_type,
102        })
103    }
104
105    fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
106        let (domain, proto) = match config.kind {
107            ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
108            ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
109        };
110
111        match Socket::new(domain, config.sock_type_hint, proto) {
112            Ok(sock) => Ok((config.sock_type_hint, sock)),
113            Err(err) => {
114                let new_type = if config.sock_type_hint == SockType::DGRAM {
115                    SockType::RAW
116                } else {
117                    SockType::DGRAM
118                };
119
120                debug!(
121                    "error opening {:?} type socket, trying {:?}: {:?}",
122                    config.sock_type_hint, new_type, err
123                );
124
125                Ok((new_type, Socket::new(domain, new_type, proto)?))
126            }
127        }
128    }
129
130    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
131        self.inner.recv_from(buf).await
132    }
133
134    pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
135        self.inner.send_to(buf, target).await
136    }
137
138    pub fn local_addr(&self) -> io::Result<SocketAddr> {
139        self.inner.local_addr()
140    }
141
142    pub fn get_type(&self) -> SockType {
143        self.sock_type
144    }
145
146    #[cfg(unix)]
147    pub fn get_native_sock(&self) -> RawFd {
148        self.inner.as_raw_fd()
149    }
150
151    #[cfg(windows)]
152    pub fn get_native_sock(&self) -> RawSocket {
153        self.inner.as_raw_socket()
154    }
155}
156
157#[derive(PartialEq, Eq, Hash)]
158struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
159
160pub(crate) struct Reply {
161    pub timestamp: Instant,
162    pub packet: IcmpPacket,
163}
164
165#[derive(Clone)]
166pub(crate) struct ReplyMap {
167    inner: Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>,
168    alive: Arc<AtomicBool>,
169}
170
171impl Default for ReplyMap {
172    fn default() -> Self {
173        Self {
174            inner: Arc::new(Mutex::new(HashMap::new())),
175            alive: Arc::new(AtomicBool::new(true)),
176        }
177    }
178}
179
180impl ReplyMap {
181    /// Register to wait for a reply from host with ident and sequence number.
182    /// If there is already someone waiting for this specific reply then an
183    /// error is returned.
184    pub fn new_waiter(
185        &self,
186        host: IpAddr,
187        ident: Option<PingIdentifier>,
188        seq: PingSequence,
189    ) -> Result<oneshot::Receiver<Reply>, SurgeError> {
190        if !self.alive.load(Ordering::Relaxed) {
191            return Err(SurgeError::ClientDestroyed);
192        }
193        let (tx, rx) = oneshot::channel();
194        if self
195            .inner
196            .lock()
197            .insert(ReplyToken(host, ident, seq), tx)
198            .is_some()
199        {
200            return Err(SurgeError::IdenticalRequests { host, ident, seq });
201        }
202        Ok(rx)
203    }
204
205    /// Remove a waiter.
206    pub(crate) fn remove(
207        &self,
208        host: IpAddr,
209        ident: Option<PingIdentifier>,
210        seq: PingSequence,
211    ) -> Option<oneshot::Sender<Reply>> {
212        self.inner.lock().remove(&ReplyToken(host, ident, seq))
213    }
214
215    /// Mark the client as destroyed. This is called when the Client is dropped.
216    pub(crate) fn mark_destroyed(&self) {
217        self.alive.store(false, Ordering::Relaxed);
218    }
219}
220
221///
222/// If you want to pass the `Client` in the task, please wrap it with `Arc`: `Arc<Client>`.
223/// and can realize the simultaneous ping of multiple addresses when only one `socket` is created.
224///
225#[derive(Clone)]
226pub struct Client {
227    socket: AsyncSocket,
228    reply_map: ReplyMap,
229    recv: Arc<JoinHandle<()>>,
230}
231
232impl Drop for Client {
233    fn drop(&mut self) {
234        // Mark the reply_map as destroyed so any pending or new ping operations
235        // will fail with ClientDestroyed error instead of timing out.
236        self.reply_map.mark_destroyed();
237        // The client may pass through multiple tasks, so need to judge whether the number of references is 1.
238        if Arc::strong_count(&self.recv) <= 1 {
239            self.recv.abort();
240        }
241    }
242}
243
244impl Client {
245    /// A client is generated according to the configuration. In fact, a `AsyncSocket` is wrapped inside,
246    /// and you can clone to any `task` at will.
247    pub fn new(config: &Config) -> io::Result<Self> {
248        let socket = AsyncSocket::new(config)?;
249        let reply_map = ReplyMap::default();
250        let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
251        Ok(Self {
252            socket,
253            reply_map,
254            recv: Arc::new(recv),
255        })
256    }
257
258    /// Create a `Pinger` instance, you can make special configuration for this instance.
259    pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
260        Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
261    }
262
263    /// Expose the underlying socket, if user wants to modify any options on it
264    pub fn get_socket(&self) -> AsyncSocket {
265        self.socket.clone()
266    }
267}
268
269async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
270    let mut buf = [0; 2048];
271    loop {
272        if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
273            let timestamp = Instant::now();
274            let message = &buf[..sz];
275            let local_addr = socket.local_addr().unwrap().ip();
276            let packet = {
277                let result = match addr.ip() {
278                    IpAddr::V4(src_addr) => {
279                        let local_addr_ip4 = match local_addr {
280                            IpAddr::V4(local_addr_ip4) => local_addr_ip4,
281                            _ => continue,
282                        };
283
284                        Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
285                            .map(IcmpPacket::V4)
286                    }
287                    IpAddr::V6(src_addr) => {
288                        Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
289                    }
290                };
291                match result {
292                    Ok(packet) => packet,
293                    Err(err) => {
294                        debug!("error decoding ICMP packet: {:?}", err);
295                        continue;
296                    }
297                }
298            };
299
300            let ident = if is_linux_icmp_socket!(socket.get_type()) {
301                None
302            } else {
303                Some(packet.get_identifier())
304            };
305
306            if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
307                // If send fails the receiving end has closed. Nothing to do.
308                let _ = waiter.send(Reply { timestamp, packet });
309            } else {
310                debug!("no one is waiting for ICMP packet ({:?})", packet);
311            }
312        }
313    }
314}