Skip to main content

trustformers_core/tensor/
sparse.rs

1//! Sparse tensor operations.
2//!
3//! This module contains functions for working with sparse tensors.
4
5use super::Tensor;
6use crate::errors::{Result, TrustformersError};
7use crate::sparse_tensor::SparseTensor;
8
9impl Tensor {
10    /// Convert a dense tensor to sparse format.
11    ///
12    /// # Arguments
13    ///
14    /// * `threshold` - Values below this threshold will be considered zero
15    ///
16    /// # Returns
17    ///
18    /// A sparse tensor representation.
19    pub fn to_sparse(&self, threshold: f32) -> Result<Tensor> {
20        match self {
21            Tensor::F32(a) => {
22                let sparse = SparseTensor::from_dense(&Tensor::F32(a.clone()), threshold)?;
23                Ok(Tensor::Sparse(sparse))
24            },
25            Tensor::Sparse(_) => {
26                // Already sparse
27                Ok(self.clone())
28            },
29            _ => Err(TrustformersError::tensor_op_error(
30                "Cannot convert this tensor type to sparse",
31                "Tensor::to_sparse",
32            )),
33        }
34    }
35
36    /// Convert a sparse tensor to dense format.
37    ///
38    /// # Returns
39    ///
40    /// A dense tensor representation.
41    pub fn to_dense(&self) -> Result<Tensor> {
42        match self {
43            Tensor::Sparse(s) => s.to_dense(),
44            Tensor::F32(_) | Tensor::F64(_) | Tensor::I64(_) => {
45                // Already dense
46                Ok(self.clone())
47            },
48            _ => Err(TrustformersError::tensor_op_error(
49                "Cannot convert this tensor type to dense",
50                "Tensor::to_dense",
51            )),
52        }
53    }
54
55    /// Check if the tensor is sparse.
56    ///
57    /// # Returns
58    ///
59    /// True if the tensor is sparse, false otherwise.
60    pub fn is_sparse(&self) -> bool {
61        matches!(self, Tensor::Sparse(_))
62    }
63
64    /// Get the sparsity ratio of the tensor.
65    ///
66    /// # Returns
67    ///
68    /// The ratio of zero elements to total elements.
69    pub fn sparsity(&self) -> Result<f32> {
70        match self {
71            Tensor::Sparse(s) => Ok(s.sparsity()),
72            Tensor::F32(a) => {
73                let total = a.len() as f32;
74                let zeros = a.iter().filter(|&&x| x == 0.0).count() as f32;
75                Ok(zeros / total)
76            },
77            _ => Err(TrustformersError::tensor_op_error(
78                "Sparsity calculation not supported for this tensor type",
79                "Tensor::sparsity",
80            )),
81        }
82    }
83
84    /// Get the number of non-zero elements.
85    ///
86    /// # Returns
87    ///
88    /// The number of non-zero elements.
89    pub fn nnz(&self) -> Result<usize> {
90        match self {
91            Tensor::Sparse(s) => Ok(s.nnz()),
92            Tensor::F32(a) => Ok(a.iter().filter(|&&x| x != 0.0).count()),
93            _ => Err(TrustformersError::tensor_op_error(
94                "NNZ calculation not supported for this tensor type",
95                "Tensor::nnz",
96            )),
97        }
98    }
99
100    /// Create a sparse tensor in COO format.
101    ///
102    /// # Arguments
103    ///
104    /// * `indices` - Coordinate indices
105    /// * `values` - Non-zero values
106    /// * `shape` - Tensor shape
107    ///
108    /// # Returns
109    ///
110    /// A sparse tensor in COO format.
111    pub fn sparse_coo(
112        indices: Vec<Vec<usize>>,
113        values: Vec<f32>,
114        shape: Vec<usize>,
115    ) -> Result<Tensor> {
116        let sparse = SparseTensor::new_coo(shape, indices[0].clone(), indices[1].clone(), values)?;
117        Ok(Tensor::Sparse(sparse))
118    }
119
120    /// Create a sparse tensor in CSR format.
121    ///
122    /// # Arguments
123    ///
124    /// * `row_ptr` - Row pointers
125    /// * `col_indices` - Column indices
126    /// * `values` - Non-zero values
127    /// * `shape` - Tensor shape
128    ///
129    /// # Returns
130    ///
131    /// A sparse tensor in CSR format.
132    pub fn sparse_csr(
133        row_ptr: Vec<usize>,
134        col_indices: Vec<usize>,
135        values: Vec<f32>,
136        shape: Vec<usize>,
137    ) -> Result<Tensor> {
138        let sparse = SparseTensor::new_csr(shape, row_ptr, col_indices, values)?;
139        Ok(Tensor::Sparse(sparse))
140    }
141}