Skip to main content

torsh_functional/sparse/
core.rs

1//! Core sparse tensor implementation
2//!
3//! This module provides the SparseTensor struct and basic operations for sparse tensors
4//! using COO (coordinate) format.
5
6use std::collections::HashMap;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10/// Sparse tensor representation using COO (coordinate) format
11#[derive(Debug, Clone)]
12pub struct SparseTensor {
13    /// Tensor values (non-zero elements)
14    pub values: Tensor,
15    /// Indices of non-zero elements [2, nnz] for 2D, [3, nnz] for 3D, etc.
16    pub indices: Tensor,
17    /// Shape of the full tensor
18    pub shape: Vec<usize>,
19    /// Number of dimensions
20    pub ndim: usize,
21    /// Number of non-zero elements
22    pub nnz: usize,
23    /// Whether the tensor is coalesced (indices are ordered and unique)
24    pub is_coalesced: bool,
25}
26
27impl SparseTensor {
28    /// Create a new sparse tensor from values and indices
29    pub fn new(values: Tensor, indices: Tensor, shape: Vec<usize>) -> TorshResult<Self> {
30        let values_shape = values.shape().dims().to_vec();
31        let indices_shape = indices.shape().dims().to_vec();
32
33        if values_shape.len() != 1 {
34            return Err(TorshError::invalid_argument_with_context(
35                "Values must be a 1D tensor",
36                "SparseTensor::new",
37            ));
38        }
39
40        if indices_shape.len() != 2 {
41            return Err(TorshError::invalid_argument_with_context(
42                "Indices must be a 2D tensor",
43                "SparseTensor::new",
44            ));
45        }
46
47        let nnz = values_shape[0];
48        let ndim = shape.len();
49
50        if indices_shape[0] != ndim {
51            return Err(TorshError::invalid_argument_with_context(
52                "Indices first dimension must equal tensor ndim",
53                "SparseTensor::new",
54            ));
55        }
56
57        if indices_shape[1] != nnz {
58            return Err(TorshError::invalid_argument_with_context(
59                "Indices second dimension must equal number of values",
60                "SparseTensor::new",
61            ));
62        }
63
64        Ok(SparseTensor {
65            values,
66            indices,
67            shape,
68            ndim,
69            nnz,
70            is_coalesced: false,
71        })
72    }
73
74    /// Create a sparse tensor from dense tensor
75    pub fn from_dense(dense: &Tensor) -> TorshResult<Self> {
76        let shape = dense.shape().dims().to_vec();
77        let ndim = shape.len();
78
79        // Find non-zero elements
80        let dense_data = dense.to_vec()?;
81        let mut values_vec = Vec::new();
82        let mut coords_vec = Vec::new(); // Store all coordinates temporarily
83
84        // Iterate through all elements
85        let total_elements: usize = shape.iter().product();
86        for flat_idx in 0..total_elements {
87            let value = dense_data[flat_idx];
88            if value.abs() > 1e-8 {
89                // Consider as non-zero
90                values_vec.push(value);
91
92                // Convert flat index to multi-dimensional indices
93                let mut remaining = flat_idx;
94                let mut coords = Vec::with_capacity(ndim);
95                for &dim_size in shape.iter().rev() {
96                    coords.push(remaining % dim_size);
97                    remaining /= dim_size;
98                }
99                coords.reverse();
100
101                coords_vec.push(coords);
102            }
103        }
104
105        let nnz = values_vec.len();
106
107        // Build indices in dimension-by-dimension order (not element-by-element)
108        // Indices should be [ndim, nnz] where row i contains all coords for dimension i
109        let mut indices_vec = Vec::with_capacity(ndim * nnz);
110        for dim in 0..ndim {
111            for coords in &coords_vec {
112                indices_vec.push(coords[dim] as f32);
113            }
114        }
115
116        let values = Tensor::from_data(values_vec, vec![nnz], dense.device())?;
117        let indices = Tensor::from_data(indices_vec, vec![ndim, nnz], dense.device())?;
118
119        Ok(SparseTensor {
120            values,
121            indices,
122            shape,
123            ndim,
124            nnz,
125            is_coalesced: false,
126        })
127    }
128
129    /// Convert sparse tensor to dense tensor
130    pub fn to_dense(&self) -> TorshResult<Tensor> {
131        let total_elements: usize = self.shape.iter().product();
132        let mut dense_data = vec![0.0f32; total_elements];
133
134        let values_data = self.values.to_vec()?;
135        let indices_data = self.indices.to_vec()?;
136
137        for i in 0..self.nnz {
138            // Extract coordinates for this non-zero element
139            let mut flat_idx = 0;
140            let mut stride = 1;
141
142            for j in (0..self.ndim).rev() {
143                let coord = indices_data[j * self.nnz + i] as usize;
144                flat_idx += coord * stride;
145                stride *= self.shape[j];
146            }
147
148            dense_data[flat_idx] = values_data[i];
149        }
150
151        Tensor::from_data(dense_data, self.shape.clone(), self.values.device())
152    }
153
154    /// Coalesce the sparse tensor (combine duplicate indices)
155    pub fn coalesce(&mut self) -> TorshResult<()> {
156        if self.is_coalesced {
157            return Ok(());
158        }
159
160        let values_data = self.values.to_vec()?;
161        let indices_data = self.indices.to_vec()?;
162
163        // Group by indices and sum values
164        let mut index_to_value: HashMap<Vec<usize>, f32> = HashMap::new();
165
166        for i in 0..self.nnz {
167            let mut coords = Vec::with_capacity(self.ndim);
168            for j in 0..self.ndim {
169                coords.push(indices_data[j * self.nnz + i] as usize);
170            }
171
172            *index_to_value.entry(coords).or_insert(0.0) += values_data[i];
173        }
174
175        // Filter out zero values and create new arrays
176        let mut new_values = Vec::new();
177        let mut new_indices = Vec::new();
178
179        for (coords, value) in index_to_value {
180            if value.abs() > 1e-8 {
181                new_values.push(value);
182                for coord in coords {
183                    new_indices.push(coord as f32);
184                }
185            }
186        }
187
188        let new_nnz = new_values.len();
189        self.values = Tensor::from_data(new_values, vec![new_nnz], self.values.device())?;
190        self.indices =
191            Tensor::from_data(new_indices, vec![self.ndim, new_nnz], self.indices.device())?;
192        self.nnz = new_nnz;
193        self.is_coalesced = true;
194
195        Ok(())
196    }
197
198    /// Get the number of non-zero elements
199    pub fn nnz(&self) -> usize {
200        self.nnz
201    }
202
203    /// Get the shape of the tensor
204    pub fn shape(&self) -> &[usize] {
205        &self.shape
206    }
207
208    /// Get the number of dimensions
209    pub fn ndim(&self) -> usize {
210        self.ndim
211    }
212
213    /// Check if the tensor is coalesced
214    pub fn is_coalesced(&self) -> bool {
215        self.is_coalesced
216    }
217}
218
219/// Create a sparse tensor from COO format
220pub fn sparse_coo_tensor(
221    indices: &Tensor,
222    values: &Tensor,
223    shape: &[usize],
224) -> TorshResult<SparseTensor> {
225    SparseTensor::new(values.clone(), indices.clone(), shape.to_vec())
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_sparse_tensor_creation() -> TorshResult<()> {
234        let values = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
235        let indices = Tensor::from_data(
236            vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
237            vec![2, 3],
238            torsh_core::DeviceType::Cpu,
239        )?;
240        let shape = vec![3, 3];
241
242        let sparse = SparseTensor::new(values, indices, shape)?;
243        assert_eq!(sparse.nnz(), 3);
244        assert_eq!(sparse.shape(), &[3, 3]);
245        assert_eq!(sparse.ndim(), 2);
246
247        Ok(())
248    }
249
250    #[test]
251    fn test_sparse_to_dense() -> TorshResult<()> {
252        let values = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
253        let indices = Tensor::from_data(
254            vec![0.0, 1.0, 0.0, 1.0],
255            vec![2, 2],
256            torsh_core::DeviceType::Cpu,
257        )?;
258        let shape = vec![2, 2];
259
260        let sparse = SparseTensor::new(values, indices, shape)?;
261        let dense = sparse.to_dense()?;
262
263        let expected_data = vec![1.0, 0.0, 0.0, 2.0];
264        let dense_data = dense.to_vec()?;
265
266        for (actual, expected) in dense_data.iter().zip(expected_data.iter()) {
267            assert!((actual - expected).abs() < 1e-6);
268        }
269
270        Ok(())
271    }
272}