Skip to main content

ruqu_core/
decomposition.rs

1//! Hybrid classical-quantum circuit decomposition engine.
2//!
3//! Performs structural decomposition of quantum circuits across simulation
4//! paradigms using graph-based partitioning. Most quantum simulation systems
5//! commit to a single backend for an entire circuit. This engine partitions
6//! a circuit into segments that are independently routed to the optimal
7//! backend (StateVector, Stabilizer, or TensorNetwork), yielding significant
8//! performance gains for heterogeneous circuits.
9//!
10//! # Decomposition strategies
11//!
12//! | Strategy | Description |
13//! |----------|-------------|
14//! | `Temporal` | Split by time slices (barrier gates or natural idle boundaries) |
15//! | `Spatial` | Split by qubit subsets (connected components or min-cut partitioning) |
16//! | `Hybrid` | Both temporal and spatial decomposition applied in sequence |
17//! | `None` | No decomposition; the whole circuit is a single segment |
18//!
19//! # Example
20//!
21//! ```
22//! use ruqu_core::circuit::QuantumCircuit;
23//! use ruqu_core::decomposition::decompose;
24//!
25//! // Two independent Bell pairs on disjoint qubits.
26//! let mut circ = QuantumCircuit::new(4);
27//! circ.h(0).cnot(0, 1);   // Bell pair on qubits 0-1
28//! circ.h(2).cnot(2, 3);   // Bell pair on qubits 2-3
29//!
30//! let partition = decompose(&circ, 25);
31//! assert_eq!(partition.segments.len(), 2);
32//! ```
33
34use std::collections::{HashMap, HashSet, VecDeque};
35
36use crate::backend::BackendType;
37use crate::circuit::QuantumCircuit;
38use crate::gate::Gate;
39use crate::stabilizer::StabilizerState;
40
41// ---------------------------------------------------------------------------
42// Public data structures
43// ---------------------------------------------------------------------------
44
45/// The result of decomposing a circuit into independently-simulable segments.
46#[derive(Debug, Clone)]
47pub struct CircuitPartition {
48    /// Ordered list of circuit segments to simulate.
49    pub segments: Vec<CircuitSegment>,
50    /// Total qubit count of the original circuit.
51    pub total_qubits: u32,
52    /// Strategy that was used for decomposition.
53    pub strategy: DecompositionStrategy,
54}
55
56/// A single segment of a decomposed circuit, ready for backend dispatch.
57#[derive(Debug, Clone)]
58pub struct CircuitSegment {
59    /// The sub-circuit to simulate.
60    pub circuit: QuantumCircuit,
61    /// The backend selected for this segment.
62    pub backend: BackendType,
63    /// Inclusive range of original qubit indices covered by this segment.
64    pub qubit_range: (u32, u32),
65    /// Start and end gate indices in the original circuit (end is exclusive).
66    pub gate_range: (usize, usize),
67    /// Estimated simulation cost of this segment.
68    pub estimated_cost: SegmentCost,
69}
70
71/// Estimated resource consumption for simulating a circuit segment.
72#[derive(Debug, Clone)]
73pub struct SegmentCost {
74    /// Estimated memory consumption in bytes.
75    pub memory_bytes: u64,
76    /// Estimated floating-point operations.
77    pub estimated_flops: u64,
78    /// Number of qubits in this segment.
79    pub qubit_count: u32,
80}
81
82/// Strategy used for circuit decomposition.
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum DecompositionStrategy {
85    /// Split by time slices (gate layers / barriers).
86    Temporal,
87    /// Split by qubit subsets (connected components / partitioning).
88    Spatial,
89    /// Both temporal and spatial decomposition applied.
90    Hybrid,
91    /// No decomposition; the circuit is a single segment.
92    None,
93}
94
95// ---------------------------------------------------------------------------
96// Interaction graph
97// ---------------------------------------------------------------------------
98
99/// Qubit interaction graph extracted from a quantum circuit.
100///
101/// Nodes are qubits. Edges are two-qubit gates, weighted by the number of
102/// such gates between each pair.
103#[derive(Debug, Clone)]
104pub struct InteractionGraph {
105    /// Number of qubits (nodes) in the graph.
106    pub num_qubits: u32,
107    /// Edges as `(qubit_a, qubit_b, gate_count)`.
108    pub edges: Vec<(u32, u32, usize)>,
109    /// Adjacency list: `adjacency[q]` contains the neighbours of qubit `q`.
110    pub adjacency: Vec<Vec<u32>>,
111}
112
113/// Build the qubit interaction graph for a circuit.
114///
115/// Every two-qubit gate contributes an edge (or increments the weight of an
116/// existing edge) between the two qubits it acts on.
117pub fn build_interaction_graph(circuit: &QuantumCircuit) -> InteractionGraph {
118    let n = circuit.num_qubits();
119    let mut edge_counts: HashMap<(u32, u32), usize> = HashMap::new();
120
121    for gate in circuit.gates() {
122        let qubits = gate.qubits();
123        if qubits.len() == 2 {
124            let (a, b) = if qubits[0] <= qubits[1] {
125                (qubits[0], qubits[1])
126            } else {
127                (qubits[1], qubits[0])
128            };
129            *edge_counts.entry((a, b)).or_insert(0) += 1;
130        }
131    }
132
133    let mut adjacency: Vec<Vec<u32>> = vec![Vec::new(); n as usize];
134    let mut edges: Vec<(u32, u32, usize)> = Vec::with_capacity(edge_counts.len());
135
136    for (&(a, b), &count) in &edge_counts {
137        edges.push((a, b, count));
138        if !adjacency[a as usize].contains(&b) {
139            adjacency[a as usize].push(b);
140        }
141        if !adjacency[b as usize].contains(&a) {
142            adjacency[b as usize].push(a);
143        }
144    }
145
146    // Sort adjacency lists for deterministic traversal.
147    for adj in &mut adjacency {
148        adj.sort_unstable();
149    }
150
151    InteractionGraph {
152        num_qubits: n,
153        edges,
154        adjacency,
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Connected components (BFS)
160// ---------------------------------------------------------------------------
161
162/// Find connected components of the qubit interaction graph using BFS.
163///
164/// Returns a list of components, each being a sorted list of qubit indices.
165/// Isolated qubits (those with no two-qubit gate interactions) are each
166/// returned as their own singleton component.
167pub fn find_connected_components(graph: &InteractionGraph) -> Vec<Vec<u32>> {
168    let n = graph.num_qubits as usize;
169    let mut visited = vec![false; n];
170    let mut components: Vec<Vec<u32>> = Vec::new();
171
172    for start in 0..n {
173        if visited[start] {
174            continue;
175        }
176        visited[start] = true;
177        let mut component = vec![start as u32];
178        let mut queue = VecDeque::new();
179        queue.push_back(start as u32);
180
181        while let Some(node) = queue.pop_front() {
182            for &neighbor in &graph.adjacency[node as usize] {
183                if !visited[neighbor as usize] {
184                    visited[neighbor as usize] = true;
185                    component.push(neighbor);
186                    queue.push_back(neighbor);
187                }
188            }
189        }
190
191        component.sort_unstable();
192        components.push(component);
193    }
194
195    components
196}
197
198// ---------------------------------------------------------------------------
199// Temporal decomposition
200// ---------------------------------------------------------------------------
201
202/// Split a circuit at `Barrier` gates or at natural breakpoints where no
203/// qubit is active across the boundary.
204///
205/// A natural breakpoint occurs when all qubits that have been touched in the
206/// current slice have been measured or reset, making them logically idle.
207///
208/// Returns a list of sub-circuits. Each sub-circuit preserves the original
209/// qubit count so that qubit indices remain valid.
210pub fn temporal_decomposition(circuit: &QuantumCircuit) -> Vec<QuantumCircuit> {
211    let gates = circuit.gates();
212    if gates.is_empty() {
213        return vec![QuantumCircuit::new(circuit.num_qubits())];
214    }
215
216    let n = circuit.num_qubits();
217    let mut slices: Vec<QuantumCircuit> = Vec::new();
218    let mut current = QuantumCircuit::new(n);
219    let mut current_has_gates = false;
220
221    // Track which qubits have been used (touched) in the current slice
222    // and which of those have been subsequently measured/reset.
223    let mut active_qubits: HashSet<u32> = HashSet::new();
224    let mut measured_qubits: HashSet<u32> = HashSet::new();
225
226    for gate in gates {
227        match gate {
228            Gate::Barrier => {
229                // Barrier always forces a slice boundary.
230                if current_has_gates {
231                    slices.push(current);
232                    current = QuantumCircuit::new(n);
233                    current_has_gates = false;
234                    active_qubits.clear();
235                    measured_qubits.clear();
236                }
237            }
238            _ => {
239                let qubits = gate.qubits();
240
241                // Before adding this gate, check if we have a natural breakpoint:
242                // All previously-active qubits have been measured/reset, and this
243                // gate touches at least one qubit not yet in the active set.
244                if current_has_gates
245                    && !active_qubits.is_empty()
246                    && active_qubits.iter().all(|q| measured_qubits.contains(q))
247                {
248                    // All active qubits are measured/reset -- natural boundary.
249                    slices.push(current);
250                    current = QuantumCircuit::new(n);
251                    active_qubits.clear();
252                    measured_qubits.clear();
253                }
254
255                // Track measurement/reset operations.
256                match gate {
257                    Gate::Measure(q) => {
258                        measured_qubits.insert(*q);
259                    }
260                    Gate::Reset(q) => {
261                        measured_qubits.insert(*q);
262                    }
263                    _ => {}
264                }
265
266                // Mark touched qubits as active.
267                for &q in &qubits {
268                    active_qubits.insert(q);
269                }
270
271                current.add_gate(gate.clone());
272                current_has_gates = true;
273            }
274        }
275    }
276
277    // Push the final slice if it has any gates.
278    if current_has_gates {
279        slices.push(current);
280    }
281
282    // Guarantee at least one circuit is returned.
283    if slices.is_empty() {
284        slices.push(QuantumCircuit::new(n));
285    }
286
287    slices
288}
289
290// ---------------------------------------------------------------------------
291// Stoer-Wagner minimum cut
292// ---------------------------------------------------------------------------
293
294/// Result of a Stoer-Wagner minimum cut computation.
295#[derive(Debug, Clone)]
296pub struct MinCutResult {
297    /// The minimum cut value (sum of edge weights crossing the cut).
298    pub cut_value: usize,
299    /// One side of the partition (qubit indices).
300    pub partition_a: Vec<u32>,
301    /// Other side of the partition.
302    pub partition_b: Vec<u32>,
303}
304
305/// Compute the minimum cut of an interaction graph using Stoer-Wagner.
306///
307/// Time complexity: O(V * E + V^2 * log V) which is O(V^3) for dense graphs.
308/// This is optimal for finding a global minimum cut without specifying s and t.
309///
310/// Returns `None` if the graph has 0 or 1 nodes.
311pub fn stoer_wagner_mincut(graph: &InteractionGraph) -> Option<MinCutResult> {
312    let n = graph.num_qubits as usize;
313    if n <= 1 {
314        return None;
315    }
316
317    // Build a weighted adjacency matrix.
318    let mut adj = vec![vec![0usize; n]; n];
319    for &(a, b, w) in &graph.edges {
320        let (a, b) = (a as usize, b as usize);
321        adj[a][b] += w;
322        adj[b][a] += w;
323    }
324
325    // Track which original vertices are merged into each super-vertex.
326    let mut merged: Vec<Vec<u32>> = (0..n).map(|i| vec![i as u32]).collect();
327    let mut active: Vec<bool> = vec![true; n];
328
329    let mut best_cut_value = usize::MAX;
330    let mut best_partition: Vec<u32> = Vec::new();
331
332    for _ in 0..(n - 1) {
333        // Stoer-Wagner phase: find the most tightly connected vertex ordering.
334        let active_nodes: Vec<usize> = (0..n).filter(|&i| active[i]).collect();
335        if active_nodes.len() < 2 {
336            break;
337        }
338
339        let mut in_a = vec![false; n];
340        let mut weight_to_a = vec![0usize; n];
341
342        // Start with the first active node.
343        let start = active_nodes[0];
344        in_a[start] = true;
345
346        // Update weights for neighbors of start.
347        for &node in &active_nodes {
348            if node != start {
349                weight_to_a[node] = adj[start][node];
350            }
351        }
352
353        let mut prev = start;
354        let mut last = start;
355
356        for _ in 1..active_nodes.len() {
357            // Find the most tightly connected vertex not yet in A.
358            let next = active_nodes
359                .iter()
360                .filter(|&&v| !in_a[v])
361                .max_by_key(|&&v| weight_to_a[v])
362                .copied()
363                .unwrap();
364
365            prev = last;
366            last = next;
367            in_a[next] = true;
368
369            // Update weights.
370            for &node in &active_nodes {
371                if !in_a[node] {
372                    weight_to_a[node] += adj[next][node];
373                }
374            }
375        }
376
377        // The cut-of-the-phase is the weight of last vertex added.
378        let cut_of_phase = weight_to_a[last];
379
380        if cut_of_phase < best_cut_value {
381            best_cut_value = cut_of_phase;
382            best_partition = merged[last].clone();
383        }
384
385        // Merge last into prev.
386        for &node in &active_nodes {
387            if node != last && node != prev {
388                adj[prev][node] += adj[last][node];
389                adj[node][prev] += adj[node][last];
390            }
391        }
392        active[last] = false;
393        let last_merged = std::mem::take(&mut merged[last]);
394        merged[prev].extend(last_merged);
395    }
396
397    let partition_a_set: HashSet<u32> = best_partition.iter().copied().collect();
398    let mut partition_a: Vec<u32> = best_partition;
399    partition_a.sort_unstable();
400    let mut partition_b: Vec<u32> = (0..n as u32)
401        .filter(|q| !partition_a_set.contains(q))
402        .collect();
403    partition_b.sort_unstable();
404
405    Some(MinCutResult {
406        cut_value: best_cut_value,
407        partition_a,
408        partition_b,
409    })
410}
411
412/// Spatial decomposition using Stoer-Wagner minimum cut.
413///
414/// Recursively bisects the circuit along minimum cuts until all segments
415/// have at most `max_qubits` qubits. Produces better partitions than the
416/// greedy approach by minimizing the number of cross-partition entangling
417/// gates.
418pub fn spatial_decomposition_mincut(
419    circuit: &QuantumCircuit,
420    graph: &InteractionGraph,
421    max_qubits: u32,
422) -> Vec<(Vec<u32>, QuantumCircuit)> {
423    let n = graph.num_qubits;
424    if n == 0 || max_qubits == 0 {
425        return Vec::new();
426    }
427    if n <= max_qubits {
428        let all_qubits: Vec<u32> = (0..n).collect();
429        return vec![(all_qubits, circuit.clone())];
430    }
431
432    // Recursively bisect using Stoer-Wagner.
433    let mut result = Vec::new();
434    recursive_mincut_partition(circuit, graph, max_qubits, &mut result);
435    result
436}
437
438/// Recursively partition using min-cut bisection.
439fn recursive_mincut_partition(
440    circuit: &QuantumCircuit,
441    graph: &InteractionGraph,
442    max_qubits: u32,
443    result: &mut Vec<(Vec<u32>, QuantumCircuit)>,
444) {
445    let n = graph.num_qubits;
446    if n <= max_qubits {
447        let all_qubits: Vec<u32> = (0..n).collect();
448        result.push((all_qubits, circuit.clone()));
449        return;
450    }
451
452    match stoer_wagner_mincut(graph) {
453        Some(cut) => {
454            // Extract subcircuits for each partition.
455            let set_a: HashSet<u32> = cut.partition_a.iter().copied().collect();
456            let set_b: HashSet<u32> = cut.partition_b.iter().copied().collect();
457
458            let circ_a = extract_component_circuit(circuit, &set_a);
459            let circ_b = extract_component_circuit(circuit, &set_b);
460
461            let graph_a = build_interaction_graph(&circ_a);
462            let graph_b = build_interaction_graph(&circ_b);
463
464            // Recurse on each half.
465            if cut.partition_a.len() as u32 > max_qubits {
466                recursive_mincut_partition(&circ_a, &graph_a, max_qubits, result);
467            } else {
468                result.push((cut.partition_a, circ_a));
469            }
470
471            if cut.partition_b.len() as u32 > max_qubits {
472                recursive_mincut_partition(&circ_b, &graph_b, max_qubits, result);
473            } else {
474                result.push((cut.partition_b, circ_b));
475            }
476        }
477        None => {
478            // Cannot partition further.
479            let all_qubits: Vec<u32> = (0..n).collect();
480            result.push((all_qubits, circuit.clone()));
481        }
482    }
483}
484
485// ---------------------------------------------------------------------------
486// Spatial decomposition (greedy heuristic)
487// ---------------------------------------------------------------------------
488
489/// Partition qubits into groups of at most `max_qubits` using a greedy
490/// min-cut heuristic, then extract subcircuits for each group.
491///
492/// Algorithm:
493/// 1. Pick the highest-degree unassigned qubit as a seed.
494/// 2. Greedily add adjacent qubits (preferring those with more edges into
495///    the current group) until the group reaches `max_qubits` or no more
496///    connected qubits remain.
497/// 3. Repeat until all qubits in the interaction graph are assigned.
498/// 4. For each group, extract the gates that operate exclusively within
499///    the group. Cross-group gates (whose qubits span multiple groups)
500///    are included in the group that contains the majority of their qubits,
501///    with the remote qubit added to the subcircuit.
502///
503/// Returns `(qubit_group, subcircuit)` pairs.
504pub fn spatial_decomposition(
505    circuit: &QuantumCircuit,
506    graph: &InteractionGraph,
507    max_qubits: u32,
508) -> Vec<(Vec<u32>, QuantumCircuit)> {
509    let n = graph.num_qubits;
510    if n == 0 || max_qubits == 0 {
511        return Vec::new();
512    }
513
514    // If the circuit fits within max_qubits, return it as a single group.
515    if n <= max_qubits {
516        let all_qubits: Vec<u32> = (0..n).collect();
517        return vec![(all_qubits, circuit.clone())];
518    }
519
520    // Compute degree for each qubit.
521    let mut degree: Vec<usize> = vec![0; n as usize];
522    for &(a, b, count) in &graph.edges {
523        degree[a as usize] += count;
524        degree[b as usize] += count;
525    }
526
527    let mut assigned = vec![false; n as usize];
528    let mut groups: Vec<Vec<u32>> = Vec::new();
529
530    while assigned.iter().any(|&a| !a) {
531        // Pick the highest-degree unassigned qubit as seed.
532        let seed = (0..n as usize)
533            .filter(|&q| !assigned[q])
534            .max_by_key(|&q| degree[q])
535            .unwrap() as u32;
536
537        let mut group = vec![seed];
538        assigned[seed as usize] = true;
539
540        // Greedily expand the group.
541        while (group.len() as u32) < max_qubits {
542            // Find the unassigned neighbor with the most connections into group.
543            let mut best_candidate: Option<u32> = Option::None;
544            let mut best_score: usize = 0;
545
546            for &member in &group {
547                for &neighbor in &graph.adjacency[member as usize] {
548                    if assigned[neighbor as usize] {
549                        continue;
550                    }
551                    // Score = number of edges from this neighbor into group members.
552                    let score: usize = graph
553                        .adjacency[neighbor as usize]
554                        .iter()
555                        .filter(|&&adj| group.contains(&adj))
556                        .count();
557                    if score > best_score
558                        || (score == best_score
559                            && best_candidate.map_or(true, |bc| neighbor < bc))
560                    {
561                        best_score = score;
562                        best_candidate = Some(neighbor);
563                    }
564                }
565            }
566
567            match best_candidate {
568                Some(candidate) => {
569                    assigned[candidate as usize] = true;
570                    group.push(candidate);
571                }
572                Option::None => break, // No more connected unassigned neighbors.
573            }
574        }
575
576        group.sort_unstable();
577        groups.push(group);
578    }
579
580    // For each group, build a subcircuit with remapped qubit indices.
581    let mut result: Vec<(Vec<u32>, QuantumCircuit)> = Vec::new();
582
583    // Build a lookup: original qubit -> group index.
584    let mut qubit_to_group: Vec<usize> = vec![0; n as usize];
585    for (gi, group) in groups.iter().enumerate() {
586        for &q in group {
587            qubit_to_group[q as usize] = gi;
588        }
589    }
590
591    for group in &groups {
592        let group_set: HashSet<u32> = group.iter().copied().collect();
593
594        // Build the qubit remapping: original index -> local index.
595        // We may need to include extra qubits for cross-group gates.
596        let mut local_qubits: Vec<u32> = group.clone();
597
598        // First pass: identify any extra qubits needed for cross-group gates
599        // that have at least one qubit in this group.
600        for gate in circuit.gates() {
601            let gate_qubits = gate.qubits();
602            if gate_qubits.is_empty() {
603                continue;
604            }
605            let in_group = gate_qubits.iter().filter(|q| group_set.contains(q)).count();
606            let out_group = gate_qubits.len() - in_group;
607            if in_group > 0 && out_group > 0 {
608                // This is a cross-group gate. If the majority of qubits are in
609                // this group, include the remote qubits.
610                if in_group >= out_group {
611                    for &q in &gate_qubits {
612                        if !local_qubits.contains(&q) {
613                            local_qubits.push(q);
614                        }
615                    }
616                }
617            }
618        }
619
620        local_qubits.sort_unstable();
621        let num_local = local_qubits.len() as u32;
622        let remap: HashMap<u32, u32> = local_qubits
623            .iter()
624            .enumerate()
625            .map(|(i, &q)| (q, i as u32))
626            .collect();
627
628        let mut sub_circuit = QuantumCircuit::new(num_local);
629
630        // Second pass: add gates that belong to this group.
631        for gate in circuit.gates() {
632            let gate_qubits = gate.qubits();
633
634            // Barrier: include in every sub-circuit.
635            if matches!(gate, Gate::Barrier) {
636                sub_circuit.add_gate(Gate::Barrier);
637                continue;
638            }
639
640            if gate_qubits.is_empty() {
641                continue;
642            }
643
644            let in_group = gate_qubits.iter().filter(|q| group_set.contains(q)).count();
645            if in_group == 0 {
646                continue; // Gate does not touch this group at all.
647            }
648
649            let out_group = gate_qubits.len() - in_group;
650            if out_group > 0 && in_group < out_group {
651                continue; // Gate is majority in another group.
652            }
653
654            // All qubits must be in our local remap.
655            if gate_qubits.iter().all(|q| remap.contains_key(q)) {
656                let remapped = remap_gate(gate, &remap);
657                sub_circuit.add_gate(remapped);
658            }
659        }
660
661        result.push((group.clone(), sub_circuit));
662    }
663
664    result
665}
666
667/// Remap qubit indices in a gate according to the given mapping.
668fn remap_gate(gate: &Gate, remap: &HashMap<u32, u32>) -> Gate {
669    match gate {
670        Gate::H(q) => Gate::H(remap[q]),
671        Gate::X(q) => Gate::X(remap[q]),
672        Gate::Y(q) => Gate::Y(remap[q]),
673        Gate::Z(q) => Gate::Z(remap[q]),
674        Gate::S(q) => Gate::S(remap[q]),
675        Gate::Sdg(q) => Gate::Sdg(remap[q]),
676        Gate::T(q) => Gate::T(remap[q]),
677        Gate::Tdg(q) => Gate::Tdg(remap[q]),
678        Gate::Rx(q, a) => Gate::Rx(remap[q], *a),
679        Gate::Ry(q, a) => Gate::Ry(remap[q], *a),
680        Gate::Rz(q, a) => Gate::Rz(remap[q], *a),
681        Gate::Phase(q, a) => Gate::Phase(remap[q], *a),
682        Gate::CNOT(c, t) => Gate::CNOT(remap[c], remap[t]),
683        Gate::CZ(a, b) => Gate::CZ(remap[a], remap[b]),
684        Gate::SWAP(a, b) => Gate::SWAP(remap[a], remap[b]),
685        Gate::Rzz(a, b, angle) => Gate::Rzz(remap[a], remap[b], *angle),
686        Gate::Measure(q) => Gate::Measure(remap[q]),
687        Gate::Reset(q) => Gate::Reset(remap[q]),
688        Gate::Barrier => Gate::Barrier,
689        Gate::Unitary1Q(q, m) => Gate::Unitary1Q(remap[q], *m),
690    }
691}
692
693// ---------------------------------------------------------------------------
694// Backend classification
695// ---------------------------------------------------------------------------
696
697/// Determine the best backend for a circuit segment based on its gate composition.
698///
699/// Decision rules:
700/// 1. If all gates are Clifford (or non-unitary) -> `Stabilizer`
701/// 2. If `num_qubits <= 25` -> `StateVector`
702/// 3. If `num_qubits > 25` and T-count <= 40 -> `CliffordT`
703/// 4. If `num_qubits > 25` and T-count > 40 -> `TensorNetwork`
704/// 5. Otherwise -> `StateVector`
705pub fn classify_segment(segment: &QuantumCircuit) -> BackendType {
706    let mut has_non_clifford = false;
707    let mut t_count: usize = 0;
708
709    for gate in segment.gates() {
710        if gate.is_non_unitary() {
711            continue;
712        }
713        if !StabilizerState::is_clifford_gate(gate) {
714            has_non_clifford = true;
715            t_count += 1;
716        }
717    }
718
719    if !has_non_clifford {
720        return BackendType::Stabilizer;
721    }
722
723    if segment.num_qubits() <= 25 {
724        return BackendType::StateVector;
725    }
726
727    // Moderate T-count on large circuits -> CliffordT (Bravyi-Gosset).
728    // 2^t stabilizer terms; practical up to ~40 T-gates.
729    if t_count <= 40 {
730        return BackendType::CliffordT;
731    }
732
733    // High T-count with > 25 qubits -> TensorNetwork
734    BackendType::TensorNetwork
735}
736
737// ---------------------------------------------------------------------------
738// Cost estimation
739// ---------------------------------------------------------------------------
740
741/// Estimate the simulation cost of a circuit segment on a given backend.
742///
743/// The estimates are order-of-magnitude correct and intended for comparing
744/// relative costs between decomposition options, not for precise prediction.
745pub fn estimate_segment_cost(segment: &QuantumCircuit, backend: BackendType) -> SegmentCost {
746    let n = segment.num_qubits();
747    let gate_count = segment.gate_count() as u64;
748
749    match backend {
750        BackendType::StateVector => {
751            // Memory: 2^n complex amplitudes * 16 bytes each.
752            let state_size = if n <= 63 { 1u64 << n } else { u64::MAX / 16 };
753            let memory_bytes = state_size.saturating_mul(16);
754            // FLOPs: each gate touches O(2^n) amplitudes with a few ops each.
755            // Single-qubit: ~4 * 2^(n-1) FLOPs; two-qubit: ~8 * 2^(n-2).
756            // Simplified to 8 * 2^n per gate.
757            let flops_per_gate = if n <= 60 {
758                8u64.saturating_mul(1u64 << n)
759            } else {
760                u64::MAX / gate_count.max(1)
761            };
762            let estimated_flops = gate_count.saturating_mul(flops_per_gate);
763            SegmentCost {
764                memory_bytes,
765                estimated_flops,
766                qubit_count: n,
767            }
768        }
769        BackendType::Stabilizer => {
770            // Memory: tableau of 2n rows x (2n+1) bits, stored as bools.
771            let tableau_size = 2 * (n as u64) * (2 * (n as u64) + 1);
772            let memory_bytes = tableau_size; // 1 byte per bool in practice
773            // FLOPs: O(n^2) per gate (row operations over 2n rows of width 2n+1).
774            let flops_per_gate = 4 * (n as u64) * (n as u64);
775            let estimated_flops = gate_count.saturating_mul(flops_per_gate);
776            SegmentCost {
777                memory_bytes,
778                estimated_flops,
779                qubit_count: n,
780            }
781        }
782        BackendType::TensorNetwork => {
783            // Memory: n tensors, each of dimension up to chi^2 * 4 (bond dim).
784            // Default chi ~ 64 for moderate entanglement.
785            let chi: u64 = 64;
786            let tensor_bytes = (n as u64) * chi * chi * 16; // complex entries
787            let memory_bytes = tensor_bytes;
788            // FLOPs: each gate requires SVD truncation ~ O(chi^3).
789            let flops_per_gate = chi * chi * chi;
790            let estimated_flops = gate_count.saturating_mul(flops_per_gate);
791            SegmentCost {
792                memory_bytes,
793                estimated_flops,
794                qubit_count: n,
795            }
796        }
797        BackendType::CliffordT => {
798            // Memory: 2^t stabiliser tableaux, each n^2 / 4 bytes.
799            let analysis = crate::backend::analyze_circuit(segment);
800            let t = analysis.non_clifford_gates as u32;
801            let terms: u64 = 1u64.checked_shl(t).unwrap_or(u64::MAX);
802            let tableau_bytes = (n as u64).saturating_mul(n as u64) / 4;
803            let memory_bytes = terms.saturating_mul(tableau_bytes).max(1);
804            // FLOPs: each of 2^t terms processes every gate at O(n^2).
805            let flops_per_gate = 4 * (n as u64) * (n as u64);
806            let estimated_flops = terms
807                .saturating_mul(gate_count)
808                .saturating_mul(flops_per_gate);
809            SegmentCost {
810                memory_bytes,
811                estimated_flops,
812                qubit_count: n,
813            }
814        }
815        BackendType::Auto => {
816            // For Auto, classify first, then estimate with the resolved backend.
817            let resolved = classify_segment(segment);
818            estimate_segment_cost(segment, resolved)
819        }
820    }
821}
822
823// ---------------------------------------------------------------------------
824// Result stitching
825// ---------------------------------------------------------------------------
826
827/// Probabilistically combine measurement results from independent circuit
828/// segments.
829///
830/// For independent segments, the probability of a combined bitstring is the
831/// product of the individual segment probabilities:
832///
833/// ```text
834/// P(combined) = P(segment_0) * P(segment_1) * ...
835/// ```
836///
837/// Each input element is `(bitstring, probability)` from one segment's
838/// simulation. The output maps combined bitstrings to their joint
839/// probabilities.
840pub fn stitch_results(
841    partitions: &[(Vec<bool>, f64)],
842) -> HashMap<Vec<bool>, f64> {
843    if partitions.is_empty() {
844        return HashMap::new();
845    }
846
847    // Group entries by segment: consecutive entries form a segment until the
848    // bitstring length changes. For simplicity, if all bitstrings have the
849    // same length, we treat them as a single segment and return as-is.
850    //
851    // The more general approach: the caller provides results as a flat list
852    // of (bitstring, probability) pairs from multiple independent segments.
853    // We combine by taking the Cartesian product.
854    //
855    // We use a simple iterative approach: start with an empty combined result,
856    // and for each new segment result, concatenate bitstrings and multiply
857    // probabilities.
858
859    // To differentiate segments, we group by consecutive runs of equal-length
860    // bitstrings. This is a pragmatic heuristic -- callers should provide
861    // segment results in order, with each segment having a distinct length.
862
863    let mut segments: Vec<Vec<(Vec<bool>, f64)>> = Vec::new();
864    let mut current_segment: Vec<(Vec<bool>, f64)> = Vec::new();
865    let mut current_len: Option<usize> = Option::None;
866
867    for (bits, prob) in partitions {
868        match current_len {
869            Some(l) if l == bits.len() => {
870                current_segment.push((bits.clone(), *prob));
871            }
872            _ => {
873                if !current_segment.is_empty() {
874                    segments.push(current_segment);
875                    current_segment = Vec::new();
876                }
877                current_len = Some(bits.len());
878                current_segment.push((bits.clone(), *prob));
879            }
880        }
881    }
882    if !current_segment.is_empty() {
883        segments.push(current_segment);
884    }
885
886    // Iteratively compute the Cartesian product.
887    let mut combined: Vec<(Vec<bool>, f64)> = vec![(Vec::new(), 1.0)];
888
889    for segment in &segments {
890        let mut next_combined: Vec<(Vec<bool>, f64)> = Vec::new();
891        for (base_bits, base_prob) in &combined {
892            for (seg_bits, seg_prob) in segment {
893                let mut merged = base_bits.clone();
894                merged.extend_from_slice(seg_bits);
895                next_combined.push((merged, base_prob * seg_prob));
896            }
897        }
898        combined = next_combined;
899    }
900
901    let mut result: HashMap<Vec<bool>, f64> = HashMap::new();
902    for (bits, prob) in combined {
903        *result.entry(bits).or_insert(0.0) += prob;
904    }
905
906    result
907}
908
909// ---------------------------------------------------------------------------
910// Fidelity-aware stitching
911// ---------------------------------------------------------------------------
912
913/// Fidelity estimate for a partition boundary.
914///
915/// Models the information loss when a quantum circuit is split across
916/// a partition boundary where entangling gates were cut. Each cut
917/// entangling gate reduces the fidelity by a factor related to the
918/// Schmidt decomposition rank at the cut.
919#[derive(Debug, Clone)]
920pub struct StitchFidelity {
921    /// Overall fidelity estimate (product of per-cut fidelities).
922    pub fidelity: f64,
923    /// Number of entangling gates that were cut.
924    pub cut_gates: usize,
925    /// Per-cut fidelity values.
926    pub per_cut_fidelity: Vec<f64>,
927}
928
929/// Stitch results with fidelity estimation.
930///
931/// Like [`stitch_results`], but also estimates the fidelity loss from
932/// partitioning. Each entangling gate that crosses a partition boundary
933/// contributes a fidelity penalty:
934///
935/// ```text
936/// F_cut = 1 / sqrt(2^k)
937/// ```
938///
939/// where k is the number of entangling gates crossing that particular
940/// boundary. This is a conservative upper bound derived from the fact
941/// that each maximally entangling gate can create at most 1 ebit of
942/// entanglement, and cutting it loses at most 1 bit of mutual information.
943///
944/// # Arguments
945///
946/// * `partitions` - Flat list of (bitstring, probability) pairs from all segments.
947/// * `partition_info` - The `CircuitPartition` used to understand cut structure.
948/// * `original_circuit` - The original (undecomposed) circuit for cut analysis.
949pub fn stitch_with_fidelity(
950    partitions: &[(Vec<bool>, f64)],
951    partition_info: &CircuitPartition,
952    original_circuit: &QuantumCircuit,
953) -> (HashMap<Vec<bool>, f64>, StitchFidelity) {
954    // Get the basic stitched distribution.
955    let distribution = stitch_results(partitions);
956
957    // Compute fidelity from the partition structure.
958    let fidelity = estimate_stitch_fidelity(partition_info, original_circuit);
959
960    (distribution, fidelity)
961}
962
963/// Estimate fidelity loss from circuit partitioning.
964///
965/// Analyzes the original circuit to count how many entangling gates
966/// cross each partition boundary.
967fn estimate_stitch_fidelity(
968    partition_info: &CircuitPartition,
969    original_circuit: &QuantumCircuit,
970) -> StitchFidelity {
971    if partition_info.segments.len() <= 1 {
972        return StitchFidelity {
973            fidelity: 1.0,
974            cut_gates: 0,
975            per_cut_fidelity: Vec::new(),
976        };
977    }
978
979    // Build a map: original qubit -> segment index.
980    let mut qubit_to_segment: HashMap<u32, usize> = HashMap::new();
981    for (seg_idx, segment) in partition_info.segments.iter().enumerate() {
982        let (lo, hi) = segment.qubit_range;
983        for q in lo..=hi {
984            qubit_to_segment.entry(q).or_insert(seg_idx);
985        }
986    }
987
988    // Count entangling gates that cross segment boundaries.
989    // Group by boundary pair (seg_a, seg_b) to compute per-boundary fidelity.
990    let mut boundary_cuts: HashMap<(usize, usize), usize> = HashMap::new();
991    let mut total_cut_gates = 0usize;
992
993    for gate in original_circuit.gates() {
994        let qubits = gate.qubits();
995        if qubits.len() != 2 {
996            continue;
997        }
998        let seg_a = qubit_to_segment.get(&qubits[0]).copied();
999        let seg_b = qubit_to_segment.get(&qubits[1]).copied();
1000
1001        if let (Some(a), Some(b)) = (seg_a, seg_b) {
1002            if a != b {
1003                let key = if a < b { (a, b) } else { (b, a) };
1004                *boundary_cuts.entry(key).or_insert(0) += 1;
1005                total_cut_gates += 1;
1006            }
1007        }
1008    }
1009
1010    // Compute per-boundary fidelity: F = 1/sqrt(2^k) where k is cut gate count.
1011    // This is conservative -- assumes each cut gate creates maximal entanglement.
1012    let per_cut_fidelity: Vec<f64> = boundary_cuts
1013        .values()
1014        .map(|&k| {
1015            if k == 0 {
1016                1.0
1017            } else {
1018                // F = 2^(-k/2)
1019                2.0_f64.powf(-(k as f64) / 2.0)
1020            }
1021        })
1022        .collect();
1023
1024    let overall_fidelity = per_cut_fidelity.iter().product::<f64>();
1025
1026    StitchFidelity {
1027        fidelity: overall_fidelity,
1028        cut_gates: total_cut_gates,
1029        per_cut_fidelity,
1030    }
1031}
1032
1033// ---------------------------------------------------------------------------
1034// Main decomposition entry point
1035// ---------------------------------------------------------------------------
1036
1037/// Decompose a quantum circuit into segments for multi-backend simulation.
1038///
1039/// This is the primary entry point for the decomposition engine. The
1040/// algorithm proceeds as follows:
1041///
1042/// 1. Build the qubit interaction graph (nodes = qubits, edges = two-qubit
1043///    gates).
1044/// 2. Identify connected components. Disconnected components become separate
1045///    spatial segments immediately.
1046/// 3. For each connected component, attempt temporal decomposition at
1047///    barriers and natural breakpoints.
1048/// 4. Classify each resulting segment to select the optimal backend.
1049/// 5. If any segment exceeds `max_segment_qubits`, attempt further spatial
1050///    decomposition using a greedy min-cut heuristic.
1051/// 6. Estimate costs for every final segment.
1052///
1053/// # Arguments
1054///
1055/// * `circuit` - The circuit to decompose.
1056/// * `max_segment_qubits` - Maximum number of qubits allowed per segment.
1057///   Segments exceeding this limit are spatially subdivided.
1058pub fn decompose(circuit: &QuantumCircuit, max_segment_qubits: u32) -> CircuitPartition {
1059    let n = circuit.num_qubits();
1060    let gates = circuit.gates();
1061
1062    // Trivial case: empty circuit or single qubit.
1063    if gates.is_empty() || n <= 1 {
1064        let backend = classify_segment(circuit);
1065        let cost = estimate_segment_cost(circuit, backend);
1066        return CircuitPartition {
1067            segments: vec![CircuitSegment {
1068                circuit: circuit.clone(),
1069                backend,
1070                qubit_range: (0, n.saturating_sub(1)),
1071                gate_range: (0, gates.len()),
1072                estimated_cost: cost,
1073            }],
1074            total_qubits: n,
1075            strategy: DecompositionStrategy::None,
1076        };
1077    }
1078
1079    // Step 1: Build the interaction graph.
1080    let graph = build_interaction_graph(circuit);
1081
1082    // Step 2: Find connected components.
1083    let components = find_connected_components(&graph);
1084
1085    let mut used_spatial = false;
1086    let mut used_temporal = false;
1087    let mut final_segments: Vec<CircuitSegment> = Vec::new();
1088
1089    if components.len() > 1 {
1090        used_spatial = true;
1091    }
1092
1093    // Step 3: For each connected component, extract its subcircuit and
1094    // attempt temporal decomposition.
1095    for component in &components {
1096        let comp_set: HashSet<u32> = component.iter().copied().collect();
1097
1098        // Extract the subcircuit for this component.
1099        let comp_circuit = extract_component_circuit(circuit, &comp_set);
1100
1101        // Find the gate index range in the original circuit for this component.
1102        let gate_indices = gate_indices_for_component(circuit, &comp_set);
1103        let gate_range_start = gate_indices.first().copied().unwrap_or(0);
1104        let _gate_range_end = gate_indices
1105            .last()
1106            .map(|&i| i + 1)
1107            .unwrap_or(0);
1108
1109        // Temporal decomposition within the component.
1110        let time_slices = temporal_decomposition(&comp_circuit);
1111
1112        if time_slices.len() > 1 {
1113            used_temporal = true;
1114        }
1115
1116        // Track cumulative gate offset for slices.
1117        let mut slice_gate_offset = gate_range_start;
1118
1119        for slice_circuit in &time_slices {
1120            let slice_gate_count = slice_circuit.gate_count();
1121
1122            // Step 4: Classify the segment.
1123            let backend = classify_segment(slice_circuit);
1124
1125            // Step 5: If the segment is too large, attempt spatial decomposition.
1126            if slice_circuit.num_qubits() > max_segment_qubits
1127                && active_qubit_count(slice_circuit) > max_segment_qubits
1128            {
1129                used_spatial = true;
1130                let sub_graph = build_interaction_graph(slice_circuit);
1131                let sub_parts =
1132                    spatial_decomposition(slice_circuit, &sub_graph, max_segment_qubits);
1133
1134                for (qubit_group, sub_circ) in &sub_parts {
1135                    let sub_backend = classify_segment(sub_circ);
1136                    let cost = estimate_segment_cost(sub_circ, sub_backend);
1137                    let qmin = qubit_group.iter().copied().min().unwrap_or(0);
1138                    let qmax = qubit_group.iter().copied().max().unwrap_or(0);
1139
1140                    final_segments.push(CircuitSegment {
1141                        circuit: sub_circ.clone(),
1142                        backend: sub_backend,
1143                        qubit_range: (qmin, qmax),
1144                        gate_range: (slice_gate_offset, slice_gate_offset + slice_gate_count),
1145                        estimated_cost: cost,
1146                    });
1147                }
1148            } else {
1149                let cost = estimate_segment_cost(slice_circuit, backend);
1150                let qmin = component.iter().copied().min().unwrap_or(0);
1151                let qmax = component.iter().copied().max().unwrap_or(0);
1152
1153                final_segments.push(CircuitSegment {
1154                    circuit: slice_circuit.clone(),
1155                    backend,
1156                    qubit_range: (qmin, qmax),
1157                    gate_range: (slice_gate_offset, slice_gate_offset + slice_gate_count),
1158                    estimated_cost: cost,
1159                });
1160            }
1161
1162            slice_gate_offset += slice_gate_count;
1163        }
1164    }
1165
1166    // Determine the overall strategy.
1167    let strategy = match (used_temporal, used_spatial) {
1168        (true, true) => DecompositionStrategy::Hybrid,
1169        (true, false) => DecompositionStrategy::Temporal,
1170        (false, true) => DecompositionStrategy::Spatial,
1171        (false, false) => DecompositionStrategy::None,
1172    };
1173
1174    CircuitPartition {
1175        segments: final_segments,
1176        total_qubits: n,
1177        strategy,
1178    }
1179}
1180
1181// ---------------------------------------------------------------------------
1182// Internal helpers
1183// ---------------------------------------------------------------------------
1184
1185/// Count the number of qubits that are actually used (touched by at least one
1186/// gate) in a circuit.
1187fn active_qubit_count(circuit: &QuantumCircuit) -> u32 {
1188    let mut active: HashSet<u32> = HashSet::new();
1189    for gate in circuit.gates() {
1190        for &q in &gate.qubits() {
1191            active.insert(q);
1192        }
1193    }
1194    active.len() as u32
1195}
1196
1197/// Extract a subcircuit containing only the gates that act on qubits in the
1198/// given component set. The subcircuit has `num_qubits` equal to the size of
1199/// the component, with qubit indices remapped to `0..component.len()`.
1200fn extract_component_circuit(
1201    circuit: &QuantumCircuit,
1202    component: &HashSet<u32>,
1203) -> QuantumCircuit {
1204    // Build a sorted list for deterministic remapping.
1205    let mut sorted_qubits: Vec<u32> = component.iter().copied().collect();
1206    sorted_qubits.sort_unstable();
1207    let remap: HashMap<u32, u32> = sorted_qubits
1208        .iter()
1209        .enumerate()
1210        .map(|(i, &q)| (q, i as u32))
1211        .collect();
1212
1213    let num_local = sorted_qubits.len() as u32;
1214    let mut sub_circuit = QuantumCircuit::new(num_local);
1215
1216    for gate in circuit.gates() {
1217        match gate {
1218            Gate::Barrier => {
1219                // Include barriers in every component subcircuit.
1220                sub_circuit.add_gate(Gate::Barrier);
1221            }
1222            _ => {
1223                let qubits = gate.qubits();
1224                if qubits.is_empty() {
1225                    continue;
1226                }
1227                // Include the gate only if all its qubits are in this component.
1228                if qubits.iter().all(|q| component.contains(q)) {
1229                    sub_circuit.add_gate(remap_gate(gate, &remap));
1230                }
1231            }
1232        }
1233    }
1234
1235    sub_circuit
1236}
1237
1238/// Find the gate indices in the original circuit that belong to a given
1239/// qubit component.
1240fn gate_indices_for_component(circuit: &QuantumCircuit, component: &HashSet<u32>) -> Vec<usize> {
1241    circuit
1242        .gates()
1243        .iter()
1244        .enumerate()
1245        .filter_map(|(i, gate)| {
1246            let qubits = gate.qubits();
1247            if qubits.is_empty() {
1248                return Some(i); // Barrier belongs to all components.
1249            }
1250            if qubits.iter().any(|q| component.contains(q)) {
1251                Some(i)
1252            } else {
1253                Option::None
1254            }
1255        })
1256        .collect()
1257}
1258
1259// ---------------------------------------------------------------------------
1260// Tests
1261// ---------------------------------------------------------------------------
1262
1263#[cfg(test)]
1264mod tests {
1265    use super::*;
1266
1267    /// Helper: create two independent Bell pairs on qubits (0,1) and (2,3).
1268    fn two_bell_pairs() -> QuantumCircuit {
1269        let mut circ = QuantumCircuit::new(4);
1270        circ.h(0).cnot(0, 1); // Bell pair on 0,1
1271        circ.h(2).cnot(2, 3); // Bell pair on 2,3
1272        circ
1273    }
1274
1275    // ----- Test 1: Two independent Bell states decompose into 2 spatial segments -----
1276
1277    #[test]
1278    fn two_independent_bell_states_decompose_into_two_segments() {
1279        let circ = two_bell_pairs();
1280        let partition = decompose(&circ, 25);
1281
1282        assert_eq!(
1283            partition.segments.len(),
1284            2,
1285            "expected 2 segments for two independent Bell pairs, got {}",
1286            partition.segments.len()
1287        );
1288        assert_eq!(partition.strategy, DecompositionStrategy::Spatial);
1289
1290        // Each segment should have 2 qubits.
1291        for seg in &partition.segments {
1292            assert_eq!(
1293                seg.circuit.num_qubits(),
1294                2,
1295                "each Bell pair segment should have 2 qubits"
1296            );
1297        }
1298    }
1299
1300    // ----- Test 2: Pure Clifford segment is classified as Stabilizer -----
1301
1302    #[test]
1303    fn pure_clifford_classified_as_stabilizer() {
1304        let mut circ = QuantumCircuit::new(4);
1305        circ.h(0).cnot(0, 1).s(2).cz(2, 3).x(1).y(3).z(0);
1306
1307        let backend = classify_segment(&circ);
1308        assert_eq!(
1309            backend,
1310            BackendType::Stabilizer,
1311            "all-Clifford circuit should be classified as Stabilizer"
1312        );
1313    }
1314
1315    // ----- Test 3: Temporal decomposition splits at barriers -----
1316
1317    #[test]
1318    fn temporal_decomposition_splits_at_barriers() {
1319        let mut circ = QuantumCircuit::new(2);
1320        circ.h(0).cnot(0, 1);
1321        circ.barrier();
1322        circ.x(0).z(1);
1323
1324        let slices = temporal_decomposition(&circ);
1325        assert_eq!(
1326            slices.len(),
1327            2,
1328            "expected 2 time slices around barrier, got {}",
1329            slices.len()
1330        );
1331
1332        // First slice: H + CNOT = 2 gates.
1333        assert_eq!(slices[0].gate_count(), 2);
1334        // Second slice: X + Z = 2 gates.
1335        assert_eq!(slices[1].gate_count(), 2);
1336    }
1337
1338    // ----- Test 4: Connected circuit stays as single segment -----
1339
1340    #[test]
1341    fn connected_circuit_stays_as_single_segment() {
1342        let mut circ = QuantumCircuit::new(4);
1343        circ.h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3);
1344
1345        let partition = decompose(&circ, 25);
1346        assert_eq!(
1347            partition.segments.len(),
1348            1,
1349            "fully connected circuit should remain a single segment"
1350        );
1351        assert_eq!(partition.strategy, DecompositionStrategy::None);
1352    }
1353
1354    // ----- Test 5: Interaction graph correctly counts two-qubit gate edges -----
1355
1356    #[test]
1357    fn interaction_graph_counts_edges() {
1358        let mut circ = QuantumCircuit::new(3);
1359        circ.cnot(0, 1); // edge (0,1)
1360        circ.cnot(0, 1); // edge (0,1) again
1361        circ.cz(1, 2); // edge (1,2)
1362
1363        let graph = build_interaction_graph(&circ);
1364
1365        assert_eq!(graph.num_qubits, 3);
1366        assert_eq!(graph.edges.len(), 2, "should have 2 distinct edges");
1367
1368        // Find the (0,1) edge and check its count.
1369        let edge_01 = graph
1370            .edges
1371            .iter()
1372            .find(|&&(a, b, _)| a == 0 && b == 1);
1373        assert!(edge_01.is_some(), "edge (0,1) should exist");
1374        assert_eq!(edge_01.unwrap().2, 2, "edge (0,1) should have count 2");
1375
1376        // Find the (1,2) edge.
1377        let edge_12 = graph
1378            .edges
1379            .iter()
1380            .find(|&&(a, b, _)| a == 1 && b == 2);
1381        assert!(edge_12.is_some(), "edge (1,2) should exist");
1382        assert_eq!(edge_12.unwrap().2, 1, "edge (1,2) should have count 1");
1383
1384        // Check adjacency.
1385        assert!(graph.adjacency[0].contains(&1));
1386        assert!(graph.adjacency[1].contains(&0));
1387        assert!(graph.adjacency[1].contains(&2));
1388        assert!(graph.adjacency[2].contains(&1));
1389    }
1390
1391    // ----- Test 6: Spatial decomposition respects max_qubits limit -----
1392
1393    #[test]
1394    fn spatial_decomposition_respects_max_qubits() {
1395        // Create a 6-qubit circuit with a chain of CNOT gates.
1396        let mut circ = QuantumCircuit::new(6);
1397        for q in 0..5 {
1398            circ.cnot(q, q + 1);
1399        }
1400
1401        let graph = build_interaction_graph(&circ);
1402        let parts = spatial_decomposition(&circ, &graph, 3);
1403
1404        // Every group should have at most 3 qubits.
1405        for (group, _sub_circ) in &parts {
1406            assert!(
1407                group.len() <= 3,
1408                "group {:?} has {} qubits, expected at most 3",
1409                group,
1410                group.len()
1411            );
1412        }
1413
1414        // All 6 qubits should be covered.
1415        let mut all_qubits: Vec<u32> = parts
1416            .iter()
1417            .flat_map(|(group, _)| group.iter().copied())
1418            .collect();
1419        all_qubits.sort_unstable();
1420        all_qubits.dedup();
1421        assert_eq!(all_qubits.len(), 6, "all 6 qubits should be covered");
1422    }
1423
1424    // ----- Test 7: Segment cost estimation produces reasonable values -----
1425
1426    #[test]
1427    fn segment_cost_estimation_reasonable() {
1428        let mut circ = QuantumCircuit::new(10);
1429        circ.h(0).cnot(0, 1).t(2);
1430
1431        // StateVector cost.
1432        let sv_cost = estimate_segment_cost(&circ, BackendType::StateVector);
1433        assert_eq!(sv_cost.qubit_count, 10);
1434        // 2^10 * 16 = 16384 bytes.
1435        assert_eq!(sv_cost.memory_bytes, 16384);
1436        assert!(sv_cost.estimated_flops > 0);
1437
1438        // Stabilizer cost.
1439        let stab_cost = estimate_segment_cost(&circ, BackendType::Stabilizer);
1440        assert_eq!(stab_cost.qubit_count, 10);
1441        // Tableau: 2*10*(2*10+1) = 420 bytes.
1442        assert_eq!(stab_cost.memory_bytes, 420);
1443        assert!(stab_cost.estimated_flops > 0);
1444
1445        // TensorNetwork cost.
1446        let tn_cost = estimate_segment_cost(&circ, BackendType::TensorNetwork);
1447        assert_eq!(tn_cost.qubit_count, 10);
1448        // 10 * 64 * 64 * 16 = 655360.
1449        assert_eq!(tn_cost.memory_bytes, 655_360);
1450        assert!(tn_cost.estimated_flops > 0);
1451
1452        // StateVector memory should be much less than TN for small qubit counts,
1453        // and stabilizer should be the smallest.
1454        assert!(stab_cost.memory_bytes < sv_cost.memory_bytes);
1455    }
1456
1457    // ----- Test 8: 10-qubit GHZ circuit stays as one segment (fully connected) -----
1458
1459    #[test]
1460    fn ghz_10_qubit_single_segment() {
1461        let mut circ = QuantumCircuit::new(10);
1462        circ.h(0);
1463        for q in 0..9 {
1464            circ.cnot(q, q + 1);
1465        }
1466
1467        let partition = decompose(&circ, 25);
1468        assert_eq!(
1469            partition.segments.len(),
1470            1,
1471            "10-qubit GHZ circuit should stay as one segment"
1472        );
1473
1474        // The GHZ circuit is all Clifford, so backend should be Stabilizer.
1475        assert_eq!(partition.segments[0].backend, BackendType::Stabilizer);
1476    }
1477
1478    // ----- Test 9: Disconnected 20-qubit circuit decomposes -----
1479
1480    #[test]
1481    fn disconnected_20_qubit_circuit_decomposes() {
1482        let mut circ = QuantumCircuit::new(20);
1483
1484        // Block A: qubits 0..9 (GHZ-like).
1485        circ.h(0);
1486        for q in 0..9 {
1487            circ.cnot(q, q + 1);
1488        }
1489
1490        // Block B: qubits 10..19 (GHZ-like).
1491        circ.h(10);
1492        for q in 10..19 {
1493            circ.cnot(q, q + 1);
1494        }
1495
1496        let partition = decompose(&circ, 25);
1497        assert_eq!(
1498            partition.segments.len(),
1499            2,
1500            "two disconnected 10-qubit blocks should yield 2 segments, got {}",
1501            partition.segments.len()
1502        );
1503        assert_eq!(partition.total_qubits, 20);
1504        assert_eq!(partition.strategy, DecompositionStrategy::Spatial);
1505
1506        // Each segment should have 10 qubits.
1507        for seg in &partition.segments {
1508            assert_eq!(seg.circuit.num_qubits(), 10);
1509        }
1510    }
1511
1512    // ----- Additional tests for edge cases and coverage -----
1513
1514    #[test]
1515    fn empty_circuit_produces_single_segment() {
1516        let circ = QuantumCircuit::new(4);
1517        let partition = decompose(&circ, 25);
1518        assert_eq!(partition.segments.len(), 1);
1519        assert_eq!(partition.strategy, DecompositionStrategy::None);
1520    }
1521
1522    #[test]
1523    fn single_qubit_circuit() {
1524        let mut circ = QuantumCircuit::new(1);
1525        circ.h(0).t(0);
1526        let partition = decompose(&circ, 25);
1527        assert_eq!(partition.segments.len(), 1);
1528        assert_eq!(partition.segments[0].backend, BackendType::StateVector);
1529    }
1530
1531    #[test]
1532    fn mixed_clifford_non_clifford_classification() {
1533        // Circuit with one T gate among Cliffords.
1534        let mut circ = QuantumCircuit::new(5);
1535        circ.h(0).cnot(0, 1).t(2).s(3);
1536
1537        let backend = classify_segment(&circ);
1538        assert_eq!(
1539            backend,
1540            BackendType::StateVector,
1541            "mixed circuit with <= 25 qubits should use StateVector"
1542        );
1543    }
1544
1545    #[test]
1546    fn connected_components_isolated_qubits() {
1547        // Circuit where qubit 2 has no two-qubit gates.
1548        let mut circ = QuantumCircuit::new(3);
1549        circ.cnot(0, 1).h(2);
1550
1551        let graph = build_interaction_graph(&circ);
1552        let components = find_connected_components(&graph);
1553
1554        assert_eq!(
1555            components.len(),
1556            2,
1557            "qubit 2 is isolated, should form its own component"
1558        );
1559
1560        // One component should be {0, 1}, the other {2}.
1561        let has_pair = components.iter().any(|c| c == &vec![0, 1]);
1562        let has_single = components.iter().any(|c| c == &vec![2]);
1563        assert!(has_pair, "component {{0, 1}} should exist");
1564        assert!(has_single, "component {{2}} should exist");
1565    }
1566
1567    #[test]
1568    fn stitch_results_independent_segments() {
1569        // Segment 1: 1-qubit outcomes.
1570        // Segment 2: 1-qubit outcomes.
1571        let partitions = vec![
1572            (vec![false], 0.5),
1573            (vec![true], 0.5),
1574            (vec![false, false], 0.25),
1575            (vec![true, true], 0.75),
1576        ];
1577
1578        let combined = stitch_results(&partitions);
1579
1580        // Combined bitstrings: 1-bit x 2-bit.
1581        // (false, false, false) = 0.5 * 0.25 = 0.125
1582        // (false, true, true)   = 0.5 * 0.75 = 0.375
1583        // (true, false, false)  = 0.5 * 0.25 = 0.125
1584        // (true, true, true)    = 0.5 * 0.75 = 0.375
1585        assert_eq!(combined.len(), 4);
1586
1587        let prob_fff = combined.get(&vec![false, false, false]).copied().unwrap_or(0.0);
1588        let prob_ftt = combined.get(&vec![false, true, true]).copied().unwrap_or(0.0);
1589        let prob_tff = combined.get(&vec![true, false, false]).copied().unwrap_or(0.0);
1590        let prob_ttt = combined.get(&vec![true, true, true]).copied().unwrap_or(0.0);
1591
1592        assert!((prob_fff - 0.125).abs() < 1e-10);
1593        assert!((prob_ftt - 0.375).abs() < 1e-10);
1594        assert!((prob_tff - 0.125).abs() < 1e-10);
1595        assert!((prob_ttt - 0.375).abs() < 1e-10);
1596    }
1597
1598    #[test]
1599    fn stitch_results_empty() {
1600        let combined = stitch_results(&[]);
1601        assert!(combined.is_empty());
1602    }
1603
1604    #[test]
1605    fn classify_large_moderate_t_as_clifford_t() {
1606        // 30 qubits with 1 T-gate -> CliffordT (moderate T-count, large circuit).
1607        let mut circ = QuantumCircuit::new(30);
1608        circ.h(0);
1609        circ.t(1); // non-Clifford
1610        for q in 0..29 {
1611            circ.cnot(q, q + 1);
1612        }
1613
1614        let backend = classify_segment(&circ);
1615        assert_eq!(
1616            backend,
1617            BackendType::CliffordT,
1618            "moderate T-count on > 25 qubits should use CliffordT"
1619        );
1620    }
1621
1622    #[test]
1623    fn classify_large_high_t_as_tensor_network() {
1624        // 30 qubits with 50 T-gates -> TensorNetwork (too many for CliffordT).
1625        let mut circ = QuantumCircuit::new(30);
1626        for q in 0..29 {
1627            circ.cnot(q, q + 1);
1628        }
1629        for _ in 0..50 {
1630            circ.rx(0, 1.0); // non-Clifford
1631        }
1632
1633        let backend = classify_segment(&circ);
1634        assert_eq!(
1635            backend,
1636            BackendType::TensorNetwork,
1637            "high T-count on > 25 qubits should use TensorNetwork"
1638        );
1639    }
1640
1641    #[test]
1642    fn temporal_decomposition_no_barriers_single_slice() {
1643        let mut circ = QuantumCircuit::new(2);
1644        circ.h(0).cnot(0, 1);
1645
1646        let slices = temporal_decomposition(&circ);
1647        assert_eq!(
1648            slices.len(),
1649            1,
1650            "circuit without barriers should produce a single time slice"
1651        );
1652        assert_eq!(slices[0].gate_count(), 2);
1653    }
1654
1655    #[test]
1656    fn temporal_decomposition_multiple_barriers() {
1657        let mut circ = QuantumCircuit::new(2);
1658        circ.h(0);
1659        circ.barrier();
1660        circ.cnot(0, 1);
1661        circ.barrier();
1662        circ.x(0);
1663
1664        let slices = temporal_decomposition(&circ);
1665        assert_eq!(
1666            slices.len(),
1667            3,
1668            "two barriers should produce three time slices"
1669        );
1670    }
1671
1672    #[test]
1673    fn cost_auto_backend_resolves() {
1674        let mut circ = QuantumCircuit::new(4);
1675        circ.h(0).cnot(0, 1);
1676
1677        let cost = estimate_segment_cost(&circ, BackendType::Auto);
1678        // Auto should resolve to Stabilizer for this all-Clifford circuit.
1679        let stab_cost = estimate_segment_cost(&circ, BackendType::Stabilizer);
1680        assert_eq!(cost.memory_bytes, stab_cost.memory_bytes);
1681        assert_eq!(cost.estimated_flops, stab_cost.estimated_flops);
1682    }
1683
1684    #[test]
1685    fn decompose_with_measurements() {
1686        let mut circ = QuantumCircuit::new(4);
1687        circ.h(0).cnot(0, 1).measure(0).measure(1);
1688        circ.h(2).cnot(2, 3).measure(2).measure(3);
1689
1690        let partition = decompose(&circ, 25);
1691        // Qubits (0,1) and (2,3) are disconnected.
1692        assert_eq!(partition.segments.len(), 2);
1693    }
1694
1695    #[test]
1696    fn interaction_graph_empty_circuit() {
1697        let circ = QuantumCircuit::new(5);
1698        let graph = build_interaction_graph(&circ);
1699
1700        assert_eq!(graph.num_qubits, 5);
1701        assert!(graph.edges.is_empty());
1702        for adj in &graph.adjacency {
1703            assert!(adj.is_empty());
1704        }
1705    }
1706
1707    #[test]
1708    fn connected_components_fully_connected() {
1709        let mut circ = QuantumCircuit::new(4);
1710        circ.cnot(0, 1).cnot(1, 2).cnot(2, 3);
1711
1712        let graph = build_interaction_graph(&circ);
1713        let components = find_connected_components(&graph);
1714
1715        assert_eq!(
1716            components.len(),
1717            1,
1718            "fully connected chain should be one component"
1719        );
1720        assert_eq!(components[0], vec![0, 1, 2, 3]);
1721    }
1722
1723    #[test]
1724    fn spatial_decomposition_returns_single_group_if_fits() {
1725        let mut circ = QuantumCircuit::new(4);
1726        circ.cnot(0, 1).cnot(2, 3);
1727
1728        let graph = build_interaction_graph(&circ);
1729        let parts = spatial_decomposition(&circ, &graph, 10);
1730
1731        // 4 qubits <= 10, so should return a single group.
1732        assert_eq!(parts.len(), 1);
1733        assert_eq!(parts[0].0, vec![0, 1, 2, 3]);
1734    }
1735
1736    #[test]
1737    fn segment_qubit_ranges_are_valid() {
1738        let circ = two_bell_pairs();
1739        let partition = decompose(&circ, 25);
1740
1741        for seg in &partition.segments {
1742            let (qmin, qmax) = seg.qubit_range;
1743            assert!(qmin <= qmax, "qubit_range should be non-inverted");
1744            assert!(
1745                qmax < partition.total_qubits,
1746                "qubit_range max should be within total_qubits"
1747            );
1748        }
1749    }
1750
1751    #[test]
1752    fn classify_segment_measure_only() {
1753        // A circuit with only measurements should be classified as Stabilizer
1754        // (all gates are non-unitary, so has_non_clifford stays false).
1755        let mut circ = QuantumCircuit::new(3);
1756        circ.measure(0).measure(1).measure(2);
1757
1758        let backend = classify_segment(&circ);
1759        assert_eq!(backend, BackendType::Stabilizer);
1760    }
1761
1762    #[test]
1763    fn classify_segment_empty_circuit() {
1764        let circ = QuantumCircuit::new(5);
1765        let backend = classify_segment(&circ);
1766        assert_eq!(
1767            backend,
1768            BackendType::Stabilizer,
1769            "empty circuit has no non-Clifford gates"
1770        );
1771    }
1772
1773    // ----- Stoer-Wagner min-cut tests -----
1774
1775    #[test]
1776    fn test_stoer_wagner_mincut_linear() {
1777        // Linear chain: 0-1-2-3-4
1778        // Min cut should be 1 (cutting any single edge).
1779        let mut circ = QuantumCircuit::new(5);
1780        circ.cnot(0, 1).cnot(1, 2).cnot(2, 3).cnot(3, 4);
1781        let graph = build_interaction_graph(&circ);
1782        let cut = stoer_wagner_mincut(&graph).unwrap();
1783        assert_eq!(cut.cut_value, 1);
1784        assert!(!cut.partition_a.is_empty());
1785        assert!(!cut.partition_b.is_empty());
1786    }
1787
1788    #[test]
1789    fn test_stoer_wagner_mincut_triangle() {
1790        // Triangle: 0-1, 1-2, 0-2 (each with weight 1).
1791        // Min cut = 2 (cutting any vertex out cuts 2 edges).
1792        let mut circ = QuantumCircuit::new(3);
1793        circ.cnot(0, 1).cnot(1, 2).cnot(0, 2);
1794        let graph = build_interaction_graph(&circ);
1795        let cut = stoer_wagner_mincut(&graph).unwrap();
1796        assert_eq!(cut.cut_value, 2);
1797    }
1798
1799    #[test]
1800    fn test_stoer_wagner_mincut_barbell() {
1801        // Barbell: clique(0,1,2) - bridge(2,3) - clique(3,4,5)
1802        // Min cut should be 1 (cutting the bridge).
1803        let mut circ = QuantumCircuit::new(6);
1804        // Left clique.
1805        circ.cnot(0, 1).cnot(1, 2).cnot(0, 2);
1806        // Bridge.
1807        circ.cnot(2, 3);
1808        // Right clique.
1809        circ.cnot(3, 4).cnot(4, 5).cnot(3, 5);
1810        let graph = build_interaction_graph(&circ);
1811        let cut = stoer_wagner_mincut(&graph).unwrap();
1812        assert_eq!(cut.cut_value, 1);
1813    }
1814
1815    #[test]
1816    fn test_spatial_decomposition_mincut() {
1817        // 6-qubit barbell, max 3 qubits per segment.
1818        let mut circ = QuantumCircuit::new(6);
1819        circ.cnot(0, 1).cnot(1, 2).cnot(0, 2);
1820        circ.cnot(2, 3);
1821        circ.cnot(3, 4).cnot(4, 5).cnot(3, 5);
1822        let graph = build_interaction_graph(&circ);
1823        let parts = spatial_decomposition_mincut(&circ, &graph, 3);
1824        assert!(parts.len() >= 2, "Should partition into at least 2 groups");
1825        for (qubits, _sub_circ) in &parts {
1826            assert!(qubits.len() as u32 <= 3, "Each group should have at most 3 qubits");
1827        }
1828    }
1829
1830    // ----- Fidelity-aware stitching tests -----
1831
1832    #[test]
1833    fn test_stitch_with_fidelity_single_segment() {
1834        let circ = QuantumCircuit::new(2);
1835        let partition = CircuitPartition {
1836            segments: vec![CircuitSegment {
1837                circuit: circ.clone(),
1838                backend: BackendType::Stabilizer,
1839                qubit_range: (0, 1),
1840                gate_range: (0, 0),
1841                estimated_cost: SegmentCost {
1842                    memory_bytes: 0,
1843                    estimated_flops: 0,
1844                    qubit_count: 2,
1845                },
1846            }],
1847            total_qubits: 2,
1848            strategy: DecompositionStrategy::None,
1849        };
1850        let partitions = vec![(vec![false, false], 1.0)];
1851        let (dist, fidelity) = stitch_with_fidelity(&partitions, &partition, &circ);
1852        assert_eq!(fidelity.fidelity, 1.0);
1853        assert_eq!(fidelity.cut_gates, 0);
1854        assert!(!dist.is_empty());
1855    }
1856
1857    #[test]
1858    fn test_stitch_with_fidelity_cut_circuit() {
1859        // Circuit with a CNOT crossing a partition boundary.
1860        let mut circ = QuantumCircuit::new(4);
1861        circ.h(0).cnot(0, 1); // Bell pair 0-1
1862        circ.h(2).cnot(2, 3); // Bell pair 2-3
1863        circ.cnot(1, 2);       // Cross-partition gate
1864
1865        let partition = CircuitPartition {
1866            segments: vec![
1867                CircuitSegment {
1868                    circuit: {
1869                        let mut c = QuantumCircuit::new(2);
1870                        c.h(0).cnot(0, 1);
1871                        c
1872                    },
1873                    backend: BackendType::Stabilizer,
1874                    qubit_range: (0, 1),
1875                    gate_range: (0, 2),
1876                    estimated_cost: SegmentCost { memory_bytes: 0, estimated_flops: 0, qubit_count: 2 },
1877                },
1878                CircuitSegment {
1879                    circuit: {
1880                        let mut c = QuantumCircuit::new(2);
1881                        c.h(0).cnot(0, 1);
1882                        c
1883                    },
1884                    backend: BackendType::Stabilizer,
1885                    qubit_range: (2, 3),
1886                    gate_range: (2, 4),
1887                    estimated_cost: SegmentCost { memory_bytes: 0, estimated_flops: 0, qubit_count: 2 },
1888                },
1889            ],
1890            total_qubits: 4,
1891            strategy: DecompositionStrategy::Spatial,
1892        };
1893
1894        let partitions = vec![
1895            (vec![false, false], 0.5),
1896            (vec![true, true], 0.5),
1897            (vec![false, false], 0.5),
1898            (vec![true, true], 0.5),
1899        ];
1900        let (_dist, fidelity) = stitch_with_fidelity(&partitions, &partition, &circ);
1901        assert!(fidelity.fidelity < 1.0, "Cut circuit should have fidelity < 1.0");
1902        assert!(fidelity.cut_gates >= 1, "Should detect at least 1 cut gate");
1903    }
1904}