Skip to main content

scirs2_sparse/
csf_tensor.rs

1//! Compressed Sparse Fiber (CSF) format for sparse tensors
2//!
3//! CSF is a hierarchical compressed format for sparse tensors, analogous to
4//! CSR/CSC for matrices but generalized to arbitrary dimensions. Each level
5//! in the hierarchy compresses one tensor mode.
6//!
7//! For an N-dimensional tensor, CSF uses N levels:
8//! - Level 0: the coarsest (outermost) mode
9//! - Level N-1: the finest (innermost) mode, where values are stored
10//!
11//! Each level `l` has:
12//! - `fptr[l]`: fiber pointers (like `indptr` in CSR)
13//! - `fids[l]`: fiber indices (like `indices` in CSR)
14//!
15//! The values are associated with the leaf-level fibers.
16//!
17//! This provides excellent compression for tensors with hierarchical sparsity
18//! patterns (e.g., most real-world tensors).
19
20use crate::csr_array::CsrArray;
21use crate::error::{SparseError, SparseResult};
22use crate::sparray::SparseArray;
23use crate::tensor_sparse::SparseTensor;
24use scirs2_core::numeric::{Float, SparseElement};
25use std::fmt::Debug;
26use std::ops::Div;
27
28/// Compressed Sparse Fiber (CSF) tensor format
29///
30/// Stores a sparse tensor using a hierarchical compressed structure.
31/// The modes are ordered from coarsest (level 0) to finest (last level).
32/// Values are stored only at the leaf level.
33#[derive(Debug, Clone)]
34pub struct CsfTensor<T> {
35    /// Fiber pointers for each level.
36    /// `fptr[l]` has length = (number of fibers at level l) + 1
37    fptr: Vec<Vec<usize>>,
38    /// Fiber indices for each level.
39    /// `fids[l]` contains the mode-l coordinate for each fiber.
40    fids: Vec<Vec<usize>>,
41    /// Values at the leaf level
42    values: Vec<T>,
43    /// Tensor shape
44    shape: Vec<usize>,
45    /// Mode ordering (which tensor mode corresponds to each CSF level)
46    mode_order: Vec<usize>,
47}
48
49impl<T> CsfTensor<T>
50where
51    T: Float + SparseElement + Debug + Copy + 'static,
52{
53    /// Create a CSF tensor from a COO-format SparseTensor.
54    ///
55    /// The `mode_order` specifies which tensor mode corresponds to each CSF level.
56    /// For example, for a 3D tensor with `mode_order = [0, 1, 2]`, level 0
57    /// compresses mode 0, level 1 compresses mode 1, and values are indexed by mode 2.
58    ///
59    /// # Arguments
60    /// * `tensor` - The sparse tensor in COO format
61    /// * `mode_order` - Ordering of modes (must be a permutation of 0..ndim)
62    pub fn from_sparse_tensor(tensor: &SparseTensor<T>, mode_order: &[usize]) -> SparseResult<Self>
63    where
64        T: std::iter::Sum,
65    {
66        let ndim = tensor.ndim();
67        if mode_order.len() != ndim {
68            return Err(SparseError::ValueError(format!(
69                "mode_order length {} must equal tensor ndim {}",
70                mode_order.len(),
71                ndim
72            )));
73        }
74
75        // Validate mode_order is a permutation
76        let mut sorted_modes = mode_order.to_vec();
77        sorted_modes.sort();
78        for (i, &m) in sorted_modes.iter().enumerate() {
79            if m != i {
80                return Err(SparseError::ValueError(
81                    "mode_order must be a permutation of 0..ndim".to_string(),
82                ));
83            }
84        }
85
86        let nnz = tensor.nnz();
87        if nnz == 0 {
88            // Empty tensor
89            let fptr = vec![vec![0, 0]; ndim];
90            let fids = vec![Vec::new(); ndim];
91            return Ok(Self {
92                fptr,
93                fids,
94                values: Vec::new(),
95                shape: tensor.shape.clone(),
96                mode_order: mode_order.to_vec(),
97            });
98        }
99
100        // Build sorted list of non-zero entries, sorted by the CSF level order
101        // Each entry: (coordinates_in_mode_order, value)
102        let mut entries: Vec<(Vec<usize>, T)> = (0..nnz)
103            .map(|i| {
104                let coords: Vec<usize> = mode_order.iter().map(|&m| tensor.indices[m][i]).collect();
105                (coords, tensor.values[i])
106            })
107            .collect();
108
109        // Sort entries lexicographically by coordinates
110        entries.sort_by(|a, b| a.0.cmp(&b.0));
111
112        // Build CSF structure level by level
113        let mut fptr: Vec<Vec<usize>> = Vec::with_capacity(ndim);
114        let mut fids: Vec<Vec<usize>> = Vec::with_capacity(ndim);
115        let mut values: Vec<T> = Vec::new();
116
117        // Level 0 (root level): unique values in the first coordinate
118        let mut level0_ids: Vec<usize> = Vec::new();
119        let mut level0_ptr: Vec<usize> = vec![0];
120
121        // We build all levels simultaneously by tracking coordinate groups
122        // at each level.
123        //
124        // Algorithm: Process entries in sorted order. At each level, track
125        // the current coordinate prefix. When it changes, we close the
126        // current fiber and start a new one.
127
128        // Initialize per-level structures
129        for _ in 0..ndim {
130            fptr.push(Vec::new());
131            fids.push(Vec::new());
132        }
133
134        // For each level, we track the current prefix
135        let mut prev_prefix: Vec<Option<usize>> = vec![None; ndim];
136        let mut level_counts: Vec<usize> = vec![0; ndim];
137
138        // Process each level from 0 to ndim-1
139        // We use a recursive-like approach: for each entry, check which levels
140        // need new fibers.
141        for (entry_idx, (coords, val)) in entries.iter().enumerate() {
142            // Determine the first level where the prefix changes
143            let mut change_level = ndim; // no change needed
144            for l in 0..ndim {
145                if prev_prefix[l] != Some(coords[l]) {
146                    change_level = l;
147                    break;
148                }
149            }
150
151            // Close fibers at levels below change_level (from deepest to change_level)
152            // and open new fibers from change_level downward
153
154            for l in change_level..ndim {
155                // At level l, we need to start a new coordinate
156                fids[l].push(coords[l]);
157
158                // If this is not the leaf level, record pointer for next level
159                if l < ndim - 1 {
160                    // The pointer for level l+1 is the current size of fids[l+1]
161                    if fptr[l].is_empty() || l == change_level {
162                        // Only push a new pointer when we actually start a new fiber at this level
163                        fptr[l].push(fids[l].len() - 1);
164                    }
165                }
166
167                prev_prefix[l] = Some(coords[l]);
168            }
169
170            // Store value at leaf level
171            values.push(*val);
172        }
173
174        // Now build proper fptr structures
175        // Reset and rebuild using the sorted entries
176        fptr = vec![Vec::new(); ndim];
177        fids = vec![Vec::new(); ndim];
178        values = Vec::new();
179
180        // Rebuild properly using group detection
181        self::build_csf_levels(&entries, &mut fptr, &mut fids, &mut values, ndim);
182
183        Ok(Self {
184            fptr,
185            fids,
186            values,
187            shape: tensor.shape.clone(),
188            mode_order: mode_order.to_vec(),
189        })
190    }
191
192    /// Get the tensor shape
193    pub fn shape(&self) -> &[usize] {
194        &self.shape
195    }
196
197    /// Get the number of dimensions
198    pub fn ndim(&self) -> usize {
199        self.shape.len()
200    }
201
202    /// Get the number of stored non-zero values
203    pub fn nnz(&self) -> usize {
204        self.values.len()
205    }
206
207    /// Get the mode ordering
208    pub fn mode_order(&self) -> &[usize] {
209        &self.mode_order
210    }
211
212    /// Get the fiber pointers at a given level
213    pub fn fiber_pointers(&self, level: usize) -> Option<&[usize]> {
214        self.fptr.get(level).map(|v| v.as_slice())
215    }
216
217    /// Get the fiber indices at a given level
218    pub fn fiber_indices(&self, level: usize) -> Option<&[usize]> {
219        self.fids.get(level).map(|v| v.as_slice())
220    }
221
222    /// Get the stored values
223    pub fn values(&self) -> &[T] {
224        &self.values
225    }
226
227    /// Look up a value by coordinate
228    pub fn get(&self, coords: &[usize]) -> T {
229        if coords.len() != self.ndim() {
230            return T::sparse_zero();
231        }
232
233        // Reorder coordinates according to mode_order
234        let ordered_coords: Vec<usize> = self.mode_order.iter().map(|&m| coords[m]).collect();
235
236        // Navigate the CSF tree
237        self.search_tree(&ordered_coords, 0, 0)
238    }
239
240    /// Recursive tree search for a coordinate
241    fn search_tree(&self, ordered_coords: &[usize], level: usize, fiber_idx: usize) -> T {
242        let ndim = self.ndim();
243
244        if level == ndim - 1 {
245            // Leaf level: search fids for the coordinate
246            let start = if level == 0 {
247                0
248            } else {
249                self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
250            };
251            let end = if level == 0 {
252                self.fids[level].len()
253            } else {
254                self.fptr[level - 1]
255                    .get(fiber_idx + 1)
256                    .copied()
257                    .unwrap_or(self.fids[level].len())
258            };
259
260            for i in start..end {
261                if i < self.fids[level].len() && self.fids[level][i] == ordered_coords[level] {
262                    let val_idx = i; // leaf fids index maps directly to values
263                    if val_idx < self.values.len() {
264                        return self.values[val_idx];
265                    }
266                }
267            }
268            return T::sparse_zero();
269        }
270
271        // Non-leaf level: find the matching fiber and recurse
272        let start = if level == 0 {
273            0
274        } else {
275            self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
276        };
277        let end = if level == 0 {
278            self.fids[level].len()
279        } else {
280            self.fptr[level - 1]
281                .get(fiber_idx + 1)
282                .copied()
283                .unwrap_or(self.fids[level].len())
284        };
285
286        for i in start..end {
287            if i < self.fids[level].len() && self.fids[level][i] == ordered_coords[level] {
288                return self.search_tree(ordered_coords, level + 1, i);
289            }
290        }
291
292        T::sparse_zero()
293    }
294
295    /// Convert back to a COO-format SparseTensor
296    pub fn to_sparse_tensor(&self) -> SparseResult<SparseTensor<T>>
297    where
298        T: std::iter::Sum,
299    {
300        let ndim = self.ndim();
301        let mut indices: Vec<Vec<usize>> = vec![Vec::new(); ndim];
302        let mut values: Vec<T> = Vec::new();
303
304        // Build inverse mode order
305        let mut inv_mode_order = vec![0usize; ndim];
306        for (csf_level, &tensor_mode) in self.mode_order.iter().enumerate() {
307            inv_mode_order[tensor_mode] = csf_level;
308        }
309
310        // Traverse the CSF tree to extract all non-zero entries
311        let mut coord_stack: Vec<usize> = vec![0; ndim];
312        self.traverse_tree(
313            0,
314            0,
315            &mut coord_stack,
316            &mut indices,
317            &mut values,
318            &inv_mode_order,
319        );
320
321        SparseTensor::new(indices, values, self.shape.clone())
322    }
323
324    /// Recursive traversal to extract entries
325    fn traverse_tree(
326        &self,
327        level: usize,
328        fiber_idx: usize,
329        coord_stack: &mut Vec<usize>,
330        indices: &mut Vec<Vec<usize>>,
331        values: &mut Vec<T>,
332        inv_mode_order: &[usize],
333    ) {
334        let ndim = self.ndim();
335
336        let start = if level == 0 {
337            0
338        } else {
339            self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
340        };
341        let end = if level == 0 {
342            self.fids[level].len()
343        } else {
344            self.fptr[level - 1]
345                .get(fiber_idx + 1)
346                .copied()
347                .unwrap_or(self.fids[level].len())
348        };
349
350        for i in start..end {
351            if i >= self.fids[level].len() {
352                break;
353            }
354            coord_stack[level] = self.fids[level][i];
355
356            if level == ndim - 1 {
357                // Leaf: emit the entry
358                if i < self.values.len() {
359                    for mode in 0..ndim {
360                        let csf_level = inv_mode_order[mode];
361                        indices[mode].push(coord_stack[csf_level]);
362                    }
363                    values.push(self.values[i]);
364                }
365            } else {
366                self.traverse_tree(level + 1, i, coord_stack, indices, values, inv_mode_order);
367            }
368        }
369    }
370
371    /// Tensor contraction along a specified mode with a vector.
372    ///
373    /// Computes the mode-n product with a vector: result = T x_n v
374    /// The result is a tensor of one lower dimension.
375    ///
376    /// # Arguments
377    /// * `mode` - The tensor mode to contract
378    /// * `vector` - The vector to contract with (length = `shape[mode]`)
379    pub fn contract_vector(&self, mode: usize, vector: &[T]) -> SparseResult<SparseTensor<T>>
380    where
381        T: std::iter::Sum,
382    {
383        let ndim = self.ndim();
384        if mode >= ndim {
385            return Err(SparseError::ValueError(format!(
386                "Mode {} exceeds tensor dimensions {}",
387                mode, ndim
388            )));
389        }
390        if vector.len() != self.shape[mode] {
391            return Err(SparseError::DimensionMismatch {
392                expected: self.shape[mode],
393                found: vector.len(),
394            });
395        }
396        if ndim < 2 {
397            return Err(SparseError::ValueError(
398                "Cannot contract a 1D tensor to 0D".to_string(),
399            ));
400        }
401
402        // Convert to COO, perform contraction, return result
403        let coo = self.to_sparse_tensor()?;
404
405        // Build result tensor: shape without the contracted mode
406        let new_shape: Vec<usize> = (0..ndim)
407            .filter(|&m| m != mode)
408            .map(|m| self.shape[m])
409            .collect();
410        let new_ndim = new_shape.len();
411
412        // Accumulate contracted values
413        let mut result_map: std::collections::HashMap<Vec<usize>, T> =
414            std::collections::HashMap::new();
415
416        for i in 0..coo.nnz() {
417            let mode_idx = coo.indices[mode][i];
418            let scale = vector[mode_idx];
419
420            if SparseElement::is_zero(&scale) {
421                continue;
422            }
423
424            let key: Vec<usize> = (0..ndim)
425                .filter(|&m| m != mode)
426                .map(|m| coo.indices[m][i])
427                .collect();
428
429            let entry = result_map.entry(key).or_insert(T::sparse_zero());
430            *entry = *entry + coo.values[i] * scale;
431        }
432
433        // Build result tensor
434        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); new_ndim];
435        let mut new_values: Vec<T> = Vec::new();
436
437        for (key, val) in &result_map {
438            if !SparseElement::is_zero(val) {
439                for (d, &k) in key.iter().enumerate() {
440                    new_indices[d].push(k);
441                }
442                new_values.push(*val);
443            }
444        }
445
446        if new_values.is_empty() {
447            // Return empty tensor
448            new_indices = vec![Vec::new(); new_ndim];
449        }
450
451        SparseTensor::new(new_indices, new_values, new_shape)
452    }
453
454    /// Mode-n product with a matrix: result = T x_n M
455    ///
456    /// The result has the same number of dimensions as the input tensor,
457    /// but the size of mode `n` changes from `shape[n]` to `M.nrows`.
458    ///
459    /// # Arguments
460    /// * `mode` - The tensor mode to multiply along
461    /// * `matrix` - The matrix (nrows x `shape[mode]`)
462    pub fn mode_n_product(&self, mode: usize, matrix: &CsrArray<T>) -> SparseResult<SparseTensor<T>>
463    where
464        T: Float + SparseElement + Div<Output = T> + std::iter::Sum + 'static,
465    {
466        let ndim = self.ndim();
467        if mode >= ndim {
468            return Err(SparseError::ValueError(format!(
469                "Mode {} exceeds tensor dimensions {}",
470                mode, ndim
471            )));
472        }
473        let (mat_rows, mat_cols) = matrix.shape();
474        if mat_cols != self.shape[mode] {
475            return Err(SparseError::DimensionMismatch {
476                expected: self.shape[mode],
477                found: mat_cols,
478            });
479        }
480
481        let coo = self.to_sparse_tensor()?;
482
483        let mut new_shape = self.shape.clone();
484        new_shape[mode] = mat_rows;
485
486        // For each non-zero in the tensor, distribute it across matrix rows
487        let mut result_map: std::collections::HashMap<Vec<usize>, T> =
488            std::collections::HashMap::new();
489
490        for i in 0..coo.nnz() {
491            let mode_idx = coo.indices[mode][i];
492            let tensor_val = coo.values[i];
493
494            // Multiply by each non-zero in column mode_idx of the matrix
495            for new_mode_idx in 0..mat_rows {
496                let m_val = matrix.get(new_mode_idx, mode_idx);
497                if SparseElement::is_zero(&m_val) {
498                    continue;
499                }
500
501                let mut key: Vec<usize> = (0..ndim).map(|m| coo.indices[m][i]).collect();
502                key[mode] = new_mode_idx;
503
504                let entry = result_map.entry(key).or_insert(T::sparse_zero());
505                *entry = *entry + tensor_val * m_val;
506            }
507        }
508
509        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); ndim];
510        let mut new_values: Vec<T> = Vec::new();
511
512        for (key, val) in &result_map {
513            if !SparseElement::is_zero(val) {
514                for (d, &k) in key.iter().enumerate() {
515                    new_indices[d].push(k);
516                }
517                new_values.push(*val);
518            }
519        }
520
521        if new_values.is_empty() {
522            // Return an empty tensor with valid indices
523            return Ok(SparseTensor {
524                indices: (0..ndim).map(|_| Vec::new()).collect(),
525                values: Vec::new(),
526                shape: new_shape,
527            });
528        }
529
530        SparseTensor::new(new_indices, new_values, new_shape)
531    }
532
533    /// Get memory usage estimate (bytes)
534    pub fn memory_usage(&self) -> usize {
535        let mut total = 0usize;
536        for fp in &self.fptr {
537            total += fp.len() * std::mem::size_of::<usize>();
538        }
539        for fi in &self.fids {
540            total += fi.len() * std::mem::size_of::<usize>();
541        }
542        total += self.values.len() * std::mem::size_of::<T>();
543        total += self.shape.len() * std::mem::size_of::<usize>();
544        total += self.mode_order.len() * std::mem::size_of::<usize>();
545        total
546    }
547}
548
549/// Build CSF level structures from sorted entries
550///
551/// The CSF structure uses `fptr[level]` to index into `fids[level+1]`.
552/// For each fiber at level `l` (identified by `fids[l][i]`), its children
553/// in `fids[l+1]` span the range `fptr[l][i]..fptr[l][i+1]`.
554///
555/// `fptr[l]` has exactly `fids[l].len() + 1` entries.
556/// Values are stored at the leaf level, indexed in parallel with `fids[ndim-1]`.
557fn build_csf_levels<T: Copy>(
558    entries: &[(Vec<usize>, T)],
559    fptr: &mut Vec<Vec<usize>>,
560    fids: &mut Vec<Vec<usize>>,
561    values: &mut Vec<T>,
562    ndim: usize,
563) {
564    if entries.is_empty() || ndim == 0 {
565        return;
566    }
567
568    // Initialize fptr and fids for all levels
569    for l in 0..ndim {
570        fptr[l] = Vec::new();
571        fids[l] = Vec::new();
572    }
573
574    // Use recursive grouping (no sentinel pushed inside recursion)
575    build_level(entries, fptr, fids, values, 0, ndim);
576
577    // Now add the sentinel pointer to each non-leaf level.
578    // fptr[l] should have fids[l].len() + 1 entries.
579    // build_level pushes one entry per fiber (the start pointer), so we
580    // need to append the final sentinel = fids[l+1].len() for each non-leaf level.
581    for l in 0..(ndim - 1) {
582        fptr[l].push(fids[l + 1].len());
583    }
584}
585
586/// Groups of entries at one CSF level: (coordinate, list of (coords, value) entries).
587type LevelGroups<T> = Vec<(usize, Vec<(Vec<usize>, T)>)>;
588
589/// Recursively build one level of the CSF structure.
590///
591/// For each group at this level, push the coordinate to `fids[level]`,
592/// push the start pointer to `fptr[level]`, then recurse.
593/// Does NOT push a sentinel; the caller handles that.
594fn build_level<T: Copy>(
595    entries: &[(Vec<usize>, T)],
596    fptr: &mut Vec<Vec<usize>>,
597    fids: &mut Vec<Vec<usize>>,
598    values: &mut Vec<T>,
599    level: usize,
600    ndim: usize,
601) {
602    if entries.is_empty() {
603        return;
604    }
605
606    if level == ndim - 1 {
607        // Leaf level: just store fids and values
608        for (coords, val) in entries {
609            fids[level].push(coords[level]);
610            values.push(*val);
611        }
612        return;
613    }
614
615    // Group entries by coordinate at this level
616    let mut groups: LevelGroups<T> = Vec::new();
617    let mut current_coord = entries[0].0[level];
618    let mut current_group: Vec<(Vec<usize>, T)> = Vec::new();
619
620    for (coords, val) in entries {
621        if coords[level] != current_coord {
622            groups.push((current_coord, std::mem::take(&mut current_group)));
623            current_coord = coords[level];
624        }
625        current_group.push((coords.clone(), *val));
626    }
627    groups.push((current_coord, current_group));
628
629    // For this level, record fids and pointers into next level
630    for (coord, group) in &groups {
631        fids[level].push(*coord);
632        // Pointer: where the children of this fiber start in the next level's fids
633        fptr[level].push(fids[level + 1].len());
634        build_level(group, fptr, fids, values, level + 1, ndim);
635    }
636    // No sentinel here -- the caller (build_csf_levels) adds it after all recursion is done.
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use approx::assert_relative_eq;
643
644    fn create_test_tensor_3d() -> SparseTensor<f64> {
645        // 2x3x4 tensor with 5 non-zeros
646        let indices = vec![
647            vec![0, 0, 0, 1, 1], // mode 0
648            vec![0, 1, 2, 0, 2], // mode 1
649            vec![0, 1, 3, 2, 0], // mode 2
650        ];
651        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
652        let shape = vec![2, 3, 4];
653        SparseTensor::new(indices, values, shape).expect("failed to create tensor")
654    }
655
656    #[test]
657    fn test_csf_from_sparse_tensor() {
658        let tensor = create_test_tensor_3d();
659        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
660
661        assert_eq!(csf.ndim(), 3);
662        assert_eq!(csf.nnz(), 5);
663        assert_eq!(csf.shape(), &[2, 3, 4]);
664        assert_eq!(csf.mode_order(), &[0, 1, 2]);
665    }
666
667    #[test]
668    fn test_csf_roundtrip() {
669        let tensor = create_test_tensor_3d();
670        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
671
672        let recovered = csf.to_sparse_tensor().expect("to_sparse_tensor failed");
673
674        assert_eq!(recovered.nnz(), tensor.nnz());
675
676        // Check all values match
677        for i in 0..tensor.nnz() {
678            let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
679            let orig_val = tensor.get(&coords);
680            let rec_val = recovered.get(&coords);
681            assert_relative_eq!(orig_val, rec_val, epsilon = 1e-12);
682        }
683    }
684
685    #[test]
686    fn test_csf_get() {
687        let tensor = create_test_tensor_3d();
688        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
689
690        assert_relative_eq!(csf.get(&[0, 0, 0]), 1.0, epsilon = 1e-12);
691        assert_relative_eq!(csf.get(&[0, 1, 1]), 2.0, epsilon = 1e-12);
692        assert_relative_eq!(csf.get(&[0, 2, 3]), 3.0, epsilon = 1e-12);
693        assert_relative_eq!(csf.get(&[1, 0, 2]), 4.0, epsilon = 1e-12);
694        assert_relative_eq!(csf.get(&[1, 2, 0]), 5.0, epsilon = 1e-12);
695        assert_relative_eq!(csf.get(&[0, 0, 1]), 0.0, epsilon = 1e-12); // zero
696    }
697
698    #[test]
699    fn test_csf_different_mode_order() {
700        let tensor = create_test_tensor_3d();
701
702        // Create with reversed mode order
703        let csf = CsfTensor::from_sparse_tensor(&tensor, &[2, 1, 0]).expect("CSF creation failed");
704        assert_eq!(csf.nnz(), 5);
705        assert_eq!(csf.mode_order(), &[2, 1, 0]);
706
707        // Should still look up correctly by original coordinates
708        assert_relative_eq!(csf.get(&[0, 0, 0]), 1.0, epsilon = 1e-12);
709        assert_relative_eq!(csf.get(&[1, 2, 0]), 5.0, epsilon = 1e-12);
710
711        // Roundtrip should work
712        let recovered = csf.to_sparse_tensor().expect("roundtrip failed");
713        for i in 0..tensor.nnz() {
714            let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
715            assert_relative_eq!(tensor.get(&coords), recovered.get(&coords), epsilon = 1e-12);
716        }
717    }
718
719    #[test]
720    fn test_csf_contract_vector() {
721        // Simple 2x3 tensor contracted along mode 1
722        let indices = vec![
723            vec![0, 0, 1], // mode 0
724            vec![0, 1, 0], // mode 1
725        ];
726        let values = vec![1.0, 2.0, 3.0];
727        let shape = vec![2, 3];
728        let tensor = SparseTensor::new(indices, values, shape).expect("create tensor");
729
730        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF creation");
731
732        let vector = vec![1.0, 2.0, 0.0]; // contract mode 1
733        let result = csf.contract_vector(1, &vector).expect("contract_vector");
734
735        // Result should be 1D tensor of shape [2]
736        assert_eq!(result.shape, vec![2]);
737
738        // result[0] = 1*1 + 2*2 = 5
739        // result[1] = 3*1 = 3
740        assert_relative_eq!(result.get(&[0]), 5.0, epsilon = 1e-12);
741        assert_relative_eq!(result.get(&[1]), 3.0, epsilon = 1e-12);
742    }
743
744    #[test]
745    fn test_csf_mode_n_product() {
746        // 2x3 tensor, mode-1 product with a 2x3 matrix
747        let indices = vec![
748            vec![0, 0, 1], // mode 0
749            vec![0, 2, 1], // mode 1
750        ];
751        let values = vec![1.0, 3.0, 2.0];
752        let shape = vec![2, 3];
753        let tensor = SparseTensor::new(indices, values, shape).expect("create tensor");
754        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF");
755
756        // 2x3 matrix
757        let m_rows = vec![0, 0, 1, 1];
758        let m_cols = vec![0, 2, 0, 1];
759        let m_vals = vec![1.0, 1.0, 0.5, 2.0];
760        let matrix =
761            CsrArray::from_triplets(&m_rows, &m_cols, &m_vals, (2, 3), false).expect("matrix");
762
763        let result = csf.mode_n_product(1, &matrix).expect("mode_n_product");
764
765        // Result shape: [2, 2] (mode 1 changes from 3 to 2)
766        assert_eq!(result.shape, vec![2, 2]);
767
768        // result[0,0] = tensor[0,0]*M[0,0] + tensor[0,2]*M[0,2] = 1*1 + 3*1 = 4
769        // result[0,1] = tensor[0,0]*M[1,0] + tensor[0,1]*M[1,1] = 1*0.5 + 0*2 = 0.5
770        // result[1,0] = tensor[1,1]*M[0,1] = 0 (M[0,1] = 0)
771        // result[1,1] = tensor[1,1]*M[1,1] = 2*2 = 4
772        assert_relative_eq!(result.get(&[0, 0]), 4.0, epsilon = 1e-12);
773        assert_relative_eq!(result.get(&[0, 1]), 0.5, epsilon = 1e-12);
774        assert_relative_eq!(result.get(&[1, 1]), 4.0, epsilon = 1e-12);
775    }
776
777    #[test]
778    fn test_csf_memory_usage() {
779        let tensor = create_test_tensor_3d();
780        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
781        let mem = csf.memory_usage();
782        assert!(mem > 0);
783    }
784
785    #[test]
786    fn test_csf_empty_tensor() {
787        let indices = vec![Vec::<usize>::new(), Vec::<usize>::new()];
788        let values: Vec<f64> = Vec::new();
789        let shape = vec![3, 4];
790        let tensor = SparseTensor::new(indices, values, shape).expect("empty tensor");
791        let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF");
792        assert_eq!(csf.nnz(), 0);
793    }
794
795    #[test]
796    fn test_csf_3d_roundtrip_with_permutation() {
797        // Larger 3x4x5 tensor
798        let indices = vec![
799            vec![0, 0, 1, 2, 2, 2], // mode 0
800            vec![0, 3, 1, 0, 2, 3], // mode 1
801            vec![0, 4, 2, 1, 3, 4], // mode 2
802        ];
803        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
804        let shape = vec![3, 4, 5];
805        let tensor = SparseTensor::new(indices, values, shape).expect("tensor");
806
807        for perm in &[[0, 1, 2], [1, 0, 2], [2, 0, 1], [0, 2, 1]] {
808            let csf = CsfTensor::from_sparse_tensor(&tensor, perm).expect("CSF");
809            assert_eq!(csf.nnz(), 6);
810
811            let recovered = csf.to_sparse_tensor().expect("roundtrip");
812            assert_eq!(recovered.nnz(), 6);
813
814            for i in 0..tensor.nnz() {
815                let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
816                assert_relative_eq!(tensor.get(&coords), recovered.get(&coords), epsilon = 1e-12,);
817            }
818        }
819    }
820}