Skip to main content

sciforge_lib/maths/tensor/
ops.rs

1use super::storage::Tensor;
2
3pub fn reshape(t: &Tensor, new_shape: &[usize]) -> Tensor {
4    assert_eq!(new_shape.iter().product::<usize>(), t.data.len());
5    Tensor::from_vec(new_shape, t.data.clone())
6}
7
8pub fn transpose(t: &Tensor, axes: &[usize]) -> Tensor {
9    assert_eq!(axes.len(), t.rank());
10    let new_shape: Vec<usize> = axes.iter().map(|&a| t.shape()[a]).collect();
11    let new_strides = Tensor::compute_strides(&new_shape);
12    let size = t.size();
13    let mut new_data = vec![0.0; size];
14    let mut indices = vec![0usize; t.rank()];
15    for _ in 0..size {
16        let new_indices: Vec<usize> = axes.iter().map(|&a| indices[a]).collect();
17        let new_flat: usize = new_indices
18            .iter()
19            .zip(&new_strides)
20            .map(|(i, s)| i * s)
21            .sum();
22        new_data[new_flat] = t.get(&indices);
23        for k in (0..t.rank()).rev() {
24            indices[k] += 1;
25            if indices[k] < t.shape()[k] {
26                break;
27            }
28            indices[k] = 0;
29        }
30    }
31    Tensor {
32        data: new_data,
33        shape: new_shape,
34        strides: new_strides,
35    }
36}
37
38pub fn contract(a: &Tensor, b: &Tensor, axis_a: usize, axis_b: usize) -> Tensor {
39    assert_eq!(a.shape()[axis_a], b.shape()[axis_b]);
40    let contract_size = a.shape()[axis_a];
41    let mut result_shape = Vec::new();
42    for (i, &s) in a.shape().iter().enumerate() {
43        if i != axis_a {
44            result_shape.push(s);
45        }
46    }
47    for (i, &s) in b.shape().iter().enumerate() {
48        if i != axis_b {
49            result_shape.push(s);
50        }
51    }
52    let a_rank = a.rank();
53    let b_rank = b.rank();
54    Tensor::from_fn(&result_shape, |indices| {
55        let split = a_rank - 1;
56        let idx_a = &indices[..split];
57        let idx_b = &indices[split..];
58        let mut sum = 0.0;
59        for k in 0..contract_size {
60            let mut full_a = Vec::with_capacity(a_rank);
61            let mut ai = 0;
62            for i in 0..a_rank {
63                if i == axis_a {
64                    full_a.push(k);
65                } else {
66                    full_a.push(idx_a[ai]);
67                    ai += 1;
68                }
69            }
70            let mut full_b = Vec::with_capacity(b_rank);
71            let mut bi = 0;
72            for i in 0..b_rank {
73                if i == axis_b {
74                    full_b.push(k);
75                } else {
76                    full_b.push(idx_b[bi]);
77                    bi += 1;
78                }
79            }
80            sum += a.get(&full_a) * b.get(&full_b);
81        }
82        sum
83    })
84}
85
86pub fn outer(a: &Tensor, b: &Tensor) -> Tensor {
87    let mut new_shape = a.shape().to_vec();
88    new_shape.extend_from_slice(b.shape());
89    let a_rank = a.rank();
90    Tensor::from_fn(&new_shape, |indices| {
91        let (ia, ib) = indices.split_at(a_rank);
92        a.get(ia) * b.get(ib)
93    })
94}
95
96pub fn kronecker(a: &Tensor, b: &Tensor) -> Tensor {
97    assert!(a.rank() == 2 && b.rank() == 2);
98    let (am, an) = (a.shape()[0], a.shape()[1]);
99    let (bm, bn) = (b.shape()[0], b.shape()[1]);
100    Tensor::from_fn(&[am * bm, an * bn], |idx| {
101        a.get(&[idx[0] / bm, idx[1] / bn]) * b.get(&[idx[0] % bm, idx[1] % bn])
102    })
103}
104
105pub fn levi_civita(n: usize) -> Tensor {
106    let shape: Vec<usize> = vec![n; n];
107    Tensor::from_fn(&shape, |indices| {
108        for i in 0..n {
109            for j in (i + 1)..n {
110                if indices[i] == indices[j] {
111                    return 0.0;
112                }
113            }
114        }
115        let mut parity = 0;
116        let mut perm: Vec<usize> = indices.to_vec();
117        for i in 0..n {
118            while perm[i] != i {
119                let target = perm[i];
120                perm.swap(i, target);
121                parity += 1;
122            }
123        }
124        if parity % 2 == 0 { 1.0 } else { -1.0 }
125    })
126}
127
128pub fn metric_raise(t: &Tensor, metric_inv: &Tensor, index: usize) -> Tensor {
129    contract(t, metric_inv, index, 1)
130}
131
132pub fn metric_lower(t: &Tensor, metric: &Tensor, index: usize) -> Tensor {
133    contract(t, metric, index, 0)
134}