1use crate::csr_array::CsrArray;
10use crate::error::{SparseError, SparseResult};
11use crate::sparray::SparseArray;
12use scirs2_core::numeric::{Float, SparseElement, Zero};
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::ops::Div;
16
17#[derive(Debug, Clone)]
19pub struct SparseTensor<T> {
20 pub indices: Vec<Vec<usize>>,
22 pub values: Vec<T>,
24 pub shape: Vec<usize>,
26}
27
28impl<T> SparseTensor<T>
29where
30 T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
31{
32 pub fn new(indices: Vec<Vec<usize>>, values: Vec<T>, shape: Vec<usize>) -> SparseResult<Self> {
34 if indices.is_empty() {
36 return Err(SparseError::ValueError(
37 "Indices cannot be empty".to_string(),
38 ));
39 }
40
41 let ndim = indices.len();
42 if ndim != shape.len() {
43 return Err(SparseError::ValueError(
44 "Number of index dimensions must match shape dimensions".to_string(),
45 ));
46 }
47
48 let nnz = values.len();
49 for idx_dim in &indices {
50 if idx_dim.len() != nnz {
51 return Err(SparseError::ValueError(
52 "All index dimensions must have same length as values".to_string(),
53 ));
54 }
55 }
56
57 for (dim, idx_vec) in indices.iter().enumerate() {
59 for &idx in idx_vec {
60 if idx >= shape[dim] {
61 return Err(SparseError::ValueError(format!(
62 "Index {} in dimension {} exceeds shape {}",
63 idx, dim, shape[dim]
64 )));
65 }
66 }
67 }
68
69 Ok(Self {
70 indices,
71 values,
72 shape,
73 })
74 }
75
76 pub fn ndim(&self) -> usize {
78 self.shape.len()
79 }
80
81 pub fn nnz(&self) -> usize {
83 self.values.len()
84 }
85
86 pub fn size(&self) -> usize {
88 self.shape.iter().product()
89 }
90
91 pub fn get(&self, indices: &[usize]) -> T {
93 if indices.len() != self.ndim() {
94 return T::sparse_zero();
95 }
96
97 for i in 0..self.nnz() {
99 let mut found = true;
100 for (dim, &idx) in indices.iter().enumerate() {
101 if self.indices[dim][i] != idx {
102 found = false;
103 break;
104 }
105 }
106 if found {
107 return self.values[i];
108 }
109 }
110
111 T::sparse_zero()
112 }
113
114 pub fn unfold(&self, mode: usize) -> SparseResult<CsrArray<T>> {
118 if mode >= self.ndim() {
119 return Err(SparseError::ValueError(format!(
120 "Mode {} exceeds tensor dimensions {}",
121 mode,
122 self.ndim()
123 )));
124 }
125
126 let nrows = self.shape[mode];
128 let ncols: usize = self
129 .shape
130 .iter()
131 .enumerate()
132 .filter(|(i, _)| *i != mode)
133 .map(|(_, &s)| s)
134 .product();
135
136 let mut rows = Vec::new();
138 let mut cols = Vec::new();
139 let mut data = Vec::new();
140
141 for elem_idx in 0..self.nnz() {
142 let row = self.indices[mode][elem_idx];
143
144 let mut col = 0;
146 let mut stride = 1;
147
148 for dim in (0..self.ndim()).rev() {
149 if dim != mode {
150 col += self.indices[dim][elem_idx] * stride;
151 stride *= self.shape[dim];
152 }
153 }
154
155 rows.push(row);
156 cols.push(col);
157 data.push(self.values[elem_idx]);
158 }
159
160 CsrArray::from_triplets(&rows, &cols, &data, (nrows, ncols), false)
161 }
162
163 pub fn fold(matrix: &dyn SparseArray<T>, shape: Vec<usize>, mode: usize) -> SparseResult<Self> {
165 if mode >= shape.len() {
166 return Err(SparseError::ValueError(format!(
167 "Mode {} exceeds tensor dimensions {}",
168 mode,
169 shape.len()
170 )));
171 }
172
173 let (nrows, ncols) = matrix.shape();
174
175 if nrows != shape[mode] {
176 return Err(SparseError::ValueError(
177 "Matrix rows must match mode dimension".to_string(),
178 ));
179 }
180
181 let expected_cols: usize = shape
182 .iter()
183 .enumerate()
184 .filter(|(i, _)| *i != mode)
185 .map(|(_, &s)| s)
186 .product();
187
188 if ncols != expected_cols {
189 return Err(SparseError::ValueError(
190 "Matrix columns must match product of other dimensions".to_string(),
191 ));
192 }
193
194 let (mat_rows, mat_cols, mat_values) = matrix.find();
196
197 let ndim = shape.len();
198 let mut indices = vec![Vec::new(); ndim];
199 let mut values = Vec::new();
200
201 for (i, (&row, &col)) in mat_rows.iter().zip(mat_cols.iter()).enumerate() {
202 indices[mode].push(row);
204
205 let mut remaining = col;
207 let mut other_dims: Vec<usize> = (0..ndim).filter(|&d| d != mode).collect();
208 other_dims.reverse();
209
210 for &dim in &other_dims {
211 let idx = remaining % shape[dim];
212 indices[dim].push(idx);
213 remaining /= shape[dim];
214 }
215
216 values.push(mat_values[i]);
217 }
218
219 Self::new(indices, values, shape)
220 }
221
222 pub fn mode_product(&self, matrix: &CsrArray<T>, mode: usize) -> SparseResult<Self> {
226 if mode >= self.ndim() {
227 return Err(SparseError::ValueError(format!(
228 "Mode {} exceeds tensor dimensions {}",
229 mode,
230 self.ndim()
231 )));
232 }
233
234 let (mat_rows, mat_cols) = matrix.shape();
235 if mat_cols != self.shape[mode] {
236 return Err(SparseError::ValueError(
237 "Matrix columns must match tensor mode dimension".to_string(),
238 ));
239 }
240
241 let unfolded = self.unfold(mode)?;
243
244 let result_matrix = matrix.dot(&unfolded)?;
246
247 let mut new_shape = self.shape.clone();
249 new_shape[mode] = mat_rows;
250
251 Self::fold(result_matrix.as_ref(), new_shape, mode)
253 }
254
255 pub fn inner_product(&self, other: &Self) -> SparseResult<T> {
257 if self.shape != other.shape {
258 return Err(SparseError::ValueError(
259 "Tensors must have the same shape for inner product".to_string(),
260 ));
261 }
262
263 let mut result = T::sparse_zero();
264
265 let mut index_map: HashMap<Vec<usize>, T> = HashMap::new();
267 for i in 0..other.nnz() {
268 let indices: Vec<usize> = (0..self.ndim()).map(|d| other.indices[d][i]).collect();
269 index_map.insert(indices, other.values[i]);
270 }
271
272 for i in 0..self.nnz() {
274 let indices: Vec<usize> = (0..self.ndim()).map(|d| self.indices[d][i]).collect();
275
276 if let Some(&other_val) = index_map.get(&indices) {
277 result = result + self.values[i] * other_val;
278 }
279 }
280
281 Ok(result)
282 }
283
284 pub fn frobenius_norm(&self) -> T {
286 let sum_sq: T = self.values.iter().map(|&v| v * v).sum();
287 sum_sq.sqrt()
288 }
289}
290
291#[derive(Debug, Clone)]
293pub struct TuckerDecomposition<T>
294where
295 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
296{
297 pub core: SparseTensor<T>,
299 pub factors: Vec<CsrArray<T>>,
301}
302
303#[derive(Debug, Clone)]
305pub struct CPDecomposition<T>
306where
307 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
308{
309 pub weights: Vec<T>,
311 pub factors: Vec<CsrArray<T>>,
313 pub rank: usize,
315}
316
317pub fn khatri_rao_product<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
321where
322 T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
323{
324 let (rows_a, cols_a) = a.shape();
325 let (rows_b, cols_b) = b.shape();
326
327 if cols_a != cols_b {
328 return Err(SparseError::ValueError(
329 "Matrices must have the same number of columns for Khatri-Rao product".to_string(),
330 ));
331 }
332
333 let ncols = cols_a;
334 let nrows = rows_a * rows_b;
335
336 let mut result_rows = Vec::new();
337 let mut result_cols = Vec::new();
338 let mut result_data = Vec::new();
339
340 for col in 0..ncols {
342 let mut col_a = vec![T::sparse_zero(); rows_a];
344 let mut col_b = vec![T::sparse_zero(); rows_b];
345
346 for row in 0..rows_a {
347 col_a[row] = a.get(row, col);
348 }
349
350 for row in 0..rows_b {
351 col_b[row] = b.get(row, col);
352 }
353
354 for i in 0..rows_a {
356 for j in 0..rows_b {
357 let value = col_a[i] * col_b[j];
358 if !scirs2_core::SparseElement::is_zero(&value) {
359 result_rows.push(i * rows_b + j);
360 result_cols.push(col);
361 result_data.push(value);
362 }
363 }
364 }
365 }
366
367 CsrArray::from_triplets(
368 &result_rows,
369 &result_cols,
370 &result_data,
371 (nrows, ncols),
372 false,
373 )
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use approx::assert_relative_eq;
380
381 fn create_test_tensor() -> SparseTensor<f64> {
382 let indices = vec![
384 vec![0, 0, 1, 1], vec![0, 1, 0, 2], vec![0, 1, 2, 3], ];
388 let values = vec![1.0, 2.0, 3.0, 4.0];
389 let shape = vec![2, 3, 4];
390
391 SparseTensor::new(indices, values, shape).expect("Failed to create tensor")
392 }
393
394 #[test]
395 fn test_tensor_creation() {
396 let tensor = create_test_tensor();
397
398 assert_eq!(tensor.ndim(), 3);
399 assert_eq!(tensor.nnz(), 4);
400 assert_eq!(tensor.size(), 24);
401 assert_eq!(tensor.shape, vec![2, 3, 4]);
402 }
403
404 #[test]
405 fn test_tensor_get() {
406 let tensor = create_test_tensor();
407
408 assert_relative_eq!(tensor.get(&[0, 0, 0]), 1.0);
409 assert_relative_eq!(tensor.get(&[0, 1, 1]), 2.0);
410 assert_relative_eq!(tensor.get(&[1, 0, 2]), 3.0);
411 assert_relative_eq!(tensor.get(&[1, 2, 3]), 4.0);
412 assert_relative_eq!(tensor.get(&[0, 0, 1]), 0.0); }
414
415 #[test]
416 fn test_unfold() {
417 let tensor = create_test_tensor();
418
419 let unfolded = tensor.unfold(0).expect("Failed to unfold");
421 assert_eq!(unfolded.shape(), (2, 12)); let unfolded1 = tensor.unfold(1).expect("Failed to unfold");
425 assert_eq!(unfolded1.shape(), (3, 8)); let unfolded2 = tensor.unfold(2).expect("Failed to unfold");
429 assert_eq!(unfolded2.shape(), (4, 6)); }
431
432 #[test]
433 fn test_fold_unfold_roundtrip() {
434 let tensor = create_test_tensor();
435
436 for mode in 0..tensor.ndim() {
437 let unfolded = tensor.unfold(mode).expect("Failed to unfold");
438 let refolded =
439 SparseTensor::fold(&unfolded, tensor.shape.clone(), mode).expect("Failed to fold");
440
441 assert_eq!(refolded.nnz(), tensor.nnz());
443
444 for i in 0..tensor.nnz() {
445 let indices: Vec<usize> =
446 (0..tensor.ndim()).map(|d| tensor.indices[d][i]).collect();
447 assert_relative_eq!(
448 tensor.get(&indices),
449 refolded.get(&indices),
450 epsilon = 1e-10
451 );
452 }
453 }
454 }
455
456 #[test]
457 fn test_inner_product() {
458 let tensor1 = create_test_tensor();
459 let tensor2 = create_test_tensor();
460
461 let ip = tensor1.inner_product(&tensor2).expect("Failed");
462
463 let sum_sq: f64 = tensor1.values.iter().map(|&v| v * v).sum();
465 assert_relative_eq!(ip, sum_sq, epsilon = 1e-10);
466 }
467
468 #[test]
469 fn test_frobenius_norm() {
470 let tensor = create_test_tensor();
471
472 let norm = tensor.frobenius_norm();
473
474 let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
476 assert_relative_eq!(norm, expected, epsilon = 1e-10);
477 }
478
479 #[test]
480 fn test_khatri_rao_product() {
481 let rows_a = vec![0, 0, 1];
483 let cols_a = vec![0, 1, 0];
484 let data_a = vec![1.0, 2.0, 3.0];
485 let a = CsrArray::from_triplets(&rows_a, &cols_a, &data_a, (2, 2), false).expect("Failed");
486
487 let rows_b = vec![0, 1, 1];
488 let cols_b = vec![0, 0, 1];
489 let data_b = vec![4.0, 5.0, 6.0];
490 let b = CsrArray::from_triplets(&rows_b, &cols_b, &data_b, (2, 2), false).expect("Failed");
491
492 let result = khatri_rao_product(&a, &b).expect("Failed");
493
494 assert_eq!(result.shape(), (4, 2));
496 assert!(result.nnz() > 0);
497 }
498}