trojan_core/io/
prefixed.rs1use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14pub struct PrefixedStream<S> {
34 prefix: Bytes,
35 pos: usize,
36 inner: S,
37}
38
39impl<S> PrefixedStream<S> {
40 pub fn new(prefix: Bytes, inner: S) -> Self {
47 Self {
48 prefix,
49 pos: 0,
50 inner,
51 }
52 }
53
54 pub fn prefix_remaining(&self) -> usize {
56 self.prefix.len().saturating_sub(self.pos)
57 }
58
59 pub fn into_inner(self) -> S {
63 self.inner
64 }
65}
66
67impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
68 fn poll_read(
69 mut self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &mut ReadBuf<'_>,
72 ) -> Poll<std::io::Result<()>> {
73 if self.pos < self.prefix.len() {
75 let remaining = &self.prefix[self.pos..];
76 let to_copy = remaining.len().min(buf.remaining());
77 buf.put_slice(&remaining[..to_copy]);
78 self.pos += to_copy;
79 return Poll::Ready(Ok(()));
80 }
81 Pin::new(&mut self.inner).poll_read(cx, buf)
83 }
84}
85
86impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
87 fn poll_write(
88 mut self: Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 data: &[u8],
91 ) -> Poll<std::io::Result<usize>> {
92 Pin::new(&mut self.inner).poll_write(cx, data)
93 }
94
95 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
96 Pin::new(&mut self.inner).poll_flush(cx)
97 }
98
99 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
100 Pin::new(&mut self.inner).poll_shutdown(cx)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
108
109 #[tokio::test]
110 async fn test_prefixed_stream_read() {
111 let (mut client, server) = duplex(1024);
112
113 let prefix = Bytes::from_static(b"prefix:");
115 let mut prefixed = PrefixedStream::new(prefix, server);
116
117 client.write_all(b"suffix").await.unwrap();
119 drop(client);
120
121 let mut buf = vec![0u8; 1024];
123 let mut total = Vec::new();
124
125 loop {
126 let n = prefixed.read(&mut buf).await.unwrap();
127 if n == 0 {
128 break;
129 }
130 total.extend_from_slice(&buf[..n]);
131 }
132
133 assert_eq!(total, b"prefix:suffix");
134 }
135
136 #[tokio::test]
137 async fn test_prefixed_stream_partial_read() {
138 let (_client, server) = duplex(1024);
139
140 let prefix = Bytes::from_static(b"hello world");
141 let mut prefixed = PrefixedStream::new(prefix, server);
142
143 let mut buf = [0u8; 5];
145 let n = prefixed.read(&mut buf).await.unwrap();
146 assert_eq!(n, 5);
147 assert_eq!(&buf, b"hello");
148
149 let n = prefixed.read(&mut buf).await.unwrap();
150 assert_eq!(n, 5);
151 assert_eq!(&buf, b" worl");
152
153 let n = prefixed.read(&mut buf).await.unwrap();
154 assert_eq!(n, 1);
155 assert_eq!(&buf[..1], b"d");
156 }
157
158 #[tokio::test]
159 async fn test_prefixed_stream_write_passthrough() {
160 let (mut client, server) = duplex(1024);
161
162 let prefix = Bytes::from_static(b"prefix");
163 let mut prefixed = PrefixedStream::new(prefix, server);
164
165 prefixed.write_all(b"hello").await.unwrap();
167
168 let mut buf = [0u8; 10];
169 let n = client.read(&mut buf).await.unwrap();
170 assert_eq!(&buf[..n], b"hello");
171 }
172
173 #[tokio::test]
174 async fn test_prefix_remaining() {
175 let (_client, server) = duplex(1024);
176
177 let prefix = Bytes::from_static(b"hello");
178 let mut prefixed = PrefixedStream::new(prefix, server);
179
180 assert_eq!(prefixed.prefix_remaining(), 5);
181
182 let mut buf = [0u8; 3];
183 let n = prefixed.read(&mut buf).await.unwrap();
184 assert_eq!(n, 3);
185 assert_eq!(prefixed.prefix_remaining(), 2);
186 }
187}