Skip to main content

rust_p2p_core/punch/
mod.rs

1//! NAT hole punching for establishing direct peer-to-peer connections.
2//!
3//! This module implements UDP and TCP hole punching techniques to traverse NAT
4//! and establish direct connections between peers. It handles both Cone and
5//! Symmetric NAT types.
6//!
7//! # Examples
8//!
9//! ```rust,no_run
10//! use rust_p2p_core::punch::{Puncher, PunchInfo};
11//!
12//! # async fn example(puncher: Puncher) -> std::io::Result<()> {
13//! let punch_info = PunchInfo::default();
14//!
15//! // Check if punching is needed
16//! if puncher.need_punch(&punch_info) {
17//!     // Perform hole punching
18//!     puncher.punch_now(None, b"punch", punch_info).await?;
19//! }
20//! # Ok(())
21//! # }
22//! ```
23
24use bytes::Bytes;
25use parking_lot::Mutex;
26use rand::seq::SliceRandom;
27use rand::Rng;
28use std::collections::HashMap;
29use std::io;
30use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
31use std::ops::{Div, Mul};
32use std::sync::Arc;
33use std::time::Duration;
34
35use crate::nat::{NatInfo, NatType};
36
37use crate::tunnel::TunnelDispatcher;
38use crate::tunnel::{tcp, udp};
39pub use config::*;
40pub mod config;
41
42/// NAT hole punching coordinator.
43///
44/// `Puncher` manages the hole punching process, tracking attempts and
45/// adapting strategies based on NAT type and previous results.
46///
47/// # Examples
48///
49/// ```rust,no_run
50/// use rust_p2p_core::punch::Puncher;
51///
52/// let puncher = Puncher::new(None, None);
53/// ```
54#[derive(Clone)]
55pub struct Puncher {
56    // 端口顺序
57    port_vec: Arc<Vec<u16>>,
58    // 指定IP的打洞记录
59    sym_record: Arc<Mutex<HashMap<SocketAddr, usize>>>,
60    #[allow(clippy::type_complexity)]
61    count_record: Arc<Mutex<HashMap<SocketAddr, (usize, usize, u64)>>>,
62    udp_socket_manager: Option<Arc<udp::UdpSocketManager>>,
63    tcp_socket_manager: Option<Arc<tcp::TcpSocketManager>>,
64}
65
66impl From<&TunnelDispatcher> for Puncher {
67    fn from(value: &TunnelDispatcher) -> Self {
68        let tcp_socket_manager = value.shared_tcp_socket_manager();
69        let udp_socket_manager = value.shared_udp_socket_manager();
70        Self::new(udp_socket_manager, tcp_socket_manager)
71    }
72}
73
74impl Puncher {
75    /// Creates a new Puncher with the given socket managers.
76    ///
77    /// # Arguments
78    ///
79    /// * `udp_socket_manager` - Optional UDP socket manager for UDP punching
80    /// * `tcp_socket_manager` - Optional TCP socket manager for TCP punching
81    pub fn new(
82        udp_socket_manager: Option<Arc<udp::UdpSocketManager>>,
83        tcp_socket_manager: Option<Arc<tcp::TcpSocketManager>>,
84    ) -> Puncher {
85        let mut port_vec: Vec<u16> = (1..=65535).collect();
86        let mut rng = rand::rng();
87        port_vec.shuffle(&mut rng);
88        Self {
89            port_vec: Arc::new(port_vec),
90            sym_record: Arc::new(Mutex::new(HashMap::new())),
91            count_record: Arc::new(Mutex::new(HashMap::new())),
92            udp_socket_manager,
93            tcp_socket_manager,
94        }
95    }
96}
97fn now() -> u64 {
98    let now = std::time::SystemTime::now();
99    let time = now
100        .duration_since(std::time::UNIX_EPOCH)
101        .unwrap_or_default();
102    time.as_secs()
103}
104impl Puncher {
105    fn clean(&self) {
106        let mut count_record = self.count_record.lock();
107        let ten_minutes_ago = now() - 1200;
108        count_record.retain(|_addr, &mut (_u1, _u2, timestamp)| timestamp >= ten_minutes_ago);
109        let valid_keys: std::collections::HashSet<_> = count_record.keys().cloned().collect();
110        let mut sym_map = self.sym_record.lock();
111        sym_map.retain(|addr, _| valid_keys.contains(addr));
112    }
113    /// Determines whether hole punching is needed for a peer.
114    ///
115    /// Call this method periodically. It uses adaptive frequency based on
116    /// previous attempts to avoid excessive punching.
117    ///
118    /// # Arguments
119    ///
120    /// * `punch_info` - Information about the peer to punch
121    ///
122    /// # Returns
123    ///
124    /// `true` if punching should be attempted, `false` otherwise.
125    ///
126    /// # Examples
127    ///
128    /// ```rust,no_run
129    /// # use rust_p2p_core::punch::{Puncher, PunchInfo};
130    /// # async fn example(puncher: Puncher, punch_info: PunchInfo) {
131    /// if puncher.need_punch(&punch_info) {
132    ///     // Perform punching
133    /// }
134    /// # }
135    /// ```
136    pub fn need_punch(&self, punch_info: &PunchInfo) -> bool {
137        let Some(id) = punch_info.peer_nat_info.flag() else {
138            return false;
139        };
140        let (count, _, _) = *self.count_record.lock().entry(id).or_insert((0, 0, now()));
141        if count > 8 {
142            //降低频率
143            let interval = count / 8;
144            return count % interval.min(360) == 0;
145        }
146        true
147    }
148
149    /// Attempts hole punching if needed (convenience method).
150    ///
151    /// This combines `need_punch` and `punch_now` into a single call.
152    ///
153    /// # Arguments
154    ///
155    /// * `buf` - The data to send during punching
156    /// * `punch_info` - Information about the peer to punch
157    ///
158    /// # Examples
159    ///
160    /// ```rust,no_run
161    /// # use rust_p2p_core::punch::{Puncher, PunchInfo};
162    /// # async fn example(puncher: Puncher) -> std::io::Result<()> {
163    /// let punch_info = PunchInfo::default();
164    /// puncher.punch(b"punch_data", punch_info).await?;
165    /// # Ok(())
166    /// # }
167    /// ```
168    pub async fn punch(&self, buf: &[u8], punch_info: PunchInfo) -> io::Result<()> {
169        if !self.need_punch(&punch_info) {
170            return Ok(());
171        }
172        self.punch_now(Some(buf), buf, punch_info).await
173    }
174
175    /// Performs hole punching immediately without checking if it's needed.
176    ///
177    /// Attempts both TCP and UDP hole punching based on available socket managers
178    /// and the peer's NAT information.
179    ///
180    /// # Arguments
181    ///
182    /// * `tcp_buf` - Optional TCP punch data
183    /// * `udp_buf` - UDP punch data
184    /// * `punch_info` - Information about the peer to punch
185    ///
186    /// # Examples
187    ///
188    /// ```rust,no_run
189    /// # use rust_p2p_core::punch::{Puncher, PunchInfo};
190    /// # async fn example(puncher: Puncher) -> std::io::Result<()> {
191    /// let punch_info = PunchInfo::default();
192    /// puncher.punch_now(Some(b"tcp"), b"udp", punch_info).await?;
193    /// # Ok(())
194    /// # }
195    /// ```
196    pub async fn punch_now(
197        &self,
198        tcp_buf: Option<&[u8]>,
199        udp_buf: &[u8],
200        punch_info: PunchInfo,
201    ) -> io::Result<()> {
202        self.clean();
203        let peer = punch_info
204            .peer_nat_info
205            .flag()
206            .unwrap_or(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
207        let (_, count, _) = *self
208            .count_record
209            .lock()
210            .entry(peer)
211            .and_modify(|(_, v, time)| {
212                *v += 1;
213                *time = now();
214            })
215            .or_insert((0, 0, now()));
216        let ttl = if count < 255 { Some(count as u8) } else { None };
217        let peer_nat_info = punch_info.peer_nat_info;
218        let punch_model = punch_info.punch_model;
219        self.punch_udp(peer, count, udp_buf, &peer_nat_info, &punch_model)
220            .await?;
221        type Scope<'a, T> = async_scoped::TokioScope<'a, T>;
222        Scope::scope_and_block(|s| {
223            if let Some(tcp_socket_manager) = self.tcp_socket_manager.as_ref() {
224                for addr in &peer_nat_info.mapping_tcp_addr {
225                    s.spawn(async move {
226                        Self::connect_tcp(
227                            tcp_socket_manager,
228                            tcp_buf,
229                            *addr,
230                            ttl,
231                            Duration::from_secs(3),
232                        )
233                        .await;
234                    })
235                }
236                if punch_model.is_match(PunchPolicy::IPv4Tcp) {
237                    if let Some(addr) = peer_nat_info.local_ipv4_tcp() {
238                        s.spawn(async move {
239                            Self::connect_tcp(
240                                tcp_socket_manager,
241                                tcp_buf,
242                                addr,
243                                ttl,
244                                Duration::from_millis(100),
245                            )
246                            .await;
247                        })
248                    }
249                    for addr in peer_nat_info.public_ipv4_tcp() {
250                        s.spawn(async move {
251                            Self::connect_tcp(
252                                tcp_socket_manager,
253                                tcp_buf,
254                                addr,
255                                ttl,
256                                Duration::from_secs(3),
257                            )
258                            .await;
259                        })
260                    }
261                }
262                if punch_model.is_match(PunchPolicy::IPv6Tcp) {
263                    if let Some(addr) = peer_nat_info.ipv6_tcp_addr() {
264                        s.spawn(async move {
265                            Self::connect_tcp(
266                                tcp_socket_manager,
267                                tcp_buf,
268                                addr,
269                                ttl,
270                                Duration::from_secs(3),
271                            )
272                            .await;
273                        })
274                    }
275                }
276            }
277        });
278
279        Ok(())
280    }
281    async fn connect_tcp(
282        tcp_socket_manager: &tcp::TcpSocketManager,
283        buf: Option<&[u8]>,
284        addr: SocketAddr,
285        ttl: Option<u8>,
286        timeout: Duration,
287    ) {
288        let rs = if let Some(buf) = buf {
289            tokio::time::timeout(
290                timeout,
291                tcp_socket_manager.multi_send_to_impl(Bytes::copy_from_slice(buf), addr, ttl),
292            )
293            .await
294        } else {
295            tokio::time::timeout(timeout, async {
296                tcp_socket_manager.connect_ttl(addr, ttl).await?;
297                Ok(())
298            })
299            .await
300        };
301        match rs {
302            Ok(rs) => {
303                if let Err(e) = rs {
304                    log::warn!("tcp connect {addr},{e:?}");
305                }
306            }
307            Err(_) => {
308                log::warn!("tcp connect timeout {addr}");
309            }
310        }
311    }
312    async fn punch_udp(
313        &self,
314        peer_id: SocketAddr,
315        count: usize,
316        buf: &[u8],
317        peer_nat_info: &NatInfo,
318        punch_model: &PunchModel,
319    ) -> io::Result<()> {
320        let udp_socket_manager = if let Some(udp_socket_manager) = self.udp_socket_manager.as_ref()
321        {
322            udp_socket_manager
323        } else {
324            return Ok(());
325        };
326        if !peer_nat_info.mapping_udp_addr.is_empty() {
327            let mapping_udp_v4_addr: Vec<SocketAddr> = peer_nat_info
328                .mapping_udp_addr
329                .iter()
330                .filter(|a| a.is_ipv4())
331                .copied()
332                .collect();
333            udp_socket_manager.try_main_v4_batch_send_to(buf, &mapping_udp_v4_addr);
334
335            let mapping_udp_v6_addr: Vec<SocketAddr> = peer_nat_info
336                .mapping_udp_addr
337                .iter()
338                .filter(|a| a.is_ipv6())
339                .copied()
340                .collect();
341            udp_socket_manager.try_main_v6_batch_send_to(buf, &mapping_udp_v6_addr);
342        }
343        let local_ipv4_addrs = peer_nat_info.local_ipv4_addrs();
344        if !local_ipv4_addrs.is_empty() {
345            udp_socket_manager.try_main_v4_batch_send_to(buf, &local_ipv4_addrs);
346        }
347
348        if punch_model.is_match(PunchPolicy::IPv6Udp) {
349            let v6_addr = peer_nat_info.ipv6_udp_addr();
350            udp_socket_manager.try_main_v6_batch_send_to(buf, &v6_addr);
351        }
352        if !punch_model.is_match(PunchPolicy::IPv4Udp) {
353            return Ok(());
354        }
355        if peer_nat_info.public_ips.is_empty() {
356            return Ok(());
357        }
358
359        match peer_nat_info.nat_type {
360            NatType::Symmetric => {
361                // 假设对方绑定n个端口,通过NAT对外映射出n个 公网ip:公网端口,自己随机尝试k次的情况下
362                // 猜中的概率 p = 1-((65535-n)/65535)*((65535-n-1)/(65535-1))*...*((65535-n-k+1)/(65535-k+1))
363                // n取76,k取600,猜中的概率就超过50%了
364                // 前提 自己是锥形网络,否则猜中了也通信不了
365
366                //预测范围内最多发送max_k1个包
367                let max_k1 = 60;
368                //全局最多发送max_k2个包
369                let mut max_k2: usize = rand::rng().random_range(600..800);
370                if count > 8 {
371                    //递减探测规模
372                    max_k2 = max_k2.mul(8).div(count).max(max_k1 as usize);
373                }
374                let port = peer_nat_info.public_udp_ports.first().copied().unwrap_or(0);
375                if peer_nat_info.public_port_range < max_k1 * 3 {
376                    //端口变化不大时,在预测的范围内随机发送
377                    let min_port = if port > peer_nat_info.public_port_range {
378                        port - peer_nat_info.public_port_range
379                    } else {
380                        1
381                    };
382                    let (max_port, overflow) =
383                        port.overflowing_add(peer_nat_info.public_port_range);
384                    let max_port = if overflow { 65535 } else { max_port };
385                    let k = if max_port - min_port + 1 > max_k1 {
386                        max_k1 as usize
387                    } else {
388                        (max_port - min_port + 1) as usize
389                    };
390                    let mut nums: Vec<u16> = (min_port..=max_port).collect();
391                    nums.shuffle(&mut rand::rng());
392                    self.punch_symmetric(
393                        udp_socket_manager,
394                        &nums[..k],
395                        buf,
396                        &peer_nat_info.public_ips,
397                        max_k1 as usize,
398                    )
399                    .await?;
400                }
401                let start = self
402                    .sym_record
403                    .lock()
404                    .get(&peer_id)
405                    .cloned()
406                    .unwrap_or_default();
407                let mut end = start + max_k2;
408                if end > self.port_vec.len() {
409                    end = self.port_vec.len();
410                }
411                let mut index = start
412                    + self
413                        .punch_symmetric(
414                            udp_socket_manager,
415                            &self.port_vec[start..end],
416                            buf,
417                            &peer_nat_info.public_ips,
418                            max_k2,
419                        )
420                        .await?;
421                if index >= self.port_vec.len() {
422                    index = 0
423                }
424                // 记录这个IP的打洞记录
425                self.sym_record.lock().insert(peer_id, index);
426            }
427            NatType::Cone => {
428                let addr = peer_nat_info.public_ipv4_addr();
429                if addr.is_empty() {
430                    return Ok(());
431                }
432                udp_socket_manager.try_main_v4_batch_send_to(buf, &addr);
433                udp_socket_manager.try_sub_batch_send_to(buf, addr[0]);
434            }
435        }
436        Ok(())
437    }
438
439    async fn punch_symmetric(
440        &self,
441        udp_socket_manager: &udp::UdpSocketManager,
442        ports: &[u16],
443        buf: &[u8],
444        ips: &Vec<Ipv4Addr>,
445        max: usize,
446    ) -> io::Result<usize> {
447        let mut count = 0;
448        for (index, port) in ports.iter().enumerate() {
449            for pub_ip in ips {
450                count += 1;
451                if count == max {
452                    return Ok(index);
453                }
454                let addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, *port));
455                if let Err(e) = udp_socket_manager.try_send_to(buf, addr) {
456                    log::info!("{addr},{e:?}");
457                }
458                tokio::time::sleep(Duration::from_millis(2)).await
459            }
460        }
461        Ok(ports.len())
462    }
463}