1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use async_lock::Mutex;
5use async_trait::async_trait;
6use bytes::Bytes;
7use dashmap::DashMap;
8use dyn_clone::DynClone;
9use rand::Rng;
10use std::io;
11use std::io::IoSlice;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15use tachyonix::{Receiver, Sender, TrySendError};
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
18use tokio::net::{TcpListener, TcpStream};
19
20pub struct TcpTunnelDispatcher {
21 route_idle_time: Duration,
22 tcp_listener: TcpListener,
23 connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
24 #[allow(dead_code)]
25 pub(crate) socket_manager: Arc<TcpSocketManager>,
26 write_half_collect: WriteHalfCollect,
27 init_codec: Arc<Box<dyn InitCodec>>,
28}
29
30impl TcpTunnelDispatcher {
31 pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
33 config.check()?;
34 let address: SocketAddr = if config.use_v6 {
35 format!("[::]:{}", config.tcp_port).parse().unwrap()
36 } else {
37 format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
38 };
39
40 let tcp_listener = create_tcp_listener(address)?;
41 let local_addr = tcp_listener.local_addr()?;
42 let tcp_listener = TcpListener::from_std(tcp_listener)?;
43 let (connect_sender, connect_receiver) = tachyonix::channel(128);
44 let write_half_collect = WriteHalfCollect::new(config.tcp_multiplexing_limit);
45 let init_codec = Arc::new(config.init_codec);
46 let socket_manager = Arc::new(TcpSocketManager::new(
47 local_addr,
48 config.tcp_multiplexing_limit,
49 write_half_collect.clone(),
50 connect_sender,
51 config.default_interface,
52 init_codec.clone(),
53 ));
54 Ok(TcpTunnelDispatcher {
55 route_idle_time: config.route_idle_time,
56 tcp_listener,
57 connect_receiver,
58 socket_manager,
59 write_half_collect,
60 init_codec,
61 })
62 }
63}
64
65impl TcpTunnelDispatcher {
66 pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
68 tokio::select! {
69 rs=self.connect_receiver.recv()=>{
70 let (route_key,read_half,sender) = rs.
71 map_err(|_| io::Error::other("connect_receiver done"))?;
72 let local_addr = read_half.read_half.local_addr()?;
73 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
74 },
75 rs=self.tcp_listener.accept()=>{
76 let (tcp_stream,addr) = rs?;
77 tcp_stream.set_nodelay(true)?;
78 let route_key = tcp_stream.route_key()?;
79 let (read_half,write_half) = tcp_stream.into_split();
80 let (decoder,encoder) = self.init_codec.codec(addr)?;
81 let read_half = ReadHalfBox::new(read_half,decoder);
82 let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
83 let local_addr = read_half.read_half.local_addr()?;
84 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
85 }
86 }
87 }
88 pub fn manager(&self) -> &Arc<TcpSocketManager> {
89 &self.socket_manager
90 }
91}
92
93pub struct TcpTunnel {
94 local_addr: SocketAddr,
95 route_key: RouteKey,
96 route_idle_time: Duration,
97 tcp_read: OwnedReadHalf,
98 decoder: Box<dyn Decoder>,
99 sender: Sender<Bytes>,
100}
101
102impl TcpTunnel {
103 pub(crate) fn new(
104 local_addr: SocketAddr,
105 route_idle_time: Duration,
106 route_key: RouteKey,
107 read: ReadHalfBox,
108 sender: Sender<Bytes>,
109 ) -> Self {
110 let decoder = read.decoder;
111 let tcp_read = read.read_half;
112 Self {
113 local_addr,
114 route_key,
115 route_idle_time,
116 tcp_read,
117 decoder,
118 sender,
119 }
120 }
121 #[inline]
122 pub fn route_key(&self) -> RouteKey {
123 self.route_key
124 }
125 pub fn local_addr(&self) -> SocketAddr {
126 self.local_addr
127 }
128 pub fn done(&mut self) {
129 self.sender.close();
130 }
131 pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
132 Ok(WeakTcpTunnelSender::new(self.sender.clone()))
133 }
134}
135#[derive(Clone)]
136pub struct WeakTcpTunnelSender {
137 sender: Sender<Bytes>,
138}
139impl WeakTcpTunnelSender {
140 fn new(sender: Sender<Bytes>) -> Self {
141 Self { sender }
142 }
143 pub async fn send(&self, buf: Bytes) -> io::Result<()> {
144 if buf.is_empty() {
145 return Ok(());
146 }
147 self.sender
148 .send(buf)
149 .await
150 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
151 }
152}
153
154impl Drop for TcpTunnel {
155 fn drop(&mut self) {
156 self.done();
157 }
158}
159
160impl TcpTunnel {
161 pub async fn send(&self, buf: Bytes) -> io::Result<()> {
163 if buf.is_empty() {
164 return Ok(());
165 }
166 self.sender
167 .send(buf)
168 .await
169 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
170 }
171
172 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
173 match tokio::time::timeout(
174 self.route_idle_time,
175 self.decoder.decode(&mut self.tcp_read, buf),
176 )
177 .await
178 {
179 Ok(rs) => rs,
180 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
181 }
182 }
183 pub async fn batch_recv<B: AsMut<[u8]>>(
184 &mut self,
185 bufs: &mut [B],
186 sizes: &mut [usize],
187 ) -> io::Result<usize> {
188 if bufs.is_empty() || bufs.len() != sizes.len() {
189 return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
190 }
191 match tokio::time::timeout(
192 self.route_idle_time,
193 self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
194 )
195 .await
196 {
197 Ok(rs) => {
198 let len = rs?;
199 sizes[0] = len;
200 let mut num = 1;
201 while num < bufs.len() {
202 let rs = self
203 .decoder
204 .try_decode(&mut self.tcp_read, bufs[num].as_mut());
205 match rs {
206 Ok(len) => {
207 sizes[num] = len;
208 num += 1;
209 }
210 Err(_e) => break,
211 }
212 }
213
214 Ok(num)
215 }
216 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
217 }
218 }
219 pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
223 match self.recv(buf).await {
224 Ok(len) => Ok((len, self.route_key())),
225 Err(e) => Err(e),
226 }
227 }
228 pub async fn batch_recv_from<B: AsMut<[u8]>>(
229 &mut self,
230 bufs: &mut [B],
231 sizes: &mut [usize],
232 ) -> io::Result<(usize, RouteKey)> {
233 match self.batch_recv(bufs, sizes).await {
234 Ok(len) => Ok((len, self.route_key())),
235 Err(e) => Err(e),
236 }
237 }
238}
239
240#[derive(Clone)]
241pub struct WriteHalfCollect {
242 tcp_multiplexing_limit: usize,
243 addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
244 write_half_map: Arc<DashMap<usize, Sender<Bytes>>>,
245}
246
247impl WriteHalfCollect {
248 fn new(tcp_multiplexing_limit: usize) -> Self {
249 Self {
250 tcp_multiplexing_limit,
251 addr_mapping: Default::default(),
252 write_half_map: Default::default(),
253 }
254 }
255}
256
257pub(crate) struct ReadHalfBox {
258 read_half: OwnedReadHalf,
259 decoder: Box<dyn Decoder>,
260}
261
262impl ReadHalfBox {
263 pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
264 Self { read_half, decoder }
265 }
266}
267
268impl WriteHalfCollect {
269 pub(crate) fn add_write_half(
270 &self,
271 route_key: RouteKey,
272 index_offset: usize,
273 mut writer: OwnedWriteHalf,
274 mut decoder: Box<dyn Encoder>,
275 ) -> Sender<Bytes> {
276 assert!(index_offset < self.tcp_multiplexing_limit);
277
278 let index = route_key.index_usize();
279 let _ref = self
280 .addr_mapping
281 .entry(route_key.addr())
282 .and_modify(|v| {
283 v[index_offset] = index;
284 })
285 .or_insert_with(|| {
286 let mut v = vec![0; self.tcp_multiplexing_limit];
287 v[index_offset] = index;
288 v
289 });
290 let (s, mut r) = tachyonix::channel(128);
291 let sender = s.clone();
292 self.write_half_map.insert(index, s);
293 let collect = self.clone();
294 tokio::spawn(async move {
295 let mut vec_buf = Vec::with_capacity(16);
296 const IO_SLICE_CAPACITY: usize = 16;
297 let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
298 let io_slice_storage = io_buffer.as_mut_slice();
299 while let Ok(v) = r.recv().await {
300 if let Ok(buf) = r.try_recv() {
301 vec_buf.push(v);
302 vec_buf.push(buf);
303 while let Ok(buf) = r.try_recv() {
304 vec_buf.push(buf);
305 if vec_buf.len() == 16 {
306 break;
307 }
308 }
309
310 let rs = {
315 let mut vec = unsafe {
316 Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
317 };
318 for x in &vec_buf {
319 vec.push(IoSlice::new(x));
320 }
321 let rs = decoder.encode_multiple(&mut writer, &vec).await;
322 vec.clear();
323 std::mem::forget(vec);
324 rs
325 };
326 vec_buf.clear();
327 if let Err(e) = rs {
328 log::debug!("{route_key:?},{e:?}");
329 break;
330 }
331 } else {
332 let rs = decoder.encode(&mut writer, &v).await;
333 if let Err(e) = rs {
334 log::debug!("{route_key:?},{e:?}");
335 break;
336 }
337 }
338 }
339 collect.remove(&route_key);
340 });
341 sender
342 }
343 pub(crate) fn remove(&self, route_key: &RouteKey) {
344 let index_usize = route_key.index_usize();
345 self.addr_mapping
346 .remove_if_mut(&route_key.addr(), |_k, index_vec| {
347 let mut remove = true;
348 for v in index_vec {
349 if *v == index_usize {
350 *v = 0;
351 }
352 if *v != 0 {
353 remove = false;
354 }
355 }
356 remove
357 });
358
359 self.write_half_map.remove(&index_usize);
360 }
361 pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<Bytes>> {
362 self.write_half_map.get(index).map(|v| v.value().clone())
363 }
364 pub(crate) fn get_write_half_by_key(&self, route_key: &RouteKey) -> io::Result<Sender<Bytes>> {
365 match route_key.index() {
366 Index::Tcp(index) => {
367 let sender = self
368 .get_write_half(&index)
369 .ok_or_else(|| io::Error::other(format!("not found {route_key:?}")))?;
370 Ok(sender)
371 }
372 _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
373 }
374 }
375
376 pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
377 if let Some(v) = self.addr_mapping.get(addr) {
378 for index_usize in v.value() {
379 if *index_usize != 0 {
380 return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
381 }
382 }
383 }
384 None
385 }
386 pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
387 if let Some(v) = self.addr_mapping.get(addr) {
388 assert_eq!(v.len(), self.tcp_multiplexing_limit);
389 let index_usize = v[index];
390 if index_usize == 0 {
391 return None;
392 }
393 return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
394 }
395 None
396 }
397 pub async fn send_to(&self, buf: Bytes, route_key: &RouteKey) -> io::Result<()> {
398 let write_half = self.get_write_half_by_key(route_key)?;
399 if buf.is_empty() {
400 return Ok(());
401 }
402 if let Err(_e) = write_half.send(buf).await {
403 Err(io::Error::from(io::ErrorKind::WriteZero))
404 } else {
405 Ok(())
406 }
407 }
408 pub fn try_send_to(&self, buf: Bytes, route_key: &RouteKey) -> io::Result<()> {
409 let write_half = self.get_write_half_by_key(route_key)?;
410 if buf.is_empty() {
411 return Ok(());
412 }
413 if let Err(e) = write_half.try_send(buf) {
414 match e {
415 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
416 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
417 }
418 } else {
419 Ok(())
420 }
421 }
422}
423
424pub struct TcpSocketManager {
425 lock: DashMap<SocketAddr, Arc<Mutex<()>>>,
426 local_addr: SocketAddr,
427 tcp_multiplexing_limit: usize,
428 write_half_collect: WriteHalfCollect,
429 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
430 default_interface: Option<LocalInterface>,
431 init_codec: Arc<Box<dyn InitCodec>>,
432}
433
434impl TcpSocketManager {
435 pub(crate) fn new(
436 local_addr: SocketAddr,
437 tcp_multiplexing_limit: usize,
438 write_half_collect: WriteHalfCollect,
439 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<Bytes>)>,
440 default_interface: Option<LocalInterface>,
441 init_codec: Arc<Box<dyn InitCodec>>,
442 ) -> Self {
443 Self {
444 local_addr,
445 lock: Default::default(),
446 tcp_multiplexing_limit,
447 write_half_collect,
448 connect_sender,
449 default_interface,
450 init_codec,
451 }
452 }
453 pub fn local_addr(&self) -> SocketAddr {
454 self.local_addr
455 }
456}
457
458impl TcpSocketManager {
459 pub async fn multi_connect(
461 &self,
462 addr: SocketAddr,
463 index_offset: usize,
464 ) -> io::Result<RouteKey> {
465 self.multi_connect_impl(addr, index_offset, None).await
466 }
467 pub(crate) async fn multi_connect_impl(
468 &self,
469 addr: SocketAddr,
470 index_offset: usize,
471 ttl: Option<u8>,
472 ) -> io::Result<RouteKey> {
473 let len = self.tcp_multiplexing_limit;
474 if index_offset >= len {
475 return Err(io::Error::new(
476 io::ErrorKind::InvalidInput,
477 "index out of bounds",
478 ));
479 }
480 let lock = self
481 .lock
482 .entry(addr)
483 .or_insert_with(|| Arc::new(Mutex::new(())))
484 .value()
485 .clone();
486 let _guard = lock.lock().await;
487 if let Some(route_key) = self
488 .write_half_collect
489 .get_limit_route_key(index_offset, &addr)
490 {
491 return Ok(route_key);
492 }
493 self.connect_impl(0, addr, index_offset, ttl).await
494 }
495 pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
497 let lock = self
498 .lock
499 .entry(addr)
500 .or_insert_with(|| Arc::new(Mutex::new(())))
501 .value()
502 .clone();
503 let _guard = lock.lock().await;
504 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
505 return Ok(route_key);
506 }
507 self.connect_impl(0, addr, 0, None).await
508 }
509 pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
510 let lock = self
511 .lock
512 .entry(addr)
513 .or_insert_with(|| Arc::new(Mutex::new(())))
514 .value()
515 .clone();
516 let _guard = lock.lock().await;
517 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
518 return Ok(route_key);
519 }
520 self.connect_impl(0, addr, 0, ttl).await
521 }
522 pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
524 let lock = self
525 .lock
526 .entry(addr)
527 .or_insert_with(|| Arc::new(Mutex::new(())))
528 .value()
529 .clone();
530 let _guard = lock.lock().await;
531 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
532 return Ok(route_key);
533 }
534 self.connect_impl(self.local_addr.port(), addr, 0, None)
535 .await
536 }
537 pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
538 let stream = connect_tcp(
539 addr,
540 self.local_addr.port(),
541 self.default_interface.as_ref(),
542 None,
543 )
544 .await?;
545 Ok(stream)
546 }
547 async fn connect_impl(
548 &self,
549 bind_port: u16,
550 addr: SocketAddr,
551 index_offset: usize,
552 ttl: Option<u8>,
553 ) -> io::Result<RouteKey> {
554 let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
555 let route_key = stream.route_key()?;
556 let (read_half, write_half) = stream.into_split();
557 let (decoder, encoder) = self.init_codec.codec(addr)?;
558 let read_half = ReadHalfBox::new(read_half, decoder);
559 let sender =
560 self.write_half_collect
561 .add_write_half(route_key, index_offset, write_half, encoder);
562 if let Err(_e) = self
563 .connect_sender
564 .send((route_key, read_half, sender))
565 .await
566 {
567 Err(io::Error::new(
568 io::ErrorKind::UnexpectedEof,
569 "connect close",
570 ))?
571 }
572 Ok(route_key)
573 }
574}
575
576impl TcpSocketManager {
577 pub async fn multi_send_to<A: Into<SocketAddr>>(&self, buf: Bytes, addr: A) -> io::Result<()> {
578 self.multi_send_to_impl(buf, addr, None).await
579 }
580 pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
581 &self,
582 buf: Bytes,
583 addr: A,
584 ttl: Option<u8>,
585 ) -> io::Result<()> {
586 let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
587 let route_key = self
588 .multi_connect_impl(addr.into(), index_offset, ttl)
589 .await?;
590 self.send_to(buf, &route_key).await
591 }
592 pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
594 &self,
595 buf: Bytes,
596 addr: A,
597 ) -> io::Result<()> {
598 let route_key = self.connect_reuse_port(addr.into()).await?;
599 self.send_to(buf, &route_key).await
600 }
601
602 pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
604 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
605 self.write_half_collect.send_to(buf, &route_key).await
606 }
607 pub async fn send_to_addr<D: Into<SocketAddr>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
608 let route_key = self.connect(dest.into()).await?;
609 self.write_half_collect.send_to(buf, &route_key).await
610 }
611 pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: Bytes, dest: D) -> io::Result<()> {
612 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
613 self.write_half_collect.try_send_to(buf, &route_key)
614 }
615}
616pub trait ToRouteKeyForTcp<T> {
617 fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
618}
619
620impl ToRouteKeyForTcp<()> for RouteKey {
621 fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
622 Ok(dest)
623 }
624}
625
626impl ToRouteKeyForTcp<()> for &RouteKey {
627 fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
628 Ok(*dest)
629 }
630}
631
632impl ToRouteKeyForTcp<()> for &mut RouteKey {
633 fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
634 Ok(*dest)
635 }
636}
637pub trait TcpStreamIndex {
644 fn route_key(&self) -> io::Result<RouteKey>;
645 fn index(&self) -> Index;
646}
647
648impl TcpStreamIndex for TcpStream {
649 fn route_key(&self) -> io::Result<RouteKey> {
650 let addr = self.peer_addr()?;
651
652 Ok(RouteKey::new(self.index(), addr))
653 }
654
655 fn index(&self) -> Index {
656 #[cfg(windows)]
657 use std::os::windows::io::AsRawSocket;
658 #[cfg(windows)]
659 let index = self.as_raw_socket() as usize;
660 #[cfg(unix)]
661 use std::os::fd::{AsFd, AsRawFd};
662 #[cfg(unix)]
663 let index = self.as_fd().as_raw_fd() as usize;
664 Index::Tcp(index)
665 }
666}
667
668pub struct BytesCodec;
670
671pub struct LengthPrefixedCodec;
673
674#[async_trait]
675impl Decoder for BytesCodec {
676 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
677 let len = read.read(src).await?;
678 Ok(len)
679 }
680}
681
682#[async_trait]
683impl Encoder for BytesCodec {
684 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
685 write.write_all(data).await?;
686 Ok(())
687 }
688}
689
690#[async_trait]
691impl Decoder for LengthPrefixedCodec {
692 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
693 let mut head = [0; 4];
694 read.read_exact(&mut head).await?;
695 let len = u32::from_be_bytes(head) as usize;
696 read.read_exact(&mut src[..len]).await?;
697 Ok(len)
698 }
699}
700
701#[async_trait]
702impl Encoder for LengthPrefixedCodec {
703 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
704 let head: [u8; 4] = (data.len() as u32).to_be_bytes();
705 write.write_all(&head).await?;
706 write.write_all(data).await?;
707 Ok(())
708 }
709}
710
711#[derive(Clone)]
712pub struct BytesInitCodec;
713
714impl InitCodec for BytesInitCodec {
715 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
716 Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
717 }
718}
719
720#[derive(Clone)]
721pub struct LengthPrefixedInitCodec;
722
723impl InitCodec for LengthPrefixedInitCodec {
724 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
725 Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
726 }
727}
728
729pub trait InitCodec: Send + Sync + DynClone {
730 fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
731}
732dyn_clone::clone_trait_object!(InitCodec);
733
734#[async_trait]
735pub trait Decoder: Send + Sync {
736 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
737 fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
738 Err(io::Error::from(io::ErrorKind::WouldBlock))
739 }
740}
741
742#[async_trait]
743pub trait Encoder: Send + Sync {
744 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
745 async fn encode_multiple(
746 &mut self,
747 write: &mut OwnedWriteHalf,
748 bufs: &[IoSlice<'_>],
749 ) -> io::Result<()> {
750 for buf in bufs {
751 self.encode(write, buf).await?
752 }
753 Ok(())
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use async_trait::async_trait;
760 use std::io;
761 use std::net::SocketAddr;
762 use tokio::io::{AsyncReadExt, AsyncWriteExt};
763 use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
764
765 use crate::tunnel::config::TcpTunnelConfig;
766 use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
767
768 #[tokio::test]
769 pub async fn create_tcp_tunnel() {
770 let config: TcpTunnelConfig = TcpTunnelConfig::default();
771 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
772 drop(tcp_tunnel_factory)
773 }
774
775 #[tokio::test]
776 pub async fn create_codec_tcp_tunnel() {
777 let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
778 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
779 drop(tcp_tunnel_factory)
780 }
781
782 #[derive(Clone)]
783 struct MyInitCodeC;
784
785 impl InitCodec for MyInitCodeC {
786 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
787 Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
788 }
789 }
790
791 struct MyCodeC;
792
793 #[async_trait]
794 impl Decoder for MyCodeC {
795 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
796 let mut head = [0; 2];
797 read.read_exact(&mut head).await?;
798 let len = u16::from_be_bytes(head) as usize;
799 read.read_exact(&mut src[..len]).await?;
800 Ok(len)
801 }
802 }
803
804 #[async_trait]
805 impl Encoder for MyCodeC {
806 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
807 let head: [u8; 2] = (data.len() as u16).to_be_bytes();
808 write.write_all(&head).await?;
809 write.write_all(data).await?;
810 Ok(())
811 }
812 }
813}