1use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8
9use futures::Stream;
10use futures::stream::FuturesUnordered;
11use pin_project_lite::pin_project;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
13use vortex_error::VortexUnwrap;
14
15pin_project! {
16 struct SizedFut<Fut> {
21 #[pin]
22 inner: Fut,
23 _permits: OwnedSemaphorePermit,
25 }
26}
27
28impl<Fut: Future> Future for SizedFut<Fut> {
29 type Output = Fut::Output;
30
31 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
32 let this = self.project();
33 let result = ready!(this.inner.poll(cx));
34 Poll::Ready(result)
35 }
36}
37
38pin_project! {
39 pub struct SizeLimitedStream<Fut> {
49 #[pin]
50 inflight: FuturesUnordered<SizedFut<Fut>>,
51 bytes_available: Arc<Semaphore>,
52 }
53}
54
55impl<Fut> SizeLimitedStream<Fut> {
56 pub fn new(max_bytes: usize) -> Self {
57 Self {
58 inflight: FuturesUnordered::new(),
59 bytes_available: Arc::new(Semaphore::new(max_bytes)),
60 }
61 }
62
63 pub fn bytes_available(&self) -> usize {
64 self.bytes_available.available_permits()
65 }
66}
67
68impl<Fut> SizeLimitedStream<Fut>
69where
70 Fut: Future,
71{
72 pub async fn push(&self, fut: Fut, bytes: usize) {
77 let permits = self
81 .bytes_available
82 .clone()
83 .acquire_many_owned(bytes.try_into().vortex_unwrap())
84 .await
85 .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
86
87 let sized_fut = SizedFut {
88 inner: fut,
89 _permits: permits,
90 };
91
92 self.inflight.push(sized_fut);
94 }
95
96 pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
101 match self
102 .bytes_available
103 .clone()
104 .try_acquire_many_owned(bytes.try_into().vortex_unwrap())
105 {
106 Ok(permits) => {
107 let sized_fut = SizedFut {
108 inner: fut,
109 _permits: permits,
110 };
111
112 self.inflight.push(sized_fut);
113 Ok(())
114 }
115 Err(acquire_err) => match acquire_err {
116 TryAcquireError::Closed => {
117 unreachable!("try_pushing to closed semaphore");
118 }
119
120 TryAcquireError::NoPermits => Err(fut),
123 },
124 }
125 }
126}
127
128impl<Fut> Stream for SizeLimitedStream<Fut>
129where
130 Fut: Future,
131{
132 type Item = Fut::Output;
133
134 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135 let this = self.project();
136 match ready!(this.inflight.poll_next(cx)) {
137 None => Poll::Ready(None),
138 Some(result) => {
139 Poll::Ready(Some(result))
142 }
143 }
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use std::{future, io};
150
151 use bytes::Bytes;
152 use futures::future::BoxFuture;
153 use futures::{FutureExt, StreamExt};
154
155 use crate::limit::SizeLimitedStream;
156
157 async fn make_future(len: usize) -> Bytes {
158 "a".as_bytes().iter().copied().cycle().take(len).collect()
159 }
160
161 #[tokio::test]
162 async fn test_size_limit() {
163 let mut size_limited = SizeLimitedStream::new(10);
164 size_limited.push(make_future(5), 5).await;
165 size_limited.push(make_future(5), 5).await;
166
167 assert!(size_limited.try_push(make_future(1), 1).is_err());
169
170 assert!(size_limited.next().await.is_some());
172 assert!(size_limited.try_push(make_future(1), 1).is_ok());
173 }
174
175 #[tokio::test]
176 async fn test_does_not_leak_permits() {
177 let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
178 future::ready(Err(io::Error::other("badness"))).boxed();
179
180 let good_fut: BoxFuture<'static, io::Result<Bytes>> =
181 future::ready(Ok(Bytes::from("aaaaa"))).boxed();
182
183 let mut size_limited = SizeLimitedStream::new(10);
184 size_limited.push(bad_fut, 10).await;
185
186 let good_fut = size_limited
188 .try_push(good_fut, 5)
189 .expect_err("try_push should fail");
190
191 let next = size_limited.next().await.unwrap();
194 assert!(next.is_err());
195
196 assert_eq!(size_limited.bytes_available(), 10);
197 assert!(size_limited.try_push(good_fut, 5).is_ok());
198 }
199
200 #[tokio::test]
201 async fn test_size_limited_stream_zero_capacity() {
202 let stream = SizeLimitedStream::new(0);
203
204 let result = stream.try_push(async { vec![1u8] }, 1);
206 assert!(result.is_err());
207 }
208
209 #[tokio::test]
210 async fn test_size_limited_stream_dropped_future_releases_permits() {
211 use futures::future::BoxFuture;
212
213 let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
214
215 stream
217 .push(
218 Box::pin(async {
219 futures::future::pending::<Vec<u8>>().await
221 }),
222 5,
223 )
224 .await;
225
226 stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
228
229 assert_eq!(stream.bytes_available(), 2);
231
232 drop(stream);
234
235 let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
237
238 new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
240 assert_eq!(new_stream.bytes_available(), 0);
241
242 let result = new_stream.next().await;
244 assert!(result.is_some());
245 assert_eq!(new_stream.bytes_available(), 10);
246 }
247
248 #[tokio::test]
249 async fn test_size_limited_stream_exact_capacity() {
250 use futures::future::BoxFuture;
251
252 let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
253
254 stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
256
257 let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
259 assert!(result.is_err());
260
261 stream
263 .next()
264 .await
265 .expect("The 10 byte vector ought to be in there!");
266 assert_eq!(stream.bytes_available(), 10);
267
268 let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
269 assert!(result.is_ok());
270 }
271
272 #[tokio::test]
273 async fn test_size_limited_stream_multiple_small_pushes() {
274 let mut stream = SizeLimitedStream::new(100);
275
276 for i in 0..10 {
278 #[allow(clippy::cast_possible_truncation)]
279 stream.push(async move { vec![i as u8; 5] }, 5).await;
280 }
281
282 assert_eq!(stream.bytes_available(), 50);
284
285 let mut count = 0;
287 while stream.next().await.is_some() {
288 count += 1;
289 if count == 10 {
290 break;
291 }
292 }
293
294 assert_eq!(count, 10);
295 assert_eq!(stream.bytes_available(), 100);
296 }
297
298 #[test]
299 fn test_size_overflow_protection() {
300 let stream = SizeLimitedStream::new(100);
301
302 #[cfg(target_pointer_width = "64")]
305 {
306 let _large_size = (u32::MAX as usize) + 1;
307 }
311
312 let result = stream.try_push(async { vec![0u8; 50] }, 50);
314 assert!(result.is_ok());
315 }
316}