Skip to main content

sciforge_lib/maths/tensor/
storage.rs

1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3#[derive(Clone, Debug)]
4pub struct Tensor {
5    pub(crate) data: Vec<f64>,
6    pub(crate) shape: Vec<usize>,
7    pub(crate) strides: Vec<usize>,
8}
9
10impl Tensor {
11    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
12        let mut strides = vec![1; shape.len()];
13        for i in (0..shape.len().saturating_sub(1)).rev() {
14            strides[i] = strides[i + 1] * shape[i + 1];
15        }
16        strides
17    }
18
19    pub(crate) fn flat_index(&self, indices: &[usize]) -> usize {
20        indices.iter().zip(&self.strides).map(|(i, s)| i * s).sum()
21    }
22
23    pub fn zeros(shape: &[usize]) -> Self {
24        let size: usize = shape.iter().product();
25        Self {
26            data: vec![0.0; size],
27            shape: shape.to_vec(),
28            strides: Self::compute_strides(shape),
29        }
30    }
31
32    pub fn ones(shape: &[usize]) -> Self {
33        let size: usize = shape.iter().product();
34        Self {
35            data: vec![1.0; size],
36            shape: shape.to_vec(),
37            strides: Self::compute_strides(shape),
38        }
39    }
40
41    pub fn from_vec(shape: &[usize], data: Vec<f64>) -> Self {
42        assert_eq!(shape.iter().product::<usize>(), data.len());
43        Self {
44            data,
45            shape: shape.to_vec(),
46            strides: Self::compute_strides(shape),
47        }
48    }
49
50    pub fn from_fn(shape: &[usize], f: impl Fn(&[usize]) -> f64) -> Self {
51        let size: usize = shape.iter().product();
52        let mut data = Vec::with_capacity(size);
53        let mut indices = vec![0usize; shape.len()];
54        for _ in 0..size {
55            data.push(f(&indices));
56            for k in (0..shape.len()).rev() {
57                indices[k] += 1;
58                if indices[k] < shape[k] {
59                    break;
60                }
61                indices[k] = 0;
62            }
63        }
64        Self {
65            data,
66            shape: shape.to_vec(),
67            strides: Self::compute_strides(shape),
68        }
69    }
70
71    pub fn scalar(val: f64) -> Self {
72        Self {
73            data: vec![val],
74            shape: vec![],
75            strides: vec![],
76        }
77    }
78
79    pub fn identity(n: usize) -> Self {
80        Self::from_fn(&[n, n], |idx| if idx[0] == idx[1] { 1.0 } else { 0.0 })
81    }
82
83    pub fn shape(&self) -> &[usize] {
84        &self.shape
85    }
86    pub fn rank(&self) -> usize {
87        self.shape.len()
88    }
89    pub fn size(&self) -> usize {
90        self.data.len()
91    }
92    pub fn data(&self) -> &[f64] {
93        &self.data
94    }
95
96    pub fn get(&self, indices: &[usize]) -> f64 {
97        self.data[self.flat_index(indices)]
98    }
99
100    pub fn set(&mut self, indices: &[usize], value: f64) {
101        let idx = self.flat_index(indices);
102        self.data[idx] = value;
103    }
104
105    pub fn scale(&self, s: f64) -> Self {
106        Self {
107            data: self.data.iter().map(|x| x * s).collect(),
108            shape: self.shape.clone(),
109            strides: self.strides.clone(),
110        }
111    }
112
113    pub fn map(&self, f: impl Fn(f64) -> f64) -> Self {
114        Self {
115            data: self.data.iter().map(|&x| f(x)).collect(),
116            shape: self.shape.clone(),
117            strides: self.strides.clone(),
118        }
119    }
120
121    pub fn elementwise(&self, other: &Tensor, f: impl Fn(f64, f64) -> f64) -> Self {
122        assert_eq!(self.shape, other.shape);
123        Self::from_vec(
124            &self.shape,
125            self.data
126                .iter()
127                .zip(&other.data)
128                .map(|(&a, &b)| f(a, b))
129                .collect(),
130        )
131    }
132
133    pub fn sum(&self) -> f64 {
134        self.data.iter().sum()
135    }
136    pub fn max(&self) -> f64 {
137        self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
138    }
139    pub fn min(&self) -> f64 {
140        self.data.iter().cloned().fold(f64::INFINITY, f64::min)
141    }
142    pub fn frobenius_norm(&self) -> f64 {
143        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
144    }
145}
146
147impl Add for &Tensor {
148    type Output = Tensor;
149    fn add(self, rhs: Self) -> Tensor {
150        self.elementwise(rhs, |a, b| a + b)
151    }
152}
153
154impl Sub for &Tensor {
155    type Output = Tensor;
156    fn sub(self, rhs: Self) -> Tensor {
157        self.elementwise(rhs, |a, b| a - b)
158    }
159}
160
161impl Mul<f64> for &Tensor {
162    type Output = Tensor;
163    fn mul(self, s: f64) -> Tensor {
164        self.scale(s)
165    }
166}
167
168impl Div<f64> for &Tensor {
169    type Output = Tensor;
170    fn div(self, s: f64) -> Tensor {
171        self.scale(1.0 / s)
172    }
173}
174
175impl Neg for &Tensor {
176    type Output = Tensor;
177    fn neg(self) -> Tensor {
178        self.scale(-1.0)
179    }
180}