1use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Context;
8use std::task::Poll;
9use std::task::ready;
10
11use futures::Stream;
12use futures::stream::FuturesUnordered;
13use pin_project_lite::pin_project;
14use tokio::sync::OwnedSemaphorePermit;
15use tokio::sync::Semaphore;
16use tokio::sync::TryAcquireError;
17use vortex_error::VortexExpect;
18
19pin_project! {
20 struct SizedFut<Fut> {
25 #[pin]
26 inner: Fut,
27 _permits: OwnedSemaphorePermit,
29 }
30}
31
32impl<Fut: Future> Future for SizedFut<Fut> {
33 type Output = Fut::Output;
34
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.project();
37 let result = ready!(this.inner.poll(cx));
38 Poll::Ready(result)
39 }
40}
41
42pin_project! {
43 pub struct SizeLimitedStream<Fut> {
53 #[pin]
54 inflight: FuturesUnordered<SizedFut<Fut>>,
55 bytes_available: Arc<Semaphore>,
56 }
57}
58
59impl<Fut> SizeLimitedStream<Fut> {
60 pub fn new(max_bytes: usize) -> Self {
61 Self {
62 inflight: FuturesUnordered::new(),
63 bytes_available: Arc::new(Semaphore::new(max_bytes)),
64 }
65 }
66
67 pub fn bytes_available(&self) -> usize {
68 self.bytes_available.available_permits()
69 }
70}
71
72impl<Fut> SizeLimitedStream<Fut>
73where
74 Fut: Future,
75{
76 pub async fn push(&self, fut: Fut, bytes: usize) {
81 let permits = self
85 .bytes_available
86 .clone()
87 .acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
88 .await
89 .unwrap_or_else(|_| unreachable!("pushing to closed semaphore"));
90
91 let sized_fut = SizedFut {
92 inner: fut,
93 _permits: permits,
94 };
95
96 self.inflight.push(sized_fut);
98 }
99
100 pub fn try_push(&self, fut: Fut, bytes: usize) -> Result<(), Fut> {
105 match self
106 .bytes_available
107 .clone()
108 .try_acquire_many_owned(bytes.try_into().vortex_expect("bytes must fit in u32"))
109 {
110 Ok(permits) => {
111 let sized_fut = SizedFut {
112 inner: fut,
113 _permits: permits,
114 };
115
116 self.inflight.push(sized_fut);
117 Ok(())
118 }
119 Err(acquire_err) => match acquire_err {
120 TryAcquireError::Closed => {
121 unreachable!("try_pushing to closed semaphore");
122 }
123
124 TryAcquireError::NoPermits => Err(fut),
127 },
128 }
129 }
130}
131
132impl<Fut> Stream for SizeLimitedStream<Fut>
133where
134 Fut: Future,
135{
136 type Item = Fut::Output;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 let this = self.project();
140 match ready!(this.inflight.poll_next(cx)) {
141 None => Poll::Ready(None),
142 Some(result) => {
143 Poll::Ready(Some(result))
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::future;
154 use std::io;
155
156 use bytes::Bytes;
157 use futures::FutureExt;
158 use futures::StreamExt;
159 use futures::future::BoxFuture;
160
161 use crate::limit::SizeLimitedStream;
162
163 async fn make_future(len: usize) -> Bytes {
164 "a".as_bytes().iter().copied().cycle().take(len).collect()
165 }
166
167 #[tokio::test]
168 async fn test_size_limit() {
169 let mut size_limited = SizeLimitedStream::new(10);
170 size_limited.push(make_future(5), 5).await;
171 size_limited.push(make_future(5), 5).await;
172
173 assert!(size_limited.try_push(make_future(1), 1).is_err());
175
176 assert!(size_limited.next().await.is_some());
178 assert!(size_limited.try_push(make_future(1), 1).is_ok());
179 }
180
181 #[tokio::test]
182 async fn test_does_not_leak_permits() {
183 let bad_fut: BoxFuture<'static, io::Result<Bytes>> =
184 future::ready(Err(io::Error::other("badness"))).boxed();
185
186 let good_fut: BoxFuture<'static, io::Result<Bytes>> =
187 future::ready(Ok(Bytes::from("aaaaa"))).boxed();
188
189 let mut size_limited = SizeLimitedStream::new(10);
190 size_limited.push(bad_fut, 10).await;
191
192 let good_fut = size_limited
194 .try_push(good_fut, 5)
195 .expect_err("try_push should fail");
196
197 let next = size_limited.next().await.unwrap();
200 assert!(next.is_err());
201
202 assert_eq!(size_limited.bytes_available(), 10);
203 assert!(size_limited.try_push(good_fut, 5).is_ok());
204 }
205
206 #[tokio::test]
207 async fn test_size_limited_stream_zero_capacity() {
208 let stream = SizeLimitedStream::new(0);
209
210 let result = stream.try_push(async { vec![1u8] }, 1);
212 assert!(result.is_err());
213 }
214
215 #[tokio::test]
216 async fn test_size_limited_stream_dropped_future_releases_permits() {
217 use futures::future::BoxFuture;
218
219 let stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
220
221 stream
223 .push(
224 Box::pin(async {
225 futures::future::pending::<Vec<u8>>().await
227 }),
228 5,
229 )
230 .await;
231
232 stream.push(Box::pin(async { vec![1u8; 3] }), 3).await;
234
235 assert_eq!(stream.bytes_available(), 2);
237
238 drop(stream);
240
241 let mut new_stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
243
244 new_stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
246 assert_eq!(new_stream.bytes_available(), 0);
247
248 let result = new_stream.next().await;
250 assert!(result.is_some());
251 assert_eq!(new_stream.bytes_available(), 10);
252 }
253
254 #[tokio::test]
255 async fn test_size_limited_stream_exact_capacity() {
256 use futures::future::BoxFuture;
257
258 let mut stream = SizeLimitedStream::<BoxFuture<'static, Vec<u8>>>::new(10);
259
260 stream.push(Box::pin(async { vec![0u8; 10] }), 10).await;
262
263 let result = stream.try_push(Box::pin(async { vec![1u8] }), 1);
265 assert!(result.is_err());
266
267 stream
269 .next()
270 .await
271 .expect("The 10 byte vector ought to be in there!");
272 assert_eq!(stream.bytes_available(), 10);
273
274 let result = stream.try_push(Box::pin(async { vec![1u8; 5] }), 5);
275 assert!(result.is_ok());
276 }
277
278 #[tokio::test]
279 async fn test_size_limited_stream_multiple_small_pushes() {
280 let mut stream = SizeLimitedStream::new(100);
281
282 for i in 0..10 {
284 #[allow(clippy::cast_possible_truncation)]
285 stream.push(async move { vec![i as u8; 5] }, 5).await;
286 }
287
288 assert_eq!(stream.bytes_available(), 50);
290
291 let mut count = 0;
293 while stream.next().await.is_some() {
294 count += 1;
295 if count == 10 {
296 break;
297 }
298 }
299
300 assert_eq!(count, 10);
301 assert_eq!(stream.bytes_available(), 100);
302 }
303
304 #[test]
305 fn test_size_overflow_protection() {
306 let stream = SizeLimitedStream::new(100);
307
308 #[cfg(target_pointer_width = "64")]
311 {
312 let _large_size = (u32::MAX as usize) + 1;
313 }
317
318 let result = stream.try_push(async { vec![0u8; 50] }, 50);
320 assert!(result.is_ok());
321 }
322}