rustp2p_reliable/
lib.rs

1pub use crate::config::Config;
2use crate::kcp::{DataType, KcpHandle};
3use crate::maintain::start_task;
4use async_shutdown::ShutdownManager;
5use bytes::BytesMut;
6use flume::Receiver;
7use parking_lot::Mutex;
8use rand::seq::SliceRandom;
9pub use rust_p2p_core::nat::NatInfo;
10use rust_p2p_core::nat::NatType;
11pub use rust_p2p_core::punch::config::*;
12use rust_p2p_core::punch::Puncher as CorePuncher;
13use rust_p2p_core::route::Index;
14use rust_p2p_core::socket::LocalInterface;
15pub use rust_p2p_core::tunnel::config::*;
16pub use rust_p2p_core::tunnel::tcp::{
17    BytesCodec, BytesInitCodec, Decoder, Encoder, InitCodec, LengthPrefixedCodec,
18    LengthPrefixedInitCodec,
19};
20use rust_p2p_core::tunnel::tcp::{TcpTunnel, WeakTcpTunnelSender};
21use rust_p2p_core::tunnel::udp::{UDPIndex, UdpTunnel};
22use rust_p2p_core::tunnel::{SocketManager, Tunnel, TunnelDispatcher};
23use std::io;
24use std::net::{Ipv4Addr, SocketAddr};
25use std::sync::Arc;
26use tokio::sync::mpsc::Sender;
27
28mod config;
29mod kcp;
30mod maintain;
31
32pub async fn from_config(config: Config) -> io::Result<(ReliableTunnelListener, Puncher)> {
33    let tunnel_config = config.tunnel_config;
34    let tcp_stun_servers = config.tcp_stun_servers;
35    let udp_stun_servers = config.udp_stun_servers;
36    let default_interface = tunnel_config
37        .udp_tunnel_config
38        .as_ref()
39        .map(|v| v.default_interface.clone())
40        .unwrap_or_default();
41
42    let (unified_tunnel_factory, puncher) =
43        rust_p2p_core::tunnel::new_tunnel_component(tunnel_config)?;
44    let manager = unified_tunnel_factory.socket_manager();
45    let shutdown_manager = ShutdownManager::<()>::new();
46    let puncher = Puncher::new(
47        default_interface,
48        tcp_stun_servers,
49        udp_stun_servers,
50        puncher,
51        manager,
52    )
53    .await?;
54    let listener = ReliableTunnelListener::new(
55        shutdown_manager.clone(),
56        unified_tunnel_factory,
57        puncher.punch_context.clone(),
58    );
59    start_task(shutdown_manager, puncher.clone());
60    Ok((listener, puncher))
61}
62pub struct ReliableTunnelListener {
63    shutdown_manager: ShutdownManager<()>,
64    punch_context: Arc<PunchContext>,
65    unified_tunnel_factory: TunnelDispatcher,
66    kcp_receiver: tokio::sync::mpsc::Receiver<KcpMessageHub>,
67    kcp_sender: Sender<KcpMessageHub>,
68}
69#[derive(Clone)]
70pub struct Puncher {
71    punch_context: Arc<PunchContext>,
72    puncher: CorePuncher,
73    socket_manager: SocketManager,
74}
75impl Drop for ReliableTunnelListener {
76    fn drop(&mut self) {
77        _ = self.shutdown_manager.trigger_shutdown(());
78    }
79}
80pub(crate) struct PunchContext {
81    default_interface: Option<LocalInterface>,
82    tcp_stun_servers: Vec<String>,
83    udp_stun_servers: Vec<String>,
84    nat_info: Arc<Mutex<NatInfo>>,
85}
86impl PunchContext {
87    pub fn new(
88        default_interface: Option<LocalInterface>,
89        tcp_stun_servers: Vec<String>,
90        udp_stun_servers: Vec<String>,
91        local_udp_ports: Vec<u16>,
92        local_tcp_port: u16,
93    ) -> Self {
94        let public_udp_ports = vec![0; local_udp_ports.len()];
95        let nat_info = NatInfo {
96            nat_type: Default::default(),
97            public_ips: vec![],
98            public_udp_ports,
99            mapping_tcp_addr: vec![],
100            mapping_udp_addr: vec![],
101            public_port_range: 0,
102            local_ipv4: Ipv4Addr::UNSPECIFIED,
103            ipv6: None,
104            local_udp_ports,
105            local_tcp_port,
106            public_tcp_port: 0,
107        };
108        Self {
109            default_interface,
110            tcp_stun_servers,
111            udp_stun_servers,
112            nat_info: Arc::new(Mutex::new(nat_info)),
113        }
114    }
115    pub fn set_public_info(
116        &self,
117        nat_type: NatType,
118        mut ips: Vec<Ipv4Addr>,
119        public_port_range: u16,
120    ) {
121        ips.retain(rust_p2p_core::extend::addr::is_ipv4_global);
122        let mut guard = self.nat_info.lock();
123        guard.public_ips = ips;
124        guard.nat_type = nat_type;
125        guard.public_port_range = public_port_range;
126    }
127    fn mapping_addr(addr: SocketAddr) -> Option<(Ipv4Addr, u16)> {
128        match addr {
129            SocketAddr::V4(addr) => Some((*addr.ip(), addr.port())),
130            SocketAddr::V6(addr) => addr.ip().to_ipv4_mapped().map(|ip| (ip, addr.port())),
131        }
132    }
133    pub fn update_tcp_public_addr(&self, addr: SocketAddr) {
134        let (ip, port) = if let Some(r) = Self::mapping_addr(addr) {
135            r
136        } else {
137            return;
138        };
139        let mut nat_info = self.nat_info.lock();
140        if rust_p2p_core::extend::addr::is_ipv4_global(&ip) && !nat_info.public_ips.contains(&ip) {
141            nat_info.public_ips.push(ip);
142        }
143        nat_info.public_tcp_port = port;
144    }
145    pub fn update_public_addr(&self, index: Index, addr: SocketAddr) {
146        let (ip, port) = if let Some(r) = Self::mapping_addr(addr) {
147            r
148        } else {
149            return;
150        };
151        let mut nat_info = self.nat_info.lock();
152
153        if rust_p2p_core::extend::addr::is_ipv4_global(&ip) {
154            if !nat_info.public_ips.contains(&ip) {
155                nat_info.public_ips.push(ip);
156            }
157            match index {
158                Index::Udp(index) => {
159                    let index = match index {
160                        UDPIndex::MainV4(index) => index,
161                        UDPIndex::MainV6(index) => index,
162                        UDPIndex::SubV4(_) => return,
163                    };
164                    if let Some(p) = nat_info.public_udp_ports.get_mut(index) {
165                        *p = port;
166                    }
167                }
168                Index::Tcp(_) => {
169                    nat_info.public_tcp_port = port;
170                }
171                _ => {}
172            }
173        } else {
174            log::debug!("not public addr: {addr:?}")
175        }
176    }
177    pub async fn update_local_addr(&self) {
178        let local_ipv4 = rust_p2p_core::extend::addr::local_ipv4().await;
179        let local_ipv6 = rust_p2p_core::extend::addr::local_ipv6().await;
180        let mut nat_info = self.nat_info.lock();
181        if let Ok(local_ipv4) = local_ipv4 {
182            nat_info.local_ipv4 = local_ipv4;
183        }
184        nat_info.ipv6 = local_ipv6.ok();
185    }
186    pub async fn update_nat_info(&self) -> io::Result<NatInfo> {
187        self.update_local_addr().await;
188        let mut udp_stun_servers = self.udp_stun_servers.clone();
189        udp_stun_servers.shuffle(&mut rand::rng());
190        let udp_stun_servers = if udp_stun_servers.len() > 3 {
191            &udp_stun_servers[..3]
192        } else {
193            &udp_stun_servers
194        };
195        let (nat_type, ips, port_range) = rust_p2p_core::stun::stun_test_nat(
196            udp_stun_servers.to_vec(),
197            self.default_interface.as_ref(),
198        )
199        .await?;
200        self.set_public_info(nat_type, ips, port_range);
201        Ok(self.nat_info())
202    }
203    pub fn nat_info(&self) -> NatInfo {
204        self.nat_info.lock().clone()
205    }
206}
207impl Puncher {
208    async fn new(
209        default_interface: Option<LocalInterface>,
210        tcp_stun_servers: Vec<String>,
211        udp_stun_servers: Vec<String>,
212        puncher: CorePuncher,
213        socket_manager: SocketManager,
214    ) -> io::Result<Self> {
215        let local_tcp_port = if let Some(v) = socket_manager.tcp_socket_manager_as_ref() {
216            v.local_addr().port()
217        } else {
218            0
219        };
220        let local_udp_ports = if let Some(v) = socket_manager.udp_socket_manager_as_ref() {
221            v.local_ports()?
222        } else {
223            vec![]
224        };
225        let punch_context = Arc::new(PunchContext::new(
226            default_interface,
227            tcp_stun_servers,
228            udp_stun_servers,
229            local_udp_ports,
230            local_tcp_port,
231        ));
232        punch_context.update_local_addr().await;
233        Ok(Self {
234            punch_context,
235            puncher,
236            socket_manager,
237        })
238    }
239
240    pub async fn punch(&self, punch_info: PunchInfo) -> io::Result<()> {
241        self.punch_conv(0, punch_info).await
242    }
243
244    pub async fn punch_conv(&self, kcp_conv: u32, punch_info: PunchInfo) -> io::Result<()> {
245        let mut punch_udp_buf = [0; 8];
246        punch_udp_buf[..4].copy_from_slice(&kcp_conv.to_le_bytes());
247        // kcp flag
248        punch_udp_buf[0] = 0x02;
249        if rust_p2p_core::stun::is_stun_response(&punch_udp_buf) {
250            return Err(io::Error::new(io::ErrorKind::Other, "kcp_conv error"));
251        }
252        if !self.puncher.need_punch(&punch_info) {
253            return Ok(());
254        }
255        self.puncher
256            .punch_now(None, &punch_udp_buf, punch_info)
257            .await
258    }
259    pub fn nat_info(&self) -> NatInfo {
260        self.punch_context.nat_info()
261    }
262}
263
264pub enum ReliableTunnel {
265    Tcp(TcpMessageHub),
266    Kcp(KcpMessageHub),
267}
268#[derive(Copy, Clone, Eq, PartialEq, Debug)]
269pub enum ReliableTunnelType {
270    Tcp,
271    Kcp,
272}
273
274impl ReliableTunnelListener {
275    fn new(
276        shutdown_manager: ShutdownManager<()>,
277        unified_tunnel_factory: TunnelDispatcher,
278        punch_context: Arc<PunchContext>,
279    ) -> Self {
280        let (kcp_sender, kcp_receiver) = tokio::sync::mpsc::channel(64);
281        Self {
282            shutdown_manager,
283            punch_context,
284            unified_tunnel_factory,
285            kcp_receiver,
286            kcp_sender,
287        }
288    }
289    pub async fn accept(&mut self) -> io::Result<ReliableTunnel> {
290        loop {
291            tokio::select! {
292                rs=self.unified_tunnel_factory.dispatch()=>{
293                    let unified_tunnel = rs?;
294                    match unified_tunnel {
295                        Tunnel::Udp(udp) => {
296                            handle_udp(self.shutdown_manager.clone(), udp, self.kcp_sender.clone(), self.punch_context.clone())?;
297                        }
298                        Tunnel::Tcp(tcp) => {
299                            let local_addr = tcp.local_addr();
300                            let remote_addr = tcp.route_key().addr();
301                            let sender = tcp.sender()?;
302                            let receiver = handle_tcp(self.shutdown_manager.clone(),tcp).await?;
303                            let hub = TcpMessageHub::new(local_addr,remote_addr,sender,receiver);
304                            return Ok(ReliableTunnel::Tcp(hub))
305                        }
306                    }
307                }
308                rs=self.kcp_receiver.recv()=>{
309                    return if let Some(hub) = rs{
310                        Ok(ReliableTunnel::Kcp(hub))
311                    }else{
312                        Err(io::Error::from(io::ErrorKind::UnexpectedEof))
313                    }
314                }
315            }
316        }
317    }
318}
319
320impl ReliableTunnel {
321    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
322        match &self {
323            ReliableTunnel::Tcp(tcp) => tcp.send(buf).await,
324            ReliableTunnel::Kcp(kcp) => kcp.send(buf).await,
325        }
326    }
327    pub async fn send_raw(&self, buf: BytesMut) -> io::Result<()> {
328        match &self {
329            ReliableTunnel::Tcp(tcp) => tcp.send(buf).await,
330            ReliableTunnel::Kcp(kcp) => kcp.send_raw(buf).await,
331        }
332    }
333    pub async fn next(&self) -> io::Result<BytesMut> {
334        match &self {
335            ReliableTunnel::Tcp(tcp) => tcp.next().await,
336            ReliableTunnel::Kcp(kcp) => kcp.next().await,
337        }
338    }
339    pub fn local_addr(&self) -> SocketAddr {
340        match &self {
341            ReliableTunnel::Tcp(tcp) => tcp.local_addr,
342            ReliableTunnel::Kcp(kcp) => kcp.local_addr,
343        }
344    }
345    pub fn remote_addr(&self) -> SocketAddr {
346        match &self {
347            ReliableTunnel::Tcp(tcp) => tcp.remote_addr,
348            ReliableTunnel::Kcp(kcp) => kcp.remote_addr,
349        }
350    }
351    pub fn tunnel_type(&self) -> ReliableTunnelType {
352        match &self {
353            ReliableTunnel::Tcp(_tcp) => ReliableTunnelType::Tcp,
354            ReliableTunnel::Kcp(_kcp) => ReliableTunnelType::Kcp,
355        }
356    }
357}
358pub struct TcpMessageHub {
359    local_addr: SocketAddr,
360    remote_addr: SocketAddr,
361    input: WeakTcpTunnelSender,
362    output: Receiver<BytesMut>,
363}
364impl TcpMessageHub {
365    pub(crate) fn new(
366        local_addr: SocketAddr,
367        remote_addr: SocketAddr,
368        input: WeakTcpTunnelSender,
369        output: Receiver<BytesMut>,
370    ) -> Self {
371        Self {
372            local_addr,
373            remote_addr,
374            input,
375            output,
376        }
377    }
378    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
379        self.input.send(buf).await
380    }
381    pub async fn next(&self) -> io::Result<BytesMut> {
382        self.output
383            .recv_async()
384            .await
385            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
386    }
387}
388pub struct KcpMessageHub {
389    local_addr: SocketAddr,
390    remote_addr: SocketAddr,
391    input: Sender<DataType>,
392    output: Receiver<BytesMut>,
393}
394
395impl KcpMessageHub {
396    pub(crate) fn new(
397        local_addr: SocketAddr,
398        remote_addr: SocketAddr,
399        input: Sender<DataType>,
400        output: Receiver<BytesMut>,
401    ) -> Self {
402        Self {
403            local_addr,
404            remote_addr,
405            input,
406            output,
407        }
408    }
409    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
410        self.input
411            .send(DataType::Kcp(buf))
412            .await
413            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
414    }
415    pub async fn send_raw(&self, buf: BytesMut) -> io::Result<()> {
416        self.input
417            .send(DataType::Raw(buf))
418            .await
419            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
420    }
421    pub async fn next(&self) -> io::Result<BytesMut> {
422        self.output
423            .recv_async()
424            .await
425            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
426    }
427}
428async fn handle_tcp(
429    shutdown_manager: ShutdownManager<()>,
430    mut tcp_tunnel: TcpTunnel,
431) -> io::Result<Receiver<BytesMut>> {
432    let (sender, receiver) = flume::bounded(128);
433    tokio::spawn(async move {
434        let mut buf = [0; 65536];
435        while let Ok(Ok(len)) = shutdown_manager
436            .wrap_cancel(tcp_tunnel.recv(&mut buf))
437            .await
438        {
439            if sender.send_async(buf[..len].into()).await.is_err() {
440                break;
441            }
442        }
443    });
444    Ok(receiver)
445}
446
447fn handle_udp(
448    shutdown_manager: ShutdownManager<()>,
449    mut udp_tunnel: UdpTunnel,
450    sender: Sender<KcpMessageHub>,
451    punch_context: Arc<PunchContext>,
452) -> io::Result<()> {
453    let mut kcp_handle = KcpHandle::new(udp_tunnel.local_addr(), udp_tunnel.sender()?, sender);
454    tokio::spawn(async move {
455        let mut buf = [0; 65536];
456
457        while let Ok(Some(rs)) = shutdown_manager
458            .wrap_cancel(udp_tunnel.recv_from(&mut buf))
459            .await
460        {
461            let (len, route_key) = match rs {
462                Ok(rs) => rs,
463                Err(e) => {
464                    log::warn!("udp_tunnel.recv_from {e:?}");
465                    continue;
466                }
467            };
468            // check stun data
469            if rust_p2p_core::stun::is_stun_response(&buf[..len]) {
470                if let Some(pub_addr) = rust_p2p_core::stun::recv_stun_response(&buf[..len]) {
471                    punch_context.update_public_addr(route_key.index(), pub_addr);
472                    continue;
473                }
474            }
475            kcp_handle.handle(&buf[..len], route_key).await;
476        }
477    });
478    Ok(())
479}