1mod cast;
2mod range;
3mod waker;
4
5use std::{
6 future::Future,
7 io,
8 pin::Pin,
9 sync::Arc,
10 task::{ready, Context, Poll, Waker},
11};
12
13use arrayvec::ArrayVec;
14use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
15use waker::PollPendingQueue;
16
17pub use range::AsyncRangeRead;
18
19use crate::cast::CastExt;
20
21pub const DEFAULT_TILE_SIZE: usize = 4096;
22pub const MAX_TILE_COUNT: usize = 1 << TILE_COUNT_BITS;
23
24const TILE_COUNT_BITS: usize = 5;
25const TILE_COUNT_MASK: usize = MAX_TILE_COUNT - 1;
26
27pub struct TileBuffer<const N: usize, R: AsyncRangeRead + 'static> {
31 tiles: [Tile<R>; N],
34
35 tile_mapping: [usize; N],
50
51 tile_pointer: usize,
65
66 tile_size: usize,
68
69 tile_total_count: usize,
71
72 offset: usize,
74
75 total_size: usize,
77
78 pending: Arc<PollPendingQueue>,
80 inner: R,
81}
82
83impl<const N: usize, R: AsyncRangeRead> TileBuffer<N, R> {
84 pub fn new(inner: R) -> Self {
86 Self::new_with_tile_size_and_offset(inner, DEFAULT_TILE_SIZE, N / 2)
87 }
88
89 pub fn new_with_tile_size(inner: R, tile_size: usize) -> Self {
91 Self::new_with_tile_size_and_offset(inner, tile_size, N / 2)
92 }
93
94 pub fn new_with_tile_size_and_offset(inner: R, tile_size: usize, tile_pointer: usize) -> Self {
96 assert!(N <= 32, "Maximum number of tiles cannot be greater 32!");
97
98 let total_size = inner.total_size();
99 let tile_total_count = total_size / tile_size;
100
101 let tile_total_count = if (total_size % tile_size) > 0 {
102 tile_total_count + 1
103 } else {
104 tile_total_count
105 };
106
107 let pending = Arc::new(PollPendingQueue::default());
108
109 Self {
110 tiles: (0..N)
111 .map(|i| {
112 let mut tile = Tile::new(i, tile_size, waker::create_waker(pending.clone(), i));
113 let offset = i * tile_size;
114 let length = usize::min(offset + tile_size, total_size) - offset;
115 tile.stage(&inner, offset, length);
116 tile
117 })
118 .cast(),
119 tile_mapping: (0..N).cast(),
120 tile_pointer,
121 tile_size,
122 tile_total_count,
123 total_size,
124 offset: 0,
125 inner,
126 pending,
127 }
128 }
129
130 pub fn set_offset(&mut self, new_offset: usize) {
136 self.offset = if new_offset > self.total_size {
137 self.total_size
138 } else {
139 new_offset
140 };
141
142 let (tile_begin, _) = self.current_tile();
143 let tile_end = tile_begin + N;
144
145 let mut free: ArrayVec<usize, N> = ArrayVec::new();
146
147 self.tile_mapping.iter_mut().for_each(|x| *x = usize::MAX);
149
150 for (idx, b) in self.tiles.iter().enumerate() {
152 if b.index >= tile_begin && b.index < tile_end {
153 self.tile_mapping[b.index - tile_begin] = idx;
154 } else {
155 free.push(idx);
156 }
157 }
158
159 for m in 0..N {
161 if self.tile_mapping[m] == usize::MAX {
162 let bindex = free.pop().unwrap();
163 self.tiles[bindex].index = tile_begin + m;
164 self.tile_mapping[m] = bindex;
165
166 let offset = (tile_begin + m) * self.tile_size;
167 let length = usize::min(offset + self.tile_size, self.total_size) - offset;
168
169 self.tiles[bindex].stage(&self.inner, offset, length);
170 }
171 }
172 }
173
174 fn poll_tiles(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
175 self.pending.register_waker(cx.waker());
176
177 loop {
178 let Some(index) = self.pending.next_item() else {
179 break Poll::Pending;
180 };
181
182 if let Poll::Ready(val) = self.tiles[index].poll(cx) {
183 break Poll::Ready(match val {
184 Ok(_) => Ok(index),
185 Err(err) => Err(err),
186 });
187 }
188 }
189 }
190
191 #[inline]
192 fn current_tile(&self) -> (usize, usize) {
193 let curr_tile = self.offset / self.tile_size;
194
195 let tile_begin = if curr_tile < self.tile_pointer {
196 0
197 } else if curr_tile + self.tile_pointer < self.tile_total_count {
198 curr_tile - self.tile_pointer
199 } else if self.tile_total_count >= N {
200 self.tile_total_count - N
201 } else {
202 0
203 };
204
205 (tile_begin, curr_tile - tile_begin)
206 }
207
208 fn current_buffer_read(&mut self, buf: &mut ReadBuf<'_>) -> Poll<usize> {
209 let (begin, current) = self.current_tile();
210 let current_tile_idx = begin + current;
211
212 let mapped = self.tile_mapping[current];
213 let tile = &mut self.tiles[mapped];
214
215 if tile.task.is_some() {
216 return Poll::Pending;
217 }
218
219 let begin_offset = current_tile_idx * self.tile_size;
220 let tile_offset = self.offset - begin_offset;
221 let tile_buffer_remaining = &tile.data[tile_offset..];
222
223 let upto = usize::min(buf.remaining(), tile_buffer_remaining.len());
224
225 buf.put_slice(&tile_buffer_remaining[..upto]);
226
227 Poll::Ready(upto)
228 }
229}
230
231impl<const N: usize, R: AsyncRangeRead> AsyncRead for TileBuffer<N, R> {
232 fn poll_read(
233 self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 buf: &mut ReadBuf<'_>,
236 ) -> Poll<io::Result<()>> {
237 let this = unsafe { self.get_unchecked_mut() };
239
240 while let Poll::Ready(val) = this.poll_tiles(cx) {
241 match val {
242 Ok(_index) => (),
243 Err(err) => return Poll::Ready(Err(err)),
244 }
245 }
246
247 let remaining = buf.remaining();
248 while let Poll::Ready(read) = this.current_buffer_read(buf) {
249 this.set_offset(this.offset + read);
250
251 if read == 0 || buf.remaining() == 0 {
252 return Poll::Ready(Ok(()));
253 }
254 }
255
256 if remaining != buf.remaining() {
257 Poll::Ready(Ok(()))
258 } else {
259 Poll::Pending
260 }
261 }
262}
263
264impl<const N: usize, R: AsyncRangeRead> AsyncSeek for TileBuffer<N, R> {
265 fn start_seek(self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
266 let new_offset = match position {
267 io::SeekFrom::Start(offset) => offset as _,
268 io::SeekFrom::End(offset) => (self.total_size as i64 + offset).try_into().unwrap(),
269 io::SeekFrom::Current(offset) => (self.offset as i64 + offset).try_into().unwrap(),
270 };
271
272 unsafe { self.get_unchecked_mut() }.set_offset(new_offset);
273
274 Ok(())
275 }
276
277 fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
278 Poll::Ready(Ok(self.offset as _))
279 }
280}
281
282struct Tile<R: AsyncRangeRead + 'static> {
283 index: usize,
284 data: Vec<u8>,
285 task: Option<R::Fut<'static>>,
286 waker: Waker,
287}
288
289impl<R: AsyncRangeRead + 'static> Tile<R> {
290 fn new(index: usize, tile_size: usize, waker: Waker) -> Self {
291 Self {
292 index,
293 data: Vec::with_capacity(tile_size),
294 task: None,
295 waker,
296 }
297 }
298
299 fn poll(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
300 if let Some(fut) = self.task.as_mut() {
301 let mut ctx = Context::from_waker(&self.waker);
302
303 let pinned = unsafe { Pin::new_unchecked(fut) };
305 ready!(pinned.poll(&mut ctx))?;
306
307 drop(self.task.take());
308
309 Poll::Ready(Ok(()))
310 } else {
311 Poll::Pending
312 }
313 }
314
315 fn stage(&mut self, inner: &R, offset: usize, length: usize) {
316 if self.data.len() != length {
317 self.data.resize(length, 0);
318 }
319
320 let fut = inner.range_read(&mut self.data, offset);
321
322 self.task = Some(unsafe { std::mem::transmute(fut) });
324 self.waker.wake_by_ref();
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use std::{io, pin::Pin};
331
332 use super::{AsyncRangeRead, TileBuffer};
333 use futures::Future;
334 use tokio::io::{AsyncReadExt, AsyncSeekExt};
335
336 struct Test;
337 impl AsyncRangeRead for Test {
338 type Fut<'a> = Pin<Box<dyn Future<Output = io::Result<()>> + 'a>>
339 where Self: 'a;
340
341 fn total_size(&self) -> usize {
342 50630
343 }
344
345 fn range_read<'a>(&'a self, buf: &'a mut [u8], offset: usize) -> Self::Fut<'a> {
346 Box::pin(async move {
347 let mut counter = offset;
348
349 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
350
351 for val in buf.iter_mut() {
352 *val = counter as _;
353 counter += 1;
354 }
355
356 Ok(())
357 })
358 }
359 }
360
361 #[tokio::test]
362 async fn sequential_read_test() {
363 let inner = Test;
364
365 let mut buff: TileBuffer<5, _> = TileBuffer::new_with_tile_size(inner, 1024);
366 let mut data = Vec::new();
367
368 let x = buff.read_to_end(&mut data).await.unwrap();
369
370 let valid = (0..50630u32).map(|x| x as u8).collect::<Vec<u8>>();
371
372 assert_eq!(x, 50630);
373 assert_eq!(data, valid);
374 }
375
376 #[tokio::test]
377 async fn random_read_test() {
378 let inner = Test;
379
380 let mut buff: TileBuffer<5, _> = TileBuffer::new_with_tile_size(inner, 1024);
381 let mut data = [0u8; 256];
382 let valid = (0u8..=255).collect::<Vec<u8>>();
383
384 let x = buff.read_exact(&mut data).await.unwrap();
385 assert_eq!(x, 256);
386 assert_eq!(data.as_slice(), valid.as_slice());
387
388 buff.seek(io::SeekFrom::Start(8448)).await.unwrap();
389
390 let x = buff.read_exact(&mut data).await.unwrap();
391 assert_eq!(x, 256);
392 assert_eq!(data.as_slice(), valid.as_slice());
393
394 buff.seek(io::SeekFrom::Start(7424)).await.unwrap();
395
396 let x = buff.read_exact(&mut data).await.unwrap();
397 assert_eq!(x, 256);
398 assert_eq!(data.as_slice(), valid.as_slice());
399
400 buff.seek(io::SeekFrom::Current(-1024)).await.unwrap();
401
402 let x = buff.read_exact(&mut data).await.unwrap();
403 assert_eq!(x, 256);
404 assert_eq!(data.as_slice(), valid.as_slice());
405
406 buff.seek(io::SeekFrom::Start(0)).await.unwrap();
407 buff.seek(io::SeekFrom::End(-454)).await.unwrap();
408
409 let x = buff.read_exact(&mut data).await.unwrap();
410 assert_eq!(x, 256);
411 assert_eq!(data.as_slice(), valid.as_slice());
412 }
413}