1use std::io;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::time::Instant as TokioInstant;
19
20pub trait RelayMetrics {
26 fn record_inbound(&self, bytes: u64);
28 fn record_outbound(&self, bytes: u64);
30}
31
32#[derive(Debug, Clone, Copy, Default)]
34pub struct NoOpMetrics;
35
36impl RelayMetrics for NoOpMetrics {
37 #[inline]
38 fn record_inbound(&self, _bytes: u64) {}
39 #[inline]
40 fn record_outbound(&self, _bytes: u64) {}
41}
42
43enum CopyState {
51 Reading(usize), Writing(usize, usize, usize), Flushing(usize, bool), ShuttingDown,
55 Done,
56}
57
58enum CopyPoll {
60 Flushed(usize),
62 Finished,
64}
65
66fn poll_copy_direction<R, W>(
73 cx: &mut Context<'_>,
74 reader: &mut R,
75 writer: &mut W,
76 buf: &mut [u8],
77 state: &mut CopyState,
78) -> Poll<io::Result<CopyPoll>>
79where
80 R: AsyncRead + Unpin + ?Sized,
81 W: AsyncWrite + Unpin + ?Sized,
82{
83 loop {
84 match state {
85 CopyState::Reading(flushed) => {
86 let mut read_buf = ReadBuf::new(buf);
87 match Pin::new(&mut *reader).poll_read(cx, &mut read_buf) {
88 Poll::Ready(Ok(())) => {
89 let n = read_buf.filled().len();
90 if n == 0 {
91 if *flushed > 0 {
93 let total = *flushed;
94 *state = CopyState::Flushing(total, true);
95 } else {
96 *state = CopyState::ShuttingDown;
97 }
98 } else {
99 let acc = *flushed;
100 *state = CopyState::Writing(0, n, acc);
101 }
102 }
103 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
104 Poll::Pending => {
105 if *flushed > 0 {
107 let total = *flushed;
108 *state = CopyState::Flushing(total, false);
109 } else {
110 return Poll::Pending;
111 }
112 }
113 }
114 }
115 CopyState::Writing(pos, len, acc) => {
116 match Pin::new(&mut *writer).poll_write(cx, &buf[*pos..*len]) {
117 Poll::Ready(Ok(n)) => {
118 *pos += n;
119 if *pos >= *len {
120 let total = *acc + *len;
121 *state = CopyState::Reading(total);
123 }
124 }
125 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
126 Poll::Pending => return Poll::Pending,
127 }
128 }
129 CopyState::Flushing(bytes, is_eof) => {
130 let bytes = *bytes;
131 let eof = *is_eof;
132 match Pin::new(&mut *writer).poll_flush(cx) {
133 Poll::Ready(Ok(())) => {
134 if eof {
135 *state = CopyState::ShuttingDown;
136 } else {
137 *state = CopyState::Reading(0);
138 }
139 return Poll::Ready(Ok(CopyPoll::Flushed(bytes)));
140 }
141 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
142 Poll::Pending => return Poll::Pending,
143 }
144 }
145 CopyState::ShuttingDown => match Pin::new(&mut *writer).poll_shutdown(cx) {
146 Poll::Ready(_) => {
147 *state = CopyState::Done;
148 return Poll::Ready(Ok(CopyPoll::Finished));
149 }
150 Poll::Pending => return Poll::Pending,
151 },
152 CopyState::Done => return Poll::Ready(Ok(CopyPoll::Finished)),
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, Default)]
159pub struct RelayStats {
160 pub inbound: u64,
162 pub outbound: u64,
164}
165
166impl RelayStats {
167 #[inline]
169 pub fn total(self) -> u64 {
170 self.inbound + self.outbound
171 }
172}
173
174pub async fn relay_bidirectional<A, B, M>(
189 inbound: A,
190 outbound: B,
191 idle_timeout: Duration,
192 buffer_size: usize,
193 metrics: &M,
194) -> io::Result<RelayStats>
195where
196 A: AsyncRead + AsyncWrite + Unpin,
197 B: AsyncRead + AsyncWrite + Unpin,
198 M: RelayMetrics,
199{
200 let (mut in_r, mut in_w) = tokio::io::split(inbound);
201 let (mut out_r, mut out_w) = tokio::io::split(outbound);
202
203 let mut buf_a = vec![0u8; buffer_size];
204 let mut buf_b = vec![0u8; buffer_size];
205 let mut state_a = CopyState::Reading(0);
206 let mut state_b = CopyState::Reading(0);
207
208 let idle_sleep = tokio::time::sleep(idle_timeout);
209 tokio::pin!(idle_sleep);
210
211 let mut a_done = false;
212 let mut b_done = false;
213 let mut total_inbound: u64 = 0;
214 let mut total_outbound: u64 = 0;
215
216 loop {
217 if a_done && b_done {
218 return Ok(RelayStats {
219 inbound: total_inbound,
220 outbound: total_outbound,
221 });
222 }
223
224 let both = std::future::poll_fn(|cx| {
228 let mut any_ready = false;
229 let mut activity = false;
230 let mut error: Option<io::Error> = None;
231
232 if !a_done {
233 match poll_copy_direction(cx, &mut in_r, &mut out_w, &mut buf_a, &mut state_a) {
234 Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
235 let bytes = n as u64;
236 metrics.record_inbound(bytes);
237 total_inbound += bytes;
238 activity = true;
239 any_ready = true;
240 }
241 Poll::Ready(Ok(CopyPoll::Finished)) => {
242 a_done = true;
243 any_ready = true;
244 }
245 Poll::Ready(Err(e)) => {
246 error = Some(e);
247 any_ready = true;
248 }
249 Poll::Pending => {}
250 }
251 }
252
253 if !b_done {
254 match poll_copy_direction(cx, &mut out_r, &mut in_w, &mut buf_b, &mut state_b) {
255 Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
256 let bytes = n as u64;
257 metrics.record_outbound(bytes);
258 total_outbound += bytes;
259 activity = true;
260 any_ready = true;
261 }
262 Poll::Ready(Ok(CopyPoll::Finished)) => {
263 b_done = true;
264 any_ready = true;
265 }
266 Poll::Ready(Err(e)) => {
267 error = Some(e);
268 any_ready = true;
269 }
270 Poll::Pending => {}
271 }
272 }
273
274 if let Some(e) = error {
275 return Poll::Ready(Err(e));
276 }
277
278 if any_ready {
279 Poll::Ready(Ok(activity))
280 } else {
281 Poll::Pending
282 }
283 });
284
285 tokio::select! {
286 result = both => {
287 let activity = result?;
288 if activity {
289 idle_sleep.as_mut().reset(TokioInstant::now() + idle_timeout);
290 }
291 }
292 _ = &mut idle_sleep => {
293 return Ok(RelayStats {
294 inbound: total_inbound,
295 outbound: total_outbound,
296 });
297 }
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use std::collections::VecDeque;
306 use std::sync::atomic::{AtomicU64, Ordering};
307 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
308
309 struct TestMetrics {
310 inbound: AtomicU64,
311 outbound: AtomicU64,
312 }
313
314 impl TestMetrics {
315 fn new() -> Self {
316 Self {
317 inbound: AtomicU64::new(0),
318 outbound: AtomicU64::new(0),
319 }
320 }
321 }
322
323 impl RelayMetrics for TestMetrics {
324 fn record_inbound(&self, bytes: u64) {
325 self.inbound.fetch_add(bytes, Ordering::Relaxed);
326 }
327 fn record_outbound(&self, bytes: u64) {
328 self.outbound.fetch_add(bytes, Ordering::Relaxed);
329 }
330 }
331
332 #[tokio::test]
333 async fn test_relay_basic() {
334 let (client, server_side) = duplex(1024);
335 let (target_side, target) = duplex(1024);
336
337 let metrics = TestMetrics::new();
338
339 let relay_handle = tokio::spawn(async move {
341 relay_bidirectional(
342 server_side,
343 target_side,
344 Duration::from_secs(5),
345 1024,
346 &metrics,
347 )
348 .await
349 });
350
351 let (mut client_r, mut client_w) = tokio::io::split(client);
353 let (mut target_r, mut target_w) = tokio::io::split(target);
354
355 client_w.write_all(b"hello").await.unwrap();
356 drop(client_w); let mut buf = vec![0u8; 1024];
359 let n = target_r.read(&mut buf).await.unwrap();
360 assert_eq!(&buf[..n], b"hello");
361
362 target_w.write_all(b"world").await.unwrap();
364 drop(target_w);
365
366 let n = client_r.read(&mut buf).await.unwrap();
367 assert_eq!(&buf[..n], b"world");
368
369 relay_handle.await.unwrap().unwrap();
371 }
372
373 #[tokio::test]
374 async fn test_relay_idle_timeout() {
375 let (client, server_side) = duplex(1024);
376 let (target_side, _target) = duplex(1024);
377
378 let start = TokioInstant::now();
379 let result = relay_bidirectional(
380 server_side,
381 target_side,
382 Duration::from_millis(50),
383 1024,
384 &NoOpMetrics,
385 )
386 .await;
387
388 result.unwrap();
389 assert!(start.elapsed() >= Duration::from_millis(50));
390
391 drop(client); }
393
394 struct MockReader {
400 chunks: VecDeque<Option<Vec<u8>>>,
402 pending_waker: Option<std::task::Waker>,
403 }
404
405 impl MockReader {
406 fn new(chunks: Vec<Option<Vec<u8>>>) -> Self {
407 Self {
408 chunks: chunks.into(),
409 pending_waker: None,
410 }
411 }
412 }
413
414 impl AsyncRead for MockReader {
415 fn poll_read(
416 mut self: Pin<&mut Self>,
417 cx: &mut Context<'_>,
418 buf: &mut ReadBuf<'_>,
419 ) -> Poll<io::Result<()>> {
420 match self.chunks.front() {
421 Some(Some(_)) => {
422 let data = self.chunks.pop_front().unwrap().unwrap();
423 buf.put_slice(&data);
424 Poll::Ready(Ok(()))
425 }
426 Some(None) => {
427 self.chunks.pop_front();
430 self.pending_waker = Some(cx.waker().clone());
431 cx.waker().wake_by_ref();
433 Poll::Pending
434 }
435 None => {
436 Poll::Ready(Ok(()))
438 }
439 }
440 }
441 }
442
443 struct FlushCountingWriter {
445 written: Vec<u8>,
446 flush_count: usize,
447 }
448
449 impl FlushCountingWriter {
450 fn new() -> Self {
451 Self {
452 written: Vec::new(),
453 flush_count: 0,
454 }
455 }
456 }
457
458 impl AsyncWrite for FlushCountingWriter {
459 fn poll_write(
460 mut self: Pin<&mut Self>,
461 _cx: &mut Context<'_>,
462 buf: &[u8],
463 ) -> Poll<io::Result<usize>> {
464 self.written.extend_from_slice(buf);
465 Poll::Ready(Ok(buf.len()))
466 }
467
468 fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
469 self.flush_count += 1;
470 Poll::Ready(Ok(()))
471 }
472
473 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
474 Poll::Ready(Ok(()))
475 }
476 }
477
478 #[tokio::test]
479 async fn test_flush_batching_consecutive_reads() {
480 let mut reader = MockReader::new(vec![
484 Some(b"aaa".to_vec()),
485 Some(b"bbb".to_vec()),
486 Some(b"ccc".to_vec()),
487 ]);
489 let mut writer = FlushCountingWriter::new();
490 let mut buf = vec![0u8; 64];
491 let mut state = CopyState::Reading(0);
492
493 let mut total_bytes = 0;
494 loop {
495 let result = std::future::poll_fn(|cx| {
496 poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
497 })
498 .await
499 .unwrap();
500 match result {
501 CopyPoll::Flushed(n) => total_bytes += n,
502 CopyPoll::Finished => break,
503 }
504 }
505
506 assert_eq!(writer.written, b"aaabbbccc");
507 assert_eq!(total_bytes, 9);
508 assert_eq!(
511 writer.flush_count, 1,
512 "consecutive reads should batch flushes"
513 );
514 }
515
516 #[tokio::test]
517 async fn test_flush_on_pending() {
518 let mut reader = MockReader::new(vec![
528 Some(b"aaa".to_vec()),
529 None, Some(b"bbb".to_vec()),
531 None, ]);
534 let mut writer = FlushCountingWriter::new();
535 let mut buf = vec![0u8; 64];
536 let mut state = CopyState::Reading(0);
537
538 let mut total_bytes = 0;
539 loop {
540 let result = std::future::poll_fn(|cx| {
541 poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
542 })
543 .await
544 .unwrap();
545 match result {
546 CopyPoll::Flushed(n) => total_bytes += n,
547 CopyPoll::Finished => break,
548 }
549 }
550
551 assert_eq!(writer.written, b"aaabbb");
552 assert_eq!(total_bytes, 6);
553 assert_eq!(writer.flush_count, 2, "should flush once per Pending gap");
555 }
556
557 #[tokio::test]
558 async fn test_flush_batching_burst_then_pending() {
559 let mut reader = MockReader::new(vec![
562 Some(b"a".to_vec()),
563 Some(b"b".to_vec()),
564 Some(b"c".to_vec()),
565 None, Some(b"d".to_vec()),
567 ]);
569 let mut writer = FlushCountingWriter::new();
570 let mut buf = vec![0u8; 64];
571 let mut state = CopyState::Reading(0);
572
573 let mut total_bytes = 0;
574 loop {
575 let result = std::future::poll_fn(|cx| {
576 poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
577 })
578 .await
579 .unwrap();
580 match result {
581 CopyPoll::Flushed(n) => total_bytes += n,
582 CopyPoll::Finished => break,
583 }
584 }
585
586 assert_eq!(writer.written, b"abcd");
587 assert_eq!(total_bytes, 4);
588 assert_eq!(
589 writer.flush_count, 2,
590 "burst then pending then EOF = 2 flushes"
591 );
592 }
593}