surge_ping/
client.rs

1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
5
6use std::{
7    collections::HashMap,
8    io,
9    net::{IpAddr, SocketAddr},
10    sync::Arc,
11    time::Instant,
12};
13
14use parking_lot::Mutex;
15use socket2::{Domain, Protocol, Socket, Type as SockType};
16use tokio::{
17    net::UdpSocket,
18    sync::oneshot,
19    task::{self, JoinHandle},
20};
21use tracing::debug;
22
23use crate::{
24    config::Config,
25    icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
26    IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
27};
28
29// Check, if the platform's socket operates with ICMP packets in a casual way
30#[macro_export]
31macro_rules! is_linux_icmp_socket {
32    ($sock_type:expr) => {
33        if ($sock_type == socket2::Type::DGRAM
34            && cfg!(not(any(target_os = "linux", target_os = "android"))))
35            || $sock_type == socket2::Type::RAW
36        {
37            false
38        } else {
39            true
40        }
41    };
42}
43
44#[derive(Clone)]
45pub struct AsyncSocket {
46    inner: Arc<UdpSocket>,
47    sock_type: SockType,
48}
49
50impl AsyncSocket {
51    pub fn new(config: &Config) -> io::Result<Self> {
52        let (sock_type, socket) = Self::create_socket(config)?;
53
54        socket.set_nonblocking(true)?;
55        if let Some(sock_addr) = &config.bind {
56            socket.bind(sock_addr)?;
57        }
58        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
59        if let Some(interface) = &config.interface {
60            socket.bind_device(Some(interface.as_bytes()))?;
61        }
62        if let Some(ttl) = config.ttl {
63            socket.set_ttl(ttl)?;
64        }
65        #[cfg(target_os = "freebsd")]
66        if let Some(fib) = config.fib {
67            socket.set_fib(fib)?;
68        }
69        #[cfg(windows)]
70        let socket = UdpSocket::from_std(unsafe {
71            std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
72        })?;
73        #[cfg(unix)]
74        let socket =
75            UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
76        Ok(Self {
77            inner: Arc::new(socket),
78            sock_type,
79        })
80    }
81
82    fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
83        let (domain, proto) = match config.kind {
84            ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
85            ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
86        };
87
88        match Socket::new(domain, config.sock_type_hint, proto) {
89            Ok(sock) => Ok((config.sock_type_hint, sock)),
90            Err(err) => {
91                let new_type = if config.sock_type_hint == SockType::DGRAM {
92                    SockType::RAW
93                } else {
94                    SockType::DGRAM
95                };
96
97                debug!(
98                    "error opening {:?} type socket, trying {:?}: {:?}",
99                    config.sock_type_hint, new_type, err
100                );
101
102                Ok((new_type, Socket::new(domain, new_type, proto)?))
103            }
104        }
105    }
106
107    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
108        self.inner.recv_from(buf).await
109    }
110
111    pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
112        self.inner.send_to(buf, target).await
113    }
114
115    pub fn local_addr(&self) -> io::Result<SocketAddr> {
116        self.inner.local_addr()
117    }
118
119    pub fn get_type(&self) -> SockType {
120        self.sock_type
121    }
122
123    #[cfg(unix)]
124    pub fn get_native_sock(&self) -> RawFd {
125        self.inner.as_raw_fd()
126    }
127
128    #[cfg(windows)]
129    pub fn get_native_sock(&self) -> RawSocket {
130        self.inner.as_raw_socket()
131    }
132}
133
134#[derive(PartialEq, Eq, Hash)]
135struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
136
137pub(crate) struct Reply {
138    pub timestamp: Instant,
139    pub packet: IcmpPacket,
140}
141
142#[derive(Clone, Default)]
143pub(crate) struct ReplyMap(Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>);
144
145impl ReplyMap {
146    /// Register to wait for a reply from host with ident and sequence number.
147    /// If there is already someone waiting for this specific reply then an
148    /// error is returned.
149    pub fn new_waiter(
150        &self,
151        host: IpAddr,
152        ident: Option<PingIdentifier>,
153        seq: PingSequence,
154    ) -> Result<oneshot::Receiver<Reply>, SurgeError> {
155        let (tx, rx) = oneshot::channel();
156        if self
157            .0
158            .lock()
159            .insert(ReplyToken(host, ident, seq), tx)
160            .is_some()
161        {
162            return Err(SurgeError::IdenticalRequests { host, ident, seq });
163        }
164        Ok(rx)
165    }
166
167    /// Remove a waiter.
168    pub(crate) fn remove(
169        &self,
170        host: IpAddr,
171        ident: Option<PingIdentifier>,
172        seq: PingSequence,
173    ) -> Option<oneshot::Sender<Reply>> {
174        self.0.lock().remove(&ReplyToken(host, ident, seq))
175    }
176}
177
178///
179/// If you want to pass the `Client` in the task, please wrap it with `Arc`: `Arc<Client>`.
180/// and can realize the simultaneous ping of multiple addresses when only one `socket` is created.
181///
182#[derive(Clone)]
183pub struct Client {
184    socket: AsyncSocket,
185    reply_map: ReplyMap,
186    recv: Arc<JoinHandle<()>>,
187}
188
189impl Drop for Client {
190    fn drop(&mut self) {
191        // The client may pass through multiple tasks, so need to judge whether the number of references is 1.
192        if Arc::strong_count(&self.recv) <= 1 {
193            self.recv.abort();
194        }
195    }
196}
197
198impl Client {
199    /// A client is generated according to the configuration. In fact, a `AsyncSocket` is wrapped inside,
200    /// and you can clone to any `task` at will.
201    pub fn new(config: &Config) -> io::Result<Self> {
202        let socket = AsyncSocket::new(config)?;
203        let reply_map = ReplyMap::default();
204        let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
205        Ok(Self {
206            socket,
207            reply_map,
208            recv: Arc::new(recv),
209        })
210    }
211
212    /// Create a `Pinger` instance, you can make special configuration for this instance.
213    pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
214        Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
215    }
216
217    /// Expose the underlying socket, if user wants to modify any options on it
218    pub fn get_socket(&self) -> AsyncSocket {
219        self.socket.clone()
220    }
221}
222
223async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
224    let mut buf = [0; 2048];
225    loop {
226        if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
227            let timestamp = Instant::now();
228            let message = &buf[..sz];
229            let local_addr = socket.local_addr().unwrap().ip();
230            let packet = {
231                let result = match addr.ip() {
232                    IpAddr::V4(src_addr) => {
233                        let local_addr_ip4 = match local_addr {
234                            IpAddr::V4(local_addr_ip4) => local_addr_ip4,
235                            _ => continue,
236                        };
237
238                        Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
239                            .map(IcmpPacket::V4)
240                    }
241                    IpAddr::V6(src_addr) => {
242                        Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
243                    }
244                };
245                match result {
246                    Ok(packet) => packet,
247                    Err(err) => {
248                        debug!("error decoding ICMP packet: {:?}", err);
249                        continue;
250                    }
251                }
252            };
253
254            let ident = if is_linux_icmp_socket!(socket.get_type()) {
255                None
256            } else {
257                Some(packet.get_identifier())
258            };
259
260            if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
261                // If send fails the receiving end has closed. Nothing to do.
262                let _ = waiter.send(Reply { timestamp, packet });
263            } else {
264                debug!("no one is waiting for ICMP packet ({:?})", packet);
265            }
266        }
267    }
268}