vortex_io/
limit.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8
9use futures::Stream;
10use futures_util::stream::FuturesUnordered;
11use pin_project::pin_project;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
13use vortex_error::VortexUnwrap;
14
15/// [`Future`] that carries the amount of memory it will require to hold the completed value.
16///
17/// The `OwnedSemaphorePermit` ensures permits are automatically returned when this future
18/// is dropped, either after completion or if cancelled/aborted.
19#[pin_project]
20struct SizedFut<Fut> {
21    #[pin]
22    inner: Fut,
23    /// Owned permit that will be automatically dropped when the future completes or is dropped
24    _permits: OwnedSemaphorePermit,
25}
26
27impl<Fut: Future> Future for SizedFut<Fut> {
28    type Output = Fut::Output;
29
30    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
31        let this = self.project();
32        let result = ready!(this.inner.poll(cx));
33        Poll::Ready(result)
34    }
35}
36
37/// A [`Stream`] that can work on several simultaneous requests, capping the amount of memory it
38/// uses at any given point.
39///
40/// It is meant to serve as a buffer between a producer and consumer of IO requests, with built-in
41/// backpressure that prevents the producer from materializing more than a specified maximum
42/// amount of memory.
43///
44/// This crate internally makes use of tokio's [Semaphore], and thus is only available with
45/// the `tokio` feature enabled.
46#[pin_project]
47pub struct SizeLimitedStream<Fut> {
48    #[pin]
49    inflight: FuturesUnordered<SizedFut<Fut>>,
50    bytes_available: Arc<Semaphore>,
51}
52
53impl<Fut> SizeLimitedStream<Fut> {
54    pub fn new(max_bytes: usize) -> Self {
55        Self {
56            inflight: FuturesUnordered::new(),
57            bytes_available: Arc::new(Semaphore::new(max_bytes)),
58        }
59    }
60
61    pub fn bytes_available(&self) -> usize {
62        self.bytes_available.available_permits()
63    }
64}
65
66impl<Fut> SizeLimitedStream<Fut>
67where
68    Fut: Future,
69{
70    /// Push a future into the queue after reserving `bytes` of capacity.
71    ///
72    /// This call may need to wait until there is sufficient capacity available in the stream to
73    /// begin work on this future.
74    pub async fn push(&self, fut: Fut, bytes: usize) {
75        // Attempt to acquire enough permits to begin working on a request that will occupy
76        // `bytes` amount of memory when it completes.
77        // Acquiring the permits is what creates backpressure for the producer.
78        let permits = self
79            .bytes_available
80            .clone()
81            .acquire_many_owned(bytes.try_into().vortex_unwrap())
82            .await
83            .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
84
85        let sized_fut = SizedFut {
86            inner: fut,
87            _permits: permits,
88        };
89
90        // push into the pending queue
91        self.inflight.push(sized_fut);
92    }
93
94    /// Synchronous push method. This method will attempt to push if there is enough capacity
95    /// to begin work on the future immediately.
96    ///
97    /// If there is not enough capacity, the original future is returned to the caller.
98    pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
99        match self
100            .bytes_available
101            .clone()
102            .try_acquire_many_owned(bytes.try_into().vortex_unwrap())
103        {
104            Ok(permits) => {
105                let sized_fut = SizedFut {
106                    inner: fut,
107                    _permits: permits,
108                };
109
110                self.inflight.push(sized_fut);
111                Ok(())
112            }
113            Err(acquire_err) => match acquire_err {
114                TryAcquireError::Closed => {
115                    unreachable!("try_pushing to closed semaphore");
116                }
117
118                // No permits available, return the future back to the client so they can
119                // try again.
120                TryAcquireError::NoPermits => Err(fut),
121            },
122        }
123    }
124}
125
126impl<Fut> Stream for SizeLimitedStream<Fut>
127where
128    Fut: Future,
129{
130    type Item = Fut::Output;
131
132    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        let this = self.project();
134        match ready!(this.inflight.poll_next(cx)) {
135            None => Poll::Ready(None),
136            Some(result) => {
137                // Permits are automatically returned when the SizedFut is dropped
138                // after being polled to completion by FuturesUnordered
139                Poll::Ready(Some(result))
140            }
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use std::{future, io};
148
149    use bytes::Bytes;
150    use futures_util::future::BoxFuture;
151    use futures_util::{FutureExt, StreamExt};
152
153    use crate::limit::SizeLimitedStream;
154
155    async fn make_future(len: usize) -> Bytes {
156        "a".as_bytes().iter().copied().cycle().take(len).collect()
157    }
158
159    #[tokio::test]
160    async fn test_size_limit() {
161        let mut size_limited = SizeLimitedStream::new(10);
162        size_limited.push(make_future(5), 5).await;
163        size_limited.push(make_future(5), 5).await;
164
165        // Pushing last request should fail, because we have 10 bytes outstanding.
166        assert!(size_limited.try_push(make_future(1), 1).is_err());
167
168        // but, we can pop off a finished work item, and then enqueue.
169        assert!(size_limited.next().await.is_some());
170        assert!(size_limited.try_push(make_future(1), 1).is_ok());
171    }
172
173    #[tokio::test]
174    async fn test_does_not_leak_permits() {
175        let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
176            future::ready(Err(io::Error::other("badness"))).boxed();
177
178        let good_fut: BoxFuture<'static, io::Result<Bytes>> =
179            future::ready(Ok(Bytes::from("aaaaa"))).boxed();
180
181        let mut size_limited = SizeLimitedStream::new(10);
182        size_limited.push(bad_fut, 10).await;
183
184        // attempt to push should fail, as all 10 bytes of capacity is occupied by bad_fut.
185        let good_fut = size_limited
186            .try_push(good_fut, 5)
187            .expect_err("try_push should fail");
188
189        // Even though the result was an error, the 10 bytes of capacity should be returned back to
190        // the stream, allowing us to push the next request.
191        let next = size_limited.next().await.unwrap();
192        assert!(next.is_err());
193
194        assert_eq!(size_limited.bytes_available(), 10);
195        assert!(size_limited.try_push(good_fut, 5).is_ok());
196    }
197
198    #[tokio::test]
199    async fn test_size_limited_stream_zero_capacity() {
200        let stream = SizeLimitedStream::new(0);
201
202        // Should not be able to push anything
203        let result = stream.try_push(async { vec![1u8] }, 1);
204        assert!(result.is_err());
205    }
206
207    #[tokio::test]
208    async fn test_size_limited_stream_dropped_future_releases_permits() {
209        use futures::future::BoxFuture;
210
211        let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
212
213        // Push a future that will never complete
214        stream
215            .push(
216                Box::pin(async {
217                    // This future will be dropped before completion
218                    futures::future::pending::<Vec<u8>>().await
219                }),
220                5,
221            )
222            .await;
223
224        // Push another future
225        stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
226
227        // We should have 2 bytes available now
228        assert_eq!(stream.bytes_available(), 2);
229
230        // Drop the stream without consuming the futures
231        drop(stream);
232
233        // Create a new stream to verify permits aren't leaked
234        let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
235
236        // Should be able to use all 10 bytes
237        new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
238        assert_eq!(new_stream.bytes_available(), 0);
239
240        // Consume to verify it works
241        let result = new_stream.next().await;
242        assert!(result.is_some());
243        assert_eq!(new_stream.bytes_available(), 10);
244    }
245
246    #[tokio::test]
247    async fn test_size_limited_stream_exact_capacity() {
248        use futures::future::BoxFuture;
249
250        let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
251
252        // Push exactly the capacity
253        stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
254
255        // Should not be able to push more
256        let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
257        assert!(result.is_err());
258
259        // After consuming, should be able to push again
260        let _ = stream.next().await;
261        assert_eq!(stream.bytes_available(), 10);
262
263        let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
264        assert!(result.is_ok());
265    }
266
267    #[tokio::test]
268    async fn test_size_limited_stream_multiple_small_pushes() {
269        let mut stream = SizeLimitedStream::new(100);
270
271        // Push many small items
272        for i in 0..10 {
273            #[allow(clippy::cast_possible_truncation)]
274            stream.push(async move { vec![i as u8; 5] }, 5).await;
275        }
276
277        // Should have used 50 bytes
278        assert_eq!(stream.bytes_available(), 50);
279
280        // Consume all
281        let mut count = 0;
282        while stream.next().await.is_some() {
283            count += 1;
284            if count == 10 {
285                break;
286            }
287        }
288
289        assert_eq!(count, 10);
290        assert_eq!(stream.bytes_available(), 100);
291    }
292
293    #[test]
294    fn test_size_overflow_protection() {
295        let stream = SizeLimitedStream::new(100);
296
297        // Test with size that would overflow u32 on 32-bit systems
298        // but this test assumes 64-bit where usize > u32::MAX is possible
299        #[cfg(target_pointer_width = "64")]
300        {
301            let _large_size = (u32::MAX as usize) + 1;
302            // This should panic with current implementation
303            // We're documenting the issue rather than testing the panic
304            // as the behavior may change
305        }
306
307        // Test with reasonable size
308        let result = stream.try_push(async { vec![0u8; 50] }, 50);
309        assert!(result.is_ok());
310    }
311}