1mod static_array;
4
5use derive_more::{Display, Error, From};
6use hugr::algorithms::replace_types::{Linearizer, NodeTemplate, ReplaceTypesError};
7use hugr::algorithms::{
8 ComposablePass, ReplaceTypes, ensure_no_nonlocal_edges, non_local::FindNonLocalEdgesError,
9};
10use hugr::builder::{
11 BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
12 inout_sig,
13};
14use hugr::extension::prelude::{bool_t, option_type, qb_t, usize_t};
15use hugr::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
16use hugr::ops::{ExtensionOp, Tag, Value, handle::ConditionalID};
17use hugr::std_extensions::arithmetic::{
18 conversions::ConvertOpDef, int_ops::IntOpDef, int_types::ConstInt,
19};
20use hugr::std_extensions::collections::{
21 array::{self, ARRAY_CLONE_OP_ID, ARRAY_DISCARD_OP_ID, GenericArrayOpDef, array_type},
22 borrow_array::{self, BArrayUnsafeOpDef, BorrowArray, borrow_array_type},
23};
24use hugr::std_extensions::logic::LogicOp;
25use hugr::types::{SumType, Term, Type};
26use hugr::{Hugr, Node, Wire, hugr::hugrmut::HugrMut, type_row};
27use static_array::{ReplaceStaticArrayBoolPass, ReplaceStaticArrayBoolPassError};
28use tket::TketOp;
29use tket::extension::{
30 bool::{BoolOp, ConstBool, bool_type},
31 guppy::{DROP_OP_NAME, GUPPY_EXTENSION},
32};
33
34use crate::extension::{
35 futures::{FutureOp, FutureOpBuilder, FutureOpDef, future_type},
36 qsystem::QSystemOp,
37};
38
39#[derive(Error, Debug, Display, From)]
40#[non_exhaustive]
41pub enum ReplaceBoolPassError<N> {
43 NonLocalEdgesError(FindNonLocalEdgesError<N>),
45 ReplacementError(ReplaceTypesError),
47 ReplaceStaticArrayBoolPassError(ReplaceStaticArrayBoolPassError),
50}
51
52#[derive(Default, Debug, Clone)]
69pub struct ReplaceBoolPass;
70
71impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceBoolPass {
72 type Error = ReplaceBoolPassError<H::Node>;
73 type Result = ();
74
75 fn run(&self, hugr: &mut H) -> Result<(), Self::Error> {
76 ensure_no_nonlocal_edges(hugr)?;
77 ReplaceStaticArrayBoolPass::default().run(hugr)?;
78 let lowerer = lowerer();
79 lowerer.run(hugr)?;
80 Ok(())
81 }
82}
83
84fn bool_dest() -> Type {
86 SumType::new([bool_t(), future_type(bool_t())]).into()
87}
88
89fn read_builder(dfb: &mut DFGBuilder<Hugr>, sum_wire: Wire) -> BuildHandle<ConditionalID> {
90 let mut cb = dfb
91 .conditional_builder(
92 ([bool_t().into(), future_type(bool_t()).into()], sum_wire),
93 [],
94 vec![bool_t()].into(),
95 )
96 .unwrap();
97
98 let case0 = cb.case_builder(0).unwrap();
100 let [b] = case0.input_wires_arr();
101 case0.finish_with_outputs([b]).unwrap();
102 let mut case1 = cb.case_builder(1).unwrap();
104 let [f] = case1.input_wires_arr();
105 let [f] = case1.add_read(f, bool_t()).unwrap();
106 case1.finish_with_outputs([f]).unwrap();
107
108 cb.finish_sub_container().unwrap()
109}
110
111fn read_op_dest() -> NodeTemplate {
112 let mut dfb = DFGBuilder::new(inout_sig(vec![bool_dest()], vec![bool_t()])).unwrap();
113 let [sum_wire] = dfb.input_wires_arr();
114 let cond = read_builder(&mut dfb, sum_wire);
115 let h = dfb.finish_hugr_with_outputs(cond.outputs()).unwrap();
116 NodeTemplate::CompoundOp(Box::new(h))
117}
118
119fn make_opaque_op_dest() -> NodeTemplate {
120 let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_dest()])).unwrap();
121 let [inp] = dfb.input_wires_arr();
122 let out = dfb
123 .add_dataflow_op(
124 Tag::new(0, vec![bool_t().into(), future_type(bool_t()).into()]),
125 vec![inp],
126 )
127 .unwrap();
128 let h = dfb.finish_hugr_with_outputs(out.outputs()).unwrap();
129 NodeTemplate::CompoundOp(Box::new(h))
130}
131
132fn binary_logic_op_dest(op: &BoolOp) -> NodeTemplate {
133 let mut dfb =
134 DFGBuilder::new(inout_sig(vec![bool_dest(), bool_dest()], vec![bool_dest()])).unwrap();
135 let [sum_wire1, sum_wire2] = dfb.input_wires_arr();
136 let cond1 = read_builder(&mut dfb, sum_wire1);
137 let cond2 = read_builder(&mut dfb, sum_wire2);
138 let result = match op {
139 BoolOp::eq => dfb
140 .add_dataflow_op(LogicOp::Eq, [cond1.out_wire(0), cond2.out_wire(0)])
141 .unwrap(),
142 BoolOp::and => dfb
143 .add_dataflow_op(LogicOp::And, [cond1.out_wire(0), cond2.out_wire(0)])
144 .unwrap(),
145 BoolOp::or => dfb
146 .add_dataflow_op(LogicOp::Or, [cond1.out_wire(0), cond2.out_wire(0)])
147 .unwrap(),
148 BoolOp::xor => dfb
149 .add_dataflow_op(LogicOp::Xor, [cond1.out_wire(0), cond2.out_wire(0)])
150 .unwrap(),
151 op => panic!("Unknown op name: {op:?}"),
152 };
153 let out = dfb
154 .add_dataflow_op(
155 Tag::new(0, vec![bool_t().into(), future_type(bool_t()).into()]),
156 vec![result.out_wire(0)],
157 )
158 .unwrap();
159
160 let h = dfb.finish_hugr_with_outputs(out.outputs()).unwrap();
161 NodeTemplate::CompoundOp(Box::new(h))
162}
163
164fn not_op_dest() -> NodeTemplate {
165 let mut dfb = DFGBuilder::new(inout_sig(vec![bool_dest()], vec![bool_dest()])).unwrap();
166 let [sum_wire] = dfb.input_wires_arr();
167 let cond = read_builder(&mut dfb, sum_wire);
168 let result = dfb
169 .add_dataflow_op(LogicOp::Not, [cond.out_wire(0)])
170 .unwrap();
171 let out = dfb
172 .add_dataflow_op(
173 Tag::new(0, vec![bool_t().into(), future_type(bool_t()).into()]),
174 vec![result.out_wire(0)],
175 )
176 .unwrap();
177 let h = dfb.finish_hugr_with_outputs(out.outputs()).unwrap();
178 NodeTemplate::CompoundOp(Box::new(h))
179}
180
181fn measure_dest() -> NodeTemplate {
182 let lazy_measure = QSystemOp::LazyMeasure.to_extension_op().unwrap();
183
184 let mut dfb = DFGBuilder::new(inout_sig(vec![qb_t()], vec![bool_dest()])).unwrap();
185 let [q] = dfb.input_wires_arr();
186 let measure = dfb.add_dataflow_op(lazy_measure, vec![q]).unwrap();
187 let tagged_output = dfb
188 .add_dataflow_op(
189 Tag::new(1, vec![bool_t().into(), future_type(bool_t()).into()]),
190 vec![measure.out_wire(0)],
191 )
192 .unwrap();
193 let h = dfb
194 .finish_hugr_with_outputs(tagged_output.outputs())
195 .unwrap();
196 NodeTemplate::CompoundOp(Box::new(h))
197}
198
199fn measure_reset_dest() -> NodeTemplate {
200 let lazy_measure_reset = QSystemOp::LazyMeasureReset.to_extension_op().unwrap();
201
202 let mut dfb = DFGBuilder::new(inout_sig(vec![qb_t()], vec![qb_t(), bool_dest()])).unwrap();
203 let [q] = dfb.input_wires_arr();
204 let measure = dfb.add_dataflow_op(lazy_measure_reset, vec![q]).unwrap();
205 let tagged_output = dfb
206 .add_dataflow_op(
207 Tag::new(1, vec![bool_t().into(), future_type(bool_t()).into()]),
208 vec![measure.out_wire(1)],
209 )
210 .unwrap();
211 let h = dfb
212 .finish_hugr_with_outputs(vec![measure.out_wire(0), tagged_output.out_wire(0)])
213 .unwrap();
214 NodeTemplate::CompoundOp(Box::new(h))
215}
216
217fn barray_get_dest(rt: &ReplaceTypes, size: u64, elem_ty: Type) -> NodeTemplate {
218 let array_ty = borrow_array_type(size, elem_ty.clone());
219 let opt_el = option_type(elem_ty.clone());
220 let mut dfb = DFGBuilder::new(inout_sig(
221 vec![array_ty.clone(), usize_t()],
222 vec![opt_el.clone().into(), array_ty.clone()],
223 ))
224 .unwrap();
225 let [arr_in, idx] = dfb.input_wires_arr();
226 let [idx_as_int] = dfb
227 .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [idx])
228 .unwrap()
229 .outputs_arr();
230 let bound = dfb.add_load_value(ConstInt::new_u(6, size).unwrap());
231 let [is_in_range] = dfb
232 .add_dataflow_op(IntOpDef::ilt_u.with_log_width(6), [idx_as_int, bound])
233 .unwrap()
234 .outputs_arr();
235 let mut cb = dfb
236 .conditional_builder(
237 (vec![type_row![]; 2], is_in_range),
238 [(array_ty.clone(), arr_in), (usize_t(), idx)],
239 vec![opt_el.clone().into(), array_ty.clone()].into(),
240 )
241 .unwrap();
242
243 let mut out_of_range = cb.case_builder(0).unwrap();
244 let [arr_in, _] = out_of_range.input_wires_arr();
245 let [none] = out_of_range
246 .add_dataflow_op(Tag::new(0, vec![type_row![], elem_ty.clone().into()]), [])
247 .unwrap()
248 .outputs_arr();
249 out_of_range.finish_with_outputs([none, arr_in]).unwrap();
250
251 let mut in_range = cb.case_builder(1).unwrap();
252 let [arr_in, idx] = in_range.input_wires_arr();
253 let [arr, elem] = in_range
254 .add_dataflow_op(
255 BArrayUnsafeOpDef::borrow.to_concrete(elem_ty.clone(), size),
256 [arr_in, idx],
257 )
258 .unwrap()
259 .outputs_arr();
260
261 let [elem1, elem2] = rt
262 .get_linearizer()
263 .copy_discard_op(&elem_ty, 2)
264 .unwrap()
265 .add(&mut in_range, [elem])
266 .unwrap()
267 .outputs_arr();
268
269 let [arr] = in_range
270 .add_dataflow_op(
271 BArrayUnsafeOpDef::r#return.to_concrete(elem_ty.clone(), size),
272 [arr, idx, elem1],
273 )
274 .unwrap()
275 .outputs_arr();
276 let [some] = in_range
277 .add_dataflow_op(Tag::new(1, vec![type_row![], elem_ty.into()]), [elem2])
278 .unwrap()
279 .outputs_arr();
280 in_range.finish_with_outputs([some, arr]).unwrap();
281
282 let outs = cb.finish_sub_container().unwrap().outputs();
283 dfb.set_outputs(outs).unwrap();
285 let h = std::mem::take(dfb.hugr_mut());
286 NodeTemplate::CompoundOp(Box::new(h))
287}
288
289fn lowerer() -> ReplaceTypes {
291 let mut lw = ReplaceTypes::default();
292
293 lw.set_replace_type(bool_type().as_extension().unwrap().clone(), bool_dest());
295 let dup_op = FutureOp {
296 op: FutureOpDef::Dup,
297 typ: bool_t(),
298 }
299 .to_extension_op()
300 .unwrap();
301 let free_op = FutureOp {
302 op: FutureOpDef::Free,
303 typ: bool_t(),
304 }
305 .to_extension_op()
306 .unwrap();
307 lw.linearizer_mut()
308 .register_simple(
309 future_type(bool_t()).as_extension().unwrap().clone(),
310 NodeTemplate::SingleOp(dup_op.into()),
311 NodeTemplate::SingleOp(free_op.into()),
312 )
313 .unwrap();
314
315 lw.replace_consts(
317 bool_type().as_extension().unwrap().clone(),
318 |const_bool, _| {
319 Ok(Value::sum(
320 0,
321 [Value::from_bool(
322 const_bool
323 .value()
324 .downcast_ref::<ConstBool>()
325 .unwrap()
326 .value(),
327 )],
328 SumType::new([vec![bool_t()], vec![future_type(bool_t())]]),
329 )
330 .unwrap())
331 },
332 );
333
334 let read_op = BoolOp::read.to_extension_op().unwrap();
336 lw.set_replace_op(&read_op, read_op_dest());
337 let make_opaque_op = BoolOp::make_opaque.to_extension_op().unwrap();
338 lw.set_replace_op(&make_opaque_op, make_opaque_op_dest());
339 for op in [BoolOp::eq, BoolOp::and, BoolOp::or, BoolOp::xor] {
340 lw.set_replace_op(&op.to_extension_op().unwrap(), binary_logic_op_dest(&op));
341 }
342 let not_op = BoolOp::not.to_extension_op().unwrap();
343 lw.set_replace_op(¬_op, not_op_dest());
344
345 let tket_measure_free = TketOp::MeasureFree.to_extension_op().unwrap();
347 let qsystem_measure = QSystemOp::Measure.to_extension_op().unwrap();
348 let qsystem_measure_reset = QSystemOp::MeasureReset.to_extension_op().unwrap();
349 lw.set_replace_op(&tket_measure_free, measure_dest());
350 lw.set_replace_op(&qsystem_measure, measure_dest());
351 lw.set_replace_op(&qsystem_measure_reset, measure_reset_dest());
352
353 for (array_ext, type_fn) in [
356 (
357 array::EXTENSION.to_owned(),
358 array_type as fn(u64, Type) -> Type,
359 ),
360 (
361 borrow_array::EXTENSION.to_owned(),
362 borrow_array_type as fn(u64, Type) -> Type,
363 ),
364 ] {
365 lw.set_replace_parametrized_op(
366 array_ext.get_op(ARRAY_CLONE_OP_ID.as_str()).unwrap(),
367 move |args, rt| {
368 let [size, elem_ty] = args else {
369 unreachable!()
370 };
371 let size = size.as_nat().unwrap();
372 let elem_ty = elem_ty.as_runtime().unwrap();
373 if elem_ty.copyable() {
374 return Ok(None);
375 }
376
377 let array_ty = type_fn(size, elem_ty);
378 Ok(Some(rt.get_linearizer().copy_discard_op(&array_ty, 2)?))
379 },
380 );
381 let drop_op_def = GUPPY_EXTENSION.get_op(DROP_OP_NAME.as_str()).unwrap();
382
383 lw.set_replace_parametrized_op(
384 array_ext.get_op(ARRAY_DISCARD_OP_ID.as_str()).unwrap(),
385 move |args, _| {
386 let [size, elem_ty] = args else {
387 unreachable!()
388 };
389 let size = size.as_nat().unwrap();
390 let elem_ty = elem_ty.as_runtime().unwrap();
391 if elem_ty.copyable() {
392 return Ok(None);
393 }
394 let drop_op = ExtensionOp::new(
395 drop_op_def.clone(),
396 vec![type_fn(size, elem_ty.clone()).into()],
397 )
398 .unwrap();
399 Ok(Some(NodeTemplate::SingleOp(drop_op.into())))
400 },
401 );
402 }
403
404 lw.set_replace_parametrized_op(
405 borrow_array::EXTENSION
406 .get_op(GenericArrayOpDef::<BorrowArray>::get.opdef_id().as_str())
407 .unwrap(),
408 |args, rt| {
409 let [Term::BoundedNat(size), Term::Runtime(elem_ty)] = args else {
410 unreachable!()
411 };
412 if elem_ty.copyable() {
413 return Ok(None);
414 }
415 Ok(Some(barray_get_dest(rt, *size, elem_ty.clone())))
416 },
417 );
418
419 lw
420}
421
422#[cfg(test)]
423mod test {
424 use crate::extension::qsystem::{QSystemOp, QSystemOpBuilder};
425
426 use super::*;
427 use hugr::extension::prelude::{UnwrapBuilder, option_type, usize_t};
428 use hugr::extension::simple_op::HasDef;
429 use hugr::ops::OpType;
430 use hugr::std_extensions::collections::array::op_builder::GenericArrayOpBuilder;
431 use hugr::std_extensions::collections::array::{Array, ArrayKind};
432 use hugr::std_extensions::collections::borrow_array::{
433 BArrayOpBuilder, BorrowArray, borrow_array_type,
434 };
435 use hugr::type_row;
436 use hugr::{
437 HugrView,
438 builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig},
439 extension::prelude::qb_t,
440 types::TypeRow,
441 };
442 use rstest::rstest;
443 use tket::{
444 TketOp,
445 extension::bool::{BoolOp, BoolOpBuilder},
446 };
447
448 fn tket_bool_t() -> Type {
449 bool_type()
450 }
451
452 #[test]
453 fn test_consts() {
454 let mut dfb = DFGBuilder::new(inout_sig(vec![], vec![tket_bool_t()])).unwrap();
455 let const_wire = dfb.add_load_value(ConstBool::new(true));
456 let mut h = dfb.finish_hugr_with_outputs([const_wire]).unwrap();
457
458 h.validate().unwrap();
459 let pass = ReplaceBoolPass;
460 pass.run(&mut h).unwrap();
461 h.validate().unwrap();
462 let sig = h.signature(h.entrypoint()).unwrap();
463 assert_eq!(sig.output(), &TypeRow::from(vec![bool_dest()]));
464 }
465
466 #[test]
467 fn test_read() {
468 let mut dfb = DFGBuilder::new(inout_sig(vec![tket_bool_t()], vec![bool_t()])).unwrap();
469 let [b] = dfb.input_wires_arr();
470 let output = dfb.add_bool_read(b).unwrap();
471 let mut h = dfb.finish_hugr_with_outputs(output).unwrap();
472
473 assert_eq!(h.num_nodes(), 8);
474
475 let pass = ReplaceBoolPass;
476 pass.run(&mut h).unwrap();
477 h.validate().unwrap();
478
479 let sig = h.signature(h.entrypoint()).unwrap();
480 assert_eq!(sig.input(), &TypeRow::from(vec![bool_dest()]));
481 assert_eq!(sig.output(), &TypeRow::from(vec![bool_t()]));
482
483 assert_eq!(h.num_nodes(), 18);
484 }
485
486 #[test]
487 fn test_make_opaque() {
488 let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t()], vec![tket_bool_t()])).unwrap();
489 let [b] = dfb.input_wires_arr();
490 let output = dfb.add_bool_make_opaque(b).unwrap();
491 let mut h = dfb.finish_hugr_with_outputs(output).unwrap();
492
493 assert_eq!(h.num_nodes(), 8);
494
495 let pass = ReplaceBoolPass;
496 pass.run(&mut h).unwrap();
497 h.validate().unwrap();
498
499 let sig = h.signature(h.entrypoint()).unwrap();
500 assert_eq!(sig.input(), &TypeRow::from(vec![bool_t()]));
501 assert_eq!(sig.output(), &TypeRow::from(vec![bool_dest()]));
502
503 assert_eq!(h.num_nodes(), 11);
504 }
505
506 #[rstest]
507 #[case(BoolOp::eq)]
508 #[case(BoolOp::and)]
509 #[case(BoolOp::or)]
510 #[case(BoolOp::xor)]
511 fn test_logic(#[case] logic_op: BoolOp) {
512 let mut dfb = DFGBuilder::new(inout_sig(
513 vec![tket_bool_t(), tket_bool_t()],
514 vec![tket_bool_t()],
515 ))
516 .unwrap();
517 let [b1, b2] = dfb.input_wires_arr();
518 let result = dfb.add_dataflow_op(logic_op, [b1, b2]).unwrap();
519 let mut h = dfb.finish_hugr_with_outputs(result.outputs()).unwrap();
520
521 let pass = ReplaceBoolPass;
522 pass.run(&mut h).unwrap();
523 h.validate().unwrap();
524
525 let sig = h.signature(h.entrypoint()).unwrap();
526 assert_eq!(sig.input(), &TypeRow::from(vec![bool_dest(), bool_dest()]));
527 assert_eq!(sig.output(), &TypeRow::from(vec![bool_dest()]));
528 }
529
530 #[test]
531 fn test_not() {
532 let mut dfb = DFGBuilder::new(inout_sig(vec![tket_bool_t()], vec![tket_bool_t()])).unwrap();
533 let [b] = dfb.input_wires_arr();
534 let result = dfb.add_dataflow_op(BoolOp::not, [b]).unwrap();
535 let mut h = dfb.finish_hugr_with_outputs(result.outputs()).unwrap();
536
537 let pass = ReplaceBoolPass;
538 pass.run(&mut h).unwrap();
539 h.validate().unwrap();
540
541 let sig = h.signature(h.entrypoint()).unwrap();
542 assert_eq!(sig.input(), &TypeRow::from(vec![bool_dest()]));
543 assert_eq!(sig.output(), &TypeRow::from(vec![bool_dest()]));
544 }
545
546 #[rstest]
547 #[case(TketOp::MeasureFree)]
548 #[case(QSystemOp::Measure)]
549 fn test_measure<T: Into<OpType>>(#[case] measure_op: T) {
550 let mut dfb = DFGBuilder::new(inout_sig(vec![qb_t()], vec![bool_type()])).unwrap();
551 let [q] = dfb.input_wires_arr();
552 let output = dfb.add_dataflow_op(measure_op, [q]).unwrap();
553 let mut h = dfb.finish_hugr_with_outputs(output.outputs()).unwrap();
554
555 let pass = ReplaceBoolPass;
556 pass.run(&mut h).unwrap();
557 h.validate().unwrap();
558
559 let sig = h.signature(h.entrypoint()).unwrap();
560 assert_eq!(sig.output(), &TypeRow::from(vec![bool_dest()]));
561
562 }
572
573 #[test]
574 fn test_measure_reset() {
575 let mut dfb = DFGBuilder::new(inout_sig(vec![qb_t()], vec![qb_t(), bool_type()])).unwrap();
576 let [q] = dfb.input_wires_arr();
577 let output = dfb.add_measure_reset(q).unwrap();
578 let mut h = dfb.finish_hugr_with_outputs(output).unwrap();
579
580 let pass = ReplaceBoolPass;
581 pass.run(&mut h).unwrap();
582 h.validate().unwrap();
583
584 let sig = h.signature(h.entrypoint()).unwrap();
585 assert_eq!(sig.output(), &TypeRow::from(vec![qb_t(), bool_dest()]));
586 }
587
588 #[rstest]
589 #[case(Array)]
590 #[case(BorrowArray)]
591 fn test_array_clone_bool<AK: ArrayKind>(#[case] _ak: AK) {
592 let elem_ty = bool_type();
593 let size = 4;
594 let arr_ty = AK::ty(size, elem_ty.clone());
595 let mut dfb = DFGBuilder::new(inout_sig(
596 vec![arr_ty.clone()],
597 vec![arr_ty.clone(), arr_ty.clone()],
598 ))
599 .unwrap();
600 let [arr_in] = dfb.input_wires_arr();
601 let (arr1, arr2) = dfb
602 .add_generic_array_clone::<AK>(elem_ty, size, arr_in)
603 .unwrap();
604 let mut h = dfb.finish_hugr_with_outputs([arr1, arr2]).unwrap();
605
606 h.validate().unwrap();
607 let pass = ReplaceBoolPass;
608 pass.run(&mut h).unwrap();
609 h.validate().unwrap();
610
611 let sig = h.signature(h.entrypoint()).unwrap();
612 let bool_dest_ty = bool_dest();
613 let arr_dest_ty = AK::ty(size, bool_dest_ty);
614 assert_eq!(sig.input(), &TypeRow::from(vec![arr_dest_ty.clone()]));
615 assert_eq!(
616 sig.output(),
617 &TypeRow::from(vec![arr_dest_ty.clone(), arr_dest_ty])
618 );
619 }
620
621 #[rstest]
622 #[case(Array)]
623 #[case(BorrowArray)]
624 fn test_array_discard_bool<AK: ArrayKind>(#[case] _ak: AK) {
625 let elem_ty = bool_type();
626 let size = 4;
627 let arr_ty = AK::ty(size, elem_ty.clone());
628 let mut dfb = DFGBuilder::new(inout_sig(vec![arr_ty.clone()], type_row![])).unwrap();
629 let [arr_in] = dfb.input_wires_arr();
630 dfb.add_generic_array_discard::<AK>(elem_ty, size, arr_in)
631 .unwrap();
632 let mut h = dfb.finish_hugr_with_outputs([]).unwrap();
633
634 h.validate().unwrap();
635 let pass = ReplaceBoolPass;
636 pass.run(&mut h).unwrap();
637 h.validate().unwrap();
638 }
639
640 #[rstest]
641 #[case(Type::new_tuple(vec![tket_bool_t(), usize_t()]), Type::new_tuple(vec![bool_dest(), usize_t()]), true)]
642 #[case(tket_bool_t(), bool_dest(), true)]
643 #[case(usize_t(), usize_t(), false)]
644 fn test_barray_get(#[case] src_ty: Type, #[case] dest_ty: Type, #[case] expect_dup: bool) {
645 let arr_ty = borrow_array_type(4, src_ty.clone());
646 let mut dfb = DFGBuilder::new(inout_sig(
647 vec![arr_ty.clone(), usize_t()],
648 vec![arr_ty, src_ty.clone()],
649 ))
650 .unwrap();
651 let [arr_in, idx] = dfb.input_wires_arr();
652 let (opt_elem, arr) = dfb
653 .add_borrow_array_get(src_ty.clone(), 4, arr_in, idx)
654 .unwrap();
655 let [elem] = dfb
656 .build_unwrap_sum(1, option_type(src_ty.clone()), opt_elem)
657 .unwrap();
658 let mut h = dfb.finish_hugr_with_outputs([arr, elem]).unwrap();
659
660 h.validate().unwrap();
661 let pass = ReplaceBoolPass;
662 pass.run(&mut h).unwrap(); h.validate().unwrap();
664
665 let sig = h.signature(h.entrypoint()).unwrap();
666 let dest_arr_ty = borrow_array_type(4, dest_ty.clone());
667 assert_eq!(
668 sig.as_ref(),
669 &inout_sig(
670 vec![dest_arr_ty.clone(), usize_t()],
671 vec![dest_arr_ty, dest_ty]
672 )
673 );
674 let contains_dup = h.nodes().any(|n| {
675 h.get_optype(n).as_extension_op().is_some_and(|eop| {
676 FutureOp::from_op(eop).is_ok_and(|fop| fop.op == FutureOpDef::Dup)
677 })
678 });
679 assert_eq!(contains_dup, expect_dup);
680 }
681}