wsi_streamer/io/
block_cache.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use bytes::{Bytes, BytesMut};
6use lru::LruCache;
7use tokio::sync::{Mutex, Notify, RwLock};
8
9use super::RangeReader;
10use crate::error::IoError;
11
12/// Default block size: 256KB
13/// This is large enough to amortize S3 latency, small enough to not waste bandwidth.
14pub const DEFAULT_BLOCK_SIZE: usize = 256 * 1024;
15
16/// Default cache capacity in number of blocks.
17/// 100 blocks * 256KB = 25.6MB default cache size.
18const DEFAULT_CACHE_CAPACITY: usize = 100;
19
20/// Block-based caching layer that wraps any RangeReader.
21///
22/// This cache is critical for performance:
23/// - TIFF parsing requires many small reads at scattered offsets
24/// - Without caching, each read would be an S3 request
25/// - Block cache amortizes these into fewer, larger requests
26///
27/// Features:
28/// - Fixed-size block cache (default 256KB blocks)
29/// - LRU eviction when cache reaches capacity
30/// - Singleflight: concurrent requests for the same block share one fetch
31/// - Handles reads spanning multiple blocks
32pub struct BlockCache<R> {
33    /// The underlying reader
34    inner: Arc<R>,
35    /// Block size in bytes
36    block_size: usize,
37    /// Cached blocks indexed by block number
38    cache: RwLock<LruCache<u64, Bytes>>,
39    /// In-flight block fetches for singleflight pattern
40    in_flight: Mutex<HashMap<u64, Arc<Notify>>>,
41}
42
43impl<R: RangeReader> BlockCache<R> {
44    /// Create a new BlockCache wrapping the given reader.
45    ///
46    /// Uses default block size (256KB) and cache capacity (100 blocks).
47    pub fn new(inner: R) -> Self {
48        Self::with_capacity(inner, DEFAULT_BLOCK_SIZE, DEFAULT_CACHE_CAPACITY)
49    }
50
51    /// Create a new BlockCache with custom block size and capacity.
52    ///
53    /// # Arguments
54    /// * `inner` - The underlying reader to wrap
55    /// * `block_size` - Size of each cached block in bytes
56    /// * `capacity` - Maximum number of blocks to cache
57    pub fn with_capacity(inner: R, block_size: usize, capacity: usize) -> Self {
58        Self {
59            inner: Arc::new(inner),
60            block_size,
61            cache: RwLock::new(LruCache::new(
62                std::num::NonZeroUsize::new(capacity).unwrap(),
63            )),
64            in_flight: Mutex::new(HashMap::new()),
65        }
66    }
67
68    /// Get a block from cache or fetch it from the underlying reader.
69    ///
70    /// Implements the singleflight pattern: if multiple tasks request the same
71    /// block concurrently, only one fetch is performed and all tasks share the result.
72    async fn get_block(&self, block_idx: u64) -> Result<Bytes, IoError> {
73        loop {
74            // Fast path: check cache
75            {
76                let cache = self.cache.read().await;
77                if let Some(data) = cache.peek(&block_idx) {
78                    return Ok(data.clone());
79                }
80            }
81
82            // Slow path: check in_flight or become leader
83            let notify = {
84                let mut in_flight = self.in_flight.lock().await;
85
86                if let Some(notify) = in_flight.get(&block_idx) {
87                    // Another task is fetching this block, wait for it
88                    let notify = notify.clone();
89                    drop(in_flight);
90                    notify.notified().await;
91                    // Loop back to check cache
92                    continue;
93                }
94
95                // We're the leader for this block
96                let notify = Arc::new(Notify::new());
97                in_flight.insert(block_idx, notify.clone());
98                notify
99            };
100
101            // Fetch the block from source
102            let result = self.fetch_block_from_source(block_idx).await;
103
104            // Update cache and in_flight atomically, then notify waiters
105            {
106                let mut cache = self.cache.write().await;
107                let mut in_flight = self.in_flight.lock().await;
108
109                if let Ok(ref data) = result {
110                    cache.put(block_idx, data.clone());
111                }
112
113                in_flight.remove(&block_idx);
114            }
115
116            notify.notify_waiters();
117
118            return result;
119        }
120    }
121
122    /// Fetch a block directly from the underlying reader.
123    async fn fetch_block_from_source(&self, block_idx: u64) -> Result<Bytes, IoError> {
124        let offset = block_idx * self.block_size as u64;
125        let size = self.inner.size();
126
127        // Calculate actual bytes to read (may be less for last block)
128        let remaining = size.saturating_sub(offset);
129        if remaining == 0 {
130            return Err(IoError::RangeOutOfBounds {
131                offset,
132                requested: self.block_size as u64,
133                size,
134            });
135        }
136
137        let len = std::cmp::min(self.block_size as u64, remaining) as usize;
138        self.inner.read_exact_at(offset, len).await
139    }
140
141    /// Calculate which block contains the given offset.
142    #[inline]
143    fn block_for_offset(&self, offset: u64) -> u64 {
144        offset / self.block_size as u64
145    }
146
147    /// Calculate the offset within a block.
148    #[inline]
149    fn offset_within_block(&self, offset: u64) -> usize {
150        (offset % self.block_size as u64) as usize
151    }
152}
153
154#[async_trait]
155impl<R: RangeReader + 'static> RangeReader for BlockCache<R> {
156    async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
157        // Validate range
158        let size = self.inner.size();
159        if offset + len as u64 > size {
160            return Err(IoError::RangeOutOfBounds {
161                offset,
162                requested: len as u64,
163                size,
164            });
165        }
166
167        // Handle zero-length reads
168        if len == 0 {
169            return Ok(Bytes::new());
170        }
171
172        // Calculate which blocks we need
173        let start_block = self.block_for_offset(offset);
174        let end_block = self.block_for_offset(offset + len as u64 - 1);
175
176        if start_block == end_block {
177            // Single block read (common case)
178            let block = self.get_block(start_block).await?;
179            let block_offset = self.offset_within_block(offset);
180            Ok(block.slice(block_offset..block_offset + len))
181        } else {
182            // Multi-block read: fetch all required blocks and combine
183            let mut result = BytesMut::with_capacity(len);
184            let mut remaining = len;
185            let mut current_offset = offset;
186
187            for block_idx in start_block..=end_block {
188                let block = self.get_block(block_idx).await?;
189                let block_offset = self.offset_within_block(current_offset);
190                let bytes_in_block = std::cmp::min(block.len() - block_offset, remaining);
191
192                result.extend_from_slice(&block[block_offset..block_offset + bytes_in_block]);
193
194                remaining -= bytes_in_block;
195                current_offset += bytes_in_block as u64;
196            }
197
198            Ok(result.freeze())
199        }
200    }
201
202    fn size(&self) -> u64 {
203        self.inner.size()
204    }
205
206    fn identifier(&self) -> &str {
207        self.inner.identifier()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use std::sync::atomic::{AtomicUsize, Ordering};
215
216    /// Mock reader for testing that tracks read calls
217    struct MockReader {
218        data: Bytes,
219        identifier: String,
220        read_count: AtomicUsize,
221    }
222
223    impl MockReader {
224        fn new(data: Vec<u8>) -> Self {
225            Self {
226                data: Bytes::from(data),
227                identifier: "mock://test".to_string(),
228                read_count: AtomicUsize::new(0),
229            }
230        }
231
232        fn read_count(&self) -> usize {
233            self.read_count.load(Ordering::SeqCst)
234        }
235    }
236
237    #[async_trait]
238    impl RangeReader for MockReader {
239        async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
240            self.read_count.fetch_add(1, Ordering::SeqCst);
241
242            if offset + len as u64 > self.data.len() as u64 {
243                return Err(IoError::RangeOutOfBounds {
244                    offset,
245                    requested: len as u64,
246                    size: self.data.len() as u64,
247                });
248            }
249
250            Ok(self.data.slice(offset as usize..offset as usize + len))
251        }
252
253        fn size(&self) -> u64 {
254            self.data.len() as u64
255        }
256
257        fn identifier(&self) -> &str {
258            &self.identifier
259        }
260    }
261
262    #[tokio::test]
263    async fn test_single_block_read() {
264        // Create mock with 1KB of data
265        let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
266        let mock = MockReader::new(data.clone());
267
268        // Use small 256-byte blocks for testing
269        let cache = BlockCache::with_capacity(mock, 256, 10);
270
271        // Read 100 bytes from offset 50
272        let result = cache.read_exact_at(50, 100).await.unwrap();
273        assert_eq!(result.len(), 100);
274        assert_eq!(&result[..], &data[50..150]);
275
276        // Should have made 1 read (fetched block 0)
277        assert_eq!(cache.inner.read_count(), 1);
278
279        // Read again from same block - should hit cache
280        let result2 = cache.read_exact_at(10, 50).await.unwrap();
281        assert_eq!(&result2[..], &data[10..60]);
282
283        // Still just 1 read (cache hit)
284        assert_eq!(cache.inner.read_count(), 1);
285    }
286
287    #[tokio::test]
288    async fn test_multi_block_read() {
289        // Create mock with 1KB of data
290        let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
291        let mock = MockReader::new(data.clone());
292
293        // Use small 256-byte blocks
294        let cache = BlockCache::with_capacity(mock, 256, 10);
295
296        // Read 300 bytes starting at offset 100
297        // This spans blocks 0 (bytes 0-255) and 1 (bytes 256-511)
298        let result = cache.read_exact_at(100, 300).await.unwrap();
299        assert_eq!(result.len(), 300);
300        assert_eq!(&result[..], &data[100..400]);
301
302        // Should have made 2 reads (blocks 0 and 1)
303        assert_eq!(cache.inner.read_count(), 2);
304    }
305
306    #[tokio::test]
307    async fn test_cache_eviction() {
308        let data: Vec<u8> = (0..2048).map(|i| (i % 256) as u8).collect();
309        let mock = MockReader::new(data);
310
311        // Small cache that can only hold 2 blocks
312        let cache = BlockCache::with_capacity(mock, 256, 2);
313
314        // Read from blocks 0, 1, 2 (will evict block 0)
315        cache.read_exact_at(0, 10).await.unwrap(); // Block 0
316        cache.read_exact_at(256, 10).await.unwrap(); // Block 1
317        cache.read_exact_at(512, 10).await.unwrap(); // Block 2, evicts block 0
318
319        assert_eq!(cache.inner.read_count(), 3);
320
321        // Read block 1 again - should hit cache
322        cache.read_exact_at(300, 10).await.unwrap();
323        assert_eq!(cache.inner.read_count(), 3);
324
325        // Read block 0 again - cache miss (was evicted)
326        cache.read_exact_at(0, 10).await.unwrap();
327        assert_eq!(cache.inner.read_count(), 4);
328    }
329
330    #[tokio::test]
331    async fn test_concurrent_reads_singleflight() {
332        use std::sync::atomic::AtomicBool;
333        use tokio::time::{sleep, Duration};
334
335        /// Slow mock reader that takes 50ms per read
336        struct SlowMockReader {
337            data: Bytes,
338            read_count: AtomicUsize,
339            is_reading: AtomicBool,
340        }
341
342        impl SlowMockReader {
343            fn new(data: Vec<u8>) -> Self {
344                Self {
345                    data: Bytes::from(data),
346                    read_count: AtomicUsize::new(0),
347                    is_reading: AtomicBool::new(false),
348                }
349            }
350        }
351
352        #[async_trait]
353        impl RangeReader for SlowMockReader {
354            async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
355                // Check if another read is in progress (would indicate singleflight failure)
356                let was_reading = self.is_reading.swap(true, Ordering::SeqCst);
357                assert!(
358                    !was_reading,
359                    "Concurrent reads detected - singleflight failed!"
360                );
361
362                self.read_count.fetch_add(1, Ordering::SeqCst);
363                sleep(Duration::from_millis(50)).await;
364
365                self.is_reading.store(false, Ordering::SeqCst);
366
367                Ok(self.data.slice(offset as usize..offset as usize + len))
368            }
369
370            fn size(&self) -> u64 {
371                self.data.len() as u64
372            }
373
374            fn identifier(&self) -> &str {
375                "slow://test"
376            }
377        }
378
379        let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
380        let mock = SlowMockReader::new(data);
381        let cache = Arc::new(BlockCache::with_capacity(mock, 256, 10));
382
383        // Spawn 10 concurrent reads for the same block
384        let mut handles = Vec::new();
385        for _ in 0..10 {
386            let cache = cache.clone();
387            handles.push(tokio::spawn(async move {
388                cache.read_exact_at(50, 100).await.unwrap()
389            }));
390        }
391
392        // Wait for all reads
393        for handle in handles {
394            handle.await.unwrap();
395        }
396
397        // Should have made only 1 read due to singleflight
398        assert_eq!(cache.inner.read_count.load(Ordering::SeqCst), 1);
399    }
400
401    #[tokio::test]
402    async fn test_out_of_bounds() {
403        let data: Vec<u8> = vec![1, 2, 3, 4, 5];
404        let mock = MockReader::new(data);
405        let cache = BlockCache::with_capacity(mock, 256, 10);
406
407        // Read past end of file
408        let result = cache.read_exact_at(3, 10).await;
409        assert!(matches!(result, Err(IoError::RangeOutOfBounds { .. })));
410    }
411
412    #[tokio::test]
413    async fn test_zero_length_read() {
414        let data: Vec<u8> = vec![1, 2, 3, 4, 5];
415        let mock = MockReader::new(data);
416        let cache = BlockCache::with_capacity(mock, 256, 10);
417
418        let result = cache.read_exact_at(0, 0).await.unwrap();
419        assert!(result.is_empty());
420
421        // No reads should have been made
422        assert_eq!(cache.inner.read_count(), 0);
423    }
424
425    #[tokio::test]
426    async fn test_last_partial_block() {
427        // Data that doesn't fill the last block completely
428        let data: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
429        let mock = MockReader::new(data.clone());
430        let cache = BlockCache::with_capacity(mock, 256, 10);
431
432        // Read from second block (which is partial: only 44 bytes)
433        let result = cache.read_exact_at(260, 30).await.unwrap();
434        assert_eq!(result.len(), 30);
435        assert_eq!(&result[..], &data[260..290]);
436    }
437}