Skip to main content

trojan_core/io/
prefixed.rs

1//! Prefixed stream adapter for replaying buffered data.
2//!
3//! This module provides `PrefixedStream`, a stream wrapper that yields
4//! pre-buffered bytes before reading from the inner stream. This is useful
5//! for protocol detection where you need to peek at incoming data without
6//! consuming it from the underlying stream.
7
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14/// A stream wrapper that yields a prefetched prefix before reading from the inner stream.
15///
16/// This is commonly used in mixed-mode protocol detection (e.g., WebSocket upgrade)
17/// where we need to read HTTP headers to determine the protocol, then replay those
18/// bytes to the actual protocol handler.
19///
20/// # Example
21///
22/// ```ignore
23/// use trojan_core::io::PrefixedStream;
24/// use bytes::Bytes;
25///
26/// // After reading some bytes for protocol detection
27/// let buffered = Bytes::from(b"GET / HTTP/1.1\r\n...");
28/// let prefixed = PrefixedStream::new(buffered, tcp_stream);
29///
30/// // Now the prefixed stream will first yield the buffered bytes,
31/// // then continue reading from tcp_stream
32/// ```
33#[derive(Debug)]
34pub struct PrefixedStream<S> {
35    prefix: Bytes,
36    pos: usize,
37    inner: S,
38}
39
40impl<S> PrefixedStream<S> {
41    /// Create a new prefixed stream.
42    ///
43    /// # Arguments
44    ///
45    /// * `prefix` - The buffered bytes to yield first
46    /// * `inner` - The underlying stream to read from after prefix is exhausted
47    pub fn new(prefix: Bytes, inner: S) -> Self {
48        Self {
49            prefix,
50            pos: 0,
51            inner,
52        }
53    }
54
55    /// Returns the remaining unread prefix bytes.
56    pub fn prefix_remaining(&self) -> usize {
57        self.prefix.len().saturating_sub(self.pos)
58    }
59
60    /// Consumes the wrapper, returning the inner stream.
61    ///
62    /// Note: Any unread prefix bytes will be lost.
63    pub fn into_inner(self) -> S {
64        self.inner
65    }
66}
67
68impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
69    fn poll_read(
70        mut self: Pin<&mut Self>,
71        cx: &mut Context<'_>,
72        buf: &mut ReadBuf<'_>,
73    ) -> Poll<std::io::Result<()>> {
74        // First, yield any remaining prefix bytes
75        if self.pos < self.prefix.len() {
76            let remaining = &self.prefix[self.pos..];
77            let to_copy = remaining.len().min(buf.remaining());
78            buf.put_slice(&remaining[..to_copy]);
79            self.pos += to_copy;
80            return Poll::Ready(Ok(()));
81        }
82        // Then delegate to inner stream
83        Pin::new(&mut self.inner).poll_read(cx, buf)
84    }
85}
86
87impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
88    fn poll_write(
89        mut self: Pin<&mut Self>,
90        cx: &mut Context<'_>,
91        data: &[u8],
92    ) -> Poll<std::io::Result<usize>> {
93        Pin::new(&mut self.inner).poll_write(cx, data)
94    }
95
96    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
97        Pin::new(&mut self.inner).poll_flush(cx)
98    }
99
100    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
101        Pin::new(&mut self.inner).poll_shutdown(cx)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
109
110    #[tokio::test]
111    async fn test_prefixed_stream_read() {
112        let (mut client, server) = duplex(1024);
113
114        // Server will read from prefixed stream
115        let prefix = Bytes::from_static(b"prefix:");
116        let mut prefixed = PrefixedStream::new(prefix, server);
117
118        // Client sends some data
119        client.write_all(b"suffix").await.unwrap();
120        drop(client);
121
122        // Read should yield prefix first, then inner stream data
123        let mut buf = vec![0u8; 1024];
124        let mut total = Vec::new();
125
126        loop {
127            let n = prefixed.read(&mut buf).await.unwrap();
128            if n == 0 {
129                break;
130            }
131            total.extend_from_slice(&buf[..n]);
132        }
133
134        assert_eq!(total, b"prefix:suffix");
135    }
136
137    #[tokio::test]
138    async fn test_prefixed_stream_partial_read() {
139        let (_client, server) = duplex(1024);
140
141        let prefix = Bytes::from_static(b"hello world");
142        let mut prefixed = PrefixedStream::new(prefix, server);
143
144        // Read with small buffer
145        let mut buf = [0u8; 5];
146        let n = prefixed.read(&mut buf).await.unwrap();
147        assert_eq!(n, 5);
148        assert_eq!(&buf, b"hello");
149
150        let n = prefixed.read(&mut buf).await.unwrap();
151        assert_eq!(n, 5);
152        assert_eq!(&buf, b" worl");
153
154        let n = prefixed.read(&mut buf).await.unwrap();
155        assert_eq!(n, 1);
156        assert_eq!(&buf[..1], b"d");
157    }
158
159    #[tokio::test]
160    async fn test_prefixed_stream_write_passthrough() {
161        let (mut client, server) = duplex(1024);
162
163        let prefix = Bytes::from_static(b"prefix");
164        let mut prefixed = PrefixedStream::new(prefix, server);
165
166        // Write should go directly to inner stream
167        prefixed.write_all(b"hello").await.unwrap();
168
169        let mut buf = [0u8; 10];
170        let n = client.read(&mut buf).await.unwrap();
171        assert_eq!(&buf[..n], b"hello");
172    }
173
174    #[tokio::test]
175    async fn test_prefix_remaining() {
176        let (_client, server) = duplex(1024);
177
178        let prefix = Bytes::from_static(b"hello");
179        let mut prefixed = PrefixedStream::new(prefix, server);
180
181        assert_eq!(prefixed.prefix_remaining(), 5);
182
183        let mut buf = [0u8; 3];
184        let n = prefixed.read(&mut buf).await.unwrap();
185        assert_eq!(n, 3);
186        assert_eq!(prefixed.prefix_remaining(), 2);
187    }
188}