1use num::complex::Complex64;
4use std::ops::{Add, BitOr, Div, Index, Mul, Sub};
5
6type R = f64;
7type C = Complex64;
8type Data = Vec<C>;
9type Shape = (usize, usize);
10
11type TensorData = (Shape, &'static [(f64, f64)]);
13
14pub trait ToTensor {
15 fn to_tensor(&self) -> Tensor;
16}
17
18impl ToTensor for TensorData {
19 fn to_tensor(&self) -> Tensor {
20 Tensor::new(
21 self.1
22 .iter()
23 .map(|c| Complex64::new(c.0, c.1))
24 .collect::<Vec<Complex64>>(),
25 self.0,
26 )
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct Tensor {
32 pub data: Data,
33 pub shape: Shape,
34}
35
36impl Tensor {
37 pub fn new(data: Vec<C>, shape: Shape) -> Tensor {
38 Tensor { data, shape }
39 }
40
41 pub fn eye(n: usize) -> Tensor {
42 let mut data = Vec::<C>::new();
43
44 for i in 0..n {
45 for j in 0..n {
46 if i == j {
47 data.push(C::new(1f64, 0f64));
48 } else {
49 data.push(C::new(0f64, 0f64));
50 }
51 }
52 }
53
54 Tensor {
55 data,
56 shape: (n, n),
57 }
58 }
59
60 pub fn item(&self) -> Option<C> {
61 if self.shape != (1, 1) {
62 return None;
63 }
64
65 Some(self.data[0])
66 }
67
68 pub fn norm_sqr(&self) -> R {
69 self.data.iter().map(|c| c.norm_sqr()).sum()
70 }
71
72 pub fn norm(&self) -> R {
73 self.norm_sqr().sqrt()
74 }
75
76 pub fn unit(&self) -> Tensor {
77 self / self.norm()
78 }
79
80 pub fn dag(&self) -> Tensor {
82 match self.shape {
83 (m, n) if m == 1 || n == 1 => {
84 Tensor::new(self.data.iter().map(|c| c.conj()).collect(), (n, m))
85 }
86 (m, n) => {
87 let mut data = Vec::<C>::new();
88
89 for j in 0..n {
90 for i in 0..m {
91 data.push(self[(i, j)].conj());
92 }
93 }
94
95 Tensor::new(data, (n, m))
96 }
97 }
98 }
99
100 pub fn proj(&self) -> Tensor {
102 self * &self.dag()
103 }
104
105 pub fn prod(&self, rhs: &Tensor) -> Tensor {
107 let shape = (self.shape.0 * rhs.shape.0, self.shape.1 * rhs.shape.1);
108 let mut data = vec![C::new(0f64, 0f64); shape.0 * shape.1];
109
110 for i in 0..self.shape.1 {
112 for j in 0..self.shape.0 {
113 for k in 0..rhs.shape.1 {
115 for l in 0..rhs.shape.0 {
116 let x = i * rhs.shape.1 + k;
117 let y = j * rhs.shape.0 + l;
118
119 data[x + y * shape.1] = self[(j, i)] * rhs[(l, k)];
120 }
121 }
122 }
123 }
124
125 Tensor::new(data, shape)
126 }
127
128 pub fn expand(&self, n: usize, i: usize) -> Tensor {
129 let eye = Tensor::eye(2);
130 let mut product = if i == 0 { self.clone() } else { eye.clone() };
131
132 for k in 1..n {
133 product = if k == i {
134 product.prod(self)
135 } else {
136 product.prod(&eye)
137 }
138 }
139
140 product
141 }
142}
143
144macro_rules! tensor_elementwise_op {
145 ( $trait:ident, $op:ident ) => {
146 impl $trait for Tensor {
147 type Output = Tensor;
148
149 fn $op(self, rhs: Tensor) -> Tensor {
150 assert!(self.shape == rhs.shape);
151
152 Tensor::new(
153 self.data
154 .iter()
155 .zip(rhs.data.iter())
156 .map(|(c1, c2)| c1.$op(c2))
157 .collect(),
158 self.shape,
159 )
160 }
161 }
162 };
163}
164
165impl Index<(usize, usize)> for Tensor {
166 type Output = C;
167
168 fn index(&self, index: (usize, usize)) -> &Self::Output {
169 &self.data[index.1 + index.0 * self.shape.1]
170 }
171}
172
173tensor_elementwise_op!(Add, add);
174tensor_elementwise_op!(Sub, sub);
175
176impl Div<f64> for &Tensor {
177 type Output = Tensor;
178
179 fn div(self, rhs: f64) -> Tensor {
180 Tensor::new(self.data.iter().map(|c| c / rhs).collect(), self.shape)
181 }
182}
183
184impl Mul<f64> for &Tensor {
185 type Output = Tensor;
186
187 fn mul(self, rhs: f64) -> Tensor {
188 Tensor::new(self.data.iter().map(|c| c * rhs).collect(), self.shape)
189 }
190}
191
192impl Mul<C> for &Tensor {
193 type Output = Tensor;
194
195 fn mul(self, rhs: C) -> Tensor {
196 Tensor::new(self.data.iter().map(|c| c * rhs).collect(), self.shape)
197 }
198}
199
200impl BitOr for Tensor {
202 type Output = C;
203
204 fn bitor(self, rhs: Tensor) -> C {
205 self.data
206 .iter()
207 .zip(rhs.data.iter())
208 .map(|(c1, c2)| c1 * c2)
209 .sum()
210 }
211}
212
213impl Mul<&Tensor> for &Tensor {
215 type Output = Tensor;
216
217 fn mul(self, rhs: &Tensor) -> Tensor {
218 assert!(self.shape.1 == rhs.shape.0 || self.shape == (1, 1) || rhs.shape == (1, 1));
219
220 if self.shape == (1, 1) {
221 return rhs * self.item().unwrap();
222 }
223
224 if rhs.shape == (1, 1) {
225 return self * rhs.item().unwrap();
226 }
227
228 let shape = (self.shape.0, rhs.shape.1);
229 let mut data = Vec::<C>::new();
230 let n = self.shape.1;
231
232 for i in 0..shape.0 {
233 for j in 0..shape.1 {
234 data.push((0..n).map(|k| self[(i, k)] * rhs[(k, j)]).sum());
235 }
236 }
237
238 Tensor::new(data, shape)
239 }
240}
241
242impl Mul for Tensor {
243 type Output = Tensor;
244
245 fn mul(self, rhs: Tensor) -> Tensor {
246 &self * &rhs
247 }
248}
249
250impl Div<&Tensor> for &Tensor {
251 type Output = Tensor;
252
253 fn div(self, rhs: &Tensor) -> Tensor {
254 assert!(rhs.shape == (1, 1));
255
256 Tensor::new(
257 self.data.iter().map(|c| c / rhs.item().unwrap()).collect(),
258 self.shape,
259 )
260 }
261}
262
263impl Div for Tensor {
264 type Output = Tensor;
265
266 fn div(self, rhs: Tensor) -> Tensor {
267 &self / &rhs
268 }
269}
270
271pub trait AsTensor {
273 fn as_tensor(&self) -> Tensor;
274}
275
276impl AsTensor for char {
277 fn as_tensor(&self) -> Tensor {
278 match self {
279 '0' => Tensor::new(vec![C::new(1.0, 0.0), C::new(0.0, 0.0)], (2, 1)),
280 '1' => Tensor::new(vec![C::new(0.0, 0.0), C::new(1.0, 0.0)], (2, 1)),
281 '+' => Tensor::new(vec![C::new(1.0, 0.0), C::new(1.0, 0.0)], (2, 1)).unit(),
282 '-' => Tensor::new(vec![C::new(1.0, 0.0), C::new(0.0, -1.0)], (2, 1)).unit(),
283 not_well_known => {
284 panic!(
285 "Cannot decode '{}' into a qubit state: only (0, 1, +, -) supported",
286 not_well_known
287 );
288 }
289 }
290 }
291}
292
293pub trait KroneckerProduct {
294 fn prod(&self) -> Tensor;
295}
296
297impl KroneckerProduct for Vec<Tensor> {
298 fn prod(&self) -> Tensor {
299 match self
300 .iter()
301 .cloned()
302 .reduce(|product, tensor| product.prod(&tensor))
303 {
304 Some(tensor) => tensor,
305 None => panic!("Should always be called on a nonempty vector"),
306 }
307 }
308}