Skip to main content

ruvector_math/tensor_networks/
contraction.rs

1//! Tensor Network Contraction
2//!
3//! General tensor network operations for quantum-inspired algorithms.
4
5use std::collections::HashMap;
6
7/// A node in a tensor network
8#[derive(Debug, Clone)]
9pub struct TensorNode {
10    /// Node identifier
11    pub id: usize,
12    /// Tensor data
13    pub data: Vec<f64>,
14    /// Dimensions of each leg
15    pub leg_dims: Vec<usize>,
16    /// Labels for each leg (for contraction)
17    pub leg_labels: Vec<String>,
18}
19
20impl TensorNode {
21    /// Create new tensor node
22    pub fn new(id: usize, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> Self {
23        let expected_size: usize = leg_dims.iter().product();
24        assert_eq!(data.len(), expected_size);
25        assert_eq!(leg_dims.len(), leg_labels.len());
26
27        Self {
28            id,
29            data,
30            leg_dims,
31            leg_labels,
32        }
33    }
34
35    /// Number of legs
36    pub fn num_legs(&self) -> usize {
37        self.leg_dims.len()
38    }
39
40    /// Total size
41    pub fn size(&self) -> usize {
42        self.data.len()
43    }
44}
45
46/// Tensor network for contraction operations
47#[derive(Debug, Clone)]
48pub struct TensorNetwork {
49    /// Nodes in the network
50    nodes: Vec<TensorNode>,
51    /// Next node ID
52    next_id: usize,
53}
54
55impl TensorNetwork {
56    /// Create empty network
57    pub fn new() -> Self {
58        Self {
59            nodes: Vec::new(),
60            next_id: 0,
61        }
62    }
63
64    /// Add a tensor node
65    pub fn add_node(
66        &mut self,
67        data: Vec<f64>,
68        leg_dims: Vec<usize>,
69        leg_labels: Vec<String>,
70    ) -> usize {
71        let id = self.next_id;
72        self.next_id += 1;
73        self.nodes
74            .push(TensorNode::new(id, data, leg_dims, leg_labels));
75        id
76    }
77
78    /// Get node by ID
79    pub fn get_node(&self, id: usize) -> Option<&TensorNode> {
80        self.nodes.iter().find(|n| n.id == id)
81    }
82
83    /// Number of nodes
84    pub fn num_nodes(&self) -> usize {
85        self.nodes.len()
86    }
87
88    /// Contract two nodes on matching labels
89    pub fn contract(&mut self, id1: usize, id2: usize) -> Option<usize> {
90        let node1_idx = self.nodes.iter().position(|n| n.id == id1)?;
91        let node2_idx = self.nodes.iter().position(|n| n.id == id2)?;
92
93        // Find matching labels
94        let node1 = &self.nodes[node1_idx];
95        let node2 = &self.nodes[node2_idx];
96
97        let mut contract_pairs: Vec<(usize, usize)> = Vec::new();
98
99        for (i1, label1) in node1.leg_labels.iter().enumerate() {
100            for (i2, label2) in node2.leg_labels.iter().enumerate() {
101                if label1 == label2 && !label1.starts_with("open_") {
102                    assert_eq!(node1.leg_dims[i1], node2.leg_dims[i2], "Dimension mismatch");
103                    contract_pairs.push((i1, i2));
104                }
105            }
106        }
107
108        if contract_pairs.is_empty() {
109            // Outer product
110            return self.outer_product(id1, id2);
111        }
112
113        // Perform contraction
114        let result = contract_tensors(node1, node2, &contract_pairs);
115
116        // Remove old nodes and add new
117        self.nodes.retain(|n| n.id != id1 && n.id != id2);
118
119        let new_id = self.next_id;
120        self.next_id += 1;
121        self.nodes
122            .push(TensorNode::new(new_id, result.0, result.1, result.2));
123
124        Some(new_id)
125    }
126
127    /// Outer product of two nodes
128    fn outer_product(&mut self, id1: usize, id2: usize) -> Option<usize> {
129        let node1 = self.nodes.iter().find(|n| n.id == id1)?;
130        let node2 = self.nodes.iter().find(|n| n.id == id2)?;
131
132        let mut new_data = Vec::with_capacity(node1.size() * node2.size());
133        for &a in &node1.data {
134            for &b in &node2.data {
135                new_data.push(a * b);
136            }
137        }
138
139        let mut new_dims = node1.leg_dims.clone();
140        new_dims.extend(node2.leg_dims.iter());
141
142        let mut new_labels = node1.leg_labels.clone();
143        new_labels.extend(node2.leg_labels.iter().cloned());
144
145        self.nodes.retain(|n| n.id != id1 && n.id != id2);
146
147        let new_id = self.next_id;
148        self.next_id += 1;
149        self.nodes
150            .push(TensorNode::new(new_id, new_data, new_dims, new_labels));
151
152        Some(new_id)
153    }
154
155    /// Contract entire network to scalar (if possible)
156    pub fn contract_all(&mut self) -> Option<f64> {
157        while self.nodes.len() > 1 {
158            // Find a pair with matching labels
159            let mut found = None;
160            'outer: for i in 0..self.nodes.len() {
161                for j in i + 1..self.nodes.len() {
162                    for label in &self.nodes[i].leg_labels {
163                        if !label.starts_with("open_") && self.nodes[j].leg_labels.contains(label) {
164                            found = Some((self.nodes[i].id, self.nodes[j].id));
165                            break 'outer;
166                        }
167                    }
168                }
169            }
170
171            if let Some((id1, id2)) = found {
172                self.contract(id1, id2)?;
173            } else {
174                // No more contractions possible
175                break;
176            }
177        }
178
179        if self.nodes.len() == 1 && self.nodes[0].leg_dims.is_empty() {
180            Some(self.nodes[0].data[0])
181        } else {
182            None
183        }
184    }
185}
186
187impl Default for TensorNetwork {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193/// Contract two tensors on specified index pairs
194fn contract_tensors(
195    node1: &TensorNode,
196    node2: &TensorNode,
197    contract_pairs: &[(usize, usize)],
198) -> (Vec<f64>, Vec<usize>, Vec<String>) {
199    // Determine output shape and labels
200    let mut out_dims = Vec::new();
201    let mut out_labels = Vec::new();
202
203    let contracted1: Vec<usize> = contract_pairs.iter().map(|p| p.0).collect();
204    let contracted2: Vec<usize> = contract_pairs.iter().map(|p| p.1).collect();
205
206    for (i, (dim, label)) in node1
207        .leg_dims
208        .iter()
209        .zip(node1.leg_labels.iter())
210        .enumerate()
211    {
212        if !contracted1.contains(&i) {
213            out_dims.push(*dim);
214            out_labels.push(label.clone());
215        }
216    }
217
218    for (i, (dim, label)) in node2
219        .leg_dims
220        .iter()
221        .zip(node2.leg_labels.iter())
222        .enumerate()
223    {
224        if !contracted2.contains(&i) {
225            out_dims.push(*dim);
226            out_labels.push(label.clone());
227        }
228    }
229
230    let out_size: usize = if out_dims.is_empty() {
231        1
232    } else {
233        out_dims.iter().product()
234    };
235    let mut out_data = vec![0.0; out_size];
236
237    // Contract by enumeration
238    let size1 = node1.size();
239    let size2 = node2.size();
240
241    let strides1 = compute_strides(&node1.leg_dims);
242    let strides2 = compute_strides(&node2.leg_dims);
243    let out_strides = compute_strides(&out_dims);
244
245    // For each element of output
246    let mut out_indices = vec![0usize; out_dims.len()];
247    for out_flat in 0..out_size {
248        // Map to input indices
249        // Sum over contracted indices
250        let contract_sizes: Vec<usize> =
251            contract_pairs.iter().map(|p| node1.leg_dims[p.0]).collect();
252        let contract_total: usize = if contract_sizes.is_empty() {
253            1
254        } else {
255            contract_sizes.iter().product()
256        };
257
258        let mut sum = 0.0;
259
260        for contract_flat in 0..contract_total {
261            // Build indices for node1 and node2
262            let mut idx1 = vec![0usize; node1.num_legs()];
263            let mut idx2 = vec![0usize; node2.num_legs()];
264
265            // Set contracted indices
266            let mut cf = contract_flat;
267            for (pi, &(i1, i2)) in contract_pairs.iter().enumerate() {
268                let ci = cf % contract_sizes[pi];
269                cf /= contract_sizes[pi];
270                idx1[i1] = ci;
271                idx2[i2] = ci;
272            }
273
274            // Set free indices from output
275            let mut out_idx_copy = out_flat;
276            let mut free1_pos = 0;
277            let mut free2_pos = 0;
278
279            for i in 0..node1.num_legs() {
280                if !contracted1.contains(&i) {
281                    if free1_pos < out_dims.len() {
282                        idx1[i] = (out_idx_copy / out_strides.get(free1_pos).unwrap_or(&1))
283                            % node1.leg_dims[i];
284                    }
285                    free1_pos += 1;
286                }
287            }
288
289            for i in 0..node2.num_legs() {
290                if !contracted2.contains(&i) {
291                    let pos = (node1.num_legs() - contracted1.len()) + free2_pos;
292                    if pos < out_dims.len() {
293                        idx2[i] =
294                            (out_flat / out_strides.get(pos).unwrap_or(&1)) % node2.leg_dims[i];
295                    }
296                    free2_pos += 1;
297                }
298            }
299
300            // Compute linear indices
301            let lin1: usize = idx1.iter().zip(strides1.iter()).map(|(i, s)| i * s).sum();
302            let lin2: usize = idx2.iter().zip(strides2.iter()).map(|(i, s)| i * s).sum();
303
304            sum += node1.data[lin1.min(node1.data.len() - 1)]
305                * node2.data[lin2.min(node2.data.len() - 1)];
306        }
307
308        out_data[out_flat] = sum;
309    }
310
311    (out_data, out_dims, out_labels)
312}
313
314fn compute_strides(dims: &[usize]) -> Vec<usize> {
315    let mut strides = Vec::with_capacity(dims.len());
316    let mut stride = 1;
317    for &d in dims.iter().rev() {
318        strides.push(stride);
319        stride *= d;
320    }
321    strides.reverse();
322    strides
323}
324
325/// Optimal contraction order finder
326#[derive(Debug, Clone)]
327pub struct NetworkContraction {
328    /// Estimated contraction cost
329    pub estimated_cost: f64,
330}
331
332impl NetworkContraction {
333    /// Find greedy contraction order (not optimal but fast)
334    pub fn greedy_order(network: &TensorNetwork) -> Vec<(usize, usize)> {
335        let mut order = Vec::new();
336        let mut remaining: Vec<usize> = network.nodes.iter().map(|n| n.id).collect();
337
338        while remaining.len() > 1 {
339            // Find pair with smallest contraction cost
340            let mut best_pair = None;
341            let mut best_cost = f64::INFINITY;
342
343            for i in 0..remaining.len() {
344                for j in i + 1..remaining.len() {
345                    let id1 = remaining[i];
346                    let id2 = remaining[j];
347
348                    if let (Some(n1), Some(n2)) = (network.get_node(id1), network.get_node(id2)) {
349                        let cost = estimate_contraction_cost(n1, n2);
350                        if cost < best_cost {
351                            best_cost = cost;
352                            best_pair = Some((i, j));
353                        }
354                    }
355                }
356            }
357
358            if let Some((i, j)) = best_pair {
359                let id1 = remaining[i];
360                let id2 = remaining[j];
361                order.push((id1, id2));
362
363                // Remove j first (larger index)
364                remaining.remove(j);
365                remaining.remove(i);
366                // In real implementation, we'd add the result node ID
367            } else {
368                break;
369            }
370        }
371
372        order
373    }
374}
375
376fn estimate_contraction_cost(n1: &TensorNode, n2: &TensorNode) -> f64 {
377    // Simple cost estimate: product of all dimension sizes
378    let size1: usize = n1.leg_dims.iter().product();
379    let size2: usize = n2.leg_dims.iter().product();
380
381    // Find contracted dimensions
382    let mut contracted_size = 1usize;
383    for (i1, label1) in n1.leg_labels.iter().enumerate() {
384        for (i2, label2) in n2.leg_labels.iter().enumerate() {
385            if label1 == label2 && !label1.starts_with("open_") {
386                contracted_size *= n1.leg_dims[i1];
387            }
388        }
389    }
390
391    // Cost ≈ output_size × contracted_size
392    (size1 * size2 / contracted_size.max(1)) as f64
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_tensor_network_creation() {
401        let mut network = TensorNetwork::new();
402
403        let id1 = network.add_node(
404            vec![1.0, 2.0, 3.0, 4.0],
405            vec![2, 2],
406            vec!["i".into(), "j".into()],
407        );
408
409        let id2 = network.add_node(
410            vec![1.0, 0.0, 0.0, 1.0],
411            vec![2, 2],
412            vec!["j".into(), "k".into()],
413        );
414
415        assert_eq!(network.num_nodes(), 2);
416    }
417
418    #[test]
419    fn test_matrix_contraction() {
420        let mut network = TensorNetwork::new();
421
422        // A = [[1, 2], [3, 4]]
423        let id1 = network.add_node(
424            vec![1.0, 2.0, 3.0, 4.0],
425            vec![2, 2],
426            vec!["i".into(), "j".into()],
427        );
428
429        // B = [[1, 0], [0, 1]] (identity)
430        let id2 = network.add_node(
431            vec![1.0, 0.0, 0.0, 1.0],
432            vec![2, 2],
433            vec!["j".into(), "k".into()],
434        );
435
436        let result_id = network.contract(id1, id2).unwrap();
437        let result = network.get_node(result_id).unwrap();
438
439        // A * I = A
440        assert_eq!(result.data.len(), 4);
441        // Result should be [[1, 2], [3, 4]]
442    }
443
444    #[test]
445    fn test_vector_dot_product() {
446        let mut network = TensorNetwork::new();
447
448        // v1 = [1, 2, 3]
449        let id1 = network.add_node(vec![1.0, 2.0, 3.0], vec![3], vec!["i".into()]);
450
451        // v2 = [1, 1, 1]
452        let id2 = network.add_node(vec![1.0, 1.0, 1.0], vec![3], vec!["i".into()]);
453
454        let result_id = network.contract(id1, id2).unwrap();
455        let result = network.get_node(result_id).unwrap();
456
457        // Dot product = 1 + 2 + 3 = 6
458        assert_eq!(result.data.len(), 1);
459        assert!((result.data[0] - 6.0).abs() < 1e-10);
460    }
461}