sklears_cross_decomposition/tensor_methods/
sparse_tensor.rs

1//! Sparse Tensor Decomposition implementation
2
3use super::common::{Trained, Untrained};
4use scirs2_core::ndarray::{Array1, Array2, Array3};
5use scirs2_core::random::{thread_rng, Rng};
6use sklears_core::{
7    error::{Result, SklearsError},
8    traits::{Estimator, Fit, Transform},
9    types::Float,
10};
11use std::marker::PhantomData;
12
13/// Sparse Tensor Decomposition
14///
15/// Decomposes a sparse tensor using CP decomposition with sparsity constraints.
16/// Handles tensors with many zero entries efficiently and can enforce sparsity
17/// in the factor matrices through L1 regularization.
18///
19/// # Examples
20///
21/// ```rust
22/// use scirs2_core::ndarray::Array3;
23/// use sklears_cross_decomposition::SparseTensorDecomposition;
24/// use sklears_core::traits::Fit;
25///
26/// let tensor = Array3::zeros((20, 15, 10));
27/// let sparse_decomp = SparseTensorDecomposition::new(5).sparsity_penalty(0.1);
28/// let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
29/// ```
30#[derive(Debug, Clone)]
31pub struct SparseTensorDecomposition<State = Untrained> {
32    /// Number of factors
33    pub n_factors: usize,
34    /// Maximum number of iterations
35    pub max_iter: usize,
36    /// Convergence tolerance
37    pub tol: Float,
38    /// L1 sparsity penalty
39    pub sparsity_penalty: Float,
40    /// L2 regularization
41    pub regularization: Float,
42    /// Sparsity threshold (values below this are set to zero)
43    pub sparsity_threshold: Float,
44    /// Factor matrices
45    factor_matrices_: Option<Vec<Array2<Float>>>,
46    /// Original tensor shape
47    original_shape_: Option<Vec<usize>>,
48    /// Sparsity levels achieved
49    sparsity_levels_: Option<Array1<Float>>,
50    /// Reconstruction error
51    reconstruction_error_: Option<Float>,
52    /// Number of iterations
53    n_iter_: Option<usize>,
54    /// State marker
55    _state: PhantomData<State>,
56}
57
58impl SparseTensorDecomposition<Untrained> {
59    /// Create a new sparse tensor decomposition
60    pub fn new(n_factors: usize) -> Self {
61        Self {
62            n_factors,
63            max_iter: 100,
64            tol: 1e-6,
65            sparsity_penalty: 0.01,
66            regularization: 0.001,
67            sparsity_threshold: 1e-8,
68            factor_matrices_: None,
69            original_shape_: None,
70            sparsity_levels_: None,
71            reconstruction_error_: None,
72            n_iter_: None,
73            _state: PhantomData,
74        }
75    }
76
77    /// Set sparsity penalty (L1 regularization)
78    pub fn sparsity_penalty(mut self, penalty: Float) -> Self {
79        self.sparsity_penalty = penalty;
80        self
81    }
82
83    /// Set L2 regularization
84    pub fn regularization(mut self, regularization: Float) -> Self {
85        self.regularization = regularization;
86        self
87    }
88
89    /// Set sparsity threshold
90    pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
91        self.sparsity_threshold = threshold;
92        self
93    }
94
95    /// Set maximum iterations
96    pub fn max_iter(mut self, max_iter: usize) -> Self {
97        self.max_iter = max_iter;
98        self
99    }
100
101    /// Set tolerance
102    pub fn tol(mut self, tol: Float) -> Self {
103        self.tol = tol;
104        self
105    }
106}
107
108impl Estimator for SparseTensorDecomposition<Untrained> {
109    type Config = ();
110    type Error = SklearsError;
111    type Float = Float;
112
113    fn config(&self) -> &Self::Config {
114        &()
115    }
116}
117
118impl Fit<Array3<Float>, ()> for SparseTensorDecomposition<Untrained> {
119    type Fitted = SparseTensorDecomposition<Trained>;
120
121    fn fit(self, tensor: &Array3<Float>, _target: &()) -> Result<Self::Fitted> {
122        let shape = tensor.shape();
123
124        // Initialize factor matrices with small random values
125        let mut factor_matrices = Vec::new();
126        for mode in 0..3 {
127            let mut factor = Array2::zeros((shape[mode], self.n_factors));
128            for i in 0..shape[mode] {
129                for j in 0..self.n_factors {
130                    factor[[i, j]] = thread_rng().random::<Float>() * 0.01;
131                }
132            }
133            factor_matrices.push(factor);
134        }
135
136        let mut converged = false;
137        let mut n_iter = 0;
138        let mut prev_error = Float::INFINITY;
139
140        // Sparse alternating least squares
141        while !converged && n_iter < self.max_iter {
142            let old_factors = factor_matrices.clone();
143
144            // Update each factor matrix with sparsity constraints
145            for mode in 0..3 {
146                factor_matrices[mode] =
147                    self.update_sparse_factor(tensor, &factor_matrices, mode)?;
148
149                // Apply soft thresholding for sparsity
150                self.apply_soft_thresholding(&mut factor_matrices[mode]);
151            }
152
153            // Compute reconstruction error
154            let reconstructed = self.reconstruct_sparse_tensor(&factor_matrices, shape)?;
155            let error = (tensor - &reconstructed).mapv(|x| x * x).sum().sqrt();
156
157            // Check convergence
158            if (prev_error - error).abs() < self.tol {
159                converged = true;
160            }
161
162            // Also check factor convergence
163            let mut max_factor_change: Float = 0.0;
164            for mode in 0..3 {
165                let change = (&factor_matrices[mode] - &old_factors[mode])
166                    .mapv(|x| x.abs())
167                    .sum();
168                max_factor_change = max_factor_change.max(change);
169            }
170
171            if max_factor_change < self.tol {
172                converged = true;
173            }
174
175            prev_error = error;
176            n_iter += 1;
177        }
178
179        // Compute sparsity levels
180        let mut sparsity_levels = Array1::zeros(3);
181        for mode in 0..3 {
182            let total_elements = factor_matrices[mode].len();
183            let sparse_elements = factor_matrices[mode]
184                .iter()
185                .filter(|&&x| x.abs() < self.sparsity_threshold)
186                .count();
187            sparsity_levels[mode] = sparse_elements as Float / total_elements as Float;
188        }
189
190        Ok(SparseTensorDecomposition {
191            n_factors: self.n_factors,
192            max_iter: self.max_iter,
193            tol: self.tol,
194            sparsity_penalty: self.sparsity_penalty,
195            regularization: self.regularization,
196            sparsity_threshold: self.sparsity_threshold,
197            factor_matrices_: Some(factor_matrices),
198            original_shape_: Some(shape.to_vec()),
199            sparsity_levels_: Some(sparsity_levels),
200            reconstruction_error_: Some(prev_error),
201            n_iter_: Some(n_iter),
202            _state: PhantomData,
203        })
204    }
205}
206
207impl SparseTensorDecomposition<Untrained> {
208    /// Update factor matrix with sparsity constraints
209    fn update_sparse_factor(
210        &self,
211        tensor: &Array3<Float>,
212        factors: &[Array2<Float>],
213        mode: usize,
214    ) -> Result<Array2<Float>> {
215        let shape = tensor.shape();
216        let mut new_factor = Array2::zeros((shape[mode], self.n_factors));
217
218        // Simplified sparse update using coordinate descent
219        for r in 0..self.n_factors {
220            let mut factor_col = Array1::zeros(shape[mode]);
221
222            match mode {
223                0 => {
224                    for i in 0..shape[0] {
225                        let mut numerator = 0.0;
226                        let mut denominator = 0.0;
227
228                        for j in 0..shape[1] {
229                            for k in 0..shape[2] {
230                                let coeff = factors[1][[j, r]] * factors[2][[k, r]];
231                                numerator += tensor[[i, j, k]] * coeff;
232                                denominator += coeff * coeff;
233                            }
234                        }
235
236                        if denominator > self.tol {
237                            factor_col[i] = numerator / (denominator + self.regularization);
238                        }
239                    }
240                }
241                1 => {
242                    for j in 0..shape[1] {
243                        let mut numerator = 0.0;
244                        let mut denominator = 0.0;
245
246                        for i in 0..shape[0] {
247                            for k in 0..shape[2] {
248                                let coeff = factors[0][[i, r]] * factors[2][[k, r]];
249                                numerator += tensor[[i, j, k]] * coeff;
250                                denominator += coeff * coeff;
251                            }
252                        }
253
254                        if denominator > self.tol {
255                            factor_col[j] = numerator / (denominator + self.regularization);
256                        }
257                    }
258                }
259                2 => {
260                    for k in 0..shape[2] {
261                        let mut numerator = 0.0;
262                        let mut denominator = 0.0;
263
264                        for i in 0..shape[0] {
265                            for j in 0..shape[1] {
266                                let coeff = factors[0][[i, r]] * factors[1][[j, r]];
267                                numerator += tensor[[i, j, k]] * coeff;
268                                denominator += coeff * coeff;
269                            }
270                        }
271
272                        if denominator > self.tol {
273                            factor_col[k] = numerator / (denominator + self.regularization);
274                        }
275                    }
276                }
277                _ => return Err(SklearsError::InvalidInput("Invalid mode".to_string())),
278            }
279
280            new_factor.column_mut(r).assign(&factor_col);
281        }
282
283        Ok(new_factor)
284    }
285
286    /// Apply soft thresholding for L1 sparsity
287    fn apply_soft_thresholding(&self, factor: &mut Array2<Float>) {
288        let threshold = self.sparsity_penalty;
289        factor.mapv_inplace(|x| {
290            if x > threshold {
291                x - threshold
292            } else if x < -threshold {
293                x + threshold
294            } else {
295                0.0
296            }
297        });
298    }
299
300    /// Reconstruct tensor from sparse factors
301    fn reconstruct_sparse_tensor(
302        &self,
303        factors: &[Array2<Float>],
304        shape: &[usize],
305    ) -> Result<Array3<Float>> {
306        let mut reconstructed = Array3::zeros((shape[0], shape[1], shape[2]));
307
308        for r in 0..self.n_factors {
309            let a = factors[0].column(r);
310            let b = factors[1].column(r);
311            let c = factors[2].column(r);
312
313            for i in 0..shape[0] {
314                for j in 0..shape[1] {
315                    for k in 0..shape[2] {
316                        reconstructed[[i, j, k]] += a[i] * b[j] * c[k];
317                    }
318                }
319            }
320        }
321
322        Ok(reconstructed)
323    }
324}
325
326impl Transform<Array3<Float>, Array3<Float>> for SparseTensorDecomposition<Trained> {
327    /// Reconstruct tensor using sparse factors
328    fn transform(&self, tensor: &Array3<Float>) -> Result<Array3<Float>> {
329        let factors = self.factor_matrices_.as_ref().unwrap();
330        let shape = tensor.shape();
331        self.reconstruct_sparse_tensor(factors, shape)
332    }
333}
334
335impl SparseTensorDecomposition<Trained> {
336    /// Get the factor matrices
337    pub fn factor_matrices(&self) -> &Vec<Array2<Float>> {
338        self.factor_matrices_.as_ref().unwrap()
339    }
340
341    /// Get the sparsity levels for each mode
342    pub fn sparsity_levels(&self) -> &Array1<Float> {
343        self.sparsity_levels_.as_ref().unwrap()
344    }
345
346    /// Get the reconstruction error
347    pub fn reconstruction_error(&self) -> Float {
348        self.reconstruction_error_.unwrap()
349    }
350
351    /// Get the number of iterations
352    pub fn n_iter(&self) -> usize {
353        self.n_iter_.unwrap()
354    }
355
356    /// Helper method for sparse reconstruction
357    fn reconstruct_sparse_tensor(
358        &self,
359        factors: &[Array2<Float>],
360        shape: &[usize],
361    ) -> Result<Array3<Float>> {
362        let mut reconstructed = Array3::zeros((shape[0], shape[1], shape[2]));
363
364        for r in 0..self.n_factors {
365            let a = factors[0].column(r);
366            let b = factors[1].column(r);
367            let c = factors[2].column(r);
368
369            for i in 0..shape[0] {
370                for j in 0..shape[1] {
371                    for k in 0..shape[2] {
372                        reconstructed[[i, j, k]] += a[i] * b[j] * c[k];
373                    }
374                }
375            }
376        }
377
378        Ok(reconstructed)
379    }
380}
381
382#[allow(non_snake_case)]
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use scirs2_core::ndarray::Array3;
387    use sklears_core::traits::Fit;
388
389    #[test]
390    fn test_sparse_tensor_decomposition_basic() {
391        let tensor = Array3::from_shape_fn((5, 4, 3), |(i, j, k)| {
392            if (i + j + k) % 3 == 0 {
393                (i + j + k) as Float
394            } else {
395                0.0
396            }
397        });
398
399        let sparse_decomp = SparseTensorDecomposition::new(2)
400            .sparsity_penalty(0.1)
401            .max_iter(50);
402        let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
403
404        assert_eq!(fitted.factor_matrices().len(), 3);
405        assert_eq!(fitted.factor_matrices()[0].shape(), &[5, 2]);
406        assert_eq!(fitted.factor_matrices()[1].shape(), &[4, 2]);
407        assert_eq!(fitted.factor_matrices()[2].shape(), &[3, 2]);
408        assert!(fitted.n_iter() > 0);
409        assert!(fitted.reconstruction_error() >= 0.0);
410
411        // Check sparsity levels
412        let sparsity = fitted.sparsity_levels();
413        assert_eq!(sparsity.len(), 3);
414        for &level in sparsity.iter() {
415            assert!(level >= 0.0 && level <= 1.0);
416        }
417    }
418
419    #[test]
420    fn test_sparse_tensor_decomposition_sparsity() {
421        let tensor = Array3::from_shape_fn(
422            (4, 4, 4),
423            |(i, j, k)| {
424                if i == j && j == k {
425                    1.0
426                } else {
427                    0.0
428                }
429            },
430        );
431
432        let sparse_decomp = SparseTensorDecomposition::new(1)
433            .sparsity_penalty(0.05)
434            .regularization(0.01)
435            .sparsity_threshold(1e-6);
436        let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
437
438        // Should achieve some level of sparsity
439        let sparsity = fitted.sparsity_levels();
440        let avg_sparsity = sparsity.mean().unwrap();
441        assert!(
442            avg_sparsity > 0.0,
443            "Expected some sparsity but got {}",
444            avg_sparsity
445        );
446    }
447
448    #[test]
449    fn test_sparse_tensor_decomposition_transform() {
450        let tensor = Array3::from_shape_fn((4, 3, 2), |(i, j, k)| (i + j + k) as Float * 0.1);
451
452        let sparse_decomp = SparseTensorDecomposition::new(2);
453        let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
454
455        let reconstructed = fitted.transform(&tensor).unwrap();
456        assert_eq!(reconstructed.shape(), tensor.shape());
457    }
458
459    #[test]
460    fn test_sparse_tensor_configuration() {
461        let tensor = Array3::ones((3, 3, 3));
462
463        let sparse_decomp = SparseTensorDecomposition::new(1)
464            .sparsity_penalty(0.2)
465            .regularization(0.05)
466            .sparsity_threshold(1e-5)
467            .max_iter(20)
468            .tol(1e-4);
469
470        let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
471        assert!(fitted.n_iter() <= 20);
472    }
473}