1use std::collections::HashMap;
20
21use crate::store::{BlockKey, StoreError, Tier, TieredStore};
22
23#[derive(Clone, Debug, PartialEq)]
29pub struct CoherenceResult {
30 pub max_error: f32,
32 pub tier: Tier,
34 pub passed: bool,
36}
37
38#[derive(Clone, Debug)]
48pub struct CoherenceCheck {
49 pub max_relative_errors: [f32; 4],
56}
57
58impl Default for CoherenceCheck {
59 fn default() -> Self {
60 Self {
61 max_relative_errors: [f32::MAX, 0.01, 0.02, 0.35],
66 }
67 }
68}
69
70impl CoherenceCheck {
71 pub fn new(max_relative_errors: [f32; 4]) -> Self {
73 Self {
74 max_relative_errors,
75 }
76 }
77
78 pub fn check_coherence(
90 &self,
91 store: &mut TieredStore,
92 key: BlockKey,
93 original_data: &[f32],
94 now: u64,
95 ) -> Result<CoherenceResult, StoreError> {
96 let tier = store.meta(key).ok_or(StoreError::BlockNotFound)?.tier;
98
99 let mut buf = vec![0.0f32; original_data.len()];
101 let n = store.get(key, &mut buf, now)?;
102
103 let max_error = compute_max_relative_error(original_data, &buf[..n]);
105
106 let tier_idx = tier as usize;
107 let bound = if tier_idx < self.max_relative_errors.len() {
108 self.max_relative_errors[tier_idx]
109 } else {
110 f32::MAX
111 };
112
113 Ok(CoherenceResult {
114 max_error,
115 tier,
116 passed: max_error <= bound,
117 })
118 }
119
120 pub fn verify_put(
130 &self,
131 store: &mut TieredStore,
132 key: BlockKey,
133 data: &[f32],
134 tier: Tier,
135 now: u64,
136 ) -> Result<CoherenceResult, StoreError> {
137 store.put(key, data, tier, now)?;
138 self.check_coherence(store, key, data, now)
139 }
140}
141
142fn compute_max_relative_error(original: &[f32], decoded: &[f32]) -> f32 {
152 const EPSILON: f32 = 1e-6;
153
154 let len = original.len().min(decoded.len());
155 let mut max_err: f32 = 0.0;
156
157 for i in 0..len {
158 let orig = original[i];
159 let dec = decoded[i];
160 let abs_err = (orig - dec).abs();
161
162 let rel_err = if orig.abs() > EPSILON {
163 abs_err / orig.abs()
164 } else {
165 abs_err
166 };
167
168 if rel_err > max_err {
169 max_err = rel_err;
170 }
171 }
172
173 max_err
174}
175
176#[derive(Clone, Debug)]
187pub struct EpochTracker {
188 next_epoch: u64,
190 epochs: HashMap<BlockKey, u64>,
192}
193
194impl EpochTracker {
195 pub fn new() -> Self {
197 Self {
198 next_epoch: 1,
199 epochs: HashMap::new(),
200 }
201 }
202
203 pub fn record_write(&mut self, key: BlockKey) -> u64 {
207 let epoch = self.next_epoch;
208 self.next_epoch += 1;
209 self.epochs.insert(key, epoch);
210 epoch
211 }
212
213 pub fn check_epoch(&self, key: BlockKey) -> Option<u64> {
215 self.epochs.get(&key).copied()
216 }
217
218 pub fn is_stale(&self, key: BlockKey, read_epoch: u64) -> bool {
224 match self.epochs.get(&key) {
225 Some(&write_epoch) => write_epoch > read_epoch,
226 None => false,
227 }
228 }
229}
230
231impl Default for EpochTracker {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::store::{BlockKey, Tier, TieredStore};
245
246 fn make_key(tid: u128, idx: u32) -> BlockKey {
247 BlockKey {
248 tensor_id: tid,
249 block_index: idx,
250 }
251 }
252
253 #[test]
256 fn test_coherence_check_default_bounds() {
257 let cc = CoherenceCheck::default();
258 assert_eq!(cc.max_relative_errors[0], f32::MAX);
259 assert!((cc.max_relative_errors[1] - 0.01).abs() < 1e-9);
260 assert!((cc.max_relative_errors[2] - 0.02).abs() < 1e-9);
261 assert!((cc.max_relative_errors[3] - 0.35).abs() < 1e-9);
262 }
263
264 #[test]
265 fn test_coherence_check_custom_bounds() {
266 let bounds = [0.0, 0.05, 0.10, 0.50];
267 let cc = CoherenceCheck::new(bounds);
268 assert_eq!(cc.max_relative_errors, bounds);
269 }
270
271 #[test]
272 fn test_check_coherence_tier1_passes() {
273 let mut store = TieredStore::new(4096);
274 let key = make_key(1, 0);
275 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.25).collect();
276
277 store.put(key, &data, Tier::Tier1, 0).unwrap();
278
279 let cc = CoherenceCheck::default();
280 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
281
282 assert_eq!(result.tier, Tier::Tier1);
283 assert!(
284 result.passed,
285 "Tier1 coherence should pass; max_error={}, bound={}",
286 result.max_error, cc.max_relative_errors[1],
287 );
288 assert!(
289 result.max_error < cc.max_relative_errors[1],
290 "max_error {} should be < bound {}",
291 result.max_error,
292 cc.max_relative_errors[1],
293 );
294 }
295
296 #[test]
297 fn test_check_coherence_tier3_passes() {
298 let mut store = TieredStore::new(4096);
299 let key = make_key(2, 0);
300 let data: Vec<f32> = (0..32).map(|i| 10.0 + (i as f32) * 0.1).collect();
304
305 store.put(key, &data, Tier::Tier3, 0).unwrap();
306
307 let cc = CoherenceCheck::default();
308 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
309
310 assert_eq!(result.tier, Tier::Tier3);
311 assert!(
312 result.passed,
313 "Tier3 coherence should pass with default 0.35 bound; max_error={}",
314 result.max_error,
315 );
316 }
317
318 #[test]
319 fn test_check_coherence_missing_block() {
320 let mut store = TieredStore::new(4096);
321 let key = make_key(99, 0);
322 let data = vec![1.0f32; 8];
323 let cc = CoherenceCheck::default();
324
325 let err = cc.check_coherence(&mut store, key, &data, 0);
326 assert_eq!(err, Err(StoreError::BlockNotFound));
327 }
328
329 #[test]
330 fn test_check_coherence_evicted_block() {
331 use crate::store::ReconstructPolicy;
332
333 let mut store = TieredStore::new(4096);
334 let key = make_key(3, 0);
335 let data = vec![1.0f32; 16];
336
337 store.put(key, &data, Tier::Tier1, 0).unwrap();
338 store.evict(key, ReconstructPolicy::None).unwrap();
339
340 let cc = CoherenceCheck::default();
341 let err = cc.check_coherence(&mut store, key, &data, 1);
342 assert_eq!(err, Err(StoreError::TensorEvicted));
343 }
344
345 #[test]
346 fn test_check_coherence_tight_bound_fails() {
347 let mut store = TieredStore::new(4096);
348 let key = make_key(4, 0);
349 let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 10.0).collect();
351
352 store.put(key, &data, Tier::Tier3, 0).unwrap();
354
355 let cc = CoherenceCheck::new([f32::MAX, 0.001, 0.001, 0.001]);
357 let result = cc.check_coherence(&mut store, key, &data, 1).unwrap();
358
359 assert_eq!(result.tier, Tier::Tier3);
360 assert!(
361 !result.passed,
362 "Tier3 with 0.001 bound should fail; max_error={}",
363 result.max_error,
364 );
365 }
366
367 #[test]
370 fn test_verify_put_tier1() {
371 let mut store = TieredStore::new(4096);
372 let key = make_key(10, 0);
373 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
374
375 let cc = CoherenceCheck::default();
376 let result = cc
377 .verify_put(&mut store, key, &data, Tier::Tier1, 0)
378 .unwrap();
379
380 assert_eq!(result.tier, Tier::Tier1);
381 assert!(result.passed, "verify_put Tier1 should pass");
382 assert_eq!(store.block_count(), 1);
383 }
384
385 #[test]
386 fn test_verify_put_tier0_rejected() {
387 let mut store = TieredStore::new(4096);
388 let key = make_key(11, 0);
389 let data = vec![1.0f32; 16];
390
391 let cc = CoherenceCheck::default();
392 let err = cc.verify_put(&mut store, key, &data, Tier::Tier0, 0);
393 assert_eq!(err, Err(StoreError::InvalidBlock));
394 }
395
396 #[test]
397 fn test_verify_put_tier2() {
398 let mut store = TieredStore::new(4096);
399 let key = make_key(12, 0);
400 let data: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.3).collect();
401
402 let cc = CoherenceCheck::default();
403 let result = cc
404 .verify_put(&mut store, key, &data, Tier::Tier2, 0)
405 .unwrap();
406
407 assert_eq!(result.tier, Tier::Tier2);
408 assert!(
409 result.passed,
410 "verify_put Tier2 should pass; max_error={}",
411 result.max_error
412 );
413 }
414
415 #[test]
418 fn test_relative_error_identical() {
419 let a = vec![1.0, 2.0, 3.0];
420 let b = vec![1.0, 2.0, 3.0];
421 assert_eq!(compute_max_relative_error(&a, &b), 0.0);
422 }
423
424 #[test]
425 fn test_relative_error_known() {
426 let original = vec![10.0, 20.0, 50.0];
427 let decoded = vec![10.5, 20.0, 48.0];
428 let err = compute_max_relative_error(&original, &decoded);
429 assert!((err - 0.05).abs() < 1e-6, "expected 0.05, got {err}");
433 }
434
435 #[test]
436 fn test_relative_error_near_zero() {
437 let original = vec![0.0, 1e-8, 1.0];
439 let decoded = vec![0.001, 0.0, 1.0];
440 let err = compute_max_relative_error(&original, &decoded);
441 assert!((err - 0.001).abs() < 1e-6, "expected ~0.001, got {err}");
445 }
446
447 #[test]
448 fn test_relative_error_empty() {
449 assert_eq!(compute_max_relative_error(&[], &[]), 0.0);
450 }
451
452 #[test]
453 fn test_relative_error_mismatched_lengths() {
454 let a = vec![1.0, 2.0, 3.0];
455 let b = vec![1.0, 2.0];
456 let err = compute_max_relative_error(&a, &b);
458 assert_eq!(err, 0.0);
459 }
460
461 #[test]
464 fn test_epoch_tracker_new() {
465 let tracker = EpochTracker::new();
466 let key = make_key(1, 0);
467 assert_eq!(tracker.check_epoch(key), None);
468 assert!(!tracker.is_stale(key, 0));
469 }
470
471 #[test]
472 fn test_epoch_tracker_record_write() {
473 let mut tracker = EpochTracker::new();
474 let key = make_key(1, 0);
475
476 let e1 = tracker.record_write(key);
477 assert_eq!(e1, 1);
478 assert_eq!(tracker.check_epoch(key), Some(1));
479
480 let e2 = tracker.record_write(key);
481 assert_eq!(e2, 2);
482 assert_eq!(tracker.check_epoch(key), Some(2));
483 }
484
485 #[test]
486 fn test_epoch_tracker_monotonic_across_keys() {
487 let mut tracker = EpochTracker::new();
488 let key_a = make_key(1, 0);
489 let key_b = make_key(2, 0);
490
491 let e1 = tracker.record_write(key_a);
492 let e2 = tracker.record_write(key_b);
493 let e3 = tracker.record_write(key_a);
494
495 assert_eq!(e1, 1);
496 assert_eq!(e2, 2);
497 assert_eq!(e3, 3);
498
499 assert_eq!(tracker.check_epoch(key_a), Some(3));
500 assert_eq!(tracker.check_epoch(key_b), Some(2));
501 }
502
503 #[test]
504 fn test_epoch_tracker_is_stale() {
505 let mut tracker = EpochTracker::new();
506 let key = make_key(1, 0);
507
508 let epoch = tracker.record_write(key);
509 assert!(
510 !tracker.is_stale(key, epoch),
511 "same epoch should not be stale"
512 );
513 assert!(
514 !tracker.is_stale(key, epoch + 1),
515 "future epoch should not be stale"
516 );
517
518 let _e2 = tracker.record_write(key);
520 assert!(
521 tracker.is_stale(key, epoch),
522 "old epoch should now be stale after a new write"
523 );
524 }
525
526 #[test]
527 fn test_epoch_tracker_unknown_key_not_stale() {
528 let tracker = EpochTracker::new();
529 let key = make_key(99, 0);
530 assert!(!tracker.is_stale(key, 0));
531 assert!(!tracker.is_stale(key, u64::MAX));
532 }
533
534 #[test]
535 fn test_epoch_tracker_multiple_keys_independent() {
536 let mut tracker = EpochTracker::new();
537 let key_a = make_key(1, 0);
538 let key_b = make_key(2, 0);
539
540 let ea = tracker.record_write(key_a);
541 let _eb = tracker.record_write(key_b);
542
543 assert!(!tracker.is_stale(key_a, ea));
545 }
546
547 #[test]
548 fn test_epoch_tracker_default_trait() {
549 let tracker = EpochTracker::default();
550 assert_eq!(tracker.check_epoch(make_key(1, 0)), None);
551 }
552}