Skip to main content

ruvector_temporal_tensor/
delta.rs

1//! Delta compression, delta chains, and reconstruction policies (ADR-021).
2//!
3//! Sparse delta encoding for incremental tensor updates, bounded-depth delta
4//! chain management with automatic compaction, and SVD-based low-rank factor
5//! reconstruction. All structures are WASM-safe (no `f64` in hot paths).
6
7use crate::store::StoreError;
8
9#[allow(unused_imports)]
10use crate::store::{BlockKey, ReconstructPolicy};
11
12/// Size of the fixed portion of a serialized delta (header + scale).
13const DELTA_HEADER_BYTES: usize = 34;
14/// Size of a single serialized sparse entry (index: u16 + value: i16).
15const DELTA_ENTRY_BYTES: usize = 4;
16/// Maximum power-iteration steps per singular component.
17const POWER_ITER_MAX: usize = 30;
18/// Convergence threshold for power iteration.
19const POWER_ITER_EPS: f32 = 1e-10;
20
21/// Header for a delta record.
22#[derive(Clone, Debug)]
23pub struct DeltaHeader {
24    pub tensor_id: u128,
25    pub block_index: u32,
26    pub base_epoch: u64,
27    pub nnz: u16,
28}
29
30/// A single sparse delta entry: index + quantized value.
31#[derive(Clone, Copy, Debug)]
32pub struct SparseEntry {
33    pub index: u16,
34    pub value: i16,
35}
36
37/// Complete delta record: header + sparse entries + scale.
38///
39/// Actual diff = `entry.value as f32 * delta_scale`.
40#[derive(Clone, Debug)]
41pub struct DeltaRecord {
42    pub header: DeltaHeader,
43    pub delta_scale: f32,
44    pub entries: Vec<SparseEntry>,
45}
46
47/// Compute a sparse delta between `old` and `new` data.
48///
49/// Keeps entries whose absolute change exceeds `threshold`. Returns `None`
50/// if the changed fraction meets or exceeds `max_change_fraction`.
51///
52/// # Panics
53///
54/// Panics if `old.len() != new.len()`.
55pub fn compute_delta(
56    old: &[f32],
57    new: &[f32],
58    tensor_id: u128,
59    block_index: u32,
60    base_epoch: u64,
61    threshold: f32,
62    max_change_fraction: f32,
63) -> Option<DeltaRecord> {
64    assert_eq!(old.len(), new.len(), "old and new must have equal length");
65    let n = old.len();
66    if n == 0 {
67        return Some(DeltaRecord {
68            header: DeltaHeader {
69                tensor_id,
70                block_index,
71                base_epoch,
72                nnz: 0,
73            },
74            delta_scale: 0.0,
75            entries: Vec::new(),
76        });
77    }
78
79    let mut changed: Vec<(u16, f32)> = Vec::new();
80    let mut max_abs = 0.0f32;
81    for i in 0..n {
82        let diff = new[i] - old[i];
83        if diff.abs() >= threshold {
84            changed.push((i as u16, diff));
85            if diff.abs() > max_abs {
86                max_abs = diff.abs();
87            }
88        }
89    }
90
91    if changed.len() as f32 / n as f32 >= max_change_fraction {
92        return None;
93    }
94
95    let delta_scale = if max_abs == 0.0 {
96        1.0
97    } else {
98        max_abs / i16::MAX as f32
99    };
100    let inv_scale = 1.0 / delta_scale;
101    let entries: Vec<SparseEntry> = changed
102        .iter()
103        .map(|&(idx, diff)| {
104            let q = (diff * inv_scale).round() as i32;
105            SparseEntry {
106                index: idx,
107                value: q.clamp(i16::MIN as i32, i16::MAX as i32) as i16,
108            }
109        })
110        .collect();
111
112    Some(DeltaRecord {
113        header: DeltaHeader {
114            tensor_id,
115            block_index,
116            base_epoch,
117            nnz: entries.len() as u16,
118        },
119        delta_scale,
120        entries,
121    })
122}
123
124/// Apply a delta to a base data vector in-place.
125///
126/// Entries whose indices exceed the base length are silently skipped.
127pub fn apply_delta(base: &mut [f32], delta: &DeltaRecord) {
128    let scale = delta.delta_scale;
129    for entry in &delta.entries {
130        let idx = entry.index as usize;
131        if idx < base.len() {
132            base[idx] += entry.value as f32 * scale;
133        }
134    }
135}
136
137/// A chain of deltas applied to a base block.
138/// Invariant: `deltas.len() <= max_chain_len`.
139#[derive(Clone, Debug)]
140pub struct DeltaChain {
141    base_data: Vec<f32>,
142    deltas: Vec<DeltaRecord>,
143    max_chain_len: u8,
144}
145
146impl DeltaChain {
147    /// Create a new chain with a base block.
148    pub fn new(base_data: Vec<f32>, max_chain_len: u8) -> Self {
149        Self {
150            base_data,
151            deltas: Vec::new(),
152            max_chain_len,
153        }
154    }
155
156    /// Append a delta. Returns `Err(StoreError::DeltaChainTooLong)` at max length.
157    pub fn append(&mut self, delta: DeltaRecord) -> Result<(), StoreError> {
158        if self.deltas.len() >= self.max_chain_len as usize {
159            return Err(StoreError::DeltaChainTooLong);
160        }
161        self.deltas.push(delta);
162        Ok(())
163    }
164
165    /// Reconstruct the current state by applying all deltas to the base.
166    pub fn reconstruct(&self) -> Vec<f32> {
167        let mut result = self.base_data.clone();
168        for delta in &self.deltas {
169            apply_delta(&mut result, delta);
170        }
171        result
172    }
173
174    /// Compact the chain: apply all deltas to base, clear delta list.
175    pub fn compact(&mut self) {
176        if self.deltas.is_empty() {
177            return;
178        }
179        for delta in &self.deltas {
180            apply_delta(&mut self.base_data, delta);
181        }
182        self.deltas.clear();
183    }
184
185    /// Number of deltas in the chain.
186    #[inline]
187    pub fn chain_len(&self) -> usize {
188        self.deltas.len()
189    }
190
191    /// Whether the chain needs compaction (at max length).
192    #[inline]
193    pub fn needs_compaction(&self) -> bool {
194        self.deltas.len() >= self.max_chain_len as usize
195    }
196
197    /// Total storage bytes: base + serialized size of all deltas.
198    pub fn total_bytes(&self) -> usize {
199        let base_bytes = self.base_data.len() * 4;
200        let delta_bytes: usize = self
201            .deltas
202            .iter()
203            .map(|d| DELTA_HEADER_BYTES + d.entries.len() * DELTA_ENTRY_BYTES)
204            .sum();
205        base_bytes + delta_bytes
206    }
207}
208
209/// Low-rank factor representation for reconstruction.
210///
211/// Stores U (m x k), S (k), V (k x n) such that data ~ U * diag(S) * V.
212/// All matrices are row-major.
213#[derive(Clone, Debug)]
214pub struct FactorSet {
215    pub m: usize,
216    pub n: usize,
217    pub k: usize,
218    pub u_data: Vec<f32>, // m * k elements
219    pub s_data: Vec<f32>, // k elements
220    pub v_data: Vec<f32>, // k * n elements
221}
222
223impl FactorSet {
224    /// Reconstruct the full data from factors: U * diag(S) * V.
225    pub fn reconstruct(&self) -> Vec<f32> {
226        let mut out = vec![0.0f32; self.m * self.n];
227        for r in 0..self.k {
228            let s_r = self.s_data[r];
229            for i in 0..self.m {
230                let u_s = self.u_data[i * self.k + r] * s_r;
231                let row = i * self.n;
232                let v_off = r * self.n;
233                for j in 0..self.n {
234                    out[row + j] += u_s * self.v_data[v_off + j];
235                }
236            }
237        }
238        out
239    }
240
241    /// Compute storage size in bytes: (m*k + k + k*n) * 4.
242    pub fn storage_bytes(&self) -> usize {
243        (self.m * self.k + self.k + self.k * self.n) * 4
244    }
245
246    /// Create from a flat data vector using truncated SVD via power iteration.
247    ///
248    /// Simplified implementation suitable for moderate-sized matrices.
249    /// Extracts top-`rank` singular triplets with successive deflation.
250    ///
251    /// # Panics
252    ///
253    /// Panics if `data.len() != rows * cols`.
254    pub fn from_data(data: &[f32], rows: usize, cols: usize, rank: usize) -> Self {
255        assert_eq!(
256            data.len(),
257            rows * cols,
258            "data length must equal rows * cols"
259        );
260        let (m, n) = (rows, cols);
261        let k = rank.min(m).min(n);
262        let mut work = data.to_vec();
263        let mut u_data = vec![0.0f32; m * k];
264        let mut s_data = vec![0.0f32; k];
265        let mut v_data = vec![0.0f32; k * n];
266
267        for r in 0..k {
268            // Deterministic initial vector: Fibonacci-hash sign pattern.
269            let inv_sqrt_n = 1.0 / (n as f32).sqrt();
270            let mut v = vec![0.0f32; n];
271            for j in 0..n {
272                let seed = (j as u32)
273                    .wrapping_mul(2_654_435_761)
274                    .wrapping_add((r as u32).wrapping_mul(0x9E37_79B9));
275                v[j] = if seed & 1 == 0 {
276                    inv_sqrt_n
277                } else {
278                    -inv_sqrt_n
279                };
280            }
281            let mut u = vec![0.0f32; m];
282            let mut sigma = 0.0f32;
283
284            for _ in 0..POWER_ITER_MAX {
285                // u = work * v
286                for i in 0..m {
287                    let mut acc = 0.0f32;
288                    let row = i * n;
289                    for j in 0..n {
290                        acc += work[row + j] * v[j];
291                    }
292                    u[i] = acc;
293                }
294                let su: f32 = u.iter().map(|x| x * x).sum::<f32>().sqrt();
295                if su < POWER_ITER_EPS {
296                    sigma = 0.0;
297                    break;
298                }
299                let inv = 1.0 / su;
300                for x in u.iter_mut() {
301                    *x *= inv;
302                }
303
304                // v = work^T * u
305                for j in 0..n {
306                    let mut acc = 0.0f32;
307                    for i in 0..m {
308                        acc += work[i * n + j] * u[i];
309                    }
310                    v[j] = acc;
311                }
312                let sv: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
313                if sv < POWER_ITER_EPS {
314                    sigma = su;
315                    break;
316                }
317                sigma = sv;
318                let inv = 1.0 / sv;
319                for x in v.iter_mut() {
320                    *x *= inv;
321                }
322            }
323
324            s_data[r] = sigma;
325            for i in 0..m {
326                u_data[i * k + r] = u[i];
327            }
328            for j in 0..n {
329                v_data[r * n + j] = v[j];
330            }
331
332            // Deflate: work -= sigma * u * v^T
333            if sigma > POWER_ITER_EPS {
334                for i in 0..m {
335                    let us = u[i] * sigma;
336                    let row = i * n;
337                    for j in 0..n {
338                        work[row + j] -= us * v[j];
339                    }
340                }
341            }
342        }
343        Self {
344            m,
345            n,
346            k,
347            u_data,
348            s_data,
349            v_data,
350        }
351    }
352
353    /// Compute the relative reconstruction error (Frobenius norm).
354    ///
355    /// Returns `||original - reconstructed|| / ||original||`.
356    /// Returns 0.0 if the original has zero norm.
357    pub fn reconstruction_error(&self, original: &[f32]) -> f32 {
358        let reconstructed = self.reconstruct();
359        let mut diff_sq = 0.0f32;
360        let mut orig_sq = 0.0f32;
361        for (i, &o) in original.iter().enumerate() {
362            let r = if i < reconstructed.len() {
363                reconstructed[i]
364            } else {
365                0.0
366            };
367            diff_sq += (o - r) * (o - r);
368            orig_sq += o * o;
369        }
370        if orig_sq < 1e-30 {
371            return 0.0;
372        }
373        (diff_sq / orig_sq).sqrt()
374    }
375
376    /// Estimate the fraction of total energy (Frobenius norm) captured by factors.
377    ///
378    /// Uses `sum(s_i^2)` as captured energy. Requires the original data to compute
379    /// total energy as `||data||_F^2`. Returns 1.0 if total energy is near zero.
380    pub fn energy_captured(&self, original: &[f32]) -> f32 {
381        let total_energy: f32 = original.iter().map(|x| x * x).sum();
382        if total_energy < 1e-30 {
383            return 1.0;
384        }
385        let captured: f32 = self.s_data.iter().map(|s| s * s).sum();
386        (captured / total_energy).min(1.0)
387    }
388
389    /// Compression ratio: original_elements * 4 bytes / storage_bytes.
390    ///
391    /// Returns 0.0 if storage_bytes is zero.
392    pub fn compression_ratio(&self, original_elements: usize) -> f32 {
393        let raw = original_elements * 4;
394        let stored = self.storage_bytes();
395        if stored == 0 {
396            return 0.0;
397        }
398        raw as f32 / stored as f32
399    }
400
401    /// Create factors with adaptive rank selection.
402    ///
403    /// Starts with rank 1 and increases until either `max_rank` is reached or
404    /// the reconstruction error falls below `target_error`.
405    pub fn from_data_adaptive(
406        data: &[f32],
407        rows: usize,
408        cols: usize,
409        max_rank: usize,
410        target_error: f32,
411    ) -> Self {
412        let max_k = max_rank.min(rows).min(cols);
413        let mut best = Self::from_data(data, rows, cols, 1);
414        for rank in 2..=max_k {
415            let err = best.reconstruction_error(data);
416            if err <= target_error {
417                break;
418            }
419            best = Self::from_data(data, rows, cols, rank);
420        }
421        best
422    }
423}
424
425/// Encode a [`DeltaRecord`] to bytes (little-endian, ADR-021 section 4.1).
426pub fn encode_delta(delta: &DeltaRecord) -> Vec<u8> {
427    let mut buf = Vec::with_capacity(DELTA_HEADER_BYTES + delta.entries.len() * DELTA_ENTRY_BYTES);
428    buf.extend_from_slice(&delta.header.tensor_id.to_le_bytes());
429    buf.extend_from_slice(&delta.header.block_index.to_le_bytes());
430    buf.extend_from_slice(&delta.header.base_epoch.to_le_bytes());
431    buf.extend_from_slice(&delta.header.nnz.to_le_bytes());
432    buf.extend_from_slice(&delta.delta_scale.to_le_bytes());
433    for entry in &delta.entries {
434        buf.extend_from_slice(&entry.index.to_le_bytes());
435        buf.extend_from_slice(&entry.value.to_le_bytes());
436    }
437    buf
438}
439
440/// Decode a [`DeltaRecord`] from bytes.
441///
442/// Returns `Err(StoreError::InvalidBlock)` on truncated or malformed input.
443pub fn decode_delta(data: &[u8]) -> Result<DeltaRecord, StoreError> {
444    if data.len() < DELTA_HEADER_BYTES {
445        return Err(StoreError::InvalidBlock);
446    }
447    let tensor_id = u128::from_le_bytes(
448        data[0..16]
449            .try_into()
450            .map_err(|_| StoreError::InvalidBlock)?,
451    );
452    let block_index = u32::from_le_bytes(
453        data[16..20]
454            .try_into()
455            .map_err(|_| StoreError::InvalidBlock)?,
456    );
457    let base_epoch = u64::from_le_bytes(
458        data[20..28]
459            .try_into()
460            .map_err(|_| StoreError::InvalidBlock)?,
461    );
462    let nnz = u16::from_le_bytes(
463        data[28..30]
464            .try_into()
465            .map_err(|_| StoreError::InvalidBlock)?,
466    );
467    let delta_scale = f32::from_le_bytes(
468        data[30..34]
469            .try_into()
470            .map_err(|_| StoreError::InvalidBlock)?,
471    );
472
473    if data.len() < DELTA_HEADER_BYTES + (nnz as usize) * DELTA_ENTRY_BYTES {
474        return Err(StoreError::InvalidBlock);
475    }
476    let mut entries = Vec::with_capacity(nnz as usize);
477    let mut off = DELTA_HEADER_BYTES;
478    for _ in 0..nnz {
479        let index = u16::from_le_bytes(
480            data[off..off + 2]
481                .try_into()
482                .map_err(|_| StoreError::InvalidBlock)?,
483        );
484        let value = i16::from_le_bytes(
485            data[off + 2..off + 4]
486                .try_into()
487                .map_err(|_| StoreError::InvalidBlock)?,
488        );
489        entries.push(SparseEntry { index, value });
490        off += DELTA_ENTRY_BYTES;
491    }
492
493    Ok(DeltaRecord {
494        header: DeltaHeader {
495            tensor_id,
496            block_index,
497            base_epoch,
498            nnz,
499        },
500        delta_scale,
501        entries,
502    })
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    fn make_delta(entries: Vec<(u16, i16)>, scale: f32) -> DeltaRecord {
510        let sparse: Vec<SparseEntry> = entries
511            .iter()
512            .map(|&(i, v)| SparseEntry { index: i, value: v })
513            .collect();
514        DeltaRecord {
515            header: DeltaHeader {
516                tensor_id: 42,
517                block_index: 0,
518                base_epoch: 1,
519                nnz: sparse.len() as u16,
520            },
521            delta_scale: scale,
522            entries: sparse,
523        }
524    }
525
526    #[test]
527    fn test_compute_delta_small_change() {
528        let old = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
529        let mut new = old.clone();
530        new[2] = 3.5;
531        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
532        assert_eq!(d.entries.len(), 1);
533        assert_eq!(d.entries[0].index, 2);
534        assert!(d.delta_scale > 0.0);
535    }
536
537    #[test]
538    fn test_compute_delta_large_change_returns_none() {
539        let old = vec![1.0; 10];
540        let new = vec![5.0; 10];
541        assert!(compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).is_none());
542    }
543
544    #[test]
545    fn test_apply_delta_modifies_base() {
546        let mut base = vec![1.0, 2.0, 3.0, 4.0];
547        apply_delta(&mut base, &make_delta(vec![(1, 100), (3, -50)], 0.01));
548        assert!((base[0] - 1.0).abs() < 1e-6);
549        assert!((base[1] - 3.0).abs() < 1e-6); // 2.0 + 100*0.01
550        assert!((base[2] - 3.0).abs() < 1e-6);
551        assert!((base[3] - 3.5).abs() < 1e-6); // 4.0 - 50*0.01
552    }
553
554    #[test]
555    fn test_chain_append_and_reconstruct() {
556        let mut chain = DeltaChain::new(vec![1.0, 2.0, 3.0, 4.0], 4);
557        chain.append(make_delta(vec![(0, 1000)], 0.001)).unwrap(); // +1.0
558        assert_eq!(chain.chain_len(), 1);
559        let r = chain.reconstruct();
560        assert!((r[0] - 2.0).abs() < 1e-3);
561        assert!((r[1] - 2.0).abs() < 1e-6);
562    }
563
564    #[test]
565    fn test_chain_compact_preserves_state() {
566        let mut chain = DeltaChain::new(vec![0.0; 4], 8);
567        chain.append(make_delta(vec![(0, 100)], 0.1)).unwrap(); // +10.0
568        chain.append(make_delta(vec![(1, 200)], 0.1)).unwrap(); // +20.0
569        let before = chain.reconstruct();
570        chain.compact();
571        assert_eq!(chain.chain_len(), 0);
572        let after = chain.reconstruct();
573        for (a, b) in before.iter().zip(after.iter()) {
574            assert!((a - b).abs() < 1e-6);
575        }
576    }
577
578    #[test]
579    fn test_chain_max_length_enforcement() {
580        let mut chain = DeltaChain::new(vec![1.0; 4], 2);
581        assert!(chain.append(make_delta(vec![(0, 1)], 0.1)).is_ok());
582        assert!(chain.append(make_delta(vec![(1, 1)], 0.1)).is_ok());
583        assert!(chain.append(make_delta(vec![(2, 1)], 0.1)).is_err());
584    }
585
586    #[test]
587    fn test_chain_needs_compaction() {
588        let mut chain = DeltaChain::new(vec![1.0; 4], 2);
589        assert!(!chain.needs_compaction());
590        chain.append(make_delta(vec![(0, 1)], 0.1)).unwrap();
591        assert!(!chain.needs_compaction());
592        chain.append(make_delta(vec![(1, 1)], 0.1)).unwrap();
593        assert!(chain.needs_compaction());
594    }
595
596    #[test]
597    fn test_factor_reconstruct() {
598        let (u, v, s) = (vec![1.0, 2.0, 3.0], vec![4.0, 5.0], 2.0);
599        let f = FactorSet {
600            m: 3,
601            n: 2,
602            k: 1,
603            u_data: u.clone(),
604            s_data: vec![s],
605            v_data: v.clone(),
606        };
607        let r = f.reconstruct();
608        assert_eq!(r.len(), 6);
609        for i in 0..3 {
610            for j in 0..2 {
611                assert!((r[i * 2 + j] - u[i] * s * v[j]).abs() < 1e-6);
612            }
613        }
614    }
615
616    #[test]
617    fn test_factor_from_data_approximation() {
618        let (m, n) = (8, 6);
619        let data: Vec<f32> = (0..m * n)
620            .map(|idx| {
621                let (i, j) = (idx / n, idx % n);
622                (i as f32 + 1.0) * (j as f32 + 1.0)
623            })
624            .collect();
625        let reconstructed = FactorSet::from_data(&data, m, n, 1).reconstruct();
626        let max_err = data
627            .iter()
628            .zip(reconstructed.iter())
629            .map(|(a, b)| (a - b).abs())
630            .fold(0.0f32, f32::max);
631        assert!(
632            max_err < 0.5,
633            "max error {max_err} too large for rank-1 input"
634        );
635    }
636
637    #[test]
638    fn test_encode_decode_roundtrip() {
639        let orig = DeltaRecord {
640            header: DeltaHeader {
641                tensor_id: 0xDEADBEEFCAFEBABE,
642                block_index: 42,
643                base_epoch: 100,
644                nnz: 3,
645            },
646            delta_scale: 0.001,
647            entries: vec![
648                SparseEntry {
649                    index: 10,
650                    value: 500,
651                },
652                SparseEntry {
653                    index: 20,
654                    value: -300,
655                },
656                SparseEntry {
657                    index: 30,
658                    value: 1,
659                },
660            ],
661        };
662        let bytes = encode_delta(&orig);
663        assert_eq!(bytes.len(), DELTA_HEADER_BYTES + 3 * DELTA_ENTRY_BYTES);
664        let dec = decode_delta(&bytes).unwrap();
665        assert_eq!(dec.header.tensor_id, orig.header.tensor_id);
666        assert_eq!(dec.header.block_index, orig.header.block_index);
667        assert_eq!(dec.header.nnz, orig.header.nnz);
668        assert!((dec.delta_scale - orig.delta_scale).abs() < 1e-10);
669        for (a, b) in dec.entries.iter().zip(orig.entries.iter()) {
670            assert_eq!(a.index, b.index);
671            assert_eq!(a.value, b.value);
672        }
673    }
674
675    #[test]
676    fn test_decode_truncated_header() {
677        assert!(decode_delta(&vec![0u8; 20]).is_err());
678    }
679
680    #[test]
681    fn test_decode_truncated_entries() {
682        let mut bytes = encode_delta(&make_delta(vec![(0, 1), (1, 2)], 1.0));
683        bytes[28] = 5;
684        bytes[29] = 0; // claim 5 entries, only 2 present
685        assert!(decode_delta(&bytes).is_err());
686    }
687
688    #[test]
689    fn test_empty_delta_roundtrip() {
690        let d = DeltaRecord {
691            header: DeltaHeader {
692                tensor_id: 99,
693                block_index: 7,
694                base_epoch: 50,
695                nnz: 0,
696            },
697            delta_scale: 0.0,
698            entries: Vec::new(),
699        };
700        let dec = decode_delta(&encode_delta(&d)).unwrap();
701        assert_eq!(dec.entries.len(), 0);
702    }
703
704    #[test]
705    fn test_single_entry_delta() {
706        let old = vec![1.0; 100];
707        let mut new = old.clone();
708        new[50] = 2.0;
709        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
710        assert_eq!(d.entries.len(), 1);
711        assert_eq!(d.entries[0].index, 50);
712        let mut base = old.clone();
713        apply_delta(&mut base, &d);
714        assert!((base[50] - 2.0).abs() < 0.01);
715    }
716
717    #[test]
718    fn test_full_density_delta() {
719        let old = vec![0.0; 4];
720        let new = vec![0.1, 0.2, 0.3, 0.4];
721        let d = compute_delta(&old, &new, 1, 0, 0, 0.001, 1.1).unwrap();
722        assert_eq!(d.entries.len(), 4);
723        let mut base = old.clone();
724        apply_delta(&mut base, &d);
725        for i in 0..4 {
726            assert!((base[i] - new[i]).abs() < 0.01, "index {i}");
727        }
728    }
729
730    #[test]
731    fn test_compute_apply_roundtrip_64() {
732        let old: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
733        let mut new = old.clone();
734        new[5] += 0.5;
735        new[10] -= 0.3;
736        new[60] += 1.0;
737        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
738        let mut recon = old.clone();
739        apply_delta(&mut recon, &d);
740        for i in 0..64 {
741            assert!((recon[i] - new[i]).abs() < 0.01, "index {i}");
742        }
743    }
744
745    #[test]
746    fn test_reconstruction_error_zero_for_exact() {
747        // Rank-1 data should be exactly reconstructed with rank-1 factors
748        let (m, n) = (4, 3);
749        let data: Vec<f32> = (0..m * n)
750            .map(|idx| {
751                let (i, j) = (idx / n, idx % n);
752                (i as f32 + 1.0) * (j as f32 + 1.0)
753            })
754            .collect();
755        let factors = FactorSet::from_data(&data, m, n, 1);
756        let err = factors.reconstruction_error(&data);
757        assert!(err < 0.01, "err={err} too large for rank-1 data");
758    }
759
760    #[test]
761    fn test_reconstruction_error_decreases_with_rank() {
762        let (m, n) = (8, 6);
763        let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.7).sin()).collect();
764        let err1 = FactorSet::from_data(&data, m, n, 1).reconstruction_error(&data);
765        let err3 = FactorSet::from_data(&data, m, n, 3).reconstruction_error(&data);
766        assert!(err3 <= err1 + 1e-6, "err3={err3} > err1={err1}");
767    }
768
769    #[test]
770    fn test_energy_captured_rank1_data() {
771        let (m, n) = (4, 3);
772        let data: Vec<f32> = (0..m * n)
773            .map(|idx| {
774                let (i, j) = (idx / n, idx % n);
775                (i as f32 + 1.0) * (j as f32 + 1.0)
776            })
777            .collect();
778        let factors = FactorSet::from_data(&data, m, n, 1);
779        let energy = factors.energy_captured(&data);
780        assert!(energy > 0.95, "energy={energy} too low for rank-1 data");
781    }
782
783    #[test]
784    fn test_compression_ratio_meaningful() {
785        let (m, n) = (16, 16);
786        let data: Vec<f32> = (0..m * n).map(|i| i as f32).collect();
787        let factors = FactorSet::from_data(&data, m, n, 2);
788        let ratio = factors.compression_ratio(m * n);
789        // rank-2 storage: (16*2 + 2 + 2*16) * 4 = 264 bytes vs 16*16*4 = 1024 bytes
790        assert!(ratio > 1.0, "ratio={ratio} should be > 1");
791    }
792
793    #[test]
794    fn test_from_data_adaptive_stops_early() {
795        let (m, n) = (4, 3);
796        // Rank-1 data: adaptive should stop at rank 1
797        let data: Vec<f32> = (0..m * n)
798            .map(|idx| {
799                let (i, j) = (idx / n, idx % n);
800                (i as f32 + 1.0) * (j as f32 + 1.0)
801            })
802            .collect();
803        let factors = FactorSet::from_data_adaptive(&data, m, n, 5, 0.05);
804        // Should use rank 1 since data is rank 1
805        assert!(
806            factors.k <= 2,
807            "k={} should be small for rank-1 data",
808            factors.k
809        );
810    }
811
812    #[test]
813    fn test_from_data_adaptive_increases_rank() {
814        let (m, n) = (8, 6);
815        // Multi-rank data
816        let data: Vec<f32> = (0..m * n)
817            .map(|i| (i as f32 * 0.3).sin() + (i as f32 * 0.7).cos())
818            .collect();
819        let factors = FactorSet::from_data_adaptive(&data, m, n, 6, 0.01);
820        let err = factors.reconstruction_error(&data);
821        // Should achieve close to target error or use max rank
822        assert!(err < 0.1 || factors.k == 6, "err={err}, k={}", factors.k);
823    }
824}