rust_p2p_core/punch/
mod.rs

1use std::collections::HashMap;
2use std::io;
3use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
4use std::ops::{Div, Mul};
5use std::sync::Arc;
6use std::time::Duration;
7
8use parking_lot::Mutex;
9use rand::seq::SliceRandom;
10use rand::Rng;
11
12use crate::nat::{NatInfo, NatType};
13
14use crate::tunnel::TunnelDispatcher;
15use crate::tunnel::{tcp, udp};
16pub use config::*;
17pub mod config;
18
19#[derive(Clone)]
20pub struct Puncher {
21    // 端口顺序
22    port_vec: Arc<Vec<u16>>,
23    // 指定IP的打洞记录
24    sym_record: Arc<Mutex<HashMap<SocketAddr, usize>>>,
25    #[allow(clippy::type_complexity)]
26    count_record: Arc<Mutex<HashMap<SocketAddr, (usize, usize, u64)>>>,
27    udp_socket_manager: Option<Arc<udp::UdpSocketManager>>,
28    tcp_socket_manager: Option<Arc<tcp::TcpSocketManager>>,
29}
30
31impl From<&TunnelDispatcher> for Puncher {
32    fn from(value: &TunnelDispatcher) -> Self {
33        let tcp_socket_manager = value.shared_tcp_socket_manager();
34        let udp_socket_manager = value.shared_udp_socket_manager();
35        Self::new(udp_socket_manager, tcp_socket_manager)
36    }
37}
38
39impl Puncher {
40    pub fn new(
41        udp_socket_manager: Option<Arc<udp::UdpSocketManager>>,
42        tcp_socket_manager: Option<Arc<tcp::TcpSocketManager>>,
43    ) -> Puncher {
44        let mut port_vec: Vec<u16> = (1..=65535).collect();
45        let mut rng = rand::rng();
46        port_vec.shuffle(&mut rng);
47        Self {
48            port_vec: Arc::new(port_vec),
49            sym_record: Arc::new(Mutex::new(HashMap::new())),
50            count_record: Arc::new(Mutex::new(HashMap::new())),
51            udp_socket_manager,
52            tcp_socket_manager,
53        }
54    }
55}
56fn now() -> u64 {
57    let now = std::time::SystemTime::now();
58    let time = now
59        .duration_since(std::time::UNIX_EPOCH)
60        .unwrap_or_default();
61    time.as_secs()
62}
63impl Puncher {
64    fn clean(&self) {
65        let mut count_record = self.count_record.lock();
66        let ten_minutes_ago = now() - 1200;
67        count_record.retain(|_addr, &mut (_u1, _u2, timestamp)| timestamp >= ten_minutes_ago);
68        let valid_keys: std::collections::HashSet<_> = count_record.keys().cloned().collect();
69        let mut sym_map = self.sym_record.lock();
70        sym_map.retain(|addr, _| valid_keys.contains(addr));
71    }
72    /// Call `need_punch` at a certain frequency, and call [`punch_now`](Self::punch_now) after getting true.
73    /// Determine whether punching is needed.
74    pub fn need_punch(&self, punch_info: &PunchInfo) -> bool {
75        let Some(id) = punch_info.peer_nat_info.flag() else {
76            return false;
77        };
78        let (count, _, _) = *self
79            .count_record
80            .lock()
81            .entry(id)
82            .and_modify(|(v, _, time)| {
83                *v += 1;
84                *time = now();
85            })
86            .or_insert((0, 0, now()));
87        if count > 8 {
88            //降低频率
89            let interval = count / 8;
90            return count % interval.min(360) == 0;
91        }
92        true
93    }
94
95    /// Call `punch` at a certain frequency
96    pub async fn punch(&self, buf: &[u8], punch_info: PunchInfo) -> io::Result<()> {
97        if !self.need_punch(&punch_info) {
98            return Ok(());
99        }
100        self.punch_now(Some(buf), buf, punch_info).await
101    }
102    pub async fn punch_now(
103        &self,
104        tcp_buf: Option<&[u8]>,
105        udp_buf: &[u8],
106        punch_info: PunchInfo,
107    ) -> io::Result<()> {
108        self.clean();
109        let peer = punch_info
110            .peer_nat_info
111            .flag()
112            .unwrap_or(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
113        let (_, count, _) = *self
114            .count_record
115            .lock()
116            .entry(peer)
117            .and_modify(|(_, v, time)| {
118                *v += 1;
119                *time = now();
120            })
121            .or_insert((0, 0, now()));
122        let ttl = if count < 255 { Some(count as u8) } else { None };
123        let peer_nat_info = punch_info.peer_nat_info;
124        let punch_model = punch_info.punch_model;
125
126        type Scope<'a, T> = async_scoped::TokioScope<'a, T>;
127        Scope::scope_and_block(|s| {
128            if let Some(tcp_socket_manager) = self.tcp_socket_manager.as_ref() {
129                for addr in &peer_nat_info.mapping_tcp_addr {
130                    s.spawn(async move {
131                        Self::connect_tcp(tcp_socket_manager, tcp_buf, *addr, ttl).await;
132                    })
133                }
134                if punch_model.is_match(PunchPolicy::IPv4Tcp) {
135                    if let Some(addr) = peer_nat_info.local_ipv4_tcp() {
136                        s.spawn(async move {
137                            Self::connect_tcp(tcp_socket_manager, tcp_buf, addr, ttl).await;
138                        })
139                    }
140                    for addr in peer_nat_info.public_ipv4_tcp() {
141                        s.spawn(async move {
142                            Self::connect_tcp(tcp_socket_manager, tcp_buf, addr, ttl).await;
143                        })
144                    }
145                }
146                if punch_model.is_match(PunchPolicy::IPv6Tcp) {
147                    if let Some(addr) = peer_nat_info.ipv6_tcp_addr() {
148                        s.spawn(async move {
149                            Self::connect_tcp(tcp_socket_manager, tcp_buf, addr, ttl).await;
150                        })
151                    }
152                }
153            }
154        });
155        self.punch_udp(peer, count, udp_buf, &peer_nat_info, &punch_model)
156            .await?;
157
158        Ok(())
159    }
160    async fn connect_tcp(
161        tcp_socket_manager: &tcp::TcpSocketManager,
162        buf: Option<&[u8]>,
163        addr: SocketAddr,
164        ttl: Option<u8>,
165    ) {
166        let rs = if let Some(buf) = buf {
167            tokio::time::timeout(
168                Duration::from_secs(3),
169                tcp_socket_manager.multi_send_to_impl(buf.into(), addr, ttl),
170            )
171            .await
172        } else {
173            tokio::time::timeout(Duration::from_secs(3), async {
174                tcp_socket_manager.connect_ttl(addr, ttl).await?;
175                Ok(())
176            })
177            .await
178        };
179        match rs {
180            Ok(rs) => {
181                if let Err(e) = rs {
182                    log::warn!("tcp connect {addr},{e:?}");
183                }
184            }
185            Err(_) => {
186                log::warn!("tcp connect timeout {addr}");
187            }
188        }
189    }
190    async fn punch_udp(
191        &self,
192        peer_id: SocketAddr,
193        count: usize,
194        buf: &[u8],
195        peer_nat_info: &NatInfo,
196        punch_model: &PunchModel,
197    ) -> io::Result<()> {
198        let udp_socket_manager = if let Some(udp_socket_manager) = self.udp_socket_manager.as_ref()
199        {
200            udp_socket_manager
201        } else {
202            return Ok(());
203        };
204        if !peer_nat_info.mapping_udp_addr.is_empty() {
205            let mapping_udp_v4_addr: Vec<SocketAddr> = peer_nat_info
206                .mapping_udp_addr
207                .iter()
208                .filter(|a| a.is_ipv4())
209                .copied()
210                .collect();
211            udp_socket_manager.try_main_v4_batch_send_to(buf, &mapping_udp_v4_addr);
212
213            let mapping_udp_v6_addr: Vec<SocketAddr> = peer_nat_info
214                .mapping_udp_addr
215                .iter()
216                .filter(|a| a.is_ipv6())
217                .copied()
218                .collect();
219            udp_socket_manager.try_main_v6_batch_send_to(buf, &mapping_udp_v6_addr);
220        }
221        let local_ipv4_addrs = peer_nat_info.local_ipv4_addrs();
222        if !local_ipv4_addrs.is_empty() {
223            udp_socket_manager.try_main_v4_batch_send_to(buf, &local_ipv4_addrs);
224        }
225
226        if punch_model.is_match(PunchPolicy::IPv6Udp) {
227            let v6_addr = peer_nat_info.ipv6_udp_addr();
228            udp_socket_manager.try_main_v6_batch_send_to(buf, &v6_addr);
229        }
230        if !punch_model.is_match(PunchPolicy::IPv4Udp) {
231            return Ok(());
232        }
233        if peer_nat_info.public_ips.is_empty() {
234            return Ok(());
235        }
236
237        match peer_nat_info.nat_type {
238            NatType::Symmetric => {
239                // 假设对方绑定n个端口,通过NAT对外映射出n个 公网ip:公网端口,自己随机尝试k次的情况下
240                // 猜中的概率 p = 1-((65535-n)/65535)*((65535-n-1)/(65535-1))*...*((65535-n-k+1)/(65535-k+1))
241                // n取76,k取600,猜中的概率就超过50%了
242                // 前提 自己是锥形网络,否则猜中了也通信不了
243
244                //预测范围内最多发送max_k1个包
245                let max_k1 = 60;
246                //全局最多发送max_k2个包
247                let mut max_k2: usize = rand::rng().random_range(600..800);
248                if count > 8 {
249                    //递减探测规模
250                    max_k2 = max_k2.mul(8).div(count).max(max_k1 as usize);
251                }
252                let port = peer_nat_info.public_udp_ports.first().copied().unwrap_or(0);
253                if peer_nat_info.public_port_range < max_k1 * 3 {
254                    //端口变化不大时,在预测的范围内随机发送
255                    let min_port = if port > peer_nat_info.public_port_range {
256                        port - peer_nat_info.public_port_range
257                    } else {
258                        1
259                    };
260                    let (max_port, overflow) =
261                        port.overflowing_add(peer_nat_info.public_port_range);
262                    let max_port = if overflow { 65535 } else { max_port };
263                    let k = if max_port - min_port + 1 > max_k1 {
264                        max_k1 as usize
265                    } else {
266                        (max_port - min_port + 1) as usize
267                    };
268                    let mut nums: Vec<u16> = (min_port..=max_port).collect();
269                    nums.shuffle(&mut rand::rng());
270                    self.punch_symmetric(
271                        udp_socket_manager,
272                        &nums[..k],
273                        buf,
274                        &peer_nat_info.public_ips,
275                        max_k1 as usize,
276                    )
277                    .await?;
278                }
279                let start = self
280                    .sym_record
281                    .lock()
282                    .get(&peer_id)
283                    .cloned()
284                    .unwrap_or_default();
285                let mut end = start + max_k2;
286                if end > self.port_vec.len() {
287                    end = self.port_vec.len();
288                }
289                let mut index = start
290                    + self
291                        .punch_symmetric(
292                            udp_socket_manager,
293                            &self.port_vec[start..end],
294                            buf,
295                            &peer_nat_info.public_ips,
296                            max_k2,
297                        )
298                        .await?;
299                if index >= self.port_vec.len() {
300                    index = 0
301                }
302                // 记录这个IP的打洞记录
303                self.sym_record.lock().insert(peer_id, index);
304            }
305            NatType::Cone => {
306                let addr = peer_nat_info.public_ipv4_addr();
307                if addr.is_empty() {
308                    return Ok(());
309                }
310                udp_socket_manager.try_main_v4_batch_send_to(buf, &addr);
311                udp_socket_manager.try_sub_batch_send_to(buf, addr[0]);
312            }
313        }
314        Ok(())
315    }
316
317    async fn punch_symmetric(
318        &self,
319        udp_socket_manager: &udp::UdpSocketManager,
320        ports: &[u16],
321        buf: &[u8],
322        ips: &Vec<Ipv4Addr>,
323        max: usize,
324    ) -> io::Result<usize> {
325        let mut count = 0;
326        for (index, port) in ports.iter().enumerate() {
327            for pub_ip in ips {
328                count += 1;
329                if count == max {
330                    return Ok(index);
331                }
332                let addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, *port));
333                if let Err(e) = udp_socket_manager.try_send_to(buf, addr) {
334                    log::info!("{addr},{e:?}");
335                }
336                tokio::time::sleep(Duration::from_millis(2)).await
337            }
338        }
339        Ok(ports.len())
340    }
341}