ruvector_math/tensor_networks/
contraction.rs1use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct TensorNode {
10 pub id: usize,
12 pub data: Vec<f64>,
14 pub leg_dims: Vec<usize>,
16 pub leg_labels: Vec<String>,
18}
19
20impl TensorNode {
21 pub fn new(id: usize, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> Self {
23 let expected_size: usize = leg_dims.iter().product();
24 assert_eq!(data.len(), expected_size);
25 assert_eq!(leg_dims.len(), leg_labels.len());
26
27 Self {
28 id,
29 data,
30 leg_dims,
31 leg_labels,
32 }
33 }
34
35 pub fn num_legs(&self) -> usize {
37 self.leg_dims.len()
38 }
39
40 pub fn size(&self) -> usize {
42 self.data.len()
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct TensorNetwork {
49 nodes: Vec<TensorNode>,
51 next_id: usize,
53}
54
55impl TensorNetwork {
56 pub fn new() -> Self {
58 Self {
59 nodes: Vec::new(),
60 next_id: 0,
61 }
62 }
63
64 pub fn add_node(
66 &mut self,
67 data: Vec<f64>,
68 leg_dims: Vec<usize>,
69 leg_labels: Vec<String>,
70 ) -> usize {
71 let id = self.next_id;
72 self.next_id += 1;
73 self.nodes
74 .push(TensorNode::new(id, data, leg_dims, leg_labels));
75 id
76 }
77
78 pub fn get_node(&self, id: usize) -> Option<&TensorNode> {
80 self.nodes.iter().find(|n| n.id == id)
81 }
82
83 pub fn num_nodes(&self) -> usize {
85 self.nodes.len()
86 }
87
88 pub fn contract(&mut self, id1: usize, id2: usize) -> Option<usize> {
90 let node1_idx = self.nodes.iter().position(|n| n.id == id1)?;
91 let node2_idx = self.nodes.iter().position(|n| n.id == id2)?;
92
93 let node1 = &self.nodes[node1_idx];
95 let node2 = &self.nodes[node2_idx];
96
97 let mut contract_pairs: Vec<(usize, usize)> = Vec::new();
98
99 for (i1, label1) in node1.leg_labels.iter().enumerate() {
100 for (i2, label2) in node2.leg_labels.iter().enumerate() {
101 if label1 == label2 && !label1.starts_with("open_") {
102 assert_eq!(node1.leg_dims[i1], node2.leg_dims[i2], "Dimension mismatch");
103 contract_pairs.push((i1, i2));
104 }
105 }
106 }
107
108 if contract_pairs.is_empty() {
109 return self.outer_product(id1, id2);
111 }
112
113 let result = contract_tensors(node1, node2, &contract_pairs);
115
116 self.nodes.retain(|n| n.id != id1 && n.id != id2);
118
119 let new_id = self.next_id;
120 self.next_id += 1;
121 self.nodes
122 .push(TensorNode::new(new_id, result.0, result.1, result.2));
123
124 Some(new_id)
125 }
126
127 fn outer_product(&mut self, id1: usize, id2: usize) -> Option<usize> {
129 let node1 = self.nodes.iter().find(|n| n.id == id1)?;
130 let node2 = self.nodes.iter().find(|n| n.id == id2)?;
131
132 let mut new_data = Vec::with_capacity(node1.size() * node2.size());
133 for &a in &node1.data {
134 for &b in &node2.data {
135 new_data.push(a * b);
136 }
137 }
138
139 let mut new_dims = node1.leg_dims.clone();
140 new_dims.extend(node2.leg_dims.iter());
141
142 let mut new_labels = node1.leg_labels.clone();
143 new_labels.extend(node2.leg_labels.iter().cloned());
144
145 self.nodes.retain(|n| n.id != id1 && n.id != id2);
146
147 let new_id = self.next_id;
148 self.next_id += 1;
149 self.nodes
150 .push(TensorNode::new(new_id, new_data, new_dims, new_labels));
151
152 Some(new_id)
153 }
154
155 pub fn contract_all(&mut self) -> Option<f64> {
157 while self.nodes.len() > 1 {
158 let mut found = None;
160 'outer: for i in 0..self.nodes.len() {
161 for j in i + 1..self.nodes.len() {
162 for label in &self.nodes[i].leg_labels {
163 if !label.starts_with("open_") && self.nodes[j].leg_labels.contains(label) {
164 found = Some((self.nodes[i].id, self.nodes[j].id));
165 break 'outer;
166 }
167 }
168 }
169 }
170
171 if let Some((id1, id2)) = found {
172 self.contract(id1, id2)?;
173 } else {
174 break;
176 }
177 }
178
179 if self.nodes.len() == 1 && self.nodes[0].leg_dims.is_empty() {
180 Some(self.nodes[0].data[0])
181 } else {
182 None
183 }
184 }
185}
186
187impl Default for TensorNetwork {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193fn contract_tensors(
195 node1: &TensorNode,
196 node2: &TensorNode,
197 contract_pairs: &[(usize, usize)],
198) -> (Vec<f64>, Vec<usize>, Vec<String>) {
199 let mut out_dims = Vec::new();
201 let mut out_labels = Vec::new();
202
203 let contracted1: Vec<usize> = contract_pairs.iter().map(|p| p.0).collect();
204 let contracted2: Vec<usize> = contract_pairs.iter().map(|p| p.1).collect();
205
206 for (i, (dim, label)) in node1
207 .leg_dims
208 .iter()
209 .zip(node1.leg_labels.iter())
210 .enumerate()
211 {
212 if !contracted1.contains(&i) {
213 out_dims.push(*dim);
214 out_labels.push(label.clone());
215 }
216 }
217
218 for (i, (dim, label)) in node2
219 .leg_dims
220 .iter()
221 .zip(node2.leg_labels.iter())
222 .enumerate()
223 {
224 if !contracted2.contains(&i) {
225 out_dims.push(*dim);
226 out_labels.push(label.clone());
227 }
228 }
229
230 let out_size: usize = if out_dims.is_empty() {
231 1
232 } else {
233 out_dims.iter().product()
234 };
235 let mut out_data = vec![0.0; out_size];
236
237 let size1 = node1.size();
239 let size2 = node2.size();
240
241 let strides1 = compute_strides(&node1.leg_dims);
242 let strides2 = compute_strides(&node2.leg_dims);
243 let out_strides = compute_strides(&out_dims);
244
245 let mut out_indices = vec![0usize; out_dims.len()];
247 for out_flat in 0..out_size {
248 let contract_sizes: Vec<usize> =
251 contract_pairs.iter().map(|p| node1.leg_dims[p.0]).collect();
252 let contract_total: usize = if contract_sizes.is_empty() {
253 1
254 } else {
255 contract_sizes.iter().product()
256 };
257
258 let mut sum = 0.0;
259
260 for contract_flat in 0..contract_total {
261 let mut idx1 = vec![0usize; node1.num_legs()];
263 let mut idx2 = vec![0usize; node2.num_legs()];
264
265 let mut cf = contract_flat;
267 for (pi, &(i1, i2)) in contract_pairs.iter().enumerate() {
268 let ci = cf % contract_sizes[pi];
269 cf /= contract_sizes[pi];
270 idx1[i1] = ci;
271 idx2[i2] = ci;
272 }
273
274 let mut out_idx_copy = out_flat;
276 let mut free1_pos = 0;
277 let mut free2_pos = 0;
278
279 for i in 0..node1.num_legs() {
280 if !contracted1.contains(&i) {
281 if free1_pos < out_dims.len() {
282 idx1[i] = (out_idx_copy / out_strides.get(free1_pos).unwrap_or(&1))
283 % node1.leg_dims[i];
284 }
285 free1_pos += 1;
286 }
287 }
288
289 for i in 0..node2.num_legs() {
290 if !contracted2.contains(&i) {
291 let pos = (node1.num_legs() - contracted1.len()) + free2_pos;
292 if pos < out_dims.len() {
293 idx2[i] =
294 (out_flat / out_strides.get(pos).unwrap_or(&1)) % node2.leg_dims[i];
295 }
296 free2_pos += 1;
297 }
298 }
299
300 let lin1: usize = idx1.iter().zip(strides1.iter()).map(|(i, s)| i * s).sum();
302 let lin2: usize = idx2.iter().zip(strides2.iter()).map(|(i, s)| i * s).sum();
303
304 sum += node1.data[lin1.min(node1.data.len() - 1)]
305 * node2.data[lin2.min(node2.data.len() - 1)];
306 }
307
308 out_data[out_flat] = sum;
309 }
310
311 (out_data, out_dims, out_labels)
312}
313
314fn compute_strides(dims: &[usize]) -> Vec<usize> {
315 let mut strides = Vec::with_capacity(dims.len());
316 let mut stride = 1;
317 for &d in dims.iter().rev() {
318 strides.push(stride);
319 stride *= d;
320 }
321 strides.reverse();
322 strides
323}
324
325#[derive(Debug, Clone)]
327pub struct NetworkContraction {
328 pub estimated_cost: f64,
330}
331
332impl NetworkContraction {
333 pub fn greedy_order(network: &TensorNetwork) -> Vec<(usize, usize)> {
335 let mut order = Vec::new();
336 let mut remaining: Vec<usize> = network.nodes.iter().map(|n| n.id).collect();
337
338 while remaining.len() > 1 {
339 let mut best_pair = None;
341 let mut best_cost = f64::INFINITY;
342
343 for i in 0..remaining.len() {
344 for j in i + 1..remaining.len() {
345 let id1 = remaining[i];
346 let id2 = remaining[j];
347
348 if let (Some(n1), Some(n2)) = (network.get_node(id1), network.get_node(id2)) {
349 let cost = estimate_contraction_cost(n1, n2);
350 if cost < best_cost {
351 best_cost = cost;
352 best_pair = Some((i, j));
353 }
354 }
355 }
356 }
357
358 if let Some((i, j)) = best_pair {
359 let id1 = remaining[i];
360 let id2 = remaining[j];
361 order.push((id1, id2));
362
363 remaining.remove(j);
365 remaining.remove(i);
366 } else {
368 break;
369 }
370 }
371
372 order
373 }
374}
375
376fn estimate_contraction_cost(n1: &TensorNode, n2: &TensorNode) -> f64 {
377 let size1: usize = n1.leg_dims.iter().product();
379 let size2: usize = n2.leg_dims.iter().product();
380
381 let mut contracted_size = 1usize;
383 for (i1, label1) in n1.leg_labels.iter().enumerate() {
384 for (i2, label2) in n2.leg_labels.iter().enumerate() {
385 if label1 == label2 && !label1.starts_with("open_") {
386 contracted_size *= n1.leg_dims[i1];
387 }
388 }
389 }
390
391 (size1 * size2 / contracted_size.max(1)) as f64
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_tensor_network_creation() {
401 let mut network = TensorNetwork::new();
402
403 let id1 = network.add_node(
404 vec![1.0, 2.0, 3.0, 4.0],
405 vec![2, 2],
406 vec!["i".into(), "j".into()],
407 );
408
409 let id2 = network.add_node(
410 vec![1.0, 0.0, 0.0, 1.0],
411 vec![2, 2],
412 vec!["j".into(), "k".into()],
413 );
414
415 assert_eq!(network.num_nodes(), 2);
416 }
417
418 #[test]
419 fn test_matrix_contraction() {
420 let mut network = TensorNetwork::new();
421
422 let id1 = network.add_node(
424 vec![1.0, 2.0, 3.0, 4.0],
425 vec![2, 2],
426 vec!["i".into(), "j".into()],
427 );
428
429 let id2 = network.add_node(
431 vec![1.0, 0.0, 0.0, 1.0],
432 vec![2, 2],
433 vec!["j".into(), "k".into()],
434 );
435
436 let result_id = network.contract(id1, id2).unwrap();
437 let result = network.get_node(result_id).unwrap();
438
439 assert_eq!(result.data.len(), 4);
441 }
443
444 #[test]
445 fn test_vector_dot_product() {
446 let mut network = TensorNetwork::new();
447
448 let id1 = network.add_node(vec![1.0, 2.0, 3.0], vec![3], vec!["i".into()]);
450
451 let id2 = network.add_node(vec![1.0, 1.0, 1.0], vec![3], vec!["i".into()]);
453
454 let result_id = network.contract(id1, id2).unwrap();
455 let result = network.get_node(result_id).unwrap();
456
457 assert_eq!(result.data.len(), 1);
459 assert!((result.data[0] - 6.0).abs() < 1e-10);
460 }
461}