1use crate::error::{SparseError, SparseResult};
24use scirs2_core::numeric::{SparseElement, Zero};
25use std::fmt::Debug;
26
27#[derive(Debug, Clone)]
32pub struct CsfTensor<T> {
33 pub shape: Vec<usize>,
35 pub mode_order: Vec<usize>,
37 pub fib_ptr: Vec<Vec<usize>>,
41 pub fib_idx: Vec<Vec<usize>>,
44 pub values: Vec<T>,
46}
47
48#[derive(Debug, Clone)]
50struct CooEntry<T: Copy> {
51 coords: Vec<usize>,
52 value: T,
53}
54
55impl<T> CsfTensor<T>
56where
57 T: Clone + Copy + Zero + SparseElement + Debug,
58{
59 pub fn from_coo(
70 indices: &[Vec<usize>],
71 values: &[T],
72 shape: &[usize],
73 mode_order: Option<&[usize]>,
74 ) -> SparseResult<Self> {
75 let ndim = shape.len();
76 if indices.len() != ndim {
77 return Err(SparseError::ValueError(format!(
78 "indices length {} != ndim {}",
79 indices.len(),
80 ndim
81 )));
82 }
83 let nnz = values.len();
84 if ndim > 0 && indices[0].len() != nnz {
85 return Err(SparseError::ValueError(
86 "indices and values length mismatch".to_string(),
87 ));
88 }
89
90 let order: Vec<usize> = match mode_order {
91 Some(o) => {
92 if o.len() != ndim {
93 return Err(SparseError::ValueError(
94 "mode_order length must match ndim".to_string(),
95 ));
96 }
97 let mut sorted = o.to_vec();
98 sorted.sort_unstable();
99 for (i, &v) in sorted.iter().enumerate() {
100 if v != i {
101 return Err(SparseError::ValueError(
102 "mode_order must be a permutation of 0..ndim".to_string(),
103 ));
104 }
105 }
106 o.to_vec()
107 }
108 None => (0..ndim).collect(),
109 };
110
111 if nnz == 0 {
112 let fib_ptr = if ndim > 1 {
113 (0..ndim - 1).map(|_| vec![0usize]).collect()
114 } else {
115 Vec::new()
116 };
117 let fib_idx = (0..ndim).map(|_| Vec::new()).collect();
118 return Ok(Self {
119 shape: shape.to_vec(),
120 mode_order: order,
121 fib_ptr,
122 fib_idx,
123 values: Vec::new(),
124 });
125 }
126
127 let mut entries: Vec<CooEntry<T>> = (0..nnz)
129 .map(|i| {
130 let coords: Vec<usize> = order.iter().map(|&m| indices[m][i]).collect();
131 CooEntry {
132 coords,
133 value: values[i],
134 }
135 })
136 .collect();
137 entries.sort_by(|a, b| a.coords.cmp(&b.coords));
138
139 let mut fib_ptr: Vec<Vec<usize>> = Vec::new();
141 let mut fib_idx: Vec<Vec<usize>> = Vec::new();
142 let mut leaf_values: Vec<T> = Vec::new();
143
144 for _ in 0..ndim {
145 fib_idx.push(Vec::new());
146 }
147 for _ in 0..ndim.saturating_sub(1) {
148 fib_ptr.push(Vec::new());
149 }
150
151 Self::build_levels(
152 &entries,
153 &mut fib_ptr,
154 &mut fib_idx,
155 &mut leaf_values,
156 0,
157 ndim,
158 );
159
160 for l in 0..ndim.saturating_sub(1) {
162 fib_ptr[l].push(fib_idx[l + 1].len());
163 }
164
165 Ok(Self {
166 shape: shape.to_vec(),
167 mode_order: order,
168 fib_ptr,
169 fib_idx,
170 values: leaf_values,
171 })
172 }
173
174 fn build_levels(
176 entries: &[CooEntry<T>],
177 fib_ptr: &mut Vec<Vec<usize>>,
178 fib_idx: &mut Vec<Vec<usize>>,
179 values: &mut Vec<T>,
180 level: usize,
181 ndim: usize,
182 ) {
183 if entries.is_empty() {
184 return;
185 }
186
187 if level == ndim - 1 {
188 for entry in entries {
190 fib_idx[level].push(entry.coords[level]);
191 values.push(entry.value);
192 }
193 return;
194 }
195
196 let mut group_start = 0usize;
198 while group_start < entries.len() {
199 let coord = entries[group_start].coords[level];
200 let mut group_end = group_start + 1;
201 while group_end < entries.len() && entries[group_end].coords[level] == coord {
202 group_end += 1;
203 }
204
205 fib_idx[level].push(coord);
206 fib_ptr[level].push(fib_idx[level + 1].len());
207 Self::build_levels(
208 &entries[group_start..group_end],
209 fib_ptr,
210 fib_idx,
211 values,
212 level + 1,
213 ndim,
214 );
215
216 group_start = group_end;
217 }
218 }
219
220 pub fn ndim(&self) -> usize {
222 self.shape.len()
223 }
224
225 pub fn nnz(&self) -> usize {
227 self.values.len()
228 }
229
230 pub fn get(&self, indices: &[usize]) -> Option<T> {
234 let ndim = self.ndim();
235 if indices.len() != ndim {
236 return None;
237 }
238
239 let ordered: Vec<usize> = self.mode_order.iter().map(|&m| indices[m]).collect();
241 self.search_tree(&ordered, 0, 0)
242 }
243
244 fn search_tree(&self, ordered: &[usize], level: usize, fiber_idx: usize) -> Option<T> {
246 let ndim = self.ndim();
247
248 let (start, end) = if level == 0 {
250 (0, self.fib_idx[0].len())
251 } else {
252 let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
253 let e = self.fib_ptr[level - 1]
254 .get(fiber_idx + 1)
255 .copied()
256 .unwrap_or(self.fib_idx[level].len());
257 (s, e)
258 };
259
260 let target = ordered[level];
261
262 let range = &self.fib_idx[level][start..end];
264 match range.binary_search(&target) {
265 Ok(pos) => {
266 let abs_pos = start + pos;
267 if level == ndim - 1 {
268 self.values.get(abs_pos).copied()
270 } else {
271 self.search_tree(ordered, level + 1, abs_pos)
272 }
273 }
274 Err(_) => Some(T::sparse_zero()),
275 }
276 }
277
278 pub fn fiber(
289 &self,
290 free_mode: usize,
291 fixed_indices: &[usize],
292 ) -> SparseResult<Vec<(usize, T)>> {
293 let ndim = self.ndim();
294 if free_mode >= ndim {
295 return Err(SparseError::ValueError(format!(
296 "free_mode {} >= ndim {}",
297 free_mode, ndim
298 )));
299 }
300 if fixed_indices.len() != ndim - 1 {
301 return Err(SparseError::ValueError(format!(
302 "fixed_indices length {} != ndim-1 = {}",
303 fixed_indices.len(),
304 ndim - 1
305 )));
306 }
307
308 let mut result = Vec::new();
310 self.collect_fiber(0, 0, free_mode, fixed_indices, &mut Vec::new(), &mut result);
311
312 Ok(result)
313 }
314
315 fn collect_fiber(
317 &self,
318 level: usize,
319 fiber_idx: usize,
320 free_mode: usize,
321 fixed_indices: &[usize],
322 coord_stack: &mut Vec<usize>,
323 result: &mut Vec<(usize, T)>,
324 ) {
325 let ndim = self.ndim();
326 let (start, end) = if level == 0 {
327 (0, self.fib_idx[0].len())
328 } else {
329 let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
330 let e = self.fib_ptr[level - 1]
331 .get(fiber_idx + 1)
332 .copied()
333 .unwrap_or(self.fib_idx[level].len());
334 (s, e)
335 };
336
337 let current_mode = self.mode_order[level];
338
339 for i in start..end {
340 if i >= self.fib_idx[level].len() {
341 break;
342 }
343 let coord = self.fib_idx[level][i];
344
345 if current_mode == free_mode {
346 coord_stack.push(coord);
348 if level == ndim - 1 {
349 if self.check_fixed_coords(coord_stack, free_mode, fixed_indices) {
351 if let Some(&val) = self.values.get(i) {
352 result.push((coord, val));
353 }
354 }
355 } else {
356 self.collect_fiber(level + 1, i, free_mode, fixed_indices, coord_stack, result);
357 }
358 coord_stack.pop();
359 } else {
360 let fixed_idx = self.fixed_index_for_mode(current_mode, free_mode);
362 if let Some(fidx) = fixed_idx {
363 if fidx < fixed_indices.len() && coord == fixed_indices[fidx] {
364 coord_stack.push(coord);
365 if level == ndim - 1 {
366 if self.check_fixed_coords(coord_stack, free_mode, fixed_indices) {
367 if let Some(&val) = self.values.get(i) {
368 let free_coord = self.find_free_coord(coord_stack, free_mode);
371 if let Some(fc) = free_coord {
372 result.push((fc, val));
373 }
374 }
375 }
376 } else {
377 self.collect_fiber(
378 level + 1,
379 i,
380 free_mode,
381 fixed_indices,
382 coord_stack,
383 result,
384 );
385 }
386 coord_stack.pop();
387 }
388 }
390 }
391 }
392 }
393
394 fn fixed_index_for_mode(&self, mode: usize, free_mode: usize) -> Option<usize> {
396 if mode == free_mode {
397 return None;
398 }
399 let mut idx = 0usize;
400 for m in 0..self.ndim() {
401 if m == free_mode {
402 continue;
403 }
404 if m == mode {
405 return Some(idx);
406 }
407 idx += 1;
408 }
409 None
410 }
411
412 fn check_fixed_coords(
414 &self,
415 coord_stack: &[usize],
416 free_mode: usize,
417 fixed_indices: &[usize],
418 ) -> bool {
419 let mut fix_idx = 0usize;
420 for (level, &coord) in coord_stack.iter().enumerate() {
421 if level >= self.mode_order.len() {
422 break;
423 }
424 let mode = self.mode_order[level];
425 if mode == free_mode {
426 continue;
427 }
428 if fix_idx >= fixed_indices.len() || coord != fixed_indices[fix_idx] {
429 return false;
430 }
431 fix_idx += 1;
432 }
433 true
434 }
435
436 fn find_free_coord(&self, coord_stack: &[usize], free_mode: usize) -> Option<usize> {
438 for (level, &coord) in coord_stack.iter().enumerate() {
439 if level < self.mode_order.len() && self.mode_order[level] == free_mode {
440 return Some(coord);
441 }
442 }
443 None
444 }
445
446 pub fn matricize(&self, mode: usize) -> SparseResult<(Vec<usize>, Vec<usize>, Vec<T>)> {
455 let ndim = self.ndim();
456 if mode >= ndim {
457 return Err(SparseError::ValueError(format!(
458 "mode {} >= ndim {}",
459 mode, ndim
460 )));
461 }
462
463 let other_modes: Vec<usize> = (0..ndim).filter(|&m| m != mode).collect();
465 let mut col_strides: Vec<usize> = Vec::with_capacity(other_modes.len());
466 let mut stride = 1usize;
467 for &m in other_modes.iter().rev() {
468 col_strides.push(stride);
469 stride = stride.saturating_mul(self.shape[m]);
470 }
471 col_strides.reverse();
472
473 let mut rows = Vec::new();
475 let mut cols = Vec::new();
476 let mut vals = Vec::new();
477 let mut coord_stack: Vec<usize> = vec![0; ndim];
478
479 self.traverse_for_matricize(
480 0,
481 0,
482 &mut coord_stack,
483 mode,
484 &other_modes,
485 &col_strides,
486 &mut rows,
487 &mut cols,
488 &mut vals,
489 );
490
491 Ok((rows, cols, vals))
492 }
493
494 fn traverse_for_matricize(
496 &self,
497 level: usize,
498 fiber_idx: usize,
499 coord_stack: &mut Vec<usize>,
500 mode: usize,
501 other_modes: &[usize],
502 col_strides: &[usize],
503 rows: &mut Vec<usize>,
504 cols: &mut Vec<usize>,
505 vals: &mut Vec<T>,
506 ) {
507 let ndim = self.ndim();
508 let (start, end) = if level == 0 {
509 (0, self.fib_idx[0].len())
510 } else {
511 let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
512 let e = self.fib_ptr[level - 1]
513 .get(fiber_idx + 1)
514 .copied()
515 .unwrap_or(self.fib_idx[level].len());
516 (s, e)
517 };
518
519 for i in start..end {
520 if i >= self.fib_idx[level].len() {
521 break;
522 }
523 coord_stack[level] = self.fib_idx[level][i];
524
525 if level == ndim - 1 {
526 if let Some(&val) = self.values.get(i) {
528 let mut orig_coords = vec![0usize; ndim];
530 for (l, &c) in coord_stack.iter().enumerate().take(ndim) {
531 orig_coords[self.mode_order[l]] = c;
532 }
533
534 let row = orig_coords[mode];
535 let mut col = 0usize;
536 for (idx, &m) in other_modes.iter().enumerate() {
537 col += orig_coords[m] * col_strides[idx];
538 }
539
540 rows.push(row);
541 cols.push(col);
542 vals.push(val);
543 }
544 } else {
545 self.traverse_for_matricize(
546 level + 1,
547 i,
548 coord_stack,
549 mode,
550 other_modes,
551 col_strides,
552 rows,
553 cols,
554 vals,
555 );
556 }
557 }
558 }
559
560 pub fn memory_usage(&self) -> usize {
562 let mut total = 0usize;
563 for fp in &self.fib_ptr {
564 total += fp.len() * std::mem::size_of::<usize>();
565 }
566 for fi in &self.fib_idx {
567 total += fi.len() * std::mem::size_of::<usize>();
568 }
569 total += self.values.len() * std::mem::size_of::<T>();
570 total += self.shape.len() * std::mem::size_of::<usize>();
571 total += self.mode_order.len() * std::mem::size_of::<usize>();
572 total
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use approx::assert_relative_eq;
580
581 #[test]
582 fn test_csf_3d_construction_and_access() {
583 let indices = vec![
585 vec![0, 0, 0, 1, 1], vec![0, 1, 2, 0, 2], vec![0, 1, 3, 2, 0], ];
589 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
590 let shape = vec![2, 3, 4];
591
592 let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
593 assert_eq!(csf.ndim(), 3);
594 assert_eq!(csf.nnz(), 5);
595
596 assert_relative_eq!(csf.get(&[0, 0, 0]).unwrap_or(0.0), 1.0, epsilon = 1e-12);
597 assert_relative_eq!(csf.get(&[0, 1, 1]).unwrap_or(0.0), 2.0, epsilon = 1e-12);
598 assert_relative_eq!(csf.get(&[0, 2, 3]).unwrap_or(0.0), 3.0, epsilon = 1e-12);
599 assert_relative_eq!(csf.get(&[1, 0, 2]).unwrap_or(0.0), 4.0, epsilon = 1e-12);
600 assert_relative_eq!(csf.get(&[1, 2, 0]).unwrap_or(0.0), 5.0, epsilon = 1e-12);
601 assert_relative_eq!(csf.get(&[0, 0, 1]).unwrap_or(0.0), 0.0, epsilon = 1e-12);
603 }
604
605 #[test]
606 fn test_csf_fiber_extraction() {
607 let indices = vec![
612 vec![0, 0, 1, 2, 2], vec![0, 2, 1, 0, 2], ];
615 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
616 let shape = vec![3, 3];
617
618 let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
619
620 let fiber = csf.fiber(1, &[0]).expect("fiber");
622 assert_eq!(fiber.len(), 2);
624 let fiber_map: std::collections::HashMap<usize, f64> = fiber.into_iter().collect();
625 assert_relative_eq!(*fiber_map.get(&0).unwrap_or(&0.0), 1.0, epsilon = 1e-12);
626 assert_relative_eq!(*fiber_map.get(&2).unwrap_or(&0.0), 2.0, epsilon = 1e-12);
627
628 let fiber = csf.fiber(0, &[2]).expect("fiber");
630 let fiber_map: std::collections::HashMap<usize, f64> = fiber.into_iter().collect();
631 assert_relative_eq!(*fiber_map.get(&0).unwrap_or(&0.0), 2.0, epsilon = 1e-12);
632 assert_relative_eq!(*fiber_map.get(&2).unwrap_or(&0.0), 5.0, epsilon = 1e-12);
633 }
634
635 #[test]
636 fn test_csf_matricize() {
637 let indices = vec![
639 vec![0, 0, 1, 1], vec![0, 1, 0, 2], vec![0, 1, 0, 1], ];
643 let values = vec![1.0, 2.0, 3.0, 4.0];
644 let shape = vec![2, 3, 2];
645
646 let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
647
648 let (rows, cols, vals) = csf.matricize(0).expect("matricize");
650 assert_eq!(rows.len(), 4);
651
652 for ((&r, &c), &v) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
654 assert!(r < 2);
655 assert!(c < 6);
656 assert!(v != 0.0);
657 }
658 }
659
660 #[test]
661 fn test_csf_empty() {
662 let indices: Vec<Vec<usize>> = vec![Vec::new(), Vec::new()];
663 let values: Vec<f64> = Vec::new();
664 let shape = vec![3, 4];
665 let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
666 assert_eq!(csf.nnz(), 0);
667 assert_eq!(csf.ndim(), 2);
668 }
669
670 #[test]
671 fn test_csf_with_mode_order() {
672 let indices = vec![
673 vec![0, 0, 1], vec![0, 1, 0], ];
676 let values = vec![1.0, 2.0, 3.0];
677 let shape = vec![2, 2];
678
679 let csf = CsfTensor::from_coo(&indices, &values, &shape, Some(&[1, 0])).expect("csf");
680 assert_eq!(csf.nnz(), 3);
681
682 assert_relative_eq!(csf.get(&[0, 0]).unwrap_or(0.0), 1.0, epsilon = 1e-12);
684 assert_relative_eq!(csf.get(&[0, 1]).unwrap_or(0.0), 2.0, epsilon = 1e-12);
685 assert_relative_eq!(csf.get(&[1, 0]).unwrap_or(0.0), 3.0, epsilon = 1e-12);
686 }
687
688 #[test]
689 fn test_csf_memory_usage() {
690 let indices = vec![vec![0, 1], vec![0, 1]];
691 let values = vec![1.0, 2.0];
692 let shape = vec![2, 2];
693 let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
694 assert!(csf.memory_usage() > 0);
695 }
696
697 #[test]
698 fn test_csf_invalid_mode_order() {
699 let indices = vec![vec![0], vec![0]];
700 let values = vec![1.0];
701 let shape = vec![2, 2];
702 assert!(CsfTensor::from_coo(&indices, &values, &shape, Some(&[0, 0])).is_err());
703 assert!(CsfTensor::from_coo(&indices, &values, &shape, Some(&[0])).is_err());
704 }
705}