scivex_core/tensor/mod.rs
1//! N-dimensional tensor type with dynamic shape and contiguous storage.
2//!
3//! The [`Tensor`] type is the fundamental data structure in Scivex, analogous
4//! to `NumPy`'s `ndarray`. It stores elements in row-major (C) order by default
5//! and is generic over any type implementing [`Scalar`].
6
7mod create;
8mod display;
9mod ops;
10mod reshape;
11mod sort;
12
13pub mod einsum;
14pub mod einsum_path;
15pub mod indexing;
16pub mod named;
17pub mod sparse;
18
19pub use indexing::SliceRange;
20
21use crate::Scalar;
22use crate::dtype::Float;
23use crate::error::{CoreError, Result};
24
25/// An N-dimensional tensor with dynamic shape.
26///
27/// Data is stored contiguously in row-major (C) order. The tensor owns its
28/// data and cloning performs a deep copy.
29///
30/// # Type Parameters
31///
32/// - `T`: The element type, which must implement [`Scalar`].
33///
34/// # Examples
35///
36/// ```
37/// # use scivex_core::Tensor;
38/// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
39/// assert_eq!(t.shape(), &[2, 2]);
40/// assert_eq!(t.numel(), 4);
41/// ```
42#[cfg_attr(
43 feature = "serde-support",
44 derive(serde::Serialize, serde::Deserialize)
45)]
46#[derive(Debug, Clone)]
47pub struct Tensor<T: Scalar> {
48 data: Vec<T>,
49 shape: Vec<usize>,
50 strides: Vec<usize>,
51}
52
53impl<T: Scalar> Tensor<T> {
54 // ------------------------------------------------------------------
55 // Construction from raw parts
56 // ------------------------------------------------------------------
57
58 /// Create a tensor from a flat data vector and a shape.
59 ///
60 /// Returns an error if the product of `shape` does not equal `data.len()`.
61 ///
62 /// # Examples
63 ///
64 /// ```
65 /// # use scivex_core::Tensor;
66 /// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
67 /// assert_eq!(t.shape(), &[2, 3]);
68 /// assert_eq!(t.numel(), 6);
69 /// ```
70 pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Result<Self> {
71 let numel: usize = shape.iter().product();
72 if numel != data.len() {
73 return Err(CoreError::InvalidShape {
74 shape: shape.clone(),
75 reason: "shape product does not match data length",
76 });
77 }
78 let strides = compute_strides(&shape);
79 Ok(Self {
80 data,
81 shape,
82 strides,
83 })
84 }
85
86 /// Create a tensor from a flat slice and a shape (copies the data).
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// # use scivex_core::Tensor;
92 /// let data = [1, 2, 3, 4];
93 /// let t = Tensor::from_slice(&data, vec![2, 2]).unwrap();
94 /// assert_eq!(t.shape(), &[2, 2]);
95 /// assert_eq!(*t.get(&[1, 0]).unwrap(), 3);
96 /// ```
97 pub fn from_slice(data: &[T], shape: Vec<usize>) -> Result<Self> {
98 Self::from_vec(data.to_vec(), shape)
99 }
100
101 /// Create a scalar (0-dimensional) tensor.
102 ///
103 /// # Examples
104 ///
105 /// ```
106 /// # use scivex_core::Tensor;
107 /// let t = Tensor::scalar(42.0_f64);
108 /// assert_eq!(t.ndim(), 0);
109 /// assert_eq!(t.numel(), 1);
110 /// assert_eq!(t.as_slice(), &[42.0]);
111 /// ```
112 pub fn scalar(value: T) -> Self {
113 Self {
114 data: vec![value],
115 shape: vec![],
116 strides: vec![],
117 }
118 }
119
120 // ------------------------------------------------------------------
121 // Accessors
122 // ------------------------------------------------------------------
123
124 /// The shape of the tensor as a slice.
125 ///
126 /// # Examples
127 ///
128 /// ```
129 /// # use scivex_core::Tensor;
130 /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
131 /// assert_eq!(t.shape(), &[2, 3]);
132 /// ```
133 #[inline]
134 pub fn shape(&self) -> &[usize] {
135 &self.shape
136 }
137
138 /// The strides of the tensor as a slice (in number of elements).
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// # use scivex_core::Tensor;
144 /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
145 /// assert_eq!(t.strides(), &[3, 1]);
146 /// ```
147 #[inline]
148 pub fn strides(&self) -> &[usize] {
149 &self.strides
150 }
151
152 /// The number of dimensions (rank) of the tensor.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// # use scivex_core::Tensor;
158 /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
159 /// assert_eq!(t.ndim(), 2);
160 /// ```
161 #[inline]
162 pub fn ndim(&self) -> usize {
163 self.shape.len()
164 }
165
166 /// The total number of elements.
167 ///
168 /// # Examples
169 ///
170 /// ```
171 /// # use scivex_core::Tensor;
172 /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
173 /// assert_eq!(t.numel(), 6);
174 /// ```
175 #[inline]
176 pub fn numel(&self) -> usize {
177 self.data.len()
178 }
179
180 /// Whether the tensor has zero elements.
181 ///
182 /// # Examples
183 ///
184 /// ```
185 /// # use scivex_core::Tensor;
186 /// let empty = Tensor::<f64>::zeros(vec![0]);
187 /// assert!(empty.is_empty());
188 /// let nonempty = Tensor::<f64>::ones(vec![3]);
189 /// assert!(!nonempty.is_empty());
190 /// ```
191 #[inline]
192 pub fn is_empty(&self) -> bool {
193 self.data.is_empty()
194 }
195
196 /// A flat slice of all elements in storage order.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// # use scivex_core::Tensor;
202 /// let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
203 /// assert_eq!(t.as_slice(), &[1, 2, 3]);
204 /// ```
205 #[inline]
206 pub fn as_slice(&self) -> &[T] {
207 &self.data
208 }
209
210 /// A mutable flat slice of all elements in storage order.
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// # use scivex_core::Tensor;
216 /// let mut t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
217 /// t.as_mut_slice()[0] = 99;
218 /// assert_eq!(t.as_slice(), &[99, 2, 3]);
219 /// ```
220 #[inline]
221 pub fn as_mut_slice(&mut self) -> &mut [T] {
222 &mut self.data
223 }
224
225 /// Consume the tensor and return the underlying `Vec<T>`.
226 ///
227 /// # Examples
228 ///
229 /// ```
230 /// # use scivex_core::Tensor;
231 /// let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
232 /// let v: Vec<i32> = t.into_vec();
233 /// assert_eq!(v, vec![1, 2, 3]);
234 /// ```
235 #[inline]
236 pub fn into_vec(self) -> Vec<T> {
237 self.data
238 }
239
240 // ------------------------------------------------------------------
241 // Element access
242 // ------------------------------------------------------------------
243
244 /// Compute the flat index for a multi-dimensional index.
245 fn flat_index(&self, index: &[usize]) -> Result<usize> {
246 if index.len() != self.ndim() {
247 return Err(CoreError::IndexOutOfBounds {
248 index: index.to_vec(),
249 shape: self.shape.clone(),
250 });
251 }
252 let mut flat = 0;
253 for (i, (&idx, &dim)) in index.iter().zip(self.shape.iter()).enumerate() {
254 if idx >= dim {
255 return Err(CoreError::IndexOutOfBounds {
256 index: index.to_vec(),
257 shape: self.shape.clone(),
258 });
259 }
260 flat += idx * self.strides[i];
261 }
262 Ok(flat)
263 }
264
265 /// Get a reference to the element at the given multi-dimensional index.
266 ///
267 /// # Examples
268 ///
269 /// ```
270 /// # use scivex_core::Tensor;
271 /// let t = Tensor::from_vec(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
272 /// assert_eq!(*t.get(&[0, 1]).unwrap(), 20);
273 /// assert_eq!(*t.get(&[1, 0]).unwrap(), 30);
274 /// ```
275 pub fn get(&self, index: &[usize]) -> Result<&T> {
276 let flat = self.flat_index(index)?;
277 Ok(&self.data[flat])
278 }
279
280 /// Get a mutable reference to the element at the given index.
281 ///
282 /// # Examples
283 ///
284 /// ```
285 /// # use scivex_core::Tensor;
286 /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
287 /// *t.get_mut(&[0, 0]).unwrap() = 42;
288 /// assert_eq!(*t.get(&[0, 0]).unwrap(), 42);
289 /// ```
290 pub fn get_mut(&mut self, index: &[usize]) -> Result<&mut T> {
291 let flat = self.flat_index(index)?;
292 Ok(&mut self.data[flat])
293 }
294
295 /// Set the element at the given multi-dimensional index.
296 ///
297 /// # Examples
298 ///
299 /// ```
300 /// # use scivex_core::Tensor;
301 /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
302 /// t.set(&[0, 1], 99).unwrap();
303 /// assert_eq!(*t.get(&[0, 1]).unwrap(), 99);
304 /// ```
305 pub fn set(&mut self, index: &[usize], value: T) -> Result<()> {
306 let flat = self.flat_index(index)?;
307 self.data[flat] = value;
308 Ok(())
309 }
310
311 // ------------------------------------------------------------------
312 // Iterators
313 // ------------------------------------------------------------------
314
315 /// Iterate over all elements in storage order.
316 ///
317 /// # Examples
318 ///
319 /// ```
320 /// # use scivex_core::Tensor;
321 /// let t = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
322 /// let sum: i32 = t.iter().sum();
323 /// assert_eq!(sum, 60);
324 /// ```
325 pub fn iter(&self) -> impl Iterator<Item = &T> {
326 self.data.iter()
327 }
328
329 /// Iterate mutably over all elements in storage order.
330 ///
331 /// # Examples
332 ///
333 /// ```
334 /// # use scivex_core::Tensor;
335 /// let mut t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
336 /// for x in t.iter_mut() {
337 /// *x *= 10;
338 /// }
339 /// assert_eq!(t.as_slice(), &[10, 20, 30]);
340 /// ```
341 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
342 self.data.iter_mut()
343 }
344
345 // ------------------------------------------------------------------
346 // Map / apply
347 // ------------------------------------------------------------------
348
349 /// Apply a function to every element, returning a new tensor.
350 ///
351 /// # Examples
352 ///
353 /// ```
354 /// # use scivex_core::Tensor;
355 /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
356 /// let doubled = t.map(|x| x * 2);
357 /// assert_eq!(doubled.as_slice(), &[2, 4, 6, 8]);
358 /// ```
359 pub fn map<F>(&self, f: F) -> Tensor<T>
360 where
361 F: Fn(T) -> T,
362 {
363 Tensor {
364 data: self.data.iter().map(|&x| f(x)).collect(),
365 shape: self.shape.clone(),
366 strides: self.strides.clone(),
367 }
368 }
369
370 /// Cast every element to a different scalar type, preserving shape.
371 ///
372 /// Uses `to_f64()` / `from_f64()` for the conversion, which is lossless for
373 /// f32→f64 and lossy (but intentionally so) for f64→f32 or f32→f16.
374 ///
375 /// # Examples
376 ///
377 /// ```
378 /// # use scivex_core::Tensor;
379 /// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
380 /// let t32: Tensor<f32> = t.cast();
381 /// assert_eq!(t32.as_slice(), &[1.0_f32, 2.0, 3.0]);
382 /// ```
383 pub fn cast<U: Scalar + Float>(&self) -> Tensor<U>
384 where
385 T: Float,
386 {
387 Tensor {
388 data: self.data.iter().map(|&x| U::from_f64(x.to_f64())).collect(),
389 shape: self.shape.clone(),
390 strides: self.strides.clone(),
391 }
392 }
393
394 /// Apply a function element-wise to two tensors of the same shape.
395 ///
396 /// # Examples
397 ///
398 /// ```
399 /// # use scivex_core::Tensor;
400 /// let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
401 /// let b = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
402 /// let c = a.zip_map(&b, |x, y| x + y).unwrap();
403 /// assert_eq!(c.as_slice(), &[11, 22, 33]);
404 /// ```
405 pub fn zip_map<F>(&self, other: &Tensor<T>, f: F) -> Result<Tensor<T>>
406 where
407 F: Fn(T, T) -> T,
408 {
409 if self.shape != other.shape {
410 return Err(CoreError::DimensionMismatch {
411 expected: self.shape.clone(),
412 got: other.shape.clone(),
413 });
414 }
415 let data = self
416 .data
417 .iter()
418 .zip(other.data.iter())
419 .map(|(&a, &b)| f(a, b))
420 .collect();
421 Ok(Tensor {
422 data,
423 shape: self.shape.clone(),
424 strides: self.strides.clone(),
425 })
426 }
427
428 /// Apply a function to every element in place.
429 ///
430 /// # Examples
431 ///
432 /// ```
433 /// # use scivex_core::Tensor;
434 /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
435 /// t.apply(|x| x * x);
436 /// assert_eq!(t.as_slice(), &[1, 4, 9, 16]);
437 /// ```
438 pub fn apply<F>(&mut self, f: F)
439 where
440 F: Fn(T) -> T,
441 {
442 for x in &mut self.data {
443 *x = f(*x);
444 }
445 }
446}
447
448impl<T: Scalar> PartialEq for Tensor<T> {
449 fn eq(&self, other: &Self) -> bool {
450 self.shape == other.shape && self.data == other.data
451 }
452}
453
454// ======================================================================
455// Utility functions
456// ======================================================================
457
458/// Compute row-major (C-order) strides from a shape.
459pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
460 let ndim = shape.len();
461 if ndim == 0 {
462 return vec![];
463 }
464 let mut strides = vec![1usize; ndim];
465 for i in (0..ndim - 1).rev() {
466 strides[i] = strides[i + 1] * shape[i + 1];
467 }
468 strides
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_from_vec() {
477 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
478 assert_eq!(t.shape(), &[2, 3]);
479 assert_eq!(t.strides(), &[3, 1]);
480 assert_eq!(t.ndim(), 2);
481 assert_eq!(t.numel(), 6);
482 }
483
484 #[test]
485 fn test_from_vec_shape_mismatch() {
486 let r = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![2, 3]);
487 assert!(r.is_err());
488 }
489
490 #[test]
491 fn test_scalar_tensor() {
492 let t = Tensor::scalar(42.0_f64);
493 assert_eq!(t.ndim(), 0);
494 assert_eq!(t.numel(), 1);
495 assert_eq!(t.as_slice(), &[42.0]);
496 }
497
498 #[test]
499 fn test_get_set() {
500 let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
501 assert_eq!(*t.get(&[0, 0]).unwrap(), 1);
502 assert_eq!(*t.get(&[1, 2]).unwrap(), 6);
503 t.set(&[0, 1], 99).unwrap();
504 assert_eq!(*t.get(&[0, 1]).unwrap(), 99);
505 }
506
507 #[test]
508 fn test_get_out_of_bounds() {
509 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
510 assert!(t.get(&[2, 0]).is_err());
511 assert!(t.get(&[0]).is_err());
512 }
513
514 #[test]
515 fn test_compute_strides() {
516 assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
517 assert_eq!(compute_strides(&[5]), vec![1]);
518 assert_eq!(compute_strides(&[]), Vec::<usize>::new());
519 }
520
521 #[test]
522 fn test_map() {
523 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
524 let t2 = t.map(|x| x * 10);
525 assert_eq!(t2.as_slice(), &[10, 20, 30, 40]);
526 assert_eq!(t2.shape(), &[2, 2]);
527 }
528
529 #[test]
530 fn test_zip_map() {
531 let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
532 let b = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
533 let c = a.zip_map(&b, |x, y| x + y).unwrap();
534 assert_eq!(c.as_slice(), &[11, 22, 33]);
535 }
536
537 #[test]
538 fn test_zip_map_shape_mismatch() {
539 let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
540 let b = Tensor::from_vec(vec![1, 2], vec![2]).unwrap();
541 assert!(a.zip_map(&b, |x, y| x + y).is_err());
542 }
543
544 #[test]
545 fn test_partial_eq() {
546 let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
547 let b = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
548 let c = Tensor::from_vec(vec![1, 2, 4], vec![3]).unwrap();
549 assert_eq!(a, b);
550 assert_ne!(a, c);
551 }
552}