Skip to main content

ruvector_temporal_tensor/
delta.rs

1//! Delta compression, delta chains, and reconstruction policies (ADR-021).
2//!
3//! Sparse delta encoding for incremental tensor updates, bounded-depth delta
4//! chain management with automatic compaction, and SVD-based low-rank factor
5//! reconstruction. All structures are WASM-safe (no `f64` in hot paths).
6
7use crate::store::StoreError;
8
9#[allow(unused_imports)]
10use crate::store::{BlockKey, ReconstructPolicy};
11
12/// Size of the fixed portion of a serialized delta (header + scale).
13const DELTA_HEADER_BYTES: usize = 34;
14/// Size of a single serialized sparse entry (index: u16 + value: i16).
15const DELTA_ENTRY_BYTES: usize = 4;
16/// Maximum power-iteration steps per singular component.
17const POWER_ITER_MAX: usize = 30;
18/// Convergence threshold for power iteration.
19const POWER_ITER_EPS: f32 = 1e-10;
20
21/// Header for a delta record.
22#[derive(Clone, Debug)]
23pub struct DeltaHeader {
24    pub tensor_id: u128,
25    pub block_index: u32,
26    pub base_epoch: u64,
27    pub nnz: u16,
28}
29
30/// A single sparse delta entry: index + quantized value.
31#[derive(Clone, Copy, Debug)]
32pub struct SparseEntry {
33    pub index: u16,
34    pub value: i16,
35}
36
37/// Complete delta record: header + sparse entries + scale.
38///
39/// Actual diff = `entry.value as f32 * delta_scale`.
40#[derive(Clone, Debug)]
41pub struct DeltaRecord {
42    pub header: DeltaHeader,
43    pub delta_scale: f32,
44    pub entries: Vec<SparseEntry>,
45}
46
47/// Compute a sparse delta between `old` and `new` data.
48///
49/// Keeps entries whose absolute change exceeds `threshold`. Returns `None`
50/// if the changed fraction meets or exceeds `max_change_fraction`.
51///
52/// # Panics
53///
54/// Panics if `old.len() != new.len()`.
55pub fn compute_delta(
56    old: &[f32],
57    new: &[f32],
58    tensor_id: u128,
59    block_index: u32,
60    base_epoch: u64,
61    threshold: f32,
62    max_change_fraction: f32,
63) -> Option<DeltaRecord> {
64    assert_eq!(old.len(), new.len(), "old and new must have equal length");
65    let n = old.len();
66    if n == 0 {
67        return Some(DeltaRecord {
68            header: DeltaHeader { tensor_id, block_index, base_epoch, nnz: 0 },
69            delta_scale: 0.0,
70            entries: Vec::new(),
71        });
72    }
73
74    let mut changed: Vec<(u16, f32)> = Vec::new();
75    let mut max_abs = 0.0f32;
76    for i in 0..n {
77        let diff = new[i] - old[i];
78        if diff.abs() >= threshold {
79            changed.push((i as u16, diff));
80            if diff.abs() > max_abs { max_abs = diff.abs(); }
81        }
82    }
83
84    if changed.len() as f32 / n as f32 >= max_change_fraction {
85        return None;
86    }
87
88    let delta_scale = if max_abs == 0.0 { 1.0 } else { max_abs / i16::MAX as f32 };
89    let inv_scale = 1.0 / delta_scale;
90    let entries: Vec<SparseEntry> = changed
91        .iter()
92        .map(|&(idx, diff)| {
93            let q = (diff * inv_scale).round() as i32;
94            SparseEntry { index: idx, value: q.clamp(i16::MIN as i32, i16::MAX as i32) as i16 }
95        })
96        .collect();
97
98    Some(DeltaRecord {
99        header: DeltaHeader { tensor_id, block_index, base_epoch, nnz: entries.len() as u16 },
100        delta_scale,
101        entries,
102    })
103}
104
105/// Apply a delta to a base data vector in-place.
106///
107/// Entries whose indices exceed the base length are silently skipped.
108pub fn apply_delta(base: &mut [f32], delta: &DeltaRecord) {
109    let scale = delta.delta_scale;
110    for entry in &delta.entries {
111        let idx = entry.index as usize;
112        if idx < base.len() {
113            base[idx] += entry.value as f32 * scale;
114        }
115    }
116}
117
118/// A chain of deltas applied to a base block.
119/// Invariant: `deltas.len() <= max_chain_len`.
120#[derive(Clone, Debug)]
121pub struct DeltaChain {
122    base_data: Vec<f32>,
123    deltas: Vec<DeltaRecord>,
124    max_chain_len: u8,
125}
126
127impl DeltaChain {
128    /// Create a new chain with a base block.
129    pub fn new(base_data: Vec<f32>, max_chain_len: u8) -> Self {
130        Self { base_data, deltas: Vec::new(), max_chain_len }
131    }
132
133    /// Append a delta. Returns `Err(StoreError::DeltaChainTooLong)` at max length.
134    pub fn append(&mut self, delta: DeltaRecord) -> Result<(), StoreError> {
135        if self.deltas.len() >= self.max_chain_len as usize {
136            return Err(StoreError::DeltaChainTooLong);
137        }
138        self.deltas.push(delta);
139        Ok(())
140    }
141
142    /// Reconstruct the current state by applying all deltas to the base.
143    pub fn reconstruct(&self) -> Vec<f32> {
144        let mut result = self.base_data.clone();
145        for delta in &self.deltas {
146            apply_delta(&mut result, delta);
147        }
148        result
149    }
150
151    /// Compact the chain: apply all deltas to base, clear delta list.
152    pub fn compact(&mut self) {
153        if self.deltas.is_empty() { return; }
154        for delta in &self.deltas {
155            apply_delta(&mut self.base_data, delta);
156        }
157        self.deltas.clear();
158    }
159
160    /// Number of deltas in the chain.
161    #[inline]
162    pub fn chain_len(&self) -> usize { self.deltas.len() }
163
164    /// Whether the chain needs compaction (at max length).
165    #[inline]
166    pub fn needs_compaction(&self) -> bool {
167        self.deltas.len() >= self.max_chain_len as usize
168    }
169
170    /// Total storage bytes: base + serialized size of all deltas.
171    pub fn total_bytes(&self) -> usize {
172        let base_bytes = self.base_data.len() * 4;
173        let delta_bytes: usize = self.deltas.iter()
174            .map(|d| DELTA_HEADER_BYTES + d.entries.len() * DELTA_ENTRY_BYTES)
175            .sum();
176        base_bytes + delta_bytes
177    }
178}
179
180/// Low-rank factor representation for reconstruction.
181///
182/// Stores U (m x k), S (k), V (k x n) such that data ~ U * diag(S) * V.
183/// All matrices are row-major.
184#[derive(Clone, Debug)]
185pub struct FactorSet {
186    pub m: usize,
187    pub n: usize,
188    pub k: usize,
189    pub u_data: Vec<f32>,  // m * k elements
190    pub s_data: Vec<f32>,  // k elements
191    pub v_data: Vec<f32>,  // k * n elements
192}
193
194impl FactorSet {
195    /// Reconstruct the full data from factors: U * diag(S) * V.
196    pub fn reconstruct(&self) -> Vec<f32> {
197        let mut out = vec![0.0f32; self.m * self.n];
198        for r in 0..self.k {
199            let s_r = self.s_data[r];
200            for i in 0..self.m {
201                let u_s = self.u_data[i * self.k + r] * s_r;
202                let row = i * self.n;
203                let v_off = r * self.n;
204                for j in 0..self.n {
205                    out[row + j] += u_s * self.v_data[v_off + j];
206                }
207            }
208        }
209        out
210    }
211
212    /// Compute storage size in bytes: (m*k + k + k*n) * 4.
213    pub fn storage_bytes(&self) -> usize {
214        (self.m * self.k + self.k + self.k * self.n) * 4
215    }
216
217    /// Create from a flat data vector using truncated SVD via power iteration.
218    ///
219    /// Simplified implementation suitable for moderate-sized matrices.
220    /// Extracts top-`rank` singular triplets with successive deflation.
221    ///
222    /// # Panics
223    ///
224    /// Panics if `data.len() != rows * cols`.
225    pub fn from_data(data: &[f32], rows: usize, cols: usize, rank: usize) -> Self {
226        assert_eq!(data.len(), rows * cols, "data length must equal rows * cols");
227        let (m, n) = (rows, cols);
228        let k = rank.min(m).min(n);
229        let mut work = data.to_vec();
230        let mut u_data = vec![0.0f32; m * k];
231        let mut s_data = vec![0.0f32; k];
232        let mut v_data = vec![0.0f32; k * n];
233
234        for r in 0..k {
235            // Deterministic initial vector: Fibonacci-hash sign pattern.
236            let inv_sqrt_n = 1.0 / (n as f32).sqrt();
237            let mut v = vec![0.0f32; n];
238            for j in 0..n {
239                let seed = (j as u32).wrapping_mul(2_654_435_761)
240                    .wrapping_add((r as u32).wrapping_mul(0x9E37_79B9));
241                v[j] = if seed & 1 == 0 { inv_sqrt_n } else { -inv_sqrt_n };
242            }
243            let mut u = vec![0.0f32; m];
244            let mut sigma = 0.0f32;
245
246            for _ in 0..POWER_ITER_MAX {
247                // u = work * v
248                for i in 0..m {
249                    let mut acc = 0.0f32;
250                    let row = i * n;
251                    for j in 0..n { acc += work[row + j] * v[j]; }
252                    u[i] = acc;
253                }
254                let su: f32 = u.iter().map(|x| x * x).sum::<f32>().sqrt();
255                if su < POWER_ITER_EPS { sigma = 0.0; break; }
256                let inv = 1.0 / su;
257                for x in u.iter_mut() { *x *= inv; }
258
259                // v = work^T * u
260                for j in 0..n {
261                    let mut acc = 0.0f32;
262                    for i in 0..m { acc += work[i * n + j] * u[i]; }
263                    v[j] = acc;
264                }
265                let sv: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
266                if sv < POWER_ITER_EPS { sigma = su; break; }
267                sigma = sv;
268                let inv = 1.0 / sv;
269                for x in v.iter_mut() { *x *= inv; }
270            }
271
272            s_data[r] = sigma;
273            for i in 0..m { u_data[i * k + r] = u[i]; }
274            for j in 0..n { v_data[r * n + j] = v[j]; }
275
276            // Deflate: work -= sigma * u * v^T
277            if sigma > POWER_ITER_EPS {
278                for i in 0..m {
279                    let us = u[i] * sigma;
280                    let row = i * n;
281                    for j in 0..n { work[row + j] -= us * v[j]; }
282                }
283            }
284        }
285        Self { m, n, k, u_data, s_data, v_data }
286    }
287
288    /// Compute the relative reconstruction error (Frobenius norm).
289    ///
290    /// Returns `||original - reconstructed|| / ||original||`.
291    /// Returns 0.0 if the original has zero norm.
292    pub fn reconstruction_error(&self, original: &[f32]) -> f32 {
293        let reconstructed = self.reconstruct();
294        let mut diff_sq = 0.0f32;
295        let mut orig_sq = 0.0f32;
296        for (i, &o) in original.iter().enumerate() {
297            let r = if i < reconstructed.len() { reconstructed[i] } else { 0.0 };
298            diff_sq += (o - r) * (o - r);
299            orig_sq += o * o;
300        }
301        if orig_sq < 1e-30 {
302            return 0.0;
303        }
304        (diff_sq / orig_sq).sqrt()
305    }
306
307    /// Estimate the fraction of total energy (Frobenius norm) captured by factors.
308    ///
309    /// Uses `sum(s_i^2)` as captured energy. Requires the original data to compute
310    /// total energy as `||data||_F^2`. Returns 1.0 if total energy is near zero.
311    pub fn energy_captured(&self, original: &[f32]) -> f32 {
312        let total_energy: f32 = original.iter().map(|x| x * x).sum();
313        if total_energy < 1e-30 {
314            return 1.0;
315        }
316        let captured: f32 = self.s_data.iter().map(|s| s * s).sum();
317        (captured / total_energy).min(1.0)
318    }
319
320    /// Compression ratio: original_elements * 4 bytes / storage_bytes.
321    ///
322    /// Returns 0.0 if storage_bytes is zero.
323    pub fn compression_ratio(&self, original_elements: usize) -> f32 {
324        let raw = original_elements * 4;
325        let stored = self.storage_bytes();
326        if stored == 0 {
327            return 0.0;
328        }
329        raw as f32 / stored as f32
330    }
331
332    /// Create factors with adaptive rank selection.
333    ///
334    /// Starts with rank 1 and increases until either `max_rank` is reached or
335    /// the reconstruction error falls below `target_error`.
336    pub fn from_data_adaptive(
337        data: &[f32],
338        rows: usize,
339        cols: usize,
340        max_rank: usize,
341        target_error: f32,
342    ) -> Self {
343        let max_k = max_rank.min(rows).min(cols);
344        let mut best = Self::from_data(data, rows, cols, 1);
345        for rank in 2..=max_k {
346            let err = best.reconstruction_error(data);
347            if err <= target_error {
348                break;
349            }
350            best = Self::from_data(data, rows, cols, rank);
351        }
352        best
353    }
354}
355
356/// Encode a [`DeltaRecord`] to bytes (little-endian, ADR-021 section 4.1).
357pub fn encode_delta(delta: &DeltaRecord) -> Vec<u8> {
358    let mut buf = Vec::with_capacity(DELTA_HEADER_BYTES + delta.entries.len() * DELTA_ENTRY_BYTES);
359    buf.extend_from_slice(&delta.header.tensor_id.to_le_bytes());
360    buf.extend_from_slice(&delta.header.block_index.to_le_bytes());
361    buf.extend_from_slice(&delta.header.base_epoch.to_le_bytes());
362    buf.extend_from_slice(&delta.header.nnz.to_le_bytes());
363    buf.extend_from_slice(&delta.delta_scale.to_le_bytes());
364    for entry in &delta.entries {
365        buf.extend_from_slice(&entry.index.to_le_bytes());
366        buf.extend_from_slice(&entry.value.to_le_bytes());
367    }
368    buf
369}
370
371/// Decode a [`DeltaRecord`] from bytes.
372///
373/// Returns `Err(StoreError::InvalidBlock)` on truncated or malformed input.
374pub fn decode_delta(data: &[u8]) -> Result<DeltaRecord, StoreError> {
375    if data.len() < DELTA_HEADER_BYTES { return Err(StoreError::InvalidBlock); }
376    let tensor_id = u128::from_le_bytes(data[0..16].try_into().map_err(|_| StoreError::InvalidBlock)?);
377    let block_index = u32::from_le_bytes(data[16..20].try_into().map_err(|_| StoreError::InvalidBlock)?);
378    let base_epoch = u64::from_le_bytes(data[20..28].try_into().map_err(|_| StoreError::InvalidBlock)?);
379    let nnz = u16::from_le_bytes(data[28..30].try_into().map_err(|_| StoreError::InvalidBlock)?);
380    let delta_scale = f32::from_le_bytes(data[30..34].try_into().map_err(|_| StoreError::InvalidBlock)?);
381
382    if data.len() < DELTA_HEADER_BYTES + (nnz as usize) * DELTA_ENTRY_BYTES {
383        return Err(StoreError::InvalidBlock);
384    }
385    let mut entries = Vec::with_capacity(nnz as usize);
386    let mut off = DELTA_HEADER_BYTES;
387    for _ in 0..nnz {
388        let index = u16::from_le_bytes(data[off..off + 2].try_into().map_err(|_| StoreError::InvalidBlock)?);
389        let value = i16::from_le_bytes(data[off + 2..off + 4].try_into().map_err(|_| StoreError::InvalidBlock)?);
390        entries.push(SparseEntry { index, value });
391        off += DELTA_ENTRY_BYTES;
392    }
393
394    Ok(DeltaRecord {
395        header: DeltaHeader { tensor_id, block_index, base_epoch, nnz },
396        delta_scale,
397        entries,
398    })
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn make_delta(entries: Vec<(u16, i16)>, scale: f32) -> DeltaRecord {
406        let sparse: Vec<SparseEntry> = entries.iter()
407            .map(|&(i, v)| SparseEntry { index: i, value: v }).collect();
408        DeltaRecord {
409            header: DeltaHeader { tensor_id: 42, block_index: 0, base_epoch: 1, nnz: sparse.len() as u16 },
410            delta_scale: scale,
411            entries: sparse,
412        }
413    }
414
415    #[test]
416    fn test_compute_delta_small_change() {
417        let old = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
418        let mut new = old.clone();
419        new[2] = 3.5;
420        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
421        assert_eq!(d.entries.len(), 1);
422        assert_eq!(d.entries[0].index, 2);
423        assert!(d.delta_scale > 0.0);
424    }
425
426    #[test]
427    fn test_compute_delta_large_change_returns_none() {
428        let old = vec![1.0; 10];
429        let new = vec![5.0; 10];
430        assert!(compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).is_none());
431    }
432
433    #[test]
434    fn test_apply_delta_modifies_base() {
435        let mut base = vec![1.0, 2.0, 3.0, 4.0];
436        apply_delta(&mut base, &make_delta(vec![(1, 100), (3, -50)], 0.01));
437        assert!((base[0] - 1.0).abs() < 1e-6);
438        assert!((base[1] - 3.0).abs() < 1e-6); // 2.0 + 100*0.01
439        assert!((base[2] - 3.0).abs() < 1e-6);
440        assert!((base[3] - 3.5).abs() < 1e-6); // 4.0 - 50*0.01
441    }
442
443    #[test]
444    fn test_chain_append_and_reconstruct() {
445        let mut chain = DeltaChain::new(vec![1.0, 2.0, 3.0, 4.0], 4);
446        chain.append(make_delta(vec![(0, 1000)], 0.001)).unwrap(); // +1.0
447        assert_eq!(chain.chain_len(), 1);
448        let r = chain.reconstruct();
449        assert!((r[0] - 2.0).abs() < 1e-3);
450        assert!((r[1] - 2.0).abs() < 1e-6);
451    }
452
453    #[test]
454    fn test_chain_compact_preserves_state() {
455        let mut chain = DeltaChain::new(vec![0.0; 4], 8);
456        chain.append(make_delta(vec![(0, 100)], 0.1)).unwrap(); // +10.0
457        chain.append(make_delta(vec![(1, 200)], 0.1)).unwrap(); // +20.0
458        let before = chain.reconstruct();
459        chain.compact();
460        assert_eq!(chain.chain_len(), 0);
461        let after = chain.reconstruct();
462        for (a, b) in before.iter().zip(after.iter()) { assert!((a - b).abs() < 1e-6); }
463    }
464
465    #[test]
466    fn test_chain_max_length_enforcement() {
467        let mut chain = DeltaChain::new(vec![1.0; 4], 2);
468        assert!(chain.append(make_delta(vec![(0, 1)], 0.1)).is_ok());
469        assert!(chain.append(make_delta(vec![(1, 1)], 0.1)).is_ok());
470        assert!(chain.append(make_delta(vec![(2, 1)], 0.1)).is_err());
471    }
472
473    #[test]
474    fn test_chain_needs_compaction() {
475        let mut chain = DeltaChain::new(vec![1.0; 4], 2);
476        assert!(!chain.needs_compaction());
477        chain.append(make_delta(vec![(0, 1)], 0.1)).unwrap();
478        assert!(!chain.needs_compaction());
479        chain.append(make_delta(vec![(1, 1)], 0.1)).unwrap();
480        assert!(chain.needs_compaction());
481    }
482
483    #[test]
484    fn test_factor_reconstruct() {
485        let (u, v, s) = (vec![1.0, 2.0, 3.0], vec![4.0, 5.0], 2.0);
486        let f = FactorSet { m: 3, n: 2, k: 1, u_data: u.clone(), s_data: vec![s], v_data: v.clone() };
487        let r = f.reconstruct();
488        assert_eq!(r.len(), 6);
489        for i in 0..3 {
490            for j in 0..2 {
491                assert!((r[i * 2 + j] - u[i] * s * v[j]).abs() < 1e-6);
492            }
493        }
494    }
495
496    #[test]
497    fn test_factor_from_data_approximation() {
498        let (m, n) = (8, 6);
499        let data: Vec<f32> = (0..m * n).map(|idx| {
500            let (i, j) = (idx / n, idx % n);
501            (i as f32 + 1.0) * (j as f32 + 1.0)
502        }).collect();
503        let reconstructed = FactorSet::from_data(&data, m, n, 1).reconstruct();
504        let max_err = data.iter().zip(reconstructed.iter())
505            .map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
506        assert!(max_err < 0.5, "max error {max_err} too large for rank-1 input");
507    }
508
509    #[test]
510    fn test_encode_decode_roundtrip() {
511        let orig = DeltaRecord {
512            header: DeltaHeader { tensor_id: 0xDEADBEEFCAFEBABE, block_index: 42, base_epoch: 100, nnz: 3 },
513            delta_scale: 0.001,
514            entries: vec![
515                SparseEntry { index: 10, value: 500 },
516                SparseEntry { index: 20, value: -300 },
517                SparseEntry { index: 30, value: 1 },
518            ],
519        };
520        let bytes = encode_delta(&orig);
521        assert_eq!(bytes.len(), DELTA_HEADER_BYTES + 3 * DELTA_ENTRY_BYTES);
522        let dec = decode_delta(&bytes).unwrap();
523        assert_eq!(dec.header.tensor_id, orig.header.tensor_id);
524        assert_eq!(dec.header.block_index, orig.header.block_index);
525        assert_eq!(dec.header.nnz, orig.header.nnz);
526        assert!((dec.delta_scale - orig.delta_scale).abs() < 1e-10);
527        for (a, b) in dec.entries.iter().zip(orig.entries.iter()) {
528            assert_eq!(a.index, b.index);
529            assert_eq!(a.value, b.value);
530        }
531    }
532
533    #[test]
534    fn test_decode_truncated_header() { assert!(decode_delta(&vec![0u8; 20]).is_err()); }
535
536    #[test]
537    fn test_decode_truncated_entries() {
538        let mut bytes = encode_delta(&make_delta(vec![(0, 1), (1, 2)], 1.0));
539        bytes[28] = 5; bytes[29] = 0; // claim 5 entries, only 2 present
540        assert!(decode_delta(&bytes).is_err());
541    }
542
543    #[test]
544    fn test_empty_delta_roundtrip() {
545        let d = DeltaRecord {
546            header: DeltaHeader { tensor_id: 99, block_index: 7, base_epoch: 50, nnz: 0 },
547            delta_scale: 0.0, entries: Vec::new(),
548        };
549        let dec = decode_delta(&encode_delta(&d)).unwrap();
550        assert_eq!(dec.entries.len(), 0);
551    }
552
553    #[test]
554    fn test_single_entry_delta() {
555        let old = vec![1.0; 100];
556        let mut new = old.clone();
557        new[50] = 2.0;
558        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
559        assert_eq!(d.entries.len(), 1);
560        assert_eq!(d.entries[0].index, 50);
561        let mut base = old.clone();
562        apply_delta(&mut base, &d);
563        assert!((base[50] - 2.0).abs() < 0.01);
564    }
565
566    #[test]
567    fn test_full_density_delta() {
568        let old = vec![0.0; 4];
569        let new = vec![0.1, 0.2, 0.3, 0.4];
570        let d = compute_delta(&old, &new, 1, 0, 0, 0.001, 1.1).unwrap();
571        assert_eq!(d.entries.len(), 4);
572        let mut base = old.clone();
573        apply_delta(&mut base, &d);
574        for i in 0..4 { assert!((base[i] - new[i]).abs() < 0.01, "index {i}"); }
575    }
576
577    #[test]
578    fn test_compute_apply_roundtrip_64() {
579        let old: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
580        let mut new = old.clone();
581        new[5] += 0.5; new[10] -= 0.3; new[60] += 1.0;
582        let d = compute_delta(&old, &new, 1, 0, 0, 0.01, 0.5).unwrap();
583        let mut recon = old.clone();
584        apply_delta(&mut recon, &d);
585        for i in 0..64 { assert!((recon[i] - new[i]).abs() < 0.01, "index {i}"); }
586    }
587
588    #[test]
589    fn test_reconstruction_error_zero_for_exact() {
590        // Rank-1 data should be exactly reconstructed with rank-1 factors
591        let (m, n) = (4, 3);
592        let data: Vec<f32> = (0..m * n).map(|idx| {
593            let (i, j) = (idx / n, idx % n);
594            (i as f32 + 1.0) * (j as f32 + 1.0)
595        }).collect();
596        let factors = FactorSet::from_data(&data, m, n, 1);
597        let err = factors.reconstruction_error(&data);
598        assert!(err < 0.01, "err={err} too large for rank-1 data");
599    }
600
601    #[test]
602    fn test_reconstruction_error_decreases_with_rank() {
603        let (m, n) = (8, 6);
604        let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.7).sin()).collect();
605        let err1 = FactorSet::from_data(&data, m, n, 1).reconstruction_error(&data);
606        let err3 = FactorSet::from_data(&data, m, n, 3).reconstruction_error(&data);
607        assert!(err3 <= err1 + 1e-6, "err3={err3} > err1={err1}");
608    }
609
610    #[test]
611    fn test_energy_captured_rank1_data() {
612        let (m, n) = (4, 3);
613        let data: Vec<f32> = (0..m * n).map(|idx| {
614            let (i, j) = (idx / n, idx % n);
615            (i as f32 + 1.0) * (j as f32 + 1.0)
616        }).collect();
617        let factors = FactorSet::from_data(&data, m, n, 1);
618        let energy = factors.energy_captured(&data);
619        assert!(energy > 0.95, "energy={energy} too low for rank-1 data");
620    }
621
622    #[test]
623    fn test_compression_ratio_meaningful() {
624        let (m, n) = (16, 16);
625        let data: Vec<f32> = (0..m * n).map(|i| i as f32).collect();
626        let factors = FactorSet::from_data(&data, m, n, 2);
627        let ratio = factors.compression_ratio(m * n);
628        // rank-2 storage: (16*2 + 2 + 2*16) * 4 = 264 bytes vs 16*16*4 = 1024 bytes
629        assert!(ratio > 1.0, "ratio={ratio} should be > 1");
630    }
631
632    #[test]
633    fn test_from_data_adaptive_stops_early() {
634        let (m, n) = (4, 3);
635        // Rank-1 data: adaptive should stop at rank 1
636        let data: Vec<f32> = (0..m * n).map(|idx| {
637            let (i, j) = (idx / n, idx % n);
638            (i as f32 + 1.0) * (j as f32 + 1.0)
639        }).collect();
640        let factors = FactorSet::from_data_adaptive(&data, m, n, 5, 0.05);
641        // Should use rank 1 since data is rank 1
642        assert!(factors.k <= 2, "k={} should be small for rank-1 data", factors.k);
643    }
644
645    #[test]
646    fn test_from_data_adaptive_increases_rank() {
647        let (m, n) = (8, 6);
648        // Multi-rank data
649        let data: Vec<f32> = (0..m * n).map(|i| (i as f32 * 0.3).sin() + (i as f32 * 0.7).cos()).collect();
650        let factors = FactorSet::from_data_adaptive(&data, m, n, 6, 0.01);
651        let err = factors.reconstruction_error(&data);
652        // Should achieve close to target error or use max rank
653        assert!(err < 0.1 || factors.k == 6, "err={err}, k={}", factors.k);
654    }
655}