Skip to main content

rust_p2p_core/tunnel/
tcp.rs

1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use async_lock::Mutex;
5use async_trait::async_trait;
6use bytes::Bytes;
7use dashmap::DashMap;
8use dyn_clone::DynClone;
9use rand::Rng;
10use std::io;
11use std::io::IoSlice;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15use tachyonix::{Receiver, Sender, TrySendError};
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
18use tokio::net::{TcpListener, TcpStream};
19
20pub struct TcpTunnelDispatcher {
21    route_idle_time: Duration,
22    tcp_listener: TcpListener,
23    connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
24    #[allow(dead_code)]
25    pub(crate) socket_manager: Arc<TcpSocketManager>,
26    write_half_collect: WriteHalfCollect,
27    init_codec: Arc<Box<dyn InitCodec>>,
28}
29
30impl TcpTunnelDispatcher {
31    /// Construct a `TCP` tunnel with the specified configuration
32    pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
33        config.check()?;
34        let address: SocketAddr = if config.use_v6 {
35            format!("[::]:{}", config.tcp_port).parse().unwrap()
36        } else {
37            format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
38        };
39
40        let tcp_listener = create_tcp_listener(address)?;
41        let local_addr = tcp_listener.local_addr()?;
42        let tcp_listener = TcpListener::from_std(tcp_listener)?;
43        let (connect_sender, connect_receiver) = tachyonix::channel(128);
44        let write_half_collect = WriteHalfCollect::new(config.tcp_multiplexing_limit);
45        let init_codec = Arc::new(config.init_codec);
46        let socket_manager = Arc::new(TcpSocketManager::new(
47            local_addr,
48            config.tcp_multiplexing_limit,
49            write_half_collect.clone(),
50            connect_sender,
51            config.default_interface,
52            init_codec.clone(),
53        ));
54        Ok(TcpTunnelDispatcher {
55            route_idle_time: config.route_idle_time,
56            tcp_listener,
57            connect_receiver,
58            socket_manager,
59            write_half_collect,
60            init_codec,
61        })
62    }
63}
64
65impl TcpTunnelDispatcher {
66    /// Dispatch `TCP` tunnel from this kind Dispatcher
67    pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
68        tokio::select! {
69            rs=self.connect_receiver.recv()=>{
70                let (route_key,read_half,sender) = rs.
71                    map_err(|_| io::Error::other("connect_receiver done"))?;
72                let local_addr = read_half.read_half.local_addr()?;
73                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
74            },
75            rs=self.tcp_listener.accept()=>{
76                let (tcp_stream,addr) = rs?;
77                tcp_stream.set_nodelay(true)?;
78                let route_key = tcp_stream.route_key()?;
79                let (read_half,write_half) = tcp_stream.into_split();
80                let (decoder,encoder) = self.init_codec.codec(addr)?;
81                let read_half = ReadHalfBox::new(read_half,decoder);
82                let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
83                let local_addr = read_half.read_half.local_addr()?;
84                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
85            }
86        }
87    }
88    pub fn manager(&self) -> &Arc<TcpSocketManager> {
89        &self.socket_manager
90    }
91}
92
93pub struct TcpTunnel {
94    local_addr: SocketAddr,
95    route_key: RouteKey,
96    route_idle_time: Duration,
97    tcp_read: OwnedReadHalf,
98    decoder: Box<dyn Decoder>,
99    sender: Sender<Bytes>,
100}
101
102impl TcpTunnel {
103    pub(crate) fn new(
104        local_addr: SocketAddr,
105        route_idle_time: Duration,
106        route_key: RouteKey,
107        read: ReadHalfBox,
108        sender: Sender<Bytes>,
109    ) -> Self {
110        let decoder = read.decoder;
111        let tcp_read = read.read_half;
112        Self {
113            local_addr,
114            route_key,
115            route_idle_time,
116            tcp_read,
117            decoder,
118            sender,
119        }
120    }
121    #[inline]
122    pub fn route_key(&self) -> RouteKey {
123        self.route_key
124    }
125    pub fn local_addr(&self) -> SocketAddr {
126        self.local_addr
127    }
128    pub fn done(&mut self) {
129        self.sender.close();
130    }
131    pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
132        Ok(WeakTcpTunnelSender::new(self.sender.clone()))
133    }
134}
135#[derive(Clone)]
136pub struct WeakTcpTunnelSender {
137    sender: Sender<Bytes>,
138}
139impl WeakTcpTunnelSender {
140    fn new(sender: Sender<Bytes>) -> Self {
141        Self { sender }
142    }
143    pub async fn send(&self, buf: Bytes) -> io::Result<()> {
144        if buf.is_empty() {
145            return Ok(());
146        }
147        self.sender
148            .send(buf)
149            .await
150            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
151    }
152}
153
154impl Drop for TcpTunnel {
155    fn drop(&mut self) {
156        self.done();
157    }
158}
159
160impl TcpTunnel {
161    /// Writing `buf` to the target denoted by `route_key` via this tunnel
162    pub async fn send(&self, buf: Bytes) -> io::Result<()> {
163        if buf.is_empty() {
164            return Ok(());
165        }
166        self.sender
167            .send(buf)
168            .await
169            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
170    }
171
172    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
173        match tokio::time::timeout(
174            self.route_idle_time,
175            self.decoder.decode(&mut self.tcp_read, buf),
176        )
177        .await
178        {
179            Ok(rs) => rs,
180            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
181        }
182    }
183    pub async fn batch_recv<B: AsMut<[u8]>>(
184        &mut self,
185        bufs: &mut [B],
186        sizes: &mut [usize],
187    ) -> io::Result<usize> {
188        if bufs.is_empty() || bufs.len() != sizes.len() {
189            return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
190        }
191        match tokio::time::timeout(
192            self.route_idle_time,
193            self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
194        )
195        .await
196        {
197            Ok(rs) => {
198                let len = rs?;
199                sizes[0] = len;
200                let mut num = 1;
201                while num < bufs.len() {
202                    let rs = self
203                        .decoder
204                        .try_decode(&mut self.tcp_read, bufs[num].as_mut());
205                    match rs {
206                        Ok(len) => {
207                            sizes[num] = len;
208                            num += 1;
209                        }
210                        Err(_e) => break,
211                    }
212                }
213
214                Ok(num)
215            }
216            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
217        }
218    }
219    /// Receive bytes from this tunnel, which the configured Decoder pre-processes
220    /// `usize` in the `Ok` branch indicates how many bytes are received
221    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
222    pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
223        match self.recv(buf).await {
224            Ok(len) => Ok((len, self.route_key())),
225            Err(e) => Err(e),
226        }
227    }
228    pub async fn batch_recv_from<B: AsMut<[u8]>>(
229        &mut self,
230        bufs: &mut [B],
231        sizes: &mut [usize],
232    ) -> io::Result<(usize, RouteKey)> {
233        match self.batch_recv(bufs, sizes).await {
234            Ok(len) => Ok((len, self.route_key())),
235            Err(e) => Err(e),
236        }
237    }
238}
239
240#[derive(Clone)]
241pub struct WriteHalfCollect {
242    tcp_multiplexing_limit: usize,
243    addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
244    write_half_map: Arc<DashMap<usize, Sender<Bytes>>>,
245}
246
247impl WriteHalfCollect {
248    fn new(tcp_multiplexing_limit: usize) -> Self {
249        Self {
250            tcp_multiplexing_limit,
251            addr_mapping: Default::default(),
252            write_half_map: Default::default(),
253        }
254    }
255}
256
257pub(crate) struct ReadHalfBox {
258    read_half: OwnedReadHalf,
259    decoder: Box<dyn Decoder>,
260}
261
262impl ReadHalfBox {
263    pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
264        Self { read_half, decoder }
265    }
266}
267
268impl WriteHalfCollect {
269    pub(crate) fn add_write_half(
270        &self,
271        route_key: RouteKey,
272        index_offset: usize,
273        mut writer: OwnedWriteHalf,
274        mut decoder: Box<dyn Encoder>,
275    ) -> Sender<Bytes> {
276        assert!(index_offset < self.tcp_multiplexing_limit);
277
278        let index = route_key.index_usize();
279        let _ref = self
280            .addr_mapping
281            .entry(route_key.addr())
282            .and_modify(|v| {
283                v[index_offset] = index;
284            })
285            .or_insert_with(|| {
286                let mut v = vec![0; self.tcp_multiplexing_limit];
287                v[index_offset] = index;
288                v
289            });
290        let (s, mut r) = tachyonix::channel(128);
291        let sender = s.clone();
292        self.write_half_map.insert(index, s);
293        let collect = self.clone();
294        tokio::spawn(async move {
295            let mut vec_buf = Vec::with_capacity(16);
296            const IO_SLICE_CAPACITY: usize = 16;
297            let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
298            let io_slice_storage = io_buffer.as_mut_slice();
299            while let Ok(v) = r.recv().await {
300                if let Ok(buf) = r.try_recv() {
301                    vec_buf.push(v);
302                    vec_buf.push(buf);
303                    while let Ok(buf) = r.try_recv() {
304                        vec_buf.push(buf);
305                        if vec_buf.len() == 16 {
306                            break;
307                        }
308                    }
309
310                    // Safety
311                    // reuse the storage of `io_buffer` via `vec` that only lives in this block and manually clear the content
312                    // within the storage when exiting the block
313                    // leak the memory storage after using `vec` since the storage is managed by `io_buffer`
314                    let rs = {
315                        let mut vec = unsafe {
316                            Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
317                        };
318                        for x in &vec_buf {
319                            vec.push(IoSlice::new(x));
320                        }
321                        let rs = decoder.encode_multiple(&mut writer, &vec).await;
322                        vec.clear();
323                        std::mem::forget(vec);
324                        rs
325                    };
326                    vec_buf.clear();
327                    if let Err(e) = rs {
328                        log::debug!("{route_key:?},{e:?}");
329                        break;
330                    }
331                } else {
332                    let rs = decoder.encode(&mut writer, &v).await;
333                    if let Err(e) = rs {
334                        log::debug!("{route_key:?},{e:?}");
335                        break;
336                    }
337                }
338            }
339            collect.remove(&route_key);
340        });
341        sender
342    }
343    pub(crate) fn remove(&self, route_key: &RouteKey) {
344        let index_usize = route_key.index_usize();
345        self.addr_mapping
346            .remove_if_mut(&route_key.addr(), |_k, index_vec| {
347                let mut remove = true;
348                for v in index_vec {
349                    if *v == index_usize {
350                        *v = 0;
351                    }
352                    if *v != 0 {
353                        remove = false;
354                    }
355                }
356                remove
357            });
358
359        self.write_half_map.remove(&index_usize);
360    }
361    pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<Bytes>> {
362        self.write_half_map.get(index).map(|v| v.value().clone())
363    }
364    pub(crate) fn get_write_half_by_key(&self, route_key: &RouteKey) -> io::Result<Sender<Bytes>> {
365        match route_key.index() {
366            Index::Tcp(index) => {
367                let sender = self
368                    .get_write_half(&index)
369                    .ok_or_else(|| io::Error::other(format!("not found {route_key:?}")))?;
370                Ok(sender)
371            }
372            _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
373        }
374    }
375
376    pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
377        if let Some(v) = self.addr_mapping.get(addr) {
378            for index_usize in v.value() {
379                if *index_usize != 0 {
380                    return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
381                }
382            }
383        }
384        None
385    }
386    pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
387        if let Some(v) = self.addr_mapping.get(addr) {
388            assert_eq!(v.len(), self.tcp_multiplexing_limit);
389            let index_usize = v[index];
390            if index_usize == 0 {
391                return None;
392            }
393            return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
394        }
395        None
396    }
397    pub async fn send_to(&self, buf: Bytes, route_key: &RouteKey) -> io::Result<()> {
398        let write_half = self.get_write_half_by_key(route_key)?;
399        if buf.is_empty() {
400            return Ok(());
401        }
402        if let Err(_e) = write_half.send(buf).await {
403            Err(io::Error::from(io::ErrorKind::WriteZero))
404        } else {
405            Ok(())
406        }
407    }
408    pub fn try_send_to(&self, buf: Bytes, route_key: &RouteKey) -> io::Result<()> {
409        let write_half = self.get_write_half_by_key(route_key)?;
410        if buf.is_empty() {
411            return Ok(());
412        }
413        if let Err(e) = write_half.try_send(buf) {
414            match e {
415                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
416                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
417            }
418        } else {
419            Ok(())
420        }
421    }
422}
423
424pub struct TcpSocketManager {
425    lock: DashMap<SocketAddr, Arc<Mutex<()>>>,
426    local_addr: SocketAddr,
427    tcp_multiplexing_limit: usize,
428    write_half_collect: WriteHalfCollect,
429    connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
430    default_interface: Option<LocalInterface>,
431    init_codec: Arc<Box<dyn InitCodec>>,
432}
433
434impl TcpSocketManager {
435    pub(crate) fn new(
436        local_addr: SocketAddr,
437        tcp_multiplexing_limit: usize,
438        write_half_collect: WriteHalfCollect,
439        connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
440        default_interface: Option<LocalInterface>,
441        init_codec: Arc<Box<dyn InitCodec>>,
442    ) -> Self {
443        Self {
444            local_addr,
445            lock: Default::default(),
446            tcp_multiplexing_limit,
447            write_half_collect,
448            connect_sender,
449            default_interface,
450            init_codec,
451        }
452    }
453    pub fn local_addr(&self) -> SocketAddr {
454        self.local_addr
455    }
456}
457
458impl TcpSocketManager {
459    /// Multiple connections can be initiated to the target address.
460    pub async fn multi_connect(
461        &self,
462        addr: SocketAddr,
463        index_offset: usize,
464    ) -> io::Result<RouteKey> {
465        self.multi_connect_impl(addr, index_offset, None).await
466    }
467    pub(crate) async fn multi_connect_impl(
468        &self,
469        addr: SocketAddr,
470        index_offset: usize,
471        ttl: Option<u8>,
472    ) -> io::Result<RouteKey> {
473        let len = self.tcp_multiplexing_limit;
474        if index_offset >= len {
475            return Err(io::Error::new(
476                io::ErrorKind::InvalidInput,
477                "index out of bounds",
478            ));
479        }
480        let lock = self
481            .lock
482            .entry(addr)
483            .or_insert_with(|| Arc::new(Mutex::new(())))
484            .value()
485            .clone();
486        let _guard = lock.lock().await;
487        if let Some(route_key) = self
488            .write_half_collect
489            .get_limit_route_key(index_offset, &addr)
490        {
491            return Ok(route_key);
492        }
493        self.connect_impl(0, addr, index_offset, ttl).await
494    }
495    /// Initiate a connection.
496    pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
497        let lock = self
498            .lock
499            .entry(addr)
500            .or_insert_with(|| Arc::new(Mutex::new(())))
501            .value()
502            .clone();
503        let _guard = lock.lock().await;
504        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
505            return Ok(route_key);
506        }
507        self.connect_impl(0, addr, 0, None).await
508    }
509    pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
510        let lock = self
511            .lock
512            .entry(addr)
513            .or_insert_with(|| Arc::new(Mutex::new(())))
514            .value()
515            .clone();
516        let _guard = lock.lock().await;
517        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
518            return Ok(route_key);
519        }
520        self.connect_impl(0, addr, 0, ttl).await
521    }
522    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
523    pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
524        let lock = self
525            .lock
526            .entry(addr)
527            .or_insert_with(|| Arc::new(Mutex::new(())))
528            .value()
529            .clone();
530        let _guard = lock.lock().await;
531        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
532            return Ok(route_key);
533        }
534        self.connect_impl(self.local_addr.port(), addr, 0, None)
535            .await
536    }
537    pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
538        let stream = connect_tcp(
539            addr,
540            self.local_addr.port(),
541            self.default_interface.as_ref(),
542            None,
543        )
544        .await?;
545        Ok(stream)
546    }
547    async fn connect_impl(
548        &self,
549        bind_port: u16,
550        addr: SocketAddr,
551        index_offset: usize,
552        ttl: Option<u8>,
553    ) -> io::Result<RouteKey> {
554        let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
555        let route_key = stream.route_key()?;
556        let (read_half, write_half) = stream.into_split();
557        let (decoder, encoder) = self.init_codec.codec(addr)?;
558        let read_half = ReadHalfBox::new(read_half, decoder);
559        let sender =
560            self.write_half_collect
561                .add_write_half(route_key, index_offset, write_half, encoder);
562        if let Err(_e) = self
563            .connect_sender
564            .send((route_key, read_half, sender))
565            .await
566        {
567            Err(io::Error::new(
568                io::ErrorKind::UnexpectedEof,
569                "connect close",
570            ))?
571        }
572        Ok(route_key)
573    }
574}
575
576impl TcpSocketManager {
577    pub async fn multi_send_to<A: Into<SocketAddr>>(&self, buf: Bytes, addr: A) -> io::Result<()> {
578        self.multi_send_to_impl(buf, addr, None).await
579    }
580    pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
581        &self,
582        buf: Bytes,
583        addr: A,
584        ttl: Option<u8>,
585    ) -> io::Result<()> {
586        let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
587        let route_key = self
588            .multi_connect_impl(addr.into(), index_offset, ttl)
589            .await?;
590        self.send_to(buf, &route_key).await
591    }
592    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
593    pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
594        &self,
595        buf: Bytes,
596        addr: A,
597    ) -> io::Result<()> {
598        let route_key = self.connect_reuse_port(addr.into()).await?;
599        self.send_to(buf, &route_key).await
600    }
601
602    /// Writing `buf` to the target denoted by `route_key`
603    pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
604        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
605        self.write_half_collect.send_to(buf, &route_key).await
606    }
607    pub async fn send_to_addr<D: Into<SocketAddr>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
608        let route_key = self.connect(dest.into()).await?;
609        self.write_half_collect.send_to(buf, &route_key).await
610    }
611    pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
612        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
613        self.write_half_collect.try_send_to(buf, &route_key)
614    }
615}
616pub trait ToRouteKeyForTcp<T> {
617    fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
618}
619
620impl ToRouteKeyForTcp<()> for RouteKey {
621    fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
622        Ok(dest)
623    }
624}
625
626impl ToRouteKeyForTcp<()> for &RouteKey {
627    fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
628        Ok(*dest)
629    }
630}
631
632impl ToRouteKeyForTcp<()> for &mut RouteKey {
633    fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
634        Ok(*dest)
635    }
636}
637// impl<S: Into<SocketAddr>> ToRouteKeyForTcp<()> for S {
638//     async fn route_key(socket_manager: &SocketManager, dest: Self) -> io::Result<RouteKey> {
639//         socket_manager.connect(dest.into()).await
640//     }
641// }
642
643pub trait TcpStreamIndex {
644    fn route_key(&self) -> io::Result<RouteKey>;
645    fn index(&self) -> Index;
646}
647
648impl TcpStreamIndex for TcpStream {
649    fn route_key(&self) -> io::Result<RouteKey> {
650        let addr = self.peer_addr()?;
651
652        Ok(RouteKey::new(self.index(), addr))
653    }
654
655    fn index(&self) -> Index {
656        #[cfg(windows)]
657        use std::os::windows::io::AsRawSocket;
658        #[cfg(windows)]
659        let index = self.as_raw_socket() as usize;
660        #[cfg(unix)]
661        use std::os::fd::{AsFd, AsRawFd};
662        #[cfg(unix)]
663        let index = self.as_fd().as_raw_fd() as usize;
664        Index::Tcp(index)
665    }
666}
667
668/// The default byte encoder/decoder; using this is no different from directly using a TCP reliable.
669pub struct BytesCodec;
670
671/// Fixed-length prefix encoder/decoder.
672pub struct LengthPrefixedCodec;
673
674#[async_trait]
675impl Decoder for BytesCodec {
676    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
677        let len = read.read(src).await?;
678        Ok(len)
679    }
680}
681
682#[async_trait]
683impl Encoder for BytesCodec {
684    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
685        write.write_all(data).await?;
686        Ok(())
687    }
688}
689
690#[async_trait]
691impl Decoder for LengthPrefixedCodec {
692    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
693        let mut head = [0; 4];
694        read.read_exact(&mut head).await?;
695        let len = u32::from_be_bytes(head) as usize;
696        read.read_exact(&mut src[..len]).await?;
697        Ok(len)
698    }
699}
700
701#[async_trait]
702impl Encoder for LengthPrefixedCodec {
703    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
704        let head: [u8; 4] = (data.len() as u32).to_be_bytes();
705        write.write_all(&head).await?;
706        write.write_all(data).await?;
707        Ok(())
708    }
709}
710
711#[derive(Clone)]
712pub struct BytesInitCodec;
713
714impl InitCodec for BytesInitCodec {
715    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
716        Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
717    }
718}
719
720#[derive(Clone)]
721pub struct LengthPrefixedInitCodec;
722
723impl InitCodec for LengthPrefixedInitCodec {
724    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
725        Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
726    }
727}
728
729pub trait InitCodec: Send + Sync + DynClone {
730    fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
731}
732dyn_clone::clone_trait_object!(InitCodec);
733
734#[async_trait]
735pub trait Decoder: Send + Sync {
736    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
737    fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
738        Err(io::Error::from(io::ErrorKind::WouldBlock))
739    }
740}
741
742#[async_trait]
743pub trait Encoder: Send + Sync {
744    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
745    async fn encode_multiple(
746        &mut self,
747        write: &mut OwnedWriteHalf,
748        bufs: &[IoSlice<'_>],
749    ) -> io::Result<()> {
750        for buf in bufs {
751            self.encode(write, buf).await?
752        }
753        Ok(())
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use async_trait::async_trait;
760    use std::io;
761    use std::net::SocketAddr;
762    use tokio::io::{AsyncReadExt, AsyncWriteExt};
763    use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
764
765    use crate::tunnel::config::TcpTunnelConfig;
766    use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
767
768    #[tokio::test]
769    pub async fn create_tcp_tunnel() {
770        let config: TcpTunnelConfig = TcpTunnelConfig::default();
771        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
772        drop(tcp_tunnel_factory)
773    }
774
775    #[tokio::test]
776    pub async fn create_codec_tcp_tunnel() {
777        let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
778        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
779        drop(tcp_tunnel_factory)
780    }
781
782    #[derive(Clone)]
783    struct MyInitCodeC;
784
785    impl InitCodec for MyInitCodeC {
786        fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
787            Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
788        }
789    }
790
791    struct MyCodeC;
792
793    #[async_trait]
794    impl Decoder for MyCodeC {
795        async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
796            let mut head = [0; 2];
797            read.read_exact(&mut head).await?;
798            let len = u16::from_be_bytes(head) as usize;
799            read.read_exact(&mut src[..len]).await?;
800            Ok(len)
801        }
802    }
803
804    #[async_trait]
805    impl Encoder for MyCodeC {
806        async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
807            let head: [u8; 2] = (data.len() as u16).to_be_bytes();
808            write.write_all(&head).await?;
809            write.write_all(data).await?;
810            Ok(())
811        }
812    }
813}