ruvector_math/tensor_networks/
contraction.rs1use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct TensorNode {
10 pub id: usize,
12 pub data: Vec<f64>,
14 pub leg_dims: Vec<usize>,
16 pub leg_labels: Vec<String>,
18}
19
20impl TensorNode {
21 pub fn new(id: usize, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> Self {
23 let expected_size: usize = leg_dims.iter().product();
24 assert_eq!(data.len(), expected_size);
25 assert_eq!(leg_dims.len(), leg_labels.len());
26
27 Self { id, data, leg_dims, leg_labels }
28 }
29
30 pub fn num_legs(&self) -> usize {
32 self.leg_dims.len()
33 }
34
35 pub fn size(&self) -> usize {
37 self.data.len()
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct TensorNetwork {
44 nodes: Vec<TensorNode>,
46 next_id: usize,
48}
49
50impl TensorNetwork {
51 pub fn new() -> Self {
53 Self { nodes: Vec::new(), next_id: 0 }
54 }
55
56 pub fn add_node(&mut self, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> usize {
58 let id = self.next_id;
59 self.next_id += 1;
60 self.nodes.push(TensorNode::new(id, data, leg_dims, leg_labels));
61 id
62 }
63
64 pub fn get_node(&self, id: usize) -> Option<&TensorNode> {
66 self.nodes.iter().find(|n| n.id == id)
67 }
68
69 pub fn num_nodes(&self) -> usize {
71 self.nodes.len()
72 }
73
74 pub fn contract(&mut self, id1: usize, id2: usize) -> Option<usize> {
76 let node1_idx = self.nodes.iter().position(|n| n.id == id1)?;
77 let node2_idx = self.nodes.iter().position(|n| n.id == id2)?;
78
79 let node1 = &self.nodes[node1_idx];
81 let node2 = &self.nodes[node2_idx];
82
83 let mut contract_pairs: Vec<(usize, usize)> = Vec::new();
84
85 for (i1, label1) in node1.leg_labels.iter().enumerate() {
86 for (i2, label2) in node2.leg_labels.iter().enumerate() {
87 if label1 == label2 && !label1.starts_with("open_") {
88 assert_eq!(node1.leg_dims[i1], node2.leg_dims[i2], "Dimension mismatch");
89 contract_pairs.push((i1, i2));
90 }
91 }
92 }
93
94 if contract_pairs.is_empty() {
95 return self.outer_product(id1, id2);
97 }
98
99 let result = contract_tensors(node1, node2, &contract_pairs);
101
102 self.nodes.retain(|n| n.id != id1 && n.id != id2);
104
105 let new_id = self.next_id;
106 self.next_id += 1;
107 self.nodes.push(TensorNode::new(new_id, result.0, result.1, result.2));
108
109 Some(new_id)
110 }
111
112 fn outer_product(&mut self, id1: usize, id2: usize) -> Option<usize> {
114 let node1 = self.nodes.iter().find(|n| n.id == id1)?;
115 let node2 = self.nodes.iter().find(|n| n.id == id2)?;
116
117 let mut new_data = Vec::with_capacity(node1.size() * node2.size());
118 for &a in &node1.data {
119 for &b in &node2.data {
120 new_data.push(a * b);
121 }
122 }
123
124 let mut new_dims = node1.leg_dims.clone();
125 new_dims.extend(node2.leg_dims.iter());
126
127 let mut new_labels = node1.leg_labels.clone();
128 new_labels.extend(node2.leg_labels.iter().cloned());
129
130 self.nodes.retain(|n| n.id != id1 && n.id != id2);
131
132 let new_id = self.next_id;
133 self.next_id += 1;
134 self.nodes.push(TensorNode::new(new_id, new_data, new_dims, new_labels));
135
136 Some(new_id)
137 }
138
139 pub fn contract_all(&mut self) -> Option<f64> {
141 while self.nodes.len() > 1 {
142 let mut found = None;
144 'outer: for i in 0..self.nodes.len() {
145 for j in i + 1..self.nodes.len() {
146 for label in &self.nodes[i].leg_labels {
147 if !label.starts_with("open_") && self.nodes[j].leg_labels.contains(label) {
148 found = Some((self.nodes[i].id, self.nodes[j].id));
149 break 'outer;
150 }
151 }
152 }
153 }
154
155 if let Some((id1, id2)) = found {
156 self.contract(id1, id2)?;
157 } else {
158 break;
160 }
161 }
162
163 if self.nodes.len() == 1 && self.nodes[0].leg_dims.is_empty() {
164 Some(self.nodes[0].data[0])
165 } else {
166 None
167 }
168 }
169}
170
171impl Default for TensorNetwork {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177fn contract_tensors(
179 node1: &TensorNode,
180 node2: &TensorNode,
181 contract_pairs: &[(usize, usize)],
182) -> (Vec<f64>, Vec<usize>, Vec<String>) {
183 let mut out_dims = Vec::new();
185 let mut out_labels = Vec::new();
186
187 let contracted1: Vec<usize> = contract_pairs.iter().map(|p| p.0).collect();
188 let contracted2: Vec<usize> = contract_pairs.iter().map(|p| p.1).collect();
189
190 for (i, (dim, label)) in node1.leg_dims.iter().zip(node1.leg_labels.iter()).enumerate() {
191 if !contracted1.contains(&i) {
192 out_dims.push(*dim);
193 out_labels.push(label.clone());
194 }
195 }
196
197 for (i, (dim, label)) in node2.leg_dims.iter().zip(node2.leg_labels.iter()).enumerate() {
198 if !contracted2.contains(&i) {
199 out_dims.push(*dim);
200 out_labels.push(label.clone());
201 }
202 }
203
204 let out_size: usize = if out_dims.is_empty() { 1 } else { out_dims.iter().product() };
205 let mut out_data = vec![0.0; out_size];
206
207 let size1 = node1.size();
209 let size2 = node2.size();
210
211 let strides1 = compute_strides(&node1.leg_dims);
212 let strides2 = compute_strides(&node2.leg_dims);
213 let out_strides = compute_strides(&out_dims);
214
215 let mut out_indices = vec![0usize; out_dims.len()];
217 for out_flat in 0..out_size {
218 let contract_sizes: Vec<usize> = contract_pairs.iter().map(|p| node1.leg_dims[p.0]).collect();
221 let contract_total: usize = if contract_sizes.is_empty() { 1 } else { contract_sizes.iter().product() };
222
223 let mut sum = 0.0;
224
225 for contract_flat in 0..contract_total {
226 let mut idx1 = vec![0usize; node1.num_legs()];
228 let mut idx2 = vec![0usize; node2.num_legs()];
229
230 let mut cf = contract_flat;
232 for (pi, &(i1, i2)) in contract_pairs.iter().enumerate() {
233 let ci = cf % contract_sizes[pi];
234 cf /= contract_sizes[pi];
235 idx1[i1] = ci;
236 idx2[i2] = ci;
237 }
238
239 let mut out_idx_copy = out_flat;
241 let mut free1_pos = 0;
242 let mut free2_pos = 0;
243
244 for i in 0..node1.num_legs() {
245 if !contracted1.contains(&i) {
246 if free1_pos < out_dims.len() {
247 idx1[i] = (out_idx_copy / out_strides.get(free1_pos).unwrap_or(&1)) % node1.leg_dims[i];
248 }
249 free1_pos += 1;
250 }
251 }
252
253 for i in 0..node2.num_legs() {
254 if !contracted2.contains(&i) {
255 let pos = (node1.num_legs() - contracted1.len()) + free2_pos;
256 if pos < out_dims.len() {
257 idx2[i] = (out_flat / out_strides.get(pos).unwrap_or(&1)) % node2.leg_dims[i];
258 }
259 free2_pos += 1;
260 }
261 }
262
263 let lin1: usize = idx1.iter().zip(strides1.iter()).map(|(i, s)| i * s).sum();
265 let lin2: usize = idx2.iter().zip(strides2.iter()).map(|(i, s)| i * s).sum();
266
267 sum += node1.data[lin1.min(node1.data.len() - 1)] * node2.data[lin2.min(node2.data.len() - 1)];
268 }
269
270 out_data[out_flat] = sum;
271 }
272
273 (out_data, out_dims, out_labels)
274}
275
276fn compute_strides(dims: &[usize]) -> Vec<usize> {
277 let mut strides = Vec::with_capacity(dims.len());
278 let mut stride = 1;
279 for &d in dims.iter().rev() {
280 strides.push(stride);
281 stride *= d;
282 }
283 strides.reverse();
284 strides
285}
286
287#[derive(Debug, Clone)]
289pub struct NetworkContraction {
290 pub estimated_cost: f64,
292}
293
294impl NetworkContraction {
295 pub fn greedy_order(network: &TensorNetwork) -> Vec<(usize, usize)> {
297 let mut order = Vec::new();
298 let mut remaining: Vec<usize> = network.nodes.iter().map(|n| n.id).collect();
299
300 while remaining.len() > 1 {
301 let mut best_pair = None;
303 let mut best_cost = f64::INFINITY;
304
305 for i in 0..remaining.len() {
306 for j in i + 1..remaining.len() {
307 let id1 = remaining[i];
308 let id2 = remaining[j];
309
310 if let (Some(n1), Some(n2)) = (network.get_node(id1), network.get_node(id2)) {
311 let cost = estimate_contraction_cost(n1, n2);
312 if cost < best_cost {
313 best_cost = cost;
314 best_pair = Some((i, j));
315 }
316 }
317 }
318 }
319
320 if let Some((i, j)) = best_pair {
321 let id1 = remaining[i];
322 let id2 = remaining[j];
323 order.push((id1, id2));
324
325 remaining.remove(j);
327 remaining.remove(i);
328 } else {
330 break;
331 }
332 }
333
334 order
335 }
336}
337
338fn estimate_contraction_cost(n1: &TensorNode, n2: &TensorNode) -> f64 {
339 let size1: usize = n1.leg_dims.iter().product();
341 let size2: usize = n2.leg_dims.iter().product();
342
343 let mut contracted_size = 1usize;
345 for (i1, label1) in n1.leg_labels.iter().enumerate() {
346 for (i2, label2) in n2.leg_labels.iter().enumerate() {
347 if label1 == label2 && !label1.starts_with("open_") {
348 contracted_size *= n1.leg_dims[i1];
349 }
350 }
351 }
352
353 (size1 * size2 / contracted_size.max(1)) as f64
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_tensor_network_creation() {
363 let mut network = TensorNetwork::new();
364
365 let id1 = network.add_node(
366 vec![1.0, 2.0, 3.0, 4.0],
367 vec![2, 2],
368 vec!["i".into(), "j".into()],
369 );
370
371 let id2 = network.add_node(
372 vec![1.0, 0.0, 0.0, 1.0],
373 vec![2, 2],
374 vec!["j".into(), "k".into()],
375 );
376
377 assert_eq!(network.num_nodes(), 2);
378 }
379
380 #[test]
381 fn test_matrix_contraction() {
382 let mut network = TensorNetwork::new();
383
384 let id1 = network.add_node(
386 vec![1.0, 2.0, 3.0, 4.0],
387 vec![2, 2],
388 vec!["i".into(), "j".into()],
389 );
390
391 let id2 = network.add_node(
393 vec![1.0, 0.0, 0.0, 1.0],
394 vec![2, 2],
395 vec!["j".into(), "k".into()],
396 );
397
398 let result_id = network.contract(id1, id2).unwrap();
399 let result = network.get_node(result_id).unwrap();
400
401 assert_eq!(result.data.len(), 4);
403 }
405
406 #[test]
407 fn test_vector_dot_product() {
408 let mut network = TensorNetwork::new();
409
410 let id1 = network.add_node(
412 vec![1.0, 2.0, 3.0],
413 vec![3],
414 vec!["i".into()],
415 );
416
417 let id2 = network.add_node(
419 vec![1.0, 1.0, 1.0],
420 vec![3],
421 vec!["i".into()],
422 );
423
424 let result_id = network.contract(id1, id2).unwrap();
425 let result = network.get_node(result_id).unwrap();
426
427 assert_eq!(result.data.len(), 1);
429 assert!((result.data[0] - 6.0).abs() < 1e-10);
430 }
431}