Skip to main content

scivex_core/tensor/
named.rs

1//! Named dimensions for tensors.
2//!
3//! [`NamedTensor`] wraps a [`Tensor`] and associates optional string names
4//! with each dimension, enabling dimension lookup and reordering by name.
5
6use crate::Scalar;
7use crate::dtype::Float;
8use crate::error::{CoreError, Result};
9
10use super::{Tensor, compute_strides};
11
12/// A tensor with optional names attached to each dimension.
13///
14/// Unnamed dimensions use `None`. Named dimensions must be unique within
15/// a single tensor.
16///
17/// # Examples
18///
19/// ```
20/// # use scivex_core::tensor::named::NamedTensor;
21/// # use scivex_core::Tensor;
22/// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
23/// let nt = NamedTensor::new(t, vec![Some("batch".into()), Some("feature".into())]).unwrap();
24/// assert_eq!(nt.dim_index("batch").unwrap(), 0);
25/// assert_eq!(nt.dim_index("feature").unwrap(), 1);
26/// ```
27#[derive(Debug, Clone)]
28pub struct NamedTensor<T: Scalar> {
29    tensor: Tensor<T>,
30    names: Vec<Option<String>>,
31}
32
33impl<T: Scalar> NamedTensor<T> {
34    /// Create a named tensor, validating that the number of names matches
35    /// the tensor's rank and that named dimensions are unique.
36    ///
37    /// # Errors
38    ///
39    /// Returns [`CoreError::InvalidArgument`] if `names.len() != tensor.ndim()`
40    /// or if duplicate dimension names are found.
41    pub fn new(tensor: Tensor<T>, names: Vec<Option<String>>) -> Result<Self> {
42        if names.len() != tensor.ndim() {
43            return Err(CoreError::InvalidArgument {
44                reason: "number of dimension names must match tensor rank",
45            });
46        }
47        // Check for duplicate named dimensions.
48        let named: Vec<&str> = names.iter().filter_map(|n| n.as_deref()).collect();
49        let mut sorted = named.clone();
50        sorted.sort_unstable();
51        for window in sorted.windows(2) {
52            if window[0] == window[1] {
53                return Err(CoreError::InvalidArgument {
54                    reason: "duplicate dimension names are not allowed",
55                });
56            }
57        }
58        Ok(Self { tensor, names })
59    }
60
61    /// Wrap a tensor with all dimensions unnamed.
62    pub fn from_tensor(tensor: Tensor<T>) -> Self {
63        let ndim = tensor.ndim();
64        Self {
65            tensor,
66            names: vec![None; ndim],
67        }
68    }
69
70    /// Borrow the inner tensor.
71    #[inline]
72    pub fn tensor(&self) -> &Tensor<T> {
73        &self.tensor
74    }
75
76    /// Get the dimension names.
77    #[inline]
78    pub fn names(&self) -> &[Option<String>] {
79        &self.names
80    }
81
82    /// Consume the wrapper and return the plain tensor.
83    #[inline]
84    pub fn into_tensor(self) -> Tensor<T> {
85        self.tensor
86    }
87
88    /// Look up the axis index for a named dimension.
89    ///
90    /// # Errors
91    ///
92    /// Returns [`CoreError::InvalidArgument`] if no dimension has the given name.
93    pub fn dim_index(&self, name: &str) -> Result<usize> {
94        self.names
95            .iter()
96            .position(|n| n.as_deref() == Some(name))
97            .ok_or(CoreError::InvalidArgument {
98                reason: "dimension name not found",
99            })
100    }
101
102    /// Rename a dimension from `old` to `new`.
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if `old` is not found or `new` already exists.
107    pub fn rename(&mut self, old: &str, new: &str) -> Result<()> {
108        // Check that new name doesn't already exist.
109        if self.names.iter().any(|n| n.as_deref() == Some(new)) {
110            return Err(CoreError::InvalidArgument {
111                reason: "new dimension name already exists",
112            });
113        }
114        let idx = self.dim_index(old)?;
115        self.names[idx] = Some(new.to_string());
116        Ok(())
117    }
118
119    /// Replace all dimension names at once.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the length does not match the tensor rank.
124    pub fn set_names(&mut self, names: Vec<Option<String>>) -> Result<()> {
125        if names.len() != self.tensor.ndim() {
126            return Err(CoreError::InvalidArgument {
127                reason: "number of dimension names must match tensor rank",
128            });
129        }
130        self.names = names;
131        Ok(())
132    }
133
134    /// Reorder dimensions by name, producing a permuted copy.
135    ///
136    /// Every named dimension in the current tensor must appear in `target_names`,
137    /// and the lengths must match the tensor rank.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if a name is not found or if the number of names does
142    /// not match the rank.
143    pub fn align_to(&self, target_names: &[&str]) -> Result<NamedTensor<T>> {
144        if target_names.len() != self.tensor.ndim() {
145            return Err(CoreError::InvalidArgument {
146                reason: "target names length must match tensor rank",
147            });
148        }
149
150        // Build permutation: perm[i] = source axis for target axis i.
151        let perm: Vec<usize> = target_names
152            .iter()
153            .map(|name| self.dim_index(name))
154            .collect::<Result<Vec<_>>>()?;
155
156        let src_shape = self.tensor.shape();
157        let src_strides = self.tensor.strides();
158        let src_data = self.tensor.as_slice();
159
160        // New shape and names according to the permutation.
161        let new_shape: Vec<usize> = perm.iter().map(|&p| src_shape[p]).collect();
162        let new_names: Vec<Option<String>> = perm.iter().map(|&p| self.names[p].clone()).collect();
163        let new_strides = compute_strides(&new_shape);
164
165        let numel: usize = new_shape.iter().product();
166        let mut new_data = vec![T::zero(); numel];
167
168        // Iterate over every element in the output tensor, compute source index.
169        for (out_flat, dest) in new_data.iter_mut().enumerate() {
170            // Convert out_flat to multi-dim index in the output.
171            let mut remaining = out_flat;
172            let mut src_flat = 0usize;
173            for (dim, &stride) in new_strides.iter().enumerate() {
174                let idx = remaining / stride;
175                remaining %= stride;
176                src_flat += idx * src_strides[perm[dim]];
177            }
178            *dest = src_data[src_flat];
179        }
180
181        let new_tensor = Tensor::from_vec(new_data, new_shape)?;
182        Ok(NamedTensor {
183            tensor: new_tensor,
184            names: new_names,
185        })
186    }
187
188    /// Select a single index along a named dimension, reducing the rank by one.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if the name is not found or the index is out of bounds.
193    pub fn select(&self, name: &str, index: usize) -> Result<NamedTensor<T>> {
194        let axis = self.dim_index(name)?;
195        let shape = self.tensor.shape();
196        if index >= shape[axis] {
197            return Err(CoreError::IndexOutOfBounds {
198                index: vec![index],
199                shape: shape.to_vec(),
200            });
201        }
202
203        let ndim = shape.len();
204        let strides = self.tensor.strides();
205        let src_data = self.tensor.as_slice();
206
207        // New shape: remove the selected axis.
208        let new_shape: Vec<usize> = shape
209            .iter()
210            .enumerate()
211            .filter(|&(i, _)| i != axis)
212            .map(|(_, &s)| s)
213            .collect();
214        let new_names: Vec<Option<String>> = self
215            .names
216            .iter()
217            .enumerate()
218            .filter(|&(i, _)| i != axis)
219            .map(|(_, n)| n.clone())
220            .collect();
221
222        let numel: usize = new_shape.iter().product();
223        let new_strides = compute_strides(&new_shape);
224        let mut new_data = vec![T::zero(); numel];
225
226        // Build a mapping from output dim to source dim (skipping `axis`).
227        let dim_map: Vec<usize> = (0..ndim).filter(|&d| d != axis).collect();
228
229        for (out_flat, dest) in new_data.iter_mut().enumerate() {
230            let mut remaining = out_flat;
231            let mut src_flat = index * strides[axis];
232            for (out_dim, &src_dim) in dim_map.iter().enumerate() {
233                let idx = if out_dim < new_strides.len() {
234                    let i = remaining / new_strides[out_dim];
235                    remaining %= new_strides[out_dim];
236                    i
237                } else {
238                    remaining
239                };
240                src_flat += idx * strides[src_dim];
241            }
242            *dest = src_data[src_flat];
243        }
244
245        let new_tensor = Tensor::from_vec(new_data, new_shape)?;
246        Ok(NamedTensor {
247            tensor: new_tensor,
248            names: new_names,
249        })
250    }
251}
252
253impl<T: Scalar + Float> NamedTensor<T> {
254    /// Sum along a named dimension, removing it from the result.
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if the name is not found.
259    pub fn sum_dim(&self, name: &str) -> Result<NamedTensor<T>> {
260        let axis = self.dim_index(name)?;
261        let shape = self.tensor.shape();
262        let strides = self.tensor.strides();
263        let src_data = self.tensor.as_slice();
264        let ndim = shape.len();
265        let axis_len = shape[axis];
266
267        let new_shape: Vec<usize> = shape
268            .iter()
269            .enumerate()
270            .filter(|&(i, _)| i != axis)
271            .map(|(_, &s)| s)
272            .collect();
273        let new_names: Vec<Option<String>> = self
274            .names
275            .iter()
276            .enumerate()
277            .filter(|&(i, _)| i != axis)
278            .map(|(_, n)| n.clone())
279            .collect();
280
281        let numel: usize = new_shape.iter().product();
282        let new_strides = compute_strides(&new_shape);
283        let mut new_data = vec![T::zero(); numel];
284
285        // Build a mapping from output dim to source dim (skipping `axis`).
286        let dim_map: Vec<usize> = (0..ndim).filter(|&d| d != axis).collect();
287
288        for (out_flat, dest) in new_data.iter_mut().enumerate() {
289            let mut remaining = out_flat;
290            // Decode the output flat index into per-source-dim indices (skipping axis).
291            let mut out_indices = vec![0usize; ndim];
292            for (out_dim, &src_dim) in dim_map.iter().enumerate() {
293                let idx = if out_dim < new_strides.len() {
294                    let i = remaining / new_strides[out_dim];
295                    remaining %= new_strides[out_dim];
296                    i
297                } else {
298                    remaining
299                };
300                out_indices[src_dim] = idx;
301            }
302
303            let mut acc = T::zero();
304            for k in 0..axis_len {
305                out_indices[axis] = k;
306                let src_flat: usize = out_indices
307                    .iter()
308                    .zip(strides.iter())
309                    .map(|(&idx, &s)| idx * s)
310                    .sum();
311                acc += src_data[src_flat];
312            }
313            *dest = acc;
314        }
315
316        let new_tensor = Tensor::from_vec(new_data, new_shape)?;
317        Ok(NamedTensor {
318            tensor: new_tensor,
319            names: new_names,
320        })
321    }
322
323    /// Mean along a named dimension, removing it from the result.
324    ///
325    /// # Errors
326    ///
327    /// Returns an error if the name is not found.
328    pub fn mean_dim(&self, name: &str) -> Result<NamedTensor<T>> {
329        let axis = self.dim_index(name)?;
330        let axis_len = self.tensor.shape()[axis];
331        let summed = self.sum_dim(name)?;
332        let divisor = T::from_usize(axis_len);
333        let result_tensor = summed.tensor.map(|x| x / divisor);
334        Ok(NamedTensor {
335            tensor: result_tensor,
336            names: summed.names,
337        })
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_named_tensor_basic() {
347        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
348        let nt = NamedTensor::new(t, vec![Some("batch".into()), Some("feature".into())]).unwrap();
349        assert_eq!(nt.names().len(), 2);
350        assert_eq!(nt.names()[0].as_deref(), Some("batch"));
351        assert_eq!(nt.names()[1].as_deref(), Some("feature"));
352        assert_eq!(nt.tensor().shape(), &[2, 3]);
353    }
354
355    #[test]
356    fn test_rename_dimension() {
357        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
358        let mut nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
359        nt.rename("rows", "samples").unwrap();
360        assert_eq!(nt.names()[0].as_deref(), Some("samples"));
361        assert_eq!(nt.dim_index("samples").unwrap(), 0);
362        assert!(nt.dim_index("rows").is_err());
363    }
364
365    #[test]
366    fn test_align_to() {
367        // 2x3x4 tensor with named dims
368        let numel = 2 * 3 * 4;
369        let data: Vec<f64> = (0..numel).map(f64::from).collect();
370        let t = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
371        let nt = NamedTensor::new(
372            t.clone(),
373            vec![
374                Some("batch".into()),
375                Some("channel".into()),
376                Some("width".into()),
377            ],
378        )
379        .unwrap();
380
381        // Reorder to (channel, width, batch) = (3, 4, 2)
382        let aligned = nt.align_to(&["channel", "width", "batch"]).unwrap();
383        assert_eq!(aligned.tensor().shape(), &[3, 4, 2]);
384        assert_eq!(aligned.names()[0].as_deref(), Some("channel"));
385        assert_eq!(aligned.names()[1].as_deref(), Some("width"));
386        assert_eq!(aligned.names()[2].as_deref(), Some("batch"));
387
388        // Verify a specific element: original [1, 2, 3] should appear at aligned [2, 3, 1]
389        let original_val = *t.get(&[1, 2, 3]).unwrap();
390        let aligned_val = *aligned.tensor().get(&[2, 3, 1]).unwrap();
391        assert!((original_val - aligned_val).abs() < 1e-15);
392    }
393
394    #[test]
395    fn test_dim_index() {
396        let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
397        let nt = NamedTensor::new(t, vec![Some("time".into())]).unwrap();
398        assert_eq!(nt.dim_index("time").unwrap(), 0);
399        assert!(nt.dim_index("space").is_err());
400    }
401
402    #[test]
403    fn test_sum_dim() {
404        // 2x3 tensor, sum along "rows" (axis 0)
405        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
406        let nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
407        let summed = nt.sum_dim("rows").unwrap();
408        assert_eq!(summed.tensor().shape(), &[3]);
409        let data = summed.tensor().as_slice();
410        assert!((data[0] - 5.0).abs() < 1e-15); // 1 + 4
411        assert!((data[1] - 7.0).abs() < 1e-15); // 2 + 5
412        assert!((data[2] - 9.0).abs() < 1e-15); // 3 + 6
413        assert_eq!(summed.names()[0].as_deref(), Some("cols"));
414    }
415
416    #[test]
417    fn test_select() {
418        // 2x3 tensor, select row 1 along "rows"
419        let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
420        let nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
421        let selected = nt.select("rows", 1).unwrap();
422        assert_eq!(selected.tensor().shape(), &[3]);
423        assert_eq!(selected.tensor().as_slice(), &[4.0, 5.0, 6.0]);
424        assert_eq!(selected.names()[0].as_deref(), Some("cols"));
425    }
426
427    #[test]
428    fn test_invalid_names_length() {
429        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
430        let result = NamedTensor::new(t, vec![Some("a".into()), Some("b".into())]);
431        assert!(result.is_err());
432    }
433}