trustformers_core/tensor/sparse.rs
1//! Sparse tensor operations.
2//!
3//! This module contains functions for working with sparse tensors.
4
5use super::Tensor;
6use crate::errors::{Result, TrustformersError};
7use crate::sparse_tensor::SparseTensor;
8
9impl Tensor {
10 /// Convert a dense tensor to sparse format.
11 ///
12 /// # Arguments
13 ///
14 /// * `threshold` - Values below this threshold will be considered zero
15 ///
16 /// # Returns
17 ///
18 /// A sparse tensor representation.
19 pub fn to_sparse(&self, threshold: f32) -> Result<Tensor> {
20 match self {
21 Tensor::F32(a) => {
22 let sparse = SparseTensor::from_dense(&Tensor::F32(a.clone()), threshold)?;
23 Ok(Tensor::Sparse(sparse))
24 },
25 Tensor::Sparse(_) => {
26 // Already sparse
27 Ok(self.clone())
28 },
29 _ => Err(TrustformersError::tensor_op_error(
30 "Cannot convert this tensor type to sparse",
31 "Tensor::to_sparse",
32 )),
33 }
34 }
35
36 /// Convert a sparse tensor to dense format.
37 ///
38 /// # Returns
39 ///
40 /// A dense tensor representation.
41 pub fn to_dense(&self) -> Result<Tensor> {
42 match self {
43 Tensor::Sparse(s) => s.to_dense(),
44 Tensor::F32(_) | Tensor::F64(_) | Tensor::I64(_) => {
45 // Already dense
46 Ok(self.clone())
47 },
48 _ => Err(TrustformersError::tensor_op_error(
49 "Cannot convert this tensor type to dense",
50 "Tensor::to_dense",
51 )),
52 }
53 }
54
55 /// Check if the tensor is sparse.
56 ///
57 /// # Returns
58 ///
59 /// True if the tensor is sparse, false otherwise.
60 pub fn is_sparse(&self) -> bool {
61 matches!(self, Tensor::Sparse(_))
62 }
63
64 /// Get the sparsity ratio of the tensor.
65 ///
66 /// # Returns
67 ///
68 /// The ratio of zero elements to total elements.
69 pub fn sparsity(&self) -> Result<f32> {
70 match self {
71 Tensor::Sparse(s) => Ok(s.sparsity()),
72 Tensor::F32(a) => {
73 let total = a.len() as f32;
74 let zeros = a.iter().filter(|&&x| x == 0.0).count() as f32;
75 Ok(zeros / total)
76 },
77 _ => Err(TrustformersError::tensor_op_error(
78 "Sparsity calculation not supported for this tensor type",
79 "Tensor::sparsity",
80 )),
81 }
82 }
83
84 /// Get the number of non-zero elements.
85 ///
86 /// # Returns
87 ///
88 /// The number of non-zero elements.
89 pub fn nnz(&self) -> Result<usize> {
90 match self {
91 Tensor::Sparse(s) => Ok(s.nnz()),
92 Tensor::F32(a) => Ok(a.iter().filter(|&&x| x != 0.0).count()),
93 _ => Err(TrustformersError::tensor_op_error(
94 "NNZ calculation not supported for this tensor type",
95 "Tensor::nnz",
96 )),
97 }
98 }
99
100 /// Create a sparse tensor in COO format.
101 ///
102 /// # Arguments
103 ///
104 /// * `indices` - Coordinate indices
105 /// * `values` - Non-zero values
106 /// * `shape` - Tensor shape
107 ///
108 /// # Returns
109 ///
110 /// A sparse tensor in COO format.
111 pub fn sparse_coo(
112 indices: Vec<Vec<usize>>,
113 values: Vec<f32>,
114 shape: Vec<usize>,
115 ) -> Result<Tensor> {
116 let sparse = SparseTensor::new_coo(shape, indices[0].clone(), indices[1].clone(), values)?;
117 Ok(Tensor::Sparse(sparse))
118 }
119
120 /// Create a sparse tensor in CSR format.
121 ///
122 /// # Arguments
123 ///
124 /// * `row_ptr` - Row pointers
125 /// * `col_indices` - Column indices
126 /// * `values` - Non-zero values
127 /// * `shape` - Tensor shape
128 ///
129 /// # Returns
130 ///
131 /// A sparse tensor in CSR format.
132 pub fn sparse_csr(
133 row_ptr: Vec<usize>,
134 col_indices: Vec<usize>,
135 values: Vec<f32>,
136 shape: Vec<usize>,
137 ) -> Result<Tensor> {
138 let sparse = SparseTensor::new_csr(shape, row_ptr, col_indices, values)?;
139 Ok(Tensor::Sparse(sparse))
140 }
141}