rust_p2p_core/tunnel/
tcp.rs

1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use crate::tunnel::recycle::RecycleBuf;
5use async_lock::Mutex;
6use async_trait::async_trait;
7use bytes::BytesMut;
8use dashmap::DashMap;
9use dyn_clone::DynClone;
10use rand::Rng;
11use std::io;
12use std::io::IoSlice;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tachyonix::{Receiver, Sender, TrySendError};
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
19use tokio::net::{TcpListener, TcpStream};
20
21pub struct TcpTunnelDispatcher {
22    route_idle_time: Duration,
23    tcp_listener: TcpListener,
24    connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
25    #[allow(dead_code)]
26    pub(crate) socket_manager: Arc<TcpSocketManager>,
27    write_half_collect: WriteHalfCollect,
28    init_codec: Arc<Box<dyn InitCodec>>,
29}
30
31impl TcpTunnelDispatcher {
32    /// Construct a `TCP` tunnel with the specified configuration
33    pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
34        config.check()?;
35        let address: SocketAddr = if config.use_v6 {
36            format!("[::]:{}", config.tcp_port).parse().unwrap()
37        } else {
38            format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
39        };
40
41        let tcp_listener = create_tcp_listener(address)?;
42        let local_addr = tcp_listener.local_addr()?;
43        let tcp_listener = TcpListener::from_std(tcp_listener)?;
44        let (connect_sender, connect_receiver) = tachyonix::channel(128);
45        let write_half_collect =
46            WriteHalfCollect::new(config.tcp_multiplexing_limit, config.recycle_buf);
47        let init_codec = Arc::new(config.init_codec);
48        let socket_manager = Arc::new(TcpSocketManager::new(
49            local_addr,
50            config.tcp_multiplexing_limit,
51            write_half_collect.clone(),
52            connect_sender,
53            config.default_interface,
54            init_codec.clone(),
55        ));
56        Ok(TcpTunnelDispatcher {
57            route_idle_time: config.route_idle_time,
58            tcp_listener,
59            connect_receiver,
60            socket_manager,
61            write_half_collect,
62            init_codec,
63        })
64    }
65}
66
67impl TcpTunnelDispatcher {
68    /// Dispatch `TCP` tunnel from this kind Dispatcher
69    pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
70        tokio::select! {
71            rs=self.connect_receiver.recv()=>{
72                let (route_key,read_half,sender) = rs.
73                    map_err(|_| io::Error::other("connect_receiver done"))?;
74                let local_addr = read_half.read_half.local_addr()?;
75                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
76            },
77            rs=self.tcp_listener.accept()=>{
78                let (tcp_stream,addr) = rs?;
79                tcp_stream.set_nodelay(true)?;
80                let route_key = tcp_stream.route_key()?;
81                let (read_half,write_half) = tcp_stream.into_split();
82                let (decoder,encoder) = self.init_codec.codec(addr)?;
83                let read_half = ReadHalfBox::new(read_half,decoder);
84                let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
85                let local_addr = read_half.read_half.local_addr()?;
86                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
87            }
88        }
89    }
90    pub fn manager(&self) -> &Arc<TcpSocketManager> {
91        &self.socket_manager
92    }
93}
94
95pub struct TcpTunnel {
96    local_addr: SocketAddr,
97    route_key: RouteKey,
98    route_idle_time: Duration,
99    tcp_read: OwnedReadHalf,
100    decoder: Box<dyn Decoder>,
101    sender: Sender<BytesMut>,
102}
103
104impl TcpTunnel {
105    pub(crate) fn new(
106        local_addr: SocketAddr,
107        route_idle_time: Duration,
108        route_key: RouteKey,
109        read: ReadHalfBox,
110        sender: Sender<BytesMut>,
111    ) -> Self {
112        let decoder = read.decoder;
113        let tcp_read = read.read_half;
114        Self {
115            local_addr,
116            route_key,
117            route_idle_time,
118            tcp_read,
119            decoder,
120            sender,
121        }
122    }
123    #[inline]
124    pub fn route_key(&self) -> RouteKey {
125        self.route_key
126    }
127    pub fn local_addr(&self) -> SocketAddr {
128        self.local_addr
129    }
130    pub fn done(&mut self) {
131        self.sender.close();
132    }
133    pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
134        Ok(WeakTcpTunnelSender::new(self.sender.clone()))
135    }
136}
137#[derive(Clone)]
138pub struct WeakTcpTunnelSender {
139    sender: Sender<BytesMut>,
140}
141impl WeakTcpTunnelSender {
142    fn new(sender: Sender<BytesMut>) -> Self {
143        Self { sender }
144    }
145    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
146        if buf.is_empty() {
147            return Ok(());
148        }
149        self.sender
150            .send(buf)
151            .await
152            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
153    }
154}
155
156impl Drop for TcpTunnel {
157    fn drop(&mut self) {
158        self.done();
159    }
160}
161
162impl TcpTunnel {
163    /// Writing `buf` to the target denoted by `route_key` via this tunnel
164    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
165        if buf.is_empty() {
166            return Ok(());
167        }
168        self.sender
169            .send(buf)
170            .await
171            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
172    }
173
174    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175        match tokio::time::timeout(
176            self.route_idle_time,
177            self.decoder.decode(&mut self.tcp_read, buf),
178        )
179        .await
180        {
181            Ok(rs) => rs,
182            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
183        }
184    }
185    pub async fn batch_recv<B: AsMut<[u8]>>(
186        &mut self,
187        bufs: &mut [B],
188        sizes: &mut [usize],
189    ) -> io::Result<usize> {
190        if bufs.is_empty() || bufs.len() != sizes.len() {
191            return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
192        }
193        match tokio::time::timeout(
194            self.route_idle_time,
195            self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
196        )
197        .await
198        {
199            Ok(rs) => {
200                let len = rs?;
201                sizes[0] = len;
202                let mut num = 1;
203                while num < bufs.len() {
204                    let rs = self
205                        .decoder
206                        .try_decode(&mut self.tcp_read, bufs[num].as_mut());
207                    match rs {
208                        Ok(len) => {
209                            sizes[num] = len;
210                            num += 1;
211                        }
212                        Err(_e) => break,
213                    }
214                }
215
216                Ok(num)
217            }
218            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
219        }
220    }
221    /// Receive bytes from this tunnel, which the configured Decoder pre-processes
222    /// `usize` in the `Ok` branch indicates how many bytes are received
223    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
224    pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
225        match self.recv(buf).await {
226            Ok(len) => Ok((len, self.route_key())),
227            Err(e) => Err(e),
228        }
229    }
230    pub async fn batch_recv_from<B: AsMut<[u8]>>(
231        &mut self,
232        bufs: &mut [B],
233        sizes: &mut [usize],
234    ) -> io::Result<(usize, RouteKey)> {
235        match self.batch_recv(bufs, sizes).await {
236            Ok(len) => Ok((len, self.route_key())),
237            Err(e) => Err(e),
238        }
239    }
240}
241
242#[derive(Clone)]
243pub struct WriteHalfCollect {
244    tcp_multiplexing_limit: usize,
245    addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
246    write_half_map: Arc<DashMap<usize, Sender<BytesMut>>>,
247    recycle_buf: Option<RecycleBuf>,
248}
249
250impl WriteHalfCollect {
251    fn new(tcp_multiplexing_limit: usize, recycle_buf: Option<RecycleBuf>) -> Self {
252        Self {
253            tcp_multiplexing_limit,
254            addr_mapping: Default::default(),
255            write_half_map: Default::default(),
256            recycle_buf,
257        }
258    }
259}
260
261pub(crate) struct ReadHalfBox {
262    read_half: OwnedReadHalf,
263    decoder: Box<dyn Decoder>,
264}
265
266impl ReadHalfBox {
267    pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
268        Self { read_half, decoder }
269    }
270}
271
272impl WriteHalfCollect {
273    pub(crate) fn add_write_half(
274        &self,
275        route_key: RouteKey,
276        index_offset: usize,
277        mut writer: OwnedWriteHalf,
278        mut decoder: Box<dyn Encoder>,
279    ) -> Sender<BytesMut> {
280        assert!(index_offset < self.tcp_multiplexing_limit);
281
282        let index = route_key.index_usize();
283        let _ref = self
284            .addr_mapping
285            .entry(route_key.addr())
286            .and_modify(|v| {
287                v[index_offset] = index;
288            })
289            .or_insert_with(|| {
290                let mut v = vec![0; self.tcp_multiplexing_limit];
291                v[index_offset] = index;
292                v
293            });
294        let (s, mut r) = tachyonix::channel(128);
295        let sender = s.clone();
296        self.write_half_map.insert(index, s);
297        let collect = self.clone();
298        let recycle_buf = self.recycle_buf.clone();
299        tokio::spawn(async move {
300            let mut vec_buf = Vec::with_capacity(16);
301            const IO_SLICE_CAPACITY: usize = 16;
302            let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
303            let io_slice_storage = io_buffer.as_mut_slice();
304            while let Ok(v) = r.recv().await {
305                if let Ok(buf) = r.try_recv() {
306                    vec_buf.push(v);
307                    vec_buf.push(buf);
308                    while let Ok(buf) = r.try_recv() {
309                        vec_buf.push(buf);
310                        if vec_buf.len() == 16 {
311                            break;
312                        }
313                    }
314
315                    // Safety
316                    // reuse the storage of `io_buffer` via `vec` that only lives in this block and manually clear the content
317                    // within the storage when exiting the block
318                    // leak the memory storage after using `vec` since the storage is managed by `io_buffer`
319                    let rs = {
320                        let mut vec = unsafe {
321                            Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
322                        };
323                        for x in &vec_buf {
324                            vec.push(IoSlice::new(x));
325                        }
326                        let rs = decoder.encode_multiple(&mut writer, &vec).await;
327                        vec.clear();
328                        std::mem::forget(vec);
329                        rs
330                    };
331
332                    if let Some(recycle_buf) = recycle_buf.as_ref() {
333                        while let Some(buf) = vec_buf.pop() {
334                            recycle_buf.push(buf);
335                        }
336                    } else {
337                        vec_buf.clear()
338                    }
339                    if let Err(e) = rs {
340                        log::debug!("{route_key:?},{e:?}");
341                        break;
342                    }
343                } else {
344                    let rs = decoder.encode(&mut writer, &v).await;
345                    if let Some(recycle_buf) = recycle_buf.as_ref() {
346                        recycle_buf.push(v);
347                    }
348                    if let Err(e) = rs {
349                        log::debug!("{route_key:?},{e:?}");
350                        break;
351                    }
352                }
353            }
354            collect.remove(&route_key);
355        });
356        sender
357    }
358    pub(crate) fn remove(&self, route_key: &RouteKey) {
359        let index_usize = route_key.index_usize();
360        self.addr_mapping
361            .remove_if_mut(&route_key.addr(), |_k, index_vec| {
362                let mut remove = true;
363                for v in index_vec {
364                    if *v == index_usize {
365                        *v = 0;
366                    }
367                    if *v != 0 {
368                        remove = false;
369                    }
370                }
371                remove
372            });
373
374        self.write_half_map.remove(&index_usize);
375    }
376    pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<BytesMut>> {
377        self.write_half_map.get(index).map(|v| v.value().clone())
378    }
379    pub(crate) fn get_write_half_by_key(
380        &self,
381        route_key: &RouteKey,
382    ) -> io::Result<Sender<BytesMut>> {
383        match route_key.index() {
384            Index::Tcp(index) => {
385                let sender = self
386                    .get_write_half(&index)
387                    .ok_or_else(|| io::Error::other(format!("not found {route_key:?}")))?;
388                Ok(sender)
389            }
390            _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
391        }
392    }
393
394    pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
395        if let Some(v) = self.addr_mapping.get(addr) {
396            for index_usize in v.value() {
397                if *index_usize != 0 {
398                    return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
399                }
400            }
401        }
402        None
403    }
404    pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
405        if let Some(v) = self.addr_mapping.get(addr) {
406            assert_eq!(v.len(), self.tcp_multiplexing_limit);
407            let index_usize = v[index];
408            if index_usize == 0 {
409                return None;
410            }
411            return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
412        }
413        None
414    }
415    pub async fn send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
416        let write_half = self.get_write_half_by_key(route_key)?;
417        if buf.is_empty() {
418            return Ok(());
419        }
420        if let Err(_e) = write_half.send(buf).await {
421            Err(io::Error::from(io::ErrorKind::WriteZero))
422        } else {
423            Ok(())
424        }
425    }
426    pub fn try_send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
427        let write_half = self.get_write_half_by_key(route_key)?;
428        if buf.is_empty() {
429            return Ok(());
430        }
431        if let Err(e) = write_half.try_send(buf) {
432            match e {
433                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
434                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
435            }
436        } else {
437            Ok(())
438        }
439    }
440}
441
442pub struct TcpSocketManager {
443    lock: Mutex<()>,
444    local_addr: SocketAddr,
445    tcp_multiplexing_limit: usize,
446    write_half_collect: WriteHalfCollect,
447    connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
448    default_interface: Option<LocalInterface>,
449    init_codec: Arc<Box<dyn InitCodec>>,
450}
451
452impl TcpSocketManager {
453    pub(crate) fn new(
454        local_addr: SocketAddr,
455        tcp_multiplexing_limit: usize,
456        write_half_collect: WriteHalfCollect,
457        connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
458        default_interface: Option<LocalInterface>,
459        init_codec: Arc<Box<dyn InitCodec>>,
460    ) -> Self {
461        Self {
462            local_addr,
463            lock: Default::default(),
464            tcp_multiplexing_limit,
465            write_half_collect,
466            connect_sender,
467            default_interface,
468            init_codec,
469        }
470    }
471    pub fn local_addr(&self) -> SocketAddr {
472        self.local_addr
473    }
474}
475
476impl TcpSocketManager {
477    /// Multiple connections can be initiated to the target address.
478    pub async fn multi_connect(
479        &self,
480        addr: SocketAddr,
481        index_offset: usize,
482    ) -> io::Result<RouteKey> {
483        self.multi_connect_impl(addr, index_offset, None).await
484    }
485    pub(crate) async fn multi_connect_impl(
486        &self,
487        addr: SocketAddr,
488        index_offset: usize,
489        ttl: Option<u8>,
490    ) -> io::Result<RouteKey> {
491        let len = self.tcp_multiplexing_limit;
492        if index_offset >= len {
493            return Err(io::Error::new(
494                io::ErrorKind::InvalidInput,
495                "index out of bounds",
496            ));
497        }
498        let _guard = self.lock.lock().await;
499        if let Some(route_key) = self
500            .write_half_collect
501            .get_limit_route_key(index_offset, &addr)
502        {
503            return Ok(route_key);
504        }
505        self.connect_impl(0, addr, index_offset, ttl).await
506    }
507    /// Initiate a connection.
508    pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
509        let _guard = self.lock.lock().await;
510        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
511            return Ok(route_key);
512        }
513        self.connect_impl(0, addr, 0, None).await
514    }
515    pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
516        let _guard = self.lock.lock().await;
517        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
518            return Ok(route_key);
519        }
520        self.connect_impl(0, addr, 0, ttl).await
521    }
522    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
523    pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
524        let _guard = self.lock.lock().await;
525        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
526            return Ok(route_key);
527        }
528        self.connect_impl(self.local_addr.port(), addr, 0, None)
529            .await
530    }
531    pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
532        let stream = connect_tcp(
533            addr,
534            self.local_addr.port(),
535            self.default_interface.as_ref(),
536            None,
537        )
538        .await?;
539        Ok(stream)
540    }
541    async fn connect_impl(
542        &self,
543        bind_port: u16,
544        addr: SocketAddr,
545        index_offset: usize,
546        ttl: Option<u8>,
547    ) -> io::Result<RouteKey> {
548        let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
549        let route_key = stream.route_key()?;
550        let (read_half, write_half) = stream.into_split();
551        let (decoder, encoder) = self.init_codec.codec(addr)?;
552        let read_half = ReadHalfBox::new(read_half, decoder);
553        let sender =
554            self.write_half_collect
555                .add_write_half(route_key, index_offset, write_half, encoder);
556        if let Err(_e) = self
557            .connect_sender
558            .send((route_key, read_half, sender))
559            .await
560        {
561            Err(io::Error::new(
562                io::ErrorKind::UnexpectedEof,
563                "connect close",
564            ))?
565        }
566        Ok(route_key)
567    }
568}
569
570impl TcpSocketManager {
571    pub async fn multi_send_to<A: Into<SocketAddr>>(
572        &self,
573        buf: BytesMut,
574        addr: A,
575    ) -> io::Result<()> {
576        self.multi_send_to_impl(buf, addr, None).await
577    }
578    pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
579        &self,
580        buf: BytesMut,
581        addr: A,
582        ttl: Option<u8>,
583    ) -> io::Result<()> {
584        let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
585        let route_key = self
586            .multi_connect_impl(addr.into(), index_offset, ttl)
587            .await?;
588        self.send_to(buf, &route_key).await
589    }
590    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
591    pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
592        &self,
593        buf: BytesMut,
594        addr: A,
595    ) -> io::Result<()> {
596        let route_key = self.connect_reuse_port(addr.into()).await?;
597        self.send_to(buf, &route_key).await
598    }
599
600    /// Writing `buf` to the target denoted by `route_key`
601    pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
602        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
603        self.write_half_collect.send_to(buf, &route_key).await
604    }
605    pub async fn send_to_addr<D: Into<SocketAddr>>(
606        &self,
607        buf: BytesMut,
608        dest: D,
609    ) -> io::Result<()> {
610        let route_key = self.connect(dest.into()).await?;
611        self.write_half_collect.send_to(buf, &route_key).await
612    }
613    pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
614        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
615        self.write_half_collect.try_send_to(buf, &route_key)
616    }
617}
618pub trait ToRouteKeyForTcp<T> {
619    fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
620}
621
622impl ToRouteKeyForTcp<()> for RouteKey {
623    fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
624        Ok(dest)
625    }
626}
627
628impl ToRouteKeyForTcp<()> for &RouteKey {
629    fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
630        Ok(*dest)
631    }
632}
633
634impl ToRouteKeyForTcp<()> for &mut RouteKey {
635    fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
636        Ok(*dest)
637    }
638}
639// impl<S: Into<SocketAddr>> ToRouteKeyForTcp<()> for S {
640//     async fn route_key(socket_manager: &SocketManager, dest: Self) -> io::Result<RouteKey> {
641//         socket_manager.connect(dest.into()).await
642//     }
643// }
644
645pub trait TcpStreamIndex {
646    fn route_key(&self) -> io::Result<RouteKey>;
647    fn index(&self) -> Index;
648}
649
650impl TcpStreamIndex for TcpStream {
651    fn route_key(&self) -> io::Result<RouteKey> {
652        let addr = self.peer_addr()?;
653
654        Ok(RouteKey::new(self.index(), addr))
655    }
656
657    fn index(&self) -> Index {
658        #[cfg(windows)]
659        use std::os::windows::io::AsRawSocket;
660        #[cfg(windows)]
661        let index = self.as_raw_socket() as usize;
662        #[cfg(unix)]
663        use std::os::fd::{AsFd, AsRawFd};
664        #[cfg(unix)]
665        let index = self.as_fd().as_raw_fd() as usize;
666        Index::Tcp(index)
667    }
668}
669
670/// The default byte encoder/decoder; using this is no different from directly using a TCP reliable.
671pub struct BytesCodec;
672
673/// Fixed-length prefix encoder/decoder.
674pub struct LengthPrefixedCodec;
675
676#[async_trait]
677impl Decoder for BytesCodec {
678    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
679        let len = read.read(src).await?;
680        Ok(len)
681    }
682}
683
684#[async_trait]
685impl Encoder for BytesCodec {
686    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
687        write.write_all(data).await?;
688        Ok(())
689    }
690}
691
692#[async_trait]
693impl Decoder for LengthPrefixedCodec {
694    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
695        let mut head = [0; 4];
696        read.read_exact(&mut head).await?;
697        let len = u32::from_be_bytes(head) as usize;
698        read.read_exact(&mut src[..len]).await?;
699        Ok(len)
700    }
701}
702
703#[async_trait]
704impl Encoder for LengthPrefixedCodec {
705    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
706        let head: [u8; 4] = (data.len() as u32).to_be_bytes();
707        write.write_all(&head).await?;
708        write.write_all(data).await?;
709        Ok(())
710    }
711}
712
713#[derive(Clone)]
714pub struct BytesInitCodec;
715
716impl InitCodec for BytesInitCodec {
717    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
718        Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
719    }
720}
721
722#[derive(Clone)]
723pub struct LengthPrefixedInitCodec;
724
725impl InitCodec for LengthPrefixedInitCodec {
726    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
727        Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
728    }
729}
730
731pub trait InitCodec: Send + Sync + DynClone {
732    fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
733}
734dyn_clone::clone_trait_object!(InitCodec);
735
736#[async_trait]
737pub trait Decoder: Send + Sync {
738    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
739    fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
740        Err(io::Error::from(io::ErrorKind::WouldBlock))
741    }
742}
743
744#[async_trait]
745pub trait Encoder: Send + Sync {
746    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
747    async fn encode_multiple(
748        &mut self,
749        write: &mut OwnedWriteHalf,
750        bufs: &[IoSlice<'_>],
751    ) -> io::Result<()> {
752        for buf in bufs {
753            self.encode(write, buf).await?
754        }
755        Ok(())
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use async_trait::async_trait;
762    use std::io;
763    use std::net::SocketAddr;
764    use tokio::io::{AsyncReadExt, AsyncWriteExt};
765    use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
766
767    use crate::tunnel::config::TcpTunnelConfig;
768    use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
769
770    #[tokio::test]
771    pub async fn create_tcp_tunnel() {
772        let config: TcpTunnelConfig = TcpTunnelConfig::default();
773        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
774        drop(tcp_tunnel_factory)
775    }
776
777    #[tokio::test]
778    pub async fn create_codec_tcp_tunnel() {
779        let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
780        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
781        drop(tcp_tunnel_factory)
782    }
783
784    #[derive(Clone)]
785    struct MyInitCodeC;
786
787    impl InitCodec for MyInitCodeC {
788        fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
789            Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
790        }
791    }
792
793    struct MyCodeC;
794
795    #[async_trait]
796    impl Decoder for MyCodeC {
797        async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
798            let mut head = [0; 2];
799            read.read_exact(&mut head).await?;
800            let len = u16::from_be_bytes(head) as usize;
801            read.read_exact(&mut src[..len]).await?;
802            Ok(len)
803        }
804    }
805
806    #[async_trait]
807    impl Encoder for MyCodeC {
808        async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
809            let head: [u8; 2] = (data.len() as u16).to_be_bytes();
810            write.write_all(&head).await?;
811            write.write_all(data).await?;
812            Ok(())
813        }
814    }
815}