1use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8
9use futures::Stream;
10use futures_util::stream::FuturesUnordered;
11use pin_project::pin_project;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
13use vortex_error::VortexUnwrap;
14
15#[pin_project]
20struct SizedFut<Fut> {
21 #[pin]
22 inner: Fut,
23 _permits: OwnedSemaphorePermit,
25}
26
27impl<Fut: Future> Future for SizedFut<Fut> {
28 type Output = Fut::Output;
29
30 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
31 let this = self.project();
32 let result = ready!(this.inner.poll(cx));
33 Poll::Ready(result)
34 }
35}
36
37#[pin_project]
47pub struct SizeLimitedStream<Fut> {
48 #[pin]
49 inflight: FuturesUnordered<SizedFut<Fut>>,
50 bytes_available: Arc<Semaphore>,
51}
52
53impl<Fut> SizeLimitedStream<Fut> {
54 pub fn new(max_bytes: usize) -> Self {
55 Self {
56 inflight: FuturesUnordered::new(),
57 bytes_available: Arc::new(Semaphore::new(max_bytes)),
58 }
59 }
60
61 pub fn bytes_available(&self) -> usize {
62 self.bytes_available.available_permits()
63 }
64}
65
66impl<Fut> SizeLimitedStream<Fut>
67where
68 Fut: Future,
69{
70 pub async fn push(&self, fut: Fut, bytes: usize) {
75 let permits = self
79 .bytes_available
80 .clone()
81 .acquire_many_owned(bytes.try_into().vortex_unwrap())
82 .await
83 .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
84
85 let sized_fut = SizedFut {
86 inner: fut,
87 _permits: permits,
88 };
89
90 self.inflight.push(sized_fut);
92 }
93
94 pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
99 match self
100 .bytes_available
101 .clone()
102 .try_acquire_many_owned(bytes.try_into().vortex_unwrap())
103 {
104 Ok(permits) => {
105 let sized_fut = SizedFut {
106 inner: fut,
107 _permits: permits,
108 };
109
110 self.inflight.push(sized_fut);
111 Ok(())
112 }
113 Err(acquire_err) => match acquire_err {
114 TryAcquireError::Closed => {
115 unreachable!("try_pushing to closed semaphore");
116 }
117
118 TryAcquireError::NoPermits => Err(fut),
121 },
122 }
123 }
124}
125
126impl<Fut> Stream for SizeLimitedStream<Fut>
127where
128 Fut: Future,
129{
130 type Item = Fut::Output;
131
132 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 let this = self.project();
134 match ready!(this.inflight.poll_next(cx)) {
135 None => Poll::Ready(None),
136 Some(result) => {
137 Poll::Ready(Some(result))
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use std::{future, io};
148
149 use bytes::Bytes;
150 use futures_util::future::BoxFuture;
151 use futures_util::{FutureExt, StreamExt};
152
153 use crate::limit::SizeLimitedStream;
154
155 async fn make_future(len: usize) -> Bytes {
156 "a".as_bytes().iter().copied().cycle().take(len).collect()
157 }
158
159 #[tokio::test]
160 async fn test_size_limit() {
161 let mut size_limited = SizeLimitedStream::new(10);
162 size_limited.push(make_future(5), 5).await;
163 size_limited.push(make_future(5), 5).await;
164
165 assert!(size_limited.try_push(make_future(1), 1).is_err());
167
168 assert!(size_limited.next().await.is_some());
170 assert!(size_limited.try_push(make_future(1), 1).is_ok());
171 }
172
173 #[tokio::test]
174 async fn test_does_not_leak_permits() {
175 let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
176 future::ready(Err(io::Error::other("badness"))).boxed();
177
178 let good_fut: BoxFuture<'static, io::Result<Bytes>> =
179 future::ready(Ok(Bytes::from("aaaaa"))).boxed();
180
181 let mut size_limited = SizeLimitedStream::new(10);
182 size_limited.push(bad_fut, 10).await;
183
184 let good_fut = size_limited
186 .try_push(good_fut, 5)
187 .expect_err("try_push should fail");
188
189 let next = size_limited.next().await.unwrap();
192 assert!(next.is_err());
193
194 assert_eq!(size_limited.bytes_available(), 10);
195 assert!(size_limited.try_push(good_fut, 5).is_ok());
196 }
197
198 #[tokio::test]
199 async fn test_size_limited_stream_zero_capacity() {
200 let stream = SizeLimitedStream::new(0);
201
202 let result = stream.try_push(async { vec![1u8] }, 1);
204 assert!(result.is_err());
205 }
206
207 #[tokio::test]
208 async fn test_size_limited_stream_dropped_future_releases_permits() {
209 use futures::future::BoxFuture;
210
211 let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
212
213 stream
215 .push(
216 Box::pin(async {
217 futures::future::pending::<Vec<u8>>().await
219 }),
220 5,
221 )
222 .await;
223
224 stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
226
227 assert_eq!(stream.bytes_available(), 2);
229
230 drop(stream);
232
233 let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
235
236 new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
238 assert_eq!(new_stream.bytes_available(), 0);
239
240 let result = new_stream.next().await;
242 assert!(result.is_some());
243 assert_eq!(new_stream.bytes_available(), 10);
244 }
245
246 #[tokio::test]
247 async fn test_size_limited_stream_exact_capacity() {
248 use futures::future::BoxFuture;
249
250 let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
251
252 stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
254
255 let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
257 assert!(result.is_err());
258
259 let _ = stream.next().await;
261 assert_eq!(stream.bytes_available(), 10);
262
263 let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
264 assert!(result.is_ok());
265 }
266
267 #[tokio::test]
268 async fn test_size_limited_stream_multiple_small_pushes() {
269 let mut stream = SizeLimitedStream::new(100);
270
271 for i in 0..10 {
273 #[allow(clippy::cast_possible_truncation)]
274 stream.push(async move { vec![i as u8; 5] }, 5).await;
275 }
276
277 assert_eq!(stream.bytes_available(), 50);
279
280 let mut count = 0;
282 while stream.next().await.is_some() {
283 count += 1;
284 if count == 10 {
285 break;
286 }
287 }
288
289 assert_eq!(count, 10);
290 assert_eq!(stream.bytes_available(), 100);
291 }
292
293 #[test]
294 fn test_size_overflow_protection() {
295 let stream = SizeLimitedStream::new(100);
296
297 #[cfg(target_pointer_width = "64")]
300 {
301 let _large_size = (u32::MAX as usize) + 1;
302 }
306
307 let result = stream.try_push(async { vec![0u8; 50] }, 50);
309 assert!(result.is_ok());
310 }
311}