Skip to main content

trojan_core/io/
prefixed.rs

1//! Prefixed stream adapter for replaying buffered data.
2//!
3//! This module provides `PrefixedStream`, a stream wrapper that yields
4//! pre-buffered bytes before reading from the inner stream. This is useful
5//! for protocol detection where you need to peek at incoming data without
6//! consuming it from the underlying stream.
7
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14/// A stream wrapper that yields a prefetched prefix before reading from the inner stream.
15///
16/// This is commonly used in mixed-mode protocol detection (e.g., WebSocket upgrade)
17/// where we need to read HTTP headers to determine the protocol, then replay those
18/// bytes to the actual protocol handler.
19///
20/// # Example
21///
22/// ```ignore
23/// use trojan_core::io::PrefixedStream;
24/// use bytes::Bytes;
25///
26/// // After reading some bytes for protocol detection
27/// let buffered = Bytes::from(b"GET / HTTP/1.1\r\n...");
28/// let prefixed = PrefixedStream::new(buffered, tcp_stream);
29///
30/// // Now the prefixed stream will first yield the buffered bytes,
31/// // then continue reading from tcp_stream
32/// ```
33pub struct PrefixedStream<S> {
34    prefix: Bytes,
35    pos: usize,
36    inner: S,
37}
38
39impl<S> PrefixedStream<S> {
40    /// Create a new prefixed stream.
41    ///
42    /// # Arguments
43    ///
44    /// * `prefix` - The buffered bytes to yield first
45    /// * `inner` - The underlying stream to read from after prefix is exhausted
46    pub fn new(prefix: Bytes, inner: S) -> Self {
47        Self {
48            prefix,
49            pos: 0,
50            inner,
51        }
52    }
53
54    /// Returns the remaining unread prefix bytes.
55    pub fn prefix_remaining(&self) -> usize {
56        self.prefix.len().saturating_sub(self.pos)
57    }
58
59    /// Consumes the wrapper, returning the inner stream.
60    ///
61    /// Note: Any unread prefix bytes will be lost.
62    pub fn into_inner(self) -> S {
63        self.inner
64    }
65}
66
67impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
68    fn poll_read(
69        mut self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &mut ReadBuf<'_>,
72    ) -> Poll<std::io::Result<()>> {
73        // First, yield any remaining prefix bytes
74        if self.pos < self.prefix.len() {
75            let remaining = &self.prefix[self.pos..];
76            let to_copy = remaining.len().min(buf.remaining());
77            buf.put_slice(&remaining[..to_copy]);
78            self.pos += to_copy;
79            return Poll::Ready(Ok(()));
80        }
81        // Then delegate to inner stream
82        Pin::new(&mut self.inner).poll_read(cx, buf)
83    }
84}
85
86impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
87    fn poll_write(
88        mut self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90        data: &[u8],
91    ) -> Poll<std::io::Result<usize>> {
92        Pin::new(&mut self.inner).poll_write(cx, data)
93    }
94
95    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
96        Pin::new(&mut self.inner).poll_flush(cx)
97    }
98
99    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
100        Pin::new(&mut self.inner).poll_shutdown(cx)
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
108
109    #[tokio::test]
110    async fn test_prefixed_stream_read() {
111        let (mut client, server) = duplex(1024);
112
113        // Server will read from prefixed stream
114        let prefix = Bytes::from_static(b"prefix:");
115        let mut prefixed = PrefixedStream::new(prefix, server);
116
117        // Client sends some data
118        client.write_all(b"suffix").await.unwrap();
119        drop(client);
120
121        // Read should yield prefix first, then inner stream data
122        let mut buf = vec![0u8; 1024];
123        let mut total = Vec::new();
124
125        loop {
126            let n = prefixed.read(&mut buf).await.unwrap();
127            if n == 0 {
128                break;
129            }
130            total.extend_from_slice(&buf[..n]);
131        }
132
133        assert_eq!(total, b"prefix:suffix");
134    }
135
136    #[tokio::test]
137    async fn test_prefixed_stream_partial_read() {
138        let (_client, server) = duplex(1024);
139
140        let prefix = Bytes::from_static(b"hello world");
141        let mut prefixed = PrefixedStream::new(prefix, server);
142
143        // Read with small buffer
144        let mut buf = [0u8; 5];
145        let n = prefixed.read(&mut buf).await.unwrap();
146        assert_eq!(n, 5);
147        assert_eq!(&buf, b"hello");
148
149        let n = prefixed.read(&mut buf).await.unwrap();
150        assert_eq!(n, 5);
151        assert_eq!(&buf, b" worl");
152
153        let n = prefixed.read(&mut buf).await.unwrap();
154        assert_eq!(n, 1);
155        assert_eq!(&buf[..1], b"d");
156    }
157
158    #[tokio::test]
159    async fn test_prefixed_stream_write_passthrough() {
160        let (mut client, server) = duplex(1024);
161
162        let prefix = Bytes::from_static(b"prefix");
163        let mut prefixed = PrefixedStream::new(prefix, server);
164
165        // Write should go directly to inner stream
166        prefixed.write_all(b"hello").await.unwrap();
167
168        let mut buf = [0u8; 10];
169        let n = client.read(&mut buf).await.unwrap();
170        assert_eq!(&buf[..n], b"hello");
171    }
172
173    #[tokio::test]
174    async fn test_prefix_remaining() {
175        let (_client, server) = duplex(1024);
176
177        let prefix = Bytes::from_static(b"hello");
178        let mut prefixed = PrefixedStream::new(prefix, server);
179
180        assert_eq!(prefixed.prefix_remaining(), 5);
181
182        let mut buf = [0u8; 3];
183        let n = prefixed.read(&mut buf).await.unwrap();
184        assert_eq!(n, 3);
185        assert_eq!(prefixed.prefix_remaining(), 2);
186    }
187}