sciforge_lib/maths/tensor/
storage.rs1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3#[derive(Clone, Debug)]
4pub struct Tensor {
5 pub(crate) data: Vec<f64>,
6 pub(crate) shape: Vec<usize>,
7 pub(crate) strides: Vec<usize>,
8}
9
10impl Tensor {
11 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
12 let mut strides = vec![1; shape.len()];
13 for i in (0..shape.len().saturating_sub(1)).rev() {
14 strides[i] = strides[i + 1] * shape[i + 1];
15 }
16 strides
17 }
18
19 pub(crate) fn flat_index(&self, indices: &[usize]) -> usize {
20 indices.iter().zip(&self.strides).map(|(i, s)| i * s).sum()
21 }
22
23 pub fn zeros(shape: &[usize]) -> Self {
24 let size: usize = shape.iter().product();
25 Self {
26 data: vec![0.0; size],
27 shape: shape.to_vec(),
28 strides: Self::compute_strides(shape),
29 }
30 }
31
32 pub fn ones(shape: &[usize]) -> Self {
33 let size: usize = shape.iter().product();
34 Self {
35 data: vec![1.0; size],
36 shape: shape.to_vec(),
37 strides: Self::compute_strides(shape),
38 }
39 }
40
41 pub fn from_vec(shape: &[usize], data: Vec<f64>) -> Self {
42 assert_eq!(shape.iter().product::<usize>(), data.len());
43 Self {
44 data,
45 shape: shape.to_vec(),
46 strides: Self::compute_strides(shape),
47 }
48 }
49
50 pub fn from_fn(shape: &[usize], f: impl Fn(&[usize]) -> f64) -> Self {
51 let size: usize = shape.iter().product();
52 let mut data = Vec::with_capacity(size);
53 let mut indices = vec![0usize; shape.len()];
54 for _ in 0..size {
55 data.push(f(&indices));
56 for k in (0..shape.len()).rev() {
57 indices[k] += 1;
58 if indices[k] < shape[k] {
59 break;
60 }
61 indices[k] = 0;
62 }
63 }
64 Self {
65 data,
66 shape: shape.to_vec(),
67 strides: Self::compute_strides(shape),
68 }
69 }
70
71 pub fn scalar(val: f64) -> Self {
72 Self {
73 data: vec![val],
74 shape: vec![],
75 strides: vec![],
76 }
77 }
78
79 pub fn identity(n: usize) -> Self {
80 Self::from_fn(&[n, n], |idx| if idx[0] == idx[1] { 1.0 } else { 0.0 })
81 }
82
83 pub fn shape(&self) -> &[usize] {
84 &self.shape
85 }
86 pub fn rank(&self) -> usize {
87 self.shape.len()
88 }
89 pub fn size(&self) -> usize {
90 self.data.len()
91 }
92 pub fn data(&self) -> &[f64] {
93 &self.data
94 }
95
96 pub fn get(&self, indices: &[usize]) -> f64 {
97 self.data[self.flat_index(indices)]
98 }
99
100 pub fn set(&mut self, indices: &[usize], value: f64) {
101 let idx = self.flat_index(indices);
102 self.data[idx] = value;
103 }
104
105 pub fn scale(&self, s: f64) -> Self {
106 Self {
107 data: self.data.iter().map(|x| x * s).collect(),
108 shape: self.shape.clone(),
109 strides: self.strides.clone(),
110 }
111 }
112
113 pub fn map(&self, f: impl Fn(f64) -> f64) -> Self {
114 Self {
115 data: self.data.iter().map(|&x| f(x)).collect(),
116 shape: self.shape.clone(),
117 strides: self.strides.clone(),
118 }
119 }
120
121 pub fn elementwise(&self, other: &Tensor, f: impl Fn(f64, f64) -> f64) -> Self {
122 assert_eq!(self.shape, other.shape);
123 Self::from_vec(
124 &self.shape,
125 self.data
126 .iter()
127 .zip(&other.data)
128 .map(|(&a, &b)| f(a, b))
129 .collect(),
130 )
131 }
132
133 pub fn sum(&self) -> f64 {
134 self.data.iter().sum()
135 }
136 pub fn max(&self) -> f64 {
137 self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
138 }
139 pub fn min(&self) -> f64 {
140 self.data.iter().cloned().fold(f64::INFINITY, f64::min)
141 }
142 pub fn frobenius_norm(&self) -> f64 {
143 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
144 }
145}
146
147impl Add for &Tensor {
148 type Output = Tensor;
149 fn add(self, rhs: Self) -> Tensor {
150 self.elementwise(rhs, |a, b| a + b)
151 }
152}
153
154impl Sub for &Tensor {
155 type Output = Tensor;
156 fn sub(self, rhs: Self) -> Tensor {
157 self.elementwise(rhs, |a, b| a - b)
158 }
159}
160
161impl Mul<f64> for &Tensor {
162 type Output = Tensor;
163 fn mul(self, s: f64) -> Tensor {
164 self.scale(s)
165 }
166}
167
168impl Div<f64> for &Tensor {
169 type Output = Tensor;
170 fn div(self, s: f64) -> Tensor {
171 self.scale(1.0 / s)
172 }
173}
174
175impl Neg for &Tensor {
176 type Output = Tensor;
177 fn neg(self) -> Tensor {
178 self.scale(-1.0)
179 }
180}