1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
2
3const BUFFER_SIZE: usize = 16 * 1024;
4
5pub async fn copy_io<A, B>(a: &mut A, b: &mut B) -> (usize, usize, Option<std::io::Error>)
6where
7 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
8 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
9{
10 let mut a2b = [0u8; BUFFER_SIZE];
11 let mut b2a = [0u8; BUFFER_SIZE];
12
13 let mut a2b_num = 0;
14 let mut b2a_num = 0;
15
16 let mut last_err = None;
17
18 loop {
19 tokio::select! {
20 a2b_res = a.read(&mut a2b) => match a2b_res {
21 Ok(num) => {
22 if num == 0 {
24 break;
25 }
26 a2b_num += num;
27 if let Err(err) = b.write_all(&a2b[..num]).await {
28 last_err = Some(err);
29 break;
30 }
31 },
32 Err(err) => {
33 last_err = Some(err);
34 break;
35 }
36 },
37 b2a_res = b.read(&mut b2a) => match b2a_res {
38 Ok(num) => {
39 if num == 0 {
41 break;
42 }
43 b2a_num += num;
44 if let Err(err) = a.write_all(&b2a[..num]).await {
45 last_err = Some(err);
46 break;
47 }
48 },
49 Err(err) => {
50 last_err = Some(err);
51 break;
52 },
53 }
54 }
55 }
56
57 (a2b_num, b2a_num, last_err)
58}
59
60#[cfg(feature = "quic")]
61pub mod quinn {
62 use std::{
63 io,
64 pin::Pin,
65 task::{Context, Poll},
66 };
67
68 use quinn::{RecvStream, SendStream};
69 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
70
71 pub struct QuinnCompat {
72 send: SendStream,
73 recv: RecvStream,
74 }
75
76 impl QuinnCompat {
77 pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
78 QuinnCompat {
79 send: send_stream,
80 recv: recv_stream,
81 }
82 }
83
84 pub fn send_stream(&self) -> &SendStream {
85 &self.send
86 }
87
88 pub fn recv_stream(&self) -> &RecvStream {
89 &self.recv
90 }
91
92 pub fn send_stream_mut(&mut self) -> &mut SendStream {
93 &mut self.send
94 }
95
96 pub fn recv_stream_mut(&mut self) -> &mut RecvStream {
97 &mut self.recv
98 }
99
100 pub fn into_inner(self) -> (SendStream, RecvStream) {
101 (self.send, self.recv)
102 }
103 }
104
105 impl AsyncWrite for QuinnCompat {
106 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
107 Pin::new(&mut self.send).poll_write(cx, buf).map_err(io::Error::other)
108 }
109
110 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
111 Pin::new(&mut self.send).poll_flush(cx)
112 }
113
114 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
115 Pin::new(&mut self.send).poll_shutdown(cx)
116 }
117 }
118
119 impl AsyncRead for QuinnCompat {
120 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
121 Pin::new(&mut self.recv).poll_read(cx, buf)
122 }
123 }
124}