1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use crate::tunnel::recycle::RecycleBuf;
5use async_lock::Mutex;
6use async_trait::async_trait;
7use bytes::BytesMut;
8use dashmap::DashMap;
9use dyn_clone::DynClone;
10use rand::Rng;
11use std::io;
12use std::io::IoSlice;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tachyonix::{Receiver, Sender, TrySendError};
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
19use tokio::net::{TcpListener, TcpStream};
20
21pub struct TcpTunnelDispatcher {
22 route_idle_time: Duration,
23 tcp_listener: TcpListener,
24 connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
25 #[allow(dead_code)]
26 pub(crate) socket_manager: Arc<TcpSocketManager>,
27 write_half_collect: WriteHalfCollect,
28 init_codec: Arc<Box<dyn InitCodec>>,
29}
30
31impl TcpTunnelDispatcher {
32 pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
34 config.check()?;
35 let address: SocketAddr = if config.use_v6 {
36 format!("[::]:{}", config.tcp_port).parse().unwrap()
37 } else {
38 format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
39 };
40
41 let tcp_listener = create_tcp_listener(address)?;
42 let local_addr = tcp_listener.local_addr()?;
43 let tcp_listener = TcpListener::from_std(tcp_listener)?;
44 let (connect_sender, connect_receiver) = tachyonix::channel(128);
45 let write_half_collect =
46 WriteHalfCollect::new(config.tcp_multiplexing_limit, config.recycle_buf);
47 let init_codec = Arc::new(config.init_codec);
48 let socket_manager = Arc::new(TcpSocketManager::new(
49 local_addr,
50 config.tcp_multiplexing_limit,
51 write_half_collect.clone(),
52 connect_sender,
53 config.default_interface,
54 init_codec.clone(),
55 ));
56 Ok(TcpTunnelDispatcher {
57 route_idle_time: config.route_idle_time,
58 tcp_listener,
59 connect_receiver,
60 socket_manager,
61 write_half_collect,
62 init_codec,
63 })
64 }
65}
66
67impl TcpTunnelDispatcher {
68 pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
70 tokio::select! {
71 rs=self.connect_receiver.recv()=>{
72 let (route_key,read_half,sender) = rs.
73 map_err(|_| io::Error::other("connect_receiver done"))?;
74 let local_addr = read_half.read_half.local_addr()?;
75 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
76 },
77 rs=self.tcp_listener.accept()=>{
78 let (tcp_stream,addr) = rs?;
79 tcp_stream.set_nodelay(true)?;
80 let route_key = tcp_stream.route_key()?;
81 let (read_half,write_half) = tcp_stream.into_split();
82 let (decoder,encoder) = self.init_codec.codec(addr)?;
83 let read_half = ReadHalfBox::new(read_half,decoder);
84 let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
85 let local_addr = read_half.read_half.local_addr()?;
86 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
87 }
88 }
89 }
90 pub fn manager(&self) -> &Arc<TcpSocketManager> {
91 &self.socket_manager
92 }
93}
94
95pub struct TcpTunnel {
96 local_addr: SocketAddr,
97 route_key: RouteKey,
98 route_idle_time: Duration,
99 tcp_read: OwnedReadHalf,
100 decoder: Box<dyn Decoder>,
101 sender: Sender<BytesMut>,
102}
103
104impl TcpTunnel {
105 pub(crate) fn new(
106 local_addr: SocketAddr,
107 route_idle_time: Duration,
108 route_key: RouteKey,
109 read: ReadHalfBox,
110 sender: Sender<BytesMut>,
111 ) -> Self {
112 let decoder = read.decoder;
113 let tcp_read = read.read_half;
114 Self {
115 local_addr,
116 route_key,
117 route_idle_time,
118 tcp_read,
119 decoder,
120 sender,
121 }
122 }
123 #[inline]
124 pub fn route_key(&self) -> RouteKey {
125 self.route_key
126 }
127 pub fn local_addr(&self) -> SocketAddr {
128 self.local_addr
129 }
130 pub fn done(&mut self) {
131 self.sender.close();
132 }
133 pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
134 Ok(WeakTcpTunnelSender::new(self.sender.clone()))
135 }
136}
137#[derive(Clone)]
138pub struct WeakTcpTunnelSender {
139 sender: Sender<BytesMut>,
140}
141impl WeakTcpTunnelSender {
142 fn new(sender: Sender<BytesMut>) -> Self {
143 Self { sender }
144 }
145 pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
146 if buf.is_empty() {
147 return Ok(());
148 }
149 self.sender
150 .send(buf)
151 .await
152 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
153 }
154}
155
156impl Drop for TcpTunnel {
157 fn drop(&mut self) {
158 self.done();
159 }
160}
161
162impl TcpTunnel {
163 pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
165 if buf.is_empty() {
166 return Ok(());
167 }
168 self.sender
169 .send(buf)
170 .await
171 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
172 }
173
174 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175 match tokio::time::timeout(
176 self.route_idle_time,
177 self.decoder.decode(&mut self.tcp_read, buf),
178 )
179 .await
180 {
181 Ok(rs) => rs,
182 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
183 }
184 }
185 pub async fn batch_recv<B: AsMut<[u8]>>(
186 &mut self,
187 bufs: &mut [B],
188 sizes: &mut [usize],
189 ) -> io::Result<usize> {
190 if bufs.is_empty() || bufs.len() != sizes.len() {
191 return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
192 }
193 match tokio::time::timeout(
194 self.route_idle_time,
195 self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
196 )
197 .await
198 {
199 Ok(rs) => {
200 let len = rs?;
201 sizes[0] = len;
202 let mut num = 1;
203 while num < bufs.len() {
204 let rs = self
205 .decoder
206 .try_decode(&mut self.tcp_read, bufs[num].as_mut());
207 match rs {
208 Ok(len) => {
209 sizes[num] = len;
210 num += 1;
211 }
212 Err(_e) => break,
213 }
214 }
215
216 Ok(num)
217 }
218 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
219 }
220 }
221 pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
225 match self.recv(buf).await {
226 Ok(len) => Ok((len, self.route_key())),
227 Err(e) => Err(e),
228 }
229 }
230 pub async fn batch_recv_from<B: AsMut<[u8]>>(
231 &mut self,
232 bufs: &mut [B],
233 sizes: &mut [usize],
234 ) -> io::Result<(usize, RouteKey)> {
235 match self.batch_recv(bufs, sizes).await {
236 Ok(len) => Ok((len, self.route_key())),
237 Err(e) => Err(e),
238 }
239 }
240}
241
242#[derive(Clone)]
243pub struct WriteHalfCollect {
244 tcp_multiplexing_limit: usize,
245 addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
246 write_half_map: Arc<DashMap<usize, Sender<BytesMut>>>,
247 recycle_buf: Option<RecycleBuf>,
248}
249
250impl WriteHalfCollect {
251 fn new(tcp_multiplexing_limit: usize, recycle_buf: Option<RecycleBuf>) -> Self {
252 Self {
253 tcp_multiplexing_limit,
254 addr_mapping: Default::default(),
255 write_half_map: Default::default(),
256 recycle_buf,
257 }
258 }
259}
260
261pub(crate) struct ReadHalfBox {
262 read_half: OwnedReadHalf,
263 decoder: Box<dyn Decoder>,
264}
265
266impl ReadHalfBox {
267 pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
268 Self { read_half, decoder }
269 }
270}
271
272impl WriteHalfCollect {
273 pub(crate) fn add_write_half(
274 &self,
275 route_key: RouteKey,
276 index_offset: usize,
277 mut writer: OwnedWriteHalf,
278 mut decoder: Box<dyn Encoder>,
279 ) -> Sender<BytesMut> {
280 assert!(index_offset < self.tcp_multiplexing_limit);
281
282 let index = route_key.index_usize();
283 let _ref = self
284 .addr_mapping
285 .entry(route_key.addr())
286 .and_modify(|v| {
287 v[index_offset] = index;
288 })
289 .or_insert_with(|| {
290 let mut v = vec![0; self.tcp_multiplexing_limit];
291 v[index_offset] = index;
292 v
293 });
294 let (s, mut r) = tachyonix::channel(128);
295 let sender = s.clone();
296 self.write_half_map.insert(index, s);
297 let collect = self.clone();
298 let recycle_buf = self.recycle_buf.clone();
299 tokio::spawn(async move {
300 let mut vec_buf = Vec::with_capacity(16);
301 const IO_SLICE_CAPACITY: usize = 16;
302 let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
303 let io_slice_storage = io_buffer.as_mut_slice();
304 while let Ok(v) = r.recv().await {
305 if let Ok(buf) = r.try_recv() {
306 vec_buf.push(v);
307 vec_buf.push(buf);
308 while let Ok(buf) = r.try_recv() {
309 vec_buf.push(buf);
310 if vec_buf.len() == 16 {
311 break;
312 }
313 }
314
315 let rs = {
320 let mut vec = unsafe {
321 Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
322 };
323 for x in &vec_buf {
324 vec.push(IoSlice::new(x));
325 }
326 let rs = decoder.encode_multiple(&mut writer, &vec).await;
327 vec.clear();
328 std::mem::forget(vec);
329 rs
330 };
331
332 if let Some(recycle_buf) = recycle_buf.as_ref() {
333 while let Some(buf) = vec_buf.pop() {
334 recycle_buf.push(buf);
335 }
336 } else {
337 vec_buf.clear()
338 }
339 if let Err(e) = rs {
340 log::debug!("{route_key:?},{e:?}");
341 break;
342 }
343 } else {
344 let rs = decoder.encode(&mut writer, &v).await;
345 if let Some(recycle_buf) = recycle_buf.as_ref() {
346 recycle_buf.push(v);
347 }
348 if let Err(e) = rs {
349 log::debug!("{route_key:?},{e:?}");
350 break;
351 }
352 }
353 }
354 collect.remove(&route_key);
355 });
356 sender
357 }
358 pub(crate) fn remove(&self, route_key: &RouteKey) {
359 let index_usize = route_key.index_usize();
360 self.addr_mapping
361 .remove_if_mut(&route_key.addr(), |_k, index_vec| {
362 let mut remove = true;
363 for v in index_vec {
364 if *v == index_usize {
365 *v = 0;
366 }
367 if *v != 0 {
368 remove = false;
369 }
370 }
371 remove
372 });
373
374 self.write_half_map.remove(&index_usize);
375 }
376 pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<BytesMut>> {
377 self.write_half_map.get(index).map(|v| v.value().clone())
378 }
379 pub(crate) fn get_write_half_by_key(
380 &self,
381 route_key: &RouteKey,
382 ) -> io::Result<Sender<BytesMut>> {
383 match route_key.index() {
384 Index::Tcp(index) => {
385 let sender = self
386 .get_write_half(&index)
387 .ok_or_else(|| io::Error::other(format!("not found {route_key:?}")))?;
388 Ok(sender)
389 }
390 _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
391 }
392 }
393
394 pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
395 if let Some(v) = self.addr_mapping.get(addr) {
396 for index_usize in v.value() {
397 if *index_usize != 0 {
398 return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
399 }
400 }
401 }
402 None
403 }
404 pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
405 if let Some(v) = self.addr_mapping.get(addr) {
406 assert_eq!(v.len(), self.tcp_multiplexing_limit);
407 let index_usize = v[index];
408 if index_usize == 0 {
409 return None;
410 }
411 return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
412 }
413 None
414 }
415 pub async fn send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
416 let write_half = self.get_write_half_by_key(route_key)?;
417 if buf.is_empty() {
418 return Ok(());
419 }
420 if let Err(_e) = write_half.send(buf).await {
421 Err(io::Error::from(io::ErrorKind::WriteZero))
422 } else {
423 Ok(())
424 }
425 }
426 pub fn try_send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
427 let write_half = self.get_write_half_by_key(route_key)?;
428 if buf.is_empty() {
429 return Ok(());
430 }
431 if let Err(e) = write_half.try_send(buf) {
432 match e {
433 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
434 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
435 }
436 } else {
437 Ok(())
438 }
439 }
440}
441
442pub struct TcpSocketManager {
443 lock: DashMap<SocketAddr, Arc<Mutex<()>>>,
444 local_addr: SocketAddr,
445 tcp_multiplexing_limit: usize,
446 write_half_collect: WriteHalfCollect,
447 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
448 default_interface: Option<LocalInterface>,
449 init_codec: Arc<Box<dyn InitCodec>>,
450}
451
452impl TcpSocketManager {
453 pub(crate) fn new(
454 local_addr: SocketAddr,
455 tcp_multiplexing_limit: usize,
456 write_half_collect: WriteHalfCollect,
457 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
458 default_interface: Option<LocalInterface>,
459 init_codec: Arc<Box<dyn InitCodec>>,
460 ) -> Self {
461 Self {
462 local_addr,
463 lock: Default::default(),
464 tcp_multiplexing_limit,
465 write_half_collect,
466 connect_sender,
467 default_interface,
468 init_codec,
469 }
470 }
471 pub fn local_addr(&self) -> SocketAddr {
472 self.local_addr
473 }
474}
475
476impl TcpSocketManager {
477 pub async fn multi_connect(
479 &self,
480 addr: SocketAddr,
481 index_offset: usize,
482 ) -> io::Result<RouteKey> {
483 self.multi_connect_impl(addr, index_offset, None).await
484 }
485 pub(crate) async fn multi_connect_impl(
486 &self,
487 addr: SocketAddr,
488 index_offset: usize,
489 ttl: Option<u8>,
490 ) -> io::Result<RouteKey> {
491 let len = self.tcp_multiplexing_limit;
492 if index_offset >= len {
493 return Err(io::Error::new(
494 io::ErrorKind::InvalidInput,
495 "index out of bounds",
496 ));
497 }
498 let lock = self
499 .lock
500 .entry(addr)
501 .or_insert_with(|| Arc::new(Mutex::new(())))
502 .value()
503 .clone();
504 let _guard = lock.lock().await;
505 if let Some(route_key) = self
506 .write_half_collect
507 .get_limit_route_key(index_offset, &addr)
508 {
509 return Ok(route_key);
510 }
511 self.connect_impl(0, addr, index_offset, ttl).await
512 }
513 pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
515 let lock = self
516 .lock
517 .entry(addr)
518 .or_insert_with(|| Arc::new(Mutex::new(())))
519 .value()
520 .clone();
521 let _guard = lock.lock().await;
522 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
523 return Ok(route_key);
524 }
525 self.connect_impl(0, addr, 0, None).await
526 }
527 pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
528 let lock = self
529 .lock
530 .entry(addr)
531 .or_insert_with(|| Arc::new(Mutex::new(())))
532 .value()
533 .clone();
534 let _guard = lock.lock().await;
535 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
536 return Ok(route_key);
537 }
538 self.connect_impl(0, addr, 0, ttl).await
539 }
540 pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
542 let lock = self
543 .lock
544 .entry(addr)
545 .or_insert_with(|| Arc::new(Mutex::new(())))
546 .value()
547 .clone();
548 let _guard = lock.lock().await;
549 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
550 return Ok(route_key);
551 }
552 self.connect_impl(self.local_addr.port(), addr, 0, None)
553 .await
554 }
555 pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
556 let stream = connect_tcp(
557 addr,
558 self.local_addr.port(),
559 self.default_interface.as_ref(),
560 None,
561 )
562 .await?;
563 Ok(stream)
564 }
565 async fn connect_impl(
566 &self,
567 bind_port: u16,
568 addr: SocketAddr,
569 index_offset: usize,
570 ttl: Option<u8>,
571 ) -> io::Result<RouteKey> {
572 let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
573 let route_key = stream.route_key()?;
574 let (read_half, write_half) = stream.into_split();
575 let (decoder, encoder) = self.init_codec.codec(addr)?;
576 let read_half = ReadHalfBox::new(read_half, decoder);
577 let sender =
578 self.write_half_collect
579 .add_write_half(route_key, index_offset, write_half, encoder);
580 if let Err(_e) = self
581 .connect_sender
582 .send((route_key, read_half, sender))
583 .await
584 {
585 Err(io::Error::new(
586 io::ErrorKind::UnexpectedEof,
587 "connect close",
588 ))?
589 }
590 Ok(route_key)
591 }
592}
593
594impl TcpSocketManager {
595 pub async fn multi_send_to<A: Into<SocketAddr>>(
596 &self,
597 buf: BytesMut,
598 addr: A,
599 ) -> io::Result<()> {
600 self.multi_send_to_impl(buf, addr, None).await
601 }
602 pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
603 &self,
604 buf: BytesMut,
605 addr: A,
606 ttl: Option<u8>,
607 ) -> io::Result<()> {
608 let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
609 let route_key = self
610 .multi_connect_impl(addr.into(), index_offset, ttl)
611 .await?;
612 self.send_to(buf, &route_key).await
613 }
614 pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
616 &self,
617 buf: BytesMut,
618 addr: A,
619 ) -> io::Result<()> {
620 let route_key = self.connect_reuse_port(addr.into()).await?;
621 self.send_to(buf, &route_key).await
622 }
623
624 pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
626 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
627 self.write_half_collect.send_to(buf, &route_key).await
628 }
629 pub async fn send_to_addr<D: Into<SocketAddr>>(
630 &self,
631 buf: BytesMut,
632 dest: D,
633 ) -> io::Result<()> {
634 let route_key = self.connect(dest.into()).await?;
635 self.write_half_collect.send_to(buf, &route_key).await
636 }
637 pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
638 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
639 self.write_half_collect.try_send_to(buf, &route_key)
640 }
641}
642pub trait ToRouteKeyForTcp<T> {
643 fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
644}
645
646impl ToRouteKeyForTcp<()> for RouteKey {
647 fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
648 Ok(dest)
649 }
650}
651
652impl ToRouteKeyForTcp<()> for &RouteKey {
653 fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
654 Ok(*dest)
655 }
656}
657
658impl ToRouteKeyForTcp<()> for &mut RouteKey {
659 fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
660 Ok(*dest)
661 }
662}
663pub trait TcpStreamIndex {
670 fn route_key(&self) -> io::Result<RouteKey>;
671 fn index(&self) -> Index;
672}
673
674impl TcpStreamIndex for TcpStream {
675 fn route_key(&self) -> io::Result<RouteKey> {
676 let addr = self.peer_addr()?;
677
678 Ok(RouteKey::new(self.index(), addr))
679 }
680
681 fn index(&self) -> Index {
682 #[cfg(windows)]
683 use std::os::windows::io::AsRawSocket;
684 #[cfg(windows)]
685 let index = self.as_raw_socket() as usize;
686 #[cfg(unix)]
687 use std::os::fd::{AsFd, AsRawFd};
688 #[cfg(unix)]
689 let index = self.as_fd().as_raw_fd() as usize;
690 Index::Tcp(index)
691 }
692}
693
694pub struct BytesCodec;
696
697pub struct LengthPrefixedCodec;
699
700#[async_trait]
701impl Decoder for BytesCodec {
702 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
703 let len = read.read(src).await?;
704 Ok(len)
705 }
706}
707
708#[async_trait]
709impl Encoder for BytesCodec {
710 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
711 write.write_all(data).await?;
712 Ok(())
713 }
714}
715
716#[async_trait]
717impl Decoder for LengthPrefixedCodec {
718 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
719 let mut head = [0; 4];
720 read.read_exact(&mut head).await?;
721 let len = u32::from_be_bytes(head) as usize;
722 read.read_exact(&mut src[..len]).await?;
723 Ok(len)
724 }
725}
726
727#[async_trait]
728impl Encoder for LengthPrefixedCodec {
729 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
730 let head: [u8; 4] = (data.len() as u32).to_be_bytes();
731 write.write_all(&head).await?;
732 write.write_all(data).await?;
733 Ok(())
734 }
735}
736
737#[derive(Clone)]
738pub struct BytesInitCodec;
739
740impl InitCodec for BytesInitCodec {
741 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
742 Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
743 }
744}
745
746#[derive(Clone)]
747pub struct LengthPrefixedInitCodec;
748
749impl InitCodec for LengthPrefixedInitCodec {
750 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
751 Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
752 }
753}
754
755pub trait InitCodec: Send + Sync + DynClone {
756 fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
757}
758dyn_clone::clone_trait_object!(InitCodec);
759
760#[async_trait]
761pub trait Decoder: Send + Sync {
762 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
763 fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
764 Err(io::Error::from(io::ErrorKind::WouldBlock))
765 }
766}
767
768#[async_trait]
769pub trait Encoder: Send + Sync {
770 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
771 async fn encode_multiple(
772 &mut self,
773 write: &mut OwnedWriteHalf,
774 bufs: &[IoSlice<'_>],
775 ) -> io::Result<()> {
776 for buf in bufs {
777 self.encode(write, buf).await?
778 }
779 Ok(())
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use async_trait::async_trait;
786 use std::io;
787 use std::net::SocketAddr;
788 use tokio::io::{AsyncReadExt, AsyncWriteExt};
789 use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
790
791 use crate::tunnel::config::TcpTunnelConfig;
792 use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
793
794 #[tokio::test]
795 pub async fn create_tcp_tunnel() {
796 let config: TcpTunnelConfig = TcpTunnelConfig::default();
797 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
798 drop(tcp_tunnel_factory)
799 }
800
801 #[tokio::test]
802 pub async fn create_codec_tcp_tunnel() {
803 let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
804 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
805 drop(tcp_tunnel_factory)
806 }
807
808 #[derive(Clone)]
809 struct MyInitCodeC;
810
811 impl InitCodec for MyInitCodeC {
812 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
813 Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
814 }
815 }
816
817 struct MyCodeC;
818
819 #[async_trait]
820 impl Decoder for MyCodeC {
821 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
822 let mut head = [0; 2];
823 read.read_exact(&mut head).await?;
824 let len = u16::from_be_bytes(head) as usize;
825 read.read_exact(&mut src[..len]).await?;
826 Ok(len)
827 }
828 }
829
830 #[async_trait]
831 impl Encoder for MyCodeC {
832 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
833 let head: [u8; 2] = (data.len() as u16).to_be_bytes();
834 write.write_all(&head).await?;
835 write.write_all(data).await?;
836 Ok(())
837 }
838 }
839}