Skip to main content

viceroy_lib/
body_tee.rs

1use futures::stream::{Stream, StreamExt};
2use hyper::body::{Body, Bytes, HttpBody};
3use std::collections::VecDeque;
4use std::fmt;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll, Waker};
8
9/// The "tee" needs a cloneable error that can be given to both forks of the output stream.
10#[derive(Clone, Debug)]
11pub struct StringError(String);
12
13impl fmt::Display for StringError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        f.write_str(&self.0)
16    }
17}
18
19impl std::error::Error for StringError {}
20
21#[derive(Debug, Default, Clone)]
22struct ConsumerState {
23    waker: Option<Waker>,
24    cursor: usize,
25    active: bool,
26}
27
28/// The shared state between the two output streams.
29#[derive(Debug)]
30struct SharedState {
31    /// The buffer holds chunks or an error from the source stream.
32    buffer: VecDeque<Result<Bytes, StringError>>,
33    /// The absolute index of the first element in the buffer.
34    offset: usize,
35    /// True if the source stream has finished.
36    is_done: bool,
37    /// State for the two consumer streams.
38    consumers: [ConsumerState; 2],
39}
40
41impl Default for SharedState {
42    fn default() -> Self {
43        Self {
44            buffer: VecDeque::new(),
45            offset: 0,
46            is_done: false,
47            consumers: [
48                ConsumerState {
49                    active: true,
50                    ..Default::default()
51                },
52                ConsumerState {
53                    active: true,
54                    ..Default::default()
55                },
56            ],
57        }
58    }
59}
60
61/// A stream that is one of two outputs from the tee operation.
62#[derive(Debug)]
63pub struct BodyTeeStream {
64    shared: Arc<Mutex<SharedState>>,
65    id: usize,
66}
67
68/// Tees a Body into two independent, error-propagating, and memory-safe streams.
69pub async fn tee(mut hyper_body: Body) -> (Body, Body) {
70    if HttpBody::size_hint(&hyper_body).exact().is_some() {
71        // If the size is known, we MUST buffer the body to preserve the
72        // Content-Length.
73        let bytes = hyper::body::to_bytes(hyper_body)
74            .await
75            .expect("Failed to buffer known-size body");
76        // `Bytes` is cheap to clone.
77        return (hyper::Body::from(bytes.clone()), hyper::Body::from(bytes));
78    }
79
80    let shared_state = Arc::new(Mutex::new(SharedState::default()));
81
82    let s1 = BodyTeeStream {
83        shared: shared_state.clone(),
84        id: 0,
85    };
86
87    let s2 = BodyTeeStream {
88        shared: shared_state.clone(),
89        id: 1,
90    };
91
92    tokio::spawn(async move {
93        loop {
94            let result = hyper_body.next().await;
95            let mut state = shared_state.lock().unwrap();
96
97            let finished = if let Some(item) = result {
98                // Convert any error into our simple, cloneable StringError.
99                let item_to_store = item.map_err(|e| StringError(e.to_string()));
100                let is_err = item_to_store.is_err();
101                state.buffer.push_back(item_to_store);
102                is_err
103            } else {
104                true
105            };
106
107            if finished {
108                state.is_done = true;
109            }
110
111            for consumer in state.consumers.iter_mut().filter(|c| c.active) {
112                if let Some(waker) = consumer.waker.take() {
113                    waker.wake();
114                }
115            }
116
117            drain_buffer(&mut state);
118
119            if finished {
120                break;
121            }
122        }
123    });
124
125    (Body::wrap_stream(s1), Body::wrap_stream(s2))
126}
127
128impl HttpBody for BodyTeeStream {
129    type Data = Bytes;
130    // The error type must be convertible into hyper's error type. A boxed
131    // standard error is the idiomatic way to do this.
132    type Error = Box<dyn std::error::Error + Send + Sync>;
133
134    fn poll_data(
135        self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
138        let this = self.get_mut();
139        let mut state = this.shared.lock().unwrap();
140
141        let SharedState {
142            buffer,
143            offset,
144            is_done,
145            consumers,
146            ..
147        } = &mut *state;
148
149        let consumer = &mut consumers[this.id];
150
151        if consumer.cursor >= *offset {
152            let buffer_index = consumer.cursor - *offset;
153            if let Some(result) = buffer.get(buffer_index) {
154                consumer.cursor += 1;
155                // FIX: When we read from the buffer, explicitly cast the boxed concrete
156                // error to a boxed trait object to satisfy the type checker.
157                return Poll::Ready(Some(result.clone().map_err(|e| Box::new(e) as Self::Error)));
158            }
159        }
160
161        if *is_done {
162            return Poll::Ready(None);
163        }
164
165        consumer.waker = Some(cx.waker().clone());
166        Poll::Pending
167    }
168
169    fn poll_trailers(
170        self: Pin<&mut Self>,
171        _cx: &mut Context<'_>,
172    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
173        Poll::Ready(Ok(None))
174    }
175
176    fn is_end_stream(&self) -> bool {
177        let state = self.shared.lock().unwrap();
178        if !state.is_done {
179            return false;
180        }
181        let consumer = &state.consumers[self.id];
182        let total_buffered_chunks = state.offset + state.buffer.len();
183        consumer.cursor >= total_buffered_chunks
184    }
185}
186
187// so it can be used with `Body::wrap_stream`.
188impl Stream for BodyTeeStream {
189    type Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>;
190
191    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        self.poll_data(cx)
193    }
194}
195
196impl Drop for BodyTeeStream {
197    fn drop(&mut self) {
198        let mut state = self.shared.lock().unwrap();
199        state.consumers[self.id].active = false;
200
201        let other_id = 1 - self.id;
202        if state.consumers[other_id].active
203            && let Some(waker) = state.consumers[other_id].waker.take()
204        {
205            waker.wake();
206        }
207
208        drain_buffer(&mut state);
209    }
210}
211
212/// Helper to remove chunks from the buffer that all active consumers have read.
213fn drain_buffer(state: &mut SharedState) {
214    let min_cursor = state
215        .consumers
216        .iter()
217        .filter(|c| c.active)
218        .map(|c| c.cursor)
219        .min()
220        .unwrap_or(state.offset + state.buffer.len());
221
222    let to_drain = min_cursor.saturating_sub(state.offset);
223    if to_drain > 0 {
224        state.buffer.drain(0..to_drain);
225        state.offset += to_drain;
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use futures::stream::{self, StreamExt};
233    use hyper::{Body, body::Bytes};
234    use std::convert::Infallible;
235
236    #[tokio::test]
237    async fn test_simple_duplication() {
238        let chunks = vec!["hello", " ", "world"];
239        let stream = stream::iter(chunks.clone()).map(|s| Ok::<_, Infallible>(Bytes::from(s)));
240        let body = Body::wrap_stream(stream);
241
242        let (body1, body2) = tee(body).await;
243
244        let res1_fut = body1
245            .map(|chunk_res| chunk_res.unwrap())
246            .collect::<Vec<_>>();
247        let res2_fut = body2
248            .map(|chunk_res| chunk_res.unwrap())
249            .collect::<Vec<_>>();
250
251        let (res1, res2) = futures::join!(res1_fut, res2_fut);
252
253        let res1_str: Vec<&str> = res1
254            .iter()
255            .map(|b| std::str::from_utf8(b).unwrap())
256            .collect();
257        let res2_str: Vec<&str> = res2
258            .iter()
259            .map(|b| std::str::from_utf8(b).unwrap())
260            .collect();
261
262        assert_eq!(res1_str, chunks);
263        assert_eq!(res2_str, chunks);
264    }
265
266    #[tokio::test]
267    async fn test_error_propagation() {
268        let error = std::io::Error::new(std::io::ErrorKind::Other, "test error");
269        let stream = stream::iter(vec![
270            Ok(Bytes::from("one")),
271            Err(error),
272            Ok(Bytes::from("two")),
273        ]);
274        let body = Body::wrap_stream(stream);
275
276        let (mut body1, mut body2) = tee(body).await;
277
278        assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("one"));
279        let err1 = body1.next().await.unwrap().unwrap_err();
280        assert!(
281            err1.to_string().contains("test error"),
282            "Got error: {}",
283            err1
284        );
285        assert!(
286            body1.next().await.is_none(),
287            "Stream should end after error"
288        );
289
290        assert_eq!(body2.next().await.unwrap().unwrap(), Bytes::from("one"));
291        let err2 = body2.next().await.unwrap().unwrap_err();
292        assert!(
293            err2.to_string().contains("test error"),
294            "Got error: {}",
295            err1
296        );
297        assert!(
298            body2.next().await.is_none(),
299            "Stream should end after error"
300        );
301    }
302
303    #[tokio::test]
304    async fn test_error_with_one_consumer_dropped() {
305        let error = std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "aborted");
306        let stream = stream::iter(vec![Ok(Bytes::from("first")), Err(error)]);
307        let body = Body::wrap_stream(stream);
308
309        let (mut body1, body2) = tee(body).await;
310
311        drop(body2);
312
313        assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("first"));
314        let err1 = body1.next().await.unwrap().unwrap_err();
315        assert!(err1.to_string().contains("aborted"));
316        assert!(
317            body1.next().await.is_none(),
318            "Stream should end after error"
319        );
320    }
321
322    #[tokio::test]
323    async fn test_size_hint_preservation() {
324        let data = "this has a known size";
325        let body = Body::from(data);
326        let original_size_hint = HttpBody::size_hint(&body);
327
328        assert_eq!(original_size_hint.exact(), Some(data.len() as u64));
329
330        let (body1, body2) = tee(body).await;
331
332        assert_eq!(
333            HttpBody::size_hint(&body1).exact(),
334            original_size_hint.exact()
335        );
336        assert_eq!(
337            HttpBody::size_hint(&body2).exact(),
338            original_size_hint.exact()
339        );
340
341        let body1_bytes = hyper::body::to_bytes(body1).await.unwrap();
342        let body2_bytes = hyper::body::to_bytes(body2).await.unwrap();
343
344        assert_eq!(body1_bytes, data.as_bytes());
345        assert_eq!(body2_bytes, data.as_bytes());
346    }
347}