1use super::error::TensorError;
2use super::tensor::Tensor;
3
4impl Tensor {
5 pub fn trace(&self) -> Result<f32, TensorError> {
7 if self.rank() != 2 {
8 return Err(TensorError::UnsupportedOperation {
9 msg: format!("trace requires a 2D tensor, got rank {}", self.rank()),
10 });
11 }
12 let (rows, cols) = (self.shape()[0], self.shape()[1]);
13 if rows != cols {
14 return Err(TensorError::ShapeMismatch {
15 left: self.shape().to_vec(),
16 right: vec![rows, rows],
17 });
18 }
19 let data = self.data();
20 let mut sum = 0.0f32;
21 for i in 0..rows {
22 sum += data[i * cols + i];
23 }
24 Ok(sum)
25 }
26
27 pub fn dot(&self, rhs: &Self) -> Result<f32, TensorError> {
29 if self.rank() != 1 || rhs.rank() != 1 {
30 return Err(TensorError::UnsupportedOperation {
31 msg: format!(
32 "dot requires two 1D tensors, got ranks {} and {}",
33 self.rank(),
34 rhs.rank()
35 ),
36 });
37 }
38 if self.shape() != rhs.shape() {
39 return Err(TensorError::ShapeMismatch {
40 left: self.shape().to_vec(),
41 right: rhs.shape().to_vec(),
42 });
43 }
44 let a = self.data();
45 let b = rhs.data();
46 let mut sum = 0.0f32;
47 for i in 0..a.len() {
48 sum += a[i] * b[i];
49 }
50 Ok(sum)
51 }
52
53 pub fn cross(&self, rhs: &Self) -> Result<Self, TensorError> {
55 if self.rank() != 1 || rhs.rank() != 1 {
56 return Err(TensorError::UnsupportedOperation {
57 msg: "cross requires two 1D tensors".into(),
58 });
59 }
60 if self.shape()[0] != 3 || rhs.shape()[0] != 3 {
61 return Err(TensorError::ShapeMismatch {
62 left: self.shape().to_vec(),
63 right: rhs.shape().to_vec(),
64 });
65 }
66 let a = self.data();
67 let b = rhs.data();
68 let result = vec![
69 a[1] * b[2] - a[2] * b[1],
70 a[2] * b[0] - a[0] * b[2],
71 a[0] * b[1] - a[1] * b[0],
72 ];
73 Tensor::from_vec(vec![3], result)
74 }
75
76 pub fn norm(&self, p: f32) -> f32 {
78 let data = self.data();
79 if p == 1.0 {
80 data.iter().map(|x| x.abs()).sum()
81 } else if p == 2.0 {
82 data.iter().map(|x| x * x).sum::<f32>().sqrt()
83 } else {
84 data.iter()
85 .map(|x| x.abs().powf(p))
86 .sum::<f32>()
87 .powf(1.0 / p)
88 }
89 }
90
91 pub fn det(&self) -> Result<f32, TensorError> {
93 if self.rank() != 2 {
94 return Err(TensorError::UnsupportedOperation {
95 msg: format!("det requires a 2D tensor, got rank {}", self.rank()),
96 });
97 }
98 let n = self.shape()[0];
99 if n != self.shape()[1] {
100 return Err(TensorError::ShapeMismatch {
101 left: self.shape().to_vec(),
102 right: vec![n, n],
103 });
104 }
105
106 let mut a: Vec<f32> = self.data().to_vec();
108 let mut sign = 1.0f32;
109
110 for col in 0..n {
111 let mut max_row = col;
113 let mut max_val = a[col * n + col].abs();
114 for row in (col + 1)..n {
115 let v = a[row * n + col].abs();
116 if v > max_val {
117 max_val = v;
118 max_row = row;
119 }
120 }
121 if max_val < 1e-12 {
122 return Ok(0.0);
123 }
124 if max_row != col {
125 for j in 0..n {
127 a.swap(col * n + j, max_row * n + j);
128 }
129 sign = -sign;
130 }
131 let pivot = a[col * n + col];
132 for row in (col + 1)..n {
133 let factor = a[row * n + col] / pivot;
134 for j in col..n {
135 let val = a[col * n + j];
136 a[row * n + j] -= factor * val;
137 }
138 }
139 }
140
141 let mut det = sign;
142 for i in 0..n {
143 det *= a[i * n + i];
144 }
145 Ok(det)
146 }
147
148 pub fn inv(&self) -> Result<Self, TensorError> {
150 if self.rank() != 2 {
151 return Err(TensorError::UnsupportedOperation {
152 msg: format!("inv requires a 2D tensor, got rank {}", self.rank()),
153 });
154 }
155 let n = self.shape()[0];
156 if n != self.shape()[1] {
157 return Err(TensorError::ShapeMismatch {
158 left: self.shape().to_vec(),
159 right: vec![n, n],
160 });
161 }
162
163 let data = self.data();
165 let nn = 2 * n;
166 let mut aug = vec![0.0f32; n * nn];
167 for i in 0..n {
168 for j in 0..n {
169 aug[i * nn + j] = data[i * n + j];
170 }
171 aug[i * nn + n + i] = 1.0;
172 }
173
174 for col in 0..n {
176 let mut max_row = col;
178 let mut max_val = aug[col * nn + col].abs();
179 for row in (col + 1)..n {
180 let v = aug[row * nn + col].abs();
181 if v > max_val {
182 max_val = v;
183 max_row = row;
184 }
185 }
186 if max_val < 1e-12 {
187 return Err(TensorError::UnsupportedOperation {
188 msg: "matrix is singular".into(),
189 });
190 }
191 if max_row != col {
192 for j in 0..nn {
193 aug.swap(col * nn + j, max_row * nn + j);
194 }
195 }
196
197 let pivot = aug[col * nn + col];
199 for j in 0..nn {
200 aug[col * nn + j] /= pivot;
201 }
202
203 for row in 0..n {
205 if row == col {
206 continue;
207 }
208 let factor = aug[row * nn + col];
209 for j in 0..nn {
210 let val = aug[col * nn + j];
211 aug[row * nn + j] -= factor * val;
212 }
213 }
214 }
215
216 let mut result = vec![0.0f32; n * n];
218 for i in 0..n {
219 for j in 0..n {
220 result[i * n + j] = aug[i * nn + n + j];
221 }
222 }
223 Tensor::from_vec(vec![n, n], result)
224 }
225
226 pub fn solve(&self, b: &Self) -> Result<Self, TensorError> {
228 if self.rank() != 2 {
229 return Err(TensorError::UnsupportedOperation {
230 msg: format!("solve requires A to be 2D, got rank {}", self.rank()),
231 });
232 }
233 let n = self.shape()[0];
234 if n != self.shape()[1] {
235 return Err(TensorError::ShapeMismatch {
236 left: self.shape().to_vec(),
237 right: vec![n, n],
238 });
239 }
240 if b.rank() != 1 || b.shape()[0] != n {
241 return Err(TensorError::ShapeMismatch {
242 left: self.shape().to_vec(),
243 right: b.shape().to_vec(),
244 });
245 }
246
247 let data = self.data();
249 let mut a = data.to_vec();
250 let mut perm: Vec<usize> = (0..n).collect();
251
252 for col in 0..n {
253 let mut max_row = col;
255 let mut max_val = a[col * n + col].abs();
256 for row in (col + 1)..n {
257 let v = a[row * n + col].abs();
258 if v > max_val {
259 max_val = v;
260 max_row = row;
261 }
262 }
263 if max_val < 1e-12 {
264 return Err(TensorError::UnsupportedOperation {
265 msg: "matrix is singular".into(),
266 });
267 }
268 if max_row != col {
269 for j in 0..n {
270 a.swap(col * n + j, max_row * n + j);
271 }
272 perm.swap(col, max_row);
273 }
274 let pivot = a[col * n + col];
275 for row in (col + 1)..n {
276 let factor = a[row * n + col] / pivot;
277 a[row * n + col] = factor; for j in (col + 1)..n {
279 let val = a[col * n + j];
280 a[row * n + j] -= factor * val;
281 }
282 }
283 }
284
285 let bd = b.data();
287 let mut pb = vec![0.0f32; n];
288 for i in 0..n {
289 pb[i] = bd[perm[i]];
290 }
291
292 let mut y = pb;
294 for i in 1..n {
295 for j in 0..i {
296 let l_ij = a[i * n + j];
297 y[i] -= l_ij * y[j];
298 }
299 }
300
301 let mut x = y;
303 for i in (0..n).rev() {
304 for j in (i + 1)..n {
305 let u_ij = a[i * n + j];
306 x[i] -= u_ij * x[j];
307 }
308 x[i] /= a[i * n + i];
309 }
310
311 Tensor::from_vec(vec![n], x)
312 }
313
314 pub fn qr(&self) -> Result<(Self, Self), TensorError> {
316 if self.rank() != 2 {
317 return Err(TensorError::UnsupportedOperation {
318 msg: format!("qr requires a 2D tensor, got rank {}", self.rank()),
319 });
320 }
321 let m = self.shape()[0];
322 let n = self.shape()[1];
323 let data = self.data();
324
325 let mut r = data.to_vec();
327
328 let mut q = vec![0.0f32; m * m];
330 for i in 0..m {
331 q[i * m + i] = 1.0;
332 }
333
334 let k = m.min(n);
335 for j in 0..k {
336 let mut col = vec![0.0f32; m - j];
338 for i in j..m {
339 col[i - j] = r[i * n + j];
340 }
341
342 let norm_col: f32 = col.iter().map(|x| x * x).sum::<f32>().sqrt();
344 if norm_col < 1e-12 {
345 continue;
346 }
347 let sign = if col[0] >= 0.0 { 1.0 } else { -1.0 };
348 col[0] += sign * norm_col;
349
350 let norm_v: f32 = col.iter().map(|x| x * x).sum::<f32>();
351 if norm_v < 1e-24 {
352 continue;
353 }
354
355 for jj in j..n {
358 let mut dot = 0.0f32;
359 for i in j..m {
360 dot += col[i - j] * r[i * n + jj];
361 }
362 let factor = 2.0 * dot / norm_v;
363 for i in j..m {
364 r[i * n + jj] -= factor * col[i - j];
365 }
366 }
367
368 for i in 0..m {
371 let mut dot = 0.0f32;
372 for jj in j..m {
373 dot += q[i * m + jj] * col[jj - j];
374 }
375 let factor = 2.0 * dot / norm_v;
376 for jj in j..m {
377 q[i * m + jj] -= factor * col[jj - j];
378 }
379 }
380 }
381
382 let q_tensor = Tensor::from_vec(vec![m, m], q)?;
383 let r_tensor = Tensor::from_vec(vec![m, n], r)?;
384 Ok((q_tensor, r_tensor))
385 }
386
387 pub fn cholesky(&self) -> Result<Self, TensorError> {
389 if self.rank() != 2 {
390 return Err(TensorError::UnsupportedOperation {
391 msg: format!("cholesky requires a 2D tensor, got rank {}", self.rank()),
392 });
393 }
394 let n = self.shape()[0];
395 if n != self.shape()[1] {
396 return Err(TensorError::ShapeMismatch {
397 left: self.shape().to_vec(),
398 right: vec![n, n],
399 });
400 }
401 let data = self.data();
402 let mut l = vec![0.0f32; n * n];
403
404 for i in 0..n {
405 for j in 0..=i {
406 let mut sum = 0.0f32;
407 for k in 0..j {
408 sum += l[i * n + k] * l[j * n + k];
409 }
410 if i == j {
411 let diag = data[i * n + i] - sum;
412 if diag <= 0.0 {
413 return Err(TensorError::UnsupportedOperation {
414 msg: "matrix is not positive definite".into(),
415 });
416 }
417 l[i * n + j] = diag.sqrt();
418 } else {
419 l[i * n + j] = (data[i * n + j] - sum) / l[j * n + j];
420 }
421 }
422 }
423
424 Tensor::from_vec(vec![n, n], l)
425 }
426}