rust_p2p_core/tunnel/
tcp.rs

1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use crate::tunnel::recycle::RecycleBuf;
5use async_lock::Mutex;
6use async_trait::async_trait;
7use bytes::BytesMut;
8use dashmap::DashMap;
9use dyn_clone::DynClone;
10use rand::Rng;
11use std::io;
12use std::io::IoSlice;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tachyonix::{Receiver, Sender, TrySendError};
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
19use tokio::net::{TcpListener, TcpStream};
20
21pub struct TcpTunnelDispatcher {
22    route_idle_time: Duration,
23    tcp_listener: TcpListener,
24    connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
25    #[allow(dead_code)]
26    pub(crate) socket_manager: Arc<TcpSocketManager>,
27    write_half_collect: WriteHalfCollect,
28    init_codec: Arc<Box<dyn InitCodec>>,
29}
30
31impl TcpTunnelDispatcher {
32    /// Construct a `TCP` tunnel with the specified configuration
33    pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
34        config.check()?;
35        let address: SocketAddr = if config.use_v6 {
36            format!("[::]:{}", config.tcp_port).parse().unwrap()
37        } else {
38            format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
39        };
40
41        let tcp_listener = create_tcp_listener(address)?;
42        let local_addr = tcp_listener.local_addr()?;
43        let tcp_listener = TcpListener::from_std(tcp_listener)?;
44        let (connect_sender, connect_receiver) = tachyonix::channel(128);
45        let write_half_collect =
46            WriteHalfCollect::new(config.tcp_multiplexing_limit, config.recycle_buf);
47        let init_codec = Arc::new(config.init_codec);
48        let socket_manager = Arc::new(TcpSocketManager::new(
49            local_addr,
50            config.tcp_multiplexing_limit,
51            write_half_collect.clone(),
52            connect_sender,
53            config.default_interface,
54            init_codec.clone(),
55        ));
56        Ok(TcpTunnelDispatcher {
57            route_idle_time: config.route_idle_time,
58            tcp_listener,
59            connect_receiver,
60            socket_manager,
61            write_half_collect,
62            init_codec,
63        })
64    }
65}
66
67impl TcpTunnelDispatcher {
68    /// Dispatch `TCP` tunnel from this kind Dispatcher
69    pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
70        tokio::select! {
71            rs=self.connect_receiver.recv()=>{
72                let (route_key,read_half,sender) = rs.
73                    map_err(|_| io::Error::other("connect_receiver done"))?;
74                let local_addr = read_half.read_half.local_addr()?;
75                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
76            },
77            rs=self.tcp_listener.accept()=>{
78                let (tcp_stream,addr) = rs?;
79                tcp_stream.set_nodelay(true)?;
80                let route_key = tcp_stream.route_key()?;
81                let (read_half,write_half) = tcp_stream.into_split();
82                let (decoder,encoder) = self.init_codec.codec(addr)?;
83                let read_half = ReadHalfBox::new(read_half,decoder);
84                let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
85                let local_addr = read_half.read_half.local_addr()?;
86                Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
87            }
88        }
89    }
90    pub fn manager(&self) -> &Arc<TcpSocketManager> {
91        &self.socket_manager
92    }
93}
94
95pub struct TcpTunnel {
96    local_addr: SocketAddr,
97    route_key: RouteKey,
98    route_idle_time: Duration,
99    tcp_read: OwnedReadHalf,
100    decoder: Box<dyn Decoder>,
101    sender: Sender<BytesMut>,
102}
103
104impl TcpTunnel {
105    pub(crate) fn new(
106        local_addr: SocketAddr,
107        route_idle_time: Duration,
108        route_key: RouteKey,
109        read: ReadHalfBox,
110        sender: Sender<BytesMut>,
111    ) -> Self {
112        let decoder = read.decoder;
113        let tcp_read = read.read_half;
114        Self {
115            local_addr,
116            route_key,
117            route_idle_time,
118            tcp_read,
119            decoder,
120            sender,
121        }
122    }
123    #[inline]
124    pub fn route_key(&self) -> RouteKey {
125        self.route_key
126    }
127    pub fn local_addr(&self) -> SocketAddr {
128        self.local_addr
129    }
130    pub fn done(&mut self) {
131        self.sender.close();
132    }
133    pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
134        Ok(WeakTcpTunnelSender::new(self.sender.clone()))
135    }
136}
137#[derive(Clone)]
138pub struct WeakTcpTunnelSender {
139    sender: Sender<BytesMut>,
140}
141impl WeakTcpTunnelSender {
142    fn new(sender: Sender<BytesMut>) -> Self {
143        Self { sender }
144    }
145    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
146        if buf.is_empty() {
147            return Ok(());
148        }
149        self.sender
150            .send(buf)
151            .await
152            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
153    }
154}
155
156impl Drop for TcpTunnel {
157    fn drop(&mut self) {
158        self.done();
159    }
160}
161
162impl TcpTunnel {
163    /// Writing `buf` to the target denoted by `route_key` via this tunnel
164    pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
165        if buf.is_empty() {
166            return Ok(());
167        }
168        self.sender
169            .send(buf)
170            .await
171            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
172    }
173
174    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175        match tokio::time::timeout(
176            self.route_idle_time,
177            self.decoder.decode(&mut self.tcp_read, buf),
178        )
179        .await
180        {
181            Ok(rs) => rs,
182            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
183        }
184    }
185    pub async fn batch_recv<B: AsMut<[u8]>>(
186        &mut self,
187        bufs: &mut [B],
188        sizes: &mut [usize],
189    ) -> io::Result<usize> {
190        if bufs.is_empty() || bufs.len() != sizes.len() {
191            return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
192        }
193        match tokio::time::timeout(
194            self.route_idle_time,
195            self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
196        )
197        .await
198        {
199            Ok(rs) => {
200                let len = rs?;
201                sizes[0] = len;
202                let mut num = 1;
203                while num < bufs.len() {
204                    let rs = self
205                        .decoder
206                        .try_decode(&mut self.tcp_read, bufs[num].as_mut());
207                    match rs {
208                        Ok(len) => {
209                            sizes[num] = len;
210                            num += 1;
211                        }
212                        Err(_e) => break,
213                    }
214                }
215
216                Ok(num)
217            }
218            Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
219        }
220    }
221    /// Receive bytes from this tunnel, which the configured Decoder pre-processes
222    /// `usize` in the `Ok` branch indicates how many bytes are received
223    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
224    pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
225        match self.recv(buf).await {
226            Ok(len) => Ok((len, self.route_key())),
227            Err(e) => Err(e),
228        }
229    }
230    pub async fn batch_recv_from<B: AsMut<[u8]>>(
231        &mut self,
232        bufs: &mut [B],
233        sizes: &mut [usize],
234    ) -> io::Result<(usize, RouteKey)> {
235        match self.batch_recv(bufs, sizes).await {
236            Ok(len) => Ok((len, self.route_key())),
237            Err(e) => Err(e),
238        }
239    }
240}
241
242#[derive(Clone)]
243pub struct WriteHalfCollect {
244    tcp_multiplexing_limit: usize,
245    addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
246    write_half_map: Arc<DashMap<usize, Sender<BytesMut>>>,
247    recycle_buf: Option<RecycleBuf>,
248}
249
250impl WriteHalfCollect {
251    fn new(tcp_multiplexing_limit: usize, recycle_buf: Option<RecycleBuf>) -> Self {
252        Self {
253            tcp_multiplexing_limit,
254            addr_mapping: Default::default(),
255            write_half_map: Default::default(),
256            recycle_buf,
257        }
258    }
259}
260
261pub(crate) struct ReadHalfBox {
262    read_half: OwnedReadHalf,
263    decoder: Box<dyn Decoder>,
264}
265
266impl ReadHalfBox {
267    pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
268        Self { read_half, decoder }
269    }
270}
271
272impl WriteHalfCollect {
273    pub(crate) fn add_write_half(
274        &self,
275        route_key: RouteKey,
276        index_offset: usize,
277        mut writer: OwnedWriteHalf,
278        mut decoder: Box<dyn Encoder>,
279    ) -> Sender<BytesMut> {
280        assert!(index_offset < self.tcp_multiplexing_limit);
281
282        let index = route_key.index_usize();
283        let _ref = self
284            .addr_mapping
285            .entry(route_key.addr())
286            .and_modify(|v| {
287                v[index_offset] = index;
288            })
289            .or_insert_with(|| {
290                let mut v = vec![0; self.tcp_multiplexing_limit];
291                v[index_offset] = index;
292                v
293            });
294        let (s, mut r) = tachyonix::channel(128);
295        let sender = s.clone();
296        self.write_half_map.insert(index, s);
297        let collect = self.clone();
298        let recycle_buf = self.recycle_buf.clone();
299        tokio::spawn(async move {
300            let mut vec_buf = Vec::with_capacity(16);
301            const IO_SLICE_CAPACITY: usize = 16;
302            let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
303            let io_slice_storage = io_buffer.as_mut_slice();
304            while let Ok(v) = r.recv().await {
305                if let Ok(buf) = r.try_recv() {
306                    vec_buf.push(v);
307                    vec_buf.push(buf);
308                    while let Ok(buf) = r.try_recv() {
309                        vec_buf.push(buf);
310                        if vec_buf.len() == 16 {
311                            break;
312                        }
313                    }
314
315                    // Safety
316                    // reuse the storage of `io_buffer` via `vec` that only lives in this block and manually clear the content
317                    // within the storage when exiting the block
318                    // leak the memory storage after using `vec` since the storage is managed by `io_buffer`
319                    let rs = {
320                        let mut vec = unsafe {
321                            Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
322                        };
323                        for x in &vec_buf {
324                            vec.push(IoSlice::new(x));
325                        }
326                        let rs = decoder.encode_multiple(&mut writer, &vec).await;
327                        vec.clear();
328                        std::mem::forget(vec);
329                        rs
330                    };
331
332                    if let Some(recycle_buf) = recycle_buf.as_ref() {
333                        while let Some(buf) = vec_buf.pop() {
334                            recycle_buf.push(buf);
335                        }
336                    } else {
337                        vec_buf.clear()
338                    }
339                    if let Err(e) = rs {
340                        log::debug!("{route_key:?},{e:?}");
341                        break;
342                    }
343                } else {
344                    let rs = decoder.encode(&mut writer, &v).await;
345                    if let Some(recycle_buf) = recycle_buf.as_ref() {
346                        recycle_buf.push(v);
347                    }
348                    if let Err(e) = rs {
349                        log::debug!("{route_key:?},{e:?}");
350                        break;
351                    }
352                }
353            }
354            collect.remove(&route_key);
355        });
356        sender
357    }
358    pub(crate) fn remove(&self, route_key: &RouteKey) {
359        let index_usize = route_key.index_usize();
360        self.addr_mapping
361            .remove_if_mut(&route_key.addr(), |_k, index_vec| {
362                let mut remove = true;
363                for v in index_vec {
364                    if *v == index_usize {
365                        *v = 0;
366                    }
367                    if *v != 0 {
368                        remove = false;
369                    }
370                }
371                remove
372            });
373
374        self.write_half_map.remove(&index_usize);
375    }
376    pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<BytesMut>> {
377        self.write_half_map.get(index).map(|v| v.value().clone())
378    }
379    pub(crate) fn get_write_half_by_key(
380        &self,
381        route_key: &RouteKey,
382    ) -> io::Result<Sender<BytesMut>> {
383        match route_key.index() {
384            Index::Tcp(index) => {
385                let sender = self
386                    .get_write_half(&index)
387                    .ok_or_else(|| io::Error::other(format!("not found {route_key:?}")))?;
388                Ok(sender)
389            }
390            _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
391        }
392    }
393
394    pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
395        if let Some(v) = self.addr_mapping.get(addr) {
396            for index_usize in v.value() {
397                if *index_usize != 0 {
398                    return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
399                }
400            }
401        }
402        None
403    }
404    pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
405        if let Some(v) = self.addr_mapping.get(addr) {
406            assert_eq!(v.len(), self.tcp_multiplexing_limit);
407            let index_usize = v[index];
408            if index_usize == 0 {
409                return None;
410            }
411            return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
412        }
413        None
414    }
415    pub async fn send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
416        let write_half = self.get_write_half_by_key(route_key)?;
417        if buf.is_empty() {
418            return Ok(());
419        }
420        if let Err(_e) = write_half.send(buf).await {
421            Err(io::Error::from(io::ErrorKind::WriteZero))
422        } else {
423            Ok(())
424        }
425    }
426    pub fn try_send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
427        let write_half = self.get_write_half_by_key(route_key)?;
428        if buf.is_empty() {
429            return Ok(());
430        }
431        if let Err(e) = write_half.try_send(buf) {
432            match e {
433                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
434                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
435            }
436        } else {
437            Ok(())
438        }
439    }
440}
441
442pub struct TcpSocketManager {
443    lock: DashMap<SocketAddr, Arc<Mutex<()>>>,
444    local_addr: SocketAddr,
445    tcp_multiplexing_limit: usize,
446    write_half_collect: WriteHalfCollect,
447    connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
448    default_interface: Option<LocalInterface>,
449    init_codec: Arc<Box<dyn InitCodec>>,
450}
451
452impl TcpSocketManager {
453    pub(crate) fn new(
454        local_addr: SocketAddr,
455        tcp_multiplexing_limit: usize,
456        write_half_collect: WriteHalfCollect,
457        connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
458        default_interface: Option<LocalInterface>,
459        init_codec: Arc<Box<dyn InitCodec>>,
460    ) -> Self {
461        Self {
462            local_addr,
463            lock: Default::default(),
464            tcp_multiplexing_limit,
465            write_half_collect,
466            connect_sender,
467            default_interface,
468            init_codec,
469        }
470    }
471    pub fn local_addr(&self) -> SocketAddr {
472        self.local_addr
473    }
474}
475
476impl TcpSocketManager {
477    /// Multiple connections can be initiated to the target address.
478    pub async fn multi_connect(
479        &self,
480        addr: SocketAddr,
481        index_offset: usize,
482    ) -> io::Result<RouteKey> {
483        self.multi_connect_impl(addr, index_offset, None).await
484    }
485    pub(crate) async fn multi_connect_impl(
486        &self,
487        addr: SocketAddr,
488        index_offset: usize,
489        ttl: Option<u8>,
490    ) -> io::Result<RouteKey> {
491        let len = self.tcp_multiplexing_limit;
492        if index_offset >= len {
493            return Err(io::Error::new(
494                io::ErrorKind::InvalidInput,
495                "index out of bounds",
496            ));
497        }
498        let lock = self
499            .lock
500            .entry(addr)
501            .or_insert_with(|| Arc::new(Mutex::new(())))
502            .value()
503            .clone();
504        let _guard = lock.lock().await;
505        if let Some(route_key) = self
506            .write_half_collect
507            .get_limit_route_key(index_offset, &addr)
508        {
509            return Ok(route_key);
510        }
511        self.connect_impl(0, addr, index_offset, ttl).await
512    }
513    /// Initiate a connection.
514    pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
515        let lock = self
516            .lock
517            .entry(addr)
518            .or_insert_with(|| Arc::new(Mutex::new(())))
519            .value()
520            .clone();
521        let _guard = lock.lock().await;
522        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
523            return Ok(route_key);
524        }
525        self.connect_impl(0, addr, 0, None).await
526    }
527    pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
528        let lock = self
529            .lock
530            .entry(addr)
531            .or_insert_with(|| Arc::new(Mutex::new(())))
532            .value()
533            .clone();
534        let _guard = lock.lock().await;
535        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
536            return Ok(route_key);
537        }
538        self.connect_impl(0, addr, 0, ttl).await
539    }
540    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
541    pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
542        let lock = self
543            .lock
544            .entry(addr)
545            .or_insert_with(|| Arc::new(Mutex::new(())))
546            .value()
547            .clone();
548        let _guard = lock.lock().await;
549        if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
550            return Ok(route_key);
551        }
552        self.connect_impl(self.local_addr.port(), addr, 0, None)
553            .await
554    }
555    pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
556        let stream = connect_tcp(
557            addr,
558            self.local_addr.port(),
559            self.default_interface.as_ref(),
560            None,
561        )
562        .await?;
563        Ok(stream)
564    }
565    async fn connect_impl(
566        &self,
567        bind_port: u16,
568        addr: SocketAddr,
569        index_offset: usize,
570        ttl: Option<u8>,
571    ) -> io::Result<RouteKey> {
572        let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
573        let route_key = stream.route_key()?;
574        let (read_half, write_half) = stream.into_split();
575        let (decoder, encoder) = self.init_codec.codec(addr)?;
576        let read_half = ReadHalfBox::new(read_half, decoder);
577        let sender =
578            self.write_half_collect
579                .add_write_half(route_key, index_offset, write_half, encoder);
580        if let Err(_e) = self
581            .connect_sender
582            .send((route_key, read_half, sender))
583            .await
584        {
585            Err(io::Error::new(
586                io::ErrorKind::UnexpectedEof,
587                "connect close",
588            ))?
589        }
590        Ok(route_key)
591    }
592}
593
594impl TcpSocketManager {
595    pub async fn multi_send_to<A: Into<SocketAddr>>(
596        &self,
597        buf: BytesMut,
598        addr: A,
599    ) -> io::Result<()> {
600        self.multi_send_to_impl(buf, addr, None).await
601    }
602    pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
603        &self,
604        buf: BytesMut,
605        addr: A,
606        ttl: Option<u8>,
607    ) -> io::Result<()> {
608        let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
609        let route_key = self
610            .multi_connect_impl(addr.into(), index_offset, ttl)
611            .await?;
612        self.send_to(buf, &route_key).await
613    }
614    /// Reuse the bound port to initiate a connection, which can be used to penetrate NAT1 network type.
615    pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
616        &self,
617        buf: BytesMut,
618        addr: A,
619    ) -> io::Result<()> {
620        let route_key = self.connect_reuse_port(addr.into()).await?;
621        self.send_to(buf, &route_key).await
622    }
623
624    /// Writing `buf` to the target denoted by `route_key`
625    pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
626        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
627        self.write_half_collect.send_to(buf, &route_key).await
628    }
629    pub async fn send_to_addr<D: Into<SocketAddr>>(
630        &self,
631        buf: BytesMut,
632        dest: D,
633    ) -> io::Result<()> {
634        let route_key = self.connect(dest.into()).await?;
635        self.write_half_collect.send_to(buf, &route_key).await
636    }
637    pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
638        let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
639        self.write_half_collect.try_send_to(buf, &route_key)
640    }
641}
642pub trait ToRouteKeyForTcp<T> {
643    fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
644}
645
646impl ToRouteKeyForTcp<()> for RouteKey {
647    fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
648        Ok(dest)
649    }
650}
651
652impl ToRouteKeyForTcp<()> for &RouteKey {
653    fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
654        Ok(*dest)
655    }
656}
657
658impl ToRouteKeyForTcp<()> for &mut RouteKey {
659    fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
660        Ok(*dest)
661    }
662}
663// impl<S: Into<SocketAddr>> ToRouteKeyForTcp<()> for S {
664//     async fn route_key(socket_manager: &SocketManager, dest: Self) -> io::Result<RouteKey> {
665//         socket_manager.connect(dest.into()).await
666//     }
667// }
668
669pub trait TcpStreamIndex {
670    fn route_key(&self) -> io::Result<RouteKey>;
671    fn index(&self) -> Index;
672}
673
674impl TcpStreamIndex for TcpStream {
675    fn route_key(&self) -> io::Result<RouteKey> {
676        let addr = self.peer_addr()?;
677
678        Ok(RouteKey::new(self.index(), addr))
679    }
680
681    fn index(&self) -> Index {
682        #[cfg(windows)]
683        use std::os::windows::io::AsRawSocket;
684        #[cfg(windows)]
685        let index = self.as_raw_socket() as usize;
686        #[cfg(unix)]
687        use std::os::fd::{AsFd, AsRawFd};
688        #[cfg(unix)]
689        let index = self.as_fd().as_raw_fd() as usize;
690        Index::Tcp(index)
691    }
692}
693
694/// The default byte encoder/decoder; using this is no different from directly using a TCP reliable.
695pub struct BytesCodec;
696
697/// Fixed-length prefix encoder/decoder.
698pub struct LengthPrefixedCodec;
699
700#[async_trait]
701impl Decoder for BytesCodec {
702    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
703        let len = read.read(src).await?;
704        Ok(len)
705    }
706}
707
708#[async_trait]
709impl Encoder for BytesCodec {
710    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
711        write.write_all(data).await?;
712        Ok(())
713    }
714}
715
716#[async_trait]
717impl Decoder for LengthPrefixedCodec {
718    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
719        let mut head = [0; 4];
720        read.read_exact(&mut head).await?;
721        let len = u32::from_be_bytes(head) as usize;
722        read.read_exact(&mut src[..len]).await?;
723        Ok(len)
724    }
725}
726
727#[async_trait]
728impl Encoder for LengthPrefixedCodec {
729    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
730        let head: [u8; 4] = (data.len() as u32).to_be_bytes();
731        write.write_all(&head).await?;
732        write.write_all(data).await?;
733        Ok(())
734    }
735}
736
737#[derive(Clone)]
738pub struct BytesInitCodec;
739
740impl InitCodec for BytesInitCodec {
741    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
742        Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
743    }
744}
745
746#[derive(Clone)]
747pub struct LengthPrefixedInitCodec;
748
749impl InitCodec for LengthPrefixedInitCodec {
750    fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
751        Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
752    }
753}
754
755pub trait InitCodec: Send + Sync + DynClone {
756    fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
757}
758dyn_clone::clone_trait_object!(InitCodec);
759
760#[async_trait]
761pub trait Decoder: Send + Sync {
762    async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
763    fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
764        Err(io::Error::from(io::ErrorKind::WouldBlock))
765    }
766}
767
768#[async_trait]
769pub trait Encoder: Send + Sync {
770    async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
771    async fn encode_multiple(
772        &mut self,
773        write: &mut OwnedWriteHalf,
774        bufs: &[IoSlice<'_>],
775    ) -> io::Result<()> {
776        for buf in bufs {
777            self.encode(write, buf).await?
778        }
779        Ok(())
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use async_trait::async_trait;
786    use std::io;
787    use std::net::SocketAddr;
788    use tokio::io::{AsyncReadExt, AsyncWriteExt};
789    use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
790
791    use crate::tunnel::config::TcpTunnelConfig;
792    use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
793
794    #[tokio::test]
795    pub async fn create_tcp_tunnel() {
796        let config: TcpTunnelConfig = TcpTunnelConfig::default();
797        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
798        drop(tcp_tunnel_factory)
799    }
800
801    #[tokio::test]
802    pub async fn create_codec_tcp_tunnel() {
803        let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
804        let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
805        drop(tcp_tunnel_factory)
806    }
807
808    #[derive(Clone)]
809    struct MyInitCodeC;
810
811    impl InitCodec for MyInitCodeC {
812        fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
813            Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
814        }
815    }
816
817    struct MyCodeC;
818
819    #[async_trait]
820    impl Decoder for MyCodeC {
821        async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
822            let mut head = [0; 2];
823            read.read_exact(&mut head).await?;
824            let len = u16::from_be_bytes(head) as usize;
825            read.read_exact(&mut src[..len]).await?;
826            Ok(len)
827        }
828    }
829
830    #[async_trait]
831    impl Encoder for MyCodeC {
832        async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
833            let head: [u8; 2] = (data.len() as u16).to_be_bytes();
834            write.write_all(&head).await?;
835            write.write_all(data).await?;
836            Ok(())
837        }
838    }
839}