Skip to main content

scirs2_sparse/
tensor_sparse.rs

1//! Tensor-based sparse operations
2//!
3//! This module provides operations for sparse tensors (multi-dimensional sparse arrays):
4//! - Sparse tensor construction and manipulation
5//! - Tensor contractions and products
6//! - Mode-n unfolding and folding
7//! - Tucker and CP decompositions
8
9use crate::csr_array::CsrArray;
10use crate::error::{SparseError, SparseResult};
11use crate::sparray::SparseArray;
12use scirs2_core::numeric::{Float, SparseElement, Zero};
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::ops::Div;
16
17/// Sparse tensor in COO (coordinate) format
18#[derive(Debug, Clone)]
19pub struct SparseTensor<T> {
20    /// Indices for each non-zero element (one vector per dimension)
21    pub indices: Vec<Vec<usize>>,
22    /// Values of non-zero elements
23    pub values: Vec<T>,
24    /// Shape of the tensor
25    pub shape: Vec<usize>,
26}
27
28impl<T> SparseTensor<T>
29where
30    T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
31{
32    /// Create a new sparse tensor
33    pub fn new(indices: Vec<Vec<usize>>, values: Vec<T>, shape: Vec<usize>) -> SparseResult<Self> {
34        // Validate inputs
35        if indices.is_empty() {
36            return Err(SparseError::ValueError(
37                "Indices cannot be empty".to_string(),
38            ));
39        }
40
41        let ndim = indices.len();
42        if ndim != shape.len() {
43            return Err(SparseError::ValueError(
44                "Number of index dimensions must match shape dimensions".to_string(),
45            ));
46        }
47
48        let nnz = values.len();
49        for idx_dim in &indices {
50            if idx_dim.len() != nnz {
51                return Err(SparseError::ValueError(
52                    "All index dimensions must have same length as values".to_string(),
53                ));
54            }
55        }
56
57        // Validate indices are within bounds
58        for (dim, idx_vec) in indices.iter().enumerate() {
59            for &idx in idx_vec {
60                if idx >= shape[dim] {
61                    return Err(SparseError::ValueError(format!(
62                        "Index {} in dimension {} exceeds shape {}",
63                        idx, dim, shape[dim]
64                    )));
65                }
66            }
67        }
68
69        Ok(Self {
70            indices,
71            values,
72            shape,
73        })
74    }
75
76    /// Get the number of dimensions
77    pub fn ndim(&self) -> usize {
78        self.shape.len()
79    }
80
81    /// Get the number of non-zero elements
82    pub fn nnz(&self) -> usize {
83        self.values.len()
84    }
85
86    /// Get the total number of elements
87    pub fn size(&self) -> usize {
88        self.shape.iter().product()
89    }
90
91    /// Get an element at the specified indices
92    pub fn get(&self, indices: &[usize]) -> T {
93        if indices.len() != self.ndim() {
94            return T::sparse_zero();
95        }
96
97        // Search for the element
98        for i in 0..self.nnz() {
99            let mut found = true;
100            for (dim, &idx) in indices.iter().enumerate() {
101                if self.indices[dim][i] != idx {
102                    found = false;
103                    break;
104                }
105            }
106            if found {
107                return self.values[i];
108            }
109        }
110
111        T::sparse_zero()
112    }
113
114    /// Mode-n unfolding (matricization)
115    ///
116    /// Unfolds the tensor along the specified mode into a matrix.
117    pub fn unfold(&self, mode: usize) -> SparseResult<CsrArray<T>> {
118        if mode >= self.ndim() {
119            return Err(SparseError::ValueError(format!(
120                "Mode {} exceeds tensor dimensions {}",
121                mode,
122                self.ndim()
123            )));
124        }
125
126        // Calculate matrix dimensions
127        let nrows = self.shape[mode];
128        let ncols: usize = self
129            .shape
130            .iter()
131            .enumerate()
132            .filter(|(i, _)| *i != mode)
133            .map(|(_, &s)| s)
134            .product();
135
136        // Build row/col/data for matrix
137        let mut rows = Vec::new();
138        let mut cols = Vec::new();
139        let mut data = Vec::new();
140
141        for elem_idx in 0..self.nnz() {
142            let row = self.indices[mode][elem_idx];
143
144            // Calculate column index from other dimensions
145            let mut col = 0;
146            let mut stride = 1;
147
148            for dim in (0..self.ndim()).rev() {
149                if dim != mode {
150                    col += self.indices[dim][elem_idx] * stride;
151                    stride *= self.shape[dim];
152                }
153            }
154
155            rows.push(row);
156            cols.push(col);
157            data.push(self.values[elem_idx]);
158        }
159
160        CsrArray::from_triplets(&rows, &cols, &data, (nrows, ncols), false)
161    }
162
163    /// Fold a matrix back into a tensor along the specified mode
164    pub fn fold(matrix: &dyn SparseArray<T>, shape: Vec<usize>, mode: usize) -> SparseResult<Self> {
165        if mode >= shape.len() {
166            return Err(SparseError::ValueError(format!(
167                "Mode {} exceeds tensor dimensions {}",
168                mode,
169                shape.len()
170            )));
171        }
172
173        let (nrows, ncols) = matrix.shape();
174
175        if nrows != shape[mode] {
176            return Err(SparseError::ValueError(
177                "Matrix rows must match mode dimension".to_string(),
178            ));
179        }
180
181        let expected_cols: usize = shape
182            .iter()
183            .enumerate()
184            .filter(|(i, _)| *i != mode)
185            .map(|(_, &s)| s)
186            .product();
187
188        if ncols != expected_cols {
189            return Err(SparseError::ValueError(
190                "Matrix columns must match product of other dimensions".to_string(),
191            ));
192        }
193
194        // Get non-zero elements from matrix
195        let (mat_rows, mat_cols, mat_values) = matrix.find();
196
197        let ndim = shape.len();
198        let mut indices = vec![Vec::new(); ndim];
199        let mut values = Vec::new();
200
201        for (i, (&row, &col)) in mat_rows.iter().zip(mat_cols.iter()).enumerate() {
202            // Set mode index
203            indices[mode].push(row);
204
205            // Decode column into other dimension indices
206            let mut remaining = col;
207            let mut other_dims: Vec<usize> = (0..ndim).filter(|&d| d != mode).collect();
208            other_dims.reverse();
209
210            for &dim in &other_dims {
211                let idx = remaining % shape[dim];
212                indices[dim].push(idx);
213                remaining /= shape[dim];
214            }
215
216            values.push(mat_values[i]);
217        }
218
219        Self::new(indices, values, shape)
220    }
221
222    /// Tensor-matrix product along specified mode
223    ///
224    /// Multiplies the tensor by a matrix along the given mode.
225    pub fn mode_product(&self, matrix: &CsrArray<T>, mode: usize) -> SparseResult<Self> {
226        if mode >= self.ndim() {
227            return Err(SparseError::ValueError(format!(
228                "Mode {} exceeds tensor dimensions {}",
229                mode,
230                self.ndim()
231            )));
232        }
233
234        let (mat_rows, mat_cols) = matrix.shape();
235        if mat_cols != self.shape[mode] {
236            return Err(SparseError::ValueError(
237                "Matrix columns must match tensor mode dimension".to_string(),
238            ));
239        }
240
241        // Unfold tensor along mode
242        let unfolded = self.unfold(mode)?;
243
244        // Multiply: result = matrix * unfolded
245        let result_matrix = matrix.dot(&unfolded)?;
246
247        // Update shape with new mode dimension
248        let mut new_shape = self.shape.clone();
249        new_shape[mode] = mat_rows;
250
251        // Fold back into tensor
252        Self::fold(result_matrix.as_ref(), new_shape, mode)
253    }
254
255    /// Inner product of two sparse tensors
256    pub fn inner_product(&self, other: &Self) -> SparseResult<T> {
257        if self.shape != other.shape {
258            return Err(SparseError::ValueError(
259                "Tensors must have the same shape for inner product".to_string(),
260            ));
261        }
262
263        let mut result = T::sparse_zero();
264
265        // Build index map for efficient lookup
266        let mut index_map: HashMap<Vec<usize>, T> = HashMap::new();
267        for i in 0..other.nnz() {
268            let indices: Vec<usize> = (0..self.ndim()).map(|d| other.indices[d][i]).collect();
269            index_map.insert(indices, other.values[i]);
270        }
271
272        // Sum products of matching non-zeros
273        for i in 0..self.nnz() {
274            let indices: Vec<usize> = (0..self.ndim()).map(|d| self.indices[d][i]).collect();
275
276            if let Some(&other_val) = index_map.get(&indices) {
277                result = result + self.values[i] * other_val;
278            }
279        }
280
281        Ok(result)
282    }
283
284    /// Frobenius norm of the tensor
285    pub fn frobenius_norm(&self) -> T {
286        let sum_sq: T = self.values.iter().map(|&v| v * v).sum();
287        sum_sq.sqrt()
288    }
289}
290
291/// Tucker decomposition result
292#[derive(Debug, Clone)]
293pub struct TuckerDecomposition<T>
294where
295    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
296{
297    /// Core tensor
298    pub core: SparseTensor<T>,
299    /// Factor matrices (one per mode)
300    pub factors: Vec<CsrArray<T>>,
301}
302
303/// CP (CANDECOMP/PARAFAC) decomposition result
304#[derive(Debug, Clone)]
305pub struct CPDecomposition<T>
306where
307    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
308{
309    /// Weights of rank-1 components
310    pub weights: Vec<T>,
311    /// Factor matrices (one per mode)
312    pub factors: Vec<CsrArray<T>>,
313    /// Rank of decomposition
314    pub rank: usize,
315}
316
317/// Compute Khatri-Rao product of two matrices
318///
319/// The Khatri-Rao product is a column-wise Kronecker product.
320pub fn khatri_rao_product<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
321where
322    T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
323{
324    let (rows_a, cols_a) = a.shape();
325    let (rows_b, cols_b) = b.shape();
326
327    if cols_a != cols_b {
328        return Err(SparseError::ValueError(
329            "Matrices must have the same number of columns for Khatri-Rao product".to_string(),
330        ));
331    }
332
333    let ncols = cols_a;
334    let nrows = rows_a * rows_b;
335
336    let mut result_rows = Vec::new();
337    let mut result_cols = Vec::new();
338    let mut result_data = Vec::new();
339
340    // For each column
341    for col in 0..ncols {
342        // Get column vectors
343        let mut col_a = vec![T::sparse_zero(); rows_a];
344        let mut col_b = vec![T::sparse_zero(); rows_b];
345
346        for row in 0..rows_a {
347            col_a[row] = a.get(row, col);
348        }
349
350        for row in 0..rows_b {
351            col_b[row] = b.get(row, col);
352        }
353
354        // Compute Kronecker product of columns
355        for i in 0..rows_a {
356            for j in 0..rows_b {
357                let value = col_a[i] * col_b[j];
358                if !scirs2_core::SparseElement::is_zero(&value) {
359                    result_rows.push(i * rows_b + j);
360                    result_cols.push(col);
361                    result_data.push(value);
362                }
363            }
364        }
365    }
366
367    CsrArray::from_triplets(
368        &result_rows,
369        &result_cols,
370        &result_data,
371        (nrows, ncols),
372        false,
373    )
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use approx::assert_relative_eq;
380
381    fn create_test_tensor() -> SparseTensor<f64> {
382        // Create a simple 2x3x4 sparse tensor with a few non-zero elements
383        let indices = vec![
384            vec![0, 0, 1, 1], // dimension 0
385            vec![0, 1, 0, 2], // dimension 1
386            vec![0, 1, 2, 3], // dimension 2
387        ];
388        let values = vec![1.0, 2.0, 3.0, 4.0];
389        let shape = vec![2, 3, 4];
390
391        SparseTensor::new(indices, values, shape).expect("Failed to create tensor")
392    }
393
394    #[test]
395    fn test_tensor_creation() {
396        let tensor = create_test_tensor();
397
398        assert_eq!(tensor.ndim(), 3);
399        assert_eq!(tensor.nnz(), 4);
400        assert_eq!(tensor.size(), 24);
401        assert_eq!(tensor.shape, vec![2, 3, 4]);
402    }
403
404    #[test]
405    fn test_tensor_get() {
406        let tensor = create_test_tensor();
407
408        assert_relative_eq!(tensor.get(&[0, 0, 0]), 1.0);
409        assert_relative_eq!(tensor.get(&[0, 1, 1]), 2.0);
410        assert_relative_eq!(tensor.get(&[1, 0, 2]), 3.0);
411        assert_relative_eq!(tensor.get(&[1, 2, 3]), 4.0);
412        assert_relative_eq!(tensor.get(&[0, 0, 1]), 0.0); // zero element
413    }
414
415    #[test]
416    fn test_unfold() {
417        let tensor = create_test_tensor();
418
419        // Unfold along mode 0
420        let unfolded = tensor.unfold(0).expect("Failed to unfold");
421        assert_eq!(unfolded.shape(), (2, 12)); // 2 x (3*4)
422
423        // Unfold along mode 1
424        let unfolded1 = tensor.unfold(1).expect("Failed to unfold");
425        assert_eq!(unfolded1.shape(), (3, 8)); // 3 x (2*4)
426
427        // Unfold along mode 2
428        let unfolded2 = tensor.unfold(2).expect("Failed to unfold");
429        assert_eq!(unfolded2.shape(), (4, 6)); // 4 x (2*3)
430    }
431
432    #[test]
433    fn test_fold_unfold_roundtrip() {
434        let tensor = create_test_tensor();
435
436        for mode in 0..tensor.ndim() {
437            let unfolded = tensor.unfold(mode).expect("Failed to unfold");
438            let refolded =
439                SparseTensor::fold(&unfolded, tensor.shape.clone(), mode).expect("Failed to fold");
440
441            // Check that we get the same values back
442            assert_eq!(refolded.nnz(), tensor.nnz());
443
444            for i in 0..tensor.nnz() {
445                let indices: Vec<usize> =
446                    (0..tensor.ndim()).map(|d| tensor.indices[d][i]).collect();
447                assert_relative_eq!(
448                    tensor.get(&indices),
449                    refolded.get(&indices),
450                    epsilon = 1e-10
451                );
452            }
453        }
454    }
455
456    #[test]
457    fn test_inner_product() {
458        let tensor1 = create_test_tensor();
459        let tensor2 = create_test_tensor();
460
461        let ip = tensor1.inner_product(&tensor2).expect("Failed");
462
463        // Inner product with itself should equal sum of squares
464        let sum_sq: f64 = tensor1.values.iter().map(|&v| v * v).sum();
465        assert_relative_eq!(ip, sum_sq, epsilon = 1e-10);
466    }
467
468    #[test]
469    fn test_frobenius_norm() {
470        let tensor = create_test_tensor();
471
472        let norm = tensor.frobenius_norm();
473
474        // Should be sqrt(1^2 + 2^2 + 3^2 + 4^2) = sqrt(30)
475        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
476        assert_relative_eq!(norm, expected, epsilon = 1e-10);
477    }
478
479    #[test]
480    fn test_khatri_rao_product() {
481        // Create two small matrices
482        let rows_a = vec![0, 0, 1];
483        let cols_a = vec![0, 1, 0];
484        let data_a = vec![1.0, 2.0, 3.0];
485        let a = CsrArray::from_triplets(&rows_a, &cols_a, &data_a, (2, 2), false).expect("Failed");
486
487        let rows_b = vec![0, 1, 1];
488        let cols_b = vec![0, 0, 1];
489        let data_b = vec![4.0, 5.0, 6.0];
490        let b = CsrArray::from_triplets(&rows_b, &cols_b, &data_b, (2, 2), false).expect("Failed");
491
492        let result = khatri_rao_product(&a, &b).expect("Failed");
493
494        // Result should be 4x2 (2*2 rows, same columns)
495        assert_eq!(result.shape(), (4, 2));
496        assert!(result.nnz() > 0);
497    }
498}