ruvector_math/tensor_networks/
contraction.rs

1//! Tensor Network Contraction
2//!
3//! General tensor network operations for quantum-inspired algorithms.
4
5use std::collections::HashMap;
6
7/// A node in a tensor network
8#[derive(Debug, Clone)]
9pub struct TensorNode {
10    /// Node identifier
11    pub id: usize,
12    /// Tensor data
13    pub data: Vec<f64>,
14    /// Dimensions of each leg
15    pub leg_dims: Vec<usize>,
16    /// Labels for each leg (for contraction)
17    pub leg_labels: Vec<String>,
18}
19
20impl TensorNode {
21    /// Create new tensor node
22    pub fn new(id: usize, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> Self {
23        let expected_size: usize = leg_dims.iter().product();
24        assert_eq!(data.len(), expected_size);
25        assert_eq!(leg_dims.len(), leg_labels.len());
26
27        Self { id, data, leg_dims, leg_labels }
28    }
29
30    /// Number of legs
31    pub fn num_legs(&self) -> usize {
32        self.leg_dims.len()
33    }
34
35    /// Total size
36    pub fn size(&self) -> usize {
37        self.data.len()
38    }
39}
40
41/// Tensor network for contraction operations
42#[derive(Debug, Clone)]
43pub struct TensorNetwork {
44    /// Nodes in the network
45    nodes: Vec<TensorNode>,
46    /// Next node ID
47    next_id: usize,
48}
49
50impl TensorNetwork {
51    /// Create empty network
52    pub fn new() -> Self {
53        Self { nodes: Vec::new(), next_id: 0 }
54    }
55
56    /// Add a tensor node
57    pub fn add_node(&mut self, data: Vec<f64>, leg_dims: Vec<usize>, leg_labels: Vec<String>) -> usize {
58        let id = self.next_id;
59        self.next_id += 1;
60        self.nodes.push(TensorNode::new(id, data, leg_dims, leg_labels));
61        id
62    }
63
64    /// Get node by ID
65    pub fn get_node(&self, id: usize) -> Option<&TensorNode> {
66        self.nodes.iter().find(|n| n.id == id)
67    }
68
69    /// Number of nodes
70    pub fn num_nodes(&self) -> usize {
71        self.nodes.len()
72    }
73
74    /// Contract two nodes on matching labels
75    pub fn contract(&mut self, id1: usize, id2: usize) -> Option<usize> {
76        let node1_idx = self.nodes.iter().position(|n| n.id == id1)?;
77        let node2_idx = self.nodes.iter().position(|n| n.id == id2)?;
78
79        // Find matching labels
80        let node1 = &self.nodes[node1_idx];
81        let node2 = &self.nodes[node2_idx];
82
83        let mut contract_pairs: Vec<(usize, usize)> = Vec::new();
84
85        for (i1, label1) in node1.leg_labels.iter().enumerate() {
86            for (i2, label2) in node2.leg_labels.iter().enumerate() {
87                if label1 == label2 && !label1.starts_with("open_") {
88                    assert_eq!(node1.leg_dims[i1], node2.leg_dims[i2], "Dimension mismatch");
89                    contract_pairs.push((i1, i2));
90                }
91            }
92        }
93
94        if contract_pairs.is_empty() {
95            // Outer product
96            return self.outer_product(id1, id2);
97        }
98
99        // Perform contraction
100        let result = contract_tensors(node1, node2, &contract_pairs);
101
102        // Remove old nodes and add new
103        self.nodes.retain(|n| n.id != id1 && n.id != id2);
104
105        let new_id = self.next_id;
106        self.next_id += 1;
107        self.nodes.push(TensorNode::new(new_id, result.0, result.1, result.2));
108
109        Some(new_id)
110    }
111
112    /// Outer product of two nodes
113    fn outer_product(&mut self, id1: usize, id2: usize) -> Option<usize> {
114        let node1 = self.nodes.iter().find(|n| n.id == id1)?;
115        let node2 = self.nodes.iter().find(|n| n.id == id2)?;
116
117        let mut new_data = Vec::with_capacity(node1.size() * node2.size());
118        for &a in &node1.data {
119            for &b in &node2.data {
120                new_data.push(a * b);
121            }
122        }
123
124        let mut new_dims = node1.leg_dims.clone();
125        new_dims.extend(node2.leg_dims.iter());
126
127        let mut new_labels = node1.leg_labels.clone();
128        new_labels.extend(node2.leg_labels.iter().cloned());
129
130        self.nodes.retain(|n| n.id != id1 && n.id != id2);
131
132        let new_id = self.next_id;
133        self.next_id += 1;
134        self.nodes.push(TensorNode::new(new_id, new_data, new_dims, new_labels));
135
136        Some(new_id)
137    }
138
139    /// Contract entire network to scalar (if possible)
140    pub fn contract_all(&mut self) -> Option<f64> {
141        while self.nodes.len() > 1 {
142            // Find a pair with matching labels
143            let mut found = None;
144            'outer: for i in 0..self.nodes.len() {
145                for j in i + 1..self.nodes.len() {
146                    for label in &self.nodes[i].leg_labels {
147                        if !label.starts_with("open_") && self.nodes[j].leg_labels.contains(label) {
148                            found = Some((self.nodes[i].id, self.nodes[j].id));
149                            break 'outer;
150                        }
151                    }
152                }
153            }
154
155            if let Some((id1, id2)) = found {
156                self.contract(id1, id2)?;
157            } else {
158                // No more contractions possible
159                break;
160            }
161        }
162
163        if self.nodes.len() == 1 && self.nodes[0].leg_dims.is_empty() {
164            Some(self.nodes[0].data[0])
165        } else {
166            None
167        }
168    }
169}
170
171impl Default for TensorNetwork {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177/// Contract two tensors on specified index pairs
178fn contract_tensors(
179    node1: &TensorNode,
180    node2: &TensorNode,
181    contract_pairs: &[(usize, usize)],
182) -> (Vec<f64>, Vec<usize>, Vec<String>) {
183    // Determine output shape and labels
184    let mut out_dims = Vec::new();
185    let mut out_labels = Vec::new();
186
187    let contracted1: Vec<usize> = contract_pairs.iter().map(|p| p.0).collect();
188    let contracted2: Vec<usize> = contract_pairs.iter().map(|p| p.1).collect();
189
190    for (i, (dim, label)) in node1.leg_dims.iter().zip(node1.leg_labels.iter()).enumerate() {
191        if !contracted1.contains(&i) {
192            out_dims.push(*dim);
193            out_labels.push(label.clone());
194        }
195    }
196
197    for (i, (dim, label)) in node2.leg_dims.iter().zip(node2.leg_labels.iter()).enumerate() {
198        if !contracted2.contains(&i) {
199            out_dims.push(*dim);
200            out_labels.push(label.clone());
201        }
202    }
203
204    let out_size: usize = if out_dims.is_empty() { 1 } else { out_dims.iter().product() };
205    let mut out_data = vec![0.0; out_size];
206
207    // Contract by enumeration
208    let size1 = node1.size();
209    let size2 = node2.size();
210
211    let strides1 = compute_strides(&node1.leg_dims);
212    let strides2 = compute_strides(&node2.leg_dims);
213    let out_strides = compute_strides(&out_dims);
214
215    // For each element of output
216    let mut out_indices = vec![0usize; out_dims.len()];
217    for out_flat in 0..out_size {
218        // Map to input indices
219        // Sum over contracted indices
220        let contract_sizes: Vec<usize> = contract_pairs.iter().map(|p| node1.leg_dims[p.0]).collect();
221        let contract_total: usize = if contract_sizes.is_empty() { 1 } else { contract_sizes.iter().product() };
222
223        let mut sum = 0.0;
224
225        for contract_flat in 0..contract_total {
226            // Build indices for node1 and node2
227            let mut idx1 = vec![0usize; node1.num_legs()];
228            let mut idx2 = vec![0usize; node2.num_legs()];
229
230            // Set contracted indices
231            let mut cf = contract_flat;
232            for (pi, &(i1, i2)) in contract_pairs.iter().enumerate() {
233                let ci = cf % contract_sizes[pi];
234                cf /= contract_sizes[pi];
235                idx1[i1] = ci;
236                idx2[i2] = ci;
237            }
238
239            // Set free indices from output
240            let mut out_idx_copy = out_flat;
241            let mut free1_pos = 0;
242            let mut free2_pos = 0;
243
244            for i in 0..node1.num_legs() {
245                if !contracted1.contains(&i) {
246                    if free1_pos < out_dims.len() {
247                        idx1[i] = (out_idx_copy / out_strides.get(free1_pos).unwrap_or(&1)) % node1.leg_dims[i];
248                    }
249                    free1_pos += 1;
250                }
251            }
252
253            for i in 0..node2.num_legs() {
254                if !contracted2.contains(&i) {
255                    let pos = (node1.num_legs() - contracted1.len()) + free2_pos;
256                    if pos < out_dims.len() {
257                        idx2[i] = (out_flat / out_strides.get(pos).unwrap_or(&1)) % node2.leg_dims[i];
258                    }
259                    free2_pos += 1;
260                }
261            }
262
263            // Compute linear indices
264            let lin1: usize = idx1.iter().zip(strides1.iter()).map(|(i, s)| i * s).sum();
265            let lin2: usize = idx2.iter().zip(strides2.iter()).map(|(i, s)| i * s).sum();
266
267            sum += node1.data[lin1.min(node1.data.len() - 1)] * node2.data[lin2.min(node2.data.len() - 1)];
268        }
269
270        out_data[out_flat] = sum;
271    }
272
273    (out_data, out_dims, out_labels)
274}
275
276fn compute_strides(dims: &[usize]) -> Vec<usize> {
277    let mut strides = Vec::with_capacity(dims.len());
278    let mut stride = 1;
279    for &d in dims.iter().rev() {
280        strides.push(stride);
281        stride *= d;
282    }
283    strides.reverse();
284    strides
285}
286
287/// Optimal contraction order finder
288#[derive(Debug, Clone)]
289pub struct NetworkContraction {
290    /// Estimated contraction cost
291    pub estimated_cost: f64,
292}
293
294impl NetworkContraction {
295    /// Find greedy contraction order (not optimal but fast)
296    pub fn greedy_order(network: &TensorNetwork) -> Vec<(usize, usize)> {
297        let mut order = Vec::new();
298        let mut remaining: Vec<usize> = network.nodes.iter().map(|n| n.id).collect();
299
300        while remaining.len() > 1 {
301            // Find pair with smallest contraction cost
302            let mut best_pair = None;
303            let mut best_cost = f64::INFINITY;
304
305            for i in 0..remaining.len() {
306                for j in i + 1..remaining.len() {
307                    let id1 = remaining[i];
308                    let id2 = remaining[j];
309
310                    if let (Some(n1), Some(n2)) = (network.get_node(id1), network.get_node(id2)) {
311                        let cost = estimate_contraction_cost(n1, n2);
312                        if cost < best_cost {
313                            best_cost = cost;
314                            best_pair = Some((i, j));
315                        }
316                    }
317                }
318            }
319
320            if let Some((i, j)) = best_pair {
321                let id1 = remaining[i];
322                let id2 = remaining[j];
323                order.push((id1, id2));
324
325                // Remove j first (larger index)
326                remaining.remove(j);
327                remaining.remove(i);
328                // In real implementation, we'd add the result node ID
329            } else {
330                break;
331            }
332        }
333
334        order
335    }
336}
337
338fn estimate_contraction_cost(n1: &TensorNode, n2: &TensorNode) -> f64 {
339    // Simple cost estimate: product of all dimension sizes
340    let size1: usize = n1.leg_dims.iter().product();
341    let size2: usize = n2.leg_dims.iter().product();
342
343    // Find contracted dimensions
344    let mut contracted_size = 1usize;
345    for (i1, label1) in n1.leg_labels.iter().enumerate() {
346        for (i2, label2) in n2.leg_labels.iter().enumerate() {
347            if label1 == label2 && !label1.starts_with("open_") {
348                contracted_size *= n1.leg_dims[i1];
349            }
350        }
351    }
352
353    // Cost ≈ output_size × contracted_size
354    (size1 * size2 / contracted_size.max(1)) as f64
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_tensor_network_creation() {
363        let mut network = TensorNetwork::new();
364
365        let id1 = network.add_node(
366            vec![1.0, 2.0, 3.0, 4.0],
367            vec![2, 2],
368            vec!["i".into(), "j".into()],
369        );
370
371        let id2 = network.add_node(
372            vec![1.0, 0.0, 0.0, 1.0],
373            vec![2, 2],
374            vec!["j".into(), "k".into()],
375        );
376
377        assert_eq!(network.num_nodes(), 2);
378    }
379
380    #[test]
381    fn test_matrix_contraction() {
382        let mut network = TensorNetwork::new();
383
384        // A = [[1, 2], [3, 4]]
385        let id1 = network.add_node(
386            vec![1.0, 2.0, 3.0, 4.0],
387            vec![2, 2],
388            vec!["i".into(), "j".into()],
389        );
390
391        // B = [[1, 0], [0, 1]] (identity)
392        let id2 = network.add_node(
393            vec![1.0, 0.0, 0.0, 1.0],
394            vec![2, 2],
395            vec!["j".into(), "k".into()],
396        );
397
398        let result_id = network.contract(id1, id2).unwrap();
399        let result = network.get_node(result_id).unwrap();
400
401        // A * I = A
402        assert_eq!(result.data.len(), 4);
403        // Result should be [[1, 2], [3, 4]]
404    }
405
406    #[test]
407    fn test_vector_dot_product() {
408        let mut network = TensorNetwork::new();
409
410        // v1 = [1, 2, 3]
411        let id1 = network.add_node(
412            vec![1.0, 2.0, 3.0],
413            vec![3],
414            vec!["i".into()],
415        );
416
417        // v2 = [1, 1, 1]
418        let id2 = network.add_node(
419            vec![1.0, 1.0, 1.0],
420            vec![3],
421            vec!["i".into()],
422        );
423
424        let result_id = network.contract(id1, id2).unwrap();
425        let result = network.get_node(result_id).unwrap();
426
427        // Dot product = 1 + 2 + 3 = 6
428        assert_eq!(result.data.len(), 1);
429        assert!((result.data[0] - 6.0).abs() < 1e-10);
430    }
431}