sklears_cross_decomposition/tensor_methods/
sparse_tensor.rs1use super::common::{Trained, Untrained};
4use scirs2_core::ndarray::{Array1, Array2, Array3};
5use scirs2_core::random::{thread_rng, Rng};
6use sklears_core::{
7 error::{Result, SklearsError},
8 traits::{Estimator, Fit, Transform},
9 types::Float,
10};
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone)]
31pub struct SparseTensorDecomposition<State = Untrained> {
32 pub n_factors: usize,
34 pub max_iter: usize,
36 pub tol: Float,
38 pub sparsity_penalty: Float,
40 pub regularization: Float,
42 pub sparsity_threshold: Float,
44 factor_matrices_: Option<Vec<Array2<Float>>>,
46 original_shape_: Option<Vec<usize>>,
48 sparsity_levels_: Option<Array1<Float>>,
50 reconstruction_error_: Option<Float>,
52 n_iter_: Option<usize>,
54 _state: PhantomData<State>,
56}
57
58impl SparseTensorDecomposition<Untrained> {
59 pub fn new(n_factors: usize) -> Self {
61 Self {
62 n_factors,
63 max_iter: 100,
64 tol: 1e-6,
65 sparsity_penalty: 0.01,
66 regularization: 0.001,
67 sparsity_threshold: 1e-8,
68 factor_matrices_: None,
69 original_shape_: None,
70 sparsity_levels_: None,
71 reconstruction_error_: None,
72 n_iter_: None,
73 _state: PhantomData,
74 }
75 }
76
77 pub fn sparsity_penalty(mut self, penalty: Float) -> Self {
79 self.sparsity_penalty = penalty;
80 self
81 }
82
83 pub fn regularization(mut self, regularization: Float) -> Self {
85 self.regularization = regularization;
86 self
87 }
88
89 pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
91 self.sparsity_threshold = threshold;
92 self
93 }
94
95 pub fn max_iter(mut self, max_iter: usize) -> Self {
97 self.max_iter = max_iter;
98 self
99 }
100
101 pub fn tol(mut self, tol: Float) -> Self {
103 self.tol = tol;
104 self
105 }
106}
107
108impl Estimator for SparseTensorDecomposition<Untrained> {
109 type Config = ();
110 type Error = SklearsError;
111 type Float = Float;
112
113 fn config(&self) -> &Self::Config {
114 &()
115 }
116}
117
118impl Fit<Array3<Float>, ()> for SparseTensorDecomposition<Untrained> {
119 type Fitted = SparseTensorDecomposition<Trained>;
120
121 fn fit(self, tensor: &Array3<Float>, _target: &()) -> Result<Self::Fitted> {
122 let shape = tensor.shape();
123
124 let mut factor_matrices = Vec::new();
126 for mode in 0..3 {
127 let mut factor = Array2::zeros((shape[mode], self.n_factors));
128 for i in 0..shape[mode] {
129 for j in 0..self.n_factors {
130 factor[[i, j]] = thread_rng().random::<Float>() * 0.01;
131 }
132 }
133 factor_matrices.push(factor);
134 }
135
136 let mut converged = false;
137 let mut n_iter = 0;
138 let mut prev_error = Float::INFINITY;
139
140 while !converged && n_iter < self.max_iter {
142 let old_factors = factor_matrices.clone();
143
144 for mode in 0..3 {
146 factor_matrices[mode] =
147 self.update_sparse_factor(tensor, &factor_matrices, mode)?;
148
149 self.apply_soft_thresholding(&mut factor_matrices[mode]);
151 }
152
153 let reconstructed = self.reconstruct_sparse_tensor(&factor_matrices, shape)?;
155 let error = (tensor - &reconstructed).mapv(|x| x * x).sum().sqrt();
156
157 if (prev_error - error).abs() < self.tol {
159 converged = true;
160 }
161
162 let mut max_factor_change: Float = 0.0;
164 for mode in 0..3 {
165 let change = (&factor_matrices[mode] - &old_factors[mode])
166 .mapv(|x| x.abs())
167 .sum();
168 max_factor_change = max_factor_change.max(change);
169 }
170
171 if max_factor_change < self.tol {
172 converged = true;
173 }
174
175 prev_error = error;
176 n_iter += 1;
177 }
178
179 let mut sparsity_levels = Array1::zeros(3);
181 for mode in 0..3 {
182 let total_elements = factor_matrices[mode].len();
183 let sparse_elements = factor_matrices[mode]
184 .iter()
185 .filter(|&&x| x.abs() < self.sparsity_threshold)
186 .count();
187 sparsity_levels[mode] = sparse_elements as Float / total_elements as Float;
188 }
189
190 Ok(SparseTensorDecomposition {
191 n_factors: self.n_factors,
192 max_iter: self.max_iter,
193 tol: self.tol,
194 sparsity_penalty: self.sparsity_penalty,
195 regularization: self.regularization,
196 sparsity_threshold: self.sparsity_threshold,
197 factor_matrices_: Some(factor_matrices),
198 original_shape_: Some(shape.to_vec()),
199 sparsity_levels_: Some(sparsity_levels),
200 reconstruction_error_: Some(prev_error),
201 n_iter_: Some(n_iter),
202 _state: PhantomData,
203 })
204 }
205}
206
207impl SparseTensorDecomposition<Untrained> {
208 fn update_sparse_factor(
210 &self,
211 tensor: &Array3<Float>,
212 factors: &[Array2<Float>],
213 mode: usize,
214 ) -> Result<Array2<Float>> {
215 let shape = tensor.shape();
216 let mut new_factor = Array2::zeros((shape[mode], self.n_factors));
217
218 for r in 0..self.n_factors {
220 let mut factor_col = Array1::zeros(shape[mode]);
221
222 match mode {
223 0 => {
224 for i in 0..shape[0] {
225 let mut numerator = 0.0;
226 let mut denominator = 0.0;
227
228 for j in 0..shape[1] {
229 for k in 0..shape[2] {
230 let coeff = factors[1][[j, r]] * factors[2][[k, r]];
231 numerator += tensor[[i, j, k]] * coeff;
232 denominator += coeff * coeff;
233 }
234 }
235
236 if denominator > self.tol {
237 factor_col[i] = numerator / (denominator + self.regularization);
238 }
239 }
240 }
241 1 => {
242 for j in 0..shape[1] {
243 let mut numerator = 0.0;
244 let mut denominator = 0.0;
245
246 for i in 0..shape[0] {
247 for k in 0..shape[2] {
248 let coeff = factors[0][[i, r]] * factors[2][[k, r]];
249 numerator += tensor[[i, j, k]] * coeff;
250 denominator += coeff * coeff;
251 }
252 }
253
254 if denominator > self.tol {
255 factor_col[j] = numerator / (denominator + self.regularization);
256 }
257 }
258 }
259 2 => {
260 for k in 0..shape[2] {
261 let mut numerator = 0.0;
262 let mut denominator = 0.0;
263
264 for i in 0..shape[0] {
265 for j in 0..shape[1] {
266 let coeff = factors[0][[i, r]] * factors[1][[j, r]];
267 numerator += tensor[[i, j, k]] * coeff;
268 denominator += coeff * coeff;
269 }
270 }
271
272 if denominator > self.tol {
273 factor_col[k] = numerator / (denominator + self.regularization);
274 }
275 }
276 }
277 _ => return Err(SklearsError::InvalidInput("Invalid mode".to_string())),
278 }
279
280 new_factor.column_mut(r).assign(&factor_col);
281 }
282
283 Ok(new_factor)
284 }
285
286 fn apply_soft_thresholding(&self, factor: &mut Array2<Float>) {
288 let threshold = self.sparsity_penalty;
289 factor.mapv_inplace(|x| {
290 if x > threshold {
291 x - threshold
292 } else if x < -threshold {
293 x + threshold
294 } else {
295 0.0
296 }
297 });
298 }
299
300 fn reconstruct_sparse_tensor(
302 &self,
303 factors: &[Array2<Float>],
304 shape: &[usize],
305 ) -> Result<Array3<Float>> {
306 let mut reconstructed = Array3::zeros((shape[0], shape[1], shape[2]));
307
308 for r in 0..self.n_factors {
309 let a = factors[0].column(r);
310 let b = factors[1].column(r);
311 let c = factors[2].column(r);
312
313 for i in 0..shape[0] {
314 for j in 0..shape[1] {
315 for k in 0..shape[2] {
316 reconstructed[[i, j, k]] += a[i] * b[j] * c[k];
317 }
318 }
319 }
320 }
321
322 Ok(reconstructed)
323 }
324}
325
326impl Transform<Array3<Float>, Array3<Float>> for SparseTensorDecomposition<Trained> {
327 fn transform(&self, tensor: &Array3<Float>) -> Result<Array3<Float>> {
329 let factors = self.factor_matrices_.as_ref().unwrap();
330 let shape = tensor.shape();
331 self.reconstruct_sparse_tensor(factors, shape)
332 }
333}
334
335impl SparseTensorDecomposition<Trained> {
336 pub fn factor_matrices(&self) -> &Vec<Array2<Float>> {
338 self.factor_matrices_.as_ref().unwrap()
339 }
340
341 pub fn sparsity_levels(&self) -> &Array1<Float> {
343 self.sparsity_levels_.as_ref().unwrap()
344 }
345
346 pub fn reconstruction_error(&self) -> Float {
348 self.reconstruction_error_.unwrap()
349 }
350
351 pub fn n_iter(&self) -> usize {
353 self.n_iter_.unwrap()
354 }
355
356 fn reconstruct_sparse_tensor(
358 &self,
359 factors: &[Array2<Float>],
360 shape: &[usize],
361 ) -> Result<Array3<Float>> {
362 let mut reconstructed = Array3::zeros((shape[0], shape[1], shape[2]));
363
364 for r in 0..self.n_factors {
365 let a = factors[0].column(r);
366 let b = factors[1].column(r);
367 let c = factors[2].column(r);
368
369 for i in 0..shape[0] {
370 for j in 0..shape[1] {
371 for k in 0..shape[2] {
372 reconstructed[[i, j, k]] += a[i] * b[j] * c[k];
373 }
374 }
375 }
376 }
377
378 Ok(reconstructed)
379 }
380}
381
382#[allow(non_snake_case)]
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use scirs2_core::ndarray::Array3;
387 use sklears_core::traits::Fit;
388
389 #[test]
390 fn test_sparse_tensor_decomposition_basic() {
391 let tensor = Array3::from_shape_fn((5, 4, 3), |(i, j, k)| {
392 if (i + j + k) % 3 == 0 {
393 (i + j + k) as Float
394 } else {
395 0.0
396 }
397 });
398
399 let sparse_decomp = SparseTensorDecomposition::new(2)
400 .sparsity_penalty(0.1)
401 .max_iter(50);
402 let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
403
404 assert_eq!(fitted.factor_matrices().len(), 3);
405 assert_eq!(fitted.factor_matrices()[0].shape(), &[5, 2]);
406 assert_eq!(fitted.factor_matrices()[1].shape(), &[4, 2]);
407 assert_eq!(fitted.factor_matrices()[2].shape(), &[3, 2]);
408 assert!(fitted.n_iter() > 0);
409 assert!(fitted.reconstruction_error() >= 0.0);
410
411 let sparsity = fitted.sparsity_levels();
413 assert_eq!(sparsity.len(), 3);
414 for &level in sparsity.iter() {
415 assert!(level >= 0.0 && level <= 1.0);
416 }
417 }
418
419 #[test]
420 fn test_sparse_tensor_decomposition_sparsity() {
421 let tensor = Array3::from_shape_fn(
422 (4, 4, 4),
423 |(i, j, k)| {
424 if i == j && j == k {
425 1.0
426 } else {
427 0.0
428 }
429 },
430 );
431
432 let sparse_decomp = SparseTensorDecomposition::new(1)
433 .sparsity_penalty(0.05)
434 .regularization(0.01)
435 .sparsity_threshold(1e-6);
436 let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
437
438 let sparsity = fitted.sparsity_levels();
440 let avg_sparsity = sparsity.mean().unwrap();
441 assert!(
442 avg_sparsity > 0.0,
443 "Expected some sparsity but got {}",
444 avg_sparsity
445 );
446 }
447
448 #[test]
449 fn test_sparse_tensor_decomposition_transform() {
450 let tensor = Array3::from_shape_fn((4, 3, 2), |(i, j, k)| (i + j + k) as Float * 0.1);
451
452 let sparse_decomp = SparseTensorDecomposition::new(2);
453 let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
454
455 let reconstructed = fitted.transform(&tensor).unwrap();
456 assert_eq!(reconstructed.shape(), tensor.shape());
457 }
458
459 #[test]
460 fn test_sparse_tensor_configuration() {
461 let tensor = Array3::ones((3, 3, 3));
462
463 let sparse_decomp = SparseTensorDecomposition::new(1)
464 .sparsity_penalty(0.2)
465 .regularization(0.05)
466 .sparsity_threshold(1e-5)
467 .max_iter(20)
468 .tol(1e-4);
469
470 let fitted = sparse_decomp.fit(&tensor, &()).unwrap();
471 assert!(fitted.n_iter() <= 20);
472 }
473}