1use crate::csr_array::CsrArray;
21use crate::error::{SparseError, SparseResult};
22use crate::sparray::SparseArray;
23use crate::tensor_sparse::SparseTensor;
24use scirs2_core::numeric::{Float, SparseElement};
25use std::fmt::Debug;
26use std::ops::Div;
27
28#[derive(Debug, Clone)]
34pub struct CsfTensor<T> {
35 fptr: Vec<Vec<usize>>,
38 fids: Vec<Vec<usize>>,
41 values: Vec<T>,
43 shape: Vec<usize>,
45 mode_order: Vec<usize>,
47}
48
49impl<T> CsfTensor<T>
50where
51 T: Float + SparseElement + Debug + Copy + 'static,
52{
53 pub fn from_sparse_tensor(tensor: &SparseTensor<T>, mode_order: &[usize]) -> SparseResult<Self>
63 where
64 T: std::iter::Sum,
65 {
66 let ndim = tensor.ndim();
67 if mode_order.len() != ndim {
68 return Err(SparseError::ValueError(format!(
69 "mode_order length {} must equal tensor ndim {}",
70 mode_order.len(),
71 ndim
72 )));
73 }
74
75 let mut sorted_modes = mode_order.to_vec();
77 sorted_modes.sort();
78 for (i, &m) in sorted_modes.iter().enumerate() {
79 if m != i {
80 return Err(SparseError::ValueError(
81 "mode_order must be a permutation of 0..ndim".to_string(),
82 ));
83 }
84 }
85
86 let nnz = tensor.nnz();
87 if nnz == 0 {
88 let fptr = vec![vec![0, 0]; ndim];
90 let fids = vec![Vec::new(); ndim];
91 return Ok(Self {
92 fptr,
93 fids,
94 values: Vec::new(),
95 shape: tensor.shape.clone(),
96 mode_order: mode_order.to_vec(),
97 });
98 }
99
100 let mut entries: Vec<(Vec<usize>, T)> = (0..nnz)
103 .map(|i| {
104 let coords: Vec<usize> = mode_order.iter().map(|&m| tensor.indices[m][i]).collect();
105 (coords, tensor.values[i])
106 })
107 .collect();
108
109 entries.sort_by(|a, b| a.0.cmp(&b.0));
111
112 let mut fptr: Vec<Vec<usize>> = Vec::with_capacity(ndim);
114 let mut fids: Vec<Vec<usize>> = Vec::with_capacity(ndim);
115 let mut values: Vec<T> = Vec::new();
116
117 let mut level0_ids: Vec<usize> = Vec::new();
119 let mut level0_ptr: Vec<usize> = vec![0];
120
121 for _ in 0..ndim {
130 fptr.push(Vec::new());
131 fids.push(Vec::new());
132 }
133
134 let mut prev_prefix: Vec<Option<usize>> = vec![None; ndim];
136 let mut level_counts: Vec<usize> = vec![0; ndim];
137
138 for (entry_idx, (coords, val)) in entries.iter().enumerate() {
142 let mut change_level = ndim; for l in 0..ndim {
145 if prev_prefix[l] != Some(coords[l]) {
146 change_level = l;
147 break;
148 }
149 }
150
151 for l in change_level..ndim {
155 fids[l].push(coords[l]);
157
158 if l < ndim - 1 {
160 if fptr[l].is_empty() || l == change_level {
162 fptr[l].push(fids[l].len() - 1);
164 }
165 }
166
167 prev_prefix[l] = Some(coords[l]);
168 }
169
170 values.push(*val);
172 }
173
174 fptr = vec![Vec::new(); ndim];
177 fids = vec![Vec::new(); ndim];
178 values = Vec::new();
179
180 self::build_csf_levels(&entries, &mut fptr, &mut fids, &mut values, ndim);
182
183 Ok(Self {
184 fptr,
185 fids,
186 values,
187 shape: tensor.shape.clone(),
188 mode_order: mode_order.to_vec(),
189 })
190 }
191
192 pub fn shape(&self) -> &[usize] {
194 &self.shape
195 }
196
197 pub fn ndim(&self) -> usize {
199 self.shape.len()
200 }
201
202 pub fn nnz(&self) -> usize {
204 self.values.len()
205 }
206
207 pub fn mode_order(&self) -> &[usize] {
209 &self.mode_order
210 }
211
212 pub fn fiber_pointers(&self, level: usize) -> Option<&[usize]> {
214 self.fptr.get(level).map(|v| v.as_slice())
215 }
216
217 pub fn fiber_indices(&self, level: usize) -> Option<&[usize]> {
219 self.fids.get(level).map(|v| v.as_slice())
220 }
221
222 pub fn values(&self) -> &[T] {
224 &self.values
225 }
226
227 pub fn get(&self, coords: &[usize]) -> T {
229 if coords.len() != self.ndim() {
230 return T::sparse_zero();
231 }
232
233 let ordered_coords: Vec<usize> = self.mode_order.iter().map(|&m| coords[m]).collect();
235
236 self.search_tree(&ordered_coords, 0, 0)
238 }
239
240 fn search_tree(&self, ordered_coords: &[usize], level: usize, fiber_idx: usize) -> T {
242 let ndim = self.ndim();
243
244 if level == ndim - 1 {
245 let start = if level == 0 {
247 0
248 } else {
249 self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
250 };
251 let end = if level == 0 {
252 self.fids[level].len()
253 } else {
254 self.fptr[level - 1]
255 .get(fiber_idx + 1)
256 .copied()
257 .unwrap_or(self.fids[level].len())
258 };
259
260 for i in start..end {
261 if i < self.fids[level].len() && self.fids[level][i] == ordered_coords[level] {
262 let val_idx = i; if val_idx < self.values.len() {
264 return self.values[val_idx];
265 }
266 }
267 }
268 return T::sparse_zero();
269 }
270
271 let start = if level == 0 {
273 0
274 } else {
275 self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
276 };
277 let end = if level == 0 {
278 self.fids[level].len()
279 } else {
280 self.fptr[level - 1]
281 .get(fiber_idx + 1)
282 .copied()
283 .unwrap_or(self.fids[level].len())
284 };
285
286 for i in start..end {
287 if i < self.fids[level].len() && self.fids[level][i] == ordered_coords[level] {
288 return self.search_tree(ordered_coords, level + 1, i);
289 }
290 }
291
292 T::sparse_zero()
293 }
294
295 pub fn to_sparse_tensor(&self) -> SparseResult<SparseTensor<T>>
297 where
298 T: std::iter::Sum,
299 {
300 let ndim = self.ndim();
301 let mut indices: Vec<Vec<usize>> = vec![Vec::new(); ndim];
302 let mut values: Vec<T> = Vec::new();
303
304 let mut inv_mode_order = vec![0usize; ndim];
306 for (csf_level, &tensor_mode) in self.mode_order.iter().enumerate() {
307 inv_mode_order[tensor_mode] = csf_level;
308 }
309
310 let mut coord_stack: Vec<usize> = vec![0; ndim];
312 self.traverse_tree(
313 0,
314 0,
315 &mut coord_stack,
316 &mut indices,
317 &mut values,
318 &inv_mode_order,
319 );
320
321 SparseTensor::new(indices, values, self.shape.clone())
322 }
323
324 fn traverse_tree(
326 &self,
327 level: usize,
328 fiber_idx: usize,
329 coord_stack: &mut Vec<usize>,
330 indices: &mut Vec<Vec<usize>>,
331 values: &mut Vec<T>,
332 inv_mode_order: &[usize],
333 ) {
334 let ndim = self.ndim();
335
336 let start = if level == 0 {
337 0
338 } else {
339 self.fptr[level - 1].get(fiber_idx).copied().unwrap_or(0)
340 };
341 let end = if level == 0 {
342 self.fids[level].len()
343 } else {
344 self.fptr[level - 1]
345 .get(fiber_idx + 1)
346 .copied()
347 .unwrap_or(self.fids[level].len())
348 };
349
350 for i in start..end {
351 if i >= self.fids[level].len() {
352 break;
353 }
354 coord_stack[level] = self.fids[level][i];
355
356 if level == ndim - 1 {
357 if i < self.values.len() {
359 for mode in 0..ndim {
360 let csf_level = inv_mode_order[mode];
361 indices[mode].push(coord_stack[csf_level]);
362 }
363 values.push(self.values[i]);
364 }
365 } else {
366 self.traverse_tree(level + 1, i, coord_stack, indices, values, inv_mode_order);
367 }
368 }
369 }
370
371 pub fn contract_vector(&self, mode: usize, vector: &[T]) -> SparseResult<SparseTensor<T>>
380 where
381 T: std::iter::Sum,
382 {
383 let ndim = self.ndim();
384 if mode >= ndim {
385 return Err(SparseError::ValueError(format!(
386 "Mode {} exceeds tensor dimensions {}",
387 mode, ndim
388 )));
389 }
390 if vector.len() != self.shape[mode] {
391 return Err(SparseError::DimensionMismatch {
392 expected: self.shape[mode],
393 found: vector.len(),
394 });
395 }
396 if ndim < 2 {
397 return Err(SparseError::ValueError(
398 "Cannot contract a 1D tensor to 0D".to_string(),
399 ));
400 }
401
402 let coo = self.to_sparse_tensor()?;
404
405 let new_shape: Vec<usize> = (0..ndim)
407 .filter(|&m| m != mode)
408 .map(|m| self.shape[m])
409 .collect();
410 let new_ndim = new_shape.len();
411
412 let mut result_map: std::collections::HashMap<Vec<usize>, T> =
414 std::collections::HashMap::new();
415
416 for i in 0..coo.nnz() {
417 let mode_idx = coo.indices[mode][i];
418 let scale = vector[mode_idx];
419
420 if SparseElement::is_zero(&scale) {
421 continue;
422 }
423
424 let key: Vec<usize> = (0..ndim)
425 .filter(|&m| m != mode)
426 .map(|m| coo.indices[m][i])
427 .collect();
428
429 let entry = result_map.entry(key).or_insert(T::sparse_zero());
430 *entry = *entry + coo.values[i] * scale;
431 }
432
433 let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); new_ndim];
435 let mut new_values: Vec<T> = Vec::new();
436
437 for (key, val) in &result_map {
438 if !SparseElement::is_zero(val) {
439 for (d, &k) in key.iter().enumerate() {
440 new_indices[d].push(k);
441 }
442 new_values.push(*val);
443 }
444 }
445
446 if new_values.is_empty() {
447 new_indices = vec![Vec::new(); new_ndim];
449 }
450
451 SparseTensor::new(new_indices, new_values, new_shape)
452 }
453
454 pub fn mode_n_product(&self, mode: usize, matrix: &CsrArray<T>) -> SparseResult<SparseTensor<T>>
463 where
464 T: Float + SparseElement + Div<Output = T> + std::iter::Sum + 'static,
465 {
466 let ndim = self.ndim();
467 if mode >= ndim {
468 return Err(SparseError::ValueError(format!(
469 "Mode {} exceeds tensor dimensions {}",
470 mode, ndim
471 )));
472 }
473 let (mat_rows, mat_cols) = matrix.shape();
474 if mat_cols != self.shape[mode] {
475 return Err(SparseError::DimensionMismatch {
476 expected: self.shape[mode],
477 found: mat_cols,
478 });
479 }
480
481 let coo = self.to_sparse_tensor()?;
482
483 let mut new_shape = self.shape.clone();
484 new_shape[mode] = mat_rows;
485
486 let mut result_map: std::collections::HashMap<Vec<usize>, T> =
488 std::collections::HashMap::new();
489
490 for i in 0..coo.nnz() {
491 let mode_idx = coo.indices[mode][i];
492 let tensor_val = coo.values[i];
493
494 for new_mode_idx in 0..mat_rows {
496 let m_val = matrix.get(new_mode_idx, mode_idx);
497 if SparseElement::is_zero(&m_val) {
498 continue;
499 }
500
501 let mut key: Vec<usize> = (0..ndim).map(|m| coo.indices[m][i]).collect();
502 key[mode] = new_mode_idx;
503
504 let entry = result_map.entry(key).or_insert(T::sparse_zero());
505 *entry = *entry + tensor_val * m_val;
506 }
507 }
508
509 let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); ndim];
510 let mut new_values: Vec<T> = Vec::new();
511
512 for (key, val) in &result_map {
513 if !SparseElement::is_zero(val) {
514 for (d, &k) in key.iter().enumerate() {
515 new_indices[d].push(k);
516 }
517 new_values.push(*val);
518 }
519 }
520
521 if new_values.is_empty() {
522 return Ok(SparseTensor {
524 indices: (0..ndim).map(|_| Vec::new()).collect(),
525 values: Vec::new(),
526 shape: new_shape,
527 });
528 }
529
530 SparseTensor::new(new_indices, new_values, new_shape)
531 }
532
533 pub fn memory_usage(&self) -> usize {
535 let mut total = 0usize;
536 for fp in &self.fptr {
537 total += fp.len() * std::mem::size_of::<usize>();
538 }
539 for fi in &self.fids {
540 total += fi.len() * std::mem::size_of::<usize>();
541 }
542 total += self.values.len() * std::mem::size_of::<T>();
543 total += self.shape.len() * std::mem::size_of::<usize>();
544 total += self.mode_order.len() * std::mem::size_of::<usize>();
545 total
546 }
547}
548
549fn build_csf_levels<T: Copy>(
558 entries: &[(Vec<usize>, T)],
559 fptr: &mut Vec<Vec<usize>>,
560 fids: &mut Vec<Vec<usize>>,
561 values: &mut Vec<T>,
562 ndim: usize,
563) {
564 if entries.is_empty() || ndim == 0 {
565 return;
566 }
567
568 for l in 0..ndim {
570 fptr[l] = Vec::new();
571 fids[l] = Vec::new();
572 }
573
574 build_level(entries, fptr, fids, values, 0, ndim);
576
577 for l in 0..(ndim - 1) {
582 fptr[l].push(fids[l + 1].len());
583 }
584}
585
586type LevelGroups<T> = Vec<(usize, Vec<(Vec<usize>, T)>)>;
588
589fn build_level<T: Copy>(
595 entries: &[(Vec<usize>, T)],
596 fptr: &mut Vec<Vec<usize>>,
597 fids: &mut Vec<Vec<usize>>,
598 values: &mut Vec<T>,
599 level: usize,
600 ndim: usize,
601) {
602 if entries.is_empty() {
603 return;
604 }
605
606 if level == ndim - 1 {
607 for (coords, val) in entries {
609 fids[level].push(coords[level]);
610 values.push(*val);
611 }
612 return;
613 }
614
615 let mut groups: LevelGroups<T> = Vec::new();
617 let mut current_coord = entries[0].0[level];
618 let mut current_group: Vec<(Vec<usize>, T)> = Vec::new();
619
620 for (coords, val) in entries {
621 if coords[level] != current_coord {
622 groups.push((current_coord, std::mem::take(&mut current_group)));
623 current_coord = coords[level];
624 }
625 current_group.push((coords.clone(), *val));
626 }
627 groups.push((current_coord, current_group));
628
629 for (coord, group) in &groups {
631 fids[level].push(*coord);
632 fptr[level].push(fids[level + 1].len());
634 build_level(group, fptr, fids, values, level + 1, ndim);
635 }
636 }
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use approx::assert_relative_eq;
643
644 fn create_test_tensor_3d() -> SparseTensor<f64> {
645 let indices = vec![
647 vec![0, 0, 0, 1, 1], vec![0, 1, 2, 0, 2], vec![0, 1, 3, 2, 0], ];
651 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
652 let shape = vec![2, 3, 4];
653 SparseTensor::new(indices, values, shape).expect("failed to create tensor")
654 }
655
656 #[test]
657 fn test_csf_from_sparse_tensor() {
658 let tensor = create_test_tensor_3d();
659 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
660
661 assert_eq!(csf.ndim(), 3);
662 assert_eq!(csf.nnz(), 5);
663 assert_eq!(csf.shape(), &[2, 3, 4]);
664 assert_eq!(csf.mode_order(), &[0, 1, 2]);
665 }
666
667 #[test]
668 fn test_csf_roundtrip() {
669 let tensor = create_test_tensor_3d();
670 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
671
672 let recovered = csf.to_sparse_tensor().expect("to_sparse_tensor failed");
673
674 assert_eq!(recovered.nnz(), tensor.nnz());
675
676 for i in 0..tensor.nnz() {
678 let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
679 let orig_val = tensor.get(&coords);
680 let rec_val = recovered.get(&coords);
681 assert_relative_eq!(orig_val, rec_val, epsilon = 1e-12);
682 }
683 }
684
685 #[test]
686 fn test_csf_get() {
687 let tensor = create_test_tensor_3d();
688 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
689
690 assert_relative_eq!(csf.get(&[0, 0, 0]), 1.0, epsilon = 1e-12);
691 assert_relative_eq!(csf.get(&[0, 1, 1]), 2.0, epsilon = 1e-12);
692 assert_relative_eq!(csf.get(&[0, 2, 3]), 3.0, epsilon = 1e-12);
693 assert_relative_eq!(csf.get(&[1, 0, 2]), 4.0, epsilon = 1e-12);
694 assert_relative_eq!(csf.get(&[1, 2, 0]), 5.0, epsilon = 1e-12);
695 assert_relative_eq!(csf.get(&[0, 0, 1]), 0.0, epsilon = 1e-12); }
697
698 #[test]
699 fn test_csf_different_mode_order() {
700 let tensor = create_test_tensor_3d();
701
702 let csf = CsfTensor::from_sparse_tensor(&tensor, &[2, 1, 0]).expect("CSF creation failed");
704 assert_eq!(csf.nnz(), 5);
705 assert_eq!(csf.mode_order(), &[2, 1, 0]);
706
707 assert_relative_eq!(csf.get(&[0, 0, 0]), 1.0, epsilon = 1e-12);
709 assert_relative_eq!(csf.get(&[1, 2, 0]), 5.0, epsilon = 1e-12);
710
711 let recovered = csf.to_sparse_tensor().expect("roundtrip failed");
713 for i in 0..tensor.nnz() {
714 let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
715 assert_relative_eq!(tensor.get(&coords), recovered.get(&coords), epsilon = 1e-12);
716 }
717 }
718
719 #[test]
720 fn test_csf_contract_vector() {
721 let indices = vec![
723 vec![0, 0, 1], vec![0, 1, 0], ];
726 let values = vec![1.0, 2.0, 3.0];
727 let shape = vec![2, 3];
728 let tensor = SparseTensor::new(indices, values, shape).expect("create tensor");
729
730 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF creation");
731
732 let vector = vec![1.0, 2.0, 0.0]; let result = csf.contract_vector(1, &vector).expect("contract_vector");
734
735 assert_eq!(result.shape, vec![2]);
737
738 assert_relative_eq!(result.get(&[0]), 5.0, epsilon = 1e-12);
741 assert_relative_eq!(result.get(&[1]), 3.0, epsilon = 1e-12);
742 }
743
744 #[test]
745 fn test_csf_mode_n_product() {
746 let indices = vec![
748 vec![0, 0, 1], vec![0, 2, 1], ];
751 let values = vec![1.0, 3.0, 2.0];
752 let shape = vec![2, 3];
753 let tensor = SparseTensor::new(indices, values, shape).expect("create tensor");
754 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF");
755
756 let m_rows = vec![0, 0, 1, 1];
758 let m_cols = vec![0, 2, 0, 1];
759 let m_vals = vec![1.0, 1.0, 0.5, 2.0];
760 let matrix =
761 CsrArray::from_triplets(&m_rows, &m_cols, &m_vals, (2, 3), false).expect("matrix");
762
763 let result = csf.mode_n_product(1, &matrix).expect("mode_n_product");
764
765 assert_eq!(result.shape, vec![2, 2]);
767
768 assert_relative_eq!(result.get(&[0, 0]), 4.0, epsilon = 1e-12);
773 assert_relative_eq!(result.get(&[0, 1]), 0.5, epsilon = 1e-12);
774 assert_relative_eq!(result.get(&[1, 1]), 4.0, epsilon = 1e-12);
775 }
776
777 #[test]
778 fn test_csf_memory_usage() {
779 let tensor = create_test_tensor_3d();
780 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1, 2]).expect("CSF creation failed");
781 let mem = csf.memory_usage();
782 assert!(mem > 0);
783 }
784
785 #[test]
786 fn test_csf_empty_tensor() {
787 let indices = vec![Vec::<usize>::new(), Vec::<usize>::new()];
788 let values: Vec<f64> = Vec::new();
789 let shape = vec![3, 4];
790 let tensor = SparseTensor::new(indices, values, shape).expect("empty tensor");
791 let csf = CsfTensor::from_sparse_tensor(&tensor, &[0, 1]).expect("CSF");
792 assert_eq!(csf.nnz(), 0);
793 }
794
795 #[test]
796 fn test_csf_3d_roundtrip_with_permutation() {
797 let indices = vec![
799 vec![0, 0, 1, 2, 2, 2], vec![0, 3, 1, 0, 2, 3], vec![0, 4, 2, 1, 3, 4], ];
803 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
804 let shape = vec![3, 4, 5];
805 let tensor = SparseTensor::new(indices, values, shape).expect("tensor");
806
807 for perm in &[[0, 1, 2], [1, 0, 2], [2, 0, 1], [0, 2, 1]] {
808 let csf = CsfTensor::from_sparse_tensor(&tensor, perm).expect("CSF");
809 assert_eq!(csf.nnz(), 6);
810
811 let recovered = csf.to_sparse_tensor().expect("roundtrip");
812 assert_eq!(recovered.nnz(), 6);
813
814 for i in 0..tensor.nnz() {
815 let coords: Vec<usize> = (0..3).map(|d| tensor.indices[d][i]).collect();
816 assert_relative_eq!(tensor.get(&coords), recovered.get(&coords), epsilon = 1e-12,);
817 }
818 }
819 }
820}