tile_buffer/
lib.rs

1mod cast;
2mod range;
3mod waker;
4
5use std::{
6    future::Future,
7    io,
8    pin::Pin,
9    sync::Arc,
10    task::{ready, Context, Poll, Waker},
11};
12
13use arrayvec::ArrayVec;
14use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
15use waker::PollPendingQueue;
16
17pub use range::AsyncRangeRead;
18
19use crate::cast::CastExt;
20
21pub const DEFAULT_TILE_SIZE: usize = 4096;
22pub const MAX_TILE_COUNT: usize = 1 << TILE_COUNT_BITS;
23
24const TILE_COUNT_BITS: usize = 5;
25const TILE_COUNT_MASK: usize = MAX_TILE_COUNT - 1;
26
27///
28/// TileBuffer structure
29///
30pub struct TileBuffer<const N: usize, R: AsyncRangeRead + 'static> {
31    ///
32    /// Array of tiles. Size of this array usually shuld not be greated than 5
33    tiles: [Tile<R>; N],
34
35    ///
36    /// Mapping between relative tile index (0..N) to global tile index (0..tile_total_count)
37    /// example:
38    ///   tiles: [{index: 12, data: Some}, {index: 13, data: Some}, {index: 14, data: Some}, {index: 15, data: None}]
39    ///   tile_mapping: [12, 13, 14, 15]
40    ///
41    /// means:
42    ///                                         |
43    /// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16
44    ///                                     == XX == ++      
45    /// |  - offset pointer
46    /// == - buffer loaded
47    /// ++ - buffer is loading now
48    /// XX - buffer loaded and current (relative index is 1; global index is 13)
49    tile_mapping: [usize; N],
50
51    ///
52    /// Pointer of the current tile.
53    /// example:
54    ///   if tile_pointer is 0:
55    ///     0 1 2 3 4 5 6 7 8
56    ///         X = = =
57    /// note: other tiles would load further data (good for sequential forward read)
58    ///
59    /// example:
60    ///   if tile_pointer is 2:
61    ///     0 1 2 3 4 5 6 7 8
62    ///         = = X =
63    /// note: most of other tiles would keep past data (good for random seeking within some distance from current offset)
64    tile_pointer: usize,
65
66    /// Size of one tile in bytes
67    tile_size: usize,
68
69    /// Effectivly total_size / tile_size
70    tile_total_count: usize,
71
72    /// Current offset in bytes
73    offset: usize,
74
75    /// Total size in bytes
76    total_size: usize,
77
78    /// Pending to poll next indexes with waker
79    pending: Arc<PollPendingQueue>,
80    inner: R,
81}
82
83impl<const N: usize, R: AsyncRangeRead> TileBuffer<N, R> {
84    ///
85    pub fn new(inner: R) -> Self {
86        Self::new_with_tile_size_and_offset(inner, DEFAULT_TILE_SIZE, N / 2)
87    }
88
89    ///
90    pub fn new_with_tile_size(inner: R, tile_size: usize) -> Self {
91        Self::new_with_tile_size_and_offset(inner, tile_size, N / 2)
92    }
93
94    ///
95    pub fn new_with_tile_size_and_offset(inner: R, tile_size: usize, tile_pointer: usize) -> Self {
96        assert!(N <= 32, "Maximum number of tiles cannot be greater 32!");
97
98        let total_size = inner.total_size();
99        let tile_total_count = total_size / tile_size;
100
101        let tile_total_count = if (total_size % tile_size) > 0 {
102            tile_total_count + 1
103        } else {
104            tile_total_count
105        };
106
107        let pending = Arc::new(PollPendingQueue::default());
108
109        Self {
110            tiles: (0..N)
111                .map(|i| {
112                    let mut tile = Tile::new(i, tile_size, waker::create_waker(pending.clone(), i));
113                    let offset = i * tile_size;
114                    let length = usize::min(offset + tile_size, total_size) - offset;
115                    tile.stage(&inner, offset, length);
116                    tile
117                })
118                .cast(),
119            tile_mapping: (0..N).cast(),
120            tile_pointer,
121            tile_size,
122            tile_total_count,
123            total_size,
124            offset: 0,
125            inner,
126            pending,
127        }
128    }
129
130    ///
131    /// Set new offset
132    ///   calling that method will recalculate mappings,
133    ///   reuse already loaded tiles and stage the ones
134    ///   which not loaded
135    pub fn set_offset(&mut self, new_offset: usize) {
136        self.offset = if new_offset > self.total_size {
137            self.total_size
138        } else {
139            new_offset
140        };
141
142        let (tile_begin, _) = self.current_tile();
143        let tile_end = tile_begin + N;
144
145        let mut free: ArrayVec<usize, N> = ArrayVec::new();
146
147        // clearing previous mappings
148        self.tile_mapping.iter_mut().for_each(|x| *x = usize::MAX);
149
150        // map existing ones
151        for (idx, b) in self.tiles.iter().enumerate() {
152            if b.index >= tile_begin && b.index < tile_end {
153                self.tile_mapping[b.index - tile_begin] = idx;
154            } else {
155                free.push(idx);
156            }
157        }
158
159        // map rest ones to free buffers and stage load appropriate ranges into those buffers
160        for m in 0..N {
161            if self.tile_mapping[m] == usize::MAX {
162                let bindex = free.pop().unwrap();
163                self.tiles[bindex].index = tile_begin + m;
164                self.tile_mapping[m] = bindex;
165
166                let offset = (tile_begin + m) * self.tile_size;
167                let length = usize::min(offset + self.tile_size, self.total_size) - offset;
168
169                self.tiles[bindex].stage(&self.inner, offset, length);
170            }
171        }
172    }
173
174    fn poll_tiles(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
175        self.pending.register_waker(cx.waker());
176
177        loop {
178            let Some(index) = self.pending.next_item() else {
179                break Poll::Pending;
180            };
181
182            if let Poll::Ready(val) = self.tiles[index].poll(cx) {
183                break Poll::Ready(match val {
184                    Ok(_) => Ok(index),
185                    Err(err) => Err(err),
186                });
187            }
188        }
189    }
190
191    #[inline]
192    fn current_tile(&self) -> (usize, usize) {
193        let curr_tile = self.offset / self.tile_size;
194
195        let tile_begin = if curr_tile < self.tile_pointer {
196            0
197        } else if curr_tile + self.tile_pointer < self.tile_total_count {
198            curr_tile - self.tile_pointer
199        } else if self.tile_total_count >= N {
200            self.tile_total_count - N
201        } else {
202            0
203        };
204
205        (tile_begin, curr_tile - tile_begin)
206    }
207
208    fn current_buffer_read(&mut self, buf: &mut ReadBuf<'_>) -> Poll<usize> {
209        let (begin, current) = self.current_tile();
210        let current_tile_idx = begin + current;
211
212        let mapped = self.tile_mapping[current];
213        let tile = &mut self.tiles[mapped];
214
215        if tile.task.is_some() {
216            return Poll::Pending;
217        }
218
219        let begin_offset = current_tile_idx * self.tile_size;
220        let tile_offset = self.offset - begin_offset;
221        let tile_buffer_remaining = &tile.data[tile_offset..];
222
223        let upto = usize::min(buf.remaining(), tile_buffer_remaining.len());
224
225        buf.put_slice(&tile_buffer_remaining[..upto]);
226
227        Poll::Ready(upto)
228    }
229}
230
231impl<const N: usize, R: AsyncRangeRead> AsyncRead for TileBuffer<N, R> {
232    fn poll_read(
233        self: Pin<&mut Self>,
234        cx: &mut Context<'_>,
235        buf: &mut ReadBuf<'_>,
236    ) -> Poll<io::Result<()>> {
237        // TODO: use pin-project eventually
238        let this = unsafe { self.get_unchecked_mut() };
239
240        while let Poll::Ready(val) = this.poll_tiles(cx) {
241            match val {
242                Ok(_index) => (),
243                Err(err) => return Poll::Ready(Err(err)),
244            }
245        }
246
247        let remaining = buf.remaining();
248        while let Poll::Ready(read) = this.current_buffer_read(buf) {
249            this.set_offset(this.offset + read);
250
251            if read == 0 || buf.remaining() == 0 {
252                return Poll::Ready(Ok(()));
253            }
254        }
255
256        if remaining != buf.remaining() {
257            Poll::Ready(Ok(()))
258        } else {
259            Poll::Pending
260        }
261    }
262}
263
264impl<const N: usize, R: AsyncRangeRead> AsyncSeek for TileBuffer<N, R> {
265    fn start_seek(self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
266        let new_offset = match position {
267            io::SeekFrom::Start(offset) => offset as _,
268            io::SeekFrom::End(offset) => (self.total_size as i64 + offset).try_into().unwrap(),
269            io::SeekFrom::Current(offset) => (self.offset as i64 + offset).try_into().unwrap(),
270        };
271
272        unsafe { self.get_unchecked_mut() }.set_offset(new_offset);
273
274        Ok(())
275    }
276
277    fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
278        Poll::Ready(Ok(self.offset as _))
279    }
280}
281
282struct Tile<R: AsyncRangeRead + 'static> {
283    index: usize,
284    data: Vec<u8>,
285    task: Option<R::Fut<'static>>,
286    waker: Waker,
287}
288
289impl<R: AsyncRangeRead + 'static> Tile<R> {
290    fn new(index: usize, tile_size: usize, waker: Waker) -> Self {
291        Self {
292            index,
293            data: Vec::with_capacity(tile_size),
294            task: None,
295            waker,
296        }
297    }
298
299    fn poll(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
300        if let Some(fut) = self.task.as_mut() {
301            let mut ctx = Context::from_waker(&self.waker);
302
303            // SAFTY:
304            let pinned = unsafe { Pin::new_unchecked(fut) };
305            ready!(pinned.poll(&mut ctx))?;
306
307            drop(self.task.take());
308
309            Poll::Ready(Ok(()))
310        } else {
311            Poll::Pending
312        }
313    }
314
315    fn stage(&mut self, inner: &R, offset: usize, length: usize) {
316        if self.data.len() != length {
317            self.data.resize(length, 0);
318        }
319
320        let fut = inner.range_read(&mut self.data, offset);
321
322        // SAFTY: Safe because it will live within 'self lifetime
323        self.task = Some(unsafe { std::mem::transmute(fut) });
324        self.waker.wake_by_ref();
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use std::{io, pin::Pin};
331
332    use super::{AsyncRangeRead, TileBuffer};
333    use futures::Future;
334    use tokio::io::{AsyncReadExt, AsyncSeekExt};
335
336    struct Test;
337    impl AsyncRangeRead for Test {
338        type Fut<'a> = Pin<Box<dyn Future<Output = io::Result<()>> + 'a>>
339        where Self: 'a;
340
341        fn total_size(&self) -> usize {
342            50630
343        }
344
345        fn range_read<'a>(&'a self, buf: &'a mut [u8], offset: usize) -> Self::Fut<'a> {
346            Box::pin(async move {
347                let mut counter = offset;
348
349                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
350
351                for val in buf.iter_mut() {
352                    *val = counter as _;
353                    counter += 1;
354                }
355
356                Ok(())
357            })
358        }
359    }
360
361    #[tokio::test]
362    async fn sequential_read_test() {
363        let inner = Test;
364
365        let mut buff: TileBuffer<5, _> = TileBuffer::new_with_tile_size(inner, 1024);
366        let mut data = Vec::new();
367
368        let x = buff.read_to_end(&mut data).await.unwrap();
369
370        let valid = (0..50630u32).map(|x| x as u8).collect::<Vec<u8>>();
371
372        assert_eq!(x, 50630);
373        assert_eq!(data, valid);
374    }
375
376    #[tokio::test]
377    async fn random_read_test() {
378        let inner = Test;
379
380        let mut buff: TileBuffer<5, _> = TileBuffer::new_with_tile_size(inner, 1024);
381        let mut data = [0u8; 256];
382        let valid = (0u8..=255).collect::<Vec<u8>>();
383
384        let x = buff.read_exact(&mut data).await.unwrap();
385        assert_eq!(x, 256);
386        assert_eq!(data.as_slice(), valid.as_slice());
387
388        buff.seek(io::SeekFrom::Start(8448)).await.unwrap();
389
390        let x = buff.read_exact(&mut data).await.unwrap();
391        assert_eq!(x, 256);
392        assert_eq!(data.as_slice(), valid.as_slice());
393
394        buff.seek(io::SeekFrom::Start(7424)).await.unwrap();
395
396        let x = buff.read_exact(&mut data).await.unwrap();
397        assert_eq!(x, 256);
398        assert_eq!(data.as_slice(), valid.as_slice());
399
400        buff.seek(io::SeekFrom::Current(-1024)).await.unwrap();
401
402        let x = buff.read_exact(&mut data).await.unwrap();
403        assert_eq!(x, 256);
404        assert_eq!(data.as_slice(), valid.as_slice());
405
406        buff.seek(io::SeekFrom::Start(0)).await.unwrap();
407        buff.seek(io::SeekFrom::End(-454)).await.unwrap();
408
409        let x = buff.read_exact(&mut data).await.unwrap();
410        assert_eq!(x, 256);
411        assert_eq!(data.as_slice(), valid.as_slice());
412    }
413}