1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use bytes::{Bytes, BytesMut};
6use lru::LruCache;
7use tokio::sync::{Mutex, Notify, RwLock};
8
9use super::RangeReader;
10use crate::error::IoError;
11
12pub const DEFAULT_BLOCK_SIZE: usize = 256 * 1024;
15
16const DEFAULT_CACHE_CAPACITY: usize = 100;
19
20pub struct BlockCache<R> {
33 inner: Arc<R>,
35 block_size: usize,
37 cache: RwLock<LruCache<u64, Bytes>>,
39 in_flight: Mutex<HashMap<u64, Arc<Notify>>>,
41}
42
43impl<R: RangeReader> BlockCache<R> {
44 pub fn new(inner: R) -> Self {
48 Self::with_capacity(inner, DEFAULT_BLOCK_SIZE, DEFAULT_CACHE_CAPACITY)
49 }
50
51 pub fn with_capacity(inner: R, block_size: usize, capacity: usize) -> Self {
58 Self {
59 inner: Arc::new(inner),
60 block_size,
61 cache: RwLock::new(LruCache::new(
62 std::num::NonZeroUsize::new(capacity).unwrap(),
63 )),
64 in_flight: Mutex::new(HashMap::new()),
65 }
66 }
67
68 async fn get_block(&self, block_idx: u64) -> Result<Bytes, IoError> {
73 loop {
74 {
76 let cache = self.cache.read().await;
77 if let Some(data) = cache.peek(&block_idx) {
78 return Ok(data.clone());
79 }
80 }
81
82 let notify = {
84 let mut in_flight = self.in_flight.lock().await;
85
86 if let Some(notify) = in_flight.get(&block_idx) {
87 let notify = notify.clone();
89 drop(in_flight);
90 notify.notified().await;
91 continue;
93 }
94
95 let notify = Arc::new(Notify::new());
97 in_flight.insert(block_idx, notify.clone());
98 notify
99 };
100
101 let result = self.fetch_block_from_source(block_idx).await;
103
104 {
106 let mut cache = self.cache.write().await;
107 let mut in_flight = self.in_flight.lock().await;
108
109 if let Ok(ref data) = result {
110 cache.put(block_idx, data.clone());
111 }
112
113 in_flight.remove(&block_idx);
114 }
115
116 notify.notify_waiters();
117
118 return result;
119 }
120 }
121
122 async fn fetch_block_from_source(&self, block_idx: u64) -> Result<Bytes, IoError> {
124 let offset = block_idx * self.block_size as u64;
125 let size = self.inner.size();
126
127 let remaining = size.saturating_sub(offset);
129 if remaining == 0 {
130 return Err(IoError::RangeOutOfBounds {
131 offset,
132 requested: self.block_size as u64,
133 size,
134 });
135 }
136
137 let len = std::cmp::min(self.block_size as u64, remaining) as usize;
138 self.inner.read_exact_at(offset, len).await
139 }
140
141 #[inline]
143 fn block_for_offset(&self, offset: u64) -> u64 {
144 offset / self.block_size as u64
145 }
146
147 #[inline]
149 fn offset_within_block(&self, offset: u64) -> usize {
150 (offset % self.block_size as u64) as usize
151 }
152}
153
154#[async_trait]
155impl<R: RangeReader + 'static> RangeReader for BlockCache<R> {
156 async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
157 let size = self.inner.size();
159 if offset + len as u64 > size {
160 return Err(IoError::RangeOutOfBounds {
161 offset,
162 requested: len as u64,
163 size,
164 });
165 }
166
167 if len == 0 {
169 return Ok(Bytes::new());
170 }
171
172 let start_block = self.block_for_offset(offset);
174 let end_block = self.block_for_offset(offset + len as u64 - 1);
175
176 if start_block == end_block {
177 let block = self.get_block(start_block).await?;
179 let block_offset = self.offset_within_block(offset);
180 Ok(block.slice(block_offset..block_offset + len))
181 } else {
182 let mut result = BytesMut::with_capacity(len);
184 let mut remaining = len;
185 let mut current_offset = offset;
186
187 for block_idx in start_block..=end_block {
188 let block = self.get_block(block_idx).await?;
189 let block_offset = self.offset_within_block(current_offset);
190 let bytes_in_block = std::cmp::min(block.len() - block_offset, remaining);
191
192 result.extend_from_slice(&block[block_offset..block_offset + bytes_in_block]);
193
194 remaining -= bytes_in_block;
195 current_offset += bytes_in_block as u64;
196 }
197
198 Ok(result.freeze())
199 }
200 }
201
202 fn size(&self) -> u64 {
203 self.inner.size()
204 }
205
206 fn identifier(&self) -> &str {
207 self.inner.identifier()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use std::sync::atomic::{AtomicUsize, Ordering};
215
216 struct MockReader {
218 data: Bytes,
219 identifier: String,
220 read_count: AtomicUsize,
221 }
222
223 impl MockReader {
224 fn new(data: Vec<u8>) -> Self {
225 Self {
226 data: Bytes::from(data),
227 identifier: "mock://test".to_string(),
228 read_count: AtomicUsize::new(0),
229 }
230 }
231
232 fn read_count(&self) -> usize {
233 self.read_count.load(Ordering::SeqCst)
234 }
235 }
236
237 #[async_trait]
238 impl RangeReader for MockReader {
239 async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
240 self.read_count.fetch_add(1, Ordering::SeqCst);
241
242 if offset + len as u64 > self.data.len() as u64 {
243 return Err(IoError::RangeOutOfBounds {
244 offset,
245 requested: len as u64,
246 size: self.data.len() as u64,
247 });
248 }
249
250 Ok(self.data.slice(offset as usize..offset as usize + len))
251 }
252
253 fn size(&self) -> u64 {
254 self.data.len() as u64
255 }
256
257 fn identifier(&self) -> &str {
258 &self.identifier
259 }
260 }
261
262 #[tokio::test]
263 async fn test_single_block_read() {
264 let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
266 let mock = MockReader::new(data.clone());
267
268 let cache = BlockCache::with_capacity(mock, 256, 10);
270
271 let result = cache.read_exact_at(50, 100).await.unwrap();
273 assert_eq!(result.len(), 100);
274 assert_eq!(&result[..], &data[50..150]);
275
276 assert_eq!(cache.inner.read_count(), 1);
278
279 let result2 = cache.read_exact_at(10, 50).await.unwrap();
281 assert_eq!(&result2[..], &data[10..60]);
282
283 assert_eq!(cache.inner.read_count(), 1);
285 }
286
287 #[tokio::test]
288 async fn test_multi_block_read() {
289 let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
291 let mock = MockReader::new(data.clone());
292
293 let cache = BlockCache::with_capacity(mock, 256, 10);
295
296 let result = cache.read_exact_at(100, 300).await.unwrap();
299 assert_eq!(result.len(), 300);
300 assert_eq!(&result[..], &data[100..400]);
301
302 assert_eq!(cache.inner.read_count(), 2);
304 }
305
306 #[tokio::test]
307 async fn test_cache_eviction() {
308 let data: Vec<u8> = (0..2048).map(|i| (i % 256) as u8).collect();
309 let mock = MockReader::new(data);
310
311 let cache = BlockCache::with_capacity(mock, 256, 2);
313
314 cache.read_exact_at(0, 10).await.unwrap(); cache.read_exact_at(256, 10).await.unwrap(); cache.read_exact_at(512, 10).await.unwrap(); assert_eq!(cache.inner.read_count(), 3);
320
321 cache.read_exact_at(300, 10).await.unwrap();
323 assert_eq!(cache.inner.read_count(), 3);
324
325 cache.read_exact_at(0, 10).await.unwrap();
327 assert_eq!(cache.inner.read_count(), 4);
328 }
329
330 #[tokio::test]
331 async fn test_concurrent_reads_singleflight() {
332 use std::sync::atomic::AtomicBool;
333 use tokio::time::{sleep, Duration};
334
335 struct SlowMockReader {
337 data: Bytes,
338 read_count: AtomicUsize,
339 is_reading: AtomicBool,
340 }
341
342 impl SlowMockReader {
343 fn new(data: Vec<u8>) -> Self {
344 Self {
345 data: Bytes::from(data),
346 read_count: AtomicUsize::new(0),
347 is_reading: AtomicBool::new(false),
348 }
349 }
350 }
351
352 #[async_trait]
353 impl RangeReader for SlowMockReader {
354 async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
355 let was_reading = self.is_reading.swap(true, Ordering::SeqCst);
357 assert!(
358 !was_reading,
359 "Concurrent reads detected - singleflight failed!"
360 );
361
362 self.read_count.fetch_add(1, Ordering::SeqCst);
363 sleep(Duration::from_millis(50)).await;
364
365 self.is_reading.store(false, Ordering::SeqCst);
366
367 Ok(self.data.slice(offset as usize..offset as usize + len))
368 }
369
370 fn size(&self) -> u64 {
371 self.data.len() as u64
372 }
373
374 fn identifier(&self) -> &str {
375 "slow://test"
376 }
377 }
378
379 let data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
380 let mock = SlowMockReader::new(data);
381 let cache = Arc::new(BlockCache::with_capacity(mock, 256, 10));
382
383 let mut handles = Vec::new();
385 for _ in 0..10 {
386 let cache = cache.clone();
387 handles.push(tokio::spawn(async move {
388 cache.read_exact_at(50, 100).await.unwrap()
389 }));
390 }
391
392 for handle in handles {
394 handle.await.unwrap();
395 }
396
397 assert_eq!(cache.inner.read_count.load(Ordering::SeqCst), 1);
399 }
400
401 #[tokio::test]
402 async fn test_out_of_bounds() {
403 let data: Vec<u8> = vec![1, 2, 3, 4, 5];
404 let mock = MockReader::new(data);
405 let cache = BlockCache::with_capacity(mock, 256, 10);
406
407 let result = cache.read_exact_at(3, 10).await;
409 assert!(matches!(result, Err(IoError::RangeOutOfBounds { .. })));
410 }
411
412 #[tokio::test]
413 async fn test_zero_length_read() {
414 let data: Vec<u8> = vec![1, 2, 3, 4, 5];
415 let mock = MockReader::new(data);
416 let cache = BlockCache::with_capacity(mock, 256, 10);
417
418 let result = cache.read_exact_at(0, 0).await.unwrap();
419 assert!(result.is_empty());
420
421 assert_eq!(cache.inner.read_count(), 0);
423 }
424
425 #[tokio::test]
426 async fn test_last_partial_block() {
427 let data: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
429 let mock = MockReader::new(data.clone());
430 let cache = BlockCache::with_capacity(mock, 256, 10);
431
432 let result = cache.read_exact_at(260, 30).await.unwrap();
434 assert_eq!(result.len(), 30);
435 assert_eq!(&result[..], &data[260..290]);
436 }
437}