1#[cfg(feature = "cli")]
5pub mod cli;
6pub mod extension;
7#[cfg(feature = "llvm")]
8pub mod llvm;
9pub mod lower_drops;
10pub mod pytket;
11pub mod replace_bools;
12
13use derive_more::{Display, Error, From};
14use hugr::algorithms::const_fold::{ConstFoldError, ConstantFoldPass};
15use hugr::algorithms::{
16 ComposablePass as _, MonomorphizePass, RemoveDeadFuncsError, RemoveDeadFuncsPass, force_order,
17 replace_types::ReplaceTypesError,
18};
19use hugr::hugr::{HugrError, hugrmut::HugrMut};
20use hugr::{Hugr, HugrView, Node, core::Visibility, ops::OpType};
21use hugr_core::hugr::internal::HugrMutInternals;
22use std::collections::HashSet;
23
24use lower_drops::LowerDropsPass;
25use replace_bools::{ReplaceBoolPass, ReplaceBoolPassError};
26use tket::TketOp;
27
28use extension::{
29 futures::FutureOpDef,
30 qsystem::{LowerTk2Error, LowerTketToQSystemPass, QSystemOp},
31};
32
33#[cfg(feature = "llvm")]
34#[expect(deprecated)]
35use hugr::llvm::utils::inline_constant_functions;
38
39#[derive(Debug, Clone, Copy)]
44pub struct QSystemPass {
45 constant_fold: bool,
46 monomorphize: bool,
47 force_order: bool,
48 lazify: bool,
49 hide_funcs: bool,
50}
51
52impl Default for QSystemPass {
53 fn default() -> Self {
54 Self {
55 constant_fold: false,
56 monomorphize: true,
57 force_order: true,
58 lazify: true,
59 hide_funcs: true,
60 }
61 }
62}
63
64#[derive(Error, Debug, Display, From)]
65#[non_exhaustive]
66pub enum QSystemPassError<N = Node> {
68 ReplaceBoolError(ReplaceBoolPassError<N>),
70 ForceOrderError(HugrError),
72 LowerTk2Error(LowerTk2Error),
74 ConstantFoldError(ConstFoldError),
76 LinearizeArrayError(ReplaceTypesError),
78 #[cfg(feature = "llvm")]
79 InlineConstantFunctionsError(anyhow::Error),
81 DCEError(RemoveDeadFuncsError),
86 #[display("No function named 'main' in module.")]
91 NoMain,
92}
93
94impl QSystemPass {
95 pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
99 let entrypoint = if hugr.entrypoint_optype().is_module() {
100 hugr.children(hugr.entrypoint())
103 .find(|&n| {
104 hugr.get_optype(n)
105 .as_func_defn()
106 .is_some_and(|fd| fd.func_name() == "main")
107 })
108 .ok_or(QSystemPassError::NoMain)?
109 } else {
110 hugr.entrypoint()
111 };
112
113 hugr.set_entrypoint(hugr.module_root());
115 if self.monomorphize {
116 self.monomorphization().run(hugr).unwrap();
117
118 let rdfp = RemoveDeadFuncsPass::default().with_module_entry_points([entrypoint]);
119 rdfp.run(hugr)?
120 }
121
122 let pubfuncs = self.hide_funcs.then(|| {
126 hugr.children(hugr.module_root())
127 .filter(|n| {
128 hugr.get_optype(*n)
129 .as_func_defn()
130 .is_some_and(|fd| fd.visibility() == &Visibility::Public)
131 })
132 .collect::<HashSet<_>>()
133 });
134
135 self.lower_tk2().run(hugr)?;
136 if self.lazify {
137 self.replace_bools().run(hugr)?;
138 }
139 self.lower_drops().run(hugr)?;
140
141 if let Some(pubfuncs) = pubfuncs {
142 for n in hugr
143 .children(hugr.module_root())
144 .filter(|n| !pubfuncs.contains(n))
145 .collect::<Vec<_>>()
146 {
147 if let OpType::FuncDefn(fd) = hugr.optype_mut(n) {
148 *fd.visibility_mut() = Visibility::Private;
149 }
150 }
151 }
152
153 #[cfg(feature = "llvm")]
154 {
155 #[expect(deprecated)]
158 inline_constant_functions(hugr)?;
159 }
160 if self.constant_fold {
161 self.constant_fold().run(hugr)?;
162 }
163 if self.force_order {
164 self.force_order(hugr)?;
165 }
166 hugr.set_entrypoint(entrypoint);
168 Ok(())
169 }
170
171 fn force_order(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
172 force_order(hugr, hugr.entrypoint(), |hugr, node| {
173 let optype = hugr.get_optype(node);
174
175 let is_quantum =
176 optype.cast::<TketOp>().is_some() || optype.cast::<QSystemOp>().is_some();
177 let is_qalloc = matches!(
178 optype.cast(),
179 Some(TketOp::QAlloc) | Some(TketOp::TryQAlloc)
180 ) || optype.cast() == Some(QSystemOp::TryQAlloc);
181 let is_qfree =
182 optype.cast() == Some(TketOp::QFree) || optype.cast() == Some(QSystemOp::QFree);
183 let is_read = optype.cast() == Some(FutureOpDef::Read);
184
185 if is_qfree {
194 -3
195 } else if is_quantum && !is_qalloc {
196 -2
198 } else if is_qalloc {
199 -1
200 } else if !is_read {
201 0
203 } else {
204 1
206 }
207 })?;
208 Ok::<_, QSystemPassError>(())
209 }
210
211 fn lower_tk2(&self) -> LowerTketToQSystemPass {
212 LowerTketToQSystemPass
213 }
214
215 fn replace_bools(&self) -> ReplaceBoolPass {
216 ReplaceBoolPass
217 }
218
219 fn constant_fold(&self) -> ConstantFoldPass {
220 ConstantFoldPass::default()
221 }
222
223 fn monomorphization(&self) -> MonomorphizePass {
224 MonomorphizePass
225 }
226
227 fn lower_drops(&self) -> LowerDropsPass {
228 LowerDropsPass
229 }
230
231 pub fn with_constant_fold(mut self, constant_fold: bool) -> Self {
236 self.constant_fold = constant_fold;
237 self
238 }
239
240 pub fn with_monormophize(mut self, monomorphize: bool) -> Self {
245 self.monomorphize = monomorphize;
246 self
247 }
248
249 pub fn with_force_order(mut self, force_order: bool) -> Self {
257 self.force_order = force_order;
258 self
259 }
260
261 pub fn with_lazify(mut self, lazify: bool) -> Self {
269 self.lazify = lazify;
270 self
271 }
272}
273
274#[cfg(test)]
275mod test {
276 use hugr::{
277 Hugr, HugrView as _,
278 builder::{Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder},
279 core::Visibility,
280 extension::prelude::qb_t,
281 hugr::hugrmut::HugrMut,
282 ops::{ExtensionOp, OpType, handle::NodeHandle},
283 std_extensions::arithmetic::float_types::ConstF64,
284 std_extensions::collections::array::{ArrayOpBuilder, array_type},
285 type_row,
286 types::Signature,
287 };
288
289 use itertools::Itertools as _;
290 use petgraph::visit::{Topo, Walker as _};
291 use rstest::rstest;
292 use tket::extension::{
293 bool::bool_type,
294 guppy::{DROP_OP_NAME, GUPPY_EXTENSION},
295 };
296
297 use crate::{
298 QSystemPass,
299 extension::{futures::FutureOpDef, qsystem::QSystemOp},
300 };
301
302 #[rstest]
303 #[case(false)]
304 #[case(true)]
305 fn qsystem_pass(#[case] set_entrypoint: bool) {
306 let mut mb = hugr::builder::ModuleBuilder::new();
307 let func = mb
308 .define_function("func", Signature::new_endo(type_row![]))
309 .unwrap()
310 .finish_with_outputs([])
311 .unwrap();
312
313 let (mut hugr, [call_node, h_node, f_node, rx_node, main_node]) = {
314 let mut builder = mb
315 .define_function(
316 "main",
317 Signature::new(qb_t(), vec![bool_type(), bool_type()]),
318 )
319 .unwrap();
320 let [qb] = builder.input_wires_arr();
321
322 let call_node = builder.call(func.handle(), &[], []).unwrap().node();
325
326 let angle = builder.add_load_value(ConstF64::new(0.0));
328 let f_node = angle.node();
329
330 let [qb] = builder
332 .add_dataflow_op(QSystemOp::Reset, [qb])
333 .unwrap()
334 .outputs_arr();
335 let h_node = qb.node();
336
337 let [qb] = builder
339 .add_dataflow_op(QSystemOp::Rz, [qb, angle])
340 .unwrap()
341 .outputs_arr();
342 let rx_node = qb.node();
343
344 let [measure_result] = builder
348 .add_dataflow_op(QSystemOp::Measure, [qb])
349 .unwrap()
350 .outputs_arr();
351
352 let main_n = builder
353 .finish_with_outputs([measure_result, measure_result])
354 .unwrap()
355 .node();
356 let hugr = mb.finish_hugr().unwrap();
357 (hugr, [call_node, h_node, f_node, rx_node, main_n])
358 };
359 if set_entrypoint {
360 hugr.set_entrypoint(main_node);
363 }
364 QSystemPass::default().run(&mut hugr).unwrap();
365
366 let topo_sorted = Topo::new(&hugr.as_petgraph())
367 .iter(&hugr.as_petgraph())
368 .collect_vec();
369
370 let get_pos = |x| topo_sorted.iter().position(|&y| y == x).unwrap();
371 assert!(get_pos(h_node) < get_pos(f_node));
372 assert!(get_pos(h_node) < get_pos(call_node));
373 assert!(get_pos(rx_node) < get_pos(call_node));
374
375 for &n in topo_sorted
376 .iter()
377 .filter(|&&n| FutureOpDef::try_from(hugr.get_optype(n)) == Ok(FutureOpDef::Read))
378 {
379 assert!(get_pos(call_node) < get_pos(n));
380 }
381 }
382
383 #[test]
384 fn hide_funcs() {
385 let orig = {
386 let arr_t = || array_type(4, bool_type());
387 let mut dfb = FunctionBuilder::new("main", Signature::new_endo(arr_t())).unwrap();
388 let [arr] = dfb.input_wires_arr();
389 let (arr1, arr2) = dfb.add_array_clone(bool_type(), 4, arr).unwrap();
390 let dop = GUPPY_EXTENSION.get_op(&DROP_OP_NAME).unwrap();
391 dfb.add_dataflow_op(
392 ExtensionOp::new(dop.clone(), [arr_t().into()]).unwrap(),
393 [arr1],
394 )
395 .unwrap();
396 dfb.finish_hugr_with_outputs([arr2]).unwrap()
397 };
398
399 let count_pub_funcs = |hugr: &Hugr| {
400 hugr.children(hugr.module_root())
401 .filter(|n| match hugr.get_optype(*n) {
402 OpType::FuncDefn(fd) => fd.visibility() == &Visibility::Public,
403 OpType::FuncDecl(fd) => fd.visibility() == &Visibility::Public,
404 _ => false,
405 })
406 .count()
407 };
408
409 let mut hugr = orig.clone();
411 QSystemPass::default().run(&mut hugr).unwrap();
412 assert_eq!(count_pub_funcs(&hugr), 0);
413
414 let mut hugr_public = orig;
416 QSystemPass {
417 hide_funcs: false,
418 ..Default::default()
419 }
420 .run(&mut hugr_public)
421 .unwrap();
422
423 assert_eq!(count_pub_funcs(&hugr_public), 4);
424 assert_eq!(
425 hugr.children(hugr.module_root()).count(),
426 hugr_public.children(hugr_public.module_root()).count()
427 );
428 assert_eq!(hugr.num_nodes(), hugr_public.num_nodes());
429 }
430
431 #[cfg(feature = "llvm")]
432 #[test]
433 #[expect(deprecated)]
436 fn const_function() {
437 use hugr::builder::{Container, DFGBuilder, DataflowHugr, ModuleBuilder};
438 use hugr::ops::{CallIndirect, Value};
439
440 let qb_sig: Signature = Signature::new_endo(qb_t());
441 let mut hugr = {
442 let mut builder = ModuleBuilder::new();
443 let val = Value::function({
444 let builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
445 let [r] = builder.input_wires_arr();
446 builder.finish_hugr_with_outputs([r]).unwrap()
447 })
448 .unwrap();
449 let const_node = builder.add_constant(val);
450 {
451 let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
452 let [i] = builder.input_wires_arr();
453 let fun = builder.load_const(&const_node);
454 let [r] = builder
455 .add_dataflow_op(
456 CallIndirect {
457 signature: qb_sig.clone(),
458 },
459 [fun, i],
460 )
461 .unwrap()
462 .outputs_arr();
463 builder.finish_with_outputs([r]).unwrap();
464 };
465 builder.finish_hugr().unwrap()
466 };
467
468 QSystemPass::default().run(&mut hugr).unwrap();
469
470 for n in hugr.descendants(hugr.module_root()) {
472 if hugr.get_optype(n).as_const().is_some() {
473 panic!("Const function is still there!");
474 }
475 }
476 }
477}