1use std::collections::HashMap;
20
21use crate::store::{BlockKey, StoreError, Tier, TieredStore};
22
23#[derive(Clone, Debug, PartialEq)]
29pub struct CoherenceResult {
30 pub max_error: f32,
32 pub tier: Tier,
34 pub passed: bool,
36}
37
38#[derive(Clone, Debug)]
48pub struct CoherenceCheck {
49 pub max_relative_errors: [f32; 4],
56}
57
58impl Default for CoherenceCheck {
59 fn default() -> Self {
60 Self {
61 max_relative_errors: [f32::MAX, 0.01, 0.02, 0.35],
66 }
67 }
68}
69
70impl CoherenceCheck {
71 pub fn new(max_relative_errors: [f32; 4]) -> Self {
73 Self { max_relative_errors }
74 }
75
76 pub fn check_coherence(
88 &self,
89 store: &mut TieredStore,
90 key: BlockKey,
91 original_data: &[f32],
92 now: u64,
93 ) -> Result<CoherenceResult, StoreError> {
94 let tier = store
96 .meta(key)
97 .ok_or(StoreError::BlockNotFound)?
98 .tier;
99
100 let mut buf = vec![0.0f32; original_data.len()];
102 let n = store.get(key, &mut buf, now)?;
103
104 let max_error = compute_max_relative_error(original_data, &buf[..n]);
106
107 let tier_idx = tier as usize;
108 let bound = if tier_idx < self.max_relative_errors.len() {
109 self.max_relative_errors[tier_idx]
110 } else {
111 f32::MAX
112 };
113
114 Ok(CoherenceResult {
115 max_error,
116 tier,
117 passed: max_error <= bound,
118 })
119 }
120
121 pub fn verify_put(
131 &self,
132 store: &mut TieredStore,
133 key: BlockKey,
134 data: &[f32],
135 tier: Tier,
136 now: u64,
137 ) -> Result<CoherenceResult, StoreError> {
138 store.put(key, data, tier, now)?;
139 self.check_coherence(store, key, data, now)
140 }
141}
142
143fn compute_max_relative_error(original: &[f32], decoded: &[f32]) -> f32 {
153 const EPSILON: f32 = 1e-6;
154
155 let len = original.len().min(decoded.len());
156 let mut max_err: f32 = 0.0;
157
158 for i in 0..len {
159 let orig = original[i];
160 let dec = decoded[i];
161 let abs_err = (orig - dec).abs();
162
163 let rel_err = if orig.abs() > EPSILON {
164 abs_err / orig.abs()
165 } else {
166 abs_err
167 };
168
169 if rel_err > max_err {
170 max_err = rel_err;
171 }
172 }
173
174 max_err
175}
176
177#[derive(Clone, Debug)]
188pub struct EpochTracker {
189 next_epoch: u64,
191 epochs: HashMap<BlockKey, u64>,
193}
194
195impl EpochTracker {
196 pub fn new() -> Self {
198 Self {
199 next_epoch: 1,
200 epochs: HashMap::new(),
201 }
202 }
203
204 pub fn record_write(&mut self, key: BlockKey) -> u64 {
208 let epoch = self.next_epoch;
209 self.next_epoch += 1;
210 self.epochs.insert(key, epoch);
211 epoch
212 }
213
214 pub fn check_epoch(&self, key: BlockKey) -> Option<u64> {
216 self.epochs.get(&key).copied()
217 }
218
219 pub fn is_stale(&self, key: BlockKey, read_epoch: u64) -> bool {
225 match self.epochs.get(&key) {
226 Some(&write_epoch) => write_epoch > read_epoch,
227 None => false,
228 }
229 }
230}
231
232impl Default for EpochTracker {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::store::{BlockKey, Tier, TieredStore};
246
247 fn make_key(tid: u128, idx: u32) -> BlockKey {
248 BlockKey {
249 tensor_id: tid,
250 block_index: idx,
251 }
252 }
253
254 #[test]
257 fn test_coherence_check_default_bounds() {
258 let cc = CoherenceCheck::default();
259 assert_eq!(cc.max_relative_errors[0], f32::MAX);
260 assert!((cc.max_relative_errors[1] - 0.01).abs() < 1e-9);
261 assert!((cc.max_relative_errors[2] - 0.02).abs() < 1e-9);
262 assert!((cc.max_relative_errors[3] - 0.35).abs() < 1e-9);
263 }
264
265 #[test]
266 fn test_coherence_check_custom_bounds() {
267 let bounds = [0.0, 0.05, 0.10, 0.50];
268 let cc = CoherenceCheck::new(bounds);
269 assert_eq!(cc.max_relative_errors, bounds);
270 }
271
272 #[test]
273 fn test_check_coherence_tier1_passes() {
274 let mut store = TieredStore::new(4096);
275 let key = make_key(1, 0);
276 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.25).collect();
277
278 store.put(key, &data, Tier::Tier1, 0).unwrap();
279
280 let cc = CoherenceCheck::default();
281 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
282
283 assert_eq!(result.tier, Tier::Tier1);
284 assert!(
285 result.passed,
286 "Tier1 coherence should pass; max_error={}, bound={}",
287 result.max_error,
288 cc.max_relative_errors[1],
289 );
290 assert!(
291 result.max_error < cc.max_relative_errors[1],
292 "max_error {} should be < bound {}",
293 result.max_error,
294 cc.max_relative_errors[1],
295 );
296 }
297
298 #[test]
299 fn test_check_coherence_tier3_passes() {
300 let mut store = TieredStore::new(4096);
301 let key = make_key(2, 0);
302 let data: Vec<f32> = (0..32).map(|i| 10.0 + (i as f32) * 0.1).collect();
306
307 store.put(key, &data, Tier::Tier3, 0).unwrap();
308
309 let cc = CoherenceCheck::default();
310 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
311
312 assert_eq!(result.tier, Tier::Tier3);
313 assert!(
314 result.passed,
315 "Tier3 coherence should pass with default 0.35 bound; max_error={}",
316 result.max_error,
317 );
318 }
319
320 #[test]
321 fn test_check_coherence_missing_block() {
322 let mut store = TieredStore::new(4096);
323 let key = make_key(99, 0);
324 let data = vec![1.0f32; 8];
325 let cc = CoherenceCheck::default();
326
327 let err = cc.check_coherence(&mut store, key, &data, 0);
328 assert_eq!(err, Err(StoreError::BlockNotFound));
329 }
330
331 #[test]
332 fn test_check_coherence_evicted_block() {
333 use crate::store::ReconstructPolicy;
334
335 let mut store = TieredStore::new(4096);
336 let key = make_key(3, 0);
337 let data = vec![1.0f32; 16];
338
339 store.put(key, &data, Tier::Tier1, 0).unwrap();
340 store.evict(key, ReconstructPolicy::None).unwrap();
341
342 let cc = CoherenceCheck::default();
343 let err = cc.check_coherence(&mut store, key, &data, 1);
344 assert_eq!(err, Err(StoreError::TensorEvicted));
345 }
346
347 #[test]
348 fn test_check_coherence_tight_bound_fails() {
349 let mut store = TieredStore::new(4096);
350 let key = make_key(4, 0);
351 let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 10.0).collect();
353
354 store.put(key, &data, Tier::Tier3, 0).unwrap();
356
357 let cc = CoherenceCheck::new([f32::MAX, 0.001, 0.001, 0.001]);
359 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
360
361 assert_eq!(result.tier, Tier::Tier3);
362 assert!(
363 !result.passed,
364 "Tier3 with 0.001 bound should fail; max_error={}",
365 result.max_error,
366 );
367 }
368
369 #[test]
372 fn test_verify_put_tier1() {
373 let mut store = TieredStore::new(4096);
374 let key = make_key(10, 0);
375 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
376
377 let cc = CoherenceCheck::default();
378 let result = cc.verify_put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
379
380 assert_eq!(result.tier, Tier::Tier1);
381 assert!(result.passed, "verify_put Tier1 should pass");
382 assert_eq!(store.block_count(), 1);
383 }
384
385 #[test]
386 fn test_verify_put_tier0_rejected() {
387 let mut store = TieredStore::new(4096);
388 let key = make_key(11, 0);
389 let data = vec![1.0f32; 16];
390
391 let cc = CoherenceCheck::default();
392 let err = cc.verify_put(&mut store, key, &data, Tier::Tier0, 0);
393 assert_eq!(err, Err(StoreError::InvalidBlock));
394 }
395
396 #[test]
397 fn test_verify_put_tier2() {
398 let mut store = TieredStore::new(4096);
399 let key = make_key(12, 0);
400 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.3).collect();
401
402 let cc = CoherenceCheck::default();
403 let result = cc.verify_put(&mut store, key, &data, Tier::Tier2, 0).unwrap();
404
405 assert_eq!(result.tier, Tier::Tier2);
406 assert!(result.passed, "verify_put Tier2 should pass; max_error={}", result.max_error);
407 }
408
409 #[test]
412 fn test_relative_error_identical() {
413 let a = vec![1.0, 2.0, 3.0];
414 let b = vec![1.0, 2.0, 3.0];
415 assert_eq!(compute_max_relative_error(&a, &b), 0.0);
416 }
417
418 #[test]
419 fn test_relative_error_known() {
420 let original = vec![10.0, 20.0, 50.0];
421 let decoded = vec![10.5, 20.0, 48.0];
422 let err = compute_max_relative_error(&original, &decoded);
423 assert!((err - 0.05).abs() < 1e-6, "expected 0.05, got {err}");
427 }
428
429 #[test]
430 fn test_relative_error_near_zero() {
431 let original = vec![0.0, 1e-8, 1.0];
433 let decoded = vec![0.001, 0.0, 1.0];
434 let err = compute_max_relative_error(&original, &decoded);
435 assert!((err - 0.001).abs() < 1e-6, "expected ~0.001, got {err}");
439 }
440
441 #[test]
442 fn test_relative_error_empty() {
443 assert_eq!(compute_max_relative_error(&[], &[]), 0.0);
444 }
445
446 #[test]
447 fn test_relative_error_mismatched_lengths() {
448 let a = vec![1.0, 2.0, 3.0];
449 let b = vec![1.0, 2.0];
450 let err = compute_max_relative_error(&a, &b);
452 assert_eq!(err, 0.0);
453 }
454
455 #[test]
458 fn test_epoch_tracker_new() {
459 let tracker = EpochTracker::new();
460 let key = make_key(1, 0);
461 assert_eq!(tracker.check_epoch(key), None);
462 assert!(!tracker.is_stale(key, 0));
463 }
464
465 #[test]
466 fn test_epoch_tracker_record_write() {
467 let mut tracker = EpochTracker::new();
468 let key = make_key(1, 0);
469
470 let e1 = tracker.record_write(key);
471 assert_eq!(e1, 1);
472 assert_eq!(tracker.check_epoch(key), Some(1));
473
474 let e2 = tracker.record_write(key);
475 assert_eq!(e2, 2);
476 assert_eq!(tracker.check_epoch(key), Some(2));
477 }
478
479 #[test]
480 fn test_epoch_tracker_monotonic_across_keys() {
481 let mut tracker = EpochTracker::new();
482 let key_a = make_key(1, 0);
483 let key_b = make_key(2, 0);
484
485 let e1 = tracker.record_write(key_a);
486 let e2 = tracker.record_write(key_b);
487 let e3 = tracker.record_write(key_a);
488
489 assert_eq!(e1, 1);
490 assert_eq!(e2, 2);
491 assert_eq!(e3, 3);
492
493 assert_eq!(tracker.check_epoch(key_a), Some(3));
494 assert_eq!(tracker.check_epoch(key_b), Some(2));
495 }
496
497 #[test]
498 fn test_epoch_tracker_is_stale() {
499 let mut tracker = EpochTracker::new();
500 let key = make_key(1, 0);
501
502 let epoch = tracker.record_write(key);
503 assert!(!tracker.is_stale(key, epoch), "same epoch should not be stale");
504 assert!(!tracker.is_stale(key, epoch + 1), "future epoch should not be stale");
505
506 let _e2 = tracker.record_write(key);
508 assert!(
509 tracker.is_stale(key, epoch),
510 "old epoch should now be stale after a new write"
511 );
512 }
513
514 #[test]
515 fn test_epoch_tracker_unknown_key_not_stale() {
516 let tracker = EpochTracker::new();
517 let key = make_key(99, 0);
518 assert!(!tracker.is_stale(key, 0));
519 assert!(!tracker.is_stale(key, u64::MAX));
520 }
521
522 #[test]
523 fn test_epoch_tracker_multiple_keys_independent() {
524 let mut tracker = EpochTracker::new();
525 let key_a = make_key(1, 0);
526 let key_b = make_key(2, 0);
527
528 let ea = tracker.record_write(key_a);
529 let _eb = tracker.record_write(key_b);
530
531 assert!(!tracker.is_stale(key_a, ea));
533 }
534
535 #[test]
536 fn test_epoch_tracker_default_trait() {
537 let tracker = EpochTracker::default();
538 assert_eq!(tracker.check_epoch(make_key(1, 0)), None);
539 }
540}