Skip to main content

ruvector_temporal_tensor/
coherence.rs

1//! Coherence gate: read-after-write validation for the temporal tensor store.
2//!
3//! Ensures data integrity by verifying that a `get()` immediately after `put()`
4//! returns data within the expected quantization error bounds for the tier.
5//!
6//! # Overview
7//!
8//! Quantization is lossy -- the error introduced depends on the tier's bit
9//! width (8-bit for Tier1, 7-bit for Tier2, 3-bit for Tier3).  The coherence
10//! gate validates that the round-trip error stays within configurable
11//! per-tier bounds, catching silent corruption or encoding bugs.
12//!
13//! # Epoch Tracking
14//!
15//! [`EpochTracker`] provides a lightweight write-epoch mechanism so that
16//! readers can detect stale data (i.e. data that was overwritten between
17//! the time it was read and the time it was consumed).
18
19use std::collections::HashMap;
20
21use crate::store::{BlockKey, StoreError, Tier, TieredStore};
22
23// ---------------------------------------------------------------------------
24// CoherenceResult
25// ---------------------------------------------------------------------------
26
27/// Outcome of a coherence check.
28#[derive(Clone, Debug, PartialEq)]
29pub struct CoherenceResult {
30    /// Maximum relative error observed across all elements.
31    pub max_error: f32,
32    /// The tier at which the block is stored.
33    pub tier: Tier,
34    /// Whether the observed error is within the configured bound for this tier.
35    pub passed: bool,
36}
37
38// ---------------------------------------------------------------------------
39// CoherenceCheck
40// ---------------------------------------------------------------------------
41
42/// Per-tier maximum relative error bounds for read-after-write validation.
43///
44/// After a `put()`, the block is immediately read back and the maximum
45/// relative error (per-element `|orig - decoded| / |orig|`) is compared
46/// against the bound for the block's current tier.
47#[derive(Clone, Debug)]
48pub struct CoherenceCheck {
49    /// Maximum acceptable relative error for each tier, indexed by
50    /// `Tier as usize`: `[Tier0, Tier1, Tier2, Tier3]`.
51    ///
52    /// Tier0 (evicted) has no payload, so any read will fail before the
53    /// error comparison is reached.  The bound is set to `f32::MAX` as a
54    /// sentinel.
55    pub max_relative_errors: [f32; 4],
56}
57
58impl Default for CoherenceCheck {
59    fn default() -> Self {
60        Self {
61            // Tier0: evicted, reads always fail (sentinel value).
62            // Tier1: 8-bit, very tight bound.
63            // Tier2: 7-bit, slightly looser.
64            // Tier3: 3-bit, aggressive quantization allows up to 35% error.
65            max_relative_errors: [f32::MAX, 0.01, 0.02, 0.35],
66        }
67    }
68}
69
70impl CoherenceCheck {
71    /// Create a `CoherenceCheck` with custom per-tier error bounds.
72    pub fn new(max_relative_errors: [f32; 4]) -> Self {
73        Self {
74            max_relative_errors,
75        }
76    }
77
78    /// Validate read-after-write coherence for a block that was just written.
79    ///
80    /// Reads the block back from `store`, computes the maximum relative
81    /// error against `original_data`, and checks whether it falls within
82    /// the configured bound for the block's tier.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`StoreError::BlockNotFound`] if the key does not exist,
87    /// [`StoreError::TensorEvicted`] if the block is in Tier0, or any
88    /// other `StoreError` from the underlying read.
89    pub fn check_coherence(
90        &self,
91        store: &mut TieredStore,
92        key: BlockKey,
93        original_data: &[f32],
94        now: u64,
95    ) -> Result<CoherenceResult, StoreError> {
96        // Look up the tier before reading (needed for the error bound).
97        let tier = store.meta(key).ok_or(StoreError::BlockNotFound)?.tier;
98
99        // Read back the block.
100        let mut buf = vec![0.0f32; original_data.len()];
101        let n = store.get(key, &mut buf, now)?;
102
103        // Compute the maximum relative error.
104        let max_error = compute_max_relative_error(original_data, &buf[..n]);
105
106        let tier_idx = tier as usize;
107        let bound = if tier_idx < self.max_relative_errors.len() {
108            self.max_relative_errors[tier_idx]
109        } else {
110            f32::MAX
111        };
112
113        Ok(CoherenceResult {
114            max_error,
115            tier,
116            passed: max_error <= bound,
117        })
118    }
119
120    /// Convenience: `put` followed by `check_coherence` in one call.
121    ///
122    /// Stores the data at the given tier, then immediately reads it back
123    /// and validates the round-trip error.  Returns the coherence result
124    /// so the caller can decide whether to retry at a higher-fidelity tier.
125    ///
126    /// # Errors
127    ///
128    /// Propagates errors from both `put` and the subsequent `get`.
129    pub fn verify_put(
130        &self,
131        store: &mut TieredStore,
132        key: BlockKey,
133        data: &[f32],
134        tier: Tier,
135        now: u64,
136    ) -> Result<CoherenceResult, StoreError> {
137        store.put(key, data, tier, now)?;
138        self.check_coherence(store, key, data, now)
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Helper: relative error computation
144// ---------------------------------------------------------------------------
145
146/// Compute the maximum element-wise relative error between `original` and
147/// `decoded`.
148///
149/// For elements where `|original| < epsilon` (near-zero), the absolute
150/// error is used directly to avoid division-by-zero amplification.
151fn compute_max_relative_error(original: &[f32], decoded: &[f32]) -> f32 {
152    const EPSILON: f32 = 1e-6;
153
154    let len = original.len().min(decoded.len());
155    let mut max_err: f32 = 0.0;
156
157    for i in 0..len {
158        let orig = original[i];
159        let dec = decoded[i];
160        let abs_err = (orig - dec).abs();
161
162        let rel_err = if orig.abs() > EPSILON {
163            abs_err / orig.abs()
164        } else {
165            abs_err
166        };
167
168        if rel_err > max_err {
169            max_err = rel_err;
170        }
171    }
172
173    max_err
174}
175
176// ---------------------------------------------------------------------------
177// EpochTracker
178// ---------------------------------------------------------------------------
179
180/// Monotonic write-epoch tracker keyed by [`BlockKey`].
181///
182/// Each call to [`record_write`](EpochTracker::record_write) increments a
183/// global counter and associates the new epoch with the given key.  Readers
184/// can later check whether their snapshot is stale via
185/// [`is_stale`](EpochTracker::is_stale).
186#[derive(Clone, Debug)]
187pub struct EpochTracker {
188    /// Global monotonically increasing write counter.
189    next_epoch: u64,
190    /// Per-key latest write epoch.
191    epochs: HashMap<BlockKey, u64>,
192}
193
194impl EpochTracker {
195    /// Create a new tracker with epoch starting at 1.
196    pub fn new() -> Self {
197        Self {
198            next_epoch: 1,
199            epochs: HashMap::new(),
200        }
201    }
202
203    /// Record a write for `key`, returning the new epoch number.
204    ///
205    /// The epoch is strictly monotonically increasing across all keys.
206    pub fn record_write(&mut self, key: BlockKey) -> u64 {
207        let epoch = self.next_epoch;
208        self.next_epoch += 1;
209        self.epochs.insert(key, epoch);
210        epoch
211    }
212
213    /// Return the latest write epoch for `key`, if any write has been recorded.
214    pub fn check_epoch(&self, key: BlockKey) -> Option<u64> {
215        self.epochs.get(&key).copied()
216    }
217
218    /// Returns `true` if the block identified by `key` has been written
219    /// after `read_epoch`, meaning the reader's snapshot is stale.
220    ///
221    /// Returns `false` if no write has been recorded for `key` (the key
222    /// does not exist in the tracker).
223    pub fn is_stale(&self, key: BlockKey, read_epoch: u64) -> bool {
224        match self.epochs.get(&key) {
225            Some(&write_epoch) => write_epoch > read_epoch,
226            None => false,
227        }
228    }
229}
230
231impl Default for EpochTracker {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237// ---------------------------------------------------------------------------
238// Tests
239// ---------------------------------------------------------------------------
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::store::{BlockKey, Tier, TieredStore};
245
246    fn make_key(tid: u128, idx: u32) -> BlockKey {
247        BlockKey {
248            tensor_id: tid,
249            block_index: idx,
250        }
251    }
252
253    // -- CoherenceCheck -----------------------------------------------------
254
255    #[test]
256    fn test_coherence_check_default_bounds() {
257        let cc = CoherenceCheck::default();
258        assert_eq!(cc.max_relative_errors[0], f32::MAX);
259        assert!((cc.max_relative_errors[1] - 0.01).abs() < 1e-9);
260        assert!((cc.max_relative_errors[2] - 0.02).abs() < 1e-9);
261        assert!((cc.max_relative_errors[3] - 0.35).abs() < 1e-9);
262    }
263
264    #[test]
265    fn test_coherence_check_custom_bounds() {
266        let bounds = [0.0, 0.05, 0.10, 0.50];
267        let cc = CoherenceCheck::new(bounds);
268        assert_eq!(cc.max_relative_errors, bounds);
269    }
270
271    #[test]
272    fn test_check_coherence_tier1_passes() {
273        let mut store = TieredStore::new(4096);
274        let key = make_key(1, 0);
275        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.25).collect();
276
277        store.put(key, &data, Tier::Tier1, 0).unwrap();
278
279        let cc = CoherenceCheck::default();
280        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
281
282        assert_eq!(result.tier, Tier::Tier1);
283        assert!(
284            result.passed,
285            "Tier1 coherence should pass; max_error={}, bound={}",
286            result.max_error, cc.max_relative_errors[1],
287        );
288        assert!(
289            result.max_error < cc.max_relative_errors[1],
290            "max_error {} should be < bound {}",
291            result.max_error,
292            cc.max_relative_errors[1],
293        );
294    }
295
296    #[test]
297    fn test_check_coherence_tier3_passes() {
298        let mut store = TieredStore::new(4096);
299        let key = make_key(2, 0);
300        // Use values with large magnitude to keep relative error low under
301        // 3-bit quantization (only 7 levels).  Avoid near-zero values where
302        // even small absolute error produces large relative error.
303        let data: Vec<f32> = (0..32).map(|i| 10.0 + (i as f32) * 0.1).collect();
304
305        store.put(key, &data, Tier::Tier3, 0).unwrap();
306
307        let cc = CoherenceCheck::default();
308        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
309
310        assert_eq!(result.tier, Tier::Tier3);
311        assert!(
312            result.passed,
313            "Tier3 coherence should pass with default 0.35 bound; max_error={}",
314            result.max_error,
315        );
316    }
317
318    #[test]
319    fn test_check_coherence_missing_block() {
320        let mut store = TieredStore::new(4096);
321        let key = make_key(99, 0);
322        let data = vec![1.0f32; 8];
323        let cc = CoherenceCheck::default();
324
325        let err = cc.check_coherence(&mut store, key, &data, 0);
326        assert_eq!(err, Err(StoreError::BlockNotFound));
327    }
328
329    #[test]
330    fn test_check_coherence_evicted_block() {
331        use crate::store::ReconstructPolicy;
332
333        let mut store = TieredStore::new(4096);
334        let key = make_key(3, 0);
335        let data = vec![1.0f32; 16];
336
337        store.put(key, &data, Tier::Tier1, 0).unwrap();
338        store.evict(key, ReconstructPolicy::None).unwrap();
339
340        let cc = CoherenceCheck::default();
341        let err = cc.check_coherence(&mut store, key, &data, 1);
342        assert_eq!(err, Err(StoreError::TensorEvicted));
343    }
344
345    #[test]
346    fn test_check_coherence_tight_bound_fails() {
347        let mut store = TieredStore::new(4096);
348        let key = make_key(4, 0);
349        // Data with large dynamic range to maximize quantization error.
350        let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 10.0).collect();
351
352        // Store at Tier3 (3-bit) for maximum quantization error.
353        store.put(key, &data, Tier::Tier3, 0).unwrap();
354
355        // Use an extremely tight bound that 3-bit quantization cannot meet.
356        let cc = CoherenceCheck::new([f32::MAX, 0.001, 0.001, 0.001]);
357        let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
358
359        assert_eq!(result.tier, Tier::Tier3);
360        assert!(
361            !result.passed,
362            "Tier3 with 0.001 bound should fail; max_error={}",
363            result.max_error,
364        );
365    }
366
367    // -- verify_put ---------------------------------------------------------
368
369    #[test]
370    fn test_verify_put_tier1() {
371        let mut store = TieredStore::new(4096);
372        let key = make_key(10, 0);
373        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
374
375        let cc = CoherenceCheck::default();
376        let result = cc
377            .verify_put(&mut store, key, &data, Tier::Tier1, 0)
378            .unwrap();
379
380        assert_eq!(result.tier, Tier::Tier1);
381        assert!(result.passed, "verify_put Tier1 should pass");
382        assert_eq!(store.block_count(), 1);
383    }
384
385    #[test]
386    fn test_verify_put_tier0_rejected() {
387        let mut store = TieredStore::new(4096);
388        let key = make_key(11, 0);
389        let data = vec![1.0f32; 16];
390
391        let cc = CoherenceCheck::default();
392        let err = cc.verify_put(&mut store, key, &data, Tier::Tier0, 0);
393        assert_eq!(err, Err(StoreError::InvalidBlock));
394    }
395
396    #[test]
397    fn test_verify_put_tier2() {
398        let mut store = TieredStore::new(4096);
399        let key = make_key(12, 0);
400        let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.3).collect();
401
402        let cc = CoherenceCheck::default();
403        let result = cc
404            .verify_put(&mut store, key, &data, Tier::Tier2, 0)
405            .unwrap();
406
407        assert_eq!(result.tier, Tier::Tier2);
408        assert!(
409            result.passed,
410            "verify_put Tier2 should pass; max_error={}",
411            result.max_error
412        );
413    }
414
415    // -- compute_max_relative_error -----------------------------------------
416
417    #[test]
418    fn test_relative_error_identical() {
419        let a = vec![1.0, 2.0, 3.0];
420        let b = vec![1.0, 2.0, 3.0];
421        assert_eq!(compute_max_relative_error(&a, &b), 0.0);
422    }
423
424    #[test]
425    fn test_relative_error_known() {
426        let original = vec![10.0, 20.0, 50.0];
427        let decoded = vec![10.5, 20.0, 48.0];
428        let err = compute_max_relative_error(&original, &decoded);
429        // Element 0: |0.5| / 10.0 = 0.05
430        // Element 1: 0.0
431        // Element 2: |2.0| / 50.0 = 0.04
432        assert!((err - 0.05).abs() < 1e-6, "expected 0.05, got {err}");
433    }
434
435    #[test]
436    fn test_relative_error_near_zero() {
437        // Near-zero original values should use absolute error.
438        let original = vec![0.0, 1e-8, 1.0];
439        let decoded = vec![0.001, 0.0, 1.0];
440        let err = compute_max_relative_error(&original, &decoded);
441        // Element 0: |0.001| (absolute, since orig < epsilon)
442        // Element 1: |1e-8| (absolute, since orig < epsilon)
443        // Element 2: 0.0
444        assert!((err - 0.001).abs() < 1e-6, "expected ~0.001, got {err}");
445    }
446
447    #[test]
448    fn test_relative_error_empty() {
449        assert_eq!(compute_max_relative_error(&[], &[]), 0.0);
450    }
451
452    #[test]
453    fn test_relative_error_mismatched_lengths() {
454        let a = vec![1.0, 2.0, 3.0];
455        let b = vec![1.0, 2.0];
456        // Should only compare up to min(len(a), len(b)) = 2 elements.
457        let err = compute_max_relative_error(&a, &b);
458        assert_eq!(err, 0.0);
459    }
460
461    // -- EpochTracker -------------------------------------------------------
462
463    #[test]
464    fn test_epoch_tracker_new() {
465        let tracker = EpochTracker::new();
466        let key = make_key(1, 0);
467        assert_eq!(tracker.check_epoch(key), None);
468        assert!(!tracker.is_stale(key, 0));
469    }
470
471    #[test]
472    fn test_epoch_tracker_record_write() {
473        let mut tracker = EpochTracker::new();
474        let key = make_key(1, 0);
475
476        let e1 = tracker.record_write(key);
477        assert_eq!(e1, 1);
478        assert_eq!(tracker.check_epoch(key), Some(1));
479
480        let e2 = tracker.record_write(key);
481        assert_eq!(e2, 2);
482        assert_eq!(tracker.check_epoch(key), Some(2));
483    }
484
485    #[test]
486    fn test_epoch_tracker_monotonic_across_keys() {
487        let mut tracker = EpochTracker::new();
488        let key_a = make_key(1, 0);
489        let key_b = make_key(2, 0);
490
491        let e1 = tracker.record_write(key_a);
492        let e2 = tracker.record_write(key_b);
493        let e3 = tracker.record_write(key_a);
494
495        assert_eq!(e1, 1);
496        assert_eq!(e2, 2);
497        assert_eq!(e3, 3);
498
499        assert_eq!(tracker.check_epoch(key_a), Some(3));
500        assert_eq!(tracker.check_epoch(key_b), Some(2));
501    }
502
503    #[test]
504    fn test_epoch_tracker_is_stale() {
505        let mut tracker = EpochTracker::new();
506        let key = make_key(1, 0);
507
508        let epoch = tracker.record_write(key);
509        assert!(
510            !tracker.is_stale(key, epoch),
511            "same epoch should not be stale"
512        );
513        assert!(
514            !tracker.is_stale(key, epoch + 1),
515            "future epoch should not be stale"
516        );
517
518        // Write again -> epoch advances.
519        let _e2 = tracker.record_write(key);
520        assert!(
521            tracker.is_stale(key, epoch),
522            "old epoch should now be stale after a new write"
523        );
524    }
525
526    #[test]
527    fn test_epoch_tracker_unknown_key_not_stale() {
528        let tracker = EpochTracker::new();
529        let key = make_key(99, 0);
530        assert!(!tracker.is_stale(key, 0));
531        assert!(!tracker.is_stale(key, u64::MAX));
532    }
533
534    #[test]
535    fn test_epoch_tracker_multiple_keys_independent() {
536        let mut tracker = EpochTracker::new();
537        let key_a = make_key(1, 0);
538        let key_b = make_key(2, 0);
539
540        let ea = tracker.record_write(key_a);
541        let _eb = tracker.record_write(key_b);
542
543        // Writing key_b should not make key_a stale at its own epoch.
544        assert!(!tracker.is_stale(key_a, ea));
545    }
546
547    #[test]
548    fn test_epoch_tracker_default_trait() {
549        let tracker = EpochTracker::default();
550        assert_eq!(tracker.check_epoch(make_key(1, 0)), None);
551    }
552}