zarr_datafusion/reader/
tracked_store.rs

1//! A storage adapter that tracks disk I/O bytes.
2//!
3//! Wraps any storage backend and counts actual bytes read from disk.
4//! Based on zarrs_storage's UsageLogStorageAdapter pattern.
5
6use std::sync::Arc;
7
8use zarrs::storage::{
9    byte_range::{ByteRange, ByteRangeIterator},
10    ListableStorageTraits, MaybeBytes, MaybeBytesIterator, ReadableStorageTraits, StorageError,
11    StoreKey, StoreKeys, StoreKeysPrefixes, StorePrefix,
12};
13
14use super::stats::SharedIoStats;
15
16/// Storage adapter that tracks bytes read from disk.
17///
18/// Wraps an inner storage and accumulates byte counts into shared stats.
19#[derive(Debug)]
20pub struct TrackedStore<S> {
21    inner: Arc<S>,
22    stats: SharedIoStats,
23}
24
25impl<S> TrackedStore<S> {
26    /// Create a new tracked store wrapping the given storage.
27    pub fn new(inner: Arc<S>, stats: SharedIoStats) -> Self {
28        Self { inner, stats }
29    }
30}
31
32impl<S: ReadableStorageTraits> ReadableStorageTraits for TrackedStore<S> {
33    fn get(&self, key: &StoreKey) -> Result<MaybeBytes, StorageError> {
34        let result = self.inner.get(key)?;
35
36        // Track actual bytes read from disk
37        if let Some(ref bytes) = result {
38            self.stats.record_disk_read(bytes.len() as u64);
39        }
40
41        Ok(result)
42    }
43
44    fn get_partial_many<'a>(
45        &'a self,
46        key: &StoreKey,
47        byte_ranges: ByteRangeIterator<'a>,
48    ) -> Result<MaybeBytesIterator<'a>, StorageError> {
49        // Collect ranges to allow reuse
50        let ranges: Vec<ByteRange> = byte_ranges.collect();
51
52        let result = self
53            .inner
54            .get_partial_many(key, Box::new(ranges.into_iter()))?;
55
56        // Track bytes - we need to consume the iterator to count, then recreate it
57        if let Some(iter) = result {
58            let bytes_vec: Vec<_> = iter.collect::<Result<Vec<_>, _>>()?;
59            let total_bytes: u64 = bytes_vec.iter().map(|b| b.len() as u64).sum();
60            self.stats.record_disk_read(total_bytes);
61
62            // Return a new iterator over the collected bytes
63            Ok(Some(Box::new(bytes_vec.into_iter().map(Ok))))
64        } else {
65            Ok(None)
66        }
67    }
68
69    fn size_key(&self, key: &StoreKey) -> Result<Option<u64>, StorageError> {
70        self.inner.size_key(key)
71    }
72
73    fn supports_get_partial(&self) -> bool {
74        self.inner.supports_get_partial()
75    }
76}
77
78impl<S: ListableStorageTraits> ListableStorageTraits for TrackedStore<S> {
79    fn list(&self) -> Result<StoreKeys, StorageError> {
80        self.inner.list()
81    }
82
83    fn list_prefix(&self, prefix: &StorePrefix) -> Result<StoreKeys, StorageError> {
84        self.inner.list_prefix(prefix)
85    }
86
87    fn list_dir(&self, prefix: &StorePrefix) -> Result<StoreKeysPrefixes, StorageError> {
88        self.inner.list_dir(prefix)
89    }
90
91    fn size(&self) -> Result<u64, StorageError> {
92        self.inner.size()
93    }
94
95    fn size_prefix(&self, prefix: &StorePrefix) -> Result<u64, StorageError> {
96        self.inner.size_prefix(prefix)
97    }
98}