sciforge_lib/maths/tensor/
ops.rs1use super::storage::Tensor;
2
3pub fn reshape(t: &Tensor, new_shape: &[usize]) -> Tensor {
4 assert_eq!(new_shape.iter().product::<usize>(), t.data.len());
5 Tensor::from_vec(new_shape, t.data.clone())
6}
7
8pub fn transpose(t: &Tensor, axes: &[usize]) -> Tensor {
9 assert_eq!(axes.len(), t.rank());
10 let new_shape: Vec<usize> = axes.iter().map(|&a| t.shape()[a]).collect();
11 let new_strides = Tensor::compute_strides(&new_shape);
12 let size = t.size();
13 let mut new_data = vec![0.0; size];
14 let mut indices = vec![0usize; t.rank()];
15 for _ in 0..size {
16 let new_indices: Vec<usize> = axes.iter().map(|&a| indices[a]).collect();
17 let new_flat: usize = new_indices
18 .iter()
19 .zip(&new_strides)
20 .map(|(i, s)| i * s)
21 .sum();
22 new_data[new_flat] = t.get(&indices);
23 for k in (0..t.rank()).rev() {
24 indices[k] += 1;
25 if indices[k] < t.shape()[k] {
26 break;
27 }
28 indices[k] = 0;
29 }
30 }
31 Tensor {
32 data: new_data,
33 shape: new_shape,
34 strides: new_strides,
35 }
36}
37
38pub fn contract(a: &Tensor, b: &Tensor, axis_a: usize, axis_b: usize) -> Tensor {
39 assert_eq!(a.shape()[axis_a], b.shape()[axis_b]);
40 let contract_size = a.shape()[axis_a];
41 let mut result_shape = Vec::new();
42 for (i, &s) in a.shape().iter().enumerate() {
43 if i != axis_a {
44 result_shape.push(s);
45 }
46 }
47 for (i, &s) in b.shape().iter().enumerate() {
48 if i != axis_b {
49 result_shape.push(s);
50 }
51 }
52 let a_rank = a.rank();
53 let b_rank = b.rank();
54 Tensor::from_fn(&result_shape, |indices| {
55 let split = a_rank - 1;
56 let idx_a = &indices[..split];
57 let idx_b = &indices[split..];
58 let mut sum = 0.0;
59 for k in 0..contract_size {
60 let mut full_a = Vec::with_capacity(a_rank);
61 let mut ai = 0;
62 for i in 0..a_rank {
63 if i == axis_a {
64 full_a.push(k);
65 } else {
66 full_a.push(idx_a[ai]);
67 ai += 1;
68 }
69 }
70 let mut full_b = Vec::with_capacity(b_rank);
71 let mut bi = 0;
72 for i in 0..b_rank {
73 if i == axis_b {
74 full_b.push(k);
75 } else {
76 full_b.push(idx_b[bi]);
77 bi += 1;
78 }
79 }
80 sum += a.get(&full_a) * b.get(&full_b);
81 }
82 sum
83 })
84}
85
86pub fn outer(a: &Tensor, b: &Tensor) -> Tensor {
87 let mut new_shape = a.shape().to_vec();
88 new_shape.extend_from_slice(b.shape());
89 let a_rank = a.rank();
90 Tensor::from_fn(&new_shape, |indices| {
91 let (ia, ib) = indices.split_at(a_rank);
92 a.get(ia) * b.get(ib)
93 })
94}
95
96pub fn kronecker(a: &Tensor, b: &Tensor) -> Tensor {
97 assert!(a.rank() == 2 && b.rank() == 2);
98 let (am, an) = (a.shape()[0], a.shape()[1]);
99 let (bm, bn) = (b.shape()[0], b.shape()[1]);
100 Tensor::from_fn(&[am * bm, an * bn], |idx| {
101 a.get(&[idx[0] / bm, idx[1] / bn]) * b.get(&[idx[0] % bm, idx[1] % bn])
102 })
103}
104
105pub fn levi_civita(n: usize) -> Tensor {
106 let shape: Vec<usize> = vec![n; n];
107 Tensor::from_fn(&shape, |indices| {
108 for i in 0..n {
109 for j in (i + 1)..n {
110 if indices[i] == indices[j] {
111 return 0.0;
112 }
113 }
114 }
115 let mut parity = 0;
116 let mut perm: Vec<usize> = indices.to_vec();
117 for i in 0..n {
118 while perm[i] != i {
119 let target = perm[i];
120 perm.swap(i, target);
121 parity += 1;
122 }
123 }
124 if parity % 2 == 0 { 1.0 } else { -1.0 }
125 })
126}
127
128pub fn metric_raise(t: &Tensor, metric_inv: &Tensor, index: usize) -> Tensor {
129 contract(t, metric_inv, index, 1)
130}
131
132pub fn metric_lower(t: &Tensor, metric: &Tensor, index: usize) -> Tensor {
133 contract(t, metric, index, 0)
134}