streamcatcher/
future.rs

1//! Support types for `AsyncRead`/`AsyncSeek` compatible stream buffers.
2//! Requires the `"async"` feature.
3use crate::*;
4#[cfg(feature = "tokio-compat")]
5pub use async_compat::Compat;
6use async_trait::async_trait;
7use core::{
8	future::Future,
9	pin::Pin,
10	task::{Context, Poll},
11};
12use futures_util::io::{self, AsyncRead, AsyncReadExt, AsyncSeek};
13use std::{
14	io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult, SeekFrom},
15	marker::Unpin,
16	mem,
17	sync::atomic::Ordering,
18};
19
20/// Async variant of [`Transform`].
21///
22/// [`Transform`]: ../trait.Transform.html
23pub trait AsyncTransform<TInput: AsyncRead> {
24	fn transform_poll_read(
25		&mut self,
26		src: Pin<&mut TInput>,
27		cx: &mut Context,
28		buf: &mut [u8],
29	) -> Poll<IoResult<TransformPosition>>;
30
31	fn min_bytes_required(&self) -> usize {
32		1
33	}
34}
35
36impl<T: AsyncRead> AsyncTransform<T> for Identity {
37	fn transform_poll_read(
38		&mut self,
39		src: Pin<&mut T>,
40		cx: &mut Context,
41		buf: &mut [u8],
42	) -> Poll<IoResult<TransformPosition>> {
43		src.poll_read(cx, buf).map(|res| {
44			res.map(|count| match count {
45				0 => TransformPosition::Finished,
46				n => TransformPosition::Read(n),
47			})
48		})
49	}
50}
51
52impl<T, Tx> TxCatcher<T, Tx>
53where
54	T: AsyncRead + Unpin + 'static,
55	Tx: AsyncTransform<T> + Unpin + 'static,
56{
57	/// Read all bytes from the underlying stream
58	/// into the backing store in the current task.
59	pub fn load_all_async(self) -> LoadAll<T, Tx> {
60		LoadAll::new(self)
61	}
62}
63
64/// Future returned by [`TxCatcher::load_all_async`].
65///
66/// [`TxCatcher::load_all_async`]: ../struct.TxCatcher.html#method.load_all_async
67pub struct LoadAll<T, Tx>
68where
69	T: AsyncRead + Unpin + 'static,
70	Tx: AsyncTransform<T> + Unpin + 'static,
71{
72	catcher: TxCatcher<T, Tx>,
73	in_pos: usize,
74}
75
76impl<T, Tx> LoadAll<T, Tx>
77where
78	T: AsyncRead + Unpin + 'static,
79	Tx: AsyncTransform<T> + Unpin + 'static,
80{
81	fn new(catcher: TxCatcher<T, Tx>) -> Self {
82		let in_pos = catcher.pos;
83
84		Self { catcher, in_pos }
85	}
86}
87
88impl<T, Tx> Future for LoadAll<T, Tx>
89where
90	T: AsyncRead + Unpin + 'static,
91	Tx: AsyncTransform<T> + Unpin + 'static,
92{
93	type Output = TxCatcher<T, Tx>;
94
95	fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
96		self.catcher.pos = self.catcher.len();
97
98		loop {
99			if self.catcher.is_finalised() {
100				break;
101			}
102
103			let mut skip_attempt = self.catcher.skip(7680);
104
105			match Future::poll(Pin::new(&mut skip_attempt), cx) {
106				Poll::Ready(0) => break,
107				Poll::Ready(_n) => {},
108				Poll::Pending => {
109					return Poll::Pending;
110				},
111			}
112		}
113
114		self.catcher.pos = self.in_pos;
115
116		Poll::Ready(self.catcher.new_handle())
117	}
118}
119
120impl<T, Tx> AsyncRead for TxCatcher<T, Tx>
121where
122	T: AsyncRead + Unpin + 'static,
123	Tx: AsyncTransform<T> + Unpin + 'static,
124{
125	fn poll_read(
126		mut self: Pin<&mut Self>,
127		cx: &mut Context,
128		buf: &mut [u8],
129	) -> Poll<IoResult<usize>> {
130		self.core.read_from_pos_async(self.pos, cx, buf).map(
131			|(bytes_read, should_finalise_here)| {
132				if should_finalise_here {
133					let handle = self.core.clone();
134					match self.core.config.spawn_finaliser {
135						Finaliser::InPlace => unreachable!(),
136						Finaliser::NewThread => {
137							std::thread::spawn(move || handle.do_finalise());
138						},
139						#[cfg(feature = "async-std-compat")]
140						Finaliser::AsyncStd => {
141							async_std::task::spawn(async move {
142								handle.do_finalise();
143							});
144						},
145						#[cfg(feature = "tokio-compat")]
146						Finaliser::Tokio => {
147							let _ = tokio::spawn(async move {
148								handle.do_finalise();
149							});
150						},
151						#[cfg(feature = "smol-compat")]
152						Finaliser::Smol => {
153							smol::spawn(async move {
154								handle.do_finalise();
155							})
156							.detach();
157						},
158					}
159				}
160
161				if let Ok(size) = bytes_read {
162					self.pos += size;
163				}
164
165				bytes_read
166			},
167		)
168	}
169}
170
171#[cfg(feature = "tokio-compat")]
172impl<T, Tx> TxCatcher<T, Tx> {
173	pub fn tokio(self) -> Compat<Self> {
174		Compat::new(self)
175	}
176}
177
178impl<T, Tx> AsyncSeek for TxCatcher<T, Tx>
179where
180	T: AsyncRead + Unpin + 'static,
181	Tx: AsyncTransform<T> + Unpin + 'static,
182{
183	fn poll_seek(mut self: Pin<&mut Self>, cx: &mut Context, pos: SeekFrom) -> Poll<IoResult<u64>> {
184		let old_pos = self.pos as u64;
185
186		let (valid, new_pos) = match pos {
187			SeekFrom::Current(adj) => {
188				// overflow expected in many cases.
189				let new_pos = old_pos.wrapping_add(adj as u64);
190				(adj >= 0 || (adj.abs() as u64) <= old_pos, new_pos)
191			},
192			SeekFrom::End(adj) => {
193				// Slower to load in the whole stream first, but safer.
194				// We could, in theory, use metadata as the basis,
195				// but incorrect metadata would be tricky to work around.
196				let mut end_read_future = self.new_handle().load_all_async();
197				if Future::poll(Pin::new(&mut end_read_future), cx).is_pending() {
198					return Poll::Pending;
199				}
200
201				let len = self.len() as u64;
202				let new_pos = len.wrapping_add(adj as u64);
203				(adj >= 0 || (adj.abs() as u64) <= len, new_pos)
204			},
205			SeekFrom::Start(new_pos) => (true, new_pos),
206		};
207
208		Poll::Ready(if valid {
209			if new_pos > old_pos {
210				self.pos = (new_pos as usize).min(self.len());
211				if new_pos != self.pos as u64 {
212					let mut skip_future = self.skip(new_pos as usize - self.pos);
213					if Future::poll(Pin::new(&mut skip_future), cx).is_pending() {
214						return Poll::Pending;
215					}
216				}
217			}
218
219			let len = self.len() as u64;
220
221			self.pos = new_pos.min(len) as usize;
222			Ok(self.pos as u64)
223		} else {
224			Err(IoError::new(
225				IoErrorKind::InvalidInput,
226				"Tried to seek before start of stream.",
227			))
228		})
229	}
230}
231
232impl<T, Tx> RawStore<T, Tx>
233where
234	T: AsyncRead + Unpin,
235	Tx: AsyncTransform<T> + Unpin,
236{
237	/// Returns read count, should_upgrade, should_finalise_external
238	fn read_from_pos_async(
239		&self,
240		pos: usize,
241		cx: &mut Context,
242		buf: &mut [u8],
243	) -> Poll<(IoResult<usize>, bool)> {
244		// Place read of finalised first to be certain that if we see finalised,
245		// then backing_len *must* be the true length.
246		let (loc, mut finalised) = self.get_location();
247
248		let mut backing_len = self.len();
249
250		let mut should_finalise_external = false;
251
252		let target_len = pos + buf.len();
253
254		let mut progress_before_pending = false;
255
256		let out = if finalised.is_source_finished() || target_len <= backing_len {
257			// If finalised, there is zero risk of triggering more writes.
258			progress_before_pending = true;
259			let read_amt = buf.len().min(backing_len - pos);
260			Ok(self.read_from_local(pos, loc, buf, read_amt))
261		} else {
262			let mut read = 0;
263			let mut base_result = None;
264
265			loop {
266				finalised = self.finalised();
267				backing_len = self.len();
268				let mut remaining_in_store = backing_len - pos - read;
269
270				if remaining_in_store == 0 {
271					let mut guard = self.lock.lock();
272
273					if Future::poll(Pin::new(&mut guard), cx).is_pending() {
274						break;
275					}
276
277					finalised = self.finalised();
278					backing_len = self.len();
279
280					// If length changed between our check and
281					// acquiring the lock, then drop it -- we don't need new bytes *yet*
282					// and might not!
283					remaining_in_store = backing_len - pos - read;
284					if remaining_in_store == 0 && finalised.is_source_live() {
285						if let Poll::Ready(read_count) =
286							self.fill_from_source_async(cx, buf.len() - read)
287						{
288							progress_before_pending = true;
289							if let Ok((read_count, finalise_elsewhere)) = read_count {
290								remaining_in_store += read_count;
291								should_finalise_external |= finalise_elsewhere;
292							}
293							base_result = Some(read_count.map(|a| a.0));
294
295							finalised = self.finalised();
296						} else {
297							break;
298						}
299					}
300
301					// (Explicitly) unlocked here.
302					mem::drop(guard);
303				}
304
305				if remaining_in_store > 0 {
306					let count = remaining_in_store.min(buf.len() - read);
307					read += self.read_from_local(pos + read, loc, &mut buf[read..], count);
308				}
309
310				// break out if:
311				// * no space in reader's buffer
312				// * hit an error
313				// * or nothing remaining, AND finalised
314				if matches!(base_result, Some(Err(_)))
315					|| read == buf.len() || (finalised.is_source_finished()
316					&& backing_len == pos + read)
317				{
318					break;
319				}
320			}
321
322			base_result.unwrap_or(Ok(0)).map(|_| read)
323		};
324
325		if loc == CacheReadLocation::Roped {
326			self.remove_rope_full();
327		}
328
329		if progress_before_pending {
330			Poll::Ready((out, should_finalise_external))
331		} else {
332			Poll::Pending
333		}
334	}
335
336	// ONLY SAFE TO CALL WITH LOCK.
337	// The critical section concerns:
338	// * adding new elements to the rope
339	// * drawing bytes from the source
340	// * modifying len
341	// * modifying encoder state
342	fn fill_from_source_async(
343		&self,
344		cx: &mut Context,
345		mut bytes_needed: usize,
346	) -> Poll<IoResult<(usize, bool)>> {
347		let minimum_to_write = self
348			.transform
349			.with(|ptr| unsafe { &*ptr }.min_bytes_required());
350
351		let overspill = bytes_needed % self.config.read_burst_len;
352		if overspill != 0 {
353			bytes_needed += self.config.read_burst_len - overspill;
354		}
355
356		let mut remaining_bytes = bytes_needed;
357		let mut recorded_error = None;
358
359		let mut spawn_new_finaliser = false;
360
361		let mut progress_before_pending = false;
362
363		loop {
364			let should_break = self.rope.with_mut(|ptr| {
365				let rope = unsafe { &mut *ptr }
366					.as_mut()
367					.expect("Writes should only occur while the rope exists.");
368
369				let chunk_count = rope.len();
370
371				let rope_el = rope
372					.back_mut()
373					.expect("There will always be at least one element in rope.");
374
375				let old_len = rope_el.data.len();
376				let cap = rope_el.data.capacity();
377				let space = cap - old_len;
378
379				let new_len = old_len + space.min(remaining_bytes);
380
381				if space < minimum_to_write {
382					let end = rope_el.end_pos;
383					// Make a new chunk!
384					rope.push_back(BufferChunk::new(
385						end,
386						self.config.next_chunk_size(cap, chunk_count),
387					));
388
389					false
390				} else {
391					rope_el.data.resize(new_len, 0);
392
393					let poll = self.transform.with_mut(|tx_ptr| {
394						self.source.with_mut(|src_ptr| {
395							let src = unsafe { &mut *src_ptr }
396								.as_mut()
397								.expect("Source must exist while not finalised.");
398
399							unsafe { &mut *tx_ptr }.transform_poll_read(
400								Pin::new(src),
401								cx,
402								&mut rope_el.data[old_len..],
403							)
404						})
405					});
406
407					if let Poll::Ready(pos) = poll {
408						progress_before_pending = true;
409
410						match pos {
411							Ok(TransformPosition::Read(len)) => {
412								rope_el.end_pos += len;
413								rope_el.data.truncate(old_len + len);
414
415								remaining_bytes -= len;
416								self.len.fetch_add(len, Ordering::Release);
417							},
418							Ok(TransformPosition::Finished) => {
419								spawn_new_finaliser = self.finalise();
420							},
421							Err(e) if e.kind() == IoErrorKind::Interrupted => {
422								// DO nothing, so try again.
423							},
424							Err(e) => {
425								recorded_error = Some(Err(e));
426							},
427						}
428
429						self.finalised().is_source_finished()
430							|| remaining_bytes < minimum_to_write
431							|| recorded_error.is_some()
432					} else {
433						// Pending
434						true
435					}
436				}
437			});
438
439			if should_break {
440				break;
441			}
442		}
443
444		if progress_before_pending {
445			Poll::Ready(
446				recorded_error.unwrap_or(Ok((bytes_needed - remaining_bytes, spawn_new_finaliser))),
447			)
448		} else {
449			Poll::Pending
450		}
451	}
452}
453
454#[async_trait]
455/// Async variant of [`ReadSkipExt`].
456///
457/// [`ReadSkipExt`]: ../trait.ReadSkipExt.html
458pub trait AsyncReadSkipExt {
459	async fn skip(&mut self, amt: usize) -> usize
460	where
461		Self: Sized;
462}
463
464#[async_trait]
465impl<R: AsyncRead + Sized + Unpin + Send> AsyncReadSkipExt for R {
466	async fn skip(&mut self, amt: usize) -> usize {
467		io::copy(&mut self.take(amt as u64), &mut io::sink())
468			.await
469			.unwrap_or(0) as usize
470	}
471}