1use crate::store::StoreError;
8
9#[allow(unused_imports)]
10use crate::store::{BlockKey, ReconstructPolicy};
11
12const DELTA_HEADER_BYTES: usize = 34;
14const DELTA_ENTRY_BYTES: usize = 4;
16const POWER_ITER_MAX: usize = 30;
18const POWER_ITER_EPS: f32 = 1e-10;
20
21#[derive(Clone, Debug)]
23pub struct DeltaHeader {
24 pub tensor_id: u128,
25 pub block_index: u32,
26 pub base_epoch: u64,
27 pub nnz: u16,
28}
29
30#[derive(Clone, Copy, Debug)]
32pub struct SparseEntry {
33 pub index: u16,
34 pub value: i16,
35}
36
37#[derive(Clone, Debug)]
41pub struct DeltaRecord {
42 pub header: DeltaHeader,
43 pub delta_scale: f32,
44 pub entries: Vec<SparseEntry>,
45}
46
47pub fn compute_delta(
56 old: &[f32],
57 new: &[f32],
58 tensor_id: u128,
59 block_index: u32,
60 base_epoch: u64,
61 threshold: f32,
62 max_change_fraction: f32,
63) -> Option<DeltaRecord> {
64 assert_eq!(old.len(), new.len(), "old and new must have equal length");
65 let n = old.len();
66 if n == 0 {
67 return Some(DeltaRecord {
68 header: DeltaHeader { tensor_id, block_index, base_epoch, nnz: 0 },
69 delta_scale: 0.0,
70 entries: Vec::new(),
71 });
72 }
73
74 let mut changed: Vec<(u16, f32)> = Vec::new();
75 let mut max_abs = 0.0f32;
76 for i in 0..n {
77 let diff = new[i] - old[i];
78 if diff.abs() >= threshold {
79 changed.push((i as u16, diff));
80 if diff.abs() > max_abs { max_abs = diff.abs(); }
81 }
82 }
83
84 if changed.len() as f32 / n as f32 >= max_change_fraction {
85 return None;
86 }
87
88 let delta_scale = if max_abs == 0.0 { 1.0 } else { max_abs / i16::MAX as f32 };
89 let inv_scale = 1.0 / delta_scale;
90 let entries: Vec<SparseEntry> = changed
91 .iter()
92 .map(|&(idx, diff)| {
93 let q = (diff * inv_scale).round() as i32;
94 SparseEntry { index: idx, value: q.clamp(i16::MIN as i32, i16::MAX as i32) as i16 }
95 })
96 .collect();
97
98 Some(DeltaRecord {
99 header: DeltaHeader { tensor_id, block_index, base_epoch, nnz: entries.len() as u16 },
100 delta_scale,
101 entries,
102 })
103}
104
105pub fn apply_delta(base: &mut [f32], delta: &DeltaRecord) {
109 let scale = delta.delta_scale;
110 for entry in &delta.entries {
111 let idx = entry.index as usize;
112 if idx < base.len() {
113 base[idx] += entry.value as f32 * scale;
114 }
115 }
116}
117
118#[derive(Clone, Debug)]
121pub struct DeltaChain {
122 base_data: Vec<f32>,
123 deltas: Vec<DeltaRecord>,
124 max_chain_len: u8,
125}
126
127impl DeltaChain {
128 pub fn new(base_data: Vec<f32>, max_chain_len: u8) -> Self {
130 Self { base_data, deltas: Vec::new(), max_chain_len }
131 }
132
133 pub fn append(&mut self, delta: DeltaRecord) -> Result<(), StoreError> {
135 if self.deltas.len() >= self.max_chain_len as usize {
136 return Err(StoreError::DeltaChainTooLong);
137 }
138 self.deltas.push(delta);
139 Ok(())
140 }
141
142 pub fn reconstruct(&self) -> Vec<f32> {
144 let mut result = self.base_data.clone();
145 for delta in &self.deltas {
146 apply_delta(&mut result, delta);
147 }
148 result
149 }
150
151 pub fn compact(&mut self) {
153 if self.deltas.is_empty() { return; }
154 for delta in &self.deltas {
155 apply_delta(&mut self.base_data, delta);
156 }
157 self.deltas.clear();
158 }
159
160 #[inline]
162 pub fn chain_len(&self) -> usize { self.deltas.len() }
163
164 #[inline]
166 pub fn needs_compaction(&self) -> bool {
167 self.deltas.len() >= self.max_chain_len as usize
168 }
169
170 pub fn total_bytes(&self) -> usize {
172 let base_bytes = self.base_data.len() * 4;
173 let delta_bytes: usize = self.deltas.iter()
174 .map(|d| DELTA_HEADER_BYTES + d.entries.len() * DELTA_ENTRY_BYTES)
175 .sum();
176 base_bytes + delta_bytes
177 }
178}
179
180#[derive(Clone, Debug)]
185pub struct FactorSet {
186 pub m: usize,
187 pub n: usize,
188 pub k: usize,
189 pub u_data: Vec<f32>, pub s_data: Vec<f32>, pub v_data: Vec<f32>, }
193
194impl FactorSet {
195 pub fn reconstruct(&self) -> Vec<f32> {
197 let mut out = vec![0.0f32; self.m * self.n];
198 for r in 0..self.k {
199 let s_r = self.s_data[r];
200 for i in 0..self.m {
201 let u_s = self.u_data[i * self.k + r] * s_r;
202 let row = i * self.n;
203 let v_off = r * self.n;
204 for j in 0..self.n {
205 out[row + j] += u_s * self.v_data[v_off + j];
206 }
207 }
208 }
209 out
210 }
211
212 pub fn storage_bytes(&self) -> usize {
214 (self.m * self.k + self.k + self.k * self.n) * 4
215 }
216
217 pub fn from_data(data: &[f32], rows: usize, cols: usize, rank: usize) -> Self {
226 assert_eq!(data.len(), rows * cols, "data length must equal rows * cols");
227 let (m, n) = (rows, cols);
228 let k = rank.min(m).min(n);
229 let mut work = data.to_vec();
230 let mut u_data = vec![0.0f32; m * k];
231 let mut s_data = vec![0.0f32; k];
232 let mut v_data = vec![0.0f32; k * n];
233
234 for r in 0..k {
235 let inv_sqrt_n = 1.0 / (n as f32).sqrt();
237 let mut v = vec![0.0f32; n];
238 for j in 0..n {
239 let seed = (j as u32).wrapping_mul(2_654_435_761)
240 .wrapping_add((r as u32).wrapping_mul(0x9E37_79B9));
241 v[j] = if seed & 1 == 0 { inv_sqrt_n } else { -inv_sqrt_n };
242 }
243 let mut u = vec![0.0f32; m];
244 let mut sigma = 0.0f32;
245
246 for _ in 0..POWER_ITER_MAX {
247 for i in 0..m {
249 let mut acc = 0.0f32;
250 let row = i * n;
251 for j in 0..n { acc += work[row + j] * v[j]; }
252 u[i] = acc;
253 }
254 let su: f32 = u.iter().map(|x| x * x).sum::<f32>().sqrt();
255 if su < POWER_ITER_EPS { sigma = 0.0; break; }
256 let inv = 1.0 / su;
257 for x in u.iter_mut() { *x *= inv; }
258
259 for j in 0..n {
261 let mut acc = 0.0f32;
262 for i in 0..m { acc += work[i * n + j] * u[i]; }
263 v[j] = acc;
264 }
265 let sv: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
266 if sv < POWER_ITER_EPS { sigma = su; break; }
267 sigma = sv;
268 let inv = 1.0 / sv;
269 for x in v.iter_mut() { *x *= inv; }
270 }
271
272 s_data[r] = sigma;
273 for i in 0..m { u_data[i * k + r] = u[i]; }
274 for j in 0..n { v_data[r * n + j] = v[j]; }
275
276 if sigma > POWER_ITER_EPS {
278 for i in 0..m {
279 let us = u[i] * sigma;
280 let row = i * n;
281 for j in 0..n { work[row + j] -= us * v[j]; }
282 }
283 }
284 }
285 Self { m, n, k, u_data, s_data, v_data }
286 }
287
288 pub fn reconstruction_error(&self, original: &[f32]) -> f32 {
293 let reconstructed = self.reconstruct();
294 let mut diff_sq = 0.0f32;
295 let mut orig_sq = 0.0f32;
296 for (i, &o) in original.iter().enumerate() {
297 let r = if i < reconstructed.len() { reconstructed[i] } else { 0.0 };
298 diff_sq += (o - r) * (o - r);
299 orig_sq += o * o;
300 }
301 if orig_sq < 1e-30 {
302 return 0.0;
303 }
304 (diff_sq / orig_sq).sqrt()
305 }
306
307 pub fn energy_captured(&self, original: &[f32]) -> f32 {
312 let total_energy: f32 = original.iter().map(|x| x * x).sum();
313 if total_energy < 1e-30 {
314 return 1.0;
315 }
316 let captured: f32 = self.s_data.iter().map(|s| s * s).sum();
317 (captured / total_energy).min(1.0)
318 }
319
320 pub fn compression_ratio(&self, original_elements: usize) -> f32 {
324 let raw = original_elements * 4;
325 let stored = self.storage_bytes();
326 if stored == 0 {
327 return 0.0;
328 }
329 raw as f32 / stored as f32
330 }
331
332 pub fn from_data_adaptive(
337 data: &[f32],
338 rows: usize,
339 cols: usize,
340 max_rank: usize,
341 target_error: f32,
342 ) -> Self {
343 let max_k = max_rank.min(rows).min(cols);
344 let mut best = Self::from_data(data, rows, cols, 1);
345 for rank in 2..=max_k {
346 let err = best.reconstruction_error(data);
347 if err <= target_error {
348 break;
349 }
350 best = Self::from_data(data, rows, cols, rank);
351 }
352 best
353 }
354}
355
356pub fn encode_delta(delta: &DeltaRecord) -> Vec<u8> {
358 let mut buf = Vec::with_capacity(DELTA_HEADER_BYTES + delta.entries.len() * DELTA_ENTRY_BYTES);
359 buf.extend_from_slice(&delta.header.tensor_id.to_le_bytes());
360 buf.extend_from_slice(&delta.header.block_index.to_le_bytes());
361 buf.extend_from_slice(&delta.header.base_epoch.to_le_bytes());
362 buf.extend_from_slice(&delta.header.nnz.to_le_bytes());
363 buf.extend_from_slice(&delta.delta_scale.to_le_bytes());
364 for entry in &delta.entries {
365 buf.extend_from_slice(&entry.index.to_le_bytes());
366 buf.extend_from_slice(&entry.value.to_le_bytes());
367 }
368 buf
369}
370
371pub fn decode_delta(data: &[u8]) -> Result<DeltaRecord, StoreError> {
375 if data.len() < DELTA_HEADER_BYTES { return Err(StoreError::InvalidBlock); }
376 let tensor_id = u128::from_le_bytes(data[0..16].try_into().map_err(|_| StoreError::InvalidBlock)?);
377 let block_index = u32::from_le_bytes(data[16..20].try_into().map_err(|_| StoreError::InvalidBlock)?);
378 let base_epoch = u64::from_le_bytes(data[20..28].try_into().map_err(|_| StoreError::InvalidBlock)?);
379 let nnz = u16::from_le_bytes(data[28..30].try_into().map_err(|_| StoreError::InvalidBlock)?);
380 let delta_scale = f32::from_le_bytes(data[30..34].try_into().map_err(|_| StoreError::InvalidBlock)?);
381
382 if data.len() < DELTA_HEADER_BYTES + (nnz as usize) * DELTA_ENTRY_BYTES {
383 return Err(StoreError::InvalidBlock);
384 }
385 let mut entries = Vec::with_capacity(nnz as usize);
386 let mut off = DELTA_HEADER_BYTES;
387 for _ in 0..nnz {
388 let index = u16::from_le_bytes(data[off..off + 2].try_into().map_err(|_| StoreError::InvalidBlock)?);
389 let value = i16::from_le_bytes(data[off + 2..off + 4].try_into().map_err(|_| StoreError::InvalidBlock)?);
390 entries.push(SparseEntry { index, value });
391 off += DELTA_ENTRY_BYTES;
392 }
393
394 Ok(DeltaRecord {
395 header: DeltaHeader { tensor_id, block_index, base_epoch, nnz },
396 delta_scale,
397 entries,
398 })
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn make_delta(entries: Vec<(u16, i16)>, scale: f32) -> DeltaRecord {
406 let sparse: Vec<SparseEntry> = entries.iter()
407 .map(|&(i, v)| SparseEntry { index: i, value: v }).collect();
408 DeltaRecord {
409 header: DeltaHeader { tensor_id: 42, block_index: 0, base_epoch: 1, nnz: sparse.len() as u16 },
410 delta_scale: scale,
411 entries: sparse,
412 }
413 }
414
415 #[test]
416 fn test_compute_delta_small_change() {
417 let old = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
418 let mut new = old.clone();
419 new[2] = 3.5;
420 let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
421 assert_eq!(d.entries.len(), 1);
422 assert_eq!(d.entries[0].index, 2);
423 assert!(d.delta_scale > 0.0);
424 }
425
426 #[test]
427 fn test_compute_delta_large_change_returns_none() {
428 let old = vec![1.0; 10];
429 let new = vec![5.0; 10];
430 assert!(compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).is_none());
431 }
432
433 #[test]
434 fn test_apply_delta_modifies_base() {
435 let mut base = vec![1.0, 2.0, 3.0, 4.0];
436 apply_delta(&mut base, &make_delta(vec![(1, 100), (3, -50)], 0.01));
437 assert!((base[0] - 1.0).abs() < 1e-6);
438 assert!((base[1] - 3.0).abs() < 1e-6); assert!((base[2] - 3.0).abs() < 1e-6);
440 assert!((base[3] - 3.5).abs() < 1e-6); }
442
443 #[test]
444 fn test_chain_append_and_reconstruct() {
445 let mut chain = DeltaChain::new(vec![1.0, 2.0, 3.0, 4.0], 4);
446 chain.append(make_delta(vec![(0, 1000)], 0.001)).unwrap(); assert_eq!(chain.chain_len(), 1);
448 let r = chain.reconstruct();
449 assert!((r[0] - 2.0).abs() < 1e-3);
450 assert!((r[1] - 2.0).abs() < 1e-6);
451 }
452
453 #[test]
454 fn test_chain_compact_preserves_state() {
455 let mut chain = DeltaChain::new(vec![0.0; 4], 8);
456 chain.append(make_delta(vec![(0, 100)], 0.1)).unwrap(); chain.append(make_delta(vec![(1, 200)], 0.1)).unwrap(); let before = chain.reconstruct();
459 chain.compact();
460 assert_eq!(chain.chain_len(), 0);
461 let after = chain.reconstruct();
462 for (a, b) in before.iter().zip(after.iter()) { assert!((a - b).abs() < 1e-6); }
463 }
464
465 #[test]
466 fn test_chain_max_length_enforcement() {
467 let mut chain = DeltaChain::new(vec![1.0; 4], 2);
468 assert!(chain.append(make_delta(vec![(0, 1)], 0.1)).is_ok());
469 assert!(chain.append(make_delta(vec![(1, 1)], 0.1)).is_ok());
470 assert!(chain.append(make_delta(vec![(2, 1)], 0.1)).is_err());
471 }
472
473 #[test]
474 fn test_chain_needs_compaction() {
475 let mut chain = DeltaChain::new(vec![1.0; 4], 2);
476 assert!(!chain.needs_compaction());
477 chain.append(make_delta(vec![(0, 1)], 0.1)).unwrap();
478 assert!(!chain.needs_compaction());
479 chain.append(make_delta(vec![(1, 1)], 0.1)).unwrap();
480 assert!(chain.needs_compaction());
481 }
482
483 #[test]
484 fn test_factor_reconstruct() {
485 let (u, v, s) = (vec![1.0, 2.0, 3.0], vec![4.0, 5.0], 2.0);
486 let f = FactorSet { m: 3, n: 2, k: 1, u_data: u.clone(), s_data: vec![s], v_data: v.clone() };
487 let r = f.reconstruct();
488 assert_eq!(r.len(), 6);
489 for i in 0..3 {
490 for j in 0..2 {
491 assert!((r[i * 2 + j] - u[i] * s * v[j]).abs() < 1e-6);
492 }
493 }
494 }
495
496 #[test]
497 fn test_factor_from_data_approximation() {
498 let (m, n) = (8, 6);
499 let data: Vec<f32> = (0..m * n).map(|idx| {
500 let (i, j) = (idx / n, idx % n);
501 (i as f32 + 1.0) * (j as f32 + 1.0)
502 }).collect();
503 let reconstructed = FactorSet::from_data(&data, m, n, 1).reconstruct();
504 let max_err = data.iter().zip(reconstructed.iter())
505 .map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
506 assert!(max_err < 0.5, "max error {max_err} too large for rank-1 input");
507 }
508
509 #[test]
510 fn test_encode_decode_roundtrip() {
511 let orig = DeltaRecord {
512 header: DeltaHeader { tensor_id: 0xDEADBEEFCAFEBABE, block_index: 42, base_epoch: 100, nnz: 3 },
513 delta_scale: 0.001,
514 entries: vec![
515 SparseEntry { index: 10, value: 500 },
516 SparseEntry { index: 20, value: -300 },
517 SparseEntry { index: 30, value: 1 },
518 ],
519 };
520 let bytes = encode_delta(&orig);
521 assert_eq!(bytes.len(), DELTA_HEADER_BYTES + 3 * DELTA_ENTRY_BYTES);
522 let dec = decode_delta(&bytes).unwrap();
523 assert_eq!(dec.header.tensor_id, orig.header.tensor_id);
524 assert_eq!(dec.header.block_index, orig.header.block_index);
525 assert_eq!(dec.header.nnz, orig.header.nnz);
526 assert!((dec.delta_scale - orig.delta_scale).abs() < 1e-10);
527 for (a, b) in dec.entries.iter().zip(orig.entries.iter()) {
528 assert_eq!(a.index, b.index);
529 assert_eq!(a.value, b.value);
530 }
531 }
532
533 #[test]
534 fn test_decode_truncated_header() { assert!(decode_delta(&vec![0u8; 20]).is_err()); }
535
536 #[test]
537 fn test_decode_truncated_entries() {
538 let mut bytes = encode_delta(&make_delta(vec![(0, 1), (1, 2)], 1.0));
539 bytes[28] = 5; bytes[29] = 0; assert!(decode_delta(&bytes).is_err());
541 }
542
543 #[test]
544 fn test_empty_delta_roundtrip() {
545 let d = DeltaRecord {
546 header: DeltaHeader { tensor_id: 99, block_index: 7, base_epoch: 50, nnz: 0 },
547 delta_scale: 0.0, entries: Vec::new(),
548 };
549 let dec = decode_delta(&encode_delta(&d)).unwrap();
550 assert_eq!(dec.entries.len(), 0);
551 }
552
553 #[test]
554 fn test_single_entry_delta() {
555 let old = vec![1.0; 100];
556 let mut new = old.clone();
557 new[50] = 2.0;
558 let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
559 assert_eq!(d.entries.len(), 1);
560 assert_eq!(d.entries[0].index, 50);
561 let mut base = old.clone();
562 apply_delta(&mut base, &d);
563 assert!((base[50] - 2.0).abs() < 0.01);
564 }
565
566 #[test]
567 fn test_full_density_delta() {
568 let old = vec![0.0; 4];
569 let new = vec![0.1, 0.2, 0.3, 0.4];
570 let d = compute_delta(&old, &new, 1, 0, 0, 0.001, 1.1).unwrap();
571 assert_eq!(d.entries.len(), 4);
572 let mut base = old.clone();
573 apply_delta(&mut base, &d);
574 for i in 0..4 { assert!((base[i] - new[i]).abs() < 0.01, "index {i}"); }
575 }
576
577 #[test]
578 fn test_compute_apply_roundtrip_64() {
579 let old: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
580 let mut new = old.clone();
581 new[5] += 0.5; new[10] -= 0.3; new[60] += 1.0;
582 let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
583 let mut recon = old.clone();
584 apply_delta(&mut recon, &d);
585 for i in 0..64 { assert!((recon[i] - new[i]).abs() < 0.01, "index {i}"); }
586 }
587
588 #[test]
589 fn test_reconstruction_error_zero_for_exact() {
590 let (m, n) = (4, 3);
592 let data: Vec<f32> = (0..m * n).map(|idx| {
593 let (i, j) = (idx / n, idx % n);
594 (i as f32 + 1.0) * (j as f32 + 1.0)
595 }).collect();
596 let factors = FactorSet::from_data(&data, m, n, 1);
597 let err = factors.reconstruction_error(&data);
598 assert!(err < 0.01, "err={err} too large for rank-1 data");
599 }
600
601 #[test]
602 fn test_reconstruction_error_decreases_with_rank() {
603 let (m, n) = (8, 6);
604 let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.7).sin()).collect();
605 let err1 = FactorSet::from_data(&data, m, n, 1).reconstruction_error(&data);
606 let err3 = FactorSet::from_data(&data, m, n, 3).reconstruction_error(&data);
607 assert!(err3 <= err1 + 1e-6, "err3={err3} > err1={err1}");
608 }
609
610 #[test]
611 fn test_energy_captured_rank1_data() {
612 let (m, n) = (4, 3);
613 let data: Vec<f32> = (0..m * n).map(|idx| {
614 let (i, j) = (idx / n, idx % n);
615 (i as f32 + 1.0) * (j as f32 + 1.0)
616 }).collect();
617 let factors = FactorSet::from_data(&data, m, n, 1);
618 let energy = factors.energy_captured(&data);
619 assert!(energy > 0.95, "energy={energy} too low for rank-1 data");
620 }
621
622 #[test]
623 fn test_compression_ratio_meaningful() {
624 let (m, n) = (16, 16);
625 let data: Vec<f32> = (0..m * n).map(|i| i as f32).collect();
626 let factors = FactorSet::from_data(&data, m, n, 2);
627 let ratio = factors.compression_ratio(m * n);
628 assert!(ratio > 1.0, "ratio={ratio} should be > 1");
630 }
631
632 #[test]
633 fn test_from_data_adaptive_stops_early() {
634 let (m, n) = (4, 3);
635 let data: Vec<f32> = (0..m * n).map(|idx| {
637 let (i, j) = (idx / n, idx % n);
638 (i as f32 + 1.0) * (j as f32 + 1.0)
639 }).collect();
640 let factors = FactorSet::from_data_adaptive(&data, m, n, 5, 0.05);
641 assert!(factors.k <= 2, "k={} should be small for rank-1 data", factors.k);
643 }
644
645 #[test]
646 fn test_from_data_adaptive_increases_rank() {
647 let (m, n) = (8, 6);
648 let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.3).sin() + (i as f32 * 0.7).cos()).collect();
650 let factors = FactorSet::from_data_adaptive(&data, m, n, 6, 0.01);
651 let err = factors.reconstruction_error(&data);
652 assert!(err < 0.1 || factors.k == 6, "err={err}, k={}", factors.k);
654 }
655}