1use crate::*;
4#[cfg(feature = "tokio-compat")]
5pub use async_compat::Compat;
6use async_trait::async_trait;
7use core::{
8 future::Future,
9 pin::Pin,
10 task::{Context, Poll},
11};
12use futures_util::io::{self, AsyncRead, AsyncReadExt, AsyncSeek};
13use std::{
14 io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult, SeekFrom},
15 marker::Unpin,
16 mem,
17 sync::atomic::Ordering,
18};
19
20pub trait AsyncTransform<TInput: AsyncRead> {
24 fn transform_poll_read(
25 &mut self,
26 src: Pin<&mut TInput>,
27 cx: &mut Context,
28 buf: &mut [u8],
29 ) -> Poll<IoResult<TransformPosition>>;
30
31 fn min_bytes_required(&self) -> usize {
32 1
33 }
34}
35
36impl<T: AsyncRead> AsyncTransform<T> for Identity {
37 fn transform_poll_read(
38 &mut self,
39 src: Pin<&mut T>,
40 cx: &mut Context,
41 buf: &mut [u8],
42 ) -> Poll<IoResult<TransformPosition>> {
43 src.poll_read(cx, buf).map(|res| {
44 res.map(|count| match count {
45 0 => TransformPosition::Finished,
46 n => TransformPosition::Read(n),
47 })
48 })
49 }
50}
51
52impl<T, Tx> TxCatcher<T, Tx>
53where
54 T: AsyncRead + Unpin + 'static,
55 Tx: AsyncTransform<T> + Unpin + 'static,
56{
57 pub fn load_all_async(self) -> LoadAll<T, Tx> {
60 LoadAll::new(self)
61 }
62}
63
64pub struct LoadAll<T, Tx>
68where
69 T: AsyncRead + Unpin + 'static,
70 Tx: AsyncTransform<T> + Unpin + 'static,
71{
72 catcher: TxCatcher<T, Tx>,
73 in_pos: usize,
74}
75
76impl<T, Tx> LoadAll<T, Tx>
77where
78 T: AsyncRead + Unpin + 'static,
79 Tx: AsyncTransform<T> + Unpin + 'static,
80{
81 fn new(catcher: TxCatcher<T, Tx>) -> Self {
82 let in_pos = catcher.pos;
83
84 Self { catcher, in_pos }
85 }
86}
87
88impl<T, Tx> Future for LoadAll<T, Tx>
89where
90 T: AsyncRead + Unpin + 'static,
91 Tx: AsyncTransform<T> + Unpin + 'static,
92{
93 type Output = TxCatcher<T, Tx>;
94
95 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
96 self.catcher.pos = self.catcher.len();
97
98 loop {
99 if self.catcher.is_finalised() {
100 break;
101 }
102
103 let mut skip_attempt = self.catcher.skip(7680);
104
105 match Future::poll(Pin::new(&mut skip_attempt), cx) {
106 Poll::Ready(0) => break,
107 Poll::Ready(_n) => {},
108 Poll::Pending => {
109 return Poll::Pending;
110 },
111 }
112 }
113
114 self.catcher.pos = self.in_pos;
115
116 Poll::Ready(self.catcher.new_handle())
117 }
118}
119
120impl<T, Tx> AsyncRead for TxCatcher<T, Tx>
121where
122 T: AsyncRead + Unpin + 'static,
123 Tx: AsyncTransform<T> + Unpin + 'static,
124{
125 fn poll_read(
126 mut self: Pin<&mut Self>,
127 cx: &mut Context,
128 buf: &mut [u8],
129 ) -> Poll<IoResult<usize>> {
130 self.core.read_from_pos_async(self.pos, cx, buf).map(
131 |(bytes_read, should_finalise_here)| {
132 if should_finalise_here {
133 let handle = self.core.clone();
134 match self.core.config.spawn_finaliser {
135 Finaliser::InPlace => unreachable!(),
136 Finaliser::NewThread => {
137 std::thread::spawn(move || handle.do_finalise());
138 },
139 #[cfg(feature = "async-std-compat")]
140 Finaliser::AsyncStd => {
141 async_std::task::spawn(async move {
142 handle.do_finalise();
143 });
144 },
145 #[cfg(feature = "tokio-compat")]
146 Finaliser::Tokio => {
147 let _ = tokio::spawn(async move {
148 handle.do_finalise();
149 });
150 },
151 #[cfg(feature = "smol-compat")]
152 Finaliser::Smol => {
153 smol::spawn(async move {
154 handle.do_finalise();
155 })
156 .detach();
157 },
158 }
159 }
160
161 if let Ok(size) = bytes_read {
162 self.pos += size;
163 }
164
165 bytes_read
166 },
167 )
168 }
169}
170
171#[cfg(feature = "tokio-compat")]
172impl<T, Tx> TxCatcher<T, Tx> {
173 pub fn tokio(self) -> Compat<Self> {
174 Compat::new(self)
175 }
176}
177
178impl<T, Tx> AsyncSeek for TxCatcher<T, Tx>
179where
180 T: AsyncRead + Unpin + 'static,
181 Tx: AsyncTransform<T> + Unpin + 'static,
182{
183 fn poll_seek(mut self: Pin<&mut Self>, cx: &mut Context, pos: SeekFrom) -> Poll<IoResult<u64>> {
184 let old_pos = self.pos as u64;
185
186 let (valid, new_pos) = match pos {
187 SeekFrom::Current(adj) => {
188 let new_pos = old_pos.wrapping_add(adj as u64);
190 (adj >= 0 || (adj.abs() as u64) <= old_pos, new_pos)
191 },
192 SeekFrom::End(adj) => {
193 let mut end_read_future = self.new_handle().load_all_async();
197 if Future::poll(Pin::new(&mut end_read_future), cx).is_pending() {
198 return Poll::Pending;
199 }
200
201 let len = self.len() as u64;
202 let new_pos = len.wrapping_add(adj as u64);
203 (adj >= 0 || (adj.abs() as u64) <= len, new_pos)
204 },
205 SeekFrom::Start(new_pos) => (true, new_pos),
206 };
207
208 Poll::Ready(if valid {
209 if new_pos > old_pos {
210 self.pos = (new_pos as usize).min(self.len());
211 if new_pos != self.pos as u64 {
212 let mut skip_future = self.skip(new_pos as usize - self.pos);
213 if Future::poll(Pin::new(&mut skip_future), cx).is_pending() {
214 return Poll::Pending;
215 }
216 }
217 }
218
219 let len = self.len() as u64;
220
221 self.pos = new_pos.min(len) as usize;
222 Ok(self.pos as u64)
223 } else {
224 Err(IoError::new(
225 IoErrorKind::InvalidInput,
226 "Tried to seek before start of stream.",
227 ))
228 })
229 }
230}
231
232impl<T, Tx> RawStore<T, Tx>
233where
234 T: AsyncRead + Unpin,
235 Tx: AsyncTransform<T> + Unpin,
236{
237 fn read_from_pos_async(
239 &self,
240 pos: usize,
241 cx: &mut Context,
242 buf: &mut [u8],
243 ) -> Poll<(IoResult<usize>, bool)> {
244 let (loc, mut finalised) = self.get_location();
247
248 let mut backing_len = self.len();
249
250 let mut should_finalise_external = false;
251
252 let target_len = pos + buf.len();
253
254 let mut progress_before_pending = false;
255
256 let out = if finalised.is_source_finished() || target_len <= backing_len {
257 progress_before_pending = true;
259 let read_amt = buf.len().min(backing_len - pos);
260 Ok(self.read_from_local(pos, loc, buf, read_amt))
261 } else {
262 let mut read = 0;
263 let mut base_result = None;
264
265 loop {
266 finalised = self.finalised();
267 backing_len = self.len();
268 let mut remaining_in_store = backing_len - pos - read;
269
270 if remaining_in_store == 0 {
271 let mut guard = self.lock.lock();
272
273 if Future::poll(Pin::new(&mut guard), cx).is_pending() {
274 break;
275 }
276
277 finalised = self.finalised();
278 backing_len = self.len();
279
280 remaining_in_store = backing_len - pos - read;
284 if remaining_in_store == 0 && finalised.is_source_live() {
285 if let Poll::Ready(read_count) =
286 self.fill_from_source_async(cx, buf.len() - read)
287 {
288 progress_before_pending = true;
289 if let Ok((read_count, finalise_elsewhere)) = read_count {
290 remaining_in_store += read_count;
291 should_finalise_external |= finalise_elsewhere;
292 }
293 base_result = Some(read_count.map(|a| a.0));
294
295 finalised = self.finalised();
296 } else {
297 break;
298 }
299 }
300
301 mem::drop(guard);
303 }
304
305 if remaining_in_store > 0 {
306 let count = remaining_in_store.min(buf.len() - read);
307 read += self.read_from_local(pos + read, loc, &mut buf[read..], count);
308 }
309
310 if matches!(base_result, Some(Err(_)))
315 || read == buf.len() || (finalised.is_source_finished()
316 && backing_len == pos + read)
317 {
318 break;
319 }
320 }
321
322 base_result.unwrap_or(Ok(0)).map(|_| read)
323 };
324
325 if loc == CacheReadLocation::Roped {
326 self.remove_rope_full();
327 }
328
329 if progress_before_pending {
330 Poll::Ready((out, should_finalise_external))
331 } else {
332 Poll::Pending
333 }
334 }
335
336 fn fill_from_source_async(
343 &self,
344 cx: &mut Context,
345 mut bytes_needed: usize,
346 ) -> Poll<IoResult<(usize, bool)>> {
347 let minimum_to_write = self
348 .transform
349 .with(|ptr| unsafe { &*ptr }.min_bytes_required());
350
351 let overspill = bytes_needed % self.config.read_burst_len;
352 if overspill != 0 {
353 bytes_needed += self.config.read_burst_len - overspill;
354 }
355
356 let mut remaining_bytes = bytes_needed;
357 let mut recorded_error = None;
358
359 let mut spawn_new_finaliser = false;
360
361 let mut progress_before_pending = false;
362
363 loop {
364 let should_break = self.rope.with_mut(|ptr| {
365 let rope = unsafe { &mut *ptr }
366 .as_mut()
367 .expect("Writes should only occur while the rope exists.");
368
369 let chunk_count = rope.len();
370
371 let rope_el = rope
372 .back_mut()
373 .expect("There will always be at least one element in rope.");
374
375 let old_len = rope_el.data.len();
376 let cap = rope_el.data.capacity();
377 let space = cap - old_len;
378
379 let new_len = old_len + space.min(remaining_bytes);
380
381 if space < minimum_to_write {
382 let end = rope_el.end_pos;
383 rope.push_back(BufferChunk::new(
385 end,
386 self.config.next_chunk_size(cap, chunk_count),
387 ));
388
389 false
390 } else {
391 rope_el.data.resize(new_len, 0);
392
393 let poll = self.transform.with_mut(|tx_ptr| {
394 self.source.with_mut(|src_ptr| {
395 let src = unsafe { &mut *src_ptr }
396 .as_mut()
397 .expect("Source must exist while not finalised.");
398
399 unsafe { &mut *tx_ptr }.transform_poll_read(
400 Pin::new(src),
401 cx,
402 &mut rope_el.data[old_len..],
403 )
404 })
405 });
406
407 if let Poll::Ready(pos) = poll {
408 progress_before_pending = true;
409
410 match pos {
411 Ok(TransformPosition::Read(len)) => {
412 rope_el.end_pos += len;
413 rope_el.data.truncate(old_len + len);
414
415 remaining_bytes -= len;
416 self.len.fetch_add(len, Ordering::Release);
417 },
418 Ok(TransformPosition::Finished) => {
419 spawn_new_finaliser = self.finalise();
420 },
421 Err(e) if e.kind() == IoErrorKind::Interrupted => {
422 },
424 Err(e) => {
425 recorded_error = Some(Err(e));
426 },
427 }
428
429 self.finalised().is_source_finished()
430 || remaining_bytes < minimum_to_write
431 || recorded_error.is_some()
432 } else {
433 true
435 }
436 }
437 });
438
439 if should_break {
440 break;
441 }
442 }
443
444 if progress_before_pending {
445 Poll::Ready(
446 recorded_error.unwrap_or(Ok((bytes_needed - remaining_bytes, spawn_new_finaliser))),
447 )
448 } else {
449 Poll::Pending
450 }
451 }
452}
453
454#[async_trait]
455pub trait AsyncReadSkipExt {
459 async fn skip(&mut self, amt: usize) -> usize
460 where
461 Self: Sized;
462}
463
464#[async_trait]
465impl<R: AsyncRead + Sized + Unpin + Send> AsyncReadSkipExt for R {
466 async fn skip(&mut self, amt: usize) -> usize {
467 io::copy(&mut self.take(amt as u64), &mut io::sink())
468 .await
469 .unwrap_or(0) as usize
470 }
471}