Skip to main content

trojan_core/io/
relay.rs

1//! Bidirectional data relay with configurable metrics.
2//!
3//! This module provides a generic bidirectional relay that can be used by both
4//! server and client implementations. Metrics recording is abstracted via the
5//! `RelayMetrics` trait, allowing each implementation to provide its own
6//! metrics backend.
7//!
8//! Each direction is driven as an independent poll-based state machine within
9//! a single future, so back-pressure on one direction never stalls the other.
10//! This prevents deadlocks in multi-hop relay chains.
11
12use std::io;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::time::Instant as TokioInstant;
19
20/// Trait for recording relay metrics.
21///
22/// Implementors can record bytes transferred in each direction.
23/// The server implementation typically records to Prometheus,
24/// while clients may use a no-op or custom implementation.
25pub trait RelayMetrics {
26    /// Record bytes received from inbound (client -> server direction).
27    fn record_inbound(&self, bytes: u64);
28    /// Record bytes sent to outbound (server -> target direction).
29    fn record_outbound(&self, bytes: u64);
30}
31
32/// No-op metrics implementation for cases where metrics aren't needed.
33#[derive(Debug, Clone, Copy, Default)]
34pub struct NoOpMetrics;
35
36impl RelayMetrics for NoOpMetrics {
37    #[inline]
38    fn record_inbound(&self, _bytes: u64) {}
39    #[inline]
40    fn record_outbound(&self, _bytes: u64) {}
41}
42
43/// State machine for one-directional copy with deferred flush.
44///
45/// Unlike a naive read→write→flush loop, this batches multiple read/write
46/// cycles before flushing. A flush only happens when the reader returns
47/// `Pending` (no more data immediately available) or on EOF. This mirrors
48/// the strategy used by `tokio::io::copy` and avoids excessive flush
49/// syscalls on buffered writers like TLS streams.
50enum CopyState {
51    Reading(usize),               // accumulated bytes since last flush
52    Writing(usize, usize, usize), // (pos, len, accumulated)
53    Flushing(usize, bool),        // (total bytes to report, is_eof)
54    ShuttingDown,
55    Done,
56}
57
58/// Result of polling one copy direction.
59enum CopyPoll {
60    /// Data was flushed — contains byte count for metrics.
61    Flushed(usize),
62    /// Direction finished (EOF + shutdown).
63    Finished,
64}
65
66/// Poll-driven one-directional copy with deferred flush.
67///
68/// Reads and writes in a loop, only flushing when the reader has no more
69/// data immediately available (`Pending`) or at EOF. This batches multiple
70/// read/write cycles into a single flush, reducing syscall overhead on
71/// buffered writers (e.g. TLS streams).
72fn poll_copy_direction<R, W>(
73    cx: &mut Context<'_>,
74    reader: &mut R,
75    writer: &mut W,
76    buf: &mut [u8],
77    state: &mut CopyState,
78) -> Poll<io::Result<CopyPoll>>
79where
80    R: AsyncRead + Unpin + ?Sized,
81    W: AsyncWrite + Unpin + ?Sized,
82{
83    loop {
84        match state {
85            CopyState::Reading(flushed) => {
86                let mut read_buf = ReadBuf::new(buf);
87                match Pin::new(&mut *reader).poll_read(cx, &mut read_buf) {
88                    Poll::Ready(Ok(())) => {
89                        let n = read_buf.filled().len();
90                        if n == 0 {
91                            // EOF — flush any accumulated bytes, then shut down.
92                            if *flushed > 0 {
93                                let total = *flushed;
94                                *state = CopyState::Flushing(total, true);
95                            } else {
96                                *state = CopyState::ShuttingDown;
97                            }
98                        } else {
99                            let acc = *flushed;
100                            *state = CopyState::Writing(0, n, acc);
101                        }
102                    }
103                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
104                    Poll::Pending => {
105                        // Reader has no more data — flush accumulated bytes.
106                        if *flushed > 0 {
107                            let total = *flushed;
108                            *state = CopyState::Flushing(total, false);
109                        } else {
110                            return Poll::Pending;
111                        }
112                    }
113                }
114            }
115            CopyState::Writing(pos, len, acc) => {
116                match Pin::new(&mut *writer).poll_write(cx, &buf[*pos..*len]) {
117                    Poll::Ready(Ok(n)) => {
118                        *pos += n;
119                        if *pos >= *len {
120                            let total = *acc + *len;
121                            // Don't flush yet — try to read more data first.
122                            *state = CopyState::Reading(total);
123                        }
124                    }
125                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
126                    Poll::Pending => return Poll::Pending,
127                }
128            }
129            CopyState::Flushing(bytes, is_eof) => {
130                let bytes = *bytes;
131                let eof = *is_eof;
132                match Pin::new(&mut *writer).poll_flush(cx) {
133                    Poll::Ready(Ok(())) => {
134                        if eof {
135                            *state = CopyState::ShuttingDown;
136                        } else {
137                            *state = CopyState::Reading(0);
138                        }
139                        return Poll::Ready(Ok(CopyPoll::Flushed(bytes)));
140                    }
141                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
142                    Poll::Pending => return Poll::Pending,
143                }
144            }
145            CopyState::ShuttingDown => match Pin::new(&mut *writer).poll_shutdown(cx) {
146                Poll::Ready(_) => {
147                    *state = CopyState::Done;
148                    return Poll::Ready(Ok(CopyPoll::Finished));
149                }
150                Poll::Pending => return Poll::Pending,
151            },
152            CopyState::Done => return Poll::Ready(Ok(CopyPoll::Finished)),
153        }
154    }
155}
156
157/// Bytes transferred in each direction during a relay session.
158#[derive(Debug, Clone, Copy, Default)]
159pub struct RelayStats {
160    /// Bytes from inbound to outbound (client → target).
161    pub inbound: u64,
162    /// Bytes from outbound to inbound (target → client).
163    pub outbound: u64,
164}
165
166impl RelayStats {
167    /// Total bytes transferred in both directions.
168    #[inline]
169    pub fn total(self) -> u64 {
170        self.inbound + self.outbound
171    }
172}
173
174/// Bidirectional relay with proper half-close handling.
175///
176/// Both directions run concurrently within a single task using poll-based
177/// I/O, so back-pressure on one direction cannot stall the other. An
178/// idle-timeout fires when **neither** direction has transferred data
179/// within `idle_timeout`.
180///
181/// # Arguments
182///
183/// * `inbound` - The inbound stream (e.g., client connection)
184/// * `outbound` - The outbound stream (e.g., target server connection)
185/// * `idle_timeout` - Maximum time without data transfer before closing
186/// * `buffer_size` - Size of the read buffers
187/// * `metrics` - Metrics recorder for tracking bytes transferred
188pub async fn relay_bidirectional<A, B, M>(
189    inbound: A,
190    outbound: B,
191    idle_timeout: Duration,
192    buffer_size: usize,
193    metrics: &M,
194) -> io::Result<RelayStats>
195where
196    A: AsyncRead + AsyncWrite + Unpin,
197    B: AsyncRead + AsyncWrite + Unpin,
198    M: RelayMetrics,
199{
200    let (mut in_r, mut in_w) = tokio::io::split(inbound);
201    let (mut out_r, mut out_w) = tokio::io::split(outbound);
202
203    let mut buf_a = vec![0u8; buffer_size];
204    let mut buf_b = vec![0u8; buffer_size];
205    let mut state_a = CopyState::Reading(0);
206    let mut state_b = CopyState::Reading(0);
207
208    let idle_sleep = tokio::time::sleep(idle_timeout);
209    tokio::pin!(idle_sleep);
210
211    let mut a_done = false;
212    let mut b_done = false;
213    let mut total_inbound: u64 = 0;
214    let mut total_outbound: u64 = 0;
215
216    loop {
217        if a_done && b_done {
218            return Ok(RelayStats {
219                inbound: total_inbound,
220                outbound: total_outbound,
221            });
222        }
223
224        // Build a future that polls both directions concurrently.
225        // Each direction registers its own waker so either can make progress
226        // independently — one blocked write cannot stall the other direction.
227        let both = std::future::poll_fn(|cx| {
228            let mut any_ready = false;
229            let mut activity = false;
230            let mut error: Option<io::Error> = None;
231
232            if !a_done {
233                match poll_copy_direction(cx, &mut in_r, &mut out_w, &mut buf_a, &mut state_a) {
234                    Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
235                        let bytes = n as u64;
236                        metrics.record_inbound(bytes);
237                        total_inbound += bytes;
238                        activity = true;
239                        any_ready = true;
240                    }
241                    Poll::Ready(Ok(CopyPoll::Finished)) => {
242                        a_done = true;
243                        any_ready = true;
244                    }
245                    Poll::Ready(Err(e)) => {
246                        error = Some(e);
247                        any_ready = true;
248                    }
249                    Poll::Pending => {}
250                }
251            }
252
253            if !b_done {
254                match poll_copy_direction(cx, &mut out_r, &mut in_w, &mut buf_b, &mut state_b) {
255                    Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
256                        let bytes = n as u64;
257                        metrics.record_outbound(bytes);
258                        total_outbound += bytes;
259                        activity = true;
260                        any_ready = true;
261                    }
262                    Poll::Ready(Ok(CopyPoll::Finished)) => {
263                        b_done = true;
264                        any_ready = true;
265                    }
266                    Poll::Ready(Err(e)) => {
267                        error = Some(e);
268                        any_ready = true;
269                    }
270                    Poll::Pending => {}
271                }
272            }
273
274            if let Some(e) = error {
275                return Poll::Ready(Err(e));
276            }
277
278            if any_ready {
279                Poll::Ready(Ok(activity))
280            } else {
281                Poll::Pending
282            }
283        });
284
285        tokio::select! {
286            result = both => {
287                let activity = result?;
288                if activity {
289                    idle_sleep.as_mut().reset(TokioInstant::now() + idle_timeout);
290                }
291            }
292            _ = &mut idle_sleep => {
293                return Ok(RelayStats {
294                    inbound: total_inbound,
295                    outbound: total_outbound,
296                });
297            }
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use std::collections::VecDeque;
306    use std::sync::atomic::{AtomicU64, Ordering};
307    use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
308
309    struct TestMetrics {
310        inbound: AtomicU64,
311        outbound: AtomicU64,
312    }
313
314    impl TestMetrics {
315        fn new() -> Self {
316            Self {
317                inbound: AtomicU64::new(0),
318                outbound: AtomicU64::new(0),
319            }
320        }
321    }
322
323    impl RelayMetrics for TestMetrics {
324        fn record_inbound(&self, bytes: u64) {
325            self.inbound.fetch_add(bytes, Ordering::Relaxed);
326        }
327        fn record_outbound(&self, bytes: u64) {
328            self.outbound.fetch_add(bytes, Ordering::Relaxed);
329        }
330    }
331
332    #[tokio::test]
333    async fn test_relay_basic() {
334        let (client, server_side) = duplex(1024);
335        let (target_side, target) = duplex(1024);
336
337        let metrics = TestMetrics::new();
338
339        // Spawn relay
340        let relay_handle = tokio::spawn(async move {
341            relay_bidirectional(
342                server_side,
343                target_side,
344                Duration::from_secs(5),
345                1024,
346                &metrics,
347            )
348            .await
349        });
350
351        // Client sends data
352        let (mut client_r, mut client_w) = tokio::io::split(client);
353        let (mut target_r, mut target_w) = tokio::io::split(target);
354
355        client_w.write_all(b"hello").await.unwrap();
356        drop(client_w); // Close write side
357
358        let mut buf = vec![0u8; 1024];
359        let n = target_r.read(&mut buf).await.unwrap();
360        assert_eq!(&buf[..n], b"hello");
361
362        // Target sends response
363        target_w.write_all(b"world").await.unwrap();
364        drop(target_w);
365
366        let n = client_r.read(&mut buf).await.unwrap();
367        assert_eq!(&buf[..n], b"world");
368
369        // Relay should complete
370        relay_handle.await.unwrap().unwrap();
371    }
372
373    #[tokio::test]
374    async fn test_relay_idle_timeout() {
375        let (client, server_side) = duplex(1024);
376        let (target_side, _target) = duplex(1024);
377
378        let start = TokioInstant::now();
379        let result = relay_bidirectional(
380            server_side,
381            target_side,
382            Duration::from_millis(50),
383            1024,
384            &NoOpMetrics,
385        )
386        .await;
387
388        result.unwrap();
389        assert!(start.elapsed() >= Duration::from_millis(50));
390
391        drop(client); // cleanup
392    }
393
394    // ── Flush-batching tests ──
395
396    /// A mock reader that yields chunks from a queue.
397    /// Returns `Pending` (with waker notification) between groups separated
398    /// by `None` entries, simulating data arriving in bursts.
399    struct MockReader {
400        /// `Some(data)` = a read returning data, `None` = return Pending once.
401        chunks: VecDeque<Option<Vec<u8>>>,
402        pending_waker: Option<std::task::Waker>,
403    }
404
405    impl MockReader {
406        fn new(chunks: Vec<Option<Vec<u8>>>) -> Self {
407            Self {
408                chunks: chunks.into(),
409                pending_waker: None,
410            }
411        }
412    }
413
414    impl AsyncRead for MockReader {
415        fn poll_read(
416            mut self: Pin<&mut Self>,
417            cx: &mut Context<'_>,
418            buf: &mut ReadBuf<'_>,
419        ) -> Poll<io::Result<()>> {
420            match self.chunks.front() {
421                Some(Some(_)) => {
422                    let data = self.chunks.pop_front().unwrap().unwrap();
423                    buf.put_slice(&data);
424                    Poll::Ready(Ok(()))
425                }
426                Some(None) => {
427                    // Consume the Pending marker, wake immediately so the
428                    // next poll will return the next chunk.
429                    self.chunks.pop_front();
430                    self.pending_waker = Some(cx.waker().clone());
431                    // Schedule a wake so the state machine advances.
432                    cx.waker().wake_by_ref();
433                    Poll::Pending
434                }
435                None => {
436                    // EOF
437                    Poll::Ready(Ok(()))
438                }
439            }
440        }
441    }
442
443    /// A writer that counts flush calls and records written data.
444    struct FlushCountingWriter {
445        written: Vec<u8>,
446        flush_count: usize,
447    }
448
449    impl FlushCountingWriter {
450        fn new() -> Self {
451            Self {
452                written: Vec::new(),
453                flush_count: 0,
454            }
455        }
456    }
457
458    impl AsyncWrite for FlushCountingWriter {
459        fn poll_write(
460            mut self: Pin<&mut Self>,
461            _cx: &mut Context<'_>,
462            buf: &[u8],
463        ) -> Poll<io::Result<usize>> {
464            self.written.extend_from_slice(buf);
465            Poll::Ready(Ok(buf.len()))
466        }
467
468        fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
469            self.flush_count += 1;
470            Poll::Ready(Ok(()))
471        }
472
473        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
474            Poll::Ready(Ok(()))
475        }
476    }
477
478    #[tokio::test]
479    async fn test_flush_batching_consecutive_reads() {
480        // Simulate 3 chunks arriving in a burst (no Pending between them),
481        // followed by EOF. The state machine should batch all 3 writes into
482        // a single flush (the EOF flush).
483        let mut reader = MockReader::new(vec![
484            Some(b"aaa".to_vec()),
485            Some(b"bbb".to_vec()),
486            Some(b"ccc".to_vec()),
487            // EOF follows (empty queue)
488        ]);
489        let mut writer = FlushCountingWriter::new();
490        let mut buf = vec![0u8; 64];
491        let mut state = CopyState::Reading(0);
492
493        let mut total_bytes = 0;
494        loop {
495            let result = std::future::poll_fn(|cx| {
496                poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
497            })
498            .await
499            .unwrap();
500            match result {
501                CopyPoll::Flushed(n) => total_bytes += n,
502                CopyPoll::Finished => break,
503            }
504        }
505
506        assert_eq!(writer.written, b"aaabbbccc");
507        assert_eq!(total_bytes, 9);
508        // All 3 chunks were available consecutively — should batch into 1 flush
509        // (the EOF-triggered flush), not 3 separate flushes.
510        assert_eq!(
511            writer.flush_count, 1,
512            "consecutive reads should batch flushes"
513        );
514    }
515
516    #[tokio::test]
517    async fn test_flush_on_pending() {
518        // Simulate: chunk1, Pending, chunk2, Pending, EOF.
519        // Should flush after each Pending (2 flushes) plus EOF flush (but
520        // EOF after Pending with 0 accumulated goes straight to shutdown).
521        // Actually: chunk1 → write → Reading(3) → Pending → Flushing(3) → flush#1
522        //           chunk2 → write → Reading(3) → Pending → but no accumulated → Pending
523        //           Wait, after flush#1 we reset to Reading(0), then read chunk2...
524        // Let me trace: chunk1 → write → Reading(3) → Pending → Flush(3, false) → flush#1
525        //               Reading(0) → chunk2 → write → Reading(3) → Pending → Flush(3, false) → flush#2
526        //               Reading(0) → EOF → ShuttingDown → Finished
527        let mut reader = MockReader::new(vec![
528            Some(b"aaa".to_vec()),
529            None, // Pending
530            Some(b"bbb".to_vec()),
531            None, // Pending
532                  // EOF
533        ]);
534        let mut writer = FlushCountingWriter::new();
535        let mut buf = vec![0u8; 64];
536        let mut state = CopyState::Reading(0);
537
538        let mut total_bytes = 0;
539        loop {
540            let result = std::future::poll_fn(|cx| {
541                poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
542            })
543            .await
544            .unwrap();
545            match result {
546                CopyPoll::Flushed(n) => total_bytes += n,
547                CopyPoll::Finished => break,
548            }
549        }
550
551        assert_eq!(writer.written, b"aaabbb");
552        assert_eq!(total_bytes, 6);
553        // 2 flushes: one after each Pending gap.
554        assert_eq!(writer.flush_count, 2, "should flush once per Pending gap");
555    }
556
557    #[tokio::test]
558    async fn test_flush_batching_burst_then_pending() {
559        // 3 chunks in a burst, then Pending, then 1 more chunk, then EOF.
560        // Should produce 2 flushes: one for the burst (at Pending), one at EOF.
561        let mut reader = MockReader::new(vec![
562            Some(b"a".to_vec()),
563            Some(b"b".to_vec()),
564            Some(b"c".to_vec()),
565            None, // Pending — triggers flush of accumulated 3 bytes
566            Some(b"d".to_vec()),
567            // EOF — triggers flush of 1 byte
568        ]);
569        let mut writer = FlushCountingWriter::new();
570        let mut buf = vec![0u8; 64];
571        let mut state = CopyState::Reading(0);
572
573        let mut total_bytes = 0;
574        loop {
575            let result = std::future::poll_fn(|cx| {
576                poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
577            })
578            .await
579            .unwrap();
580            match result {
581                CopyPoll::Flushed(n) => total_bytes += n,
582                CopyPoll::Finished => break,
583            }
584        }
585
586        assert_eq!(writer.written, b"abcd");
587        assert_eq!(total_bytes, 4);
588        assert_eq!(
589            writer.flush_count, 2,
590            "burst then pending then EOF = 2 flushes"
591        );
592    }
593}