1use crate::route::{Index, RouteKey};
2use crate::socket::{connect_tcp, create_tcp_listener, LocalInterface};
3use crate::tunnel::config::TcpTunnelConfig;
4use crate::tunnel::recycle::RecycleBuf;
5use async_lock::Mutex;
6use async_trait::async_trait;
7use bytes::BytesMut;
8use dashmap::DashMap;
9use dyn_clone::DynClone;
10use rand::Rng;
11use std::io;
12use std::io::IoSlice;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16use tachyonix::{Receiver, Sender, TrySendError};
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
19use tokio::net::{TcpListener, TcpStream};
20
21pub struct TcpTunnelDispatcher {
22 route_idle_time: Duration,
23 tcp_listener: TcpListener,
24 connect_receiver: Receiver<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
25 #[allow(dead_code)]
26 pub(crate) socket_manager: Arc<TcpSocketManager>,
27 write_half_collect: WriteHalfCollect,
28 init_codec: Arc<Box<dyn InitCodec>>,
29}
30
31impl TcpTunnelDispatcher {
32 pub fn new(config: TcpTunnelConfig) -> io::Result<TcpTunnelDispatcher> {
34 config.check()?;
35 let address: SocketAddr = if config.use_v6 {
36 format!("[::]:{}", config.tcp_port).parse().unwrap()
37 } else {
38 format!("0.0.0.0:{}", config.tcp_port).parse().unwrap()
39 };
40
41 let tcp_listener = create_tcp_listener(address)?;
42 let local_addr = tcp_listener.local_addr()?;
43 let tcp_listener = TcpListener::from_std(tcp_listener)?;
44 let (connect_sender, connect_receiver) = tachyonix::channel(128);
45 let write_half_collect =
46 WriteHalfCollect::new(config.tcp_multiplexing_limit, config.recycle_buf);
47 let init_codec = Arc::new(config.init_codec);
48 let socket_manager = Arc::new(TcpSocketManager::new(
49 local_addr,
50 config.tcp_multiplexing_limit,
51 write_half_collect.clone(),
52 connect_sender,
53 config.default_interface,
54 init_codec.clone(),
55 ));
56 Ok(TcpTunnelDispatcher {
57 route_idle_time: config.route_idle_time,
58 tcp_listener,
59 connect_receiver,
60 socket_manager,
61 write_half_collect,
62 init_codec,
63 })
64 }
65}
66
67impl TcpTunnelDispatcher {
68 pub async fn dispatch(&mut self) -> io::Result<TcpTunnel> {
70 tokio::select! {
71 rs=self.connect_receiver.recv()=>{
72 let (route_key,read_half,sender) = rs.
73 map_err(|_| io::Error::new(io::ErrorKind::Other,"connect_receiver done"))?;
74 let local_addr = read_half.read_half.local_addr()?;
75 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
76 },
77 rs=self.tcp_listener.accept()=>{
78 let (tcp_stream,addr) = rs?;
79 tcp_stream.set_nodelay(true)?;
80 let route_key = tcp_stream.route_key()?;
81 let (read_half,write_half) = tcp_stream.into_split();
82 let (decoder,encoder) = self.init_codec.codec(addr)?;
83 let read_half = ReadHalfBox::new(read_half,decoder);
84 let sender = self.write_half_collect.add_write_half(route_key,0, write_half,encoder);
85 let local_addr = read_half.read_half.local_addr()?;
86 Ok(TcpTunnel::new(local_addr,self.route_idle_time,route_key,read_half,sender))
87 }
88 }
89 }
90 pub fn manager(&self) -> &Arc<TcpSocketManager> {
91 &self.socket_manager
92 }
93}
94
95pub struct TcpTunnel {
96 local_addr: SocketAddr,
97 route_key: RouteKey,
98 route_idle_time: Duration,
99 tcp_read: OwnedReadHalf,
100 decoder: Box<dyn Decoder>,
101 sender: Sender<BytesMut>,
102}
103
104impl TcpTunnel {
105 pub(crate) fn new(
106 local_addr: SocketAddr,
107 route_idle_time: Duration,
108 route_key: RouteKey,
109 read: ReadHalfBox,
110 sender: Sender<BytesMut>,
111 ) -> Self {
112 let decoder = read.decoder;
113 let tcp_read = read.read_half;
114 Self {
115 local_addr,
116 route_key,
117 route_idle_time,
118 tcp_read,
119 decoder,
120 sender,
121 }
122 }
123 #[inline]
124 pub fn route_key(&self) -> RouteKey {
125 self.route_key
126 }
127 pub fn local_addr(&self) -> SocketAddr {
128 self.local_addr
129 }
130 pub fn done(&mut self) {
131 self.sender.close();
132 }
133 pub fn sender(&self) -> io::Result<WeakTcpTunnelSender> {
134 Ok(WeakTcpTunnelSender::new(self.sender.clone()))
135 }
136}
137#[derive(Clone)]
138pub struct WeakTcpTunnelSender {
139 sender: Sender<BytesMut>,
140}
141impl WeakTcpTunnelSender {
142 fn new(sender: Sender<BytesMut>) -> Self {
143 Self { sender }
144 }
145 pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
146 if buf.is_empty() {
147 return Ok(());
148 }
149 self.sender
150 .send(buf)
151 .await
152 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
153 }
154}
155
156impl Drop for TcpTunnel {
157 fn drop(&mut self) {
158 self.done();
159 }
160}
161
162impl TcpTunnel {
163 pub async fn send(&self, buf: BytesMut) -> io::Result<()> {
165 if buf.is_empty() {
166 return Ok(());
167 }
168 self.sender
169 .send(buf)
170 .await
171 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
172 }
173
174 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175 match tokio::time::timeout(
176 self.route_idle_time,
177 self.decoder.decode(&mut self.tcp_read, buf),
178 )
179 .await
180 {
181 Ok(rs) => rs,
182 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
183 }
184 }
185 pub async fn batch_recv<B: AsMut<[u8]>>(
186 &mut self,
187 bufs: &mut [B],
188 sizes: &mut [usize],
189 ) -> io::Result<usize> {
190 if bufs.is_empty() || bufs.len() != sizes.len() {
191 return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs error"));
192 }
193 match tokio::time::timeout(
194 self.route_idle_time,
195 self.decoder.decode(&mut self.tcp_read, bufs[0].as_mut()),
196 )
197 .await
198 {
199 Ok(rs) => {
200 let len = rs?;
201 sizes[0] = len;
202 let mut num = 1;
203 while num < bufs.len() {
204 let rs = self
205 .decoder
206 .try_decode(&mut self.tcp_read, bufs[num].as_mut());
207 match rs {
208 Ok(len) => {
209 sizes[num] = len;
210 num += 1;
211 }
212 Err(_e) => break,
213 }
214 }
215
216 Ok(num)
217 }
218 Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
219 }
220 }
221 pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, RouteKey)> {
225 match self.recv(buf).await {
226 Ok(len) => Ok((len, self.route_key())),
227 Err(e) => Err(e),
228 }
229 }
230 pub async fn batch_recv_from<B: AsMut<[u8]>>(
231 &mut self,
232 bufs: &mut [B],
233 sizes: &mut [usize],
234 ) -> io::Result<(usize, RouteKey)> {
235 match self.batch_recv(bufs, sizes).await {
236 Ok(len) => Ok((len, self.route_key())),
237 Err(e) => Err(e),
238 }
239 }
240}
241
242#[derive(Clone)]
243pub struct WriteHalfCollect {
244 tcp_multiplexing_limit: usize,
245 addr_mapping: Arc<DashMap<SocketAddr, Vec<usize>>>,
246 write_half_map: Arc<DashMap<usize, Sender<BytesMut>>>,
247 recycle_buf: Option<RecycleBuf>,
248}
249
250impl WriteHalfCollect {
251 fn new(tcp_multiplexing_limit: usize, recycle_buf: Option<RecycleBuf>) -> Self {
252 Self {
253 tcp_multiplexing_limit,
254 addr_mapping: Default::default(),
255 write_half_map: Default::default(),
256 recycle_buf,
257 }
258 }
259}
260
261pub(crate) struct ReadHalfBox {
262 read_half: OwnedReadHalf,
263 decoder: Box<dyn Decoder>,
264}
265
266impl ReadHalfBox {
267 pub(crate) fn new(read_half: OwnedReadHalf, decoder: Box<dyn Decoder>) -> Self {
268 Self { read_half, decoder }
269 }
270}
271
272impl WriteHalfCollect {
273 pub(crate) fn add_write_half(
274 &self,
275 route_key: RouteKey,
276 index_offset: usize,
277 mut writer: OwnedWriteHalf,
278 mut decoder: Box<dyn Encoder>,
279 ) -> Sender<BytesMut> {
280 assert!(index_offset < self.tcp_multiplexing_limit);
281
282 let index = route_key.index_usize();
283 let _ref = self
284 .addr_mapping
285 .entry(route_key.addr())
286 .and_modify(|v| {
287 v[index_offset] = index;
288 })
289 .or_insert_with(|| {
290 let mut v = vec![0; self.tcp_multiplexing_limit];
291 v[index_offset] = index;
292 v
293 });
294 let (s, mut r) = tachyonix::channel(128);
295 let sender = s.clone();
296 self.write_half_map.insert(index, s);
297 let collect = self.clone();
298 let recycle_buf = self.recycle_buf.clone();
299 tokio::spawn(async move {
300 let mut vec_buf = Vec::with_capacity(16);
301 const IO_SLICE_CAPACITY: usize = 16;
302 let mut io_buffer: Vec<IoSlice> = Vec::with_capacity(IO_SLICE_CAPACITY);
303 let io_slice_storage = io_buffer.as_mut_slice();
304 while let Ok(v) = r.recv().await {
305 if let Ok(buf) = r.try_recv() {
306 vec_buf.push(v);
307 vec_buf.push(buf);
308 while let Ok(buf) = r.try_recv() {
309 vec_buf.push(buf);
310 if vec_buf.len() == 16 {
311 break;
312 }
313 }
314
315 let rs = {
320 let mut vec = unsafe {
321 Vec::from_raw_parts(io_slice_storage.as_mut_ptr(), 0, IO_SLICE_CAPACITY)
322 };
323 for x in &vec_buf {
324 vec.push(IoSlice::new(x));
325 }
326 let rs = decoder.encode_multiple(&mut writer, &vec).await;
327 vec.clear();
328 std::mem::forget(vec);
329 rs
330 };
331
332 if let Some(recycle_buf) = recycle_buf.as_ref() {
333 while let Some(buf) = vec_buf.pop() {
334 recycle_buf.push(buf);
335 }
336 } else {
337 vec_buf.clear()
338 }
339 if let Err(e) = rs {
340 log::debug!("{route_key:?},{e:?}");
341 break;
342 }
343 } else {
344 let rs = decoder.encode(&mut writer, &v).await;
345 if let Some(recycle_buf) = recycle_buf.as_ref() {
346 recycle_buf.push(v);
347 }
348 if let Err(e) = rs {
349 log::debug!("{route_key:?},{e:?}");
350 break;
351 }
352 }
353 }
354 collect.remove(&route_key);
355 });
356 sender
357 }
358 pub(crate) fn remove(&self, route_key: &RouteKey) {
359 let index_usize = route_key.index_usize();
360 self.addr_mapping
361 .remove_if_mut(&route_key.addr(), |_k, index_vec| {
362 let mut remove = true;
363 for v in index_vec {
364 if *v == index_usize {
365 *v = 0;
366 }
367 if *v != 0 {
368 remove = false;
369 }
370 }
371 remove
372 });
373
374 self.write_half_map.remove(&index_usize);
375 }
376 pub(crate) fn get_write_half(&self, index: &usize) -> Option<Sender<BytesMut>> {
377 self.write_half_map.get(index).map(|v| v.value().clone())
378 }
379 pub(crate) fn get_write_half_by_key(
380 &self,
381 route_key: &RouteKey,
382 ) -> io::Result<Sender<BytesMut>> {
383 match route_key.index() {
384 Index::Tcp(index) => {
385 let sender = self.get_write_half(&index).ok_or_else(|| {
386 io::Error::new(io::ErrorKind::Other, format!("not found {route_key:?}"))
387 })?;
388 Ok(sender)
389 }
390 _ => Err(io::Error::from(io::ErrorKind::InvalidInput)),
391 }
392 }
393
394 pub(crate) fn get_one_route_key(&self, addr: &SocketAddr) -> Option<RouteKey> {
395 if let Some(v) = self.addr_mapping.get(addr) {
396 for index_usize in v.value() {
397 if *index_usize != 0 {
398 return Some(RouteKey::new(Index::Tcp(*index_usize), *addr));
399 }
400 }
401 }
402 None
403 }
404 pub(crate) fn get_limit_route_key(&self, index: usize, addr: &SocketAddr) -> Option<RouteKey> {
405 if let Some(v) = self.addr_mapping.get(addr) {
406 assert_eq!(v.len(), self.tcp_multiplexing_limit);
407 let index_usize = v[index];
408 if index_usize == 0 {
409 return None;
410 }
411 return Some(RouteKey::new(Index::Tcp(index_usize), *addr));
412 }
413 None
414 }
415 pub async fn send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
416 let write_half = self.get_write_half_by_key(route_key)?;
417 if buf.is_empty() {
418 return Ok(());
419 }
420 if let Err(_e) = write_half.send(buf).await {
421 Err(io::Error::from(io::ErrorKind::WriteZero))
422 } else {
423 Ok(())
424 }
425 }
426 pub fn try_send_to(&self, buf: BytesMut, route_key: &RouteKey) -> io::Result<()> {
427 let write_half = self.get_write_half_by_key(route_key)?;
428 if buf.is_empty() {
429 return Ok(());
430 }
431 if let Err(e) = write_half.try_send(buf) {
432 match e {
433 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
434 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
435 }
436 } else {
437 Ok(())
438 }
439 }
440}
441
442pub struct TcpSocketManager {
443 lock: Mutex<()>,
444 local_addr: SocketAddr,
445 tcp_multiplexing_limit: usize,
446 write_half_collect: WriteHalfCollect,
447 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
448 default_interface: Option<LocalInterface>,
449 init_codec: Arc<Box<dyn InitCodec>>,
450}
451
452impl TcpSocketManager {
453 pub(crate) fn new(
454 local_addr: SocketAddr,
455 tcp_multiplexing_limit: usize,
456 write_half_collect: WriteHalfCollect,
457 connect_sender: Sender<(RouteKey, ReadHalfBox, Sender<BytesMut>)>,
458 default_interface: Option<LocalInterface>,
459 init_codec: Arc<Box<dyn InitCodec>>,
460 ) -> Self {
461 Self {
462 local_addr,
463 lock: Default::default(),
464 tcp_multiplexing_limit,
465 write_half_collect,
466 connect_sender,
467 default_interface,
468 init_codec,
469 }
470 }
471 pub fn local_addr(&self) -> SocketAddr {
472 self.local_addr
473 }
474}
475
476impl TcpSocketManager {
477 pub async fn multi_connect(
479 &self,
480 addr: SocketAddr,
481 index_offset: usize,
482 ) -> io::Result<RouteKey> {
483 self.multi_connect_impl(addr, index_offset, None).await
484 }
485 pub(crate) async fn multi_connect_impl(
486 &self,
487 addr: SocketAddr,
488 index_offset: usize,
489 ttl: Option<u8>,
490 ) -> io::Result<RouteKey> {
491 let len = self.tcp_multiplexing_limit;
492 if index_offset >= len {
493 return Err(io::Error::new(
494 io::ErrorKind::InvalidInput,
495 "index out of bounds",
496 ));
497 }
498 let _guard = self.lock.lock().await;
499 if let Some(route_key) = self
500 .write_half_collect
501 .get_limit_route_key(index_offset, &addr)
502 {
503 return Ok(route_key);
504 }
505 self.connect_impl(0, addr, index_offset, ttl).await
506 }
507 pub async fn connect(&self, addr: SocketAddr) -> io::Result<RouteKey> {
509 let _guard = self.lock.lock().await;
510 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
511 return Ok(route_key);
512 }
513 self.connect_impl(0, addr, 0, None).await
514 }
515 pub async fn connect_ttl(&self, addr: SocketAddr, ttl: Option<u8>) -> io::Result<RouteKey> {
516 let _guard = self.lock.lock().await;
517 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
518 return Ok(route_key);
519 }
520 self.connect_impl(0, addr, 0, ttl).await
521 }
522 pub async fn connect_reuse_port(&self, addr: SocketAddr) -> io::Result<RouteKey> {
524 let _guard = self.lock.lock().await;
525 if let Some(route_key) = self.write_half_collect.get_one_route_key(&addr) {
526 return Ok(route_key);
527 }
528 self.connect_impl(self.local_addr.port(), addr, 0, None)
529 .await
530 }
531 pub async fn connect_reuse_port_raw(&self, addr: SocketAddr) -> io::Result<TcpStream> {
532 let stream = connect_tcp(
533 addr,
534 self.local_addr.port(),
535 self.default_interface.as_ref(),
536 None,
537 )
538 .await?;
539 Ok(stream)
540 }
541 async fn connect_impl(
542 &self,
543 bind_port: u16,
544 addr: SocketAddr,
545 index_offset: usize,
546 ttl: Option<u8>,
547 ) -> io::Result<RouteKey> {
548 let stream = connect_tcp(addr, bind_port, self.default_interface.as_ref(), ttl).await?;
549 let route_key = stream.route_key()?;
550 let (read_half, write_half) = stream.into_split();
551 let (decoder, encoder) = self.init_codec.codec(addr)?;
552 let read_half = ReadHalfBox::new(read_half, decoder);
553 let sender =
554 self.write_half_collect
555 .add_write_half(route_key, index_offset, write_half, encoder);
556 if let Err(_e) = self
557 .connect_sender
558 .send((route_key, read_half, sender))
559 .await
560 {
561 Err(io::Error::new(
562 io::ErrorKind::UnexpectedEof,
563 "connect close",
564 ))?
565 }
566 Ok(route_key)
567 }
568}
569
570impl TcpSocketManager {
571 pub async fn multi_send_to<A: Into<SocketAddr>>(
572 &self,
573 buf: BytesMut,
574 addr: A,
575 ) -> io::Result<()> {
576 self.multi_send_to_impl(buf, addr, None).await
577 }
578 pub(crate) async fn multi_send_to_impl<A: Into<SocketAddr>>(
579 &self,
580 buf: BytesMut,
581 addr: A,
582 ttl: Option<u8>,
583 ) -> io::Result<()> {
584 let index_offset = rand::rng().random_range(0..self.tcp_multiplexing_limit);
585 let route_key = self
586 .multi_connect_impl(addr.into(), index_offset, ttl)
587 .await?;
588 self.send_to(buf, &route_key).await
589 }
590 pub async fn reuse_port_send_to<A: Into<SocketAddr>>(
592 &self,
593 buf: BytesMut,
594 addr: A,
595 ) -> io::Result<()> {
596 let route_key = self.connect_reuse_port(addr.into()).await?;
597 self.send_to(buf, &route_key).await
598 }
599
600 pub async fn send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
602 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
603 self.write_half_collect.send_to(buf, &route_key).await
604 }
605 pub async fn send_to_addr<D: Into<SocketAddr>>(
606 &self,
607 buf: BytesMut,
608 dest: D,
609 ) -> io::Result<()> {
610 let route_key = self.connect(dest.into()).await?;
611 self.write_half_collect.send_to(buf, &route_key).await
612 }
613 pub fn try_send_to<D: ToRouteKeyForTcp<()>>(&self, buf: BytesMut, dest: D) -> io::Result<()> {
614 let route_key = ToRouteKeyForTcp::route_key(self, dest)?;
615 self.write_half_collect.try_send_to(buf, &route_key)
616 }
617}
618pub trait ToRouteKeyForTcp<T> {
619 fn route_key(_: &TcpSocketManager, _: Self) -> io::Result<RouteKey>;
620}
621
622impl ToRouteKeyForTcp<()> for RouteKey {
623 fn route_key(_: &TcpSocketManager, dest: RouteKey) -> io::Result<RouteKey> {
624 Ok(dest)
625 }
626}
627
628impl ToRouteKeyForTcp<()> for &RouteKey {
629 fn route_key(_: &TcpSocketManager, dest: &RouteKey) -> io::Result<RouteKey> {
630 Ok(*dest)
631 }
632}
633
634impl ToRouteKeyForTcp<()> for &mut RouteKey {
635 fn route_key(_: &TcpSocketManager, dest: &mut RouteKey) -> io::Result<RouteKey> {
636 Ok(*dest)
637 }
638}
639pub trait TcpStreamIndex {
646 fn route_key(&self) -> io::Result<RouteKey>;
647 fn index(&self) -> Index;
648}
649
650impl TcpStreamIndex for TcpStream {
651 fn route_key(&self) -> io::Result<RouteKey> {
652 let addr = self.peer_addr()?;
653
654 Ok(RouteKey::new(self.index(), addr))
655 }
656
657 fn index(&self) -> Index {
658 #[cfg(windows)]
659 use std::os::windows::io::AsRawSocket;
660 #[cfg(windows)]
661 let index = self.as_raw_socket() as usize;
662 #[cfg(unix)]
663 use std::os::fd::{AsFd, AsRawFd};
664 #[cfg(unix)]
665 let index = self.as_fd().as_raw_fd() as usize;
666 Index::Tcp(index)
667 }
668}
669
670pub struct BytesCodec;
672
673pub struct LengthPrefixedCodec;
675
676#[async_trait]
677impl Decoder for BytesCodec {
678 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
679 let len = read.read(src).await?;
680 Ok(len)
681 }
682}
683
684#[async_trait]
685impl Encoder for BytesCodec {
686 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
687 write.write_all(data).await?;
688 Ok(())
689 }
690}
691
692#[async_trait]
693impl Decoder for LengthPrefixedCodec {
694 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
695 let mut head = [0; 4];
696 read.read_exact(&mut head).await?;
697 let len = u32::from_be_bytes(head) as usize;
698 read.read_exact(&mut src[..len]).await?;
699 Ok(len)
700 }
701}
702
703#[async_trait]
704impl Encoder for LengthPrefixedCodec {
705 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
706 let head: [u8; 4] = (data.len() as u32).to_be_bytes();
707 write.write_all(&head).await?;
708 write.write_all(data).await?;
709 Ok(())
710 }
711}
712
713#[derive(Clone)]
714pub struct BytesInitCodec;
715
716impl InitCodec for BytesInitCodec {
717 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
718 Ok((Box::new(BytesCodec), Box::new(BytesCodec)))
719 }
720}
721
722#[derive(Clone)]
723pub struct LengthPrefixedInitCodec;
724
725impl InitCodec for LengthPrefixedInitCodec {
726 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
727 Ok((Box::new(LengthPrefixedCodec), Box::new(LengthPrefixedCodec)))
728 }
729}
730
731pub trait InitCodec: Send + Sync + DynClone {
732 fn codec(&self, addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)>;
733}
734dyn_clone::clone_trait_object!(InitCodec);
735
736#[async_trait]
737pub trait Decoder: Send + Sync {
738 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize>;
739 fn try_decode(&mut self, _read: &mut OwnedReadHalf, _src: &mut [u8]) -> io::Result<usize> {
740 Err(io::Error::from(io::ErrorKind::WouldBlock))
741 }
742}
743
744#[async_trait]
745pub trait Encoder: Send + Sync {
746 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()>;
747 async fn encode_multiple(
748 &mut self,
749 write: &mut OwnedWriteHalf,
750 bufs: &[IoSlice<'_>],
751 ) -> io::Result<()> {
752 for buf in bufs {
753 self.encode(write, buf).await?
754 }
755 Ok(())
756 }
757}
758
759#[cfg(test)]
760mod tests {
761 use async_trait::async_trait;
762 use std::io;
763 use std::net::SocketAddr;
764 use tokio::io::{AsyncReadExt, AsyncWriteExt};
765 use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
766
767 use crate::tunnel::config::TcpTunnelConfig;
768 use crate::tunnel::tcp::{Decoder, Encoder, InitCodec, TcpTunnelDispatcher};
769
770 #[tokio::test]
771 pub async fn create_tcp_tunnel() {
772 let config: TcpTunnelConfig = TcpTunnelConfig::default();
773 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
774 drop(tcp_tunnel_factory)
775 }
776
777 #[tokio::test]
778 pub async fn create_codec_tcp_tunnel() {
779 let config = TcpTunnelConfig::new(Box::new(MyInitCodeC));
780 let tcp_tunnel_factory = TcpTunnelDispatcher::new(config).unwrap();
781 drop(tcp_tunnel_factory)
782 }
783
784 #[derive(Clone)]
785 struct MyInitCodeC;
786
787 impl InitCodec for MyInitCodeC {
788 fn codec(&self, _addr: SocketAddr) -> io::Result<(Box<dyn Decoder>, Box<dyn Encoder>)> {
789 Ok((Box::new(MyCodeC), Box::new(MyCodeC)))
790 }
791 }
792
793 struct MyCodeC;
794
795 #[async_trait]
796 impl Decoder for MyCodeC {
797 async fn decode(&mut self, read: &mut OwnedReadHalf, src: &mut [u8]) -> io::Result<usize> {
798 let mut head = [0; 2];
799 read.read_exact(&mut head).await?;
800 let len = u16::from_be_bytes(head) as usize;
801 read.read_exact(&mut src[..len]).await?;
802 Ok(len)
803 }
804 }
805
806 #[async_trait]
807 impl Encoder for MyCodeC {
808 async fn encode(&mut self, write: &mut OwnedWriteHalf, data: &[u8]) -> io::Result<()> {
809 let head: [u8; 2] = (data.len() as u16).to_be_bytes();
810 write.write_all(&head).await?;
811 write.write_all(data).await?;
812 Ok(())
813 }
814 }
815}