Skip to main content

scivex_core/tensor/
mod.rs

1//! N-dimensional tensor type with dynamic shape and contiguous storage.
2//!
3//! The [`Tensor`] type is the fundamental data structure in Scivex, analogous
4//! to `NumPy`'s `ndarray`. It stores elements in row-major (C) order by default
5//! and is generic over any type implementing [`Scalar`].
6
7mod create;
8mod display;
9mod ops;
10mod reshape;
11mod sort;
12
13pub mod einsum;
14pub mod einsum_path;
15pub mod indexing;
16pub mod named;
17pub mod sparse;
18
19pub use indexing::SliceRange;
20
21use crate::Scalar;
22use crate::dtype::Float;
23use crate::error::{CoreError, Result};
24
25/// An N-dimensional tensor with dynamic shape.
26///
27/// Data is stored contiguously in row-major (C) order. The tensor owns its
28/// data and cloning performs a deep copy.
29///
30/// # Type Parameters
31///
32/// - `T`: The element type, which must implement [`Scalar`].
33///
34/// # Examples
35///
36/// ```
37/// # use scivex_core::Tensor;
38/// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
39/// assert_eq!(t.shape(), &[2, 2]);
40/// assert_eq!(t.numel(), 4);
41/// ```
42#[cfg_attr(
43    feature = "serde-support",
44    derive(serde::Serialize, serde::Deserialize)
45)]
46#[derive(Debug, Clone)]
47pub struct Tensor<T: Scalar> {
48    data: Vec<T>,
49    shape: Vec<usize>,
50    strides: Vec<usize>,
51}
52
53impl<T: Scalar> Tensor<T> {
54    // ------------------------------------------------------------------
55    // Construction from raw parts
56    // ------------------------------------------------------------------
57
58    /// Create a tensor from a flat data vector and a shape.
59    ///
60    /// Returns an error if the product of `shape` does not equal `data.len()`.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// # use scivex_core::Tensor;
66    /// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
67    /// assert_eq!(t.shape(), &[2, 3]);
68    /// assert_eq!(t.numel(), 6);
69    /// ```
70    pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Result<Self> {
71        let numel: usize = shape.iter().product();
72        if numel != data.len() {
73            return Err(CoreError::InvalidShape {
74                shape: shape.clone(),
75                reason: "shape product does not match data length",
76            });
77        }
78        let strides = compute_strides(&shape);
79        Ok(Self {
80            data,
81            shape,
82            strides,
83        })
84    }
85
86    /// Create a tensor from a flat slice and a shape (copies the data).
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// # use scivex_core::Tensor;
92    /// let data = [1, 2, 3, 4];
93    /// let t = Tensor::from_slice(&data, vec![2, 2]).unwrap();
94    /// assert_eq!(t.shape(), &[2, 2]);
95    /// assert_eq!(*t.get(&[1, 0]).unwrap(), 3);
96    /// ```
97    pub fn from_slice(data: &[T], shape: Vec<usize>) -> Result<Self> {
98        Self::from_vec(data.to_vec(), shape)
99    }
100
101    /// Create a scalar (0-dimensional) tensor.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// # use scivex_core::Tensor;
107    /// let t = Tensor::scalar(42.0_f64);
108    /// assert_eq!(t.ndim(), 0);
109    /// assert_eq!(t.numel(), 1);
110    /// assert_eq!(t.as_slice(), &[42.0]);
111    /// ```
112    pub fn scalar(value: T) -> Self {
113        Self {
114            data: vec![value],
115            shape: vec![],
116            strides: vec![],
117        }
118    }
119
120    // ------------------------------------------------------------------
121    // Accessors
122    // ------------------------------------------------------------------
123
124    /// The shape of the tensor as a slice.
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// # use scivex_core::Tensor;
130    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
131    /// assert_eq!(t.shape(), &[2, 3]);
132    /// ```
133    #[inline]
134    pub fn shape(&self) -> &[usize] {
135        &self.shape
136    }
137
138    /// The strides of the tensor as a slice (in number of elements).
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use scivex_core::Tensor;
144    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
145    /// assert_eq!(t.strides(), &[3, 1]);
146    /// ```
147    #[inline]
148    pub fn strides(&self) -> &[usize] {
149        &self.strides
150    }
151
152    /// The number of dimensions (rank) of the tensor.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// # use scivex_core::Tensor;
158    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
159    /// assert_eq!(t.ndim(), 2);
160    /// ```
161    #[inline]
162    pub fn ndim(&self) -> usize {
163        self.shape.len()
164    }
165
166    /// The total number of elements.
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// # use scivex_core::Tensor;
172    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
173    /// assert_eq!(t.numel(), 6);
174    /// ```
175    #[inline]
176    pub fn numel(&self) -> usize {
177        self.data.len()
178    }
179
180    /// Whether the tensor has zero elements.
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// # use scivex_core::Tensor;
186    /// let empty = Tensor::<f64>::zeros(vec![0]);
187    /// assert!(empty.is_empty());
188    /// let nonempty = Tensor::<f64>::ones(vec![3]);
189    /// assert!(!nonempty.is_empty());
190    /// ```
191    #[inline]
192    pub fn is_empty(&self) -> bool {
193        self.data.is_empty()
194    }
195
196    /// A flat slice of all elements in storage order.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// # use scivex_core::Tensor;
202    /// let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
203    /// assert_eq!(t.as_slice(), &[1, 2, 3]);
204    /// ```
205    #[inline]
206    pub fn as_slice(&self) -> &[T] {
207        &self.data
208    }
209
210    /// A mutable flat slice of all elements in storage order.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # use scivex_core::Tensor;
216    /// let mut t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
217    /// t.as_mut_slice()[0] = 99;
218    /// assert_eq!(t.as_slice(), &[99, 2, 3]);
219    /// ```
220    #[inline]
221    pub fn as_mut_slice(&mut self) -> &mut [T] {
222        &mut self.data
223    }
224
225    /// Consume the tensor and return the underlying `Vec<T>`.
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// # use scivex_core::Tensor;
231    /// let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
232    /// let v: Vec<i32> = t.into_vec();
233    /// assert_eq!(v, vec![1, 2, 3]);
234    /// ```
235    #[inline]
236    pub fn into_vec(self) -> Vec<T> {
237        self.data
238    }
239
240    // ------------------------------------------------------------------
241    // Element access
242    // ------------------------------------------------------------------
243
244    /// Compute the flat index for a multi-dimensional index.
245    fn flat_index(&self, index: &[usize]) -> Result<usize> {
246        if index.len() != self.ndim() {
247            return Err(CoreError::IndexOutOfBounds {
248                index: index.to_vec(),
249                shape: self.shape.clone(),
250            });
251        }
252        let mut flat = 0;
253        for (i, (&idx, &dim)) in index.iter().zip(self.shape.iter()).enumerate() {
254            if idx >= dim {
255                return Err(CoreError::IndexOutOfBounds {
256                    index: index.to_vec(),
257                    shape: self.shape.clone(),
258                });
259            }
260            flat += idx * self.strides[i];
261        }
262        Ok(flat)
263    }
264
265    /// Get a reference to the element at the given multi-dimensional index.
266    ///
267    /// # Examples
268    ///
269    /// ```
270    /// # use scivex_core::Tensor;
271    /// let t = Tensor::from_vec(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
272    /// assert_eq!(*t.get(&[0, 1]).unwrap(), 20);
273    /// assert_eq!(*t.get(&[1, 0]).unwrap(), 30);
274    /// ```
275    pub fn get(&self, index: &[usize]) -> Result<&T> {
276        let flat = self.flat_index(index)?;
277        Ok(&self.data[flat])
278    }
279
280    /// Get a mutable reference to the element at the given index.
281    ///
282    /// # Examples
283    ///
284    /// ```
285    /// # use scivex_core::Tensor;
286    /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
287    /// *t.get_mut(&[0, 0]).unwrap() = 42;
288    /// assert_eq!(*t.get(&[0, 0]).unwrap(), 42);
289    /// ```
290    pub fn get_mut(&mut self, index: &[usize]) -> Result<&mut T> {
291        let flat = self.flat_index(index)?;
292        Ok(&mut self.data[flat])
293    }
294
295    /// Set the element at the given multi-dimensional index.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// # use scivex_core::Tensor;
301    /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
302    /// t.set(&[0, 1], 99).unwrap();
303    /// assert_eq!(*t.get(&[0, 1]).unwrap(), 99);
304    /// ```
305    pub fn set(&mut self, index: &[usize], value: T) -> Result<()> {
306        let flat = self.flat_index(index)?;
307        self.data[flat] = value;
308        Ok(())
309    }
310
311    // ------------------------------------------------------------------
312    // Iterators
313    // ------------------------------------------------------------------
314
315    /// Iterate over all elements in storage order.
316    ///
317    /// # Examples
318    ///
319    /// ```
320    /// # use scivex_core::Tensor;
321    /// let t = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
322    /// let sum: i32 = t.iter().sum();
323    /// assert_eq!(sum, 60);
324    /// ```
325    pub fn iter(&self) -> impl Iterator<Item = &T> {
326        self.data.iter()
327    }
328
329    /// Iterate mutably over all elements in storage order.
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// # use scivex_core::Tensor;
335    /// let mut t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
336    /// for x in t.iter_mut() {
337    ///     *x *= 10;
338    /// }
339    /// assert_eq!(t.as_slice(), &[10, 20, 30]);
340    /// ```
341    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
342        self.data.iter_mut()
343    }
344
345    // ------------------------------------------------------------------
346    // Map / apply
347    // ------------------------------------------------------------------
348
349    /// Apply a function to every element, returning a new tensor.
350    ///
351    /// # Examples
352    ///
353    /// ```
354    /// # use scivex_core::Tensor;
355    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
356    /// let doubled = t.map(|x| x * 2);
357    /// assert_eq!(doubled.as_slice(), &[2, 4, 6, 8]);
358    /// ```
359    pub fn map<F>(&self, f: F) -> Tensor<T>
360    where
361        F: Fn(T) -> T,
362    {
363        Tensor {
364            data: self.data.iter().map(|&x| f(x)).collect(),
365            shape: self.shape.clone(),
366            strides: self.strides.clone(),
367        }
368    }
369
370    /// Cast every element to a different scalar type, preserving shape.
371    ///
372    /// Uses `to_f64()` / `from_f64()` for the conversion, which is lossless for
373    /// f32→f64 and lossy (but intentionally so) for f64→f32 or f32→f16.
374    ///
375    /// # Examples
376    ///
377    /// ```
378    /// # use scivex_core::Tensor;
379    /// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
380    /// let t32: Tensor<f32> = t.cast();
381    /// assert_eq!(t32.as_slice(), &[1.0_f32, 2.0, 3.0]);
382    /// ```
383    pub fn cast<U: Scalar + Float>(&self) -> Tensor<U>
384    where
385        T: Float,
386    {
387        Tensor {
388            data: self.data.iter().map(|&x| U::from_f64(x.to_f64())).collect(),
389            shape: self.shape.clone(),
390            strides: self.strides.clone(),
391        }
392    }
393
394    /// Apply a function element-wise to two tensors of the same shape.
395    ///
396    /// # Examples
397    ///
398    /// ```
399    /// # use scivex_core::Tensor;
400    /// let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
401    /// let b = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
402    /// let c = a.zip_map(&b, |x, y| x + y).unwrap();
403    /// assert_eq!(c.as_slice(), &[11, 22, 33]);
404    /// ```
405    pub fn zip_map<F>(&self, other: &Tensor<T>, f: F) -> Result<Tensor<T>>
406    where
407        F: Fn(T, T) -> T,
408    {
409        if self.shape != other.shape {
410            return Err(CoreError::DimensionMismatch {
411                expected: self.shape.clone(),
412                got: other.shape.clone(),
413            });
414        }
415        let data = self
416            .data
417            .iter()
418            .zip(other.data.iter())
419            .map(|(&a, &b)| f(a, b))
420            .collect();
421        Ok(Tensor {
422            data,
423            shape: self.shape.clone(),
424            strides: self.strides.clone(),
425        })
426    }
427
428    /// Apply a function to every element in place.
429    ///
430    /// # Examples
431    ///
432    /// ```
433    /// # use scivex_core::Tensor;
434    /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
435    /// t.apply(|x| x * x);
436    /// assert_eq!(t.as_slice(), &[1, 4, 9, 16]);
437    /// ```
438    pub fn apply<F>(&mut self, f: F)
439    where
440        F: Fn(T) -> T,
441    {
442        for x in &mut self.data {
443            *x = f(*x);
444        }
445    }
446}
447
448impl<T: Scalar> PartialEq for Tensor<T> {
449    fn eq(&self, other: &Self) -> bool {
450        self.shape == other.shape && self.data == other.data
451    }
452}
453
454// ======================================================================
455// Utility functions
456// ======================================================================
457
458/// Compute row-major (C-order) strides from a shape.
459pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
460    let ndim = shape.len();
461    if ndim == 0 {
462        return vec![];
463    }
464    let mut strides = vec![1usize; ndim];
465    for i in (0..ndim - 1).rev() {
466        strides[i] = strides[i + 1] * shape[i + 1];
467    }
468    strides
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_from_vec() {
477        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
478        assert_eq!(t.shape(), &[2, 3]);
479        assert_eq!(t.strides(), &[3, 1]);
480        assert_eq!(t.ndim(), 2);
481        assert_eq!(t.numel(), 6);
482    }
483
484    #[test]
485    fn test_from_vec_shape_mismatch() {
486        let r = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![2, 3]);
487        assert!(r.is_err());
488    }
489
490    #[test]
491    fn test_scalar_tensor() {
492        let t = Tensor::scalar(42.0_f64);
493        assert_eq!(t.ndim(), 0);
494        assert_eq!(t.numel(), 1);
495        assert_eq!(t.as_slice(), &[42.0]);
496    }
497
498    #[test]
499    fn test_get_set() {
500        let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
501        assert_eq!(*t.get(&[0, 0]).unwrap(), 1);
502        assert_eq!(*t.get(&[1, 2]).unwrap(), 6);
503        t.set(&[0, 1], 99).unwrap();
504        assert_eq!(*t.get(&[0, 1]).unwrap(), 99);
505    }
506
507    #[test]
508    fn test_get_out_of_bounds() {
509        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
510        assert!(t.get(&[2, 0]).is_err());
511        assert!(t.get(&[0]).is_err());
512    }
513
514    #[test]
515    fn test_compute_strides() {
516        assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
517        assert_eq!(compute_strides(&[5]), vec![1]);
518        assert_eq!(compute_strides(&[]), Vec::<usize>::new());
519    }
520
521    #[test]
522    fn test_map() {
523        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
524        let t2 = t.map(|x| x * 10);
525        assert_eq!(t2.as_slice(), &[10, 20, 30, 40]);
526        assert_eq!(t2.shape(), &[2, 2]);
527    }
528
529    #[test]
530    fn test_zip_map() {
531        let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
532        let b = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
533        let c = a.zip_map(&b, |x, y| x + y).unwrap();
534        assert_eq!(c.as_slice(), &[11, 22, 33]);
535    }
536
537    #[test]
538    fn test_zip_map_shape_mismatch() {
539        let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
540        let b = Tensor::from_vec(vec![1, 2], vec![2]).unwrap();
541        assert!(a.zip_map(&b, |x, y| x + y).is_err());
542    }
543
544    #[test]
545    fn test_partial_eq() {
546        let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
547        let b = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
548        let c = Tensor::from_vec(vec![1, 2, 4], vec![3]).unwrap();
549        assert_eq!(a, b);
550        assert_ne!(a, c);
551    }
552}