1use futures::stream::{Stream, StreamExt};
2use hyper::body::{Body, Bytes, HttpBody};
3use std::collections::VecDeque;
4use std::fmt;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll, Waker};
8
9#[derive(Clone, Debug)]
11pub struct StringError(String);
12
13impl fmt::Display for StringError {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 f.write_str(&self.0)
16 }
17}
18
19impl std::error::Error for StringError {}
20
21#[derive(Debug, Default, Clone)]
22struct ConsumerState {
23 waker: Option<Waker>,
24 cursor: usize,
25 active: bool,
26}
27
28#[derive(Debug)]
30struct SharedState {
31 buffer: VecDeque<Result<Bytes, StringError>>,
33 offset: usize,
35 is_done: bool,
37 consumers: [ConsumerState; 2],
39}
40
41impl Default for SharedState {
42 fn default() -> Self {
43 Self {
44 buffer: VecDeque::new(),
45 offset: 0,
46 is_done: false,
47 consumers: [
48 ConsumerState {
49 active: true,
50 ..Default::default()
51 },
52 ConsumerState {
53 active: true,
54 ..Default::default()
55 },
56 ],
57 }
58 }
59}
60
61#[derive(Debug)]
63pub struct BodyTeeStream {
64 shared: Arc<Mutex<SharedState>>,
65 id: usize,
66}
67
68pub async fn tee(mut hyper_body: Body) -> (Body, Body) {
70 if HttpBody::size_hint(&hyper_body).exact().is_some() {
71 let bytes = hyper::body::to_bytes(hyper_body)
74 .await
75 .expect("Failed to buffer known-size body");
76 return (hyper::Body::from(bytes.clone()), hyper::Body::from(bytes));
78 }
79
80 let shared_state = Arc::new(Mutex::new(SharedState::default()));
81
82 let s1 = BodyTeeStream {
83 shared: shared_state.clone(),
84 id: 0,
85 };
86
87 let s2 = BodyTeeStream {
88 shared: shared_state.clone(),
89 id: 1,
90 };
91
92 tokio::spawn(async move {
93 loop {
94 let result = hyper_body.next().await;
95 let mut state = shared_state.lock().unwrap();
96
97 let finished = if let Some(item) = result {
98 let item_to_store = item.map_err(|e| StringError(e.to_string()));
100 let is_err = item_to_store.is_err();
101 state.buffer.push_back(item_to_store);
102 is_err
103 } else {
104 true
105 };
106
107 if finished {
108 state.is_done = true;
109 }
110
111 for consumer in state.consumers.iter_mut().filter(|c| c.active) {
112 if let Some(waker) = consumer.waker.take() {
113 waker.wake();
114 }
115 }
116
117 drain_buffer(&mut state);
118
119 if finished {
120 break;
121 }
122 }
123 });
124
125 (Body::wrap_stream(s1), Body::wrap_stream(s2))
126}
127
128impl HttpBody for BodyTeeStream {
129 type Data = Bytes;
130 type Error = Box<dyn std::error::Error + Send + Sync>;
133
134 fn poll_data(
135 self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
138 let this = self.get_mut();
139 let mut state = this.shared.lock().unwrap();
140
141 let SharedState {
142 buffer,
143 offset,
144 is_done,
145 consumers,
146 ..
147 } = &mut *state;
148
149 let consumer = &mut consumers[this.id];
150
151 if consumer.cursor >= *offset {
152 let buffer_index = consumer.cursor - *offset;
153 if let Some(result) = buffer.get(buffer_index) {
154 consumer.cursor += 1;
155 return Poll::Ready(Some(result.clone().map_err(|e| Box::new(e) as Self::Error)));
158 }
159 }
160
161 if *is_done {
162 return Poll::Ready(None);
163 }
164
165 consumer.waker = Some(cx.waker().clone());
166 Poll::Pending
167 }
168
169 fn poll_trailers(
170 self: Pin<&mut Self>,
171 _cx: &mut Context<'_>,
172 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
173 Poll::Ready(Ok(None))
174 }
175
176 fn is_end_stream(&self) -> bool {
177 let state = self.shared.lock().unwrap();
178 if !state.is_done {
179 return false;
180 }
181 let consumer = &state.consumers[self.id];
182 let total_buffered_chunks = state.offset + state.buffer.len();
183 consumer.cursor >= total_buffered_chunks
184 }
185}
186
187impl Stream for BodyTeeStream {
189 type Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>;
190
191 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 self.poll_data(cx)
193 }
194}
195
196impl Drop for BodyTeeStream {
197 fn drop(&mut self) {
198 let mut state = self.shared.lock().unwrap();
199 state.consumers[self.id].active = false;
200
201 let other_id = 1 - self.id;
202 if state.consumers[other_id].active
203 && let Some(waker) = state.consumers[other_id].waker.take()
204 {
205 waker.wake();
206 }
207
208 drain_buffer(&mut state);
209 }
210}
211
212fn drain_buffer(state: &mut SharedState) {
214 let min_cursor = state
215 .consumers
216 .iter()
217 .filter(|c| c.active)
218 .map(|c| c.cursor)
219 .min()
220 .unwrap_or(state.offset + state.buffer.len());
221
222 let to_drain = min_cursor.saturating_sub(state.offset);
223 if to_drain > 0 {
224 state.buffer.drain(0..to_drain);
225 state.offset += to_drain;
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use futures::stream::{self, StreamExt};
233 use hyper::{Body, body::Bytes};
234 use std::convert::Infallible;
235
236 #[tokio::test]
237 async fn test_simple_duplication() {
238 let chunks = vec!["hello", " ", "world"];
239 let stream = stream::iter(chunks.clone()).map(|s| Ok::<_, Infallible>(Bytes::from(s)));
240 let body = Body::wrap_stream(stream);
241
242 let (body1, body2) = tee(body).await;
243
244 let res1_fut = body1
245 .map(|chunk_res| chunk_res.unwrap())
246 .collect::<Vec<_>>();
247 let res2_fut = body2
248 .map(|chunk_res| chunk_res.unwrap())
249 .collect::<Vec<_>>();
250
251 let (res1, res2) = futures::join!(res1_fut, res2_fut);
252
253 let res1_str: Vec<&str> = res1
254 .iter()
255 .map(|b| std::str::from_utf8(b).unwrap())
256 .collect();
257 let res2_str: Vec<&str> = res2
258 .iter()
259 .map(|b| std::str::from_utf8(b).unwrap())
260 .collect();
261
262 assert_eq!(res1_str, chunks);
263 assert_eq!(res2_str, chunks);
264 }
265
266 #[tokio::test]
267 async fn test_error_propagation() {
268 let error = std::io::Error::new(std::io::ErrorKind::Other, "test error");
269 let stream = stream::iter(vec![
270 Ok(Bytes::from("one")),
271 Err(error),
272 Ok(Bytes::from("two")),
273 ]);
274 let body = Body::wrap_stream(stream);
275
276 let (mut body1, mut body2) = tee(body).await;
277
278 assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("one"));
279 let err1 = body1.next().await.unwrap().unwrap_err();
280 assert!(
281 err1.to_string().contains("test error"),
282 "Got error: {}",
283 err1
284 );
285 assert!(
286 body1.next().await.is_none(),
287 "Stream should end after error"
288 );
289
290 assert_eq!(body2.next().await.unwrap().unwrap(), Bytes::from("one"));
291 let err2 = body2.next().await.unwrap().unwrap_err();
292 assert!(
293 err2.to_string().contains("test error"),
294 "Got error: {}",
295 err1
296 );
297 assert!(
298 body2.next().await.is_none(),
299 "Stream should end after error"
300 );
301 }
302
303 #[tokio::test]
304 async fn test_error_with_one_consumer_dropped() {
305 let error = std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "aborted");
306 let stream = stream::iter(vec![Ok(Bytes::from("first")), Err(error)]);
307 let body = Body::wrap_stream(stream);
308
309 let (mut body1, body2) = tee(body).await;
310
311 drop(body2);
312
313 assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("first"));
314 let err1 = body1.next().await.unwrap().unwrap_err();
315 assert!(err1.to_string().contains("aborted"));
316 assert!(
317 body1.next().await.is_none(),
318 "Stream should end after error"
319 );
320 }
321
322 #[tokio::test]
323 async fn test_size_hint_preservation() {
324 let data = "this has a known size";
325 let body = Body::from(data);
326 let original_size_hint = HttpBody::size_hint(&body);
327
328 assert_eq!(original_size_hint.exact(), Some(data.len() as u64));
329
330 let (body1, body2) = tee(body).await;
331
332 assert_eq!(
333 HttpBody::size_hint(&body1).exact(),
334 original_size_hint.exact()
335 );
336 assert_eq!(
337 HttpBody::size_hint(&body2).exact(),
338 original_size_hint.exact()
339 );
340
341 let body1_bytes = hyper::body::to_bytes(body1).await.unwrap();
342 let body2_bytes = hyper::body::to_bytes(body2).await.unwrap();
343
344 assert_eq!(body1_bytes, data.as_bytes());
345 assert_eq!(body2_bytes, data.as_bytes());
346 }
347}