Skip to main content

tensorlogic_scirs_backend/
decomposition.rs

1//! Tensor decomposition algorithms for the SciRS2 backend.
2//!
3//! Provides production-quality implementations of:
4//! - **Truncated SVD** via deflation/power iteration (no BLAS/LAPACK dependency)
5//! - **Mode-n unfolding / folding** for tensor-matrix conversions
6//! - **Tucker-1** single-mode compression
7//! - **CP / PARAFAC** via Alternating Least Squares (ALS) for 3-mode tensors
8//! - **HOSVD** (Higher-Order SVD) multilinear compression
9//!
10//! All operations are pure Rust; no C or Fortran linkage is required.
11
12use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
13use std::fmt;
14
15// ---------------------------------------------------------------------------
16// Error type
17// ---------------------------------------------------------------------------
18
19/// Errors that can arise during tensor decomposition.
20#[derive(Debug, Clone)]
21pub enum DecompositionError {
22    /// Input shape is incompatible with the requested operation.
23    ShapeError(String),
24    /// Iterative algorithm did not converge within the allowed iterations.
25    ConvergenceFailure { iterations: usize, residual: f64 },
26    /// A matrix is numerically singular and cannot be inverted.
27    SingularMatrix,
28    /// Requested rank exceeds the maximum feasible rank for the dimension.
29    InvalidRank { rank: usize, max_rank: usize },
30    /// Operation applied to an empty tensor (at least one dim is 0).
31    EmptyTensor,
32    /// Operation requires a matrix (2-D array) but received an n-D tensor.
33    NonMatrixInput { ndim: usize },
34}
35
36impl fmt::Display for DecompositionError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::ShapeError(msg) => write!(f, "Shape error: {msg}"),
40            Self::ConvergenceFailure {
41                iterations,
42                residual,
43            } => write!(
44                f,
45                "Convergence failure after {iterations} iterations (residual={residual:.3e})"
46            ),
47            Self::SingularMatrix => write!(f, "Matrix is numerically singular"),
48            Self::InvalidRank { rank, max_rank } => {
49                write!(f, "Invalid rank {rank}: must be in 1..={max_rank}")
50            }
51            Self::EmptyTensor => write!(f, "Tensor has at least one zero-length dimension"),
52            Self::NonMatrixInput { ndim } => {
53                write!(f, "Expected a 2-D matrix, got a {ndim}-D tensor")
54            }
55        }
56    }
57}
58
59impl std::error::Error for DecompositionError {}
60
61// ---------------------------------------------------------------------------
62// Truncated SVD
63// ---------------------------------------------------------------------------
64
65/// Result of a truncated SVD decomposition on a 2-D matrix.
66#[derive(Debug, Clone)]
67pub struct TruncatedSvd {
68    /// Left singular vectors `[m, k]`.
69    pub u: Array2<f64>,
70    /// Singular values `[k]`, in descending order.
71    pub s: Array1<f64>,
72    /// Right singular vectors (transposed) `[k, n]`.
73    pub vt: Array2<f64>,
74    /// Number of singular components retained.
75    pub rank: usize,
76    /// Fraction of total variance explained: `Σ s[:k]² / Σ s_full²`.
77    pub explained_variance_ratio: f64,
78}
79
80impl TruncatedSvd {
81    /// Reconstruct the approximate matrix `U @ diag(s) @ Vt`.
82    pub fn reconstruct(&self) -> Array2<f64> {
83        let m = self.u.nrows();
84        let n = self.vt.ncols();
85        let mut result = Array2::<f64>::zeros((m, n));
86        for i in 0..self.rank {
87            let u_col = self.u.column(i);
88            let vt_row = self.vt.row(i);
89            let s_i = self.s[i];
90            for r in 0..m {
91                for c in 0..n {
92                    result[[r, c]] += s_i * u_col[r] * vt_row[c];
93                }
94            }
95        }
96        result
97    }
98
99    /// Frobenius norm of `original - reconstructed`.
100    pub fn reconstruction_error(&self, original: &Array2<f64>) -> f64 {
101        let approx = self.reconstruct();
102        let diff = original - &approx;
103        frobenius_norm_2d(&diff)
104    }
105}
106
107/// Compute truncated SVD using the deflation / power-iteration method.
108///
109/// The algorithm iterates over components one by one:
110/// 1. Start with a random unit vector `v ∈ ℝⁿ`.
111/// 2. Power-iterate: `v ← Mᵀ(Mv) / ‖Mᵀ(Mv)‖` for `n_iter` steps.
112/// 3. Recover `u = Mv / ‖Mv‖`, `σ = uᵀMv`.
113/// 4. Deflate: `M ← M − σ · uv ᵀ`.
114///
115/// # Parameters
116/// - `matrix`: Input `[m, n]` matrix.
117/// - `k`: Number of singular triplets to compute.
118/// - `n_iter`: Power iterations per component (higher = more accurate; default 20).
119/// - `tol`: Convergence tolerance for power iteration.
120pub fn truncated_svd(
121    matrix: &Array2<f64>,
122    k: usize,
123    n_iter: usize,
124    tol: f64,
125) -> Result<TruncatedSvd, DecompositionError> {
126    let (m, n) = (matrix.nrows(), matrix.ncols());
127    if m == 0 || n == 0 {
128        return Err(DecompositionError::EmptyTensor);
129    }
130    let max_rank = m.min(n);
131    if k == 0 || k > max_rank {
132        return Err(DecompositionError::InvalidRank { rank: k, max_rank });
133    }
134
135    // Compute full Frobenius norm for explained variance denominator
136    let total_sq: f64 = matrix.iter().map(|x| x * x).sum();
137
138    let mut u_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
139    let mut s_vals: Vec<f64> = Vec::with_capacity(k);
140    let mut v_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
141
142    // Working copy — we deflate this in-place
143    let mut work = matrix.to_owned();
144
145    for _comp in 0..k {
146        // Initialise v with a deterministic pseudo-random vector
147        let mut v = init_vector(n, _comp);
148        normalize_vec(&mut v);
149
150        let mut prev_sigma = f64::INFINITY;
151
152        for _iter in 0..n_iter {
153            // u = M v
154            let u = mat_vec_mul(&work, &v);
155            // v_new = Mᵀ u
156            let mut v_new = mat_t_vec_mul(&work, &u);
157            normalize_vec(&mut v_new);
158
159            // convergence check via change in v
160            let diff: f64 = v_new
161                .iter()
162                .zip(v.iter())
163                .map(|(a, b)| (a - b).abs())
164                .sum::<f64>();
165            v = v_new;
166
167            let sigma_est = {
168                let u_tmp = mat_vec_mul(&work, &v);
169                dot_product(&u_tmp, &u_tmp).sqrt()
170            };
171            if (sigma_est - prev_sigma).abs() < tol && diff < tol {
172                break;
173            }
174            prev_sigma = sigma_est;
175        }
176
177        // Final extraction
178        let mut u = mat_vec_mul(&work, &v);
179        let sigma = vec_norm(&u);
180        if sigma < 1e-14 {
181            // All remaining singular values are negligible
182            u = vec![0.0; m];
183        } else {
184            u.iter_mut().for_each(|x| *x /= sigma);
185        }
186
187        // Deflate: work -= sigma * outer(u, v)
188        for r in 0..m {
189            for c in 0..n {
190                work[[r, c]] -= sigma * u[r] * v[c];
191            }
192        }
193
194        u_vecs.push(u);
195        s_vals.push(sigma);
196        v_vecs.push(v);
197    }
198
199    // Build result arrays
200    let u_arr = Array2::from_shape_fn((m, k), |(r, c)| u_vecs[c][r]);
201    let s_arr = Array1::from_vec(s_vals.clone());
202    let vt_arr = Array2::from_shape_fn((k, n), |(r, c)| v_vecs[r][c]);
203
204    let captured_sq: f64 = s_vals.iter().map(|s| s * s).sum();
205    let explained_variance_ratio = if total_sq < 1e-30 {
206        1.0
207    } else {
208        (captured_sq / total_sq).min(1.0)
209    };
210
211    Ok(TruncatedSvd {
212        u: u_arr,
213        s: s_arr,
214        vt: vt_arr,
215        rank: k,
216        explained_variance_ratio,
217    })
218}
219
220// ---------------------------------------------------------------------------
221// Tensor unfolding / folding
222// ---------------------------------------------------------------------------
223
224/// Mode-n unfolding of a tensor: reshape to 2-D matrix where mode `n` becomes
225/// the rows and all other modes are flattened column-wise.
226///
227/// Axis permutation: `[mode, 0, 1, …, mode-1, mode+1, …, ndim-1]`.
228pub fn unfold(tensor: &ArrayD<f64>, mode: usize) -> Result<Array2<f64>, DecompositionError> {
229    let ndim = tensor.ndim();
230    if ndim == 0 {
231        return Err(DecompositionError::ShapeError(
232            "Cannot unfold a 0-D scalar tensor".into(),
233        ));
234    }
235    if mode >= ndim {
236        return Err(DecompositionError::ShapeError(format!(
237            "mode {mode} out of range for {ndim}-D tensor"
238        )));
239    }
240    if tensor.shape().contains(&0) {
241        return Err(DecompositionError::EmptyTensor);
242    }
243
244    let shape = tensor.shape();
245    let n_rows = shape[mode];
246    let n_cols: usize = shape
247        .iter()
248        .enumerate()
249        .filter(|&(i, _)| i != mode)
250        .map(|(_i, &d)| d)
251        .product();
252
253    // Build permutation: [mode, 0, 1, …, mode-1, mode+1, …]
254    let mut perm: Vec<usize> = Vec::with_capacity(ndim);
255    perm.push(mode);
256    for i in 0..ndim {
257        if i != mode {
258            perm.push(i);
259        }
260    }
261
262    let permuted = tensor.view().permuted_axes(perm);
263    // Collect in permuted order
264    let data: Vec<f64> = permuted.iter().copied().collect();
265
266    Array2::from_shape_vec((n_rows, n_cols), data)
267        .map_err(|e| DecompositionError::ShapeError(e.to_string()))
268}
269
270/// Inverse of [`unfold`]: fold a 2-D matrix back to a tensor of the given shape.
271///
272/// The `mode` axis corresponds to the rows; all other axes are columns in the
273/// same iteration order as [`unfold`].
274pub fn fold(
275    matrix: &Array2<f64>,
276    mode: usize,
277    shape: &[usize],
278) -> Result<ArrayD<f64>, DecompositionError> {
279    let ndim = shape.len();
280    if ndim == 0 {
281        return Err(DecompositionError::ShapeError(
282            "Cannot fold to a 0-D tensor".into(),
283        ));
284    }
285    if mode >= ndim {
286        return Err(DecompositionError::ShapeError(format!(
287            "mode {mode} out of range for {ndim}-D shape"
288        )));
289    }
290
291    let n_rows = shape[mode];
292    let n_cols: usize = shape
293        .iter()
294        .enumerate()
295        .filter(|&(i, _)| i != mode)
296        .map(|(_i, &d)| d)
297        .product();
298
299    if matrix.nrows() != n_rows || matrix.ncols() != n_cols {
300        return Err(DecompositionError::ShapeError(format!(
301            "matrix shape {}×{} does not match expected {}×{} for mode={mode}, shape={shape:?}",
302            matrix.nrows(),
303            matrix.ncols(),
304            n_rows,
305            n_cols
306        )));
307    }
308
309    // Build permuted shape: [mode, others...]
310    let mut perm_shape: Vec<usize> = Vec::with_capacity(ndim);
311    perm_shape.push(shape[mode]);
312    for (i, &d) in shape.iter().enumerate() {
313        if i != mode {
314            perm_shape.push(d);
315        }
316    }
317
318    let data: Vec<f64> = matrix.iter().copied().collect();
319    let permuted = ArrayD::from_shape_vec(IxDyn(&perm_shape), data)
320        .map_err(|e| DecompositionError::ShapeError(e.to_string()))?;
321
322    // Inverse permutation: perm was [mode, 0..mode-1, mode+1..], inverse maps back
323    // perm[0] = mode, perm[1] = 0, perm[2] = 1, ...
324    // original axis i → permuted axis:
325    //   mode → 0, j<mode → j+1, j>mode → j
326    let mut inv_perm = vec![0usize; ndim];
327    inv_perm[mode] = 0;
328    let mut pos = 1usize;
329    for (i, slot) in inv_perm.iter_mut().enumerate().take(ndim) {
330        if i != mode {
331            *slot = pos;
332            pos += 1;
333        }
334    }
335
336    let result = permuted.view().permuted_axes(inv_perm).to_owned();
337    Ok(result)
338}
339
340// ---------------------------------------------------------------------------
341// Tucker-1
342// ---------------------------------------------------------------------------
343
344/// Result of a Tucker-1 (single-mode) tensor decomposition.
345#[derive(Debug, Clone)]
346pub struct Tucker1Result {
347    /// Compressed core tensor (same shape as original, except `shape[mode] = rank`).
348    pub core: ArrayD<f64>,
349    /// Factor matrix `[original_dim, rank]`.
350    pub factor: Array2<f64>,
351    /// The mode that was compressed.
352    pub mode: usize,
353    /// Rank used for this mode.
354    pub rank: usize,
355    /// `original_elements / core_elements`.
356    pub compression_ratio: f64,
357}
358
359impl Tucker1Result {
360    /// Reconstruct original tensor: core ×_mode factorᵀ  (Tucker product along mode).
361    ///
362    /// Equivalent to: unfold(core, mode) → factor @ unfolded → fold back.
363    pub fn reconstruct(&self) -> ArrayD<f64> {
364        // unfold core along mode → [rank, rest]
365        let core_unfolded = match unfold(&self.core, self.mode) {
366            Ok(m) => m,
367            Err(_) => return self.core.clone(),
368        };
369        // factor: [orig_dim, rank], factor @ core_unfolded: [orig_dim, rest]
370        let reconstructed_mat = self.factor.dot(&core_unfolded);
371
372        // shape: replace mode dim with orig_dim
373        let mut orig_shape: Vec<usize> = self.core.shape().to_vec();
374        orig_shape[self.mode] = self.factor.nrows();
375
376        match fold(&reconstructed_mat, self.mode, &orig_shape) {
377            Ok(t) => t,
378            Err(_) => ArrayD::zeros(IxDyn(&orig_shape)),
379        }
380    }
381
382    /// Frobenius norm of `original - reconstruct()`.
383    pub fn reconstruction_error(&self, original: &ArrayD<f64>) -> f64 {
384        let approx = self.reconstruct();
385        let diff = original - &approx;
386        frobenius_norm_nd(&diff)
387    }
388}
389
390/// Tucker-1 decomposition: compress tensor along a single mode using truncated SVD.
391///
392/// Steps:
393/// 1. Unfold the tensor along `mode` → matrix `X_(mode)` of shape `[dim, rest]`.
394/// 2. Compute truncated SVD of rank `rank`.
395/// 3. `factor = U` (`[dim, rank]`).
396/// 4. Core unfolding = `Uᵀ @ X_(mode)` → fold back.
397pub fn tucker1(
398    tensor: &ArrayD<f64>,
399    mode: usize,
400    rank: usize,
401) -> Result<Tucker1Result, DecompositionError> {
402    let ndim = tensor.ndim();
403    if ndim == 0 {
404        return Err(DecompositionError::ShapeError("0-D tensor".into()));
405    }
406    if mode >= ndim {
407        return Err(DecompositionError::ShapeError(format!(
408            "mode {mode} out of range for {ndim}-D tensor"
409        )));
410    }
411
412    let orig_dim = tensor.shape()[mode];
413    if orig_dim == 0 {
414        return Err(DecompositionError::EmptyTensor);
415    }
416    if rank == 0 || rank > orig_dim {
417        return Err(DecompositionError::InvalidRank {
418            rank,
419            max_rank: orig_dim,
420        });
421    }
422
423    let unfolded = unfold(tensor, mode)?; // [orig_dim, rest]
424    let svd = truncated_svd(&unfolded, rank, 30, 1e-10)?; // U:[m,k], s:[k], Vt:[k,n]
425
426    // factor = U  ([orig_dim, rank])
427    let factor = svd.u.clone();
428
429    // core_unfolded = Uᵀ @ unfolded = [rank, rest]
430    // U is [orig_dim, rank], so Uᵀ is [rank, orig_dim]
431    let u_t = svd.u.t().to_owned(); // [rank, orig_dim]
432    let core_unfolded = u_t.dot(&unfolded); // [rank, rest]
433
434    // Build core shape
435    let mut core_shape: Vec<usize> = tensor.shape().to_vec();
436    core_shape[mode] = rank;
437
438    let core = fold(&core_unfolded, mode, &core_shape)?;
439
440    let original_elements: usize = tensor.shape().iter().product();
441    let core_elements: usize = core_shape.iter().product();
442    let compression_ratio = if core_elements == 0 {
443        1.0
444    } else {
445        original_elements as f64 / core_elements as f64
446    };
447
448    Ok(Tucker1Result {
449        core,
450        factor,
451        mode,
452        rank,
453        compression_ratio,
454    })
455}
456
457// ---------------------------------------------------------------------------
458// CP / PARAFAC decomposition (ALS, 3-mode)
459// ---------------------------------------------------------------------------
460
461/// Result of a CP (PARAFAC) tensor decomposition.
462#[derive(Debug, Clone)]
463pub struct CpDecomposition {
464    /// Factor matrices, one per mode, each `[dim_i, rank]`.
465    pub factors: Vec<Array2<f64>>,
466    /// Normalization weights per component `[rank]`.
467    pub weights: Array1<f64>,
468    /// Number of components.
469    pub rank: usize,
470    /// Number of modes (tensor order).
471    pub num_modes: usize,
472    /// Number of ALS iterations actually performed.
473    pub iterations: usize,
474    /// Residual ‖X - X̂‖_F at convergence.
475    pub final_residual: f64,
476    /// Whether ALS converged within tolerance.
477    pub converged: bool,
478}
479
480impl CpDecomposition {
481    /// Reconstruct the tensor from factor matrices and weights.
482    ///
483    /// For each component r, accumulates `weights[r] * outer(A[:,r], B[:,r], C[:,r], …)`.
484    pub fn reconstruct(&self) -> ArrayD<f64> {
485        if self.factors.is_empty() {
486            return ArrayD::zeros(IxDyn(&[]));
487        }
488        let shape: Vec<usize> = self.factors.iter().map(|f| f.nrows()).collect();
489        let total: usize = shape.iter().product();
490        let ndim = self.num_modes;
491
492        let mut data = vec![0.0f64; total];
493
494        for r in 0..self.rank {
495            let w = self.weights[r];
496            // Iterate over all multi-indices via flat index
497            for (flat, slot) in data.iter_mut().enumerate().take(total) {
498                let mut strides = vec![0usize; ndim];
499                let mut remaining = flat;
500                for d in (0..ndim).rev() {
501                    strides[d] = remaining % shape[d];
502                    remaining /= shape[d];
503                }
504                let contrib: f64 = self
505                    .factors
506                    .iter()
507                    .zip(strides.iter())
508                    .map(|(f, &i)| f[[i, r]])
509                    .product();
510                *slot += w * contrib;
511            }
512        }
513
514        ArrayD::from_shape_vec(IxDyn(&shape), data).unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)))
515    }
516
517    /// Frobenius reconstruction error ‖original - reconstructed‖_F.
518    pub fn reconstruction_error(&self, original: &ArrayD<f64>) -> f64 {
519        let approx = self.reconstruct();
520        let diff = original - &approx;
521        frobenius_norm_nd(&diff)
522    }
523
524    /// Fraction of variance explained: `1 - ‖residual‖² / ‖original‖²`.
525    pub fn explained_variance(&self, original: &ArrayD<f64>) -> f64 {
526        let original_sq: f64 = original.iter().map(|x| x * x).sum();
527        if original_sq < 1e-30 {
528            return 1.0;
529        }
530        let residual = self.reconstruction_error(original);
531        let explained = 1.0 - (residual * residual) / original_sq;
532        explained.clamp(0.0, 1.0)
533    }
534}
535
536/// Khatri-Rao product (column-wise Kronecker product).
537///
538/// For `A: [m, r]` and `B: [n, r]` → output `[mn, r]`.
539fn khatri_rao(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
540    let (m, r_a) = (a.nrows(), a.ncols());
541    let (n, r_b) = (b.nrows(), b.ncols());
542    debug_assert_eq!(
543        r_a, r_b,
544        "Khatri-Rao: both matrices must have same number of columns"
545    );
546    let r = r_a.min(r_b);
547    let mut result = Array2::<f64>::zeros((m * n, r));
548    for col in 0..r {
549        for i in 0..m {
550            for j in 0..n {
551                result[[i * n + j, col]] = a[[i, col]] * b[[j, col]];
552            }
553        }
554    }
555    result
556}
557
558/// Element-wise (Hadamard) product of two 2-D arrays (same shape).
559fn hadamard(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
560    a * b
561}
562
563/// Gram matrix: `Aᵀ A` for `A: [m, r]` → `[r, r]`.
564fn gram(a: &Array2<f64>) -> Array2<f64> {
565    a.t().dot(a)
566}
567
568/// Pseudo-inverse of a small square matrix via direct inversion with regularisation.
569///
570/// For the small `[r, r]` matrices in CP-ALS this is sufficient and avoids
571/// pulling in a full linear-algebra solver.
572fn pinv_small(m: &Array2<f64>, lambda: f64) -> Result<Array2<f64>, DecompositionError> {
573    let n = m.nrows();
574    debug_assert_eq!(n, m.ncols());
575
576    // Regularised matrix: M + λI
577    let mut reg = m.to_owned();
578    for i in 0..n {
579        reg[[i, i]] += lambda;
580    }
581
582    // Gaussian elimination with partial pivoting
583    let mut aug: Vec<Vec<f64>> = (0..n)
584        .map(|i| {
585            let mut row: Vec<f64> = reg.row(i).to_vec();
586            // Append identity column
587            for j in 0..n {
588                row.push(if i == j { 1.0 } else { 0.0 });
589            }
590            row
591        })
592        .collect();
593
594    for col in 0..n {
595        // Find pivot
596        let pivot = (col..n).max_by(|&a, &b| {
597            aug[a][col]
598                .abs()
599                .partial_cmp(&aug[b][col].abs())
600                .unwrap_or(std::cmp::Ordering::Equal)
601        });
602        let pivot = pivot.ok_or(DecompositionError::SingularMatrix)?;
603        aug.swap(col, pivot);
604
605        let diag = aug[col][col];
606        if diag.abs() < 1e-15 {
607            return Err(DecompositionError::SingularMatrix);
608        }
609
610        let inv_diag = 1.0 / diag;
611        for elem in aug[col].iter_mut().take(2 * n) {
612            *elem *= inv_diag;
613        }
614
615        for row in 0..n {
616            if row != col {
617                let factor = aug[row][col];
618                let col_row: Vec<f64> = aug[col][..2 * n].to_vec();
619                for (j, &cv) in col_row.iter().enumerate().take(2 * n) {
620                    aug[row][j] -= factor * cv;
621                }
622            }
623        }
624    }
625
626    // Extract inverse from right half
627    let inv = Array2::from_shape_fn((n, n), |(i, j)| aug[i][n + j]);
628    Ok(inv)
629}
630
631/// CP (PARAFAC) decomposition of a 3-mode tensor via Alternating Least Squares.
632///
633/// # Algorithm (ALS)
634/// Given tensor X ∈ ℝ^{I×J×K} and rank R:
635/// 1. Initialise factor matrices A, B, C randomly (or via HOSVD init for better
636///    convergence).
637/// 2. Repeat:
638///    - `A ← X_(0) (C ⊙ B) (CᵀC * BᵀB)⁺`
639///    - `B ← X_(1) (C ⊙ A) (CᵀC * AᵀA)⁺`
640///    - `C ← X_(2) (B ⊙ A) (BᵀB * AᵀA)⁺`
641///    - Normalise columns, track weights.
642/// 3. Check convergence via relative change in residual.
643///
644/// ⊙ denotes the Khatri-Rao (column-wise Kronecker) product.
645pub fn cp_als(
646    tensor: &ArrayD<f64>,
647    rank: usize,
648    max_iter: usize,
649    tol: f64,
650) -> Result<CpDecomposition, DecompositionError> {
651    if tensor.ndim() != 3 {
652        return Err(DecompositionError::ShapeError(format!(
653            "cp_als currently supports only 3-mode tensors, got {}-mode",
654            tensor.ndim()
655        )));
656    }
657    let shape = tensor.shape();
658    let (i_dim, j_dim, k_dim) = (shape[0], shape[1], shape[2]);
659    if i_dim == 0 || j_dim == 0 || k_dim == 0 {
660        return Err(DecompositionError::EmptyTensor);
661    }
662    let max_rank = i_dim.min(j_dim).min(k_dim);
663    if rank == 0 || rank > max_rank {
664        return Err(DecompositionError::InvalidRank { rank, max_rank });
665    }
666
667    // Precompute unfoldings (read-only, computed once)
668    let x0 = unfold(tensor, 0)?; // [I, J*K]
669    let x1 = unfold(tensor, 1)?; // [J, I*K]
670    let x2 = unfold(tensor, 2)?; // [K, I*J]
671
672    let tensor_norm_sq: f64 = tensor.iter().map(|x| x * x).sum();
673
674    // Initialise factors via truncated SVD (HOSVD-style warm start)
675    let mut a = init_factor_svd(&x0, rank)?; // [I, rank]
676    let mut b = init_factor_svd(&x1, rank)?; // [J, rank]
677    let mut c = init_factor_svd(&x2, rank)?; // [K, rank]
678
679    let mut weights = Array1::<f64>::ones(rank);
680    let mut prev_residual = f64::INFINITY;
681    let mut converged = false;
682    let mut iter = 0usize;
683    let regularization = 1e-10;
684
685    for _it in 0..max_iter {
686        iter = _it + 1;
687
688        // --- Update A ---
689        // X_(0) shape [I, J*K]: J is the outer (slow) index, K is inner (fast).
690        // The matching KR product is B ⊙ C: rows indexed as [j*K + k] → shape [J*K, R].
691        {
692            let kr_bc = khatri_rao(&b, &c); // [J*K, rank]
693            let gram_prod = hadamard(&gram(&b), &gram(&c)); // [rank, rank]
694            let gram_inv = pinv_small(&gram_prod, regularization)?;
695            let rhs = kr_bc.dot(&gram_inv); // [J*K, rank]
696            a = x0.dot(&rhs); // [I, rank]
697        }
698
699        // --- Update B ---
700        // X_(1) shape [J, I*K]: I is outer, K is inner.
701        // Matching KR product is A ⊙ C: rows [i*K + k] → shape [I*K, R].
702        {
703            let kr_ac = khatri_rao(&a, &c); // [I*K, rank]
704            let gram_prod = hadamard(&gram(&a), &gram(&c));
705            let gram_inv = pinv_small(&gram_prod, regularization)?;
706            let rhs = kr_ac.dot(&gram_inv);
707            b = x1.dot(&rhs); // [J, rank]
708        }
709
710        // --- Update C ---
711        // X_(2) shape [K, I*J]: I is outer, J is inner.
712        // Matching KR product is A ⊙ B: rows [i*J + j] → shape [I*J, R].
713        {
714            let kr_ab = khatri_rao(&a, &b); // [I*J, rank]
715            let gram_prod = hadamard(&gram(&a), &gram(&b));
716            let gram_inv = pinv_small(&gram_prod, regularization)?;
717            let rhs = kr_ab.dot(&gram_inv);
718            c = x2.dot(&rhs); // [K, rank]
719        }
720
721        // --- Normalise columns, store weights ---
722        for r in 0..rank {
723            let norm_a = col_norm(&a, r);
724            let norm_b = col_norm(&b, r);
725            let norm_c = col_norm(&c, r);
726            let w = norm_a * norm_b * norm_c;
727            weights[r] = w;
728            if norm_a > 1e-14 {
729                a.column_mut(r).mapv_inplace(|x| x / norm_a);
730            }
731            if norm_b > 1e-14 {
732                b.column_mut(r).mapv_inplace(|x| x / norm_b);
733            }
734            if norm_c > 1e-14 {
735                c.column_mut(r).mapv_inplace(|x| x / norm_c);
736            }
737        }
738
739        // --- Convergence check ---
740        // Efficient residual: ‖X‖² - 2 ⟨X, X̂⟩ + ‖X̂‖²
741        // For simplicity we use the reconstruction-based residual every iteration.
742        // For large tensors this can be made cheaper; correctness first.
743        let approx_norm_sq = compute_cp_norm_sq(&a, &b, &c, &weights, rank);
744        let inner_xhat = compute_inner_x_xhat(&x0, &a, &b, &c, &weights, rank);
745        let residual_sq = (tensor_norm_sq - 2.0 * inner_xhat + approx_norm_sq).max(0.0);
746        let residual = residual_sq.sqrt();
747
748        let rel_change = if prev_residual.is_finite() && prev_residual > 1e-30 {
749            (prev_residual - residual).abs() / prev_residual
750        } else {
751            f64::INFINITY
752        };
753
754        if rel_change < tol && _it > 0 {
755            converged = true;
756            prev_residual = residual;
757            break;
758        }
759        prev_residual = residual;
760    }
761
762    Ok(CpDecomposition {
763        factors: vec![a, b, c],
764        weights,
765        rank,
766        num_modes: 3,
767        iterations: iter,
768        final_residual: prev_residual,
769        converged,
770    })
771}
772
773// ---------------------------------------------------------------------------
774// HOSVD
775// ---------------------------------------------------------------------------
776
777/// Result of a Higher-Order SVD (HOSVD) decomposition.
778#[derive(Debug, Clone)]
779pub struct HosvdResult {
780    /// Compressed core tensor with shape `ranks`.
781    pub core: ArrayD<f64>,
782    /// Factor matrices, one per mode, each `[dim_i, rank_i]`.
783    pub factors: Vec<Array2<f64>>,
784    /// Target ranks per mode.
785    pub ranks: Vec<usize>,
786    /// `original_elements / core_elements`.
787    pub compression_ratio: f64,
788}
789
790impl HosvdResult {
791    /// Reconstruct the original tensor via successive Tucker products along each mode.
792    pub fn reconstruct(&self) -> ArrayD<f64> {
793        let mut current = self.core.clone();
794        for (mode, factor) in self.factors.iter().enumerate() {
795            current = match tucker_product(&current, factor, mode) {
796                Ok(t) => t,
797                Err(_) => return self.core.clone(),
798            };
799        }
800        current
801    }
802
803    /// Frobenius reconstruction error.
804    pub fn reconstruction_error(&self, original: &ArrayD<f64>) -> f64 {
805        let approx = self.reconstruct();
806        let diff = original - &approx;
807        frobenius_norm_nd(&diff)
808    }
809}
810
811/// Higher-Order SVD: compute truncated SVD along each mode independently.
812///
813/// For each mode n:
814///   1. Unfold tensor along mode n → `X_(n)`.
815///   2. Compute `truncated_svd(X_(n), ranks[n])` → `U_n`.
816///
817/// Core = X ×_1 U_1ᵀ ×_2 U_2ᵀ … ×_N U_Nᵀ (multi-mode product).
818pub fn hosvd(tensor: &ArrayD<f64>, ranks: &[usize]) -> Result<HosvdResult, DecompositionError> {
819    let ndim = tensor.ndim();
820    if ndim == 0 {
821        return Err(DecompositionError::ShapeError("0-D tensor".into()));
822    }
823    if ranks.len() != ndim {
824        return Err(DecompositionError::ShapeError(format!(
825            "ranks length {} must match tensor ndim {}",
826            ranks.len(),
827            ndim
828        )));
829    }
830    if tensor.shape().contains(&0) {
831        return Err(DecompositionError::EmptyTensor);
832    }
833    for (mode, &r) in ranks.iter().enumerate() {
834        let dim = tensor.shape()[mode];
835        if r == 0 || r > dim {
836            return Err(DecompositionError::InvalidRank {
837                rank: r,
838                max_rank: dim,
839            });
840        }
841    }
842
843    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(ndim);
844
845    for (mode, &rank_m) in ranks.iter().enumerate().take(ndim) {
846        let unfolded = unfold(tensor, mode)?;
847        let svd = truncated_svd(&unfolded, rank_m, 30, 1e-10)?;
848        factors.push(svd.u); // [dim_mode, rank_mode]
849    }
850
851    // Compute core = X ×_0 U_0ᵀ ×_1 U_1ᵀ … ×_(N-1) U_(N-1)ᵀ
852    // We do this mode by mode: multiply the current tensor by Uᵀ along each mode.
853    let mut core = tensor.to_owned();
854    for (mode, factor) in factors.iter().enumerate() {
855        // unfold along mode, multiply by Uᵀ ([rank, dim] @ [dim, rest]) → fold back
856        let unfolded_core = unfold(&core, mode)?; // [dim_mode, rest]
857        let u_t = factor.t().to_owned(); // [rank_mode, dim_mode]
858        let compressed = u_t.dot(&unfolded_core); // [rank_mode, rest]
859
860        let mut new_shape: Vec<usize> = core.shape().to_vec();
861        new_shape[mode] = ranks[mode];
862
863        core = fold(&compressed, mode, &new_shape)?;
864    }
865
866    let original_elements: usize = tensor.shape().iter().product();
867    let core_elements: usize = ranks.iter().product();
868    let compression_ratio = if core_elements == 0 {
869        1.0
870    } else {
871        original_elements as f64 / core_elements as f64
872    };
873
874    Ok(HosvdResult {
875        core,
876        factors,
877        ranks: ranks.to_vec(),
878        compression_ratio,
879    })
880}
881
882// ---------------------------------------------------------------------------
883// Internal helpers
884// ---------------------------------------------------------------------------
885
886/// Tucker mode product: tensor ×_mode factor ([dim, rank] → expand mode).
887///
888/// Equivalent to: unfold(tensor, mode) → factor @ unfolded → fold.
889fn tucker_product(
890    tensor: &ArrayD<f64>,
891    factor: &Array2<f64>,
892    mode: usize,
893) -> Result<ArrayD<f64>, DecompositionError> {
894    let unfolded = unfold(tensor, mode)?; // [old_dim, rest]
895    let result_mat = factor.dot(&unfolded); // [new_dim, rest]
896
897    let mut new_shape: Vec<usize> = tensor.shape().to_vec();
898    new_shape[mode] = factor.nrows();
899
900    fold(&result_mat, mode, &new_shape)
901}
902
903/// Initialise a factor matrix from the left singular vectors of `matrix`.
904fn init_factor_svd(matrix: &Array2<f64>, rank: usize) -> Result<Array2<f64>, DecompositionError> {
905    let effective_rank = rank.min(matrix.nrows().min(matrix.ncols()));
906    if effective_rank == 0 {
907        return Err(DecompositionError::EmptyTensor);
908    }
909    let svd = truncated_svd(matrix, effective_rank, 20, 1e-10)?;
910    // Pad with random-ish columns if rank > effective_rank
911    if effective_rank == rank {
912        return Ok(svd.u);
913    }
914    let m = matrix.nrows();
915    let mut factor = Array2::<f64>::zeros((m, rank));
916    for c in 0..effective_rank {
917        for r in 0..m {
918            factor[[r, c]] = svd.u[[r, c]];
919        }
920    }
921    // Fill remaining columns with pseudo-random unit vectors
922    for c in effective_rank..rank {
923        let mut v = init_vector(m, c);
924        normalize_vec(&mut v);
925        for r in 0..m {
926            factor[[r, c]] = v[r];
927        }
928    }
929    Ok(factor)
930}
931
932/// Deterministic pseudo-random unit vector of length `n`, seeded by `seed`.
933fn init_vector(n: usize, seed: usize) -> Vec<f64> {
934    (0..n)
935        .map(|i| {
936            let x = i
937                .wrapping_mul(6364136223846793005usize)
938                .wrapping_add(seed.wrapping_mul(1442695040888963407usize))
939                as u64;
940            (x as f64 / u64::MAX as f64) * 2.0 - 1.0
941        })
942        .collect()
943}
944
945fn normalize_vec(v: &mut [f64]) {
946    let norm = vec_norm(v);
947    if norm > 1e-14 {
948        v.iter_mut().for_each(|x| *x /= norm);
949    }
950}
951
952fn vec_norm(v: &[f64]) -> f64 {
953    dot_product(v, v).sqrt()
954}
955
956fn dot_product(a: &[f64], b: &[f64]) -> f64 {
957    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
958}
959
960/// Matrix-vector product: `A @ v` where `A: [m, n]`, `v: [n]` → `[m]`.
961fn mat_vec_mul(a: &Array2<f64>, v: &[f64]) -> Vec<f64> {
962    let m = a.nrows();
963    let n = a.ncols();
964    let mut result = vec![0.0f64; m];
965    for i in 0..m {
966        let mut s = 0.0;
967        for j in 0..n {
968            s += a[[i, j]] * v[j];
969        }
970        result[i] = s;
971    }
972    result
973}
974
975/// Matrix-transpose-vector product: `Aᵀ @ v` where `A: [m, n]`, `v: [m]` → `[n]`.
976fn mat_t_vec_mul(a: &Array2<f64>, v: &[f64]) -> Vec<f64> {
977    let m = a.nrows();
978    let n = a.ncols();
979    let mut result = vec![0.0f64; n];
980    for j in 0..n {
981        let mut s = 0.0;
982        for i in 0..m {
983            s += a[[i, j]] * v[i];
984        }
985        result[j] = s;
986    }
987    result
988}
989
990fn frobenius_norm_2d(m: &Array2<f64>) -> f64 {
991    m.iter().map(|x| x * x).sum::<f64>().sqrt()
992}
993
994fn frobenius_norm_nd(t: &ArrayD<f64>) -> f64 {
995    t.iter().map(|x| x * x).sum::<f64>().sqrt()
996}
997
998fn col_norm(a: &Array2<f64>, col: usize) -> f64 {
999    a.column(col).iter().map(|x| x * x).sum::<f64>().sqrt()
1000}
1001
1002/// ‖X̂‖²_F computed directly from factor matrices (Gram approach).
1003///
1004/// ‖X̂‖² = (AᵀA * BᵀB * CᵀC) ⊙ wwᵀ summed over all entries.
1005fn compute_cp_norm_sq(
1006    a: &Array2<f64>,
1007    b: &Array2<f64>,
1008    c: &Array2<f64>,
1009    weights: &Array1<f64>,
1010    rank: usize,
1011) -> f64 {
1012    let ga = gram(a); // [rank, rank]
1013    let gb = gram(b);
1014    let gc = gram(c);
1015
1016    let mut norm_sq = 0.0f64;
1017    for r1 in 0..rank {
1018        for r2 in 0..rank {
1019            let val = ga[[r1, r2]] * gb[[r1, r2]] * gc[[r1, r2]] * weights[r1] * weights[r2];
1020            norm_sq += val;
1021        }
1022    }
1023    norm_sq
1024}
1025
1026/// Inner product ⟨X, X̂⟩ using the mode-0 unfolding.
1027///
1028/// X_(0) has shape [I, J*K] with J-outer/K-inner ordering.
1029/// The matching KR product is B ⊙ C (B outer, C inner) → [J*K, rank].
1030/// ⟨X, X̂⟩ = Σ_r w_r · (X_(0) (b_r ⊗ c_r)) · a_r
1031fn compute_inner_x_xhat(
1032    x0: &Array2<f64>, // [I, J*K]
1033    a: &Array2<f64>,  // [I, rank]
1034    b: &Array2<f64>,  // [J, rank]
1035    c: &Array2<f64>,  // [K, rank]
1036    weights: &Array1<f64>,
1037    rank: usize,
1038) -> f64 {
1039    let kr = khatri_rao(b, c); // [J*K, rank] — B outer (slow), C inner (fast)
1040                               // x0 @ kr → [I, rank]
1041    let mttkrp = x0.dot(&kr); // [I, rank]
1042                              // inner = sum_r weights[r] * (A[:,r] · mttkrp[:,r])
1043    let mut inner = 0.0f64;
1044    for r in 0..rank {
1045        let dot: f64 = a
1046            .column(r)
1047            .iter()
1048            .zip(mttkrp.column(r).iter())
1049            .map(|(x, y)| x * y)
1050            .sum();
1051        inner += weights[r] * dot;
1052    }
1053    inner
1054}
1055
1056// ---------------------------------------------------------------------------
1057// Tests
1058// ---------------------------------------------------------------------------
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063    use scirs2_core::ndarray::IxDyn;
1064
1065    const TOL: f64 = 1e-6;
1066    const ALS_TOL: f64 = 1e-5;
1067
1068    // Helper: build a random-ish ArrayD from a seed
1069    fn make_tensor(shape: &[usize], seed: usize) -> ArrayD<f64> {
1070        let n: usize = shape.iter().product();
1071        let data: Vec<f64> = (0..n)
1072            .map(|i| {
1073                let x = i
1074                    .wrapping_mul(6364136223846793005usize)
1075                    .wrapping_add(seed.wrapping_mul(1442695040888963407usize))
1076                    as u64;
1077                (x as f64 / u64::MAX as f64) * 2.0 - 1.0
1078            })
1079            .collect();
1080        ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
1081    }
1082
1083    fn make_matrix(rows: usize, cols: usize, seed: usize) -> Array2<f64> {
1084        let n = rows * cols;
1085        let data: Vec<f64> = (0..n)
1086            .map(|i| {
1087                let x = i
1088                    .wrapping_mul(6364136223846793005usize)
1089                    .wrapping_add(seed.wrapping_mul(1442695040888963407usize))
1090                    as u64;
1091                (x as f64 / u64::MAX as f64) * 2.0 - 1.0
1092            })
1093            .collect();
1094        Array2::from_shape_vec((rows, cols), data).expect("shape ok")
1095    }
1096
1097    // --- Truncated SVD tests ---
1098
1099    #[test]
1100    fn test_truncated_svd_rank1() {
1101        // Rank-1 matrix: outer product of two vectors
1102        let u_true: Vec<f64> = (0..5).map(|i| (i + 1) as f64).collect();
1103        let v_true: Vec<f64> = (0..4).map(|i| (i + 1) as f64).collect();
1104        let mut mat = Array2::<f64>::zeros((5, 4));
1105        for i in 0..5 {
1106            for j in 0..4 {
1107                mat[[i, j]] = u_true[i] * v_true[j];
1108            }
1109        }
1110        let svd = truncated_svd(&mat, 1, 40, 1e-12).expect("svd ok");
1111        let recon = svd.reconstruct();
1112        let err = (mat - recon).iter().map(|x| x * x).sum::<f64>().sqrt();
1113        assert!(err < 1e-6, "rank-1 reconstruction error too large: {err}");
1114    }
1115
1116    #[test]
1117    fn test_truncated_svd_reconstruction_error_decreases_with_rank() {
1118        let mat = make_matrix(8, 6, 42);
1119        let err1 = truncated_svd(&mat, 1, 30, 1e-12)
1120            .expect("ok")
1121            .reconstruction_error(&mat);
1122        let err3 = truncated_svd(&mat, 3, 30, 1e-12)
1123            .expect("ok")
1124            .reconstruction_error(&mat);
1125        let err6 = truncated_svd(&mat, 6, 30, 1e-12)
1126            .expect("ok")
1127            .reconstruction_error(&mat);
1128        assert!(
1129            err3 <= err1 + 1e-8,
1130            "rank-3 err ({err3}) should be <= rank-1 err ({err1})"
1131        );
1132        assert!(
1133            err6 <= err3 + 1e-8,
1134            "rank-6 err ({err6}) should be <= rank-3 err ({err3})"
1135        );
1136    }
1137
1138    #[test]
1139    fn test_truncated_svd_singular_values_descending() {
1140        let mat = make_matrix(7, 5, 99);
1141        let svd = truncated_svd(&mat, 4, 30, 1e-12).expect("ok");
1142        for i in 0..svd.rank - 1 {
1143            assert!(
1144                svd.s[i] >= svd.s[i + 1] - 1e-8,
1145                "singular values not descending: s[{i}]={} < s[{}]={}",
1146                svd.s[i],
1147                i + 1,
1148                svd.s[i + 1]
1149            );
1150        }
1151    }
1152
1153    #[test]
1154    fn test_truncated_svd_explained_variance_full_rank() {
1155        let mat = make_matrix(4, 4, 7);
1156        let max_rank = 4;
1157        let svd = truncated_svd(&mat, max_rank, 50, 1e-14).expect("ok");
1158        assert!(
1159            svd.explained_variance_ratio > 0.99,
1160            "full-rank EVR should be ≈1, got {}",
1161            svd.explained_variance_ratio
1162        );
1163    }
1164
1165    #[test]
1166    fn test_truncated_svd_invalid_rank() {
1167        let mat = make_matrix(3, 4, 1);
1168        let result = truncated_svd(&mat, 0, 10, 1e-10);
1169        assert!(result.is_err(), "rank=0 should fail");
1170        let result2 = truncated_svd(&mat, 5, 10, 1e-10);
1171        assert!(result2.is_err(), "rank > min(m,n) should fail");
1172    }
1173
1174    // --- Unfold / fold tests ---
1175
1176    #[test]
1177    fn test_unfold_mode0_shape() {
1178        let tensor = make_tensor(&[3, 4, 5], 1);
1179        let mat = unfold(&tensor, 0).expect("ok");
1180        assert_eq!(mat.nrows(), 3, "mode-0 rows should be dim-0");
1181        assert_eq!(mat.ncols(), 4 * 5, "mode-0 cols should be dim-1 * dim-2");
1182    }
1183
1184    #[test]
1185    fn test_unfold_mode1_shape() {
1186        let tensor = make_tensor(&[3, 4, 5], 2);
1187        let mat = unfold(&tensor, 1).expect("ok");
1188        assert_eq!(mat.nrows(), 4, "mode-1 rows should be dim-1");
1189        assert_eq!(mat.ncols(), 3 * 5, "mode-1 cols should be dim-0 * dim-2");
1190    }
1191
1192    #[test]
1193    fn test_fold_roundtrip() {
1194        let original = make_tensor(&[3, 4, 5], 3);
1195        for mode in 0..3usize {
1196            let mat = unfold(&original, mode).expect("unfold ok");
1197            let recovered = fold(&mat, mode, &[3, 4, 5]).expect("fold ok");
1198            let err = (&original - &recovered)
1199                .iter()
1200                .map(|x| x * x)
1201                .sum::<f64>()
1202                .sqrt();
1203            assert!(err < TOL, "fold(unfold(x, {mode})) != x, error={err}");
1204        }
1205    }
1206
1207    // --- Tucker-1 tests ---
1208
1209    #[test]
1210    fn test_tucker1_compression_ratio() {
1211        let tensor = make_tensor(&[10, 8, 6], 5);
1212        let result = tucker1(&tensor, 0, 3).expect("ok");
1213        assert!(
1214            result.compression_ratio > 1.0,
1215            "compressed core should be smaller than original, ratio={}",
1216            result.compression_ratio
1217        );
1218        assert_eq!(
1219            result.core.shape()[0],
1220            3,
1221            "core mode-0 dim should equal rank"
1222        );
1223    }
1224
1225    #[test]
1226    fn test_tucker1_reconstruction_error_small() {
1227        // Build a tensor that is exactly rank-2 along mode 0
1228        let _basis: Vec<f64> = vec![1.0, 0.0, 0.0, 1.0]; // 2x2 identity
1229        let factor_true =
1230            Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.3, 0.7])
1231                .expect("ok");
1232        // core: shape [2, 3, 3]
1233        let core_true = make_tensor(&[2, 3, 3], 11);
1234        // Reconstruct tensor via Tucker-product
1235        let core_unfolded = unfold(&core_true, 0).expect("ok"); // [2, 9]
1236        let t_unfolded = factor_true.dot(&core_unfolded); // [4, 9]
1237        let tensor = fold(&t_unfolded, 0, &[4, 3, 3]).expect("ok");
1238
1239        let result = tucker1(&tensor, 0, 2).expect("ok");
1240        let err = result.reconstruction_error(&tensor);
1241        assert!(
1242            err < 1e-5,
1243            "Tucker-1 error too large for rank-2 tensor: {err}"
1244        );
1245    }
1246
1247    #[test]
1248    fn test_tucker1_invalid_rank() {
1249        let tensor = make_tensor(&[3, 4, 5], 6);
1250        let result = tucker1(&tensor, 0, 10); // rank > dim-0 = 3
1251        assert!(result.is_err(), "rank > dim should return error");
1252    }
1253
1254    #[test]
1255    fn test_tucker1_reconstruct_shape() {
1256        let tensor = make_tensor(&[5, 4, 3], 8);
1257        let result = tucker1(&tensor, 1, 2).expect("ok");
1258        let recon = result.reconstruct();
1259        assert_eq!(
1260            recon.shape(),
1261            tensor.shape(),
1262            "reconstructed shape must match original"
1263        );
1264    }
1265
1266    // --- CP-ALS tests ---
1267
1268    #[test]
1269    fn test_cp_als_3mode() {
1270        let tensor = make_tensor(&[4, 3, 3], 20);
1271        let result = cp_als(&tensor, 2, 200, ALS_TOL).expect("ok");
1272        assert_eq!(result.num_modes, 3);
1273        assert_eq!(result.rank, 2);
1274        assert!(result.iterations > 0);
1275    }
1276
1277    #[test]
1278    fn test_cp_als_reconstruction_error() {
1279        // Build a rank-2 tensor exactly
1280        let a = make_matrix(4, 2, 30);
1281        let b = make_matrix(3, 2, 31);
1282        let c = make_matrix(3, 2, 32);
1283        // Reconstruct exactly
1284        let mut data = vec![0.0f64; 4 * 3 * 3];
1285        for i in 0..4 {
1286            for j in 0..3 {
1287                for k in 0..3 {
1288                    let v: f64 = (0..2).map(|r| a[[i, r]] * b[[j, r]] * c[[k, r]]).sum();
1289                    data[i * 9 + j * 3 + k] = v;
1290                }
1291            }
1292        }
1293        let tensor = ArrayD::from_shape_vec(IxDyn(&[4, 3, 3]), data).expect("ok");
1294        let result = cp_als(&tensor, 2, 300, 1e-8).expect("ok");
1295        let err = result.reconstruction_error(&tensor);
1296        // Allow some numerical tolerance
1297        assert!(
1298            err < 0.5,
1299            "CP-ALS reconstruction error too large for rank-2 tensor: {err}"
1300        );
1301    }
1302
1303    #[test]
1304    fn test_cp_als_factors_shape() {
1305        let tensor = make_tensor(&[5, 4, 3], 40);
1306        let result = cp_als(&tensor, 2, 100, ALS_TOL).expect("ok");
1307        assert_eq!(result.factors.len(), 3);
1308        assert_eq!(result.factors[0].shape(), &[5, 2]);
1309        assert_eq!(result.factors[1].shape(), &[4, 2]);
1310        assert_eq!(result.factors[2].shape(), &[3, 2]);
1311    }
1312
1313    #[test]
1314    fn test_cp_als_weights_positive() {
1315        let tensor = make_tensor(&[4, 3, 3], 50);
1316        let result = cp_als(&tensor, 2, 100, ALS_TOL).expect("ok");
1317        for (i, &w) in result.weights.iter().enumerate() {
1318            assert!(w >= 0.0, "weight[{i}] should be non-negative, got {w}");
1319        }
1320    }
1321
1322    // --- HOSVD tests ---
1323
1324    #[test]
1325    fn test_hosvd_shape() {
1326        let tensor = make_tensor(&[6, 5, 4], 60);
1327        let ranks = vec![3, 2, 2];
1328        let result = hosvd(&tensor, &ranks).expect("ok");
1329        assert_eq!(result.core.shape(), &[3usize, 2, 2][..]);
1330    }
1331
1332    #[test]
1333    fn test_hosvd_factors_shape() {
1334        let tensor = make_tensor(&[6, 5, 4], 61);
1335        let ranks = vec![3, 2, 2];
1336        let result = hosvd(&tensor, &ranks).expect("ok");
1337        assert_eq!(result.factors.len(), 3);
1338        assert_eq!(result.factors[0].shape(), &[6, 3]);
1339        assert_eq!(result.factors[1].shape(), &[5, 2]);
1340        assert_eq!(result.factors[2].shape(), &[4, 2]);
1341    }
1342
1343    #[test]
1344    fn test_hosvd_reconstruction_error_small() {
1345        // Full-rank HOSVD should reconstruct perfectly
1346        let tensor = make_tensor(&[3, 3, 3], 62);
1347        let ranks = vec![3, 3, 3];
1348        let result = hosvd(&tensor, &ranks).expect("ok");
1349        let err = result.reconstruction_error(&tensor);
1350        assert!(err < 1e-5, "Full-rank HOSVD error should be near 0: {err}");
1351    }
1352
1353    // --- Error display tests ---
1354
1355    #[test]
1356    fn test_decomposition_error_display() {
1357        let errors = vec![
1358            DecompositionError::ShapeError("bad shape".into()),
1359            DecompositionError::ConvergenceFailure {
1360                iterations: 100,
1361                residual: 1e-3,
1362            },
1363            DecompositionError::SingularMatrix,
1364            DecompositionError::InvalidRank {
1365                rank: 10,
1366                max_rank: 5,
1367            },
1368            DecompositionError::EmptyTensor,
1369            DecompositionError::NonMatrixInput { ndim: 3 },
1370        ];
1371        for e in &errors {
1372            let s = format!("{e}");
1373            assert!(!s.is_empty(), "Display for {e:?} should not be empty");
1374        }
1375    }
1376
1377    // --- CP explained variance ---
1378
1379    #[test]
1380    fn test_cp_decomp_explained_variance() {
1381        let tensor = make_tensor(&[4, 3, 3], 70);
1382        let result = cp_als(&tensor, 2, 100, ALS_TOL).expect("ok");
1383        let ev = result.explained_variance(&tensor);
1384        assert!(
1385            (0.0..=1.0).contains(&ev),
1386            "explained variance must be in [0,1], got {ev}"
1387        );
1388    }
1389}