1#![allow(unused_variables)] use crate::errors::{Result, TrustformersError};
10use crate::tensor::{DType, Tensor};
11use scirs2_core::ndarray::{ArrayD, IxDyn};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub enum SparseFormat {
18 COO,
20 CSR,
22 CSC,
24 BSR,
26 DOK,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct SparseTensor {
33 pub format: SparseFormat,
35 pub shape: Vec<usize>,
37 pub dtype: DType,
39 pub values: Vec<f32>,
41 pub indices: SparseIndices,
43 pub nnz: usize,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum SparseIndices {
50 COO {
52 row_indices: Vec<usize>,
53 col_indices: Vec<usize>,
54 },
55 CSR {
57 row_ptr: Vec<usize>,
58 col_indices: Vec<usize>,
59 },
60 CSC {
62 col_ptr: Vec<usize>,
63 row_indices: Vec<usize>,
64 },
65 BSR {
67 row_ptr: Vec<usize>,
68 col_indices: Vec<usize>,
69 block_shape: (usize, usize),
70 },
71 DOK {
73 indices_map: HashMap<(usize, usize), usize>,
74 },
75}
76
77impl SparseTensor {
78 pub fn new_coo(
80 shape: Vec<usize>,
81 row_indices: Vec<usize>,
82 col_indices: Vec<usize>,
83 values: Vec<f32>,
84 ) -> Result<Self> {
85 if row_indices.len() != col_indices.len() || col_indices.len() != values.len() {
86 return Err(TrustformersError::shape_error(
87 "Indices and values must have the same length".to_string(),
88 ));
89 }
90
91 if shape.len() != 2 {
92 return Err(TrustformersError::shape_error(
93 "COO format currently supports only 2D tensors".to_string(),
94 ));
95 }
96
97 for &row in &row_indices {
99 if row >= shape[0] {
100 return Err(TrustformersError::shape_error(format!(
101 "Row index {} out of bounds for shape {:?}",
102 row, shape
103 )));
104 }
105 }
106
107 for &col in &col_indices {
108 if col >= shape[1] {
109 return Err(TrustformersError::shape_error(format!(
110 "Column index {} out of bounds for shape {:?}",
111 col, shape
112 )));
113 }
114 }
115
116 Ok(SparseTensor {
117 format: SparseFormat::COO,
118 shape,
119 dtype: DType::F32,
120 nnz: values.len(),
121 values,
122 indices: SparseIndices::COO {
123 row_indices,
124 col_indices,
125 },
126 })
127 }
128
129 pub fn new_csr(
131 shape: Vec<usize>,
132 row_ptr: Vec<usize>,
133 col_indices: Vec<usize>,
134 values: Vec<f32>,
135 ) -> Result<Self> {
136 if col_indices.len() != values.len() {
137 return Err(TrustformersError::shape_error(
138 "Column indices and values must have the same length".to_string(),
139 ));
140 }
141
142 if shape.len() != 2 {
143 return Err(TrustformersError::shape_error(
144 "CSR format currently supports only 2D tensors".to_string(),
145 ));
146 }
147
148 if row_ptr.len() != shape[0] + 1 {
149 return Err(TrustformersError::shape_error(format!(
150 "Row pointer length {} must be {} for shape {:?}",
151 row_ptr.len(),
152 shape[0] + 1,
153 shape
154 )));
155 }
156
157 Ok(SparseTensor {
158 format: SparseFormat::CSR,
159 shape,
160 dtype: DType::F32,
161 nnz: values.len(),
162 values,
163 indices: SparseIndices::CSR {
164 row_ptr,
165 col_indices,
166 },
167 })
168 }
169
170 pub fn from_dense(tensor: &Tensor, threshold: f32) -> Result<Self> {
172 match tensor {
173 Tensor::F32(arr) => {
174 let shape = arr.shape().to_vec();
175 if shape.len() != 2 {
176 return Err(TrustformersError::shape_error(
177 "Dense to sparse conversion currently supports only 2D tensors".to_string(),
178 ));
179 }
180
181 let mut row_indices = Vec::new();
182 let mut col_indices = Vec::new();
183 let mut values = Vec::new();
184
185 for (i, row) in arr.outer_iter().enumerate() {
186 for (j, &val) in row.iter().enumerate() {
187 if val.abs() > threshold {
188 row_indices.push(i);
189 col_indices.push(j);
190 values.push(val);
191 }
192 }
193 }
194
195 Self::new_coo(shape, row_indices, col_indices, values)
196 },
197 _ => Err(TrustformersError::tensor_op_error(
198 "Dense to sparse conversion only supports F32 tensors",
199 "dense to sparse conversion",
200 )),
201 }
202 }
203
204 pub fn to_dense(&self) -> Result<Tensor> {
206 match self.format {
207 SparseFormat::COO => {
208 if let SparseIndices::COO {
209 row_indices,
210 col_indices,
211 } = &self.indices
212 {
213 let mut dense = ArrayD::zeros(IxDyn(&self.shape));
214
215 for ((&row, &col), &val) in
216 row_indices.iter().zip(col_indices.iter()).zip(self.values.iter())
217 {
218 dense[[row, col]] = val;
219 }
220
221 Ok(Tensor::F32(dense))
222 } else {
223 Err(TrustformersError::tensor_op_error(
224 "Invalid indices format for COO tensor",
225 "COO to dense conversion",
226 ))
227 }
228 },
229 SparseFormat::CSR => {
230 if let SparseIndices::CSR {
231 row_ptr,
232 col_indices,
233 } = &self.indices
234 {
235 let mut dense = ArrayD::zeros(IxDyn(&self.shape));
236
237 for (row, window) in row_ptr.windows(2).enumerate() {
238 let start = window[0];
239 let end = window[1];
240 for (offset, &col) in col_indices[start..end].iter().enumerate() {
241 let val = self.values[start + offset];
242 dense[[row, col]] = val;
243 }
244 }
245
246 Ok(Tensor::F32(dense))
247 } else {
248 Err(TrustformersError::tensor_op_error(
249 "Invalid indices format for CSR tensor",
250 "CSR to dense conversion",
251 ))
252 }
253 },
254 SparseFormat::CSC => {
255 if let SparseIndices::CSC {
256 col_ptr,
257 row_indices,
258 } = &self.indices
259 {
260 let mut dense = ArrayD::zeros(IxDyn(&self.shape));
261
262 for (col, window) in col_ptr.windows(2).enumerate() {
263 let start = window[0];
264 let end = window[1];
265 for (offset, &row) in row_indices[start..end].iter().enumerate() {
266 let val = self.values[start + offset];
267 dense[[row, col]] = val;
268 }
269 }
270
271 Ok(Tensor::F32(dense))
272 } else {
273 Err(TrustformersError::tensor_op_error(
274 "Invalid indices format for CSC tensor",
275 "CSC to dense conversion",
276 ))
277 }
278 },
279 SparseFormat::BSR => {
280 if let SparseIndices::BSR {
281 row_ptr,
282 col_indices,
283 block_shape,
284 } = &self.indices
285 {
286 let mut dense = ArrayD::zeros(IxDyn(&self.shape));
287 let (block_rows, block_cols) = *block_shape;
288 let values_per_block = block_rows * block_cols;
289
290 for (block_row, window) in row_ptr.windows(2).enumerate() {
291 let start = window[0];
292 let end = window[1];
293 for (offset, &block_col) in col_indices[start..end].iter().enumerate() {
294 let block_idx = start + offset;
295
296 let row_start = block_row * block_rows;
298 let row_end = (row_start + block_rows).min(self.shape[0]);
299 let col_start = block_col * block_cols;
300 let col_end = (col_start + block_cols).min(self.shape[1]);
301
302 let values_start = block_idx * values_per_block;
304 let mut value_idx = 0;
305
306 for row in row_start..row_end {
307 for col in col_start..col_end {
308 if values_start + value_idx < self.values.len() {
309 dense[[row, col]] = self.values[values_start + value_idx];
310 value_idx += 1;
311 }
312 }
313 }
314 }
315 }
316
317 Ok(Tensor::F32(dense))
318 } else {
319 Err(TrustformersError::tensor_op_error(
320 "Invalid indices format for BSR tensor",
321 "BSR to dense conversion",
322 ))
323 }
324 },
325 SparseFormat::DOK => {
326 if let SparseIndices::DOK { indices_map } = &self.indices {
327 let mut dense = ArrayD::zeros(IxDyn(&self.shape));
328
329 for (&(row, col), &value_idx) in indices_map.iter() {
330 if value_idx < self.values.len() {
331 dense[[row, col]] = self.values[value_idx];
332 }
333 }
334
335 Ok(Tensor::F32(dense))
336 } else {
337 Err(TrustformersError::tensor_op_error(
338 "Invalid indices format for DOK tensor",
339 "DOK to dense conversion",
340 ))
341 }
342 },
343 }
344 }
345
346 pub fn to_format(&self, target_format: SparseFormat) -> Result<Self> {
348 if self.format == target_format {
349 return Ok(self.clone());
350 }
351
352 match (self.format, target_format) {
353 (SparseFormat::COO, SparseFormat::CSR) => self.coo_to_csr(),
354 (SparseFormat::CSR, SparseFormat::COO) => self.csr_to_coo(),
355 _ => Err(TrustformersError::tensor_op_error(
356 &format!(
357 "Conversion from {:?} to {:?} not implemented",
358 self.format, target_format
359 ),
360 "sparse format conversion",
361 )),
362 }
363 }
364
365 fn coo_to_csr(&self) -> Result<Self> {
367 if let SparseIndices::COO {
368 row_indices,
369 col_indices,
370 } = &self.indices
371 {
372 let nrows = self.shape[0];
373 let nnz = self.nnz;
374
375 let mut row_ptr = vec![0; nrows + 1];
377
378 for &row in row_indices {
380 row_ptr[row + 1] += 1;
381 }
382
383 for i in 1..=nrows {
385 row_ptr[i] += row_ptr[i - 1];
386 }
387
388 let mut sorted_col_indices = vec![0; nnz];
390 let mut sorted_values = vec![0.0; nnz];
391 let mut temp_ptr = row_ptr.clone();
392
393 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
394 let dest = temp_ptr[row];
395 sorted_col_indices[dest] = col;
396 sorted_values[dest] = self.values[i];
397 temp_ptr[row] += 1;
398 }
399
400 Ok(SparseTensor {
401 format: SparseFormat::CSR,
402 shape: self.shape.clone(),
403 dtype: self.dtype,
404 nnz: self.nnz,
405 values: sorted_values,
406 indices: SparseIndices::CSR {
407 row_ptr,
408 col_indices: sorted_col_indices,
409 },
410 })
411 } else {
412 Err(TrustformersError::tensor_op_error(
413 "Invalid indices format for COO tensor",
414 "COO to CSR conversion",
415 ))
416 }
417 }
418
419 fn csr_to_coo(&self) -> Result<Self> {
421 if let SparseIndices::CSR {
422 row_ptr,
423 col_indices,
424 } = &self.indices
425 {
426 let mut row_indices = Vec::with_capacity(self.nnz);
427
428 for (row, window) in row_ptr.windows(2).enumerate() {
429 let start = window[0];
430 let end = window[1];
431 for _ in start..end {
432 row_indices.push(row);
433 }
434 }
435
436 Ok(SparseTensor {
437 format: SparseFormat::COO,
438 shape: self.shape.clone(),
439 dtype: self.dtype,
440 nnz: self.nnz,
441 values: self.values.clone(),
442 indices: SparseIndices::COO {
443 row_indices,
444 col_indices: col_indices.clone(),
445 },
446 })
447 } else {
448 Err(TrustformersError::tensor_op_error(
449 "Invalid indices format for CSR tensor",
450 "CSR to COO conversion",
451 ))
452 }
453 }
454
455 pub fn sparse_matmul(&self, other: &SparseTensor) -> Result<SparseTensor> {
457 let lhs = self.to_format(SparseFormat::CSR)?;
459 let rhs = other.to_format(SparseFormat::CSR)?;
460
461 if lhs.shape[1] != rhs.shape[0] {
462 return Err(TrustformersError::shape_error(format!(
463 "Matrix dimensions incompatible: {:?} x {:?}",
464 lhs.shape, rhs.shape
465 )));
466 }
467
468 let result_shape = vec![lhs.shape[0], rhs.shape[1]];
469
470 if let (
478 SparseIndices::CSR {
479 row_ptr: lhs_row_ptr,
480 col_indices: lhs_col_indices,
481 },
482 SparseIndices::CSR {
483 row_ptr: rhs_row_ptr,
484 col_indices: rhs_col_indices,
485 },
486 ) = (&lhs.indices, &rhs.indices)
487 {
488 let (result_row_ptr, result_col_indices) = Self::symbolic_sparse_matmul(
490 lhs_row_ptr,
491 lhs_col_indices,
492 rhs_row_ptr,
493 rhs_col_indices,
494 lhs.shape[0],
495 rhs.shape[1],
496 );
497
498 let result_values = Self::numerical_sparse_matmul(
500 &lhs.values,
501 lhs_row_ptr,
502 lhs_col_indices,
503 &rhs.values,
504 rhs_row_ptr,
505 rhs_col_indices,
506 &result_row_ptr,
507 &result_col_indices,
508 );
509
510 let mut row_indices = Vec::new();
512 let mut col_indices = Vec::new();
513 let mut values = Vec::new();
514
515 for i in 0..result_row_ptr.len() - 1 {
516 for idx in result_row_ptr[i]..result_row_ptr[i + 1] {
517 let val = result_values[idx];
518 if val.abs() > f32::EPSILON * 10.0 {
519 row_indices.push(i);
521 col_indices.push(result_col_indices[idx]);
522 values.push(val);
523 }
524 }
525 }
526
527 return SparseTensor::new_coo(result_shape, row_indices, col_indices, values);
528 }
529
530 let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
532
533 let (lhs_rows, lhs_cols, lhs_vals) = match &lhs.indices {
535 SparseIndices::COO {
536 row_indices,
537 col_indices,
538 } => (
539 row_indices.as_slice(),
540 col_indices.as_slice(),
541 lhs.values.as_slice(),
542 ),
543 _ => {
544 return Err(crate::errors::compute_error(
545 "sparse matrix multiplication",
546 "Unsupported sparse format combination for matrix multiplication",
547 ))
548 },
549 };
550
551 let (rhs_rows, rhs_cols, rhs_vals) = match &rhs.indices {
553 SparseIndices::COO {
554 row_indices,
555 col_indices,
556 } => (
557 row_indices.as_slice(),
558 col_indices.as_slice(),
559 rhs.values.as_slice(),
560 ),
561 _ => {
562 return Err(crate::errors::compute_error(
563 "sparse matrix multiplication",
564 "Unsupported sparse format combination for matrix multiplication",
565 ))
566 },
567 };
568
569 for (idx_a, (&i, (&j, &lhs_val))) in
571 lhs_rows.iter().zip(lhs_cols.iter().zip(lhs_vals.iter())).enumerate()
572 {
573 for (idx_b, (&k, (&l, &rhs_val))) in
574 rhs_rows.iter().zip(rhs_cols.iter().zip(rhs_vals.iter())).enumerate()
575 {
576 if j == k {
577 *result_map.entry((i, l)).or_insert(0.0) += lhs_val * rhs_val;
578 }
579 }
580 }
581
582 let mut row_indices = Vec::new();
584 let mut col_indices = Vec::new();
585 let mut values = Vec::new();
586
587 for ((row, col), val) in result_map.iter() {
588 if val.abs() > f32::EPSILON * 10.0 {
589 row_indices.push(*row);
590 col_indices.push(*col);
591 values.push(*val);
592 }
593 }
594
595 SparseTensor::new_coo(result_shape, row_indices, col_indices, values)
596 }
597
598 pub fn dense_matmul(&self, dense: &Tensor) -> Result<Tensor> {
600 let dense_shape = dense.shape();
601 if self.shape[1] != dense_shape[0] {
602 return Err(TrustformersError::shape_error(format!(
603 "Matrix dimensions incompatible: {:?} x {:?}",
604 self.shape, dense_shape
605 )));
606 }
607
608 match (self.format, dense) {
609 (SparseFormat::CSR, Tensor::F32(dense_arr)) => {
610 if let SparseIndices::CSR {
611 row_ptr,
612 col_indices,
613 } = &self.indices
614 {
615 let result_shape = vec![self.shape[0], dense_shape[1]];
616 let mut result = ArrayD::zeros(IxDyn(&result_shape));
617
618 for i in 0..self.shape[0] {
619 let start = row_ptr[i];
620 let end = row_ptr[i + 1];
621 for (offset, &k) in col_indices[start..end].iter().enumerate() {
622 let sparse_idx = start + offset;
623 let sparse_val = self.values[sparse_idx];
624
625 for j in 0..dense_shape[1] {
626 result[[i, j]] += sparse_val * dense_arr[[k, j]];
627 }
628 }
629 }
630
631 Ok(Tensor::F32(result))
632 } else {
633 Err(TrustformersError::tensor_op_error(
634 "Invalid indices format for CSR tensor",
635 "CSR dense matmul",
636 ))
637 }
638 },
639 (SparseFormat::COO, Tensor::F32(dense_arr)) => {
640 if let SparseIndices::COO {
641 row_indices,
642 col_indices,
643 } = &self.indices
644 {
645 let result_shape = vec![self.shape[0], dense_shape[1]];
646 let mut result = ArrayD::zeros(IxDyn(&result_shape));
647
648 for ((row, col), val) in
649 row_indices.iter().zip(col_indices.iter()).zip(self.values.iter())
650 {
651 for j in 0..dense_shape[1] {
652 result[[*row, j]] += val * dense_arr[[*col, j]];
653 }
654 }
655
656 Ok(Tensor::F32(result))
657 } else {
658 Err(TrustformersError::tensor_op_error(
659 "Invalid indices format for COO tensor",
660 "COO dense matmul",
661 ))
662 }
663 },
664 _ => Err(TrustformersError::tensor_op_error(
665 "Sparse-dense multiplication not implemented for this format",
666 "sparse-dense matmul",
667 )),
668 }
669 }
670
671 pub fn add(&self, other: &SparseTensor) -> Result<SparseTensor> {
673 if self.shape != other.shape {
674 return Err(TrustformersError::shape_error(format!(
675 "Shape mismatch: {:?} vs {:?}",
676 self.shape, other.shape
677 )));
678 }
679
680 let lhs = self.to_format(SparseFormat::COO)?;
682 let rhs = other.to_format(SparseFormat::COO)?;
683
684 let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
685
686 if let SparseIndices::COO {
688 row_indices,
689 col_indices,
690 } = &lhs.indices
691 {
692 for ((&row, &col), &val) in
693 row_indices.iter().zip(col_indices.iter()).zip(lhs.values.iter())
694 {
695 result_map.insert((row, col), val);
696 }
697 }
698
699 if let SparseIndices::COO {
701 row_indices,
702 col_indices,
703 } = &rhs.indices
704 {
705 for ((&row, &col), &val) in
706 row_indices.iter().zip(col_indices.iter()).zip(rhs.values.iter())
707 {
708 *result_map.entry((row, col)).or_insert(0.0) += val;
709 }
710 }
711
712 let mut row_indices = Vec::new();
714 let mut col_indices = Vec::new();
715 let mut values = Vec::new();
716
717 for ((row, col), val) in result_map.iter() {
718 if val.abs() > 1e-10 {
719 row_indices.push(*row);
721 col_indices.push(*col);
722 values.push(*val);
723 }
724 }
725
726 SparseTensor::new_coo(self.shape.clone(), row_indices, col_indices, values)
727 }
728
729 pub fn mul_scalar(&self, scalar: f32) -> Result<SparseTensor> {
731 let mut result = self.clone();
732 for val in &mut result.values {
733 *val *= scalar;
734 }
735 Ok(result)
736 }
737
738 pub fn sparsity(&self) -> f32 {
740 let total_elements: usize = self.shape.iter().product();
741 1.0 - (self.nnz as f32 / total_elements as f32)
742 }
743
744 pub fn density(&self) -> f32 {
746 1.0 - self.sparsity()
747 }
748
749 pub fn shape(&self) -> &[usize] {
751 &self.shape
752 }
753
754 pub fn nnz(&self) -> usize {
756 self.nnz
757 }
758
759 pub fn memory_usage(&self) -> usize {
761 let values_size = self.values.len() * std::mem::size_of::<f32>();
762 let indices_size = match &self.indices {
763 SparseIndices::COO {
764 row_indices,
765 col_indices,
766 } => (row_indices.len() + col_indices.len()) * std::mem::size_of::<usize>(),
767 SparseIndices::CSR {
768 row_ptr,
769 col_indices,
770 } => (row_ptr.len() + col_indices.len()) * std::mem::size_of::<usize>(),
771 SparseIndices::CSC {
772 col_ptr,
773 row_indices,
774 } => (col_ptr.len() + row_indices.len()) * std::mem::size_of::<usize>(),
775 SparseIndices::BSR {
776 row_ptr,
777 col_indices,
778 ..
779 } => (row_ptr.len() + col_indices.len()) * std::mem::size_of::<usize>(),
780 SparseIndices::DOK { indices_map } => {
781 indices_map.len()
782 * (2 * std::mem::size_of::<usize>() + std::mem::size_of::<usize>())
783 },
784 };
785 values_size + indices_size
786 }
787
788 fn symbolic_sparse_matmul(
809 lhs_row_ptr: &[usize],
810 lhs_col_indices: &[usize],
811 rhs_row_ptr: &[usize],
812 rhs_col_indices: &[usize],
813 n_rows: usize,
814 n_cols: usize,
815 ) -> (Vec<usize>, Vec<usize>) {
816 let mut result_row_ptr = vec![0; n_rows + 1];
817 let mut column_markers = vec![false; n_cols];
818 let mut column_buffer = Vec::new();
819
820 for i in 0..n_rows {
822 column_buffer.clear();
823
824 for &lhs_k in &lhs_col_indices[lhs_row_ptr[i]..lhs_row_ptr[i + 1]] {
826 for &rhs_j in &rhs_col_indices[rhs_row_ptr[lhs_k]..rhs_row_ptr[lhs_k + 1]] {
828 if !column_markers[rhs_j] {
829 column_markers[rhs_j] = true;
830 column_buffer.push(rhs_j);
831 }
832 }
833 }
834
835 result_row_ptr[i + 1] = result_row_ptr[i] + column_buffer.len();
837 for &col in &column_buffer {
838 column_markers[col] = false;
839 }
840 }
841
842 let total_nnz = result_row_ptr[n_rows];
844 let mut result_col_indices = vec![0; total_nnz];
845 let mut current_idx = 0;
846
847 for i in 0..n_rows {
848 column_buffer.clear();
849
850 for &lhs_k in &lhs_col_indices[lhs_row_ptr[i]..lhs_row_ptr[i + 1]] {
852 for &rhs_j in &rhs_col_indices[rhs_row_ptr[lhs_k]..rhs_row_ptr[lhs_k + 1]] {
853 if !column_markers[rhs_j] {
854 column_markers[rhs_j] = true;
855 column_buffer.push(rhs_j);
856 }
857 }
858 }
859
860 column_buffer.sort_unstable();
862
863 for &col in &column_buffer {
865 result_col_indices[current_idx] = col;
866 current_idx += 1;
867 column_markers[col] = false;
868 }
869 }
870
871 (result_row_ptr, result_col_indices)
872 }
873
874 fn numerical_sparse_matmul(
894 lhs_values: &[f32],
895 lhs_row_ptr: &[usize],
896 lhs_col_indices: &[usize],
897 rhs_values: &[f32],
898 rhs_row_ptr: &[usize],
899 rhs_col_indices: &[usize],
900 result_row_ptr: &[usize],
901 result_col_indices: &[usize],
902 ) -> Vec<f32> {
903 let total_nnz = result_col_indices.len();
904 let mut result_values = vec![0.0; total_nnz];
905
906 let max_row_nnz = result_row_ptr.windows(2).map(|w| w[1] - w[0]).max().unwrap_or(0);
908
909 let mut workspace = vec![0.0; max_row_nnz];
910 let mut workspace_markers = vec![usize::MAX; max_row_nnz];
911
912 for i in 0..result_row_ptr.len() - 1 {
913 let row_start = result_row_ptr[i];
914 let row_end = result_row_ptr[i + 1];
915 let row_nnz = row_end - row_start;
916
917 workspace.fill(0.0);
919
920 for (pos, &col) in result_col_indices[row_start..row_end].iter().enumerate() {
922 workspace_markers[pos] = col;
923 }
924
925 for lhs_idx in lhs_row_ptr[i]..lhs_row_ptr[i + 1] {
927 let k = lhs_col_indices[lhs_idx];
928 let lhs_val = lhs_values[lhs_idx];
929
930 if rhs_row_ptr[k + 1] - rhs_row_ptr[k] > 32 {
932 Self::accumulate_with_binary_search(
934 &mut workspace,
935 &workspace_markers[0..row_nnz],
936 lhs_val,
937 &rhs_values[rhs_row_ptr[k]..rhs_row_ptr[k + 1]],
938 &rhs_col_indices[rhs_row_ptr[k]..rhs_row_ptr[k + 1]],
939 );
940 } else {
941 for rhs_idx in rhs_row_ptr[k]..rhs_row_ptr[k + 1] {
943 let j = rhs_col_indices[rhs_idx];
944 let rhs_val = rhs_values[rhs_idx];
945
946 for pos in 0..row_nnz {
948 if workspace_markers[pos] == j {
949 workspace[pos] += lhs_val * rhs_val;
950 break;
951 }
952 }
953 }
954 }
955 }
956
957 for (pos, &val) in workspace[0..row_nnz].iter().enumerate() {
959 result_values[row_start + pos] = val;
960 }
961 }
962
963 result_values
964 }
965
966 fn accumulate_with_binary_search(
972 workspace: &mut [f32],
973 workspace_cols: &[usize],
974 lhs_val: f32,
975 rhs_values: &[f32],
976 rhs_cols: &[usize],
977 ) {
978 for (rhs_idx, &rhs_col) in rhs_cols.iter().enumerate() {
979 let rhs_val = rhs_values[rhs_idx];
980
981 if let Ok(pos) = workspace_cols.binary_search(&rhs_col) {
983 workspace[pos] += lhs_val * rhs_val;
984 }
985 }
986 }
987}
988
989#[cfg(test)]
990mod tests {
991 use super::*;
992 use crate::tensor::Tensor;
993
994 #[test]
995 fn test_sparse_tensor_creation() {
996 let sparse = SparseTensor::new_coo(
997 vec![3, 3],
998 vec![0, 1, 2],
999 vec![0, 1, 2],
1000 vec![1.0, 2.0, 3.0],
1001 );
1002 assert!(sparse.is_ok());
1003 let sparse = sparse.expect("operation failed in test");
1004 assert_eq!(sparse.nnz(), 3);
1005 assert_eq!(sparse.shape(), &[3, 3]);
1006 }
1007
1008 #[test]
1009 fn test_sparse_to_dense() {
1010 let sparse = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1011 .expect("tensor operation failed");
1012
1013 let dense = sparse.to_dense().expect("operation failed in test");
1014 assert_eq!(dense.shape(), vec![2, 2]);
1015
1016 let data = dense.data().expect("operation failed in test");
1017 assert_eq!(data[0], 1.0); assert_eq!(data[1], 0.0); assert_eq!(data[2], 0.0); assert_eq!(data[3], 2.0); }
1022
1023 #[test]
1024 fn test_dense_to_sparse() {
1025 let dense = Tensor::new(vec![1.0, 0.0, 0.0, 2.0]).expect("tensor operation failed");
1026 let dense_2d = dense.reshape(&[2, 2]).expect("Reshape failed");
1027
1028 let sparse = SparseTensor::from_dense(&dense_2d, 0.5).expect("tensor operation failed");
1029 assert_eq!(sparse.nnz(), 2);
1030 assert_eq!(sparse.sparsity(), 0.5);
1031 }
1032
1033 #[test]
1034 fn test_coo_to_csr_conversion() {
1035 let sparse_coo = SparseTensor::new_coo(
1036 vec![3, 3],
1037 vec![0, 1, 2],
1038 vec![0, 1, 2],
1039 vec![1.0, 2.0, 3.0],
1040 )
1041 .expect("operation failed in test");
1042
1043 let sparse_csr = sparse_coo.to_format(SparseFormat::CSR).expect("operation failed in test");
1044 assert_eq!(sparse_csr.format, SparseFormat::CSR);
1045 assert_eq!(sparse_csr.nnz(), 3);
1046
1047 let dense = sparse_csr.to_dense().expect("operation failed in test");
1049 assert_eq!(dense.shape(), vec![3, 3]);
1050 }
1051
1052 #[test]
1053 fn test_sparse_addition() {
1054 let sparse1 = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1055 .expect("tensor operation failed");
1056
1057 let sparse2 = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![1, 0], vec![3.0, 4.0])
1058 .expect("tensor operation failed");
1059
1060 let result = sparse1.add(&sparse2).expect("Addition failed");
1061 assert_eq!(result.nnz(), 4); }
1063
1064 #[test]
1065 fn test_sparse_scalar_multiplication() {
1066 let sparse = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1067 .expect("tensor operation failed");
1068
1069 let result = sparse.mul_scalar(3.0).expect("operation failed in test");
1070 assert_eq!(result.values[0], 3.0);
1071 assert_eq!(result.values[1], 6.0);
1072 }
1073
1074 #[test]
1075 fn test_sparsity_calculation() {
1076 let sparse = SparseTensor::new_coo(vec![4, 4], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1077 .expect("tensor operation failed");
1078
1079 assert_eq!(sparse.sparsity(), 0.875); assert_eq!(sparse.density(), 0.125); }
1082
1083 #[test]
1084 fn test_sparse_dense_matmul() {
1085 let sparse = SparseTensor::new_csr(vec![2, 2], vec![0, 1, 2], vec![0, 1], vec![1.0, 2.0])
1086 .expect("tensor operation failed");
1087
1088 let dense = Tensor::new(vec![1.0, 0.0, 0.0, 1.0]).expect("tensor operation failed");
1089 let dense_2d = dense.reshape(&[2, 2]).expect("Reshape failed");
1090
1091 let result = sparse.dense_matmul(&dense_2d).expect("operation failed in test");
1092 assert_eq!(result.shape(), vec![2, 2]);
1093 }
1094
1095 #[test]
1096 fn test_memory_usage() {
1097 let sparse =
1098 SparseTensor::new_coo(vec![1000, 1000], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1099 .expect("operation failed in test");
1100
1101 let usage = sparse.memory_usage();
1102 assert!(usage > 0);
1103
1104 let dense_usage = 1000 * 1000 * std::mem::size_of::<f32>();
1106 assert!(usage < dense_usage / 10);
1107 }
1108}