Skip to main content

tket_qsystem/
lib.rs

1//! Provides a preparation and validation workflow for Hugrs targeting
2//! Quantinuum H-series quantum computers.
3
4#[cfg(feature = "cli")]
5pub mod cli;
6pub mod extension;
7#[cfg(feature = "llvm")]
8pub mod llvm;
9pub mod lower_drops;
10pub mod pytket;
11pub mod replace_bools;
12
13use derive_more::{Display, Error, From};
14use hugr::algorithms::const_fold::{ConstFoldError, ConstantFoldPass};
15use hugr::algorithms::{
16    ComposablePass as _, MonomorphizePass, RemoveDeadFuncsError, RemoveDeadFuncsPass, force_order,
17    replace_types::ReplaceTypesError,
18};
19use hugr::hugr::{HugrError, hugrmut::HugrMut};
20use hugr::{Hugr, HugrView, Node, core::Visibility, ops::OpType};
21use hugr_core::hugr::internal::HugrMutInternals;
22use std::collections::HashSet;
23
24use lower_drops::LowerDropsPass;
25use replace_bools::{ReplaceBoolPass, ReplaceBoolPassError};
26use tket::TketOp;
27
28use extension::{
29    futures::FutureOpDef,
30    qsystem::{LowerTk2Error, LowerTketToQSystemPass, QSystemOp},
31};
32
33#[cfg(feature = "llvm")]
34#[expect(deprecated)]
35// TODO: We still want to run this as long as deserialized hugrs are allowed to contain Value::Function
36// Once that variant is removed, we can remove this pass step.
37use hugr::llvm::utils::inline_constant_functions;
38
39/// Modify a [hugr::Hugr] into a form that is acceptable for ingress into a
40/// Q-System. Returns an error if this cannot be done.
41///
42/// To construct a `QSystemPass` use [Default::default].
43#[derive(Debug, Clone, Copy)]
44pub struct QSystemPass {
45    constant_fold: bool,
46    monomorphize: bool,
47    force_order: bool,
48    lazify: bool,
49    hide_funcs: bool,
50}
51
52impl Default for QSystemPass {
53    fn default() -> Self {
54        Self {
55            constant_fold: false,
56            monomorphize: true,
57            force_order: true,
58            lazify: true,
59            hide_funcs: true,
60        }
61    }
62}
63
64#[derive(Error, Debug, Display, From)]
65#[non_exhaustive]
66/// An error reported from [QSystemPass].
67pub enum QSystemPassError<N = Node> {
68    /// An error from the component [ReplaceBoolPass].
69    ReplaceBoolError(ReplaceBoolPassError<N>),
70    /// An error from the component [force_order()] pass.
71    ForceOrderError(HugrError),
72    /// An error from the component [LowerTketToQSystemPass] pass.
73    LowerTk2Error(LowerTk2Error),
74    /// An error from the component [ConstantFoldPass] pass.
75    ConstantFoldError(ConstFoldError),
76    /// An error from the component [LowerDropsPass] pass.
77    LinearizeArrayError(ReplaceTypesError),
78    #[cfg(feature = "llvm")]
79    /// An error from the component [inline_constant_functions()] pass.
80    InlineConstantFunctionsError(anyhow::Error),
81    /// An error when running [RemoveDeadFuncsPass] after the monomorphisation
82    /// pass.
83    ///
84    ///  [RemoveDeadFuncsPass]: hugr::algorithms::RemoveDeadFuncsError
85    DCEError(RemoveDeadFuncsError),
86    /// No [FuncDefn] named "main" in [Module].
87    ///
88    /// [FuncDefn]: hugr::ops::FuncDefn
89    /// [Module]: hugr::ops::Module
90    #[display("No function named 'main' in module.")]
91    NoMain,
92}
93
94impl QSystemPass {
95    /// Run `QSystemPass` on the given [Hugr]. `registry` is used for
96    /// validation, if enabled.
97    /// Expects the HUGR to have a function entrypoint.
98    pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
99        let entrypoint = if hugr.entrypoint_optype().is_module() {
100            // backwards compatibility: if the entrypoint is a module, we look for
101            // a function named "main" in the module and use that as the entrypoint.
102            hugr.children(hugr.entrypoint())
103                .find(|&n| {
104                    hugr.get_optype(n)
105                        .as_func_defn()
106                        .is_some_and(|fd| fd.func_name() == "main")
107                })
108                .ok_or(QSystemPassError::NoMain)?
109        } else {
110            hugr.entrypoint()
111        };
112
113        // passes that run on whole module
114        hugr.set_entrypoint(hugr.module_root());
115        if self.monomorphize {
116            self.monomorphization().run(hugr).unwrap();
117
118            let rdfp = RemoveDeadFuncsPass::default().with_module_entry_points([entrypoint]);
119            rdfp.run(hugr)?
120        }
121
122        // ReplaceTypes steps (there are several below) can introduce new helper
123        // functions that are public to enable linking/sharing. We'll make these private
124        // once we're done so that LLVM is not forced to compile them as callable.
125        let pubfuncs = self.hide_funcs.then(|| {
126            hugr.children(hugr.module_root())
127                .filter(|n| {
128                    hugr.get_optype(*n)
129                        .as_func_defn()
130                        .is_some_and(|fd| fd.visibility() == &Visibility::Public)
131                })
132                .collect::<HashSet<_>>()
133        });
134
135        self.lower_tk2().run(hugr)?;
136        if self.lazify {
137            self.replace_bools().run(hugr)?;
138        }
139        self.lower_drops().run(hugr)?;
140
141        if let Some(pubfuncs) = pubfuncs {
142            for n in hugr
143                .children(hugr.module_root())
144                .filter(|n| !pubfuncs.contains(n))
145                .collect::<Vec<_>>()
146            {
147                if let OpType::FuncDefn(fd) = hugr.optype_mut(n) {
148                    *fd.visibility_mut() = Visibility::Private;
149                }
150            }
151        }
152
153        #[cfg(feature = "llvm")]
154        {
155            // TODO: We still want to run this as long as deserialized hugrs are allowed to contain Value::Function
156            // Once that variant is removed, we can remove this pass step.
157            #[expect(deprecated)]
158            inline_constant_functions(hugr)?;
159        }
160        if self.constant_fold {
161            self.constant_fold().run(hugr)?;
162        }
163        if self.force_order {
164            self.force_order(hugr)?;
165        }
166        // restore the entrypoint
167        hugr.set_entrypoint(entrypoint);
168        Ok(())
169    }
170
171    fn force_order(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
172        force_order(hugr, hugr.entrypoint(), |hugr, node| {
173            let optype = hugr.get_optype(node);
174
175            let is_quantum =
176                optype.cast::<TketOp>().is_some() || optype.cast::<QSystemOp>().is_some();
177            let is_qalloc = matches!(
178                optype.cast(),
179                Some(TketOp::QAlloc) | Some(TketOp::TryQAlloc)
180            ) || optype.cast() == Some(QSystemOp::TryQAlloc);
181            let is_qfree =
182                optype.cast() == Some(TketOp::QFree) || optype.cast() == Some(QSystemOp::QFree);
183            let is_read = optype.cast() == Some(FutureOpDef::Read);
184
185            // HACK: for now qallocs and qfrees are not adequately ordered,
186            // see <https://github.com/quantinuum/guppylang/issues/778>. To
187            // mitigate this we push qfrees as early as possible and qallocs
188            // as late as possible
189            //
190            // To maximise laziness we push quantum ops (including
191            // LazyMeasure) as early as possible and Future::Read as late as
192            // possible.
193            if is_qfree {
194                -3
195            } else if is_quantum && !is_qalloc {
196                // non-qalloc quantum ops
197                -2
198            } else if is_qalloc {
199                -1
200            } else if !is_read {
201                // all other ops
202                0
203            } else {
204                // Future::Read ops
205                1
206            }
207        })?;
208        Ok::<_, QSystemPassError>(())
209    }
210
211    fn lower_tk2(&self) -> LowerTketToQSystemPass {
212        LowerTketToQSystemPass
213    }
214
215    fn replace_bools(&self) -> ReplaceBoolPass {
216        ReplaceBoolPass
217    }
218
219    fn constant_fold(&self) -> ConstantFoldPass {
220        ConstantFoldPass::default()
221    }
222
223    fn monomorphization(&self) -> MonomorphizePass {
224        MonomorphizePass
225    }
226
227    fn lower_drops(&self) -> LowerDropsPass {
228        LowerDropsPass
229    }
230
231    /// Returns a new `QSystemPass` with constant folding enabled according to
232    /// `constant_fold`.
233    ///
234    /// Off by default.
235    pub fn with_constant_fold(mut self, constant_fold: bool) -> Self {
236        self.constant_fold = constant_fold;
237        self
238    }
239
240    /// Returns a new `QSystemPass` with monomorphization enabled according to
241    /// `monomorphize`.
242    ///
243    /// On by default.
244    pub fn with_monormophize(mut self, monomorphize: bool) -> Self {
245        self.monomorphize = monomorphize;
246        self
247    }
248
249    /// Returns a new `QSystemPass` with forcing the HUGR to have
250    /// totally-ordered ops enabled according to `force_order`.
251    ///
252    /// On by default.
253    ///
254    /// When enabled, we push quantum ops as early as possible, and we push
255    /// `tket.futures.read` ops as late as possible.
256    pub fn with_force_order(mut self, force_order: bool) -> Self {
257        self.force_order = force_order;
258        self
259    }
260
261    /// Returns a new `QSystemPass` with lazification enabled according to
262    /// `lazify`.
263    ///
264    /// On by default.
265    ///
266    /// When enabled we replace strict measurement ops with lazy equivalents
267    /// from `tket.qsystem`.
268    pub fn with_lazify(mut self, lazify: bool) -> Self {
269        self.lazify = lazify;
270        self
271    }
272}
273
274#[cfg(test)]
275mod test {
276    use hugr::{
277        Hugr, HugrView as _,
278        builder::{Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder},
279        core::Visibility,
280        extension::prelude::qb_t,
281        hugr::hugrmut::HugrMut,
282        ops::{ExtensionOp, OpType, handle::NodeHandle},
283        std_extensions::arithmetic::float_types::ConstF64,
284        std_extensions::collections::array::{ArrayOpBuilder, array_type},
285        type_row,
286        types::Signature,
287    };
288
289    use itertools::Itertools as _;
290    use petgraph::visit::{Topo, Walker as _};
291    use rstest::rstest;
292    use tket::extension::{
293        bool::bool_type,
294        guppy::{DROP_OP_NAME, GUPPY_EXTENSION},
295    };
296
297    use crate::{
298        QSystemPass,
299        extension::{futures::FutureOpDef, qsystem::QSystemOp},
300    };
301
302    #[rstest]
303    #[case(false)]
304    #[case(true)]
305    fn qsystem_pass(#[case] set_entrypoint: bool) {
306        let mut mb = hugr::builder::ModuleBuilder::new();
307        let func = mb
308            .define_function("func", Signature::new_endo(type_row![]))
309            .unwrap()
310            .finish_with_outputs([])
311            .unwrap();
312
313        let (mut hugr, [call_node, h_node, f_node, rx_node, main_node]) = {
314            let mut builder = mb
315                .define_function(
316                    "main",
317                    Signature::new(qb_t(), vec![bool_type(), bool_type()]),
318                )
319                .unwrap();
320            let [qb] = builder.input_wires_arr();
321
322            // This call node has no dependencies, so it should be lifted above
323            // Future Reads and sunk below quantum ops.
324            let call_node = builder.call(func.handle(), &[], []).unwrap().node();
325
326            // this LoadConstant should be pushed below the quantum ops where possible
327            let angle = builder.add_load_value(ConstF64::new(0.0));
328            let f_node = angle.node();
329
330            // with no dependencies, this Reset should be lifted to the start
331            let [qb] = builder
332                .add_dataflow_op(QSystemOp::Reset, [qb])
333                .unwrap()
334                .outputs_arr();
335            let h_node = qb.node();
336
337            // depending on the angle means this op can't be lifted above the angle ops
338            let [qb] = builder
339                .add_dataflow_op(QSystemOp::Rz, [qb, angle])
340                .unwrap()
341                .outputs_arr();
342            let rx_node = qb.node();
343
344            // the Measure node will be removed. A Lazy Measure and two Future
345            // Reads will be added.  The Lazy Measure will be lifted and the
346            // reads will be sunk.
347            let [measure_result] = builder
348                .add_dataflow_op(QSystemOp::Measure, [qb])
349                .unwrap()
350                .outputs_arr();
351
352            let main_n = builder
353                .finish_with_outputs([measure_result, measure_result])
354                .unwrap()
355                .node();
356            let hugr = mb.finish_hugr().unwrap();
357            (hugr, [call_node, h_node, f_node, rx_node, main_n])
358        };
359        if set_entrypoint {
360            // set the entrypoint to the main function
361            // if this is not done the "backwards compatibility" code is triggered
362            hugr.set_entrypoint(main_node);
363        }
364        QSystemPass::default().run(&mut hugr).unwrap();
365
366        let topo_sorted = Topo::new(&hugr.as_petgraph())
367            .iter(&hugr.as_petgraph())
368            .collect_vec();
369
370        let get_pos = |x| topo_sorted.iter().position(|&y| y == x).unwrap();
371        assert!(get_pos(h_node) < get_pos(f_node));
372        assert!(get_pos(h_node) < get_pos(call_node));
373        assert!(get_pos(rx_node) < get_pos(call_node));
374
375        for &n in topo_sorted
376            .iter()
377            .filter(|&&n| FutureOpDef::try_from(hugr.get_optype(n)) == Ok(FutureOpDef::Read))
378        {
379            assert!(get_pos(call_node) < get_pos(n));
380        }
381    }
382
383    #[test]
384    fn hide_funcs() {
385        let orig = {
386            let arr_t = || array_type(4, bool_type());
387            let mut dfb = FunctionBuilder::new("main", Signature::new_endo(arr_t())).unwrap();
388            let [arr] = dfb.input_wires_arr();
389            let (arr1, arr2) = dfb.add_array_clone(bool_type(), 4, arr).unwrap();
390            let dop = GUPPY_EXTENSION.get_op(&DROP_OP_NAME).unwrap();
391            dfb.add_dataflow_op(
392                ExtensionOp::new(dop.clone(), [arr_t().into()]).unwrap(),
393                [arr1],
394            )
395            .unwrap();
396            dfb.finish_hugr_with_outputs([arr2]).unwrap()
397        };
398
399        let count_pub_funcs = |hugr: &Hugr| {
400            hugr.children(hugr.module_root())
401                .filter(|n| match hugr.get_optype(*n) {
402                    OpType::FuncDefn(fd) => fd.visibility() == &Visibility::Public,
403                    OpType::FuncDecl(fd) => fd.visibility() == &Visibility::Public,
404                    _ => false,
405                })
406                .count()
407        };
408
409        // Check there are no public funcs (after hiding)
410        let mut hugr = orig.clone();
411        QSystemPass::default().run(&mut hugr).unwrap();
412        assert_eq!(count_pub_funcs(&hugr), 0);
413
414        // Run again without hiding...
415        let mut hugr_public = orig;
416        QSystemPass {
417            hide_funcs: false,
418            ..Default::default()
419        }
420        .run(&mut hugr_public)
421        .unwrap();
422
423        assert_eq!(count_pub_funcs(&hugr_public), 4);
424        assert_eq!(
425            hugr.children(hugr.module_root()).count(),
426            hugr_public.children(hugr_public.module_root()).count()
427        );
428        assert_eq!(hugr.num_nodes(), hugr_public.num_nodes());
429    }
430
431    #[cfg(feature = "llvm")]
432    #[test]
433    // TODO: We still want to test this as long as deserialized hugrs are allowed to contain Value::Function
434    // Once that variant is removed, we can remove this test.
435    #[expect(deprecated)]
436    fn const_function() {
437        use hugr::builder::{Container, DFGBuilder, DataflowHugr, ModuleBuilder};
438        use hugr::ops::{CallIndirect, Value};
439
440        let qb_sig: Signature = Signature::new_endo(qb_t());
441        let mut hugr = {
442            let mut builder = ModuleBuilder::new();
443            let val = Value::function({
444                let builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
445                let [r] = builder.input_wires_arr();
446                builder.finish_hugr_with_outputs([r]).unwrap()
447            })
448            .unwrap();
449            let const_node = builder.add_constant(val);
450            {
451                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
452                let [i] = builder.input_wires_arr();
453                let fun = builder.load_const(&const_node);
454                let [r] = builder
455                    .add_dataflow_op(
456                        CallIndirect {
457                            signature: qb_sig.clone(),
458                        },
459                        [fun, i],
460                    )
461                    .unwrap()
462                    .outputs_arr();
463                builder.finish_with_outputs([r]).unwrap();
464            };
465            builder.finish_hugr().unwrap()
466        };
467
468        QSystemPass::default().run(&mut hugr).unwrap();
469
470        // QSystemPass should have removed the const function
471        for n in hugr.descendants(hugr.module_root()) {
472            if hugr.get_optype(n).as_const().is_some() {
473                panic!("Const function is still there!");
474            }
475        }
476    }
477}