trojan_core/io/
prefixed.rs1use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14#[derive(Debug)]
34pub struct PrefixedStream<S> {
35 prefix: Bytes,
36 pos: usize,
37 inner: S,
38}
39
40impl<S> PrefixedStream<S> {
41 pub fn new(prefix: Bytes, inner: S) -> Self {
48 Self {
49 prefix,
50 pos: 0,
51 inner,
52 }
53 }
54
55 pub fn prefix_remaining(&self) -> usize {
57 self.prefix.len().saturating_sub(self.pos)
58 }
59
60 pub fn into_inner(self) -> S {
64 self.inner
65 }
66}
67
68impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
69 fn poll_read(
70 mut self: Pin<&mut Self>,
71 cx: &mut Context<'_>,
72 buf: &mut ReadBuf<'_>,
73 ) -> Poll<std::io::Result<()>> {
74 if self.pos < self.prefix.len() {
76 let remaining = &self.prefix[self.pos..];
77 let to_copy = remaining.len().min(buf.remaining());
78 buf.put_slice(&remaining[..to_copy]);
79 self.pos += to_copy;
80 return Poll::Ready(Ok(()));
81 }
82 Pin::new(&mut self.inner).poll_read(cx, buf)
84 }
85}
86
87impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
88 fn poll_write(
89 mut self: Pin<&mut Self>,
90 cx: &mut Context<'_>,
91 data: &[u8],
92 ) -> Poll<std::io::Result<usize>> {
93 Pin::new(&mut self.inner).poll_write(cx, data)
94 }
95
96 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
97 Pin::new(&mut self.inner).poll_flush(cx)
98 }
99
100 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
101 Pin::new(&mut self.inner).poll_shutdown(cx)
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
109
110 #[tokio::test]
111 async fn test_prefixed_stream_read() {
112 let (mut client, server) = duplex(1024);
113
114 let prefix = Bytes::from_static(b"prefix:");
116 let mut prefixed = PrefixedStream::new(prefix, server);
117
118 client.write_all(b"suffix").await.unwrap();
120 drop(client);
121
122 let mut buf = vec![0u8; 1024];
124 let mut total = Vec::new();
125
126 loop {
127 let n = prefixed.read(&mut buf).await.unwrap();
128 if n == 0 {
129 break;
130 }
131 total.extend_from_slice(&buf[..n]);
132 }
133
134 assert_eq!(total, b"prefix:suffix");
135 }
136
137 #[tokio::test]
138 async fn test_prefixed_stream_partial_read() {
139 let (_client, server) = duplex(1024);
140
141 let prefix = Bytes::from_static(b"hello world");
142 let mut prefixed = PrefixedStream::new(prefix, server);
143
144 let mut buf = [0u8; 5];
146 let n = prefixed.read(&mut buf).await.unwrap();
147 assert_eq!(n, 5);
148 assert_eq!(&buf, b"hello");
149
150 let n = prefixed.read(&mut buf).await.unwrap();
151 assert_eq!(n, 5);
152 assert_eq!(&buf, b" worl");
153
154 let n = prefixed.read(&mut buf).await.unwrap();
155 assert_eq!(n, 1);
156 assert_eq!(&buf[..1], b"d");
157 }
158
159 #[tokio::test]
160 async fn test_prefixed_stream_write_passthrough() {
161 let (mut client, server) = duplex(1024);
162
163 let prefix = Bytes::from_static(b"prefix");
164 let mut prefixed = PrefixedStream::new(prefix, server);
165
166 prefixed.write_all(b"hello").await.unwrap();
168
169 let mut buf = [0u8; 10];
170 let n = client.read(&mut buf).await.unwrap();
171 assert_eq!(&buf[..n], b"hello");
172 }
173
174 #[tokio::test]
175 async fn test_prefix_remaining() {
176 let (_client, server) = duplex(1024);
177
178 let prefix = Bytes::from_static(b"hello");
179 let mut prefixed = PrefixedStream::new(prefix, server);
180
181 assert_eq!(prefixed.prefix_remaining(), 5);
182
183 let mut buf = [0u8; 3];
184 let n = prefixed.read(&mut buf).await.unwrap();
185 assert_eq!(n, 3);
186 assert_eq!(prefixed.prefix_remaining(), 2);
187 }
188}