test_dirac_tensor/
lib.rs

1//! Small CPU tensor library for compile-time macro tensor operations.
2
3use num::complex::Complex64;
4use std::ops::{Add, BitOr, Div, Index, Mul, Sub};
5
6type R = f64;
7type C = Complex64;
8type Data = Vec<C>;
9type Shape = (usize, usize);
10
11// Static tensor format for data transfer between compile time and runtime
12type TensorData = (Shape, &'static [(f64, f64)]);
13
14pub trait ToTensor {
15    fn to_tensor(&self) -> Tensor;
16}
17
18impl ToTensor for TensorData {
19    fn to_tensor(&self) -> Tensor {
20        Tensor::new(
21            self.1
22                .iter()
23                .map(|c| Complex64::new(c.0, c.1))
24                .collect::<Vec<Complex64>>(),
25            self.0,
26        )
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct Tensor {
32    pub data: Data,
33    pub shape: Shape,
34}
35
36impl Tensor {
37    pub fn new(data: Vec<C>, shape: Shape) -> Tensor {
38        Tensor { data, shape }
39    }
40
41    pub fn eye(n: usize) -> Tensor {
42        let mut data = Vec::<C>::new();
43
44        for i in 0..n {
45            for j in 0..n {
46                if i == j {
47                    data.push(C::new(1f64, 0f64));
48                } else {
49                    data.push(C::new(0f64, 0f64));
50                }
51            }
52        }
53
54        Tensor {
55            data,
56            shape: (n, n),
57        }
58    }
59
60    pub fn item(&self) -> Option<C> {
61        if self.shape != (1, 1) {
62            return None;
63        }
64
65        Some(self.data[0])
66    }
67
68    pub fn norm_sqr(&self) -> R {
69        self.data.iter().map(|c| c.norm_sqr()).sum()
70    }
71
72    pub fn norm(&self) -> R {
73        self.norm_sqr().sqrt()
74    }
75
76    pub fn unit(&self) -> Tensor {
77        self / self.norm()
78    }
79
80    // Dagger - conjugate transpose
81    pub fn dag(&self) -> Tensor {
82        match self.shape {
83            (m, n) if m == 1 || n == 1 => {
84                Tensor::new(self.data.iter().map(|c| c.conj()).collect(), (n, m))
85            }
86            (m, n) => {
87                let mut data = Vec::<C>::new();
88
89                for j in 0..n {
90                    for i in 0..m {
91                        data.push(self[(i, j)].conj());
92                    }
93                }
94
95                Tensor::new(data, (n, m))
96            }
97        }
98    }
99
100    // Projector
101    pub fn proj(&self) -> Tensor {
102        self * &self.dag()
103    }
104
105    // Kronecker product
106    pub fn prod(&self, rhs: &Tensor) -> Tensor {
107        let shape = (self.shape.0 * rhs.shape.0, self.shape.1 * rhs.shape.1);
108        let mut data = vec![C::new(0f64, 0f64); shape.0 * shape.1];
109
110        // Walk the first matrix
111        for i in 0..self.shape.1 {
112            for j in 0..self.shape.0 {
113                // For each element, walk the second matrix
114                for k in 0..rhs.shape.1 {
115                    for l in 0..rhs.shape.0 {
116                        let x = i * rhs.shape.1 + k;
117                        let y = j * rhs.shape.0 + l;
118
119                        data[x + y * shape.1] = self[(j, i)] * rhs[(l, k)];
120                    }
121                }
122            }
123        }
124
125        Tensor::new(data, shape)
126    }
127
128    pub fn expand(&self, n: usize, i: usize) -> Tensor {
129        let eye = Tensor::eye(2);
130        let mut product = if i == 0 { self.clone() } else { eye.clone() };
131
132        for k in 1..n {
133            product = if k == i {
134                product.prod(self)
135            } else {
136                product.prod(&eye)
137            }
138        }
139
140        product
141    }
142}
143
144macro_rules! tensor_elementwise_op {
145    ( $trait:ident, $op:ident ) => {
146        impl $trait for Tensor {
147            type Output = Tensor;
148
149            fn $op(self, rhs: Tensor) -> Tensor {
150                assert!(self.shape == rhs.shape);
151
152                Tensor::new(
153                    self.data
154                        .iter()
155                        .zip(rhs.data.iter())
156                        .map(|(c1, c2)| c1.$op(c2))
157                        .collect(),
158                    self.shape,
159                )
160            }
161        }
162    };
163}
164
165impl Index<(usize, usize)> for Tensor {
166    type Output = C;
167
168    fn index(&self, index: (usize, usize)) -> &Self::Output {
169        &self.data[index.1 + index.0 * self.shape.1]
170    }
171}
172
173tensor_elementwise_op!(Add, add);
174tensor_elementwise_op!(Sub, sub);
175
176impl Div<f64> for &Tensor {
177    type Output = Tensor;
178
179    fn div(self, rhs: f64) -> Tensor {
180        Tensor::new(self.data.iter().map(|c| c / rhs).collect(), self.shape)
181    }
182}
183
184impl Mul<f64> for &Tensor {
185    type Output = Tensor;
186
187    fn mul(self, rhs: f64) -> Tensor {
188        Tensor::new(self.data.iter().map(|c| c * rhs).collect(), self.shape)
189    }
190}
191
192impl Mul<C> for &Tensor {
193    type Output = Tensor;
194
195    fn mul(self, rhs: C) -> Tensor {
196        Tensor::new(self.data.iter().map(|c| c * rhs).collect(), self.shape)
197    }
198}
199
200// Dot product
201impl BitOr for Tensor {
202    type Output = C;
203
204    fn bitor(self, rhs: Tensor) -> C {
205        self.data
206            .iter()
207            .zip(rhs.data.iter())
208            .map(|(c1, c2)| c1 * c2)
209            .sum()
210    }
211}
212
213// Matrix multiplication
214impl Mul<&Tensor> for &Tensor {
215    type Output = Tensor;
216
217    fn mul(self, rhs: &Tensor) -> Tensor {
218        assert!(self.shape.1 == rhs.shape.0 || self.shape == (1, 1) || rhs.shape == (1, 1));
219
220        if self.shape == (1, 1) {
221            return rhs * self.item().unwrap();
222        }
223
224        if rhs.shape == (1, 1) {
225            return self * rhs.item().unwrap();
226        }
227
228        let shape = (self.shape.0, rhs.shape.1);
229        let mut data = Vec::<C>::new();
230        let n = self.shape.1;
231
232        for i in 0..shape.0 {
233            for j in 0..shape.1 {
234                data.push((0..n).map(|k| self[(i, k)] * rhs[(k, j)]).sum());
235            }
236        }
237
238        Tensor::new(data, shape)
239    }
240}
241
242impl Mul for Tensor {
243    type Output = Tensor;
244
245    fn mul(self, rhs: Tensor) -> Tensor {
246        &self * &rhs
247    }
248}
249
250impl Div<&Tensor> for &Tensor {
251    type Output = Tensor;
252
253    fn div(self, rhs: &Tensor) -> Tensor {
254        assert!(rhs.shape == (1, 1));
255
256        Tensor::new(
257            self.data.iter().map(|c| c / rhs.item().unwrap()).collect(),
258            self.shape,
259        )
260    }
261}
262
263impl Div for Tensor {
264    type Output = Tensor;
265
266    fn div(self, rhs: Tensor) -> Tensor {
267        &self / &rhs
268    }
269}
270
271/// Converts some type to a tensor
272pub trait AsTensor {
273    fn as_tensor(&self) -> Tensor;
274}
275
276impl AsTensor for char {
277    fn as_tensor(&self) -> Tensor {
278        match self {
279            '0' => Tensor::new(vec![C::new(1.0, 0.0), C::new(0.0, 0.0)], (2, 1)),
280            '1' => Tensor::new(vec![C::new(0.0, 0.0), C::new(1.0, 0.0)], (2, 1)),
281            '+' => Tensor::new(vec![C::new(1.0, 0.0), C::new(1.0, 0.0)], (2, 1)).unit(),
282            '-' => Tensor::new(vec![C::new(1.0, 0.0), C::new(0.0, -1.0)], (2, 1)).unit(),
283            not_well_known => {
284                panic!(
285                    "Cannot decode '{}' into a qubit state: only (0, 1, +, -) supported",
286                    not_well_known
287                );
288            }
289        }
290    }
291}
292
293pub trait KroneckerProduct {
294    fn prod(&self) -> Tensor;
295}
296
297impl KroneckerProduct for Vec<Tensor> {
298    fn prod(&self) -> Tensor {
299        match self
300            .iter()
301            .cloned()
302            .reduce(|product, tensor| product.prod(&tensor))
303        {
304            Some(tensor) => tensor,
305            None => panic!("Should always be called on a nonempty vector"),
306        }
307    }
308}