vortex_io/
limit.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8
9use futures::Stream;
10use futures::stream::FuturesUnordered;
11use pin_project_lite::pin_project;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
13use vortex_error::VortexUnwrap;
14
15pin_project! {
16    /// [`Future`] that carries the amount of memory it will require to hold the completed value.
17    ///
18    /// The `OwnedSemaphorePermit` ensures permits are automatically returned when this future
19    /// is dropped, either after completion or if cancelled/aborted.
20    struct SizedFut<Fut> {
21        #[pin]
22        inner: Fut,
23        // Owned permit that will be automatically dropped when the future completes or is dropped
24        _permits: OwnedSemaphorePermit,
25    }
26}
27
28impl<Fut: Future> Future for SizedFut<Fut> {
29    type Output = Fut::Output;
30
31    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
32        let this = self.project();
33        let result = ready!(this.inner.poll(cx));
34        Poll::Ready(result)
35    }
36}
37
38pin_project! {
39    /// A [`Stream`] that can work on several simultaneous requests, capping the amount of memory it
40    /// uses at any given point.
41    ///
42    /// It is meant to serve as a buffer between a producer and consumer of IO requests, with built-in
43    /// backpressure that prevents the producer from materializing more than a specified maximum
44    /// amount of memory.
45    ///
46    /// This crate internally makes use of tokio's [Semaphore], and thus is only available with
47    /// the `tokio` feature enabled.
48    pub struct SizeLimitedStream<Fut> {
49        #[pin]
50        inflight: FuturesUnordered<SizedFut<Fut>>,
51        bytes_available: Arc<Semaphore>,
52    }
53}
54
55impl<Fut> SizeLimitedStream<Fut> {
56    pub fn new(max_bytes: usize) -> Self {
57        Self {
58            inflight: FuturesUnordered::new(),
59            bytes_available: Arc::new(Semaphore::new(max_bytes)),
60        }
61    }
62
63    pub fn bytes_available(&self) -> usize {
64        self.bytes_available.available_permits()
65    }
66}
67
68impl<Fut> SizeLimitedStream<Fut>
69where
70    Fut: Future,
71{
72    /// Push a future into the queue after reserving `bytes` of capacity.
73    ///
74    /// This call may need to wait until there is sufficient capacity available in the stream to
75    /// begin work on this future.
76    pub async fn push(&self, fut: Fut, bytes: usize) {
77        // Attempt to acquire enough permits to begin working on a request that will occupy
78        // `bytes` amount of memory when it completes.
79        // Acquiring the permits is what creates backpressure for the producer.
80        let permits = self
81            .bytes_available
82            .clone()
83            .acquire_many_owned(bytes.try_into().vortex_unwrap())
84            .await
85            .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
86
87        let sized_fut = SizedFut {
88            inner: fut,
89            _permits: permits,
90        };
91
92        // push into the pending queue
93        self.inflight.push(sized_fut);
94    }
95
96    /// Synchronous push method. This method will attempt to push if there is enough capacity
97    /// to begin work on the future immediately.
98    ///
99    /// If there is not enough capacity, the original future is returned to the caller.
100    pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
101        match self
102            .bytes_available
103            .clone()
104            .try_acquire_many_owned(bytes.try_into().vortex_unwrap())
105        {
106            Ok(permits) => {
107                let sized_fut = SizedFut {
108                    inner: fut,
109                    _permits: permits,
110                };
111
112                self.inflight.push(sized_fut);
113                Ok(())
114            }
115            Err(acquire_err) => match acquire_err {
116                TryAcquireError::Closed => {
117                    unreachable!("try_pushing to closed semaphore");
118                }
119
120                // No permits available, return the future back to the client so they can
121                // try again.
122                TryAcquireError::NoPermits => Err(fut),
123            },
124        }
125    }
126}
127
128impl<Fut> Stream for SizeLimitedStream<Fut>
129where
130    Fut: Future,
131{
132    type Item = Fut::Output;
133
134    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135        let this = self.project();
136        match ready!(this.inflight.poll_next(cx)) {
137            None => Poll::Ready(None),
138            Some(result) => {
139                // Permits are automatically returned when the SizedFut is dropped
140                // after being polled to completion by FuturesUnordered
141                Poll::Ready(Some(result))
142            }
143        }
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use std::{future, io};
150
151    use bytes::Bytes;
152    use futures::future::BoxFuture;
153    use futures::{FutureExt, StreamExt};
154
155    use crate::limit::SizeLimitedStream;
156
157    async fn make_future(len: usize) -> Bytes {
158        "a".as_bytes().iter().copied().cycle().take(len).collect()
159    }
160
161    #[tokio::test]
162    async fn test_size_limit() {
163        let mut size_limited = SizeLimitedStream::new(10);
164        size_limited.push(make_future(5), 5).await;
165        size_limited.push(make_future(5), 5).await;
166
167        // Pushing last request should fail, because we have 10 bytes outstanding.
168        assert!(size_limited.try_push(make_future(1), 1).is_err());
169
170        // but, we can pop off a finished work item, and then enqueue.
171        assert!(size_limited.next().await.is_some());
172        assert!(size_limited.try_push(make_future(1), 1).is_ok());
173    }
174
175    #[tokio::test]
176    async fn test_does_not_leak_permits() {
177        let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
178            future::ready(Err(io::Error::other("badness"))).boxed();
179
180        let good_fut: BoxFuture<'static, io::Result<Bytes>> =
181            future::ready(Ok(Bytes::from("aaaaa"))).boxed();
182
183        let mut size_limited = SizeLimitedStream::new(10);
184        size_limited.push(bad_fut, 10).await;
185
186        // attempt to push should fail, as all 10 bytes of capacity is occupied by bad_fut.
187        let good_fut = size_limited
188            .try_push(good_fut, 5)
189            .expect_err("try_push should fail");
190
191        // Even though the result was an error, the 10 bytes of capacity should be returned back to
192        // the stream, allowing us to push the next request.
193        let next = size_limited.next().await.unwrap();
194        assert!(next.is_err());
195
196        assert_eq!(size_limited.bytes_available(), 10);
197        assert!(size_limited.try_push(good_fut, 5).is_ok());
198    }
199
200    #[tokio::test]
201    async fn test_size_limited_stream_zero_capacity() {
202        let stream = SizeLimitedStream::new(0);
203
204        // Should not be able to push anything
205        let result = stream.try_push(async { vec![1u8] }, 1);
206        assert!(result.is_err());
207    }
208
209    #[tokio::test]
210    async fn test_size_limited_stream_dropped_future_releases_permits() {
211        use futures::future::BoxFuture;
212
213        let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
214
215        // Push a future that will never complete
216        stream
217            .push(
218                Box::pin(async {
219                    // This future will be dropped before completion
220                    futures::future::pending::<Vec<u8>>().await
221                }),
222                5,
223            )
224            .await;
225
226        // Push another future
227        stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
228
229        // We should have 2 bytes available now
230        assert_eq!(stream.bytes_available(), 2);
231
232        // Drop the stream without consuming the futures
233        drop(stream);
234
235        // Create a new stream to verify permits aren't leaked
236        let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
237
238        // Should be able to use all 10 bytes
239        new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
240        assert_eq!(new_stream.bytes_available(), 0);
241
242        // Consume to verify it works
243        let result = new_stream.next().await;
244        assert!(result.is_some());
245        assert_eq!(new_stream.bytes_available(), 10);
246    }
247
248    #[tokio::test]
249    async fn test_size_limited_stream_exact_capacity() {
250        use futures::future::BoxFuture;
251
252        let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
253
254        // Push exactly the capacity
255        stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
256
257        // Should not be able to push more
258        let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
259        assert!(result.is_err());
260
261        // After consuming, should be able to push again
262        stream
263            .next()
264            .await
265            .expect("The 10 byte vector ought to be in there!");
266        assert_eq!(stream.bytes_available(), 10);
267
268        let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
269        assert!(result.is_ok());
270    }
271
272    #[tokio::test]
273    async fn test_size_limited_stream_multiple_small_pushes() {
274        let mut stream = SizeLimitedStream::new(100);
275
276        // Push many small items
277        for i in 0..10 {
278            #[allow(clippy::cast_possible_truncation)]
279            stream.push(async move { vec![i as u8; 5] }, 5).await;
280        }
281
282        // Should have used 50 bytes
283        assert_eq!(stream.bytes_available(), 50);
284
285        // Consume all
286        let mut count = 0;
287        while stream.next().await.is_some() {
288            count += 1;
289            if count == 10 {
290                break;
291            }
292        }
293
294        assert_eq!(count, 10);
295        assert_eq!(stream.bytes_available(), 100);
296    }
297
298    #[test]
299    fn test_size_overflow_protection() {
300        let stream = SizeLimitedStream::new(100);
301
302        // Test with size that would overflow u32 on 32-bit systems
303        // but this test assumes 64-bit where usize > u32::MAX is possible
304        #[cfg(target_pointer_width = "64")]
305        {
306            let _large_size = (u32::MAX as usize) + 1;
307            // This should panic with current implementation
308            // We're documenting the issue rather than testing the panic
309            // as the behavior may change
310        }
311
312        // Test with reasonable size
313        let result = stream.try_push(async { vec![0u8; 50] }, 50);
314        assert!(result.is_ok());
315    }
316}