Skip to main content

ruvector_temporal_tensor/
coherence.rs

1//! Coherence gate: read-after-write validation for the temporal tensor store.
2//!
3//! Ensures data integrity by verifying that a `get()` immediately after `put()`
4//! returns data within the expected quantization error bounds for the tier.
5//!
6//! # Overview
7//!
8//! Quantization is lossy -- the error introduced depends on the tier's bit
9//! width (8-bit for Tier1, 7-bit for Tier2, 3-bit for Tier3).  The coherence
10//! gate validates that the round-trip error stays within configurable
11//! per-tier bounds, catching silent corruption or encoding bugs.
12//!
13//! # Epoch Tracking
14//!
15//! [`EpochTracker`] provides a lightweight write-epoch mechanism so that
16//! readers can detect stale data (i.e. data that was overwritten between
17//! the time it was read and the time it was consumed).
18
19use std::collections::HashMap;
20
21use crate::store::{BlockKey, StoreError, Tier, TieredStore};
22
23// ---------------------------------------------------------------------------
24// CoherenceResult
25// ---------------------------------------------------------------------------
26
27/// Outcome of a coherence check.
28#[derive(Clone, Debug, PartialEq)]
29pub struct CoherenceResult {
30    /// Maximum relative error observed across all elements.
31    pub max_error: f32,
32    /// The tier at which the block is stored.
33    pub tier: Tier,
34    /// Whether the observed error is within the configured bound for this tier.
35    pub passed: bool,
36}
37
38// ---------------------------------------------------------------------------
39// CoherenceCheck
40// ---------------------------------------------------------------------------
41
42/// Per-tier maximum relative error bounds for read-after-write validation.
43///
44/// After a `put()`, the block is immediately read back and the maximum
45/// relative error (per-element `|orig - decoded| / |orig|`) is compared
46/// against the bound for the block's current tier.
47#[derive(Clone, Debug)]
48pub struct CoherenceCheck {
49    /// Maximum acceptable relative error for each tier, indexed by
50    /// `Tier as usize`: `[Tier0, Tier1, Tier2, Tier3]`.
51    ///
52    /// Tier0 (evicted) has no payload, so any read will fail before the
53    /// error comparison is reached.  The bound is set to `f32::MAX` as a
54    /// sentinel.
55    pub max_relative_errors: [f32; 4],
56}
57
58impl Default for CoherenceCheck {
59    fn default() -> Self {
60        Self {
61            // Tier0: evicted, reads always fail (sentinel value).
62            // Tier1: 8-bit, very tight bound.
63            // Tier2: 7-bit, slightly looser.
64            // Tier3: 3-bit, aggressive quantization allows up to 35% error.
65            max_relative_errors: [f32::MAX, 0.01, 0.02, 0.35],
66        }
67    }
68}
69
70impl CoherenceCheck {
71    /// Create a `CoherenceCheck` with custom per-tier error bounds.
72    pub fn new(max_relative_errors: [f32; 4]) -> Self {
73        Self { max_relative_errors }
74    }
75
76    /// Validate read-after-write coherence for a block that was just written.
77    ///
78    /// Reads the block back from `store`, computes the maximum relative
79    /// error against `original_data`, and checks whether it falls within
80    /// the configured bound for the block's tier.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`StoreError::BlockNotFound`] if the key does not exist,
85    /// [`StoreError::TensorEvicted`] if the block is in Tier0, or any
86    /// other `StoreError` from the underlying read.
87    pub fn check_coherence(
88        &self,
89        store: &mut TieredStore,
90        key: BlockKey,
91        original_data: &[f32],
92        now: u64,
93    ) -> Result<CoherenceResult, StoreError> {
94        // Look up the tier before reading (needed for the error bound).
95        let tier = store
96            .meta(key)
97            .ok_or(StoreError::BlockNotFound)?
98            .tier;
99
100        // Read back the block.
101        let mut buf = vec![0.0f32; original_data.len()];
102        let n = store.get(key, &mut buf, now)?;
103
104        // Compute the maximum relative error.
105        let max_error = compute_max_relative_error(original_data, &buf[..n]);
106
107        let tier_idx = tier as usize;
108        let bound = if tier_idx < self.max_relative_errors.len() {
109            self.max_relative_errors[tier_idx]
110        } else {
111            f32::MAX
112        };
113
114        Ok(CoherenceResult {
115            max_error,
116            tier,
117            passed: max_error <= bound,
118        })
119    }
120
121    /// Convenience: `put` followed by `check_coherence` in one call.
122    ///
123    /// Stores the data at the given tier, then immediately reads it back
124    /// and validates the round-trip error.  Returns the coherence result
125    /// so the caller can decide whether to retry at a higher-fidelity tier.
126    ///
127    /// # Errors
128    ///
129    /// Propagates errors from both `put` and the subsequent `get`.
130    pub fn verify_put(
131        &self,
132        store: &mut TieredStore,
133        key: BlockKey,
134        data: &[f32],
135        tier: Tier,
136        now: u64,
137    ) -> Result<CoherenceResult, StoreError> {
138        store.put(key, data, tier, now)?;
139        self.check_coherence(store, key, data, now)
140    }
141}
142
143// ---------------------------------------------------------------------------
144// Helper: relative error computation
145// ---------------------------------------------------------------------------
146
147/// Compute the maximum element-wise relative error between `original` and
148/// `decoded`.
149///
150/// For elements where `|original| < epsilon` (near-zero), the absolute
151/// error is used directly to avoid division-by-zero amplification.
152fn compute_max_relative_error(original: &[f32], decoded: &[f32]) -> f32 {
153    const EPSILON: f32 = 1e-6;
154
155    let len = original.len().min(decoded.len());
156    let mut max_err: f32 = 0.0;
157
158    for i in 0..len {
159        let orig = original[i];
160        let dec = decoded[i];
161        let abs_err = (orig - dec).abs();
162
163        let rel_err = if orig.abs() > EPSILON {
164            abs_err / orig.abs()
165        } else {
166            abs_err
167        };
168
169        if rel_err > max_err {
170            max_err = rel_err;
171        }
172    }
173
174    max_err
175}
176
177// ---------------------------------------------------------------------------
178// EpochTracker
179// ---------------------------------------------------------------------------
180
181/// Monotonic write-epoch tracker keyed by [`BlockKey`].
182///
183/// Each call to [`record_write`](EpochTracker::record_write) increments a
184/// global counter and associates the new epoch with the given key.  Readers
185/// can later check whether their snapshot is stale via
186/// [`is_stale`](EpochTracker::is_stale).
187#[derive(Clone, Debug)]
188pub struct EpochTracker {
189    /// Global monotonically increasing write counter.
190    next_epoch: u64,
191    /// Per-key latest write epoch.
192    epochs: HashMap<BlockKey, u64>,
193}
194
195impl EpochTracker {
196    /// Create a new tracker with epoch starting at 1.
197    pub fn new() -> Self {
198        Self {
199            next_epoch: 1,
200            epochs: HashMap::new(),
201        }
202    }
203
204    /// Record a write for `key`, returning the new epoch number.
205    ///
206    /// The epoch is strictly monotonically increasing across all keys.
207    pub fn record_write(&mut self, key: BlockKey) -> u64 {
208        let epoch = self.next_epoch;
209        self.next_epoch += 1;
210        self.epochs.insert(key, epoch);
211        epoch
212    }
213
214    /// Return the latest write epoch for `key`, if any write has been recorded.
215    pub fn check_epoch(&self, key: BlockKey) -> Option<u64> {
216        self.epochs.get(&key).copied()
217    }
218
219    /// Returns `true` if the block identified by `key` has been written
220    /// after `read_epoch`, meaning the reader's snapshot is stale.
221    ///
222    /// Returns `false` if no write has been recorded for `key` (the key
223    /// does not exist in the tracker).
224    pub fn is_stale(&self, key: BlockKey, read_epoch: u64) -> bool {
225        match self.epochs.get(&key) {
226            Some(&write_epoch) => write_epoch > read_epoch,
227            None => false,
228        }
229    }
230}
231
232impl Default for EpochTracker {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238// ---------------------------------------------------------------------------
239// Tests
240// ---------------------------------------------------------------------------
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::store::{BlockKey, Tier, TieredStore};
246
247    fn make_key(tid: u128, idx: u32) -> BlockKey {
248        BlockKey {
249            tensor_id: tid,
250            block_index: idx,
251        }
252    }
253
254    // -- CoherenceCheck -----------------------------------------------------
255
256    #[test]
257    fn test_coherence_check_default_bounds() {
258        let cc = CoherenceCheck::default();
259        assert_eq!(cc.max_relative_errors[0], f32::MAX);
260        assert!((cc.max_relative_errors[1] - 0.01).abs() < 1e-9);
261        assert!((cc.max_relative_errors[2] - 0.02).abs() < 1e-9);
262        assert!((cc.max_relative_errors[3] - 0.35).abs() < 1e-9);
263    }
264
265    #[test]
266    fn test_coherence_check_custom_bounds() {
267        let bounds = [0.0, 0.05, 0.10, 0.50];
268        let cc = CoherenceCheck::new(bounds);
269        assert_eq!(cc.max_relative_errors, bounds);
270    }
271
272    #[test]
273    fn test_check_coherence_tier1_passes() {
274        let mut store = TieredStore::new(4096);
275        let key = make_key(1, 0);
276        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.25).collect();
277
278        store.put(key, &data, Tier::Tier1, 0).unwrap();
279
280        let cc = CoherenceCheck::default();
281        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
282
283        assert_eq!(result.tier, Tier::Tier1);
284        assert!(
285            result.passed,
286            "Tier1 coherence should pass; max_error={}, bound={}",
287            result.max_error,
288            cc.max_relative_errors[1],
289        );
290        assert!(
291            result.max_error < cc.max_relative_errors[1],
292            "max_error {} should be < bound {}",
293            result.max_error,
294            cc.max_relative_errors[1],
295        );
296    }
297
298    #[test]
299    fn test_check_coherence_tier3_passes() {
300        let mut store = TieredStore::new(4096);
301        let key = make_key(2, 0);
302        // Use values with large magnitude to keep relative error low under
303        // 3-bit quantization (only 7 levels).  Avoid near-zero values where
304        // even small absolute error produces large relative error.
305        let data: Vec<f32> = (0..32).map(|i| 10.0 + (i as f32) * 0.1).collect();
306
307        store.put(key, &data, Tier::Tier3, 0).unwrap();
308
309        let cc = CoherenceCheck::default();
310        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
311
312        assert_eq!(result.tier, Tier::Tier3);
313        assert!(
314            result.passed,
315            "Tier3 coherence should pass with default 0.35 bound; max_error={}",
316            result.max_error,
317        );
318    }
319
320    #[test]
321    fn test_check_coherence_missing_block() {
322        let mut store = TieredStore::new(4096);
323        let key = make_key(99, 0);
324        let data = vec![1.0f32; 8];
325        let cc = CoherenceCheck::default();
326
327        let err = cc.check_coherence(&mut store, key, &data, 0);
328        assert_eq!(err, Err(StoreError::BlockNotFound));
329    }
330
331    #[test]
332    fn test_check_coherence_evicted_block() {
333        use crate::store::ReconstructPolicy;
334
335        let mut store = TieredStore::new(4096);
336        let key = make_key(3, 0);
337        let data = vec![1.0f32; 16];
338
339        store.put(key, &data, Tier::Tier1, 0).unwrap();
340        store.evict(key, ReconstructPolicy::None).unwrap();
341
342        let cc = CoherenceCheck::default();
343        let err = cc.check_coherence(&mut store, key, &data, 1);
344        assert_eq!(err, Err(StoreError::TensorEvicted));
345    }
346
347    #[test]
348    fn test_check_coherence_tight_bound_fails() {
349        let mut store = TieredStore::new(4096);
350        let key = make_key(4, 0);
351        // Data with large dynamic range to maximize quantization error.
352        let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 10.0).collect();
353
354        // Store at Tier3 (3-bit) for maximum quantization error.
355        store.put(key, &data, Tier::Tier3, 0).unwrap();
356
357        // Use an extremely tight bound that 3-bit quantization cannot meet.
358        let cc = CoherenceCheck::new([f32::MAX, 0.001, 0.001, 0.001]);
359        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
360
361        assert_eq!(result.tier, Tier::Tier3);
362        assert!(
363            !result.passed,
364            "Tier3 with 0.001 bound should fail; max_error={}",
365            result.max_error,
366        );
367    }
368
369    // -- verify_put ---------------------------------------------------------
370
371    #[test]
372    fn test_verify_put_tier1() {
373        let mut store = TieredStore::new(4096);
374        let key = make_key(10, 0);
375        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
376
377        let cc = CoherenceCheck::default();
378        let result = cc.verify_put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
379
380        assert_eq!(result.tier, Tier::Tier1);
381        assert!(result.passed, "verify_put Tier1 should pass");
382        assert_eq!(store.block_count(), 1);
383    }
384
385    #[test]
386    fn test_verify_put_tier0_rejected() {
387        let mut store = TieredStore::new(4096);
388        let key = make_key(11, 0);
389        let data = vec![1.0f32; 16];
390
391        let cc = CoherenceCheck::default();
392        let err = cc.verify_put(&mut store, key, &data, Tier::Tier0, 0);
393        assert_eq!(err, Err(StoreError::InvalidBlock));
394    }
395
396    #[test]
397    fn test_verify_put_tier2() {
398        let mut store = TieredStore::new(4096);
399        let key = make_key(12, 0);
400        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.3).collect();
401
402        let cc = CoherenceCheck::default();
403        let result = cc.verify_put(&mut store, key, &data, Tier::Tier2, 0).unwrap();
404
405        assert_eq!(result.tier, Tier::Tier2);
406        assert!(result.passed, "verify_put Tier2 should pass; max_error={}", result.max_error);
407    }
408
409    // -- compute_max_relative_error -----------------------------------------
410
411    #[test]
412    fn test_relative_error_identical() {
413        let a = vec![1.0, 2.0, 3.0];
414        let b = vec![1.0, 2.0, 3.0];
415        assert_eq!(compute_max_relative_error(&a, &b), 0.0);
416    }
417
418    #[test]
419    fn test_relative_error_known() {
420        let original = vec![10.0, 20.0, 50.0];
421        let decoded = vec![10.5, 20.0, 48.0];
422        let err = compute_max_relative_error(&original, &decoded);
423        // Element 0: |0.5| / 10.0 = 0.05
424        // Element 1: 0.0
425        // Element 2: |2.0| / 50.0 = 0.04
426        assert!((err - 0.05).abs() < 1e-6, "expected 0.05, got {err}");
427    }
428
429    #[test]
430    fn test_relative_error_near_zero() {
431        // Near-zero original values should use absolute error.
432        let original = vec![0.0, 1e-8, 1.0];
433        let decoded = vec![0.001, 0.0, 1.0];
434        let err = compute_max_relative_error(&original, &decoded);
435        // Element 0: |0.001| (absolute, since orig < epsilon)
436        // Element 1: |1e-8| (absolute, since orig < epsilon)
437        // Element 2: 0.0
438        assert!((err - 0.001).abs() < 1e-6, "expected ~0.001, got {err}");
439    }
440
441    #[test]
442    fn test_relative_error_empty() {
443        assert_eq!(compute_max_relative_error(&[], &[]), 0.0);
444    }
445
446    #[test]
447    fn test_relative_error_mismatched_lengths() {
448        let a = vec![1.0, 2.0, 3.0];
449        let b = vec![1.0, 2.0];
450        // Should only compare up to min(len(a), len(b)) = 2 elements.
451        let err = compute_max_relative_error(&a, &b);
452        assert_eq!(err, 0.0);
453    }
454
455    // -- EpochTracker -------------------------------------------------------
456
457    #[test]
458    fn test_epoch_tracker_new() {
459        let tracker = EpochTracker::new();
460        let key = make_key(1, 0);
461        assert_eq!(tracker.check_epoch(key), None);
462        assert!(!tracker.is_stale(key, 0));
463    }
464
465    #[test]
466    fn test_epoch_tracker_record_write() {
467        let mut tracker = EpochTracker::new();
468        let key = make_key(1, 0);
469
470        let e1 = tracker.record_write(key);
471        assert_eq!(e1, 1);
472        assert_eq!(tracker.check_epoch(key), Some(1));
473
474        let e2 = tracker.record_write(key);
475        assert_eq!(e2, 2);
476        assert_eq!(tracker.check_epoch(key), Some(2));
477    }
478
479    #[test]
480    fn test_epoch_tracker_monotonic_across_keys() {
481        let mut tracker = EpochTracker::new();
482        let key_a = make_key(1, 0);
483        let key_b = make_key(2, 0);
484
485        let e1 = tracker.record_write(key_a);
486        let e2 = tracker.record_write(key_b);
487        let e3 = tracker.record_write(key_a);
488
489        assert_eq!(e1, 1);
490        assert_eq!(e2, 2);
491        assert_eq!(e3, 3);
492
493        assert_eq!(tracker.check_epoch(key_a), Some(3));
494        assert_eq!(tracker.check_epoch(key_b), Some(2));
495    }
496
497    #[test]
498    fn test_epoch_tracker_is_stale() {
499        let mut tracker = EpochTracker::new();
500        let key = make_key(1, 0);
501
502        let epoch = tracker.record_write(key);
503        assert!(!tracker.is_stale(key, epoch), "same epoch should not be stale");
504        assert!(!tracker.is_stale(key, epoch + 1), "future epoch should not be stale");
505
506        // Write again -> epoch advances.
507        let _e2 = tracker.record_write(key);
508        assert!(
509            tracker.is_stale(key, epoch),
510            "old epoch should now be stale after a new write"
511        );
512    }
513
514    #[test]
515    fn test_epoch_tracker_unknown_key_not_stale() {
516        let tracker = EpochTracker::new();
517        let key = make_key(99, 0);
518        assert!(!tracker.is_stale(key, 0));
519        assert!(!tracker.is_stale(key, u64::MAX));
520    }
521
522    #[test]
523    fn test_epoch_tracker_multiple_keys_independent() {
524        let mut tracker = EpochTracker::new();
525        let key_a = make_key(1, 0);
526        let key_b = make_key(2, 0);
527
528        let ea = tracker.record_write(key_a);
529        let _eb = tracker.record_write(key_b);
530
531        // Writing key_b should not make key_a stale at its own epoch.
532        assert!(!tracker.is_stale(key_a, ea));
533    }
534
535    #[test]
536    fn test_epoch_tracker_default_trait() {
537        let tracker = EpochTracker::default();
538        assert_eq!(tracker.check_epoch(make_key(1, 0)), None);
539    }
540}