1use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Context;
8use std::task::Poll;
9use std::task::ready;
10
11use futures::Stream;
12use futures::stream::FuturesUnordered;
13use pin_project_lite::pin_project;
14use tokio::sync::OwnedSemaphorePermit;
15use tokio::sync::Semaphore;
16use tokio::sync::TryAcquireError;
17use vortex_error::VortexExpect;
18
19pin_project! {
20 struct SizedFut<Fut> {
25 #[pin]
26 inner: Fut,
27 _permits: OwnedSemaphorePermit,
29 }
30}
31
32impl<Fut: Future> Future for SizedFut<Fut> {
33 type Output = Fut::Output;
34
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.project();
37 let result = ready!(this.inner.poll(cx));
38 Poll::Ready(result)
39 }
40}
41
42pin_project! {
43 pub struct SizeLimitedStream<Fut> {
53 #[pin]
54 inflight: FuturesUnordered<SizedFut<Fut>>,
55 bytes_available: Arc<Semaphore>,
56 }
57}
58
59impl<Fut> SizeLimitedStream<Fut> {
60 pub fn new(max_bytes: usize) -> Self {
61 Self {
62 inflight: FuturesUnordered::new(),
63 bytes_available: Arc::new(Semaphore::new(max_bytes)),
64 }
65 }
66
67 pub fn bytes_available(&self) -> usize {
68 self.bytes_available.available_permits()
69 }
70}
71
72impl<Fut> SizeLimitedStream<Fut>
73where
74 Fut: Future,
75{
76 pub async fn push(&self, fut: Fut, bytes: usize) {
81 let permits = Arc::clone(&self.bytes_available)
85 .acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
86 .await
87 .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
88
89 let sized_fut = SizedFut {
90 inner: fut,
91 _permits: permits,
92 };
93
94 self.inflight.push(sized_fut);
96 }
97
98 pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
103 match Arc::clone(&self.bytes_available)
104 .try_acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
105 {
106 Ok(permits) => {
107 let sized_fut = SizedFut {
108 inner: fut,
109 _permits: permits,
110 };
111
112 self.inflight.push(sized_fut);
113 Ok(())
114 }
115 Err(acquire_err) => match acquire_err {
116 TryAcquireError::Closed => {
117 unreachable!("try_pushing to closed semaphore");
118 }
119
120 TryAcquireError::NoPermits => Err(fut),
123 },
124 }
125 }
126}
127
128impl<Fut> Stream for SizeLimitedStream<Fut>
129where
130 Fut: Future,
131{
132 type Item = Fut::Output;
133
134 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135 let this = self.project();
136 match ready!(this.inflight.poll_next(cx)) {
137 None => Poll::Ready(None),
138 Some(result) => {
139 Poll::Ready(Some(result))
142 }
143 }
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use std::future;
150 use std::io;
151
152 use bytes::Bytes;
153 use futures::FutureExt;
154 use futures::StreamExt;
155 use futures::future::BoxFuture;
156
157 use crate::limit::SizeLimitedStream;
158
159 async fn make_future(len: usize) -> Bytes {
160 "a".as_bytes().iter().copied().cycle().take(len).collect()
161 }
162
163 #[tokio::test]
164 async fn test_size_limit() {
165 let mut size_limited = SizeLimitedStream::new(10);
166 size_limited.push(make_future(5), 5).await;
167 size_limited.push(make_future(5), 5).await;
168
169 assert!(size_limited.try_push(make_future(1), 1).is_err());
171
172 assert!(size_limited.next().await.is_some());
174 assert!(size_limited.try_push(make_future(1), 1).is_ok());
175 }
176
177 #[tokio::test]
178 async fn test_does_not_leak_permits() {
179 let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
180 future::ready(Err(io::Error::other("badness"))).boxed();
181
182 let good_fut: BoxFuture<'static, io::Result<Bytes>> =
183 future::ready(Ok(Bytes::from("aaaaa"))).boxed();
184
185 let mut size_limited = SizeLimitedStream::new(10);
186 size_limited.push(bad_fut, 10).await;
187
188 let good_fut = size_limited
190 .try_push(good_fut, 5)
191 .expect_err("try_push should fail");
192
193 let next = size_limited.next().await.unwrap();
196 assert!(next.is_err());
197
198 assert_eq!(size_limited.bytes_available(), 10);
199 assert!(size_limited.try_push(good_fut, 5).is_ok());
200 }
201
202 #[tokio::test]
203 async fn test_size_limited_stream_zero_capacity() {
204 let stream = SizeLimitedStream::new(0);
205
206 let result = stream.try_push(async { vec![1u8] }, 1);
208 assert!(result.is_err());
209 }
210
211 #[tokio::test]
212 async fn test_size_limited_stream_dropped_future_releases_permits() {
213 use futures::future::BoxFuture;
214
215 let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
216
217 stream
219 .push(
220 Box::pin(async {
221 futures::future::pending::<Vec<u8>>().await
223 }),
224 5,
225 )
226 .await;
227
228 stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
230
231 assert_eq!(stream.bytes_available(), 2);
233
234 drop(stream);
236
237 let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
239
240 new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
242 assert_eq!(new_stream.bytes_available(), 0);
243
244 let result = new_stream.next().await;
246 assert!(result.is_some());
247 assert_eq!(new_stream.bytes_available(), 10);
248 }
249
250 #[tokio::test]
251 async fn test_size_limited_stream_exact_capacity() {
252 use futures::future::BoxFuture;
253
254 let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
255
256 stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
258
259 let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
261 assert!(result.is_err());
262
263 stream
265 .next()
266 .await
267 .expect("The 10 byte vector ought to be in there!");
268 assert_eq!(stream.bytes_available(), 10);
269
270 let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
271 assert!(result.is_ok());
272 }
273
274 #[tokio::test]
275 async fn test_size_limited_stream_multiple_small_pushes() {
276 let mut stream = SizeLimitedStream::new(100);
277
278 for i in 0..10 {
280 #[allow(clippy::cast_possible_truncation)]
281 stream.push(async move { vec![i as u8; 5] }, 5).await;
282 }
283
284 assert_eq!(stream.bytes_available(), 50);
286
287 let mut count = 0;
289 while stream.next().await.is_some() {
290 count += 1;
291 if count == 10 {
292 break;
293 }
294 }
295
296 assert_eq!(count, 10);
297 assert_eq!(stream.bytes_available(), 100);
298 }
299
300 #[test]
301 fn test_size_overflow_protection() {
302 let stream = SizeLimitedStream::new(100);
303
304 #[cfg(target_pointer_width = "64")]
307 {
308 let _large_size = (u32::MAX as usize) + 1;
309 }
313
314 let result = stream.try_push(async { vec![0u8; 50] }, 50);
316 assert!(result.is_ok());
317 }
318}