vortex_io/
limit.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Context;
8use std::task::Poll;
9use std::task::ready;
10
11use futures::Stream;
12use futures::stream::FuturesUnordered;
13use pin_project_lite::pin_project;
14use tokio::sync::OwnedSemaphorePermit;
15use tokio::sync::Semaphore;
16use tokio::sync::TryAcquireError;
17use vortex_error::VortexExpect;
18
19pin_project! {
20    /// [`Future`] that carries the amount of memory it will require to hold the completed value.
21    ///
22    /// The `OwnedSemaphorePermit` ensures permits are automatically returned when this future
23    /// is dropped, either after completion or if cancelled/aborted.
24    struct SizedFut<Fut> {
25        #[pin]
26        inner: Fut,
27        // Owned permit that will be automatically dropped when the future completes or is dropped
28        _permits: OwnedSemaphorePermit,
29    }
30}
31
32impl<Fut: Future> Future for SizedFut<Fut> {
33    type Output = Fut::Output;
34
35    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36        let this = self.project();
37        let result = ready!(this.inner.poll(cx));
38        Poll::Ready(result)
39    }
40}
41
42pin_project! {
43    /// A [`Stream`] that can work on several simultaneous requests, capping the amount of memory it
44    /// uses at any given point.
45    ///
46    /// It is meant to serve as a buffer between a producer and consumer of IO requests, with built-in
47    /// backpressure that prevents the producer from materializing more than a specified maximum
48    /// amount of memory.
49    ///
50    /// This crate internally makes use of tokio's [Semaphore], and thus is only available with
51    /// the `tokio` feature enabled.
52    pub struct SizeLimitedStream<Fut> {
53        #[pin]
54        inflight: FuturesUnordered<SizedFut<Fut>>,
55        bytes_available: Arc<Semaphore>,
56    }
57}
58
59impl<Fut> SizeLimitedStream<Fut> {
60    pub fn new(max_bytes: usize) -> Self {
61        Self {
62            inflight: FuturesUnordered::new(),
63            bytes_available: Arc::new(Semaphore::new(max_bytes)),
64        }
65    }
66
67    pub fn bytes_available(&self) -> usize {
68        self.bytes_available.available_permits()
69    }
70}
71
72impl<Fut> SizeLimitedStream<Fut>
73where
74    Fut: Future,
75{
76    /// Push a future into the queue after reserving `bytes` of capacity.
77    ///
78    /// This call may need to wait until there is sufficient capacity available in the stream to
79    /// begin work on this future.
80    pub async fn push(&self, fut: Fut, bytes: usize) {
81        // Attempt to acquire enough permits to begin working on a request that will occupy
82        // `bytes` amount of memory when it completes.
83        // Acquiring the permits is what creates backpressure for the producer.
84        let permits = self
85            .bytes_available
86            .clone()
87            .acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
88            .await
89            .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
90
91        let sized_fut = SizedFut {
92            inner: fut,
93            _permits: permits,
94        };
95
96        // push into the pending queue
97        self.inflight.push(sized_fut);
98    }
99
100    /// Synchronous push method. This method will attempt to push if there is enough capacity
101    /// to begin work on the future immediately.
102    ///
103    /// If there is not enough capacity, the original future is returned to the caller.
104    pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
105        match self
106            .bytes_available
107            .clone()
108            .try_acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
109        {
110            Ok(permits) => {
111                let sized_fut = SizedFut {
112                    inner: fut,
113                    _permits: permits,
114                };
115
116                self.inflight.push(sized_fut);
117                Ok(())
118            }
119            Err(acquire_err) => match acquire_err {
120                TryAcquireError::Closed => {
121                    unreachable!("try_pushing to closed semaphore");
122                }
123
124                // No permits available, return the future back to the client so they can
125                // try again.
126                TryAcquireError::NoPermits => Err(fut),
127            },
128        }
129    }
130}
131
132impl<Fut> Stream for SizeLimitedStream<Fut>
133where
134    Fut: Future,
135{
136    type Item = Fut::Output;
137
138    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        let this = self.project();
140        match ready!(this.inflight.poll_next(cx)) {
141            None => Poll::Ready(None),
142            Some(result) => {
143                // Permits are automatically returned when the SizedFut is dropped
144                // after being polled to completion by FuturesUnordered
145                Poll::Ready(Some(result))
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::future;
154    use std::io;
155
156    use bytes::Bytes;
157    use futures::FutureExt;
158    use futures::StreamExt;
159    use futures::future::BoxFuture;
160
161    use crate::limit::SizeLimitedStream;
162
163    async fn make_future(len: usize) -> Bytes {
164        "a".as_bytes().iter().copied().cycle().take(len).collect()
165    }
166
167    #[tokio::test]
168    async fn test_size_limit() {
169        let mut size_limited = SizeLimitedStream::new(10);
170        size_limited.push(make_future(5), 5).await;
171        size_limited.push(make_future(5), 5).await;
172
173        // Pushing last request should fail, because we have 10 bytes outstanding.
174        assert!(size_limited.try_push(make_future(1), 1).is_err());
175
176        // but, we can pop off a finished work item, and then enqueue.
177        assert!(size_limited.next().await.is_some());
178        assert!(size_limited.try_push(make_future(1), 1).is_ok());
179    }
180
181    #[tokio::test]
182    async fn test_does_not_leak_permits() {
183        let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
184            future::ready(Err(io::Error::other("badness"))).boxed();
185
186        let good_fut: BoxFuture<'static, io::Result<Bytes>> =
187            future::ready(Ok(Bytes::from("aaaaa"))).boxed();
188
189        let mut size_limited = SizeLimitedStream::new(10);
190        size_limited.push(bad_fut, 10).await;
191
192        // attempt to push should fail, as all 10 bytes of capacity is occupied by bad_fut.
193        let good_fut = size_limited
194            .try_push(good_fut, 5)
195            .expect_err("try_push should fail");
196
197        // Even though the result was an error, the 10 bytes of capacity should be returned back to
198        // the stream, allowing us to push the next request.
199        let next = size_limited.next().await.unwrap();
200        assert!(next.is_err());
201
202        assert_eq!(size_limited.bytes_available(), 10);
203        assert!(size_limited.try_push(good_fut, 5).is_ok());
204    }
205
206    #[tokio::test]
207    async fn test_size_limited_stream_zero_capacity() {
208        let stream = SizeLimitedStream::new(0);
209
210        // Should not be able to push anything
211        let result = stream.try_push(async { vec![1u8] }, 1);
212        assert!(result.is_err());
213    }
214
215    #[tokio::test]
216    async fn test_size_limited_stream_dropped_future_releases_permits() {
217        use futures::future::BoxFuture;
218
219        let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
220
221        // Push a future that will never complete
222        stream
223            .push(
224                Box::pin(async {
225                    // This future will be dropped before completion
226                    futures::future::pending::<Vec<u8>>().await
227                }),
228                5,
229            )
230            .await;
231
232        // Push another future
233        stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
234
235        // We should have 2 bytes available now
236        assert_eq!(stream.bytes_available(), 2);
237
238        // Drop the stream without consuming the futures
239        drop(stream);
240
241        // Create a new stream to verify permits aren't leaked
242        let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
243
244        // Should be able to use all 10 bytes
245        new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
246        assert_eq!(new_stream.bytes_available(), 0);
247
248        // Consume to verify it works
249        let result = new_stream.next().await;
250        assert!(result.is_some());
251        assert_eq!(new_stream.bytes_available(), 10);
252    }
253
254    #[tokio::test]
255    async fn test_size_limited_stream_exact_capacity() {
256        use futures::future::BoxFuture;
257
258        let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
259
260        // Push exactly the capacity
261        stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
262
263        // Should not be able to push more
264        let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
265        assert!(result.is_err());
266
267        // After consuming, should be able to push again
268        stream
269            .next()
270            .await
271            .expect("The 10 byte vector ought to be in there!");
272        assert_eq!(stream.bytes_available(), 10);
273
274        let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
275        assert!(result.is_ok());
276    }
277
278    #[tokio::test]
279    async fn test_size_limited_stream_multiple_small_pushes() {
280        let mut stream = SizeLimitedStream::new(100);
281
282        // Push many small items
283        for i in 0..10 {
284            #[allow(clippy::cast_possible_truncation)]
285            stream.push(async move { vec![i as u8; 5] }, 5).await;
286        }
287
288        // Should have used 50 bytes
289        assert_eq!(stream.bytes_available(), 50);
290
291        // Consume all
292        let mut count = 0;
293        while stream.next().await.is_some() {
294            count += 1;
295            if count == 10 {
296                break;
297            }
298        }
299
300        assert_eq!(count, 10);
301        assert_eq!(stream.bytes_available(), 100);
302    }
303
304    #[test]
305    fn test_size_overflow_protection() {
306        let stream = SizeLimitedStream::new(100);
307
308        // Test with size that would overflow u32 on 32-bit systems
309        // but this test assumes 64-bit where usize > u32::MAX is possible
310        #[cfg(target_pointer_width = "64")]
311        {
312            let _large_size = (u32::MAX as usize) + 1;
313            // This should panic with current implementation
314            // We're documenting the issue rather than testing the panic
315            // as the behavior may change
316        }
317
318        // Test with reasonable size
319        let result = stream.try_push(async { vec![0u8; 50] }, 50);
320        assert!(result.is_ok());
321    }
322}