1use anyhow::Result;
11use rand::Rng;
12use rand::SeedableRng;
13use rand::distributions::Distribution;
14use rand::distributions::WeightedIndex;
15use rand::prelude::SliceRandom;
16use rand_pcg::Pcg64Mcg;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use std::collections::BTreeMap;
20use std::io::Write as IoWrite;
21use std::panic::{AssertUnwindSafe, catch_unwind};
22use std::path::{Path, PathBuf};
23use std::process::Command;
24use std::sync::Arc;
25use std::sync::Mutex;
26use std::sync::atomic::AtomicUsize;
27use std::sync::atomic::Ordering;
28use std::sync::mpsc::Sender;
29use std::time::Instant;
30
31use clap::ValueEnum;
32
33use xlsynth::IrBits;
34use xlsynth::IrPackage;
35use xlsynth::IrValue;
36use xlsynth_g8r::aig::get_summary_stats;
37use xlsynth_g8r::aig::get_summary_stats::AigStats;
38use xlsynth_g8r::aig::graph_logical_effort::GraphLogicalEffortOptions;
39use xlsynth_g8r::aig::graph_logical_effort::analyze_graph_logical_effort;
40use xlsynth_g8r::aig_serdes::emit_aiger_binary::emit_aiger_binary;
41use xlsynth_g8r::aig_serdes::gate2ir::{
42 GateFnInterfaceSchema, repack_gate_fn_interface_with_schema,
43};
44use xlsynth_g8r::aig_serdes::load_aiger_auto::load_aiger_auto_from_path;
45use xlsynth_g8r::aig_sim::count_toggles;
46use xlsynth_g8r::gate_builder::GateBuilderOptions;
47use xlsynth_g8r::process_ir_path::{
48 CanonicalG8rOptions, canonical_ir_text_to_g8r_lowering_artifacts,
49};
50use xlsynth_mcmc::MIN_TEMPERATURE_RATIO;
51use xlsynth_mcmc::McmcIterationOutput as SharedMcmcIterationOutput;
52use xlsynth_mcmc::McmcOptions as SharedMcmcOptions;
53use xlsynth_mcmc::McmcStats as SharedMcmcStats;
54use xlsynth_mcmc::metropolis_accept;
55use xlsynth_mcmc::multichain::{ChainRole, ChainStrategy, SegmentOutcome, SegmentRunParams};
56use xlsynth_mcmc::multichain::{SegmentRunner, run_multichain};
57use xlsynth_pir::desugar_extensions::{self, ExtensionEmitMode};
58use xlsynth_pir::ir::FileTable as PirFileTable;
59use xlsynth_pir::ir::Fn as IrFn;
60use xlsynth_pir::ir::Package as PirPackage;
61use xlsynth_pir::ir::PackageMember as PirPackageMember;
62use xlsynth_pir::ir::Param as PirParam;
63use xlsynth_pir::ir::Type as PirType;
64use xlsynth_pir::ir_eval::{FnEvalResult, eval_fn_assuming_node_index_topological};
65use xlsynth_pir::ir_parser;
66use xlsynth_pir::ir_utils::compact_and_toposort_in_place;
67use xlsynth_pir::ir_value_utils::flatten_ir_value_to_lsb0_bits_for_type;
68use xlsynth_pir::random_inputs::generate_biased_irbits_with_rng;
69use xlsynth_pir::structural_similarity::collect_structural_entries;
70
71pub mod driver_cli;
72pub mod transforms;
73
74use crate::transforms::{
75 PirTransform, PirTransformKind, build_transform_weights, get_all_pir_transforms,
76};
77
78const DEFAULT_ORACLE_RANDOM_SAMPLES: usize = 32;
79
80pub fn parse_irvals_tuple_lines(irvals_text: &str) -> Result<Vec<IrValue>> {
81 xlsynth_pir::irvals::parse_irvals_tuple_lines(irvals_text).map_err(anyhow::Error::msg)
82}
83
84pub fn parse_irvals_tuple_file(path: &Path) -> Result<Vec<IrValue>> {
85 xlsynth_pir::irvals::parse_irvals_tuple_file(path).map_err(anyhow::Error::msg)
86}
87
88static INVALID_BIT_SLICE_WARN_COUNT: AtomicUsize = AtomicUsize::new(0);
93const INVALID_BIT_SLICE_WARN_LIMIT: usize = 8;
94
95use xlsynth_prover::prover::prove_ir_fn_equiv;
96use xlsynth_prover::prover::types::EquivResult;
97
98#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
100pub struct Cost {
101 pub pir_nodes: usize,
103 pub g8r_nodes: usize,
106 pub g8r_depth: usize,
109 pub g8r_le_graph_milli: usize,
114 pub g8r_gate_output_toggles: usize,
120 pub g8r_weighted_switching_milli: u128,
125 pub g8r_post_and_nodes: usize,
128 pub g8r_post_depth: usize,
131 pub g8r_post_le_graph_milli: usize,
134 pub g8r_post_gate_output_toggles: usize,
136 pub g8r_post_weighted_switching_milli: u128,
139}
140
141#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
143#[serde(tag = "kind", rename_all = "snake_case")]
144pub enum G8rEvaluationMode {
145 Builtin,
147 ExternalPostprocess { program: String },
150}
151
152impl Default for G8rEvaluationMode {
153 fn default() -> Self {
154 Self::Builtin
155 }
156}
157
158impl G8rEvaluationMode {
159 pub(crate) fn external_postprocess_program(&self) -> Option<&str> {
160 match self {
161 G8rEvaluationMode::Builtin => None,
162 G8rEvaluationMode::ExternalPostprocess { program } => Some(program.as_str()),
163 }
164 }
165
166 pub fn canonicalized_for_persistence(&self) -> Result<Self> {
168 match self {
169 G8rEvaluationMode::Builtin => Ok(Self::Builtin),
170 G8rEvaluationMode::ExternalPostprocess { program } => {
171 let path = std::fs::canonicalize(program).map_err(|e| {
172 anyhow::anyhow!(
173 "failed to canonicalize g8r postprocess program '{}': {}",
174 program,
175 e
176 )
177 })?;
178 Ok(Self::ExternalPostprocess {
179 program: path.display().to_string(),
180 })
181 }
182 }
183 }
184}
185
186#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
188pub enum ExtensionCostingMode {
189 #[value(name = "preserve")]
192 Preserve,
193 #[value(name = "desugar")]
196 Desugar,
197}
198
199impl ExtensionCostingMode {
200 pub fn value_name(self) -> &'static str {
201 match self {
202 ExtensionCostingMode::Preserve => "preserve",
203 ExtensionCostingMode::Desugar => "desugar",
204 }
205 }
206
207 fn from_value_name(value: &str) -> Result<Self> {
208 match value {
209 "preserve" => Ok(ExtensionCostingMode::Preserve),
210 "desugar" => Ok(ExtensionCostingMode::Desugar),
211 _ => Err(anyhow::anyhow!(
212 "unknown extension costing mode in artifact: {}",
213 value
214 )),
215 }
216 }
217
218 fn extension_emit_mode(self) -> ExtensionEmitMode {
219 match self {
220 ExtensionCostingMode::Preserve => ExtensionEmitMode::AsFfiFunction,
221 ExtensionCostingMode::Desugar => ExtensionEmitMode::Desugared,
222 }
223 }
224}
225
226#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
230pub struct ConstraintLimits {
231 pub max_delay: Option<usize>,
232 pub max_area: Option<usize>,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, Eq)]
240pub struct ConstraintViolationScore {
241 pub delay_over: Option<usize>,
242 pub area_over: Option<usize>,
243}
244
245#[derive(Clone, Copy, Debug, PartialEq, Eq)]
252pub struct SearchScore {
253 pub objective: u128,
254 pub violation: Option<ConstraintViolationScore>,
255}
256
257impl SearchScore {
258 pub fn feasible(self) -> bool {
260 self.violation.is_none()
261 }
262}
263
264impl Ord for SearchScore {
265 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
266 match (self.violation, other.violation) {
267 (None, None) => self.objective.cmp(&other.objective),
268 (None, Some(_)) => std::cmp::Ordering::Less,
269 (Some(_), None) => std::cmp::Ordering::Greater,
270 (Some(lhs), Some(rhs)) => lhs
271 .delay_over
272 .cmp(&rhs.delay_over)
273 .then_with(|| lhs.area_over.cmp(&rhs.area_over))
274 .then_with(|| self.objective.cmp(&other.objective)),
275 }
276 }
277}
278
279impl PartialOrd for SearchScore {
280 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
281 Some(self.cmp(other))
282 }
283}
284
285#[derive(Clone)]
286struct BestState {
287 score: SearchScore,
288 value: IrFn,
289}
290
291pub struct Best {
293 inner: Mutex<BestState>,
294}
295
296impl Best {
297 pub fn new(initial_score: SearchScore, value: IrFn) -> Self {
298 Self {
299 inner: Mutex::new(BestState {
300 score: initial_score,
301 value,
302 }),
303 }
304 }
305
306 pub fn try_update(&self, new_score: SearchScore, new_value: IrFn) -> bool {
307 let mut guard = self.inner.lock().unwrap();
308 if new_score < guard.score {
309 *guard = BestState {
310 score: new_score,
311 value: new_value,
312 };
313 true
314 } else {
315 false
316 }
317 }
318
319 pub fn get(&self) -> IrFn {
320 self.inner.lock().unwrap().value.clone()
321 }
322
323 pub fn score(&self) -> SearchScore {
324 self.inner.lock().unwrap().score
325 }
326}
327
328pub fn cost(f: &IrFn, objective: Objective) -> Result<Cost> {
334 cost_with_effort_options_and_toggle_stimulus(
335 f,
336 objective,
337 None,
338 &count_toggles::WeightedSwitchingOptions::default(),
339 )
340}
341
342pub fn cost_with_toggle_stimulus(
345 f: &IrFn,
346 objective: Objective,
347 toggle_stimulus: Option<&[Vec<IrBits>]>,
348) -> Result<Cost> {
349 cost_with_effort_options_and_toggle_stimulus(
350 f,
351 objective,
352 toggle_stimulus,
353 &count_toggles::WeightedSwitchingOptions::default(),
354 )
355}
356
357pub fn cost_with_effort_options_and_toggle_stimulus(
359 f: &IrFn,
360 objective: Objective,
361 toggle_stimulus: Option<&[Vec<IrBits>]>,
362 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
363) -> Result<Cost> {
364 cost_with_effort_options_toggle_stimulus_and_extension_mode(
365 f,
366 objective,
367 toggle_stimulus,
368 weighted_switching_options,
369 ExtensionCostingMode::Preserve,
370 )
371}
372
373pub fn cost_with_effort_options_toggle_stimulus_and_extension_mode(
376 f: &IrFn,
377 objective: Objective,
378 toggle_stimulus: Option<&[Vec<IrBits>]>,
379 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
380 extension_costing_mode: ExtensionCostingMode,
381) -> Result<Cost> {
382 cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
383 f,
384 objective,
385 toggle_stimulus,
386 weighted_switching_options,
387 extension_costing_mode,
388 &G8rEvaluationMode::Builtin,
389 )
390}
391
392pub fn cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
395 f: &IrFn,
396 objective: Objective,
397 toggle_stimulus: Option<&[Vec<IrBits>]>,
398 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
399 extension_costing_mode: ExtensionCostingMode,
400 g8r_evaluation_mode: &G8rEvaluationMode,
401) -> Result<Cost> {
402 cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
403 f,
404 objective,
405 toggle_stimulus,
406 weighted_switching_options,
407 extension_costing_mode,
408 g8r_evaluation_mode,
409 &CanonicalG8rOptions::default(),
410 )
411}
412
413pub fn cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
415 f: &IrFn,
416 objective: Objective,
417 toggle_stimulus: Option<&[Vec<IrBits>]>,
418 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
419 extension_costing_mode: ExtensionCostingMode,
420 g8r_evaluation_mode: &G8rEvaluationMode,
421 canonical_g8r_options: &CanonicalG8rOptions,
422) -> Result<Cost> {
423 let pir_nodes = f.nodes.len();
424
425 if objective.needs_toggle_stimulus() && toggle_stimulus.is_none() {
426 return Err(anyhow::anyhow!(
427 "objective {} requires toggle stimulus",
428 objective.value_name()
429 ));
430 }
431 if !objective.needs_toggle_stimulus() && toggle_stimulus.is_some() {
432 return Err(anyhow::anyhow!(
433 "toggle stimulus provided but objective {} does not use toggles",
434 objective.value_name()
435 ));
436 }
437 if objective.needs_weighted_switching() {
438 validate_weighted_switching_options(weighted_switching_options)?;
439 }
440 if objective.uses_postprocessed_costing()
441 && g8r_evaluation_mode.external_postprocess_program().is_none()
442 {
443 return Err(anyhow::anyhow!(
444 "objective {} requires an external g8r postprocessor",
445 objective.value_name()
446 ));
447 }
448
449 let gate_stats = if objective.uses_gate_costing() {
450 compute_g8r_stats_for_pir_fn(
451 f,
452 objective.needs_graph_logical_effort(),
453 objective.needs_gate_output_toggles(),
454 objective.needs_weighted_switching(),
455 toggle_stimulus,
456 weighted_switching_options,
457 extension_costing_mode,
458 g8r_evaluation_mode,
459 canonical_g8r_options,
460 objective.uses_postprocessed_costing(),
461 )?
462 } else {
463 GateCostStats {
464 raw: RawG8rStats {
465 nodes: pir_nodes,
466 depth: pir_nodes,
467 le_graph_milli: 0,
468 gate_output_toggles: 0,
469 weighted_switching_milli: 0,
470 },
471 post: None,
472 }
473 };
474 let post = gate_stats.post.unwrap_or_default();
475
476 Ok(Cost {
477 pir_nodes,
478 g8r_nodes: gate_stats.raw.nodes,
479 g8r_depth: gate_stats.raw.depth,
480 g8r_le_graph_milli: gate_stats.raw.le_graph_milli,
481 g8r_gate_output_toggles: gate_stats.raw.gate_output_toggles,
482 g8r_weighted_switching_milli: gate_stats.raw.weighted_switching_milli,
483 g8r_post_and_nodes: post.and_nodes,
484 g8r_post_depth: post.depth,
485 g8r_post_le_graph_milli: post.le_graph_milli,
486 g8r_post_gate_output_toggles: post.gate_output_toggles,
487 g8r_post_weighted_switching_milli: post.weighted_switching_milli,
488 })
489}
490
491fn validate_weighted_switching_options(
492 options: &count_toggles::WeightedSwitchingOptions,
493) -> Result<()> {
494 let checks = [
495 ("beta1", options.beta1),
496 ("beta2", options.beta2),
497 ("primary_output_load", options.primary_output_load),
498 ];
499 for (name, value) in checks {
500 if !value.is_finite() {
501 return Err(anyhow::anyhow!(
502 "weighted switching option '{}' must be finite; got {}",
503 name,
504 value
505 ));
506 }
507 }
508 Ok(())
509}
510
511pub(crate) fn optimize_pir_package_via_xls_with_extension_mode(
513 pkg: &PirPackage,
514 top: &str,
515 extension_costing_mode: ExtensionCostingMode,
516) -> Result<PirPackage> {
517 let wrapped_ir_text = desugar_extensions::emit_package_with_extension_mode(
518 pkg,
519 extension_costing_mode.extension_emit_mode(),
520 )
521 .map_err(|e| anyhow::anyhow!("emit_package_with_extension_mode failed: {}", e))?;
522
523 let ir_pkg = IrPackage::parse_ir(&wrapped_ir_text, None)
524 .map_err(|e| anyhow::anyhow!("IrPackage::parse_ir failed: {:?}", e))?;
525 let optimized_ir_pkg = xlsynth::optimize_ir(&ir_pkg, top)
526 .map_err(|e| anyhow::anyhow!("optimize_ir failed: {:?}", e))?;
527 let optimized_ir_text = optimized_ir_pkg.to_string();
528
529 let mut parser = ir_parser::Parser::new(&optimized_ir_text);
530 parser
531 .parse_and_validate_package()
532 .map_err(|e| anyhow::anyhow!("PIR parse_and_validate_package failed: {:?}", e))
533}
534
535pub(crate) fn optimize_pir_fn_via_xls_with_extension_mode(
538 f: &IrFn,
539 extension_costing_mode: ExtensionCostingMode,
540) -> Result<IrFn> {
541 let mut fn_for_text = f.clone();
545 compact_and_toposort_in_place(&mut fn_for_text)
546 .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
547
548 let mut pir_pkg = PirPackage {
549 name: "pir_mcmc".to_string(),
550 file_table: PirFileTable::new(),
551 members: vec![PirPackageMember::Function(fn_for_text)],
552 top: None,
553 };
554 pir_pkg
555 .set_top_fn(&f.name)
556 .map_err(|e| anyhow::anyhow!("set_top_fn failed: {}", e))?;
557
558 let pir_pkg = optimize_pir_package_via_xls_with_extension_mode(
559 &pir_pkg,
560 &f.name,
561 extension_costing_mode,
562 )?;
563 let top_fn = pir_pkg
564 .get_top_fn()
565 .ok_or_else(|| anyhow::anyhow!("No top function found in optimized PIR package"))?;
566 Ok(top_fn.clone())
567}
568
569#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
571pub enum Objective {
572 Nodes,
573 G8rNodes,
574 G8rNodesTimesDepth,
575 #[value(name = "g8r-nodes-times-depth-times-toggles")]
576 G8rNodesTimesDepthTimesToggles,
577 #[value(
578 name = "g8r-le-graph",
579 alias = "g8r-graph-le",
580 alias = "g8r-graph-logical-effort"
581 )]
582 G8rLeGraph,
583 #[value(name = "g8r-le-graph-times-nodes")]
584 G8rLeGraphTimesNodes,
585 #[value(name = "g8r-le-graph-times-product")]
586 G8rLeGraphTimesProduct,
587 #[value(name = "g8r-weighted-switching")]
588 G8rWeightedSwitching,
589 #[value(name = "g8r-nodes-times-weighted-switching-no-depth-regress")]
590 G8rNodesTimesWeightedSwitchingNoDepthRegress,
591 #[value(name = "g8r-post-and-nodes")]
592 G8rPostAndNodes,
593 #[value(name = "g8r-post-and-nodes-times-depth")]
594 G8rPostAndNodesTimesDepth,
595 #[value(name = "g8r-post-and-nodes-times-depth-times-toggles")]
596 G8rPostAndNodesTimesDepthTimesToggles,
597 #[value(name = "g8r-post-le-graph")]
598 G8rPostLeGraph,
599 #[value(name = "g8r-post-le-graph-times-and-nodes")]
600 G8rPostLeGraphTimesAndNodes,
601 #[value(name = "g8r-post-le-graph-times-product")]
602 G8rPostLeGraphTimesProduct,
603 #[value(name = "g8r-post-weighted-switching")]
604 G8rPostWeightedSwitching,
605 #[value(name = "g8r-post-and-nodes-times-weighted-switching-no-depth-regress")]
606 G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
607}
608
609impl Objective {
610 pub fn uses_g8r_costing(self) -> bool {
611 matches!(
612 self,
613 Objective::G8rNodes
614 | Objective::G8rNodesTimesDepth
615 | Objective::G8rNodesTimesDepthTimesToggles
616 | Objective::G8rLeGraph
617 | Objective::G8rLeGraphTimesNodes
618 | Objective::G8rLeGraphTimesProduct
619 | Objective::G8rWeightedSwitching
620 | Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
621 )
622 }
623
624 pub fn uses_postprocessed_costing(self) -> bool {
625 matches!(
626 self,
627 Objective::G8rPostAndNodes
628 | Objective::G8rPostAndNodesTimesDepth
629 | Objective::G8rPostAndNodesTimesDepthTimesToggles
630 | Objective::G8rPostLeGraph
631 | Objective::G8rPostLeGraphTimesAndNodes
632 | Objective::G8rPostLeGraphTimesProduct
633 | Objective::G8rPostWeightedSwitching
634 | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
635 )
636 }
637
638 pub fn uses_gate_costing(self) -> bool {
639 self.uses_g8r_costing() || self.uses_postprocessed_costing()
640 }
641
642 pub fn needs_graph_logical_effort(self) -> bool {
643 matches!(
644 self,
645 Objective::G8rLeGraph
646 | Objective::G8rLeGraphTimesNodes
647 | Objective::G8rLeGraphTimesProduct
648 | Objective::G8rPostLeGraph
649 | Objective::G8rPostLeGraphTimesAndNodes
650 | Objective::G8rPostLeGraphTimesProduct
651 )
652 }
653
654 pub fn needs_toggle_stimulus(self) -> bool {
655 self.needs_gate_output_toggles() || self.needs_weighted_switching()
656 }
657
658 pub fn needs_gate_output_toggles(self) -> bool {
659 matches!(
660 self,
661 Objective::G8rNodesTimesDepthTimesToggles
662 | Objective::G8rPostAndNodesTimesDepthTimesToggles
663 )
664 }
665
666 pub fn needs_weighted_switching(self) -> bool {
667 matches!(
668 self,
669 Objective::G8rWeightedSwitching
670 | Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
671 | Objective::G8rPostWeightedSwitching
672 | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
673 )
674 }
675
676 pub fn enforces_non_regressing_depth(self) -> bool {
677 matches!(
678 self,
679 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
680 | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
681 )
682 }
683
684 pub fn value_name(self) -> &'static str {
685 match self {
686 Objective::Nodes => "nodes",
687 Objective::G8rNodes => "g8r-nodes",
688 Objective::G8rNodesTimesDepth => "g8r-nodes-times-depth",
689 Objective::G8rNodesTimesDepthTimesToggles => "g8r-nodes-times-depth-times-toggles",
690 Objective::G8rLeGraph => "g8r-le-graph",
691 Objective::G8rLeGraphTimesNodes => "g8r-le-graph-times-nodes",
692 Objective::G8rLeGraphTimesProduct => "g8r-le-graph-times-product",
693 Objective::G8rWeightedSwitching => "g8r-weighted-switching",
694 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress => {
695 "g8r-nodes-times-weighted-switching-no-depth-regress"
696 }
697 Objective::G8rPostAndNodes => "g8r-post-and-nodes",
698 Objective::G8rPostAndNodesTimesDepth => "g8r-post-and-nodes-times-depth",
699 Objective::G8rPostAndNodesTimesDepthTimesToggles => {
700 "g8r-post-and-nodes-times-depth-times-toggles"
701 }
702 Objective::G8rPostLeGraph => "g8r-post-le-graph",
703 Objective::G8rPostLeGraphTimesAndNodes => "g8r-post-le-graph-times-and-nodes",
704 Objective::G8rPostLeGraphTimesProduct => "g8r-post-le-graph-times-product",
705 Objective::G8rPostWeightedSwitching => "g8r-post-weighted-switching",
706 Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress => {
707 "g8r-post-and-nodes-times-weighted-switching-no-depth-regress"
708 }
709 }
710 }
711
712 fn from_value_name(value: &str) -> Result<Self> {
713 match value {
714 "nodes" => Ok(Objective::Nodes),
715 "g8r-nodes" => Ok(Objective::G8rNodes),
716 "g8r-nodes-times-depth" => Ok(Objective::G8rNodesTimesDepth),
717 "g8r-nodes-times-depth-times-toggles" => Ok(Objective::G8rNodesTimesDepthTimesToggles),
718 "g8r-le-graph" => Ok(Objective::G8rLeGraph),
719 "g8r-le-graph-times-nodes" => Ok(Objective::G8rLeGraphTimesNodes),
720 "g8r-le-graph-times-product" => Ok(Objective::G8rLeGraphTimesProduct),
721 "g8r-weighted-switching" => Ok(Objective::G8rWeightedSwitching),
722 "g8r-nodes-times-weighted-switching-no-depth-regress" => {
723 Ok(Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress)
724 }
725 "g8r-post-and-nodes" => Ok(Objective::G8rPostAndNodes),
726 "g8r-post-and-nodes-times-depth" => Ok(Objective::G8rPostAndNodesTimesDepth),
727 "g8r-post-and-nodes-times-depth-times-toggles" => {
728 Ok(Objective::G8rPostAndNodesTimesDepthTimesToggles)
729 }
730 "g8r-post-le-graph" => Ok(Objective::G8rPostLeGraph),
731 "g8r-post-le-graph-times-and-nodes" => Ok(Objective::G8rPostLeGraphTimesAndNodes),
732 "g8r-post-le-graph-times-product" => Ok(Objective::G8rPostLeGraphTimesProduct),
733 "g8r-post-weighted-switching" => Ok(Objective::G8rPostWeightedSwitching),
734 "g8r-post-and-nodes-times-weighted-switching-no-depth-regress" => {
735 Ok(Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress)
736 }
737 _ => Err(anyhow::anyhow!("unknown objective in artifact: {}", value)),
738 }
739 }
740
741 pub fn metric(self, c: &Cost) -> u128 {
742 match self {
743 Objective::Nodes => c.pir_nodes as u128,
744 Objective::G8rNodes => c.g8r_nodes as u128,
745 Objective::G8rNodesTimesDepth => {
746 (c.g8r_nodes as u128).saturating_mul(c.g8r_depth as u128)
747 }
748 Objective::G8rNodesTimesDepthTimesToggles => (c.g8r_nodes as u128)
749 .saturating_mul(c.g8r_depth as u128)
750 .saturating_mul(c.g8r_gate_output_toggles as u128),
751 Objective::G8rLeGraph => c.g8r_le_graph_milli as u128,
752 Objective::G8rLeGraphTimesNodes => {
753 (c.g8r_le_graph_milli as u128).saturating_mul(c.g8r_nodes as u128)
754 }
755 Objective::G8rLeGraphTimesProduct => {
756 let product = (c.g8r_nodes as u128).saturating_mul(c.g8r_depth as u128);
757 (c.g8r_le_graph_milli as u128).saturating_mul(product)
758 }
759 Objective::G8rWeightedSwitching => c.g8r_weighted_switching_milli,
760 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress => {
761 (c.g8r_nodes as u128).saturating_mul(c.g8r_weighted_switching_milli)
762 }
763 Objective::G8rPostAndNodes => c.g8r_post_and_nodes as u128,
764 Objective::G8rPostAndNodesTimesDepth => {
765 (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_depth as u128)
766 }
767 Objective::G8rPostAndNodesTimesDepthTimesToggles => (c.g8r_post_and_nodes as u128)
768 .saturating_mul(c.g8r_post_depth as u128)
769 .saturating_mul(c.g8r_post_gate_output_toggles as u128),
770 Objective::G8rPostLeGraph => c.g8r_post_le_graph_milli as u128,
771 Objective::G8rPostLeGraphTimesAndNodes => {
772 (c.g8r_post_le_graph_milli as u128).saturating_mul(c.g8r_post_and_nodes as u128)
773 }
774 Objective::G8rPostLeGraphTimesProduct => {
775 let product =
776 (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_depth as u128);
777 (c.g8r_post_le_graph_milli as u128).saturating_mul(product)
778 }
779 Objective::G8rPostWeightedSwitching => c.g8r_post_weighted_switching_milli,
780 Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress => {
781 (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_weighted_switching_milli)
782 }
783 }
784 }
785
786 fn area_for_constraints(self, c: &Cost) -> usize {
787 if self.uses_postprocessed_costing() {
788 c.g8r_post_and_nodes
789 } else {
790 c.g8r_nodes
791 }
792 }
793
794 fn depth_for_constraints(self, c: &Cost) -> usize {
795 if self.uses_postprocessed_costing() {
796 c.g8r_post_depth
797 } else {
798 c.g8r_depth
799 }
800 }
801}
802
803pub(crate) fn validate_constraint_configuration(
804 objective: Objective,
805 limits: ConstraintLimits,
806) -> Result<()> {
807 if limits.max_delay.is_some() && limits.max_area.is_some() {
808 return Err(anyhow::anyhow!(
809 "at most one of --max-delay and --max-area may be specified"
810 ));
811 }
812 if !objective.uses_gate_costing() && (limits.max_delay.is_some() || limits.max_area.is_some()) {
813 return Err(anyhow::anyhow!(
814 "area/delay caps require a gate-based objective; got {}",
815 objective.value_name()
816 ));
817 }
818 if objective.enforces_non_regressing_depth() && limits.max_area.is_some() {
819 return Err(anyhow::anyhow!(
820 "--max-area is not compatible with objective {} because it already enforces a non-regressing depth cap",
821 objective.value_name()
822 ));
823 }
824 Ok(())
825}
826
827pub fn constraint_violation(
829 c: &Cost,
830 objective: Objective,
831 limits: ConstraintLimits,
832) -> Option<ConstraintViolationScore> {
833 let delay_over = limits
834 .max_delay
835 .map(|max_delay| objective.depth_for_constraints(c).saturating_sub(max_delay))
836 .filter(|over| *over > 0);
837 let area_over = limits
838 .max_area
839 .map(|max_area| objective.area_for_constraints(c).saturating_sub(max_area))
840 .filter(|over| *over > 0);
841
842 if delay_over.is_none() && area_over.is_none() {
843 return None;
844 }
845
846 Some(ConstraintViolationScore {
847 delay_over,
848 area_over,
849 })
850}
851
852pub fn search_score(c: &Cost, objective: Objective, limits: ConstraintLimits) -> SearchScore {
855 SearchScore {
856 objective: objective.metric(c),
857 violation: constraint_violation(c, objective, limits),
858 }
859}
860
861pub(crate) fn effective_constraint_limits(
862 objective: Objective,
863 user_limits: ConstraintLimits,
864 initial_cost: &Cost,
865) -> ConstraintLimits {
866 let limits = ConstraintLimits {
867 max_delay: match (
868 user_limits.max_delay,
869 objective.enforces_non_regressing_depth(),
870 ) {
871 (Some(user_cap), true) => {
872 Some(user_cap.min(objective.depth_for_constraints(initial_cost)))
873 }
874 (Some(user_cap), false) => Some(user_cap),
875 (None, true) => Some(objective.depth_for_constraints(initial_cost)),
876 (None, false) => None,
877 },
878 max_area: user_limits.max_area,
879 };
880 debug_assert!(
881 limits.max_delay.is_none() || limits.max_area.is_none(),
882 "effective constraints must keep at most one active cap"
883 );
884 limits
885}
886
887fn repair_energy(v: ConstraintViolationScore) -> u128 {
888 match (v.delay_over, v.area_over) {
889 (Some(over), None) => over as u128,
890 (None, Some(over)) => over as u128,
891 (Some(_), Some(_)) => unreachable!("constraint configuration validation rejects dual caps"),
892 (None, None) => 0,
893 }
894}
895
896pub(crate) fn format_search_score(score: SearchScore) -> String {
897 match score.violation {
898 None => format!("feasible(obj={})", score.objective),
899 Some(v) => format!(
900 "infeasible(obj={}, delay_over={:?}, area_over={:?})",
901 score.objective, v.delay_over, v.area_over,
902 ),
903 }
904}
905
906#[derive(Clone, Copy, Debug)]
907struct RawG8rStats {
908 nodes: usize,
909 depth: usize,
910 le_graph_milli: usize,
911 gate_output_toggles: usize,
912 weighted_switching_milli: u128,
913}
914
915#[derive(Clone, Copy, Debug, Default)]
916struct G8rPostStats {
917 and_nodes: usize,
918 depth: usize,
919 le_graph_milli: usize,
920 gate_output_toggles: usize,
921 weighted_switching_milli: u128,
922}
923
924#[derive(Clone, Copy, Debug)]
925struct GateCostStats {
926 raw: RawG8rStats,
927 post: Option<G8rPostStats>,
928}
929
930pub struct CanonicalG8rScoringInput {
932 pub top_fn: IrFn,
933 pub ir_text: String,
934}
935
936pub fn canonical_g8r_scoring_input_for_pir_fn(
938 f: &IrFn,
939 extension_costing_mode: ExtensionCostingMode,
940) -> Result<CanonicalG8rScoringInput> {
941 let top_fn = optimize_pir_fn_via_xls_with_extension_mode(f, extension_costing_mode)?;
942 let ir_text = format!("package pir_mcmc\n\ntop {}", top_fn);
943 Ok(CanonicalG8rScoringInput { top_fn, ir_text })
944}
945
946fn canonical_g8r_scoring_lowering_options(
948 canonical_g8r_options: &CanonicalG8rOptions,
949) -> CanonicalG8rOptions {
950 let mut scoring_g8r_options = canonical_g8r_options.clone();
951 scoring_g8r_options.compute_graph_logical_effort = false;
954 scoring_g8r_options
955}
956
957pub(crate) struct PostprocessedAigArtifact {
959 pub bytes: Vec<u8>,
960 pub stats: AigStats,
961 pub graph_logical_effort_worst_case_delay: Option<f64>,
962}
963
964fn compute_g8r_stats_for_pir_fn(
967 f: &IrFn,
968 compute_graph_logical_effort: bool,
969 compute_gate_output_toggles: bool,
970 compute_weighted_switching: bool,
971 toggle_stimulus: Option<&[Vec<IrBits>]>,
972 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
973 extension_costing_mode: ExtensionCostingMode,
974 g8r_evaluation_mode: &G8rEvaluationMode,
975 canonical_g8r_options: &CanonicalG8rOptions,
976 compute_postprocessed_stats: bool,
977) -> Result<GateCostStats> {
978 let result = catch_unwind(AssertUnwindSafe(|| {
985 compute_g8r_stats_for_pir_fn_impl(
986 f,
987 compute_graph_logical_effort,
988 compute_gate_output_toggles,
989 compute_weighted_switching,
990 toggle_stimulus,
991 weighted_switching_options,
992 extension_costing_mode,
993 g8r_evaluation_mode,
994 canonical_g8r_options,
995 compute_postprocessed_stats,
996 )
997 }));
998 match result {
999 Ok(r) => r,
1000 Err(_panic) => Err(anyhow::anyhow!(
1001 "panic during g8r-stats pipeline (likely a cycle)"
1002 )),
1003 }
1004}
1005
1006fn graph_le_delay_to_milli(delay: f64) -> usize {
1007 if !delay.is_finite() {
1008 return usize::MAX;
1009 }
1010 if delay <= 0.0 {
1011 return 0;
1012 }
1013 let scaled = delay * 1000.0;
1014 if scaled >= usize::MAX as f64 {
1015 usize::MAX
1016 } else {
1017 scaled.round() as usize
1018 }
1019}
1020
1021fn compute_g8r_stats_for_pir_fn_impl(
1022 f: &IrFn,
1023 compute_graph_logical_effort: bool,
1024 compute_gate_output_toggles: bool,
1025 compute_weighted_switching: bool,
1026 toggle_stimulus: Option<&[Vec<IrBits>]>,
1027 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1028 extension_costing_mode: ExtensionCostingMode,
1029 g8r_evaluation_mode: &G8rEvaluationMode,
1030 canonical_g8r_options: &CanonicalG8rOptions,
1031 compute_postprocessed_stats: bool,
1032) -> Result<GateCostStats> {
1033 let scoring_input = canonical_g8r_scoring_input_for_pir_fn(f, extension_costing_mode)?;
1035 let top_fn = scoring_input.top_fn;
1036 let scoring_g8r_options = canonical_g8r_scoring_lowering_options(canonical_g8r_options);
1037 let artifacts = canonical_ir_text_to_g8r_lowering_artifacts(
1038 &scoring_input.ir_text,
1039 Some(&top_fn.name),
1040 &scoring_g8r_options,
1041 )
1042 .map_err(|e| anyhow::anyhow!("canonical g8r lowering failed: {}", e))?;
1043 let gate_fn = artifacts.gate_fn;
1044 let stats = artifacts.stats;
1045 let g8r_le_graph_milli = if compute_graph_logical_effort {
1046 let graph_le = analyze_graph_logical_effort(
1047 &gate_fn,
1048 &GraphLogicalEffortOptions {
1049 beta1: canonical_g8r_options.graph_logical_effort_beta1,
1050 beta2: canonical_g8r_options.graph_logical_effort_beta2,
1051 },
1052 );
1053 graph_le_delay_to_milli(graph_le.delay)
1054 } else {
1055 0
1056 };
1057 let (g8r_gate_output_toggles, g8r_weighted_switching_milli) = if compute_gate_output_toggles
1058 || compute_weighted_switching
1059 {
1060 let batch = toggle_stimulus.ok_or_else(|| {
1061 anyhow::anyhow!("toggle-based objective requires prepared toggle stimulus")
1062 })?;
1063 if batch.len() < 2 {
1064 return Err(anyhow::anyhow!(
1065 "toggle stimulus must contain at least two samples; got {}",
1066 batch.len()
1067 ));
1068 }
1069 let expected_input_count = gate_fn.inputs.len();
1070 for (sample_idx, sample) in batch.iter().enumerate() {
1071 if sample.len() != expected_input_count {
1072 return Err(anyhow::anyhow!(
1073 "toggle sample {} has {} inputs, expected {}",
1074 sample_idx + 1,
1075 sample.len(),
1076 expected_input_count
1077 ));
1078 }
1079 for (input_idx, (bits, gate_input)) in
1080 sample.iter().zip(gate_fn.inputs.iter()).enumerate()
1081 {
1082 let expected_width = gate_input.get_bit_count();
1083 if bits.get_bit_count() != expected_width {
1084 return Err(anyhow::anyhow!(
1085 "toggle sample {} input {} has width {}, expected {}",
1086 sample_idx + 1,
1087 input_idx,
1088 bits.get_bit_count(),
1089 expected_width
1090 ));
1091 }
1092 }
1093 }
1094 let mut gate_output_toggles = 0usize;
1095 let mut weighted_switching_milli = 0u128;
1096 if compute_weighted_switching {
1097 let weighted_stats = count_toggles::count_weighted_switching(
1098 &gate_fn,
1099 batch,
1100 weighted_switching_options,
1101 );
1102 weighted_switching_milli = weighted_stats.weighted_switching_milli;
1103 if compute_gate_output_toggles {
1104 gate_output_toggles = weighted_stats.gate_output_toggles;
1105 }
1106 }
1107 if compute_gate_output_toggles && !compute_weighted_switching {
1108 gate_output_toggles = count_toggles::count_toggles(&gate_fn, batch).gate_output_toggles;
1109 }
1110 (gate_output_toggles, weighted_switching_milli)
1111 } else {
1112 (0, 0u128)
1113 };
1114 let post = if compute_postprocessed_stats {
1115 let schema = GateFnInterfaceSchema::from_pir_fn(&top_fn)
1116 .map_err(|e| anyhow::anyhow!("failed to derive gate interface schema: {}", e))?;
1117 let post_gate_fn =
1118 postprocess_gate_fn_with_external_program(&gate_fn, &schema, g8r_evaluation_mode)?
1119 .gate_fn;
1120 Some(compute_post_stats_for_gate_fn(
1121 &post_gate_fn,
1122 compute_graph_logical_effort,
1123 compute_gate_output_toggles,
1124 compute_weighted_switching,
1125 toggle_stimulus,
1126 weighted_switching_options,
1127 canonical_g8r_options,
1128 )?)
1129 } else {
1130 None
1131 };
1132
1133 Ok(GateCostStats {
1134 raw: RawG8rStats {
1135 nodes: stats.live_nodes,
1136 depth: stats.deepest_path,
1137 le_graph_milli: g8r_le_graph_milli,
1138 gate_output_toggles: g8r_gate_output_toggles,
1139 weighted_switching_milli: g8r_weighted_switching_milli,
1140 },
1141 post,
1142 })
1143}
1144
1145struct LoadedPostprocessedGateFn {
1146 gate_fn: xlsynth_g8r::aig::gate::GateFn,
1147 output_bytes: Vec<u8>,
1148}
1149
1150fn postprocess_gate_fn_with_external_program(
1153 gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1154 schema: &GateFnInterfaceSchema,
1155 g8r_evaluation_mode: &G8rEvaluationMode,
1156) -> Result<LoadedPostprocessedGateFn> {
1157 let program = g8r_evaluation_mode
1158 .external_postprocess_program()
1159 .ok_or_else(|| {
1160 anyhow::anyhow!("g8r postprocessing requested without an external postprocessor")
1161 })?;
1162 let temp_dir = tempfile::Builder::new()
1163 .prefix("pir_mcmc_g8r_postprocess_")
1164 .tempdir()
1165 .map_err(|e| anyhow::anyhow!("failed to create g8r postprocess tempdir: {}", e))?;
1166 let input_path = temp_dir.path().join("input.aig");
1167 let output_path = temp_dir.path().join("output.aig");
1168 let input_bytes = emit_aiger_binary(gate_fn, true)
1169 .map_err(|e| anyhow::anyhow!("emit AIGER failed: {}", e))?;
1170 std::fs::write(&input_path, input_bytes).map_err(|e| {
1171 anyhow::anyhow!(
1172 "failed to write g8r postprocess input {}: {}",
1173 input_path.display(),
1174 e
1175 )
1176 })?;
1177
1178 let output = Command::new(program)
1179 .arg(&input_path)
1180 .arg("--output-path")
1181 .arg(&output_path)
1182 .output()
1183 .map_err(|e| anyhow::anyhow!("failed to run g8r postprocessor '{}': {}", program, e))?;
1184 if !output.status.success() {
1185 return Err(anyhow::anyhow!(
1186 "g8r postprocessor '{}' failed with status {}: {}",
1187 program,
1188 output.status,
1189 String::from_utf8_lossy(&output.stderr).trim()
1190 ));
1191 }
1192 if !output_path.exists() {
1193 return Err(anyhow::anyhow!(
1194 "g8r postprocessor '{}' did not create {}",
1195 program,
1196 output_path.display()
1197 ));
1198 }
1199 let output_bytes = std::fs::read(&output_path).map_err(|e| {
1200 anyhow::anyhow!(
1201 "failed to read g8r postprocess output {}: {}",
1202 output_path.display(),
1203 e
1204 )
1205 })?;
1206 let loaded =
1207 load_aiger_auto_from_path(&output_path, GateBuilderOptions::no_opt()).map_err(|e| {
1208 anyhow::anyhow!(
1209 "failed to load g8r postprocess output {}: {}",
1210 output_path.display(),
1211 e
1212 )
1213 })?;
1214 let gate_fn = repack_gate_fn_interface_with_schema(loaded.gate_fn, schema)
1215 .map_err(|e| anyhow::anyhow!("failed to repack postprocessed AIGER interface: {}", e))?;
1216 Ok(LoadedPostprocessedGateFn {
1217 gate_fn,
1218 output_bytes,
1219 })
1220}
1221
1222pub(crate) fn postprocess_gate_fn_for_artifact(
1225 gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1226 schema: &GateFnInterfaceSchema,
1227 g8r_evaluation_mode: &G8rEvaluationMode,
1228 canonical_g8r_options: &CanonicalG8rOptions,
1229 compute_graph_logical_effort: bool,
1230) -> Result<PostprocessedAigArtifact> {
1231 let loaded = postprocess_gate_fn_with_external_program(gate_fn, schema, g8r_evaluation_mode)?;
1232 let stats = get_summary_stats::get_aig_stats(&loaded.gate_fn);
1233 let graph_logical_effort_worst_case_delay = compute_graph_logical_effort.then(|| {
1234 analyze_graph_logical_effort(
1235 &loaded.gate_fn,
1236 &GraphLogicalEffortOptions {
1237 beta1: canonical_g8r_options.graph_logical_effort_beta1,
1238 beta2: canonical_g8r_options.graph_logical_effort_beta2,
1239 },
1240 )
1241 .delay
1242 });
1243 Ok(PostprocessedAigArtifact {
1244 bytes: loaded.output_bytes,
1245 stats,
1246 graph_logical_effort_worst_case_delay,
1247 })
1248}
1249
1250fn compute_post_stats_for_gate_fn(
1252 gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1253 compute_graph_logical_effort: bool,
1254 compute_gate_output_toggles: bool,
1255 compute_weighted_switching: bool,
1256 toggle_stimulus: Option<&[Vec<IrBits>]>,
1257 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1258 canonical_g8r_options: &CanonicalG8rOptions,
1259) -> Result<G8rPostStats> {
1260 let stats = get_summary_stats::get_aig_stats(gate_fn);
1261 let le_graph_milli = if compute_graph_logical_effort {
1262 let graph_le = analyze_graph_logical_effort(
1263 gate_fn,
1264 &GraphLogicalEffortOptions {
1265 beta1: canonical_g8r_options.graph_logical_effort_beta1,
1266 beta2: canonical_g8r_options.graph_logical_effort_beta2,
1267 },
1268 );
1269 graph_le_delay_to_milli(graph_le.delay)
1270 } else {
1271 0
1272 };
1273 let (gate_output_toggles, weighted_switching_milli) = compute_toggle_stats_for_gate_fn(
1274 gate_fn,
1275 compute_gate_output_toggles,
1276 compute_weighted_switching,
1277 toggle_stimulus,
1278 weighted_switching_options,
1279 )?;
1280 Ok(G8rPostStats {
1281 and_nodes: stats.and_nodes,
1282 depth: stats.max_depth,
1283 le_graph_milli,
1284 gate_output_toggles,
1285 weighted_switching_milli,
1286 })
1287}
1288
1289fn compute_toggle_stats_for_gate_fn(
1290 gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1291 compute_gate_output_toggles: bool,
1292 compute_weighted_switching: bool,
1293 toggle_stimulus: Option<&[Vec<IrBits>]>,
1294 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1295) -> Result<(usize, u128)> {
1296 if !compute_gate_output_toggles && !compute_weighted_switching {
1297 return Ok((0, 0));
1298 }
1299 let batch = validate_toggle_stimulus_for_gate_fn(gate_fn, toggle_stimulus)?;
1300 let mut gate_output_toggles = 0usize;
1301 let mut weighted_switching_milli = 0u128;
1302 if compute_weighted_switching {
1303 let weighted_stats =
1304 count_toggles::count_weighted_switching(gate_fn, batch, weighted_switching_options);
1305 weighted_switching_milli = weighted_stats.weighted_switching_milli;
1306 if compute_gate_output_toggles {
1307 gate_output_toggles = weighted_stats.gate_output_toggles;
1308 }
1309 }
1310 if compute_gate_output_toggles && !compute_weighted_switching {
1311 gate_output_toggles = count_toggles::count_toggles(gate_fn, batch).gate_output_toggles;
1312 }
1313 Ok((gate_output_toggles, weighted_switching_milli))
1314}
1315
1316fn validate_toggle_stimulus_for_gate_fn<'a>(
1317 gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1318 toggle_stimulus: Option<&'a [Vec<IrBits>]>,
1319) -> Result<&'a [Vec<IrBits>]> {
1320 let batch = toggle_stimulus.ok_or_else(|| {
1321 anyhow::anyhow!("toggle-based objective requires prepared toggle stimulus")
1322 })?;
1323 if batch.len() < 2 {
1324 return Err(anyhow::anyhow!(
1325 "toggle stimulus must contain at least two samples; got {}",
1326 batch.len()
1327 ));
1328 }
1329 xlsynth_g8r::aig_sim::gate_simd::validate_ordered_batch_inputs(gate_fn, batch)
1330 .map_err(anyhow::Error::msg)?;
1331 Ok(batch)
1332}
1333
1334pub fn lower_toggle_stimulus_for_fn(samples: &[IrValue], f: &IrFn) -> Result<Vec<Vec<IrBits>>> {
1337 if samples.len() < 2 {
1338 return Err(anyhow::anyhow!(
1339 "toggle stimulus must contain at least two samples; got {}",
1340 samples.len()
1341 ));
1342 }
1343
1344 let mut lowered: Vec<Vec<IrBits>> = Vec::with_capacity(samples.len());
1345 for (sample_idx, tuple_val) in samples.iter().enumerate() {
1346 let elems = tuple_val.get_elements().map_err(|e| {
1347 anyhow::anyhow!("sample {} is not a tuple value: {}", sample_idx + 1, e)
1348 })?;
1349 if elems.len() != f.params.len() {
1350 return Err(anyhow::anyhow!(
1351 "sample {} tuple arity mismatch: expected {}, got {}",
1352 sample_idx + 1,
1353 f.params.len(),
1354 elems.len()
1355 ));
1356 }
1357
1358 let mut sample_bits = Vec::with_capacity(f.params.len());
1359 for (param_idx, (elem, param)) in elems.iter().zip(f.params.iter()).enumerate() {
1360 let mut flat_bits: Vec<bool> = Vec::with_capacity(param.ty.bit_count());
1361 flatten_ir_value_to_lsb0_bits_for_type(elem, ¶m.ty, &mut flat_bits).map_err(
1362 |e| {
1363 anyhow::anyhow!(
1364 "sample {} param {} ('{}') incompatible with {}: {}",
1365 sample_idx + 1,
1366 param_idx,
1367 param.name,
1368 param.ty.to_string(),
1369 e
1370 )
1371 },
1372 )?;
1373 if flat_bits.len() != param.ty.bit_count() {
1374 return Err(anyhow::anyhow!(
1375 "sample {} param {} ('{}') flattened width mismatch: expected {}, got {}",
1376 sample_idx + 1,
1377 param_idx,
1378 param.name,
1379 param.ty.bit_count(),
1380 flat_bits.len()
1381 ));
1382 }
1383 sample_bits.push(IrBits::from_lsb_is_0(&flat_bits));
1384 }
1385 lowered.push(sample_bits);
1386 }
1387 Ok(lowered)
1388}
1389
1390pub type McmcStats = SharedMcmcStats<PirTransformKind>;
1393pub type IterationOutcomeDetails = xlsynth_mcmc::IterationOutcomeDetails<PirTransformKind>;
1394pub type McmcIterationOutput = SharedMcmcIterationOutput<IrFn, Cost, PirTransformKind>;
1395pub type McmcOptions = SharedMcmcOptions;
1396
1397#[derive(Default)]
1404pub struct EvalFnBaselineResults {
1405 samples: Vec<Vec<IrValue>>,
1406 expected_values: Vec<Result<IrValue, ()>>,
1407 random_samples: usize,
1408 param_types: Vec<PirType>,
1409}
1410
1411impl EvalFnBaselineResults {
1412 fn clear(&mut self) {
1413 self.samples.clear();
1414 self.expected_values.clear();
1415 self.random_samples = 0;
1416 self.param_types.clear();
1417 }
1418
1419 fn matches_signature(&self, baseline: &IrFn, random_samples: usize) -> bool {
1420 self.random_samples == random_samples
1421 && self.param_types.len() == baseline.params.len()
1422 && self
1423 .param_types
1424 .iter()
1425 .zip(baseline.params.iter())
1426 .all(|(cached_ty, param)| cached_ty == ¶m.ty)
1427 }
1428
1429 fn populate_from_baseline<R: Rng>(
1430 &mut self,
1431 baseline: &IrFn,
1432 rng: &mut R,
1433 random_samples: usize,
1434 ) -> Result<()> {
1435 self.clear();
1436 self.random_samples = random_samples;
1437 self.param_types = baseline.params.iter().map(|p| p.ty.clone()).collect();
1438
1439 self.samples.push(make_oracle_args(
1441 &baseline.params,
1442 "all-zeros",
1443 make_all_zeros_value,
1444 )?);
1445 self.samples.push(make_oracle_args(
1446 &baseline.params,
1447 "all-ones",
1448 make_all_ones_value,
1449 )?);
1450
1451 for _ in 0..random_samples {
1452 self.samples
1453 .push(make_oracle_args(&baseline.params, "random", |ty| {
1454 arbitrary_value_for_type(rng, ty)
1455 })?);
1456 }
1457
1458 self.expected_values = self
1459 .samples
1460 .iter()
1461 .map(|args| eval_fn_safe(baseline, args))
1462 .collect();
1463 Ok(())
1464 }
1465
1466 fn ensure_populated<R: Rng>(
1467 &mut self,
1468 baseline_if_empty: &IrFn,
1469 rng: &mut R,
1470 random_samples: usize,
1471 ) -> Result<()> {
1472 if !self.matches_signature(baseline_if_empty, random_samples) || self.samples.is_empty() {
1473 self.populate_from_baseline(baseline_if_empty, rng, random_samples)?;
1474 }
1475 Ok(())
1476 }
1477}
1478
1479pub struct PirMcmcContext<'a> {
1481 pub rng: &'a mut Pcg64Mcg,
1482 pub all_transforms: Vec<Box<dyn PirTransform>>,
1483 pub weights: Vec<f64>,
1484 pub enable_formal_oracle: bool,
1485 pub oracle_baseline_cache: EvalFnBaselineResults,
1486}
1487
1488#[derive(Clone, Debug)]
1490pub struct RunOptions {
1491 pub max_iters: u64,
1493 pub threads: u64,
1495 pub chain_strategy: ChainStrategy,
1497 pub checkpoint_iters: u64,
1501 pub progress_iters: u64,
1503 pub seed: u64,
1505 pub initial_temperature: f64,
1507 pub objective: Objective,
1509 pub extension_costing_mode: ExtensionCostingMode,
1511 pub g8r_evaluation_mode: G8rEvaluationMode,
1513 pub canonical_g8r_options: CanonicalG8rOptions,
1515 pub max_allowed_depth: Option<usize>,
1517 pub max_allowed_area: Option<usize>,
1519 pub weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1522 pub enable_formal_oracle: bool,
1527
1528 pub trajectory_dir: Option<PathBuf>,
1533
1534 pub toggle_stimulus: Option<Vec<IrValue>>,
1537}
1538
1539#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1545pub enum CheckpointKind {
1546 Periodic,
1548 GlobalBestUpdate,
1550}
1551
1552#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1555pub struct CheckpointMsg {
1556 pub chain_no: usize,
1557 pub global_iter: u64,
1558 pub kind: CheckpointKind,
1559}
1560
1561pub struct PirMcmcResult {
1563 pub best_fn: IrFn,
1564 pub best_cost: Cost,
1565 pub stats: McmcStats,
1566}
1567
1568#[derive(Clone, Debug)]
1570pub enum PirMcmcProvenanceAction {
1571 AcceptedRewrite {
1573 action_index: usize,
1575 chain_no: usize,
1577 global_iter: u64,
1579 transform_kind: PirTransformKind,
1581 state: IrFn,
1583 cost: Cost,
1585 },
1586 XlsOptimizedHandoff {
1589 action_index: usize,
1591 chain_no: usize,
1593 global_iter: u64,
1595 state: IrFn,
1597 cost: Cost,
1599 },
1600}
1601
1602impl PirMcmcProvenanceAction {
1603 fn action_index(&self) -> usize {
1604 match self {
1605 Self::AcceptedRewrite { action_index, .. }
1606 | Self::XlsOptimizedHandoff { action_index, .. } => *action_index,
1607 }
1608 }
1609
1610 fn state(&self) -> &IrFn {
1611 match self {
1612 Self::AcceptedRewrite { state, .. } | Self::XlsOptimizedHandoff { state, .. } => state,
1613 }
1614 }
1615
1616 fn cost(&self) -> Cost {
1617 match self {
1618 Self::AcceptedRewrite { cost, .. } | Self::XlsOptimizedHandoff { cost, .. } => *cost,
1619 }
1620 }
1621
1622 fn transform_kind(&self) -> Option<PirTransformKind> {
1623 match self {
1624 Self::AcceptedRewrite { transform_kind, .. } => Some(*transform_kind),
1625 Self::XlsOptimizedHandoff { .. } => None,
1626 }
1627 }
1628}
1629
1630#[derive(Clone, Debug)]
1632pub struct PirMcmcArtifact {
1633 pub origin_fn: IrFn,
1635 pub origin_cost: Cost,
1637 pub run_options: RunOptions,
1639 pub raw_winner_fn: IrFn,
1641 pub raw_winner_cost: Cost,
1643 pub winning_provenance: Vec<PirMcmcProvenanceAction>,
1645}
1646
1647#[derive(Clone, Copy, Debug)]
1649pub struct PirMcmcPrefixMinimizeOptions {
1650 pub retained_win_fraction: f64,
1652}
1653
1654#[derive(Clone, Debug)]
1656pub struct PirMcmcPrefixMinimizeResult {
1657 pub witness_fn: IrFn,
1659 pub witness_cost: Cost,
1661 pub provenance_action_count: usize,
1663 pub original_winning_provenance_len: usize,
1665 pub requested_retained_win_fraction: f64,
1667 pub actual_retained_win_fraction: f64,
1669 pub origin_metric: u128,
1671 pub winner_metric: u128,
1673 pub witness_metric: u128,
1675}
1676
1677#[derive(Clone, Copy, Debug)]
1679pub struct PirMcmcBudgetFrontierOptions {
1680 pub budget_step: usize,
1682 pub max_actions: usize,
1684 pub rollouts_per_budget: usize,
1686 pub seed: u64,
1688 pub witness_kind_boost: f64,
1691 pub proposal_attempts_per_rewrite: usize,
1693}
1694
1695impl PirMcmcBudgetFrontierOptions {
1696 pub const DEFAULT_WITNESS_KIND_BOOST: f64 = 4.0;
1697 pub const DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE: usize = 64;
1698}
1699
1700#[derive(Clone, Debug)]
1702pub struct PirMcmcBudgetWitness {
1703 pub witness_fn: IrFn,
1704 pub witness_cost: Cost,
1705 pub provenance_action_count: usize,
1706 pub metric: u128,
1707 pub absolute_win: u128,
1708 pub win_percent_vs_origin: f64,
1709 pub retained_win_fraction: f64,
1710}
1711
1712#[derive(Clone, Debug)]
1714pub struct PirMcmcBudgetFrontierPoint {
1715 pub action_budget: usize,
1716 pub guided: PirMcmcBudgetWitness,
1717 pub prefix_baseline: PirMcmcBudgetWitness,
1718}
1719
1720#[derive(Clone, Debug)]
1722pub struct PirMcmcBudgetFrontierResult {
1723 pub origin_metric: u128,
1724 pub winner_metric: u128,
1725 pub original_winning_provenance_len: usize,
1726 pub points: Vec<PirMcmcBudgetFrontierPoint>,
1727}
1728
1729struct PirMcmcArtifactRunOutput {
1730 result: PirMcmcResult,
1731 artifact: PirMcmcArtifact,
1732}
1733
1734const PIR_MCMC_ARTIFACT_DIR_NAME: &str = "winning-lineage";
1735const PIR_MCMC_ARTIFACT_MANIFEST_FILE: &str = "manifest.json";
1736const PIR_MCMC_ARTIFACT_STATES_DIR_NAME: &str = "states";
1737const PIR_MCMC_ARTIFACT_SCHEMA_VERSION: u32 = 4;
1738
1739pub struct LoadedPirMcmcArtifact {
1741 pub artifact: PirMcmcArtifact,
1742 pub package_template: PirPackage,
1743 pub top_fn_name: String,
1744}
1745
1746#[derive(Debug, Serialize, Deserialize)]
1747struct PersistedPirMcmcArtifactManifest {
1748 schema_version: u32,
1749 top_fn_name: String,
1750 run_options: PersistedRunOptions,
1751 origin: PersistedArtifactState,
1752 raw_winner: PersistedArtifactState,
1753 winning_provenance: Vec<PersistedPirMcmcProvenanceAction>,
1754}
1755
1756#[derive(Debug, Serialize, Deserialize)]
1757struct PersistedRunOptions {
1758 max_iters: u64,
1759 threads: u64,
1760 chain_strategy: String,
1761 checkpoint_iters: u64,
1762 progress_iters: u64,
1763 seed: u64,
1764 initial_temperature: f64,
1765 objective: String,
1766 extension_costing_mode: String,
1767 g8r_evaluation_mode: G8rEvaluationMode,
1768 canonical_g8r_options: CanonicalG8rOptions,
1769 max_allowed_depth: Option<usize>,
1770 max_allowed_area: Option<usize>,
1771 switching_beta1: f64,
1772 switching_beta2: f64,
1773 switching_primary_output_load: f64,
1774 enable_formal_oracle: bool,
1775}
1776
1777#[derive(Debug, Serialize, Deserialize)]
1778struct PersistedArtifactState {
1779 file: String,
1780 cost: Cost,
1781}
1782
1783#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
1784#[serde(rename_all = "snake_case")]
1785enum PersistedPirMcmcProvenanceActionKind {
1786 AcceptedRewrite,
1787 XlsOptimizedHandoff,
1788}
1789
1790#[derive(Debug, Serialize, Deserialize)]
1791struct PersistedPirMcmcProvenanceAction {
1792 kind: PersistedPirMcmcProvenanceActionKind,
1793 action_index: usize,
1794 chain_no: usize,
1795 global_iter: u64,
1796 transform_kind: Option<PirTransformKind>,
1797 state: PersistedArtifactState,
1798}
1799
1800#[derive(Clone, Debug)]
1805pub struct AcceptedSampleMsg {
1806 pub chain_no: usize,
1807 pub global_iter: u64,
1808 pub digest: [u8; 32],
1809 pub cost: Cost,
1810 pub func: IrFn,
1811}
1812
1813fn compute_fn_structural_digest(f: &IrFn) -> Option<[u8; 32]> {
1814 let result = catch_unwind(AssertUnwindSafe(|| {
1815 let ret = f
1816 .ret_node_ref
1817 .expect("PIR functions must have a return node");
1818 let (entries, _depths) = collect_structural_entries(f);
1819 let h = entries[ret.index].hash.as_bytes();
1820 let mut out = [0u8; 32];
1821 out.copy_from_slice(h);
1822 out
1823 }));
1824 match result {
1825 Ok(v) => Some(v),
1826 Err(_panic) => None,
1827 }
1828}
1829
1830fn canonicalize_fn_for_sample(f: &IrFn) -> Result<IrFn> {
1831 let mut f = f.clone();
1832 compact_and_toposort_in_place(&mut f)
1833 .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
1834 Ok(f)
1835}
1836
1837fn iteration_outcome_tag<K>(o: &xlsynth_mcmc::IterationOutcomeDetails<K>) -> &'static str {
1838 match o {
1839 xlsynth_mcmc::IterationOutcomeDetails::CandidateFailure => "CandidateFailure",
1840 xlsynth_mcmc::IterationOutcomeDetails::ApplyFailure => "ApplyFailure",
1841 xlsynth_mcmc::IterationOutcomeDetails::SimFailure => "SimFailure",
1842 xlsynth_mcmc::IterationOutcomeDetails::OracleFailure => "OracleFailure",
1843 xlsynth_mcmc::IterationOutcomeDetails::MetropolisReject => "MetropolisReject",
1844 xlsynth_mcmc::IterationOutcomeDetails::Accepted { .. } => "Accepted",
1845 }
1846}
1847
1848fn hash_to_hex(bytes: &[u8; 32]) -> String {
1849 let mut s = String::with_capacity(64);
1850 for b in bytes.iter() {
1851 s.push_str(&format!("{:02x}", b));
1852 }
1853 s
1854}
1855
1856#[derive(Clone, Debug)]
1857struct ProvenancedChainState {
1858 search_fn: IrFn,
1860 search_provenance: Vec<PirMcmcProvenanceAction>,
1862 raw_winner_fn: IrFn,
1864 raw_winner_cost: Cost,
1866 raw_winner_provenance: Vec<PirMcmcProvenanceAction>,
1868 pending_handoff: Option<(usize, u64)>,
1871}
1872
1873impl ProvenancedChainState {
1874 fn origin(origin_fn: IrFn, origin_cost: Cost) -> Self {
1875 Self {
1876 search_fn: origin_fn.clone(),
1877 search_provenance: Vec::new(),
1878 raw_winner_fn: origin_fn,
1879 raw_winner_cost: origin_cost,
1880 raw_winner_provenance: Vec::new(),
1881 pending_handoff: None,
1882 }
1883 }
1884
1885 fn with_xls_optimized_handoff(&self, receiving_chain_no: usize, global_iter: u64) -> Self {
1886 let mut next = self.clone();
1887 next.pending_handoff = Some((receiving_chain_no, global_iter));
1888 next
1889 }
1890}
1891
1892struct PirSegmentRunner {
1893 objective: Objective,
1894 extension_costing_mode: ExtensionCostingMode,
1895 g8r_evaluation_mode: G8rEvaluationMode,
1896 canonical_g8r_options: CanonicalG8rOptions,
1897 weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1898 initial_temperature: f64,
1899 constraints: ConstraintLimits,
1900 enable_formal_oracle: bool,
1901 progress_iters: u64,
1902 checkpoint_iters: u64,
1903 checkpoint_tx: Option<Sender<CheckpointMsg>>,
1904 accepted_sample_tx: Option<Sender<AcceptedSampleMsg>>,
1905 shared_best: Option<Arc<Best>>,
1906 trajectory_dir: Option<PathBuf>,
1907 prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
1908}
1909
1910type PirTransformFactory = Arc<dyn Fn() -> Vec<Box<dyn PirTransform>> + Send + Sync>;
1911
1912struct PirArtifactSegmentRunner {
1913 objective: Objective,
1914 extension_costing_mode: ExtensionCostingMode,
1915 g8r_evaluation_mode: G8rEvaluationMode,
1916 canonical_g8r_options: CanonicalG8rOptions,
1917 weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1918 initial_temperature: f64,
1919 constraints: ConstraintLimits,
1920 enable_formal_oracle: bool,
1921 progress_iters: u64,
1922 checkpoint_iters: u64,
1923 checkpoint_tx: Option<Sender<CheckpointMsg>>,
1924 shared_best: Option<Arc<Best>>,
1925 trajectory_dir: Option<PathBuf>,
1926 prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
1927 transform_factory: PirTransformFactory,
1928}
1929
1930pub fn mcmc_iteration(
1932 current_fn: IrFn, current_cost: Cost,
1935 best_fn: &mut IrFn, best_cost: &mut Cost, best_score: &mut SearchScore,
1938 context: &mut PirMcmcContext,
1939 temp: f64,
1940 objective: Objective,
1941 extension_costing_mode: ExtensionCostingMode,
1942 g8r_evaluation_mode: &G8rEvaluationMode,
1943 canonical_g8r_options: &CanonicalG8rOptions,
1944 toggle_stimulus: Option<&[Vec<IrBits>]>,
1945 weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1946 constraints: ConstraintLimits,
1947) -> McmcIterationOutput {
1948 let mut iteration_best_updated = false;
1949
1950 if context.all_transforms.is_empty() {
1951 return McmcIterationOutput {
1953 output_state: current_fn,
1954 output_cost: current_cost,
1955 best_updated: false,
1956 outcome: IterationOutcomeDetails::CandidateFailure,
1957 oracle_time_micros: 0,
1958 transform_always_equivalent: true,
1959 transform: None,
1960 };
1961 }
1962
1963 let dist = WeightedIndex::new(&context.weights).expect("non-empty weights");
1964 let chosen_transform_idx = dist.sample(context.rng);
1965 let chosen_transform = &mut context.all_transforms[chosen_transform_idx];
1966 let current_transform_kind = chosen_transform.kind();
1967
1968 let mut candidates = chosen_transform.find_candidates(¤t_fn);
1969 if !context.enable_formal_oracle {
1970 candidates.retain(|c| c.always_equivalent);
1971 }
1972
1973 log::trace!(
1974 "Found {} PIR candidates for {:?}",
1975 candidates.len(),
1976 current_transform_kind,
1977 );
1978
1979 if candidates.is_empty() {
1980 return McmcIterationOutput {
1981 output_state: current_fn,
1982 output_cost: current_cost,
1983 best_updated: false,
1984 outcome: IterationOutcomeDetails::CandidateFailure,
1985 oracle_time_micros: 0,
1986 transform_always_equivalent: true,
1987 transform: Some(current_transform_kind),
1988 };
1989 }
1990
1991 let chosen_candidate = candidates.choose(context.rng).unwrap();
1992
1993 log::trace!("Chosen PIR candidate: {:?}", chosen_candidate);
1994
1995 let mut candidate_fn = current_fn.clone();
1996
1997 log::trace!(
1998 "Applying PIR transform {:?} at {:?}",
1999 current_transform_kind,
2000 chosen_candidate.location
2001 );
2002
2003 match chosen_transform.apply(&mut candidate_fn, &chosen_candidate.location) {
2004 Ok(()) => {
2005 if let Err(e) = compact_and_toposort_in_place(&mut candidate_fn) {
2010 log::debug!(
2011 "[pir-mcmc] compact/toposort failed for '{}' after {:?} at {:?}: {}; \
2012 rejecting candidate",
2013 candidate_fn.name,
2014 current_transform_kind,
2015 chosen_candidate.location,
2016 e
2017 );
2018 return McmcIterationOutput {
2019 output_state: current_fn,
2020 output_cost: current_cost,
2021 best_updated: false,
2022 outcome: IterationOutcomeDetails::CandidateFailure,
2023 oracle_time_micros: 0,
2024 transform_always_equivalent: chosen_candidate.always_equivalent,
2025 transform: Some(current_transform_kind),
2026 };
2027 }
2028
2029 log::trace!("PIR transform applied successfully; determining cost...");
2030 let (is_equiv, oracle_time_micros) = if chosen_candidate.always_equivalent {
2031 (true, 0u128)
2033 } else {
2034 let oracle_start = Instant::now();
2035 let ok = pir_equiv_oracle(
2043 ¤t_fn,
2044 &candidate_fn,
2045 context.rng,
2046 DEFAULT_ORACLE_RANDOM_SAMPLES,
2047 context.enable_formal_oracle,
2048 &mut context.oracle_baseline_cache,
2049 );
2050 let micros = oracle_start.elapsed().as_micros();
2051 (ok, micros)
2052 };
2053
2054 if !is_equiv {
2055 McmcIterationOutput {
2056 output_state: current_fn,
2057 output_cost: current_cost,
2058 best_updated: false,
2059 outcome: IterationOutcomeDetails::OracleFailure,
2060 oracle_time_micros,
2061 transform_always_equivalent: chosen_candidate.always_equivalent,
2062 transform: Some(current_transform_kind),
2063 }
2064 } else {
2065 let cost_start = Instant::now();
2066 let new_candidate_cost =
2067 match cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
2068 &candidate_fn,
2069 objective,
2070 toggle_stimulus,
2071 weighted_switching_options,
2072 extension_costing_mode,
2073 g8r_evaluation_mode,
2074 canonical_g8r_options,
2075 ) {
2076 Ok(c) => c,
2077 Err(e) => {
2078 let sim_micros = cost_start.elapsed().as_micros();
2079 let msg = e.to_string();
2080 let is_invalid_bit_slice = msg
2081 .contains("Expected operand 0 of bit_slice")
2082 || msg.contains("invalid bit_slice");
2083
2084 if is_invalid_bit_slice {
2085 let n = INVALID_BIT_SLICE_WARN_COUNT
2092 .fetch_add(1, Ordering::Relaxed)
2093 .saturating_add(1);
2094 if n <= INVALID_BIT_SLICE_WARN_LIMIT {
2095 log::warn!(
2096 "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate (invalid bit_slice; warning {}/{})",
2097 candidate_fn.name,
2098 objective,
2099 e,
2100 n,
2101 INVALID_BIT_SLICE_WARN_LIMIT
2102 );
2103 } else {
2104 log::debug!(
2105 "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate (invalid bit_slice; further occurrences silenced to debug)",
2106 candidate_fn.name,
2107 objective,
2108 e
2109 );
2110 }
2111 } else {
2112 log::warn!(
2113 "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate",
2114 candidate_fn.name,
2115 objective,
2116 e
2117 );
2118 }
2119 return McmcIterationOutput {
2120 output_state: current_fn,
2121 output_cost: current_cost,
2122 best_updated: false,
2123 outcome: IterationOutcomeDetails::SimFailure,
2124 oracle_time_micros: sim_micros,
2125 transform_always_equivalent: chosen_candidate.always_equivalent,
2126 transform: Some(current_transform_kind),
2127 };
2128 }
2129 };
2130
2131 let current_score = search_score(¤t_cost, objective, constraints);
2132 let new_score = search_score(&new_candidate_cost, objective, constraints);
2133 let curr_metric_u128 = objective.metric(¤t_cost);
2134 let new_metric_u128 = objective.metric(&new_candidate_cost);
2135 let accept = match (current_score.violation, new_score.violation) {
2136 (Some(_), None) => true,
2137 (None, Some(_)) => false,
2138 (Some(curr_violation), Some(new_violation)) => metropolis_accept(
2139 repair_energy(curr_violation) as f64,
2140 repair_energy(new_violation) as f64,
2141 temp,
2142 context.rng,
2143 ),
2144 (None, None) => {
2145 if new_metric_u128 == curr_metric_u128
2146 && new_candidate_cost.pir_nodes > current_cost.pir_nodes
2147 {
2148 metropolis_accept(
2151 current_cost.pir_nodes as f64,
2152 new_candidate_cost.pir_nodes as f64,
2153 temp,
2154 context.rng,
2155 )
2156 } else {
2157 metropolis_accept(
2158 curr_metric_u128 as f64,
2159 new_metric_u128 as f64,
2160 temp,
2161 context.rng,
2162 )
2163 }
2164 }
2165 };
2166
2167 if accept {
2168 if new_score < *best_score {
2169 *best_fn = match optimize_pir_fn_via_xls_with_extension_mode(
2174 &candidate_fn,
2175 extension_costing_mode,
2176 ) {
2177 Ok(opt) => opt,
2178 Err(e) => {
2179 log::warn!(
2180 "[pir-mcmc] failed to optimize new best candidate '{}': {}; storing unoptimized function",
2181 candidate_fn.name,
2182 e
2183 );
2184 candidate_fn.clone()
2185 }
2186 };
2187 *best_cost = new_candidate_cost;
2188 *best_score = new_score;
2189 iteration_best_updated = true;
2190 }
2191 McmcIterationOutput {
2192 output_state: candidate_fn,
2193 output_cost: new_candidate_cost,
2194 best_updated: iteration_best_updated,
2195 outcome: IterationOutcomeDetails::Accepted {
2196 kind: current_transform_kind,
2197 },
2198 oracle_time_micros,
2199 transform_always_equivalent: chosen_candidate.always_equivalent,
2200 transform: Some(current_transform_kind),
2201 }
2202 } else {
2203 McmcIterationOutput {
2204 output_state: current_fn,
2205 output_cost: current_cost,
2206 best_updated: false,
2207 outcome: IterationOutcomeDetails::MetropolisReject,
2208 oracle_time_micros,
2209 transform_always_equivalent: chosen_candidate.always_equivalent,
2210 transform: Some(current_transform_kind),
2211 }
2212 }
2213 }
2214 }
2215 Err(e) => {
2216 log::debug!(
2217 "Error applying PIR transform {:?}: {:?}",
2218 current_transform_kind,
2219 e
2220 );
2221 McmcIterationOutput {
2222 output_state: current_fn,
2223 output_cost: current_cost,
2224 best_updated: false,
2225 outcome: IterationOutcomeDetails::ApplyFailure,
2226 oracle_time_micros: 0,
2227 transform_always_equivalent: chosen_candidate.always_equivalent,
2228 transform: Some(current_transform_kind),
2229 }
2230 }
2231 }
2232}
2233
2234fn make_all_zeros_value(ty: &PirType) -> Result<IrValue> {
2235 match ty {
2236 PirType::Token => Ok(IrValue::make_token()),
2237 PirType::Bits(width) => {
2238 if *width == 0 {
2239 Ok(IrValue::from_bits(&IrBits::make_ubits(0, 0).unwrap()))
2240 } else {
2241 Ok(IrValue::from_bits(&IrBits::make_ubits(*width, 0).unwrap()))
2242 }
2243 }
2244 PirType::Tuple(elem_types) => {
2245 let elems: Result<Vec<IrValue>> =
2246 elem_types.iter().map(|t| make_all_zeros_value(t)).collect();
2247 Ok(IrValue::make_tuple(&elems?))
2248 }
2249 PirType::Array(arr) => {
2250 if arr.element_count == 0 {
2251 return Err(anyhow::anyhow!(
2252 "cannot construct all-zeros oracle sample for zero-length array type {}",
2253 ty
2254 ));
2255 }
2256 let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2257 for _ in 0..arr.element_count {
2258 elems.push(make_all_zeros_value(&arr.element_type)?);
2259 }
2260 IrValue::make_array(&elems).map_err(|e| {
2261 anyhow::anyhow!("failed to construct all-zeros array oracle sample: {}", e)
2262 })
2263 }
2264 }
2265}
2266
2267fn make_all_ones_value(ty: &PirType) -> Result<IrValue> {
2268 match ty {
2269 PirType::Token => Ok(IrValue::make_token()),
2270 PirType::Bits(width) => {
2271 if *width == 0 {
2272 Ok(IrValue::from_bits(&IrBits::make_ubits(0, 0).unwrap()))
2273 } else if *width <= 64 {
2274 let mask = if *width == 64 {
2275 u64::MAX
2276 } else {
2277 (1u64 << *width) - 1
2278 };
2279 Ok(IrValue::from_bits(
2280 &IrBits::make_ubits(*width, mask).unwrap(),
2281 ))
2282 } else {
2283 let ones: Vec<bool> = vec![true; *width];
2285 Ok(IrValue::from_bits(&IrBits::from_lsb_is_0(&ones)))
2286 }
2287 }
2288 PirType::Tuple(elem_types) => {
2289 let elems: Result<Vec<IrValue>> =
2290 elem_types.iter().map(|t| make_all_ones_value(t)).collect();
2291 Ok(IrValue::make_tuple(&elems?))
2292 }
2293 PirType::Array(arr) => {
2294 if arr.element_count == 0 {
2295 return Err(anyhow::anyhow!(
2296 "cannot construct all-ones oracle sample for zero-length array type {}",
2297 ty
2298 ));
2299 }
2300 let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2301 for _ in 0..arr.element_count {
2302 elems.push(make_all_ones_value(&arr.element_type)?);
2303 }
2304 IrValue::make_array(&elems).map_err(|e| {
2305 anyhow::anyhow!("failed to construct all-ones array oracle sample: {}", e)
2306 })
2307 }
2308 }
2309}
2310
2311fn arbitrary_value_for_type<R: Rng>(rng: &mut R, ty: &PirType) -> Result<IrValue> {
2312 match ty {
2313 PirType::Token => Ok(IrValue::make_token()),
2314 PirType::Bits(width) => {
2315 let bits = generate_biased_irbits_with_rng(rng, *width);
2316 Ok(IrValue::from_bits(&bits))
2317 }
2318 PirType::Tuple(elem_types) => {
2319 let elems: Result<Vec<IrValue>> = elem_types
2320 .iter()
2321 .map(|t| arbitrary_value_for_type(rng, t))
2322 .collect();
2323 Ok(IrValue::make_tuple(&elems?))
2324 }
2325 PirType::Array(arr) => {
2326 if arr.element_count == 0 {
2327 return Err(anyhow::anyhow!(
2328 "cannot construct random oracle sample for zero-length array type {}",
2329 ty
2330 ));
2331 }
2332 let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2333 for _ in 0..arr.element_count {
2334 elems.push(arbitrary_value_for_type(rng, &arr.element_type)?);
2335 }
2336 IrValue::make_array(&elems).map_err(|e| {
2337 anyhow::anyhow!("failed to construct random array oracle sample: {}", e)
2338 })
2339 }
2340 }
2341}
2342
2343fn eval_fn_safe(f: &IrFn, args: &[IrValue]) -> Result<IrValue, ()> {
2344 let result = catch_unwind(AssertUnwindSafe(|| {
2353 eval_fn_assuming_node_index_topological(f, args)
2354 }));
2355 match result {
2356 Ok(FnEvalResult::Success(s)) => Ok(s.value),
2357 Ok(FnEvalResult::Failure(_f)) => Err(()),
2358 Err(_panic) => Err(()),
2359 }
2360}
2361
2362fn make_oracle_args<F>(params: &[PirParam], label: &str, mut make_value: F) -> Result<Vec<IrValue>>
2363where
2364 F: FnMut(&PirType) -> Result<IrValue>,
2365{
2366 params
2367 .iter()
2368 .map(|p| make_value(&p.ty))
2369 .collect::<Result<Vec<_>>>()
2370 .map_err(|e| anyhow::anyhow!("failed to construct {} oracle sample args: {}", label, e))
2371}
2372
2373fn pir_equiv_oracle<R: Rng>(
2374 lhs: &IrFn,
2375 rhs: &IrFn,
2376 rng: &mut R,
2377 random_samples: usize,
2378 enable_formal_oracle: bool,
2379 baseline_cache: &mut EvalFnBaselineResults,
2380) -> bool {
2381 if lhs.params.len() != rhs.params.len() || lhs.ret_ty != rhs.ret_ty {
2382 return false;
2383 }
2384 for (lp, rp) in lhs.params.iter().zip(rhs.params.iter()) {
2385 if lp.ty != rp.ty {
2386 return false;
2387 }
2388 }
2389
2390 if let Err(e) = baseline_cache.ensure_populated(lhs, rng, random_samples) {
2395 log::debug!(
2396 "[pir-mcmc] failed to populate oracle baseline cache: {}; rejecting candidate",
2397 e
2398 );
2399 return false;
2400 }
2401 for (args, expected_value) in baseline_cache
2402 .samples
2403 .iter()
2404 .zip(baseline_cache.expected_values.iter())
2405 {
2406 let Ok(expected_value) = expected_value else {
2407 return false;
2408 };
2409 let Ok(rhs_value) = eval_fn_safe(rhs, args) else {
2410 return false;
2411 };
2412 if expected_value != &rhs_value {
2413 return false;
2414 }
2415 }
2416
2417 if enable_formal_oracle {
2418 {
2419 match prove_ir_fn_equiv(lhs, rhs) {
2420 EquivResult::Proved => true,
2421 EquivResult::Disproved { .. } | EquivResult::ToolchainDisproved(_) => false,
2422 EquivResult::Inconclusive(msg) => {
2423 log::warn!(
2424 "[pir-mcmc] formal oracle inconclusive for '{}' vs '{}': {}; rejecting candidate",
2425 lhs.name,
2426 rhs.name,
2427 msg
2428 );
2429 false
2430 }
2431 EquivResult::Error(msg) => {
2432 log::warn!(
2433 "[pir-mcmc] formal oracle error for '{}' vs '{}': {}; rejecting candidate",
2434 lhs.name,
2435 rhs.name,
2436 msg
2437 );
2438 false
2439 }
2440 }
2441 }
2442 } else {
2443 true
2444 }
2445}
2446
2447fn get_pir_transforms_for_run(enable_formal_oracle: bool) -> Vec<Box<dyn PirTransform>> {
2448 let mut all_transforms = get_all_pir_transforms();
2449 if !enable_formal_oracle {
2450 all_transforms.retain(|t| t.can_emit_always_equivalent_candidates());
2451 }
2452 all_transforms
2453}
2454
2455struct PreparedRun {
2456 start_fn: IrFn,
2457 prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
2458 initial_cost: Cost,
2459 effective_constraints: ConstraintLimits,
2460}
2461
2462fn prepare_run_start(mut start_fn: IrFn, options: &RunOptions) -> Result<PreparedRun> {
2465 if !options.objective.needs_toggle_stimulus() && options.toggle_stimulus.is_some() {
2466 return Err(anyhow::anyhow!(
2467 "toggle stimulus is not valid with objective {}",
2468 options.objective.value_name()
2469 ));
2470 }
2471 if options.objective.uses_postprocessed_costing()
2472 && options
2473 .g8r_evaluation_mode
2474 .external_postprocess_program()
2475 .is_none()
2476 {
2477 return Err(anyhow::anyhow!(
2478 "objective {} requires --g8r-postprocess-program",
2479 options.objective.value_name()
2480 ));
2481 }
2482 validate_constraint_configuration(
2483 options.objective,
2484 ConstraintLimits {
2485 max_delay: options.max_allowed_depth,
2486 max_area: options.max_allowed_area,
2487 },
2488 )?;
2489 compact_and_toposort_in_place(&mut start_fn)
2490 .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2491
2492 let prepared_toggle_stimulus = if options.objective.needs_toggle_stimulus() {
2493 let samples = options.toggle_stimulus.as_ref().ok_or_else(|| {
2494 anyhow::anyhow!(
2495 "objective {} requires toggle stimulus",
2496 options.objective.value_name()
2497 )
2498 })?;
2499 Some(Arc::new(lower_toggle_stimulus_for_fn(samples, &start_fn)?))
2500 } else {
2501 None
2502 };
2503
2504 let initial_cost =
2505 cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
2506 &start_fn,
2507 options.objective,
2508 prepared_toggle_stimulus.as_ref().map(|v| v.as_slice()),
2509 &options.weighted_switching_options,
2510 options.extension_costing_mode,
2511 &options.g8r_evaluation_mode,
2512 &options.canonical_g8r_options,
2513 )?;
2514 let effective_constraints = effective_constraint_limits(
2515 options.objective,
2516 ConstraintLimits {
2517 max_delay: options.max_allowed_depth,
2518 max_area: options.max_allowed_area,
2519 },
2520 &initial_cost,
2521 );
2522
2523 Ok(PreparedRun {
2524 start_fn,
2525 prepared_toggle_stimulus,
2526 initial_cost,
2527 effective_constraints,
2528 })
2529}
2530
2531pub fn validate_pir_mcmc_artifact_run_options(options: &RunOptions) -> Result<()> {
2533 if options.max_allowed_depth.is_some() || options.max_allowed_area.is_some() {
2534 return Err(anyhow::anyhow!(
2535 "run_pir_mcmc_with_artifact currently supports only unconstrained runs"
2536 ));
2537 }
2538 if options.objective.enforces_non_regressing_depth() {
2539 return Err(anyhow::anyhow!(
2540 "run_pir_mcmc_with_artifact does not yet support objectives with implicit feasibility caps; got {}",
2541 options.objective.value_name()
2542 ));
2543 }
2544 Ok(())
2545}
2546
2547fn validate_prefix_minimization_artifact(artifact: &PirMcmcArtifact) -> Result<()> {
2548 validate_pir_mcmc_artifact_run_options(&artifact.run_options)?;
2549 for (expected_index, action) in artifact.winning_provenance.iter().enumerate() {
2550 if action.action_index() != expected_index + 1 {
2551 return Err(anyhow::anyhow!(
2552 "winning provenance action indices must be contiguous from 1; expected {}, got {}",
2553 expected_index + 1,
2554 action.action_index()
2555 ));
2556 }
2557 }
2558 match artifact.winning_provenance.last() {
2559 Some(last_action) => {
2560 if last_action.cost() != artifact.raw_winner_cost
2561 || last_action.state().to_string() != artifact.raw_winner_fn.to_string()
2562 {
2563 return Err(anyhow::anyhow!(
2564 "winning provenance endpoint does not match the recorded raw winner"
2565 ));
2566 }
2567 }
2568 None => {
2569 if artifact.raw_winner_cost != artifact.origin_cost
2570 || artifact.raw_winner_fn.to_string() != artifact.origin_fn.to_string()
2571 {
2572 return Err(anyhow::anyhow!(
2573 "empty winning provenance is valid only when the raw winner is the origin"
2574 ));
2575 }
2576 }
2577 }
2578 Ok(())
2579}
2580
2581fn retained_win_fraction_for_metric(origin_metric: u128, winner_metric: u128, metric: u128) -> f64 {
2582 let total_win = origin_metric.saturating_sub(winner_metric);
2583 if total_win == 0 {
2584 return 0.0;
2585 }
2586 let retained_win = origin_metric.saturating_sub(metric);
2587 retained_win as f64 / total_win as f64
2588}
2589
2590fn win_percent_vs_origin_for_metric(origin_metric: u128, metric: u128) -> f64 {
2591 if origin_metric == 0 {
2592 return 0.0;
2593 }
2594 100.0 * origin_metric.saturating_sub(metric) as f64 / origin_metric as f64
2595}
2596
2597fn objective_supports_budget_frontier_search(objective: Objective) -> bool {
2598 !objective.needs_toggle_stimulus()
2599 && !matches!(
2600 objective,
2601 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
2602 | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
2603 )
2604}
2605
2606fn make_budget_witness(
2607 witness_fn: IrFn,
2608 witness_cost: Cost,
2609 provenance_action_count: usize,
2610 objective: Objective,
2611 origin_metric: u128,
2612 winner_metric: u128,
2613) -> PirMcmcBudgetWitness {
2614 let metric = objective.metric(&witness_cost);
2615 PirMcmcBudgetWitness {
2616 witness_fn,
2617 witness_cost,
2618 provenance_action_count,
2619 metric,
2620 absolute_win: origin_metric.saturating_sub(metric),
2621 win_percent_vs_origin: win_percent_vs_origin_for_metric(origin_metric, metric),
2622 retained_win_fraction: retained_win_fraction_for_metric(
2623 origin_metric,
2624 winner_metric,
2625 metric,
2626 ),
2627 }
2628}
2629
2630fn better_budget_witness(
2631 candidate: &PirMcmcBudgetWitness,
2632 incumbent: &PirMcmcBudgetWitness,
2633) -> bool {
2634 candidate.metric < incumbent.metric
2635 || (candidate.metric == incumbent.metric
2636 && candidate.provenance_action_count < incumbent.provenance_action_count)
2637}
2638
2639fn frontier_budgets(options: PirMcmcBudgetFrontierOptions) -> Result<Vec<usize>> {
2640 if options.budget_step == 0 {
2641 return Err(anyhow::anyhow!("budget_step must be > 0"));
2642 }
2643 if options.max_actions == 0 {
2644 return Err(anyhow::anyhow!("max_actions must be > 0"));
2645 }
2646 if options.budget_step > options.max_actions {
2647 return Err(anyhow::anyhow!(
2648 "budget_step must be <= max_actions; got step={} max={}",
2649 options.budget_step,
2650 options.max_actions
2651 ));
2652 }
2653 if options.rollouts_per_budget == 0 {
2654 return Err(anyhow::anyhow!("rollouts_per_budget must be > 0"));
2655 }
2656 if options.proposal_attempts_per_rewrite == 0 {
2657 return Err(anyhow::anyhow!("proposal_attempts_per_rewrite must be > 0"));
2658 }
2659 if !options.witness_kind_boost.is_finite() || options.witness_kind_boost < 0.0 {
2660 return Err(anyhow::anyhow!(
2661 "witness_kind_boost must be finite and >= 0; got {}",
2662 options.witness_kind_boost
2663 ));
2664 }
2665
2666 let mut budgets = Vec::new();
2667 let mut budget = options.budget_step;
2668 while budget <= options.max_actions {
2669 budgets.push(budget);
2670 match budget.checked_add(options.budget_step) {
2671 Some(next) => budget = next,
2672 None => break,
2673 }
2674 }
2675 if budgets.last().copied() != Some(options.max_actions) {
2676 budgets.push(options.max_actions);
2677 }
2678 Ok(budgets)
2679}
2680
2681fn build_witness_guided_transform_weights(
2682 transforms: &[Box<dyn PirTransform>],
2683 artifact: &PirMcmcArtifact,
2684 witness_kind_boost: f64,
2685) -> Vec<f64> {
2686 let mut counts = BTreeMap::<PirTransformKind, usize>::new();
2687 for action in artifact.winning_provenance.iter() {
2688 if let Some(kind) = action.transform_kind() {
2689 *counts.entry(kind).or_insert(0) += 1;
2690 }
2691 }
2692 transforms
2693 .iter()
2694 .map(|transform| {
2695 1.0 + witness_kind_boost * counts.get(&transform.kind()).copied().unwrap_or(0) as f64
2696 })
2697 .collect()
2698}
2699
2700fn prefix_baseline_for_budget(
2701 artifact: &PirMcmcArtifact,
2702 action_budget: usize,
2703 origin_metric: u128,
2704 winner_metric: u128,
2705) -> PirMcmcBudgetWitness {
2706 let objective = artifact.run_options.objective;
2707 let mut best = make_budget_witness(
2708 artifact.origin_fn.clone(),
2709 artifact.origin_cost,
2710 0,
2711 objective,
2712 origin_metric,
2713 winner_metric,
2714 );
2715 for action in artifact
2716 .winning_provenance
2717 .iter()
2718 .take_while(|action| action.action_index() <= action_budget)
2719 {
2720 let candidate = make_budget_witness(
2721 action.state().clone(),
2722 action.cost(),
2723 action.action_index(),
2724 objective,
2725 origin_metric,
2726 winner_metric,
2727 );
2728 if better_budget_witness(&candidate, &best) {
2729 best = candidate;
2730 }
2731 }
2732 best
2733}
2734
2735fn chain_strategy_value_name(strategy: ChainStrategy) -> &'static str {
2736 match strategy {
2737 ChainStrategy::Independent => "independent",
2738 ChainStrategy::ExploreExploit => "explore-exploit",
2739 }
2740}
2741
2742fn chain_strategy_from_value_name(value: &str) -> Result<ChainStrategy> {
2743 match value {
2744 "independent" => Ok(ChainStrategy::Independent),
2745 "explore-exploit" => Ok(ChainStrategy::ExploreExploit),
2746 _ => Err(anyhow::anyhow!(
2747 "unknown chain strategy in artifact: {}",
2748 value
2749 )),
2750 }
2751}
2752
2753fn emit_pkg_text_toposorted(pkg: &PirPackage) -> Result<String> {
2754 let mut pkg = pkg.clone();
2755 for member in pkg.members.iter_mut() {
2756 match member {
2757 PirPackageMember::Function(f) => {
2758 compact_and_toposort_in_place(f)
2759 .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2760 }
2761 PirPackageMember::Block { func, .. } => {
2762 compact_and_toposort_in_place(func)
2763 .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2764 }
2765 }
2766 }
2767 Ok(pkg.to_string())
2768}
2769
2770fn package_with_replaced_fn(
2771 package_template: &PirPackage,
2772 top_fn_name: &str,
2773 replacement: &IrFn,
2774) -> Result<PirPackage> {
2775 let mut pkg = package_template.clone();
2776 let top_mut = pkg.get_fn_mut(top_fn_name).ok_or_else(|| {
2777 anyhow::anyhow!(
2778 "top function '{}' not found in artifact package template",
2779 top_fn_name
2780 )
2781 })?;
2782 *top_mut = replacement.clone();
2783 Ok(pkg)
2784}
2785
2786fn write_artifact_state_package(
2787 artifact_dir: &Path,
2788 package_template: &PirPackage,
2789 top_fn_name: &str,
2790 relative_file: &str,
2791 state_fn: &IrFn,
2792) -> Result<()> {
2793 let pkg = package_with_replaced_fn(package_template, top_fn_name, state_fn)?;
2794 let text = emit_pkg_text_toposorted(&pkg)?;
2795 let path = artifact_dir.join(relative_file);
2796 if let Some(parent) = path.parent() {
2797 std::fs::create_dir_all(parent).map_err(|e| {
2798 anyhow::anyhow!(
2799 "failed to create artifact state directory {}: {}",
2800 parent.display(),
2801 e
2802 )
2803 })?;
2804 }
2805 std::fs::write(&path, text.as_bytes())
2806 .map_err(|e| anyhow::anyhow!("failed to write {}: {}", path.display(), e))
2807}
2808
2809fn load_artifact_state_fn(
2810 artifact_dir: &Path,
2811 relative_file: &str,
2812 top_fn_name: &str,
2813) -> Result<(PirPackage, IrFn)> {
2814 let path = artifact_dir.join(relative_file);
2815 let pkg = ir_parser::parse_and_validate_path_to_package(&path).map_err(|e| {
2816 anyhow::anyhow!(
2817 "failed to parse artifact state package {}: {:?}",
2818 path.display(),
2819 e
2820 )
2821 })?;
2822 let state_fn = pkg.get_fn(top_fn_name).cloned().ok_or_else(|| {
2823 anyhow::anyhow!(
2824 "artifact state package {} does not contain top function '{}'",
2825 path.display(),
2826 top_fn_name
2827 )
2828 })?;
2829 Ok((pkg, state_fn))
2830}
2831
2832impl PersistedRunOptions {
2833 fn from_run_options(options: &RunOptions) -> Result<Self> {
2834 Ok(Self {
2835 max_iters: options.max_iters,
2836 threads: options.threads,
2837 chain_strategy: chain_strategy_value_name(options.chain_strategy).to_string(),
2838 checkpoint_iters: options.checkpoint_iters,
2839 progress_iters: options.progress_iters,
2840 seed: options.seed,
2841 initial_temperature: options.initial_temperature,
2842 objective: options.objective.value_name().to_string(),
2843 extension_costing_mode: options.extension_costing_mode.value_name().to_string(),
2844 g8r_evaluation_mode: options
2845 .g8r_evaluation_mode
2846 .canonicalized_for_persistence()?,
2847 canonical_g8r_options: options.canonical_g8r_options.clone(),
2848 max_allowed_depth: options.max_allowed_depth,
2849 max_allowed_area: options.max_allowed_area,
2850 switching_beta1: options.weighted_switching_options.beta1,
2851 switching_beta2: options.weighted_switching_options.beta2,
2852 switching_primary_output_load: options.weighted_switching_options.primary_output_load,
2853 enable_formal_oracle: options.enable_formal_oracle,
2854 })
2855 }
2856
2857 fn into_run_options(self) -> Result<RunOptions> {
2858 Ok(RunOptions {
2859 max_iters: self.max_iters,
2860 threads: self.threads,
2861 chain_strategy: chain_strategy_from_value_name(&self.chain_strategy)?,
2862 checkpoint_iters: self.checkpoint_iters,
2863 progress_iters: self.progress_iters,
2864 seed: self.seed,
2865 initial_temperature: self.initial_temperature,
2866 objective: Objective::from_value_name(&self.objective)?,
2867 extension_costing_mode: ExtensionCostingMode::from_value_name(
2868 &self.extension_costing_mode,
2869 )?,
2870 g8r_evaluation_mode: self.g8r_evaluation_mode,
2871 canonical_g8r_options: self.canonical_g8r_options,
2872 max_allowed_depth: self.max_allowed_depth,
2873 max_allowed_area: self.max_allowed_area,
2874 weighted_switching_options: count_toggles::WeightedSwitchingOptions {
2875 beta1: self.switching_beta1,
2876 beta2: self.switching_beta2,
2877 primary_output_load: self.switching_primary_output_load,
2878 },
2879 enable_formal_oracle: self.enable_formal_oracle,
2880 trajectory_dir: None,
2881 toggle_stimulus: None,
2882 })
2883 }
2884}
2885
2886pub fn write_pir_mcmc_artifact_dir(
2889 artifact: &PirMcmcArtifact,
2890 package_template: &PirPackage,
2891 run_dir: &Path,
2892) -> Result<PathBuf> {
2893 validate_prefix_minimization_artifact(artifact)?;
2894 let artifact_dir = run_dir.join(PIR_MCMC_ARTIFACT_DIR_NAME);
2895 std::fs::create_dir_all(&artifact_dir).map_err(|e| {
2896 anyhow::anyhow!(
2897 "failed to create artifact directory {}: {}",
2898 artifact_dir.display(),
2899 e
2900 )
2901 })?;
2902
2903 let top_fn_name = artifact.origin_fn.name.clone();
2904 let origin_file = format!("{}/origin.ir", PIR_MCMC_ARTIFACT_STATES_DIR_NAME);
2905 let raw_winner_file = format!("{}/raw-winner.ir", PIR_MCMC_ARTIFACT_STATES_DIR_NAME);
2906 write_artifact_state_package(
2907 &artifact_dir,
2908 package_template,
2909 &top_fn_name,
2910 &origin_file,
2911 &artifact.origin_fn,
2912 )?;
2913 write_artifact_state_package(
2914 &artifact_dir,
2915 package_template,
2916 &top_fn_name,
2917 &raw_winner_file,
2918 &artifact.raw_winner_fn,
2919 )?;
2920
2921 let mut winning_provenance = Vec::with_capacity(artifact.winning_provenance.len());
2922 for action in artifact.winning_provenance.iter() {
2923 let state_file = format!(
2924 "{}/action-{:06}.ir",
2925 PIR_MCMC_ARTIFACT_STATES_DIR_NAME,
2926 action.action_index()
2927 );
2928 write_artifact_state_package(
2929 &artifact_dir,
2930 package_template,
2931 &top_fn_name,
2932 &state_file,
2933 action.state(),
2934 )?;
2935 let state = PersistedArtifactState {
2936 file: state_file,
2937 cost: action.cost(),
2938 };
2939 winning_provenance.push(match action {
2940 PirMcmcProvenanceAction::AcceptedRewrite {
2941 action_index,
2942 chain_no,
2943 global_iter,
2944 transform_kind,
2945 ..
2946 } => PersistedPirMcmcProvenanceAction {
2947 kind: PersistedPirMcmcProvenanceActionKind::AcceptedRewrite,
2948 action_index: *action_index,
2949 chain_no: *chain_no,
2950 global_iter: *global_iter,
2951 transform_kind: Some(*transform_kind),
2952 state,
2953 },
2954 PirMcmcProvenanceAction::XlsOptimizedHandoff {
2955 action_index,
2956 chain_no,
2957 global_iter,
2958 ..
2959 } => PersistedPirMcmcProvenanceAction {
2960 kind: PersistedPirMcmcProvenanceActionKind::XlsOptimizedHandoff,
2961 action_index: *action_index,
2962 chain_no: *chain_no,
2963 global_iter: *global_iter,
2964 transform_kind: None,
2965 state,
2966 },
2967 });
2968 }
2969
2970 let manifest = PersistedPirMcmcArtifactManifest {
2971 schema_version: PIR_MCMC_ARTIFACT_SCHEMA_VERSION,
2972 top_fn_name,
2973 run_options: PersistedRunOptions::from_run_options(&artifact.run_options)?,
2974 origin: PersistedArtifactState {
2975 file: origin_file,
2976 cost: artifact.origin_cost,
2977 },
2978 raw_winner: PersistedArtifactState {
2979 file: raw_winner_file,
2980 cost: artifact.raw_winner_cost,
2981 },
2982 winning_provenance,
2983 };
2984 let manifest_json = serde_json::to_string_pretty(&manifest)
2985 .map_err(|e| anyhow::anyhow!("failed to serialize artifact manifest: {}", e))?;
2986 let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
2987 std::fs::write(&manifest_path, manifest_json.as_bytes())
2988 .map_err(|e| anyhow::anyhow!("failed to write {}: {}", manifest_path.display(), e))?;
2989 Ok(artifact_dir)
2990}
2991
2992pub fn read_pir_mcmc_artifact_dir(run_dir: &Path) -> Result<LoadedPirMcmcArtifact> {
2994 let artifact_dir = run_dir.join(PIR_MCMC_ARTIFACT_DIR_NAME);
2995 let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
2996 let manifest_text = std::fs::read_to_string(&manifest_path)
2997 .map_err(|e| anyhow::anyhow!("failed to read {}: {}", manifest_path.display(), e))?;
2998 let manifest_value: serde_json::Value = serde_json::from_str(&manifest_text)
2999 .map_err(|e| anyhow::anyhow!("failed to parse {}: {}", manifest_path.display(), e))?;
3000 let schema_version = manifest_value
3001 .get("schema_version")
3002 .and_then(serde_json::Value::as_u64)
3003 .ok_or_else(|| {
3004 anyhow::anyhow!(
3005 "artifact manifest {} is missing integer schema_version",
3006 manifest_path.display()
3007 )
3008 })? as u32;
3009 if schema_version != PIR_MCMC_ARTIFACT_SCHEMA_VERSION {
3010 return Err(anyhow::anyhow!(
3011 "unsupported PIR MCMC artifact schema version {}; expected {}",
3012 schema_version,
3013 PIR_MCMC_ARTIFACT_SCHEMA_VERSION
3014 ));
3015 }
3016 let manifest: PersistedPirMcmcArtifactManifest = serde_json::from_str(&manifest_text)
3017 .map_err(|e| anyhow::anyhow!("failed to parse {}: {}", manifest_path.display(), e))?;
3018
3019 let top_fn_name = manifest.top_fn_name.clone();
3020 let (package_template, origin_fn) =
3021 load_artifact_state_fn(&artifact_dir, &manifest.origin.file, &top_fn_name)?;
3022 let (_, raw_winner_fn) =
3023 load_artifact_state_fn(&artifact_dir, &manifest.raw_winner.file, &top_fn_name)?;
3024
3025 let mut winning_provenance = Vec::with_capacity(manifest.winning_provenance.len());
3026 for action in manifest.winning_provenance.into_iter() {
3027 let (_, state) = load_artifact_state_fn(&artifact_dir, &action.state.file, &top_fn_name)?;
3028 winning_provenance.push(match action.kind {
3029 PersistedPirMcmcProvenanceActionKind::AcceptedRewrite => {
3030 let transform_kind = action.transform_kind.ok_or_else(|| {
3031 anyhow::anyhow!(
3032 "accepted_rewrite action {} is missing transform_kind",
3033 action.action_index
3034 )
3035 })?;
3036 PirMcmcProvenanceAction::AcceptedRewrite {
3037 action_index: action.action_index,
3038 chain_no: action.chain_no,
3039 global_iter: action.global_iter,
3040 transform_kind,
3041 state,
3042 cost: action.state.cost,
3043 }
3044 }
3045 PersistedPirMcmcProvenanceActionKind::XlsOptimizedHandoff => {
3046 if action.transform_kind.is_some() {
3047 return Err(anyhow::anyhow!(
3048 "xls_optimized_handoff action {} must not include transform_kind",
3049 action.action_index
3050 ));
3051 }
3052 PirMcmcProvenanceAction::XlsOptimizedHandoff {
3053 action_index: action.action_index,
3054 chain_no: action.chain_no,
3055 global_iter: action.global_iter,
3056 state,
3057 cost: action.state.cost,
3058 }
3059 }
3060 });
3061 }
3062
3063 let artifact = PirMcmcArtifact {
3064 origin_fn,
3065 origin_cost: manifest.origin.cost,
3066 run_options: manifest.run_options.into_run_options()?,
3067 raw_winner_fn,
3068 raw_winner_cost: manifest.raw_winner.cost,
3069 winning_provenance,
3070 };
3071 validate_prefix_minimization_artifact(&artifact)?;
3072 Ok(LoadedPirMcmcArtifact {
3073 artifact,
3074 package_template,
3075 top_fn_name,
3076 })
3077}
3078
3079impl SegmentRunner<IrFn, Cost, PirTransformKind> for PirSegmentRunner {
3080 type Error = anyhow::Error;
3081
3082 fn run_segment(
3083 &self,
3084 start_state: IrFn,
3085 params: SegmentRunParams,
3086 ) -> Result<SegmentOutcome<IrFn, Cost, PirTransformKind>, Self::Error> {
3087 let mut trajectory_writer: Option<std::io::BufWriter<std::fs::File>> =
3088 if let Some(dir) = &self.trajectory_dir {
3089 std::fs::create_dir_all(dir).map_err(|e| {
3090 anyhow::anyhow!("failed to create trajectory dir {}: {}", dir.display(), e)
3091 })?;
3092 let path = dir.join(format!("trajectory.c{:03}.jsonl", params.chain_no));
3093 let f = std::fs::OpenOptions::new()
3094 .create(true)
3095 .append(true)
3096 .open(&path)
3097 .map_err(|e| anyhow::anyhow!("failed to open {}: {}", path.display(), e))?;
3098 Some(std::io::BufWriter::new(f))
3099 } else {
3100 None
3101 };
3102
3103 let mut iteration_rng = Pcg64Mcg::seed_from_u64(params.seed);
3104 let all_transforms = get_pir_transforms_for_run(self.enable_formal_oracle);
3105 let weights = build_transform_weights(&all_transforms);
3106
3107 let mut context = PirMcmcContext {
3108 rng: &mut iteration_rng,
3109 all_transforms,
3110 weights,
3111 enable_formal_oracle: self.enable_formal_oracle,
3112 oracle_baseline_cache: EvalFnBaselineResults::default(),
3113 };
3114
3115 let toggle_stimulus = self.prepared_toggle_stimulus.as_ref().map(|v| v.as_slice());
3116
3117 let mut current_fn = start_state.clone();
3118 let mut current_cost =
3119 cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
3120 ¤t_fn,
3121 self.objective,
3122 toggle_stimulus,
3123 &self.weighted_switching_options,
3124 self.extension_costing_mode,
3125 &self.g8r_evaluation_mode,
3126 &self.canonical_g8r_options,
3127 )
3128 .map_err(|e| {
3129 anyhow::anyhow!(
3130 "failed to evaluate initial cost for '{}' under {:?}: {}",
3131 current_fn.name,
3132 self.objective,
3133 e
3134 )
3135 })?;
3136 let mut best_fn = start_state;
3137 let mut best_cost = current_cost;
3138 let mut best_score = search_score(&best_cost, self.objective, self.constraints);
3139 let mut stats = McmcStats::default();
3140
3141 let seg_start_time = Instant::now();
3142
3143 let mut iterations_count: u64 = 0;
3144 while iterations_count < params.segment_iters {
3145 iterations_count += 1;
3146 let global_iter = params.iter_offset + iterations_count;
3147
3148 let temp = match params.role {
3149 ChainRole::Explorer => self.initial_temperature * 10.0,
3150 ChainRole::Exploit => {
3151 let progress_ratio = if params.total_iters > 0 {
3152 (global_iter as f64) / (params.total_iters as f64)
3153 } else {
3154 0.0
3155 };
3156 let progress_ratio = progress_ratio.min(1.0);
3157 self.initial_temperature * (1.0 - progress_ratio).max(MIN_TEMPERATURE_RATIO)
3158 }
3159 };
3160
3161 let iteration_output = mcmc_iteration(
3162 current_fn,
3163 current_cost,
3164 &mut best_fn,
3165 &mut best_cost,
3166 &mut best_score,
3167 &mut context,
3168 temp,
3169 self.objective,
3170 self.extension_costing_mode,
3171 &self.g8r_evaluation_mode,
3172 &self.canonical_g8r_options,
3173 toggle_stimulus,
3174 &self.weighted_switching_options,
3175 self.constraints,
3176 );
3177
3178 let mut accepted_digest: Option<[u8; 32]> = None;
3179 let mut accepted_sample_sent = false;
3180
3181 if let IterationOutcomeDetails::Accepted { .. } = iteration_output.outcome {
3182 if let Some(ref tx) = self.accepted_sample_tx {
3183 match canonicalize_fn_for_sample(&iteration_output.output_state) {
3184 Ok(canon) => match optimize_pir_fn_via_xls_with_extension_mode(
3185 &canon,
3186 self.extension_costing_mode,
3187 ) {
3188 Ok(mut opt) => {
3189 let _ = compact_and_toposort_in_place(&mut opt);
3190 match compute_fn_structural_digest(&opt) {
3191 Some(digest) => {
3192 accepted_digest = Some(digest);
3193 accepted_sample_sent = tx
3194 .send(AcceptedSampleMsg {
3195 chain_no: params.chain_no,
3196 global_iter,
3197 digest,
3198 cost: iteration_output.output_cost,
3199 func: opt,
3200 })
3201 .is_ok();
3202 }
3203 None => {
3204 log::warn!(
3205 "[pir-mcmc] failed to compute structural digest for accepted sample '{}' after XLS optimize (c{:03}:i{:06}); skipping sample emission",
3206 iteration_output.output_state.name,
3207 params.chain_no,
3208 global_iter
3209 );
3210 }
3211 }
3212 }
3213 Err(e) => {
3214 log::warn!(
3218 "[pir-mcmc] failed to XLS-optimize accepted sample '{}' (c{:03}:i{:06}): {}; skipping sample emission",
3219 iteration_output.output_state.name,
3220 params.chain_no,
3221 global_iter,
3222 e
3223 );
3224 }
3225 },
3226 Err(e) => {
3227 log::warn!(
3228 "[pir-mcmc] failed to canonicalize accepted sample '{}' (c{:03}:i{:06}): {}; skipping sample emission",
3229 iteration_output.output_state.name,
3230 params.chain_no,
3231 global_iter,
3232 e
3233 );
3234 }
3235 }
3236 }
3237 }
3238
3239 if let Some(w) = trajectory_writer.as_mut() {
3240 let metric_u128 = self.objective.metric(&iteration_output.output_cost);
3241 let iter_score = search_score(
3242 &iteration_output.output_cost,
3243 self.objective,
3244 self.constraints,
3245 );
3246 let iter_violation = iter_score.violation;
3247 let rec = json!({
3248 "chain_no": params.chain_no,
3249 "role": format!("{:?}", params.role),
3250 "global_iter": global_iter,
3251 "temp": temp,
3252 "outcome": iteration_outcome_tag(&iteration_output.outcome),
3253 "best_updated": iteration_output.best_updated,
3254 "objective": format!("{:?}", self.objective),
3255 "extension_costing_mode": self.extension_costing_mode.value_name(),
3256 "metric": metric_u128,
3257 "pir_nodes": iteration_output.output_cost.pir_nodes,
3258 "g8r_nodes": iteration_output.output_cost.g8r_nodes,
3259 "g8r_depth": iteration_output.output_cost.g8r_depth,
3260 "g8r_le_graph_milli": iteration_output.output_cost.g8r_le_graph_milli,
3261 "g8r_gate_output_toggles": iteration_output.output_cost.g8r_gate_output_toggles,
3262 "g8r_weighted_switching_milli": iteration_output.output_cost.g8r_weighted_switching_milli,
3263 "g8r_post_and_nodes": iteration_output.output_cost.g8r_post_and_nodes,
3264 "g8r_post_depth": iteration_output.output_cost.g8r_post_depth,
3265 "g8r_post_le_graph_milli": iteration_output.output_cost.g8r_post_le_graph_milli,
3266 "g8r_post_gate_output_toggles": iteration_output.output_cost.g8r_post_gate_output_toggles,
3267 "g8r_post_weighted_switching_milli": iteration_output.output_cost.g8r_post_weighted_switching_milli,
3268 "feasible": iter_score.feasible(),
3269 "delay_over": iter_violation.and_then(|v| v.delay_over),
3270 "area_over": iter_violation.and_then(|v| v.area_over),
3271 "oracle_time_micros": iteration_output.oracle_time_micros,
3272 "transform": iteration_output.transform.map(|k| format!("{:?}", k)),
3273 "transform_mechanism": iteration_output.transform.map(|k| k.mechanism_hint()),
3274 "transform_always_equivalent": iteration_output.transform_always_equivalent,
3275 "accepted_digest": accepted_digest.map(|d| hash_to_hex(&d)),
3276 "accepted_sample_sent": accepted_sample_sent,
3277 });
3278 writeln!(w, "{}", rec.to_string())?;
3282 if global_iter % 1000 == 0 {
3283 w.flush()?;
3284 }
3285 }
3286
3287 current_fn = iteration_output.output_state.clone();
3288 current_cost = iteration_output.output_cost;
3289
3290 if iteration_output.best_updated {
3291 if let Some(ref shared_best) = self.shared_best {
3292 let before = shared_best.score();
3293 let _ = shared_best.try_update(best_score, best_fn.clone());
3294 let after = shared_best.score();
3295 if after < before {
3296 log::info!(
3297 "[pir-mcmc] GLOBAL BEST UPDATE c{:03}:i{:06} | {} -> {}",
3298 params.chain_no,
3299 global_iter,
3300 format_search_score(before),
3301 format_search_score(after),
3302 );
3303 if let Some(ref tx) = self.checkpoint_tx {
3306 let _ = tx.send(CheckpointMsg {
3307 chain_no: params.chain_no,
3308 global_iter,
3309 kind: CheckpointKind::GlobalBestUpdate,
3310 });
3311 }
3312 }
3313 }
3314 }
3315
3316 if self.checkpoint_iters > 0 && global_iter % self.checkpoint_iters == 0 {
3317 if let Some(ref tx) = self.checkpoint_tx {
3318 let _ = tx.send(CheckpointMsg {
3320 chain_no: params.chain_no,
3321 global_iter,
3322 kind: CheckpointKind::Periodic,
3323 });
3324 }
3325 }
3326
3327 stats.update_for_iteration(&iteration_output, false, global_iter);
3328
3329 if self.progress_iters > 0
3330 && (global_iter % self.progress_iters == 0
3331 || global_iter == params.total_iters
3332 || iterations_count == params.segment_iters)
3333 {
3334 let elapsed_secs = seg_start_time.elapsed().as_secs_f64();
3335 let samples_per_sec = if elapsed_secs > 0.0 {
3336 iterations_count as f64 / elapsed_secs
3337 } else {
3338 0.0
3339 };
3340 log::info!(
3341 "PIR MCMC c{:03}:i{:06} | GBest={} | LBest (pir={}, g8r_n={}, g8r_d={}, score={}) | Cur (pir={}, g8r_n={}, g8r_d={}, score={}) | Temp={:.2e} | Samples/s={:.2}",
3342 params.chain_no,
3343 global_iter,
3344 self.shared_best
3345 .as_ref()
3346 .map(|b| format_search_score(b.score()))
3347 .unwrap_or_else(|| "none".to_string()),
3348 best_cost.pir_nodes,
3349 best_cost.g8r_nodes,
3350 best_cost.g8r_depth,
3351 format_search_score(best_score),
3352 current_cost.pir_nodes,
3353 current_cost.g8r_nodes,
3354 current_cost.g8r_depth,
3355 format_search_score(search_score(
3356 ¤t_cost,
3357 self.objective,
3358 self.constraints
3359 )),
3360 temp,
3361 samples_per_sec,
3362 );
3363 }
3364 }
3365
3366 if let Some(mut w) = trajectory_writer {
3367 w.flush()?;
3368 }
3369
3370 Ok(SegmentOutcome {
3371 end_state: current_fn,
3372 end_cost: current_cost,
3373 best_state: best_fn,
3374 best_cost,
3375 stats,
3376 })
3377 }
3378}
3379
3380impl SegmentRunner<ProvenancedChainState, Cost, PirTransformKind> for PirArtifactSegmentRunner {
3381 type Error = anyhow::Error;
3382
3383 fn run_segment(
3384 &self,
3385 start_state: ProvenancedChainState,
3386 params: SegmentRunParams,
3387 ) -> Result<SegmentOutcome<ProvenancedChainState, Cost, PirTransformKind>, Self::Error> {
3388 let mut trajectory_writer: Option<std::io::BufWriter<std::fs::File>> =
3389 if let Some(dir) = &self.trajectory_dir {
3390 std::fs::create_dir_all(dir).map_err(|e| {
3391 anyhow::anyhow!("failed to create trajectory dir {}: {}", dir.display(), e)
3392 })?;
3393 let path = dir.join(format!("trajectory.c{:03}.jsonl", params.chain_no));
3394 let f = std::fs::OpenOptions::new()
3395 .create(true)
3396 .append(true)
3397 .open(&path)
3398 .map_err(|e| anyhow::anyhow!("failed to open {}: {}", path.display(), e))?;
3399 Some(std::io::BufWriter::new(f))
3400 } else {
3401 None
3402 };
3403
3404 let mut iteration_rng = Pcg64Mcg::seed_from_u64(params.seed);
3405 let all_transforms = (self.transform_factory)();
3406 let weights = build_transform_weights(&all_transforms);
3407 let mut context = PirMcmcContext {
3408 rng: &mut iteration_rng,
3409 all_transforms,
3410 weights,
3411 enable_formal_oracle: self.enable_formal_oracle,
3412 oracle_baseline_cache: EvalFnBaselineResults::default(),
3413 };
3414
3415 let toggle_stimulus = self.prepared_toggle_stimulus.as_ref().map(|v| v.as_slice());
3416 let mut current_fn = start_state.search_fn.clone();
3417 let mut current_provenance = start_state.search_provenance.clone();
3418 let mut current_cost =
3419 cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
3420 ¤t_fn,
3421 self.objective,
3422 toggle_stimulus,
3423 &self.weighted_switching_options,
3424 self.extension_costing_mode,
3425 &self.g8r_evaluation_mode,
3426 &self.canonical_g8r_options,
3427 )
3428 .map_err(|e| {
3429 anyhow::anyhow!(
3430 "failed to evaluate initial cost for '{}' under {:?}: {}",
3431 current_fn.name,
3432 self.objective,
3433 e
3434 )
3435 })?;
3436 if let Some((chain_no, global_iter)) = start_state.pending_handoff {
3437 current_provenance.push(PirMcmcProvenanceAction::XlsOptimizedHandoff {
3438 action_index: current_provenance.len() + 1,
3439 chain_no,
3440 global_iter,
3441 state: current_fn.clone(),
3442 cost: current_cost,
3443 });
3444 }
3445 let mut raw_winner_fn = start_state.raw_winner_fn;
3446 let mut raw_winner_cost = start_state.raw_winner_cost;
3447 let mut raw_winner_provenance = start_state.raw_winner_provenance;
3448 let current_score = search_score(¤t_cost, self.objective, self.constraints);
3449 if current_score < search_score(&raw_winner_cost, self.objective, self.constraints) {
3450 raw_winner_fn = current_fn.clone();
3451 raw_winner_cost = current_cost;
3452 raw_winner_provenance = current_provenance.clone();
3453 }
3454 let mut best_fn_for_iteration = start_state.search_fn.clone();
3455 let mut best_cost_for_iteration = current_cost;
3456 let mut best_score =
3457 search_score(&best_cost_for_iteration, self.objective, self.constraints);
3458 let mut best_state = ProvenancedChainState {
3459 search_fn: start_state.search_fn,
3460 search_provenance: current_provenance.clone(),
3461 raw_winner_fn,
3462 raw_winner_cost,
3463 raw_winner_provenance,
3464 pending_handoff: None,
3465 };
3466 let mut stats = McmcStats::default();
3467 let seg_start_time = Instant::now();
3468
3469 let mut iterations_count: u64 = 0;
3470 while iterations_count < params.segment_iters {
3471 iterations_count += 1;
3472 let global_iter = params.iter_offset + iterations_count;
3473 let temp = match params.role {
3474 ChainRole::Explorer => self.initial_temperature * 10.0,
3475 ChainRole::Exploit => {
3476 let progress_ratio = if params.total_iters > 0 {
3477 (global_iter as f64) / (params.total_iters as f64)
3478 } else {
3479 0.0
3480 };
3481 let progress_ratio = progress_ratio.min(1.0);
3482 self.initial_temperature * (1.0 - progress_ratio).max(MIN_TEMPERATURE_RATIO)
3483 }
3484 };
3485
3486 let iteration_output = mcmc_iteration(
3487 current_fn,
3488 current_cost,
3489 &mut best_fn_for_iteration,
3490 &mut best_cost_for_iteration,
3491 &mut best_score,
3492 &mut context,
3493 temp,
3494 self.objective,
3495 self.extension_costing_mode,
3496 &self.g8r_evaluation_mode,
3497 &self.canonical_g8r_options,
3498 toggle_stimulus,
3499 &self.weighted_switching_options,
3500 self.constraints,
3501 );
3502
3503 if let IterationOutcomeDetails::Accepted { kind } = &iteration_output.outcome {
3504 current_provenance.push(PirMcmcProvenanceAction::AcceptedRewrite {
3505 action_index: current_provenance.len() + 1,
3506 chain_no: params.chain_no,
3507 global_iter,
3508 transform_kind: *kind,
3509 state: iteration_output.output_state.clone(),
3510 cost: iteration_output.output_cost,
3511 });
3512 }
3513
3514 if let Some(w) = trajectory_writer.as_mut() {
3515 let metric_u128 = self.objective.metric(&iteration_output.output_cost);
3516 let iter_score = search_score(
3517 &iteration_output.output_cost,
3518 self.objective,
3519 self.constraints,
3520 );
3521 let iter_violation = iter_score.violation;
3522 let rec = json!({
3523 "chain_no": params.chain_no,
3524 "role": format!("{:?}", params.role),
3525 "global_iter": global_iter,
3526 "temp": temp,
3527 "outcome": iteration_outcome_tag(&iteration_output.outcome),
3528 "best_updated": iteration_output.best_updated,
3529 "objective": format!("{:?}", self.objective),
3530 "extension_costing_mode": self.extension_costing_mode.value_name(),
3531 "metric": metric_u128,
3532 "pir_nodes": iteration_output.output_cost.pir_nodes,
3533 "g8r_nodes": iteration_output.output_cost.g8r_nodes,
3534 "g8r_depth": iteration_output.output_cost.g8r_depth,
3535 "g8r_le_graph_milli": iteration_output.output_cost.g8r_le_graph_milli,
3536 "g8r_gate_output_toggles": iteration_output.output_cost.g8r_gate_output_toggles,
3537 "g8r_weighted_switching_milli": iteration_output.output_cost.g8r_weighted_switching_milli,
3538 "g8r_post_and_nodes": iteration_output.output_cost.g8r_post_and_nodes,
3539 "g8r_post_depth": iteration_output.output_cost.g8r_post_depth,
3540 "g8r_post_le_graph_milli": iteration_output.output_cost.g8r_post_le_graph_milli,
3541 "g8r_post_gate_output_toggles": iteration_output.output_cost.g8r_post_gate_output_toggles,
3542 "g8r_post_weighted_switching_milli": iteration_output.output_cost.g8r_post_weighted_switching_milli,
3543 "feasible": iter_score.feasible(),
3544 "delay_over": iter_violation.and_then(|v| v.delay_over),
3545 "area_over": iter_violation.and_then(|v| v.area_over),
3546 "oracle_time_micros": iteration_output.oracle_time_micros,
3547 "transform": iteration_output.transform.map(|k| format!("{:?}", k)),
3548 "transform_mechanism": iteration_output.transform.map(|k| k.mechanism_hint()),
3549 "transform_always_equivalent": iteration_output.transform_always_equivalent,
3550 "accepted_digest": Option::<String>::None,
3551 "accepted_sample_sent": false,
3552 });
3553 writeln!(w, "{}", rec.to_string())?;
3554 if global_iter % 1000 == 0 {
3555 w.flush()?;
3556 }
3557 }
3558
3559 current_fn = iteration_output.output_state.clone();
3560 current_cost = iteration_output.output_cost;
3561
3562 if iteration_output.best_updated {
3563 best_state = ProvenancedChainState {
3564 search_fn: best_fn_for_iteration.clone(),
3565 search_provenance: current_provenance.clone(),
3566 raw_winner_fn: current_fn.clone(),
3567 raw_winner_cost: current_cost,
3568 raw_winner_provenance: current_provenance.clone(),
3569 pending_handoff: None,
3570 };
3571
3572 if let Some(ref shared_best) = self.shared_best {
3573 let before = shared_best.score();
3574 let _ = shared_best.try_update(best_score, best_fn_for_iteration.clone());
3575 let after = shared_best.score();
3576 if after < before {
3577 log::info!(
3578 "[pir-mcmc] GLOBAL BEST UPDATE c{:03}:i{:06} | {} -> {}",
3579 params.chain_no,
3580 global_iter,
3581 format_search_score(before),
3582 format_search_score(after),
3583 );
3584 if let Some(ref tx) = self.checkpoint_tx {
3585 let _ = tx.send(CheckpointMsg {
3586 chain_no: params.chain_no,
3587 global_iter,
3588 kind: CheckpointKind::GlobalBestUpdate,
3589 });
3590 }
3591 }
3592 }
3593 }
3594
3595 if self.checkpoint_iters > 0 && global_iter % self.checkpoint_iters == 0 {
3596 if let Some(ref tx) = self.checkpoint_tx {
3597 let _ = tx.send(CheckpointMsg {
3598 chain_no: params.chain_no,
3599 global_iter,
3600 kind: CheckpointKind::Periodic,
3601 });
3602 }
3603 }
3604
3605 stats.update_for_iteration(&iteration_output, false, global_iter);
3606
3607 if self.progress_iters > 0
3608 && (global_iter % self.progress_iters == 0
3609 || global_iter == params.total_iters
3610 || iterations_count == params.segment_iters)
3611 {
3612 let elapsed_secs = seg_start_time.elapsed().as_secs_f64();
3613 let samples_per_sec = if elapsed_secs > 0.0 {
3614 iterations_count as f64 / elapsed_secs
3615 } else {
3616 0.0
3617 };
3618 log::info!(
3619 "PIR MCMC c{:03}:i{:06} | GBest={} | LBest (pir={}, g8r_n={}, g8r_d={}, score={}) | Cur (pir={}, g8r_n={}, g8r_d={}, score={}) | Temp={:.2e} | Samples/s={:.2}",
3620 params.chain_no,
3621 global_iter,
3622 self.shared_best
3623 .as_ref()
3624 .map(|b| format_search_score(b.score()))
3625 .unwrap_or_else(|| "none".to_string()),
3626 best_cost_for_iteration.pir_nodes,
3627 best_cost_for_iteration.g8r_nodes,
3628 best_cost_for_iteration.g8r_depth,
3629 format_search_score(best_score),
3630 current_cost.pir_nodes,
3631 current_cost.g8r_nodes,
3632 current_cost.g8r_depth,
3633 format_search_score(search_score(
3634 ¤t_cost,
3635 self.objective,
3636 self.constraints
3637 )),
3638 temp,
3639 samples_per_sec,
3640 );
3641 }
3642 }
3643
3644 if let Some(mut w) = trajectory_writer {
3645 w.flush()?;
3646 }
3647
3648 Ok(SegmentOutcome {
3649 end_state: ProvenancedChainState {
3650 search_fn: current_fn,
3651 search_provenance: current_provenance,
3652 raw_winner_fn: best_state.raw_winner_fn.clone(),
3653 raw_winner_cost: best_state.raw_winner_cost,
3654 raw_winner_provenance: best_state.raw_winner_provenance.clone(),
3655 pending_handoff: None,
3656 },
3657 end_cost: current_cost,
3658 best_state,
3659 best_cost: best_cost_for_iteration,
3660 stats,
3661 })
3662 }
3663}
3664
3665pub fn run_pir_mcmc(start_fn: IrFn, options: RunOptions) -> Result<PirMcmcResult> {
3669 run_pir_mcmc_with_shared_best(start_fn, options, None, None, None)
3670}
3671
3672pub fn run_pir_mcmc_with_artifact(start_fn: IrFn, options: RunOptions) -> Result<PirMcmcArtifact> {
3675 validate_pir_mcmc_artifact_run_options(&options)?;
3676 let enable_formal_oracle = options.enable_formal_oracle;
3677 Ok(run_pir_mcmc_with_artifact_using_transform_factory(
3678 start_fn,
3679 options,
3680 Arc::new(move || get_pir_transforms_for_run(enable_formal_oracle)),
3681 )?
3682 .artifact)
3683}
3684
3685#[cfg(test)]
3686fn run_pir_mcmc_with_artifact_using_transforms(
3687 start_fn: IrFn,
3688 options: RunOptions,
3689 all_transforms: Vec<Box<dyn PirTransform>>,
3690) -> Result<PirMcmcArtifactRunOutput> {
3691 let transforms = Arc::new(Mutex::new(Some(all_transforms)));
3692 run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3693 start_fn,
3694 options,
3695 Arc::new(move || {
3696 transforms
3697 .lock()
3698 .expect("artifact transform fixture lock poisoned")
3699 .take()
3700 .expect("artifact transform fixture can only be consumed once")
3701 }),
3702 None,
3703 None,
3704 )
3705}
3706
3707fn run_pir_mcmc_with_artifact_using_transform_factory(
3708 start_fn: IrFn,
3709 options: RunOptions,
3710 transform_factory: PirTransformFactory,
3711) -> Result<PirMcmcArtifactRunOutput> {
3712 run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3713 start_fn,
3714 options,
3715 transform_factory,
3716 None,
3717 None,
3718 )
3719}
3720
3721fn run_pir_mcmc_with_artifact_and_observers(
3722 start_fn: IrFn,
3723 options: RunOptions,
3724 shared_best: Option<Arc<Best>>,
3725 checkpoint_tx: Option<Sender<CheckpointMsg>>,
3726) -> Result<PirMcmcArtifactRunOutput> {
3727 validate_pir_mcmc_artifact_run_options(&options)?;
3728 let enable_formal_oracle = options.enable_formal_oracle;
3729 run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3730 start_fn,
3731 options,
3732 Arc::new(move || get_pir_transforms_for_run(enable_formal_oracle)),
3733 shared_best,
3734 checkpoint_tx,
3735 )
3736}
3737
3738fn run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3739 start_fn: IrFn,
3740 options: RunOptions,
3741 transform_factory: PirTransformFactory,
3742 shared_best: Option<Arc<Best>>,
3743 checkpoint_tx: Option<Sender<CheckpointMsg>>,
3744) -> Result<PirMcmcArtifactRunOutput> {
3745 let prepared = prepare_run_start(start_fn, &options)?;
3746 let origin_fn = prepared.start_fn.clone();
3747 let origin_cost = prepared.initial_cost;
3748 let runner = PirArtifactSegmentRunner {
3749 objective: options.objective,
3750 extension_costing_mode: options.extension_costing_mode,
3751 g8r_evaluation_mode: options.g8r_evaluation_mode.clone(),
3752 canonical_g8r_options: options.canonical_g8r_options.clone(),
3753 weighted_switching_options: options.weighted_switching_options,
3754 initial_temperature: options.initial_temperature,
3755 constraints: prepared.effective_constraints,
3756 enable_formal_oracle: options.enable_formal_oracle,
3757 progress_iters: options.progress_iters,
3758 checkpoint_iters: options.checkpoint_iters,
3759 checkpoint_tx,
3760 shared_best,
3761 trajectory_dir: options.trajectory_dir.clone(),
3762 prepared_toggle_stimulus: prepared.prepared_toggle_stimulus,
3763 transform_factory,
3764 };
3765 let objective = options.objective;
3766 let threshold = options.initial_temperature as u128;
3767 let constraints = prepared.effective_constraints;
3768 let (best_state, best_cost, stats) = run_multichain(
3769 ProvenancedChainState::origin(prepared.start_fn, prepared.initial_cost),
3770 options.max_iters,
3771 options.seed,
3772 options.threads.max(1) as usize,
3773 options.chain_strategy,
3774 options.checkpoint_iters,
3775 Arc::new(runner),
3776 move |c: &Cost| search_score(c, objective, constraints),
3777 |s: &ProvenancedChainState| s.search_fn.to_string(),
3778 move |cur_cost: &Cost, global_best_cost: &Cost| {
3779 let cur_score = search_score(cur_cost, objective, constraints);
3780 let global_best_score = search_score(global_best_cost, objective, constraints);
3781 match (cur_score.violation, global_best_score.violation) {
3782 (Some(_), None) => true,
3783 (None, Some(_)) => false,
3784 (Some(cur_violation), Some(best_violation)) => {
3785 repair_energy(cur_violation)
3786 > repair_energy(best_violation).saturating_add(threshold)
3787 }
3788 (None, None) => {
3789 objective.metric(cur_cost)
3790 > objective.metric(global_best_cost).saturating_add(threshold)
3791 }
3792 }
3793 },
3794 |best_state: &ProvenancedChainState, receiving_chain_no, global_iter| {
3795 best_state.with_xls_optimized_handoff(receiving_chain_no, global_iter)
3796 },
3797 )?;
3798
3799 Ok(PirMcmcArtifactRunOutput {
3800 result: PirMcmcResult {
3801 best_fn: best_state.search_fn.clone(),
3802 best_cost,
3803 stats,
3804 },
3805 artifact: PirMcmcArtifact {
3806 origin_fn,
3807 origin_cost,
3808 run_options: options,
3809 raw_winner_fn: best_state.raw_winner_fn,
3810 raw_winner_cost: best_state.raw_winner_cost,
3811 winning_provenance: best_state.raw_winner_provenance,
3812 },
3813 })
3814}
3815
3816pub fn minimize_winning_prefix(
3819 artifact: &PirMcmcArtifact,
3820 options: PirMcmcPrefixMinimizeOptions,
3821) -> Result<PirMcmcPrefixMinimizeResult> {
3822 validate_prefix_minimization_artifact(artifact)?;
3823 if !options.retained_win_fraction.is_finite()
3824 || !(0.0..=1.0).contains(&options.retained_win_fraction)
3825 {
3826 return Err(anyhow::anyhow!(
3827 "retained_win_fraction must be finite and in [0, 1]; got {}",
3828 options.retained_win_fraction
3829 ));
3830 }
3831
3832 let objective = artifact.run_options.objective;
3833 let origin_metric = objective.metric(&artifact.origin_cost);
3834 let winner_metric = objective.metric(&artifact.raw_winner_cost);
3835 if winner_metric >= origin_metric {
3836 return Err(anyhow::anyhow!(
3837 "artifact does not contain a positive objective win: origin_metric={}, winner_metric={}",
3838 origin_metric,
3839 winner_metric
3840 ));
3841 }
3842
3843 if options.retained_win_fraction == 0.0 {
3844 return Ok(PirMcmcPrefixMinimizeResult {
3845 witness_fn: artifact.origin_fn.clone(),
3846 witness_cost: artifact.origin_cost,
3847 provenance_action_count: 0,
3848 original_winning_provenance_len: artifact.winning_provenance.len(),
3849 requested_retained_win_fraction: options.retained_win_fraction,
3850 actual_retained_win_fraction: 0.0,
3851 origin_metric,
3852 winner_metric,
3853 witness_metric: origin_metric,
3854 });
3855 }
3856
3857 for action in artifact.winning_provenance.iter() {
3858 let witness_metric = objective.metric(&action.cost());
3859 let actual_retained_win_fraction =
3860 retained_win_fraction_for_metric(origin_metric, winner_metric, witness_metric);
3861 if actual_retained_win_fraction >= options.retained_win_fraction {
3862 return Ok(PirMcmcPrefixMinimizeResult {
3863 witness_fn: action.state().clone(),
3864 witness_cost: action.cost(),
3865 provenance_action_count: action.action_index(),
3866 original_winning_provenance_len: artifact.winning_provenance.len(),
3867 requested_retained_win_fraction: options.retained_win_fraction,
3868 actual_retained_win_fraction,
3869 origin_metric,
3870 winner_metric,
3871 witness_metric,
3872 });
3873 }
3874 }
3875
3876 Err(anyhow::anyhow!(
3877 "winning provenance did not contain a prefix retaining requested win fraction {}",
3878 options.retained_win_fraction
3879 ))
3880}
3881
3882struct GuidedRolloutResult {
3883 best_witness: PirMcmcBudgetWitness,
3884}
3885
3886fn run_witness_guided_rollout(
3887 artifact: &PirMcmcArtifact,
3888 action_budget: usize,
3889 rollout_seed: u64,
3890 weights: Vec<f64>,
3891 proposal_attempts_per_rewrite: usize,
3892 transforms: Vec<Box<dyn PirTransform>>,
3893 origin_metric: u128,
3894 winner_metric: u128,
3895) -> Result<GuidedRolloutResult> {
3896 let objective = artifact.run_options.objective;
3897 let mut iteration_rng = Pcg64Mcg::seed_from_u64(rollout_seed);
3898 let mut context = PirMcmcContext {
3899 rng: &mut iteration_rng,
3900 all_transforms: transforms,
3901 weights,
3902 enable_formal_oracle: artifact.run_options.enable_formal_oracle,
3903 oracle_baseline_cache: EvalFnBaselineResults::default(),
3904 };
3905 let mut current_fn = artifact.origin_fn.clone();
3906 let mut current_cost = artifact.origin_cost;
3907 let mut best_fn_for_iteration = artifact.origin_fn.clone();
3908 let mut best_cost_for_iteration = artifact.origin_cost;
3909 let mut best_score = search_score(
3910 &best_cost_for_iteration,
3911 objective,
3912 ConstraintLimits::default(),
3913 );
3914 let mut best_witness = make_budget_witness(
3915 artifact.origin_fn.clone(),
3916 artifact.origin_cost,
3917 0,
3918 objective,
3919 origin_metric,
3920 winner_metric,
3921 );
3922 let max_proposals = action_budget.saturating_mul(proposal_attempts_per_rewrite);
3923 let mut accepted_rewrites = 0usize;
3924 let mut proposal_attempts = 0usize;
3925
3926 while accepted_rewrites < action_budget && proposal_attempts < max_proposals {
3927 proposal_attempts += 1;
3928 let progress_ratio = accepted_rewrites as f64 / action_budget.max(1) as f64;
3929 let temp = artifact.run_options.initial_temperature
3930 * (1.0 - progress_ratio.min(1.0)).max(MIN_TEMPERATURE_RATIO);
3931 let iteration_output = mcmc_iteration(
3932 current_fn,
3933 current_cost,
3934 &mut best_fn_for_iteration,
3935 &mut best_cost_for_iteration,
3936 &mut best_score,
3937 &mut context,
3938 temp,
3939 objective,
3940 artifact.run_options.extension_costing_mode,
3941 &artifact.run_options.g8r_evaluation_mode,
3942 &artifact.run_options.canonical_g8r_options,
3943 None,
3944 &artifact.run_options.weighted_switching_options,
3945 ConstraintLimits::default(),
3946 );
3947 current_fn = iteration_output.output_state.clone();
3948 current_cost = iteration_output.output_cost;
3949 if matches!(
3950 iteration_output.outcome,
3951 IterationOutcomeDetails::Accepted { .. }
3952 ) {
3953 accepted_rewrites += 1;
3954 let candidate = make_budget_witness(
3955 current_fn.clone(),
3956 current_cost,
3957 accepted_rewrites,
3958 objective,
3959 origin_metric,
3960 winner_metric,
3961 );
3962 if better_budget_witness(&candidate, &best_witness) {
3963 best_witness = candidate;
3964 }
3965 }
3966 }
3967
3968 Ok(GuidedRolloutResult { best_witness })
3969}
3970
3971pub fn search_winning_budget_frontier(
3974 artifact: &PirMcmcArtifact,
3975 options: PirMcmcBudgetFrontierOptions,
3976) -> Result<PirMcmcBudgetFrontierResult> {
3977 search_winning_budget_frontier_with_rollout(
3978 artifact,
3979 options,
3980 |action_budget, _rollout_idx, rollout_seed, origin_metric, winner_metric| {
3981 let transforms = get_pir_transforms_for_run(artifact.run_options.enable_formal_oracle);
3982 let weights = build_witness_guided_transform_weights(
3983 &transforms,
3984 artifact,
3985 options.witness_kind_boost,
3986 );
3987 let rollout = run_witness_guided_rollout(
3988 artifact,
3989 action_budget,
3990 rollout_seed,
3991 weights,
3992 options.proposal_attempts_per_rewrite,
3993 transforms,
3994 origin_metric,
3995 winner_metric,
3996 )?;
3997 Ok(rollout.best_witness)
3998 },
3999 )
4000}
4001
4002fn search_winning_budget_frontier_with_rollout<F>(
4003 artifact: &PirMcmcArtifact,
4004 options: PirMcmcBudgetFrontierOptions,
4005 mut run_rollout: F,
4006) -> Result<PirMcmcBudgetFrontierResult>
4007where
4008 F: FnMut(usize, usize, u64, u128, u128) -> Result<PirMcmcBudgetWitness>,
4009{
4010 validate_prefix_minimization_artifact(artifact)?;
4011 if !objective_supports_budget_frontier_search(artifact.run_options.objective) {
4012 return Err(anyhow::anyhow!(
4013 "budget frontier search currently supports only objectives that can be recomputed without stored toggle stimulus; got {}",
4014 artifact.run_options.objective.value_name()
4015 ));
4016 }
4017 let budgets = frontier_budgets(options)?;
4018 let objective = artifact.run_options.objective;
4019 let origin_metric = objective.metric(&artifact.origin_cost);
4020 let winner_metric = objective.metric(&artifact.raw_winner_cost);
4021 if winner_metric >= origin_metric {
4022 return Err(anyhow::anyhow!(
4023 "artifact does not contain a positive objective win: origin_metric={}, winner_metric={}",
4024 origin_metric,
4025 winner_metric
4026 ));
4027 }
4028
4029 let mut carried_guided = make_budget_witness(
4030 artifact.origin_fn.clone(),
4031 artifact.origin_cost,
4032 0,
4033 objective,
4034 origin_metric,
4035 winner_metric,
4036 );
4037 let mut points = Vec::with_capacity(budgets.len());
4038 for action_budget in budgets {
4039 let prefix_baseline =
4040 prefix_baseline_for_budget(artifact, action_budget, origin_metric, winner_metric);
4041 let mut best_for_budget = carried_guided.clone();
4042 for rollout_idx in 0..options.rollouts_per_budget {
4043 let rollout_seed = options
4044 .seed
4045 .wrapping_add((action_budget as u64).wrapping_mul(1_000_003))
4046 .wrapping_add(rollout_idx as u64);
4047 let rollout_witness = run_rollout(
4048 action_budget,
4049 rollout_idx,
4050 rollout_seed,
4051 origin_metric,
4052 winner_metric,
4053 )?;
4054 if better_budget_witness(&rollout_witness, &best_for_budget) {
4055 best_for_budget = rollout_witness;
4056 }
4057 }
4058 carried_guided = best_for_budget.clone();
4059 points.push(PirMcmcBudgetFrontierPoint {
4060 action_budget,
4061 guided: best_for_budget,
4062 prefix_baseline,
4063 });
4064 }
4065
4066 Ok(PirMcmcBudgetFrontierResult {
4067 origin_metric,
4068 winner_metric,
4069 original_winning_provenance_len: artifact.winning_provenance.len(),
4070 points,
4071 })
4072}
4073
4074pub fn run_pir_mcmc_with_shared_best(
4075 start_fn: IrFn,
4076 options: RunOptions,
4077 shared_best: Option<Arc<Best>>,
4078 checkpoint_tx: Option<Sender<CheckpointMsg>>,
4079 accepted_sample_tx: Option<Sender<AcceptedSampleMsg>>,
4080) -> Result<PirMcmcResult> {
4081 let prepared = prepare_run_start(start_fn, &options)?;
4082 let runner = PirSegmentRunner {
4083 objective: options.objective,
4084 extension_costing_mode: options.extension_costing_mode,
4085 g8r_evaluation_mode: options.g8r_evaluation_mode.clone(),
4086 canonical_g8r_options: options.canonical_g8r_options.clone(),
4087 weighted_switching_options: options.weighted_switching_options,
4088 initial_temperature: options.initial_temperature,
4089 constraints: prepared.effective_constraints,
4090 enable_formal_oracle: options.enable_formal_oracle,
4091 progress_iters: options.progress_iters,
4092 checkpoint_iters: options.checkpoint_iters,
4093 checkpoint_tx,
4094 accepted_sample_tx,
4095 shared_best,
4096 trajectory_dir: options.trajectory_dir.clone(),
4097 prepared_toggle_stimulus: prepared.prepared_toggle_stimulus,
4098 };
4099
4100 let objective = options.objective;
4101 let threshold = options.initial_temperature as u128;
4102 let constraints = prepared.effective_constraints;
4103
4104 let (best_fn, best_cost, stats) = run_multichain(
4105 prepared.start_fn,
4106 options.max_iters,
4107 options.seed,
4108 options.threads.max(1) as usize,
4109 options.chain_strategy,
4110 options.checkpoint_iters,
4111 Arc::new(runner),
4112 move |c: &Cost| search_score(c, objective, constraints),
4113 |f: &IrFn| f.to_string(),
4114 move |cur_cost: &Cost, global_best_cost: &Cost| {
4115 let cur_score = search_score(cur_cost, objective, constraints);
4116 let global_best_score = search_score(global_best_cost, objective, constraints);
4117 match (cur_score.violation, global_best_score.violation) {
4118 (Some(_), None) => true,
4119 (None, Some(_)) => false,
4120 (Some(cur_violation), Some(best_violation)) => {
4121 repair_energy(cur_violation)
4122 > repair_energy(best_violation).saturating_add(threshold)
4123 }
4124 (None, None) => {
4125 objective.metric(cur_cost)
4126 > objective.metric(global_best_cost).saturating_add(threshold)
4127 }
4128 }
4129 },
4130 |best_fn: &IrFn, _, _| best_fn.clone(),
4131 )?;
4132
4133 Ok(PirMcmcResult {
4134 best_fn,
4135 best_cost,
4136 stats,
4137 })
4138}
4139
4140#[cfg(test)]
4141mod tests {
4142 use std::collections::HashSet;
4143 use std::fs;
4144 use std::os::unix::fs::PermissionsExt;
4145
4146 use super::*;
4147 use count_toggles::WeightedSwitchingOptions;
4148 use tempfile::tempdir;
4149 use xlsynth_g8r::gatify::ir2gate::{self, GatifyOptions};
4150 use xlsynth_pir::ir::{ExtNaryAddArchitecture, NodePayload};
4151 use xlsynth_pir::ir_parser;
4152 use xlsynth_pir::ir_utils::remap_payload_with;
4153
4154 fn parse_fn(ir_text: &str) -> IrFn {
4155 let mut parser = ir_parser::Parser::new(ir_text);
4156 parser.parse_fn().unwrap()
4157 }
4158
4159 fn parse_pkg(ir_text: &str) -> PirPackage {
4160 let mut parser = ir_parser::Parser::new(ir_text);
4161 parser.parse_and_validate_package().unwrap()
4162 }
4163
4164 fn write_executable_script(dir: &Path, name: &str, body: &str) -> PathBuf {
4165 let path = dir.join(name);
4166 fs::write(&path, body).unwrap();
4167 let mut permissions = fs::metadata(&path).unwrap().permissions();
4168 permissions.set_mode(0o755);
4169 fs::set_permissions(&path, permissions).unwrap();
4170 path
4171 }
4172
4173 fn test_run_options(objective: Objective) -> RunOptions {
4174 RunOptions {
4175 max_iters: 1,
4176 threads: 1,
4177 chain_strategy: ChainStrategy::Independent,
4178 checkpoint_iters: 100,
4179 progress_iters: 0,
4180 seed: 1,
4181 initial_temperature: 1.0,
4182 objective,
4183 extension_costing_mode: ExtensionCostingMode::Preserve,
4184 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4185 canonical_g8r_options: CanonicalG8rOptions::default(),
4186 max_allowed_depth: None,
4187 max_allowed_area: None,
4188 weighted_switching_options: WeightedSwitchingOptions::default(),
4189 enable_formal_oracle: false,
4190 trajectory_dir: None,
4191 toggle_stimulus: None,
4192 }
4193 }
4194
4195 fn cost_with_pir_nodes(pir_nodes: usize) -> Cost {
4196 Cost {
4197 pir_nodes,
4198 g8r_nodes: pir_nodes,
4199 g8r_depth: pir_nodes,
4200 g8r_le_graph_milli: 0,
4201 g8r_gate_output_toggles: 0,
4202 g8r_weighted_switching_milli: 0,
4203 g8r_post_and_nodes: 0,
4204 g8r_post_depth: 0,
4205 g8r_post_le_graph_milli: 0,
4206 g8r_post_gate_output_toggles: 0,
4207 g8r_post_weighted_switching_milli: 0,
4208 }
4209 }
4210
4211 fn renamed_state(base: &IrFn, name: &str) -> IrFn {
4212 let mut f = base.clone();
4213 f.name = name.to_string();
4214 f
4215 }
4216
4217 fn manual_prefix_artifact() -> PirMcmcArtifact {
4218 let origin = parse_fn(
4219 r#"fn origin(x: bits[8] id=1) -> bits[8] {
4220 ret identity.2: bits[8] = identity(x, id=2)
4221}"#,
4222 );
4223 let step1_state = renamed_state(&origin, "step1");
4224 let step2_state = renamed_state(&origin, "step2");
4225 let step3_state = renamed_state(&origin, "step3");
4226 PirMcmcArtifact {
4227 origin_fn: origin.clone(),
4228 origin_cost: cost_with_pir_nodes(100),
4229 run_options: test_run_options(Objective::Nodes),
4230 raw_winner_fn: step3_state.clone(),
4231 raw_winner_cost: cost_with_pir_nodes(50),
4232 winning_provenance: vec![
4233 PirMcmcProvenanceAction::AcceptedRewrite {
4234 action_index: 1,
4235 chain_no: 0,
4236 global_iter: 2,
4237 transform_kind: PirTransformKind::NotNotCancel,
4238 state: step1_state,
4239 cost: cost_with_pir_nodes(90),
4240 },
4241 PirMcmcProvenanceAction::AcceptedRewrite {
4242 action_index: 2,
4243 chain_no: 0,
4244 global_iter: 4,
4245 transform_kind: PirTransformKind::NegNegCancel,
4246 state: step2_state,
4247 cost: cost_with_pir_nodes(70),
4248 },
4249 PirMcmcProvenanceAction::AcceptedRewrite {
4250 action_index: 3,
4251 chain_no: 0,
4252 global_iter: 7,
4253 transform_kind: PirTransformKind::SelSameArmsFold,
4254 state: step3_state,
4255 cost: cost_with_pir_nodes(50),
4256 },
4257 ],
4258 }
4259 }
4260
4261 #[derive(Debug)]
4262 struct RemoveDeadNodeTestTransform;
4263
4264 impl PirTransform for RemoveDeadNodeTestTransform {
4265 fn kind(&self) -> PirTransformKind {
4266 PirTransformKind::NotNotCancel
4267 }
4268
4269 fn find_candidates(&mut self, f: &IrFn) -> Vec<crate::transforms::TransformCandidate> {
4270 f.node_refs()
4271 .into_iter()
4272 .filter(|nr| {
4273 f.get_node(*nr)
4274 .name
4275 .as_deref()
4276 .map(|name| name.starts_with("dead"))
4277 .unwrap_or(false)
4278 })
4279 .map(|nr| crate::transforms::TransformCandidate {
4280 location: crate::transforms::TransformLocation::Node(nr),
4281 always_equivalent: true,
4282 })
4283 .collect()
4284 }
4285
4286 fn apply(
4287 &self,
4288 f: &mut IrFn,
4289 loc: &crate::transforms::TransformLocation,
4290 ) -> Result<(), String> {
4291 let crate::transforms::TransformLocation::Node(nr) = loc else {
4292 return Err("RemoveDeadNodeTestTransform expects a node location".to_string());
4293 };
4294 f.get_node_mut(*nr).payload = NodePayload::Nil;
4295 Ok(())
4296 }
4297 }
4298
4299 #[test]
4300 fn pir_mcmc_runs_and_is_deterministic_on_simple_add() {
4301 let ir_text = r#"fn add(x: bits[8] id=10, y: bits[8] id=20) -> bits[8] {
4302 ret add.42: bits[8] = add(x, y, id=42)
4303}"#;
4304 let mut parser = ir_parser::Parser::new(ir_text);
4305 let ir_fn = parser.parse_fn().unwrap();
4306
4307 let opts = RunOptions {
4308 max_iters: 10,
4309 threads: 1,
4310 chain_strategy: ChainStrategy::Independent,
4311 checkpoint_iters: 5000,
4312 progress_iters: 0,
4313 seed: 1,
4314 initial_temperature: 5.0,
4315 objective: Objective::Nodes,
4316 extension_costing_mode: ExtensionCostingMode::Preserve,
4317 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4318 canonical_g8r_options: CanonicalG8rOptions::default(),
4319 max_allowed_depth: None,
4320 max_allowed_area: None,
4321 weighted_switching_options: WeightedSwitchingOptions::default(),
4322 enable_formal_oracle: false,
4323 trajectory_dir: None,
4324 toggle_stimulus: None,
4325 };
4326
4327 let res1 = run_pir_mcmc(ir_fn.clone(), opts.clone()).unwrap();
4328 let res2 = run_pir_mcmc(ir_fn.clone(), opts).unwrap();
4329
4330 assert_eq!(res1.best_cost.pir_nodes, ir_fn.nodes.len());
4331 assert_eq!(res2.best_cost.pir_nodes, ir_fn.nodes.len());
4332
4333 assert_eq!(res1.best_fn.to_string(), res2.best_fn.to_string());
4336 }
4337
4338 #[test]
4339 fn prefix_minimization_selects_earliest_prefix_meeting_requested_win() {
4340 let artifact = manual_prefix_artifact();
4341
4342 let retain_all = minimize_winning_prefix(
4343 &artifact,
4344 PirMcmcPrefixMinimizeOptions {
4345 retained_win_fraction: 1.0,
4346 },
4347 )
4348 .unwrap();
4349 assert_eq!(retain_all.provenance_action_count, 3);
4350 assert_eq!(retain_all.witness_metric, 50);
4351
4352 let retain_most = minimize_winning_prefix(
4353 &artifact,
4354 PirMcmcPrefixMinimizeOptions {
4355 retained_win_fraction: 0.6,
4356 },
4357 )
4358 .unwrap();
4359 assert_eq!(retain_most.provenance_action_count, 2);
4360 assert_eq!(retain_most.witness_fn.name, "step2");
4361 assert_eq!(retain_most.witness_metric, 70);
4362 assert_eq!(retain_most.original_winning_provenance_len, 3);
4363 assert!((retain_most.actual_retained_win_fraction - 0.6).abs() < 1e-12);
4364
4365 let retain_none = minimize_winning_prefix(
4366 &artifact,
4367 PirMcmcPrefixMinimizeOptions {
4368 retained_win_fraction: 0.0,
4369 },
4370 )
4371 .unwrap();
4372 assert_eq!(retain_none.provenance_action_count, 0);
4373 assert_eq!(retain_none.witness_fn.name, "origin");
4374 assert_eq!(retain_none.witness_metric, 100);
4375 }
4376
4377 #[test]
4378 fn prefix_minimization_rejects_invalid_fraction_and_non_winning_artifacts() {
4379 let artifact = manual_prefix_artifact();
4380 for retained_win_fraction in [f64::NAN, -0.1, 1.1] {
4381 let err = minimize_winning_prefix(
4382 &artifact,
4383 PirMcmcPrefixMinimizeOptions {
4384 retained_win_fraction,
4385 },
4386 )
4387 .unwrap_err();
4388 assert!(
4389 err.to_string().contains("retained_win_fraction"),
4390 "unexpected error: {err}"
4391 );
4392 }
4393
4394 let mut no_win = artifact.clone();
4395 no_win.origin_cost = no_win.raw_winner_cost;
4396 let err = minimize_winning_prefix(
4397 &no_win,
4398 PirMcmcPrefixMinimizeOptions {
4399 retained_win_fraction: 0.5,
4400 },
4401 )
4402 .unwrap_err();
4403 assert!(
4404 err.to_string().contains("positive objective win"),
4405 "unexpected error: {err}"
4406 );
4407 }
4408
4409 #[test]
4410 fn frontier_schedule_validates_step_and_max() {
4411 let opts = PirMcmcBudgetFrontierOptions {
4412 budget_step: 4,
4413 max_actions: 16,
4414 rollouts_per_budget: 1,
4415 seed: 1,
4416 witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4417 proposal_attempts_per_rewrite:
4418 PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4419 };
4420 assert_eq!(frontier_budgets(opts).unwrap(), vec![4, 8, 12, 16]);
4421 assert_eq!(
4422 frontier_budgets(PirMcmcBudgetFrontierOptions {
4423 max_actions: 10,
4424 ..opts
4425 })
4426 .unwrap(),
4427 vec![4, 8, 10]
4428 );
4429 assert!(
4430 frontier_budgets(PirMcmcBudgetFrontierOptions {
4431 budget_step: 0,
4432 ..opts
4433 })
4434 .is_err()
4435 );
4436 assert!(
4437 frontier_budgets(PirMcmcBudgetFrontierOptions {
4438 max_actions: 0,
4439 ..opts
4440 })
4441 .is_err()
4442 );
4443 assert!(
4444 frontier_budgets(PirMcmcBudgetFrontierOptions {
4445 budget_step: 8,
4446 max_actions: 4,
4447 ..opts
4448 })
4449 .is_err()
4450 );
4451 }
4452
4453 #[test]
4454 fn win_percent_vs_origin_reports_percentage_points() {
4455 assert!(
4456 (win_percent_vs_origin_for_metric(1205, 1173) - 2.655_601_659_751_037).abs() < 1e-12
4457 );
4458 }
4459
4460 #[test]
4461 fn witness_guided_weights_favor_lineage_kinds_without_excluding_others() {
4462 #[derive(Debug)]
4463 struct KindOnlyTransform(PirTransformKind);
4464 impl PirTransform for KindOnlyTransform {
4465 fn kind(&self) -> PirTransformKind {
4466 self.0
4467 }
4468
4469 fn find_candidates(&mut self, _f: &IrFn) -> Vec<crate::transforms::TransformCandidate> {
4470 Vec::new()
4471 }
4472
4473 fn apply(
4474 &self,
4475 _f: &mut IrFn,
4476 _loc: &crate::transforms::TransformLocation,
4477 ) -> Result<(), String> {
4478 Ok(())
4479 }
4480 }
4481
4482 let artifact = manual_prefix_artifact();
4483 let transforms: Vec<Box<dyn PirTransform>> = vec![
4484 Box::new(KindOnlyTransform(PirTransformKind::NotNotCancel)),
4485 Box::new(KindOnlyTransform(PirTransformKind::CmpSwap)),
4486 ];
4487 let weights = build_witness_guided_transform_weights(&transforms, &artifact, 4.0);
4488 assert_eq!(weights, vec![5.0, 1.0]);
4489 }
4490
4491 #[test]
4492 fn frontier_reports_prefix_baseline_and_carries_guided_points_forward() {
4493 let artifact = manual_prefix_artifact();
4494 let opts = PirMcmcBudgetFrontierOptions {
4495 budget_step: 1,
4496 max_actions: 3,
4497 rollouts_per_budget: 1,
4498 seed: 7,
4499 witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4500 proposal_attempts_per_rewrite:
4501 PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4502 };
4503 let result = search_winning_budget_frontier_with_rollout(
4504 &artifact,
4505 opts,
4506 |budget, _, _, origin_metric, winner_metric| {
4507 let (state_name, cost) = match budget {
4508 1 => ("guided1", cost_with_pir_nodes(95)),
4509 2 => ("guided2", cost_with_pir_nodes(80)),
4510 3 => ("guided3", cost_with_pir_nodes(85)),
4511 _ => unreachable!(),
4512 };
4513 Ok(make_budget_witness(
4514 renamed_state(&artifact.origin_fn, state_name),
4515 cost,
4516 budget,
4517 Objective::Nodes,
4518 origin_metric,
4519 winner_metric,
4520 ))
4521 },
4522 )
4523 .unwrap();
4524
4525 assert_eq!(result.points.len(), 3);
4526 assert_eq!(result.points[0].prefix_baseline.metric, 90);
4527 assert_eq!(result.points[1].prefix_baseline.metric, 70);
4528 assert_eq!(result.points[2].prefix_baseline.metric, 50);
4529 assert_eq!(result.points[0].guided.metric, 95);
4530 assert_eq!(result.points[1].guided.metric, 80);
4531 assert_eq!(
4532 result.points[2].guided.metric, 80,
4533 "worse later searches must carry forward the prior frontier point"
4534 );
4535 }
4536
4537 #[test]
4538 fn frontier_uses_origin_fallback_when_rollouts_do_not_improve() {
4539 let artifact = manual_prefix_artifact();
4540 let opts = PirMcmcBudgetFrontierOptions {
4541 budget_step: 2,
4542 max_actions: 4,
4543 rollouts_per_budget: 1,
4544 seed: 1,
4545 witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4546 proposal_attempts_per_rewrite:
4547 PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4548 };
4549 let result = search_winning_budget_frontier_with_rollout(
4550 &artifact,
4551 opts,
4552 |budget, _, _, origin_metric, winner_metric| {
4553 Ok(make_budget_witness(
4554 renamed_state(&artifact.origin_fn, &format!("noop{budget}")),
4555 artifact.origin_cost,
4556 budget,
4557 Objective::Nodes,
4558 origin_metric,
4559 winner_metric,
4560 ))
4561 },
4562 )
4563 .unwrap();
4564 assert_eq!(result.points[0].guided.provenance_action_count, 0);
4565 assert_eq!(result.points[0].guided.metric, 100);
4566 assert_eq!(result.points[1].guided.provenance_action_count, 0);
4567 assert_eq!(result.points[1].guided.metric, 100);
4568 }
4569
4570 #[test]
4571 fn frontier_rejects_no_win_and_toggle_objectives() {
4572 let opts = PirMcmcBudgetFrontierOptions {
4573 budget_step: 1,
4574 max_actions: 1,
4575 rollouts_per_budget: 1,
4576 seed: 1,
4577 witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4578 proposal_attempts_per_rewrite:
4579 PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4580 };
4581 let mut no_win = manual_prefix_artifact();
4582 no_win.raw_winner_cost = no_win.origin_cost;
4583 assert!(search_winning_budget_frontier(&no_win, opts).is_err());
4584
4585 let mut toggle_artifact = manual_prefix_artifact();
4586 toggle_artifact.run_options.objective = Objective::G8rNodesTimesDepthTimesToggles;
4587 toggle_artifact.raw_winner_cost.g8r_gate_output_toggles = 1;
4588 assert!(search_winning_budget_frontier(&toggle_artifact, opts).is_err());
4589 }
4590
4591 #[test]
4592 fn artifact_api_rejects_unsupported_run_shapes() {
4593 let f = parse_fn(
4594 r#"fn f(x: bits[1] id=1) -> bits[1] {
4595 ret identity.2: bits[1] = identity(x, id=2)
4596}"#,
4597 );
4598
4599 let mut constrained = test_run_options(Objective::G8rNodes);
4600 constrained.max_allowed_depth = Some(4);
4601 assert!(run_pir_mcmc_with_artifact(f.clone(), constrained).is_err());
4602
4603 let mut implicit_constraint =
4604 test_run_options(Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress);
4605 implicit_constraint.toggle_stimulus = Some(vec![
4606 IrValue::parse_typed("(bits[1]:0)").unwrap(),
4607 IrValue::parse_typed("(bits[1]:1)").unwrap(),
4608 ]);
4609 assert!(run_pir_mcmc_with_artifact(f, implicit_constraint).is_err());
4610 }
4611
4612 #[test]
4613 fn artifact_api_supports_independent_multichain_runs() {
4614 let f = parse_fn(
4615 r#"fn f(x: bits[8] id=1) -> bits[8] {
4616 dead: bits[8] = identity(x, id=2)
4617 ret live: bits[8] = identity(x, id=3)
4618}"#,
4619 );
4620 let mut opts = test_run_options(Objective::Nodes);
4621 opts.max_iters = 1;
4622 opts.threads = 2;
4623 let artifact = run_pir_mcmc_with_artifact_using_transform_factory(
4624 f,
4625 opts,
4626 Arc::new(|| vec![Box::new(RemoveDeadNodeTestTransform)]),
4627 )
4628 .unwrap()
4629 .artifact;
4630
4631 assert_eq!(artifact.winning_provenance.len(), 1);
4632 let PirMcmcProvenanceAction::AcceptedRewrite {
4633 action_index,
4634 chain_no,
4635 ..
4636 } = artifact.winning_provenance.last().unwrap()
4637 else {
4638 panic!("expected accepted rewrite provenance");
4639 };
4640 assert_eq!(*action_index, 1);
4641 assert_eq!(*chain_no, 0);
4642 assert_eq!(
4643 artifact
4644 .winning_provenance
4645 .last()
4646 .unwrap()
4647 .state()
4648 .to_string(),
4649 artifact.raw_winner_fn.to_string()
4650 );
4651 }
4652
4653 #[test]
4654 fn artifact_api_supports_explore_exploit_multichain_runs() {
4655 let f = parse_fn(
4656 r#"fn f(x: bits[1] id=1) -> bits[1] {
4657 ret identity.2: bits[1] = identity(x, id=2)
4658}"#,
4659 );
4660 let mut opts = test_run_options(Objective::Nodes);
4661 opts.max_iters = 0;
4662 opts.threads = 2;
4663 opts.chain_strategy = ChainStrategy::ExploreExploit;
4664 let artifact = run_pir_mcmc_with_artifact(f, opts).unwrap();
4665 assert!(artifact.winning_provenance.is_empty());
4666 assert_eq!(artifact.raw_winner_cost, artifact.origin_cost);
4667 }
4668
4669 #[test]
4670 fn artifact_segment_records_handoff_before_later_rewrite() {
4671 let f = parse_fn(
4672 r#"fn f(x: bits[8] id=1) -> bits[8] {
4673 dead: bits[8] = identity(x, id=2)
4674 ret live: bits[8] = identity(x, id=3)
4675}"#,
4676 );
4677 let origin_cost = cost_with_effort_options_toggle_stimulus_and_extension_mode(
4678 &f,
4679 Objective::Nodes,
4680 None,
4681 &WeightedSwitchingOptions::default(),
4682 ExtensionCostingMode::Preserve,
4683 )
4684 .unwrap();
4685 let runner = PirArtifactSegmentRunner {
4686 objective: Objective::Nodes,
4687 extension_costing_mode: ExtensionCostingMode::Preserve,
4688 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4689 canonical_g8r_options: CanonicalG8rOptions::default(),
4690 weighted_switching_options: WeightedSwitchingOptions::default(),
4691 initial_temperature: 1.0,
4692 constraints: ConstraintLimits::default(),
4693 enable_formal_oracle: false,
4694 progress_iters: 0,
4695 checkpoint_iters: 0,
4696 checkpoint_tx: None,
4697 shared_best: None,
4698 trajectory_dir: None,
4699 prepared_toggle_stimulus: None,
4700 transform_factory: Arc::new(|| vec![Box::new(RemoveDeadNodeTestTransform)]),
4701 };
4702 let out = runner
4703 .run_segment(
4704 ProvenancedChainState::origin(f, origin_cost).with_xls_optimized_handoff(1, 7),
4705 SegmentRunParams {
4706 chain_no: 1,
4707 role: ChainRole::Exploit,
4708 iter_offset: 7,
4709 segment_iters: 1,
4710 total_iters: 8,
4711 seed: 1,
4712 },
4713 )
4714 .unwrap();
4715
4716 assert_eq!(out.best_state.raw_winner_provenance.len(), 2);
4717 assert!(matches!(
4718 out.best_state.raw_winner_provenance[0],
4719 PirMcmcProvenanceAction::XlsOptimizedHandoff {
4720 action_index: 1,
4721 chain_no: 1,
4722 global_iter: 7,
4723 ..
4724 }
4725 ));
4726 assert!(matches!(
4727 out.best_state.raw_winner_provenance[1],
4728 PirMcmcProvenanceAction::AcceptedRewrite {
4729 action_index: 2,
4730 chain_no: 1,
4731 global_iter: 8,
4732 ..
4733 }
4734 ));
4735 }
4736
4737 #[test]
4738 fn artifact_run_is_deterministic_and_captures_raw_winning_provenance() {
4739 let f = parse_fn(
4740 r#"fn f(x: bits[8] id=1) -> bits[8] {
4741 dead: bits[8] = identity(x, id=2)
4742 ret live: bits[8] = identity(x, id=3)
4743}"#,
4744 );
4745 let mut opts = test_run_options(Objective::Nodes);
4746 opts.max_iters = 1;
4747 opts.seed = 7;
4748 let transforms1: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4749 let transforms2: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4750
4751 let artifact1 =
4752 run_pir_mcmc_with_artifact_using_transforms(f.clone(), opts.clone(), transforms1)
4753 .unwrap()
4754 .artifact;
4755 let artifact2 = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms2)
4756 .unwrap()
4757 .artifact;
4758
4759 assert!(
4760 artifact1.raw_winner_cost.pir_nodes < artifact1.origin_cost.pir_nodes,
4761 "expected a real node-count win in the deterministic fixture"
4762 );
4763 assert_eq!(
4764 artifact1.origin_fn.to_string(),
4765 artifact2.origin_fn.to_string()
4766 );
4767 assert_eq!(artifact1.origin_cost, artifact2.origin_cost);
4768 assert_eq!(artifact1.raw_winner_cost, artifact2.raw_winner_cost);
4769 assert_eq!(
4770 artifact1.raw_winner_fn.to_string(),
4771 artifact2.raw_winner_fn.to_string()
4772 );
4773 assert_eq!(
4774 artifact1.winning_provenance.len(),
4775 artifact2.winning_provenance.len()
4776 );
4777 assert!(
4778 !artifact1.winning_provenance.is_empty(),
4779 "expected a non-empty winning provenance"
4780 );
4781 let last1 = artifact1.winning_provenance.last().unwrap();
4782 let last2 = artifact2.winning_provenance.last().unwrap();
4783 assert_eq!(last1.cost(), artifact1.raw_winner_cost);
4784 assert_eq!(
4785 last1.state().to_string(),
4786 artifact1.raw_winner_fn.to_string()
4787 );
4788 assert_eq!(last1.cost(), last2.cost());
4789 assert_eq!(last1.state().to_string(), last2.state().to_string());
4790 assert_eq!(last1.transform_kind(), last2.transform_kind());
4791 }
4792
4793 #[test]
4794 fn artifact_provenance_can_be_minimized_end_to_end() {
4795 let f = parse_fn(
4796 r#"fn f(x: bits[8] id=1) -> bits[8] {
4797 dead_a: bits[8] = identity(x, id=2)
4798 dead_b: bits[8] = identity(x, id=3)
4799 ret live: bits[8] = identity(x, id=4)
4800}"#,
4801 );
4802 let mut opts = test_run_options(Objective::Nodes);
4803 opts.max_iters = 2;
4804 let transforms: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4805
4806 let artifact = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms)
4807 .unwrap()
4808 .artifact;
4809 assert_eq!(artifact.origin_cost.pir_nodes, 5);
4810 assert_eq!(artifact.raw_winner_cost.pir_nodes, 3);
4811 assert_eq!(artifact.winning_provenance.len(), 2);
4812
4813 let minimized = minimize_winning_prefix(
4814 &artifact,
4815 PirMcmcPrefixMinimizeOptions {
4816 retained_win_fraction: 0.5,
4817 },
4818 )
4819 .unwrap();
4820 assert_eq!(minimized.provenance_action_count, 1);
4821 assert_eq!(minimized.original_winning_provenance_len, 2);
4822 assert_eq!(minimized.origin_metric, 5);
4823 assert_eq!(minimized.winner_metric, 3);
4824 assert_eq!(minimized.witness_metric, 4);
4825 assert!((minimized.actual_retained_win_fraction - 0.5).abs() < 1e-12);
4826 }
4827
4828 #[test]
4829 fn durable_artifact_round_trips_and_minimizes_identically() {
4830 let pkg = parse_pkg(
4831 r#"package sample
4832
4833top fn f(x: bits[8] id=1) -> bits[8] {
4834 dead_a: bits[8] = identity(x, id=2)
4835 dead_b: bits[8] = identity(x, id=3)
4836 ret live: bits[8] = identity(x, id=4)
4837}
4838"#,
4839 );
4840 let f = pkg.get_fn("f").unwrap().clone();
4841 let mut opts = test_run_options(Objective::Nodes);
4842 opts.max_iters = 2;
4843 let transforms: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4844 let artifact = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms)
4845 .unwrap()
4846 .artifact;
4847 let before = minimize_winning_prefix(
4848 &artifact,
4849 PirMcmcPrefixMinimizeOptions {
4850 retained_win_fraction: 0.5,
4851 },
4852 )
4853 .unwrap();
4854
4855 let run_dir = tempdir().unwrap();
4856 write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4857 let loaded = read_pir_mcmc_artifact_dir(run_dir.path()).unwrap();
4858 let after = minimize_winning_prefix(
4859 &loaded.artifact,
4860 PirMcmcPrefixMinimizeOptions {
4861 retained_win_fraction: 0.5,
4862 },
4863 )
4864 .unwrap();
4865
4866 assert_eq!(loaded.top_fn_name, "f");
4867 assert_eq!(loaded.artifact.origin_cost, artifact.origin_cost);
4868 assert_eq!(
4869 loaded.artifact.origin_fn.to_string(),
4870 artifact.origin_fn.to_string()
4871 );
4872 assert_eq!(
4873 loaded.artifact.raw_winner_fn.to_string(),
4874 artifact.raw_winner_fn.to_string()
4875 );
4876 assert_eq!(
4877 loaded.artifact.winning_provenance.len(),
4878 artifact.winning_provenance.len()
4879 );
4880 assert_eq!(
4881 loaded.artifact.run_options.g8r_evaluation_mode,
4882 artifact.run_options.g8r_evaluation_mode
4883 );
4884 assert_eq!(
4885 loaded.artifact.winning_provenance[0].transform_kind(),
4886 artifact.winning_provenance[0].transform_kind()
4887 );
4888 assert_eq!(
4889 after.provenance_action_count,
4890 before.provenance_action_count
4891 );
4892 assert_eq!(after.witness_metric, before.witness_metric);
4893 assert_eq!(after.witness_fn.to_string(), before.witness_fn.to_string());
4894 }
4895
4896 #[test]
4897 fn durable_artifact_manifest_is_deterministic_for_fixed_run() {
4898 let pkg = parse_pkg(
4899 r#"package sample
4900
4901top fn f(x: bits[8] id=1) -> bits[8] {
4902 dead_a: bits[8] = identity(x, id=2)
4903 dead_b: bits[8] = identity(x, id=3)
4904 ret live: bits[8] = identity(x, id=4)
4905}
4906"#,
4907 );
4908 let f = pkg.get_fn("f").unwrap().clone();
4909 let mut opts = test_run_options(Objective::Nodes);
4910 opts.max_iters = 2;
4911 let artifact1 = run_pir_mcmc_with_artifact_using_transforms(
4912 f.clone(),
4913 opts.clone(),
4914 vec![Box::new(RemoveDeadNodeTestTransform)],
4915 )
4916 .unwrap()
4917 .artifact;
4918 let artifact2 = run_pir_mcmc_with_artifact_using_transforms(
4919 f,
4920 opts,
4921 vec![Box::new(RemoveDeadNodeTestTransform)],
4922 )
4923 .unwrap()
4924 .artifact;
4925
4926 let run_dir1 = tempdir().unwrap();
4927 let run_dir2 = tempdir().unwrap();
4928 let artifact_dir1 = write_pir_mcmc_artifact_dir(&artifact1, &pkg, run_dir1.path()).unwrap();
4929 let artifact_dir2 = write_pir_mcmc_artifact_dir(&artifact2, &pkg, run_dir2.path()).unwrap();
4930 let manifest1 =
4931 fs::read_to_string(artifact_dir1.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap();
4932 let manifest2 =
4933 fs::read_to_string(artifact_dir2.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap();
4934 assert_eq!(manifest1, manifest2);
4935 }
4936
4937 #[test]
4938 fn durable_artifact_manifest_canonicalizes_relative_postprocessor_path() {
4939 let pkg = parse_pkg(
4940 r#"package sample
4941
4942top fn f(x: bits[8] id=1) -> bits[8] {
4943 dead: bits[8] = identity(x, id=2)
4944 ret live: bits[8] = identity(x, id=3)
4945}
4946"#,
4947 );
4948 let f = pkg.get_fn("f").unwrap().clone();
4949 let cwd = std::env::current_dir().unwrap();
4950 let hook_dir = tempfile::tempdir_in(&cwd).unwrap();
4951 let hook = hook_dir.path().join("identity.sh");
4952 fs::write(&hook, "#!/bin/sh\n").unwrap();
4953 let relative_hook = hook.strip_prefix(&cwd).unwrap().display().to_string();
4954 let mut artifact = run_pir_mcmc_with_artifact_using_transforms(
4955 f,
4956 test_run_options(Objective::Nodes),
4957 vec![Box::new(RemoveDeadNodeTestTransform)],
4958 )
4959 .unwrap()
4960 .artifact;
4961 artifact.run_options.g8r_evaluation_mode = G8rEvaluationMode::ExternalPostprocess {
4962 program: relative_hook,
4963 };
4964 let run_dir = tempdir().unwrap();
4965 let artifact_dir = write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4966 let manifest: serde_json::Value = serde_json::from_str(
4967 &fs::read_to_string(artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap(),
4968 )
4969 .unwrap();
4970 assert_eq!(
4971 manifest["run_options"]["g8r_evaluation_mode"]["program"],
4972 std::fs::canonicalize(&hook).unwrap().display().to_string()
4973 );
4974 }
4975
4976 #[test]
4977 fn durable_artifact_rejects_malformed_action_records() {
4978 let pkg = parse_pkg(
4979 r#"package sample
4980
4981top fn f(x: bits[8] id=1) -> bits[8] {
4982 dead: bits[8] = identity(x, id=2)
4983 ret live: bits[8] = identity(x, id=3)
4984}
4985"#,
4986 );
4987 let f = pkg.get_fn("f").unwrap().clone();
4988 let artifact = run_pir_mcmc_with_artifact_using_transforms(
4989 f,
4990 test_run_options(Objective::Nodes),
4991 vec![Box::new(RemoveDeadNodeTestTransform)],
4992 )
4993 .unwrap()
4994 .artifact;
4995 let run_dir = tempdir().unwrap();
4996 let artifact_dir = write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4997 let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
4998 let mut manifest: serde_json::Value =
4999 serde_json::from_str(&fs::read_to_string(&manifest_path).unwrap()).unwrap();
5000 manifest["winning_provenance"][0]["transform_kind"] = serde_json::Value::Null;
5001 fs::write(
5002 &manifest_path,
5003 serde_json::to_string_pretty(&manifest).unwrap(),
5004 )
5005 .unwrap();
5006
5007 let err = read_pir_mcmc_artifact_dir(run_dir.path())
5008 .err()
5009 .expect("malformed action record must be rejected");
5010 assert!(
5011 err.to_string().contains("missing transform_kind"),
5012 "unexpected error: {err}"
5013 );
5014 }
5015
5016 #[test]
5017 fn durable_artifact_rejects_malformed_or_old_schema_manifests() {
5018 let malformed_dir = tempdir().unwrap();
5019 let malformed_artifact_dir = malformed_dir.path().join(PIR_MCMC_ARTIFACT_DIR_NAME);
5020 fs::create_dir_all(&malformed_artifact_dir).unwrap();
5021 fs::write(
5022 malformed_artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE),
5023 "{not json",
5024 )
5025 .unwrap();
5026 assert!(read_pir_mcmc_artifact_dir(malformed_dir.path()).is_err());
5027
5028 let incomplete_dir = tempdir().unwrap();
5029 let incomplete_artifact_dir = incomplete_dir.path().join(PIR_MCMC_ARTIFACT_DIR_NAME);
5030 fs::create_dir_all(&incomplete_artifact_dir).unwrap();
5031 fs::write(
5032 incomplete_artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE),
5033 r#"{
5034 "schema_version": 1,
5035 "top_fn_name": "f",
5036 "run_options": {
5037 "max_iters": 1,
5038 "threads": 1,
5039 "chain_strategy": "independent",
5040 "checkpoint_iters": 1,
5041 "progress_iters": 0,
5042 "seed": 1,
5043 "initial_temperature": 1.0,
5044 "objective": "nodes",
5045 "extension_costing_mode": "preserve",
5046 "max_allowed_depth": null,
5047 "max_allowed_area": null,
5048 "switching_beta1": 1.0,
5049 "switching_beta2": 0.0,
5050 "switching_primary_output_load": 1.0,
5051 "enable_formal_oracle": false
5052 },
5053 "origin": {"file": "states/origin.ir", "cost": {"pir_nodes": 2, "g8r_nodes": 2, "g8r_depth": 2, "g8r_le_graph_milli": 0, "g8r_gate_output_toggles": 0, "g8r_weighted_switching_milli": 0}},
5054 "raw_winner": {"file": "states/raw-winner.ir", "cost": {"pir_nodes": 1, "g8r_nodes": 1, "g8r_depth": 1, "g8r_le_graph_milli": 0, "g8r_gate_output_toggles": 0, "g8r_weighted_switching_milli": 0}},
5055 "winning_lineage": []
5056}"#,
5057 )
5058 .unwrap();
5059 let err = read_pir_mcmc_artifact_dir(incomplete_dir.path())
5060 .err()
5061 .expect("old-schema artifact must be rejected");
5062 assert!(
5063 err.to_string().contains("schema version 1"),
5064 "unexpected error: {err}"
5065 );
5066 }
5067
5068 #[test]
5069 fn pir_equiv_oracle_rejects_obviously_non_equivalent_rewire() {
5070 let ir_text = r#"fn add_lit(x: bits[8] id=10, y: bits[8] id=20) -> bits[8] {
5073 literal.30: bits[8] = literal(value=0, id=30)
5074 ret add.42: bits[8] = add(x, y, id=42)
5075}"#;
5076 let mut parser1 = ir_parser::Parser::new(ir_text);
5077 let orig_fn = parser1.parse_fn().unwrap();
5078 let mut parser2 = ir_parser::Parser::new(ir_text);
5079 let mut rewired_fn = parser2.parse_fn().unwrap();
5080
5081 let mut add_ref = None;
5084 let mut lit_ref = None;
5085 for nr in rewired_fn.node_refs() {
5086 let node = rewired_fn.get_node(nr);
5087 match &node.payload {
5088 xlsynth_pir::ir::NodePayload::Literal(_) => {
5089 lit_ref = Some(nr);
5090 }
5091 xlsynth_pir::ir::NodePayload::Binop(xlsynth_pir::ir::Binop::Add, _, _) => {
5092 add_ref = Some(nr);
5093 }
5094 _ => {}
5095 }
5096 }
5097 let add_ref = add_ref.expect("expected add node");
5098 let lit_ref = lit_ref.expect("expected literal node");
5099
5100 let old_add_payload = rewired_fn.get_node(add_ref).payload.clone();
5101 rewired_fn.get_node_mut(add_ref).payload = remap_payload_with(
5102 &old_add_payload,
5103 |(slot, dep)| {
5104 if slot == 1 { lit_ref } else { dep }
5105 },
5106 );
5107
5108 let mut rng = Pcg64Mcg::seed_from_u64(1);
5109 let mut baseline_cache = EvalFnBaselineResults::default();
5110 assert!(!pir_equiv_oracle(
5111 &orig_fn,
5112 &rewired_fn,
5113 &mut rng,
5114 4,
5115 false,
5116 &mut baseline_cache,
5117 ));
5118 }
5119
5120 #[test]
5121 fn pir_equiv_oracle_rejects_zero_length_array_params_without_panic() {
5122 let ir_text = r#"fn zero_len(a: bits[8][0] id=10) -> bits[1] {
5123 ret literal.20: bits[1] = literal(value=0, id=20)
5124}"#;
5125 let mut parser1 = ir_parser::Parser::new(ir_text);
5126 let lhs = parser1.parse_fn().unwrap();
5127 let mut parser2 = ir_parser::Parser::new(ir_text);
5128 let rhs = parser2.parse_fn().unwrap();
5129
5130 let mut rng = Pcg64Mcg::seed_from_u64(1);
5131 let mut baseline_cache = EvalFnBaselineResults::default();
5132 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
5133 pir_equiv_oracle(
5134 &lhs,
5135 &rhs,
5136 &mut rng,
5137 4,
5138 false,
5139 &mut baseline_cache,
5140 )
5141 }));
5142
5143 assert!(result.is_ok());
5144 assert!(!result.unwrap());
5145 }
5146
5147 #[test]
5148 fn optimize_pir_fn_via_xls_preserves_ext_nary_add_arch_via_ffi_wrappers() {
5149 let mut parser = ir_parser::Parser::new(
5150 r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5151 ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5152}"#,
5153 );
5154 let f = parser.parse_fn().unwrap();
5155
5156 let optimized =
5157 optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Preserve)
5158 .unwrap();
5159 let ext_nodes = optimized
5160 .nodes
5161 .iter()
5162 .filter(|node| {
5163 matches!(
5164 &node.payload,
5165 NodePayload::ExtNaryAdd {
5166 arch: Some(ExtNaryAddArchitecture::BrentKung),
5167 ..
5168 }
5169 )
5170 })
5171 .count();
5172 assert_eq!(
5173 ext_nodes, 1,
5174 "expected optimized PIR to reconstruct brent_kung ext_nary_add:\n{}",
5175 optimized
5176 );
5177 }
5178
5179 #[test]
5180 fn canonical_g8r_scoring_input_materializes_the_optimized_top() {
5181 let f = parse_fn(
5182 r#"fn f(x: bits[8] id=1) -> bits[8] {
5183 dead: bits[8] = identity(x, id=2)
5184 ret live: bits[8] = identity(x, id=3)
5185}"#,
5186 );
5187
5188 let scoring_input =
5189 canonical_g8r_scoring_input_for_pir_fn(&f, ExtensionCostingMode::Preserve).unwrap();
5190 let mut parser = ir_parser::Parser::new(&scoring_input.ir_text);
5191 let pkg = parser.parse_and_validate_package().unwrap();
5192 let reparsed_top = pkg.get_top_fn().unwrap();
5193
5194 assert_eq!(reparsed_top.to_string(), scoring_input.top_fn.to_string());
5195 assert!(!scoring_input.ir_text.contains("dead:"));
5196 }
5197
5198 #[test]
5199 fn canonical_g8r_scoring_lowering_options_skip_summary_graph_le() {
5200 let source = CanonicalG8rOptions {
5201 compute_graph_logical_effort: true,
5202 graph_logical_effort_beta1: 2.0,
5203 ..CanonicalG8rOptions::default()
5204 };
5205
5206 let scoring = canonical_g8r_scoring_lowering_options(&source);
5207
5208 assert!(!scoring.compute_graph_logical_effort);
5209 assert_eq!(scoring.graph_logical_effort_beta1, 2.0);
5210 }
5211
5212 #[test]
5213 fn optimize_pir_fn_via_xls_can_desugar_ext_nary_add_to_standard_ir() {
5214 let mut parser = ir_parser::Parser::new(
5215 r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5216 ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5217}"#,
5218 );
5219 let f = parser.parse_fn().unwrap();
5220
5221 let optimized =
5222 optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Desugar).unwrap();
5223 let ext_nodes = optimized
5224 .nodes
5225 .iter()
5226 .filter(|node| matches!(&node.payload, NodePayload::ExtNaryAdd { .. }))
5227 .count();
5228 assert_eq!(
5229 ext_nodes, 0,
5230 "expected desugared optimized PIR to contain no ext_nary_add nodes:\n{}",
5231 optimized
5232 );
5233 assert!(
5234 !optimized.to_string().contains("ext_nary_add"),
5235 "expected desugared optimized PIR text to contain no extension op spelling:\n{}",
5236 optimized
5237 );
5238
5239 let cost = cost_with_effort_options_toggle_stimulus_and_extension_mode(
5240 &f,
5241 Objective::G8rNodes,
5242 None,
5243 &WeightedSwitchingOptions::default(),
5244 ExtensionCostingMode::Desugar,
5245 )
5246 .unwrap();
5247 assert!(cost.g8r_nodes > 0);
5248 }
5249
5250 #[test]
5251 fn get_pir_transforms_for_run_prunes_unsafe_only_classes_without_formal_oracle() {
5252 let no_oracle_kinds: HashSet<PirTransformKind> =
5253 get_pir_transforms_for_run(false)
5254 .into_iter()
5255 .map(|t| t.kind())
5256 .collect();
5257 let oracle_kinds: HashSet<PirTransformKind> =
5258 get_pir_transforms_for_run(true)
5259 .into_iter()
5260 .map(|t| t.kind())
5261 .collect();
5262
5263 assert!(!no_oracle_kinds.contains(&PirTransformKind::ShiftHoist));
5264 assert!(!no_oracle_kinds.contains(&PirTransformKind::MaskOperandHighBit));
5265 assert!(!no_oracle_kinds.contains(&PirTransformKind::RewireOperandToSameType));
5266 assert!(!no_oracle_kinds.contains(&PirTransformKind::GuardedPredicateRewire));
5267 assert!(no_oracle_kinds.contains(&PirTransformKind::AbsorbAddOperandIntoExtNaryAdd));
5268 assert!(no_oracle_kinds.contains(&PirTransformKind::AddToExtNaryAdd));
5269 assert!(no_oracle_kinds.contains(&PirTransformKind::ReduceSelDistribute));
5270
5271 assert!(oracle_kinds.contains(&PirTransformKind::ShiftHoist));
5272 assert!(oracle_kinds.contains(&PirTransformKind::MaskOperandHighBit));
5273 assert!(oracle_kinds.contains(&PirTransformKind::RewireOperandToSameType));
5274 assert!(oracle_kinds.contains(&PirTransformKind::GuardedPredicateRewire));
5275 }
5276
5277 #[test]
5278 fn ext_nary_add_arch_reaches_g8r_tags_after_xls_round_trip() {
5279 let mut parser = ir_parser::Parser::new(
5280 r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5281 ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5282}"#,
5283 );
5284 let f = parser.parse_fn().unwrap();
5285
5286 let optimized =
5287 optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Preserve)
5288 .unwrap();
5289 let ext_text_id = optimized
5290 .nodes
5291 .iter()
5292 .find_map(|node| match &node.payload {
5293 NodePayload::ExtNaryAdd {
5294 arch: Some(ExtNaryAddArchitecture::BrentKung),
5295 ..
5296 } => Some(node.text_id),
5297 _ => None,
5298 })
5299 .expect("expected reconstructed brent_kung ext_nary_add");
5300
5301 let gatify_output =
5302 ir2gate::gatify(&optimized, GatifyOptions::all_opts_disabled()).unwrap();
5303 let gate_fn_text = gatify_output.gate_fn.to_string();
5304 assert!(
5305 gate_fn_text.contains(&format!(
5306 "ext_nary_add_{}_brent_kung_output_bit_",
5307 ext_text_id
5308 )),
5309 "expected gatify tags to reflect the ext_nary_add arch after XLS round-trip:\n{}",
5310 gate_fn_text
5311 );
5312 }
5313
5314 #[test]
5315 fn objective_metric_toggles_product_saturates() {
5316 let c = Cost {
5317 pir_nodes: 0,
5318 g8r_nodes: usize::MAX,
5319 g8r_depth: usize::MAX,
5320 g8r_le_graph_milli: 0,
5321 g8r_gate_output_toggles: 2,
5322 g8r_weighted_switching_milli: 0,
5323 g8r_post_and_nodes: 0,
5324 g8r_post_depth: 0,
5325 g8r_post_le_graph_milli: 0,
5326 g8r_post_gate_output_toggles: 0,
5327 g8r_post_weighted_switching_milli: 0,
5328 };
5329 assert_eq!(
5330 Objective::G8rNodesTimesDepthTimesToggles.metric(&c),
5331 u128::MAX
5332 );
5333 }
5334
5335 #[test]
5336 fn objective_metric_graph_logical_effort_times_nodes_multiplies_nodes() {
5337 let c = Cost {
5338 pir_nodes: 0,
5339 g8r_nodes: 7,
5340 g8r_depth: 11,
5341 g8r_le_graph_milli: 13,
5342 g8r_gate_output_toggles: 0,
5343 g8r_weighted_switching_milli: 0,
5344 g8r_post_and_nodes: 0,
5345 g8r_post_depth: 0,
5346 g8r_post_le_graph_milli: 0,
5347 g8r_post_gate_output_toggles: 0,
5348 g8r_post_weighted_switching_milli: 0,
5349 };
5350 assert_eq!(Objective::G8rLeGraphTimesNodes.metric(&c), 91);
5351 }
5352
5353 #[test]
5354 fn objective_metric_graph_logical_effort_times_nodes_handles_large_values() {
5355 let c = Cost {
5356 pir_nodes: 0,
5357 g8r_nodes: usize::MAX,
5358 g8r_depth: 0,
5359 g8r_le_graph_milli: usize::MAX,
5360 g8r_gate_output_toggles: 0,
5361 g8r_weighted_switching_milli: 0,
5362 g8r_post_and_nodes: 0,
5363 g8r_post_depth: 0,
5364 g8r_post_le_graph_milli: 0,
5365 g8r_post_gate_output_toggles: 0,
5366 g8r_post_weighted_switching_milli: 0,
5367 };
5368 assert_eq!(
5369 Objective::G8rLeGraphTimesNodes.metric(&c),
5370 (usize::MAX as u128) * (usize::MAX as u128)
5371 );
5372 }
5373
5374 #[test]
5375 fn graph_logical_effort_times_nodes_uses_graph_effort_without_toggles() {
5376 let objective = Objective::G8rLeGraphTimesNodes;
5377 assert!(objective.uses_g8r_costing());
5378 assert!(objective.needs_graph_logical_effort());
5379 assert!(!objective.needs_toggle_stimulus());
5380 }
5381
5382 #[test]
5383 fn objective_metric_nodes_times_weighted_switching_saturates() {
5384 let c = Cost {
5385 pir_nodes: 0,
5386 g8r_nodes: usize::MAX,
5387 g8r_depth: 0,
5388 g8r_le_graph_milli: 0,
5389 g8r_gate_output_toggles: 0,
5390 g8r_weighted_switching_milli: u128::MAX,
5391 g8r_post_and_nodes: 0,
5392 g8r_post_depth: 0,
5393 g8r_post_le_graph_milli: 0,
5394 g8r_post_gate_output_toggles: 0,
5395 g8r_post_weighted_switching_milli: 0,
5396 };
5397 assert_eq!(
5398 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress.metric(&c),
5399 u128::MAX
5400 );
5401 }
5402
5403 #[test]
5404 fn g8r_post_objectives_use_postprocessed_cost_fields() {
5405 let c = Cost {
5406 pir_nodes: 1,
5407 g8r_nodes: 2,
5408 g8r_depth: 3,
5409 g8r_le_graph_milli: 5,
5410 g8r_gate_output_toggles: 7,
5411 g8r_weighted_switching_milli: 11,
5412 g8r_post_and_nodes: 13,
5413 g8r_post_depth: 17,
5414 g8r_post_le_graph_milli: 19,
5415 g8r_post_gate_output_toggles: 23,
5416 g8r_post_weighted_switching_milli: 29,
5417 };
5418 assert_eq!(Objective::G8rPostAndNodes.metric(&c), 13);
5419 assert_eq!(Objective::G8rPostAndNodesTimesDepth.metric(&c), 221);
5420 assert_eq!(
5421 Objective::G8rPostAndNodesTimesDepthTimesToggles.metric(&c),
5422 5083
5423 );
5424 assert_eq!(Objective::G8rPostLeGraph.metric(&c), 19);
5425 assert_eq!(Objective::G8rPostLeGraphTimesAndNodes.metric(&c), 247);
5426 assert_eq!(Objective::G8rPostLeGraphTimesProduct.metric(&c), 4199);
5427 assert_eq!(Objective::G8rPostWeightedSwitching.metric(&c), 29);
5428 assert_eq!(
5429 Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress.metric(&c),
5430 377
5431 );
5432 for objective in [
5433 Objective::G8rPostAndNodes,
5434 Objective::G8rPostAndNodesTimesDepth,
5435 Objective::G8rPostAndNodesTimesDepthTimesToggles,
5436 Objective::G8rPostLeGraph,
5437 Objective::G8rPostLeGraphTimesAndNodes,
5438 Objective::G8rPostLeGraphTimesProduct,
5439 Objective::G8rPostWeightedSwitching,
5440 Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
5441 ] {
5442 assert_eq!(
5443 Objective::from_value_name(objective.value_name()).unwrap(),
5444 objective
5445 );
5446 }
5447 }
5448
5449 #[test]
5450 fn search_score_prefers_feasible_then_lower_violation_then_objective() {
5451 let feasible = search_score(
5452 &Cost {
5453 pir_nodes: 0,
5454 g8r_nodes: 10,
5455 g8r_depth: 10,
5456 g8r_le_graph_milli: 0,
5457 g8r_gate_output_toggles: 0,
5458 g8r_weighted_switching_milli: 0,
5459 g8r_post_and_nodes: 0,
5460 g8r_post_depth: 0,
5461 g8r_post_le_graph_milli: 0,
5462 g8r_post_gate_output_toggles: 0,
5463 g8r_post_weighted_switching_milli: 0,
5464 },
5465 Objective::G8rNodes,
5466 ConstraintLimits {
5467 max_delay: Some(12),
5468 max_area: None,
5469 },
5470 );
5471 let mildly_infeasible = search_score(
5472 &Cost {
5473 pir_nodes: 0,
5474 g8r_nodes: 10,
5475 g8r_depth: 13,
5476 g8r_le_graph_milli: 0,
5477 g8r_gate_output_toggles: 0,
5478 g8r_weighted_switching_milli: 0,
5479 g8r_post_and_nodes: 0,
5480 g8r_post_depth: 0,
5481 g8r_post_le_graph_milli: 0,
5482 g8r_post_gate_output_toggles: 0,
5483 g8r_post_weighted_switching_milli: 0,
5484 },
5485 Objective::G8rNodes,
5486 ConstraintLimits {
5487 max_delay: Some(12),
5488 max_area: None,
5489 },
5490 );
5491 let badly_infeasible = search_score(
5492 &Cost {
5493 pir_nodes: 0,
5494 g8r_depth: 16,
5495 g8r_nodes: 10,
5496 g8r_le_graph_milli: 0,
5497 g8r_gate_output_toggles: 0,
5498 g8r_weighted_switching_milli: 0,
5499 g8r_post_and_nodes: 0,
5500 g8r_post_depth: 0,
5501 g8r_post_le_graph_milli: 0,
5502 g8r_post_gate_output_toggles: 0,
5503 g8r_post_weighted_switching_milli: 0,
5504 },
5505 Objective::G8rNodes,
5506 ConstraintLimits {
5507 max_delay: Some(12),
5508 max_area: None,
5509 },
5510 );
5511
5512 assert!(feasible < mildly_infeasible);
5513 assert!(mildly_infeasible < badly_infeasible);
5514 }
5515
5516 #[test]
5517 fn effective_constraint_limits_respects_non_regressing_depth() {
5518 let initial = Cost {
5519 pir_nodes: 0,
5520 g8r_nodes: 10,
5521 g8r_depth: 17,
5522 g8r_le_graph_milli: 0,
5523 g8r_gate_output_toggles: 0,
5524 g8r_weighted_switching_milli: 0,
5525 g8r_post_and_nodes: 0,
5526 g8r_post_depth: 0,
5527 g8r_post_le_graph_milli: 0,
5528 g8r_post_gate_output_toggles: 0,
5529 g8r_post_weighted_switching_milli: 0,
5530 };
5531 let got = effective_constraint_limits(
5532 Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress,
5533 ConstraintLimits {
5534 max_delay: Some(20),
5535 max_area: None,
5536 },
5537 &initial,
5538 );
5539 assert_eq!(got.max_delay, Some(17));
5540 assert_eq!(got.max_area, None);
5541 }
5542
5543 #[test]
5544 fn postprocessed_constraints_use_post_area_and_depth() {
5545 let c = Cost {
5546 pir_nodes: 0,
5547 g8r_nodes: 1,
5548 g8r_depth: 1,
5549 g8r_le_graph_milli: 0,
5550 g8r_gate_output_toggles: 0,
5551 g8r_weighted_switching_milli: 0,
5552 g8r_post_and_nodes: 11,
5553 g8r_post_depth: 17,
5554 g8r_post_le_graph_milli: 0,
5555 g8r_post_gate_output_toggles: 0,
5556 g8r_post_weighted_switching_milli: 0,
5557 };
5558 let score = search_score(
5559 &c,
5560 Objective::G8rPostAndNodes,
5561 ConstraintLimits {
5562 max_delay: Some(12),
5563 max_area: None,
5564 },
5565 );
5566 assert_eq!(
5567 score.violation,
5568 Some(ConstraintViolationScore {
5569 delay_over: Some(5),
5570 area_over: None,
5571 })
5572 );
5573 let got = effective_constraint_limits(
5574 Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
5575 ConstraintLimits {
5576 max_delay: Some(20),
5577 max_area: None,
5578 },
5579 &c,
5580 );
5581 assert_eq!(got.max_delay, Some(17));
5582 }
5583
5584 #[test]
5585 fn g8r_post_costing_requires_external_postprocessor() {
5586 let f = parse_fn(
5587 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5588 ret and.3: bits[1] = and(a, b, id=3)
5589}"#,
5590 );
5591 let err = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5592 &f,
5593 Objective::G8rPostAndNodes,
5594 None,
5595 &WeightedSwitchingOptions::default(),
5596 ExtensionCostingMode::Preserve,
5597 &G8rEvaluationMode::Builtin,
5598 )
5599 .unwrap_err();
5600 assert!(
5601 err.to_string()
5602 .contains("requires an external g8r postprocessor"),
5603 "unexpected error: {err}"
5604 );
5605 }
5606
5607 #[test]
5608 fn external_postprocessor_identity_populates_post_stats() {
5609 let temp = tempdir().unwrap();
5610 let hook =
5611 write_executable_script(temp.path(), "identity.sh", "#!/bin/sh\ncp \"$1\" \"$3\"\n");
5612 let f = parse_fn(
5613 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5614 ret and.3: bits[1] = and(a, b, id=3)
5615}"#,
5616 );
5617 let c = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5618 &f,
5619 Objective::G8rPostAndNodes,
5620 None,
5621 &WeightedSwitchingOptions::default(),
5622 ExtensionCostingMode::Preserve,
5623 &G8rEvaluationMode::ExternalPostprocess {
5624 program: hook.display().to_string(),
5625 },
5626 )
5627 .unwrap();
5628 assert!(c.g8r_post_and_nodes > 0);
5629 assert!(c.g8r_post_depth > 0);
5630 }
5631
5632 #[test]
5633 fn external_postprocessor_repacked_interface_supports_toggle_metrics() {
5634 let temp = tempdir().unwrap();
5635 let hook =
5636 write_executable_script(temp.path(), "identity.sh", "#!/bin/sh\ncp \"$1\" \"$3\"\n");
5637 let f = parse_fn(
5638 r#"fn f(a: bits[2] id=1, b: bits[2] id=2) -> bits[2] {
5639 ret and.3: bits[2] = and(a, b, id=3)
5640}"#,
5641 );
5642 let samples = vec![
5643 IrValue::parse_typed("(bits[2]:0b00, bits[2]:0b11)").unwrap(),
5644 IrValue::parse_typed("(bits[2]:0b11, bits[2]:0b11)").unwrap(),
5645 IrValue::parse_typed("(bits[2]:0b01, bits[2]:0b11)").unwrap(),
5646 ];
5647 let toggle_stimulus = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5648 let post_mode = G8rEvaluationMode::ExternalPostprocess {
5649 program: hook.display().to_string(),
5650 };
5651 let raw_toggles = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5652 &f,
5653 Objective::G8rNodesTimesDepthTimesToggles,
5654 Some(&toggle_stimulus),
5655 &WeightedSwitchingOptions::default(),
5656 ExtensionCostingMode::Preserve,
5657 &G8rEvaluationMode::Builtin,
5658 )
5659 .unwrap();
5660 let post_toggles = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5661 &f,
5662 Objective::G8rPostAndNodesTimesDepthTimesToggles,
5663 Some(&toggle_stimulus),
5664 &WeightedSwitchingOptions::default(),
5665 ExtensionCostingMode::Preserve,
5666 &post_mode,
5667 )
5668 .unwrap();
5669 assert!(post_toggles.g8r_post_gate_output_toggles > 0);
5670 assert_eq!(
5671 post_toggles.g8r_post_gate_output_toggles,
5672 raw_toggles.g8r_gate_output_toggles
5673 );
5674
5675 let raw_weighted = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5676 &f,
5677 Objective::G8rWeightedSwitching,
5678 Some(&toggle_stimulus),
5679 &WeightedSwitchingOptions::default(),
5680 ExtensionCostingMode::Preserve,
5681 &G8rEvaluationMode::Builtin,
5682 )
5683 .unwrap();
5684 let post_weighted = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5685 &f,
5686 Objective::G8rPostWeightedSwitching,
5687 Some(&toggle_stimulus),
5688 &WeightedSwitchingOptions::default(),
5689 ExtensionCostingMode::Preserve,
5690 &post_mode,
5691 )
5692 .unwrap();
5693 assert!(post_weighted.g8r_post_weighted_switching_milli > 0);
5694 assert_eq!(
5695 post_weighted.g8r_post_weighted_switching_milli,
5696 raw_weighted.g8r_weighted_switching_milli
5697 );
5698 }
5699
5700 #[test]
5701 fn external_postprocessor_reports_failures() {
5702 let temp = tempdir().unwrap();
5703 let fail = write_executable_script(
5704 temp.path(),
5705 "fail.sh",
5706 "#!/bin/sh\necho broken >&2\nexit 7\n",
5707 );
5708 let missing = write_executable_script(temp.path(), "missing.sh", "#!/bin/sh\nexit 0\n");
5709 let malformed = write_executable_script(
5710 temp.path(),
5711 "malformed.sh",
5712 "#!/bin/sh\nprintf nope > \"$3\"\n",
5713 );
5714 let f = parse_fn(
5715 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5716 ret and.3: bits[1] = and(a, b, id=3)
5717}"#,
5718 );
5719 let err = |program: &Path| {
5720 cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5721 &f,
5722 Objective::G8rPostAndNodes,
5723 None,
5724 &WeightedSwitchingOptions::default(),
5725 ExtensionCostingMode::Preserve,
5726 &G8rEvaluationMode::ExternalPostprocess {
5727 program: program.display().to_string(),
5728 },
5729 )
5730 .unwrap_err()
5731 .to_string()
5732 };
5733 assert!(err(&fail).contains("broken"));
5734 assert!(err(&missing).contains("did not create"));
5735 assert!(err(&malformed).contains("failed to load"));
5736 }
5737
5738 #[test]
5739 fn parse_irvals_tuple_lines_accepts_valid_tuples() {
5740 let text = "(bits[1]:0, bits[1]:1)\n(bits[1]:1, bits[1]:1)\n";
5741 let got = parse_irvals_tuple_lines(text).unwrap();
5742 assert_eq!(got.len(), 2);
5743 }
5744
5745 #[test]
5746 fn parse_irvals_tuple_lines_rejects_invalid_or_non_tuple_lines() {
5747 let bad_parse = parse_irvals_tuple_lines("not_a_value\n").unwrap_err();
5748 assert!(
5749 bad_parse.to_string().contains("line 1"),
5750 "expected line number in parse error"
5751 );
5752
5753 let non_tuple = parse_irvals_tuple_lines("bits[1]:1\n").unwrap_err();
5754 assert!(
5755 non_tuple.to_string().contains("not a tuple"),
5756 "expected tuple-specific error"
5757 );
5758 }
5759
5760 #[test]
5761 fn lower_toggle_stimulus_rejects_arity_and_type_mismatch() {
5762 let mut parser = ir_parser::Parser::new(
5763 r#"fn f(a: bits[1] id=1, b: bits[2] id=2) -> bits[1] {
5764 ret identity.3: bits[1] = identity(a, id=3)
5765}"#,
5766 );
5767 let f = parser.parse_fn().unwrap();
5768
5769 let arity_bad = vec![
5770 IrValue::parse_typed("(bits[1]:0, bits[2]:0, bits[1]:0)").unwrap(),
5771 IrValue::parse_typed("(bits[1]:1, bits[2]:1, bits[1]:1)").unwrap(),
5772 ];
5773 assert!(lower_toggle_stimulus_for_fn(&arity_bad, &f).is_err());
5774
5775 let type_bad = vec![
5776 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5777 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5778 ];
5779 assert!(lower_toggle_stimulus_for_fn(&type_bad, &f).is_err());
5780 }
5781
5782 #[test]
5783 fn run_pir_mcmc_rejects_invalid_toggle_stimulus_usage() {
5784 let mut parser = ir_parser::Parser::new(
5785 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5786 ret and.3: bits[1] = and(a, b, id=3)
5787}"#,
5788 );
5789 let f = parser.parse_fn().unwrap();
5790
5791 let opts_missing = RunOptions {
5792 max_iters: 1,
5793 threads: 1,
5794 chain_strategy: ChainStrategy::Independent,
5795 checkpoint_iters: 1,
5796 progress_iters: 0,
5797 seed: 1,
5798 initial_temperature: 1.0,
5799 objective: Objective::G8rNodesTimesDepthTimesToggles,
5800 extension_costing_mode: ExtensionCostingMode::Preserve,
5801 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5802 canonical_g8r_options: CanonicalG8rOptions::default(),
5803 max_allowed_depth: None,
5804 max_allowed_area: None,
5805 weighted_switching_options: WeightedSwitchingOptions::default(),
5806 enable_formal_oracle: false,
5807 trajectory_dir: None,
5808 toggle_stimulus: None,
5809 };
5810 assert!(run_pir_mcmc(f.clone(), opts_missing).is_err());
5811
5812 let opts_wrong_objective = RunOptions {
5813 max_iters: 1,
5814 threads: 1,
5815 chain_strategy: ChainStrategy::Independent,
5816 checkpoint_iters: 1,
5817 progress_iters: 0,
5818 seed: 1,
5819 initial_temperature: 1.0,
5820 objective: Objective::Nodes,
5821 extension_costing_mode: ExtensionCostingMode::Preserve,
5822 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5823 canonical_g8r_options: CanonicalG8rOptions::default(),
5824 max_allowed_depth: None,
5825 max_allowed_area: None,
5826 weighted_switching_options: WeightedSwitchingOptions::default(),
5827 enable_formal_oracle: false,
5828 trajectory_dir: None,
5829 toggle_stimulus: Some(vec![
5830 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5831 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5832 ]),
5833 };
5834 assert!(run_pir_mcmc(f, opts_wrong_objective).is_err());
5835 }
5836
5837 #[test]
5838 fn run_pir_mcmc_rejects_caps_with_nodes_objective() {
5839 let mut parser = ir_parser::Parser::new(
5840 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5841 ret and.3: bits[1] = and(a, b, id=3)
5842}"#,
5843 );
5844 let f = parser.parse_fn().unwrap();
5845
5846 let opts = RunOptions {
5847 max_iters: 1,
5848 threads: 1,
5849 chain_strategy: ChainStrategy::Independent,
5850 checkpoint_iters: 1,
5851 progress_iters: 0,
5852 seed: 1,
5853 initial_temperature: 1.0,
5854 objective: Objective::Nodes,
5855 extension_costing_mode: ExtensionCostingMode::Preserve,
5856 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5857 canonical_g8r_options: CanonicalG8rOptions::default(),
5858 max_allowed_depth: Some(10),
5859 max_allowed_area: None,
5860 weighted_switching_options: WeightedSwitchingOptions::default(),
5861 enable_formal_oracle: false,
5862 trajectory_dir: None,
5863 toggle_stimulus: None,
5864 };
5865 assert!(run_pir_mcmc(f, opts).is_err());
5866 }
5867
5868 #[test]
5869 fn run_pir_mcmc_rejects_dual_caps() {
5870 let mut parser = ir_parser::Parser::new(
5871 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5872 ret and.3: bits[1] = and(a, b, id=3)
5873}"#,
5874 );
5875 let f = parser.parse_fn().unwrap();
5876
5877 let opts = RunOptions {
5878 max_iters: 1,
5879 threads: 1,
5880 chain_strategy: ChainStrategy::Independent,
5881 checkpoint_iters: 1,
5882 progress_iters: 0,
5883 seed: 1,
5884 initial_temperature: 1.0,
5885 objective: Objective::G8rNodes,
5886 extension_costing_mode: ExtensionCostingMode::Preserve,
5887 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5888 canonical_g8r_options: CanonicalG8rOptions::default(),
5889 max_allowed_depth: Some(10),
5890 max_allowed_area: Some(10),
5891 weighted_switching_options: WeightedSwitchingOptions::default(),
5892 enable_formal_oracle: false,
5893 trajectory_dir: None,
5894 toggle_stimulus: None,
5895 };
5896 assert!(run_pir_mcmc(f, opts).is_err());
5897 }
5898
5899 #[test]
5900 fn run_pir_mcmc_rejects_area_cap_with_non_regressing_depth_objective() {
5901 let mut parser = ir_parser::Parser::new(
5902 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5903 ret and.3: bits[1] = and(a, b, id=3)
5904}"#,
5905 );
5906 let f = parser.parse_fn().unwrap();
5907
5908 let opts = RunOptions {
5909 max_iters: 1,
5910 threads: 1,
5911 chain_strategy: ChainStrategy::Independent,
5912 checkpoint_iters: 1,
5913 progress_iters: 0,
5914 seed: 1,
5915 initial_temperature: 1.0,
5916 objective: Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress,
5917 extension_costing_mode: ExtensionCostingMode::Preserve,
5918 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5919 canonical_g8r_options: CanonicalG8rOptions::default(),
5920 max_allowed_depth: None,
5921 max_allowed_area: Some(10),
5922 weighted_switching_options: WeightedSwitchingOptions::default(),
5923 enable_formal_oracle: false,
5924 trajectory_dir: None,
5925 toggle_stimulus: Some(vec![
5926 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5927 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5928 ]),
5929 };
5930 assert!(run_pir_mcmc(f, opts).is_err());
5931 }
5932
5933 #[test]
5934 fn cost_with_toggle_objective_populates_toggle_count() {
5935 let mut parser = ir_parser::Parser::new(
5936 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5937 ret and.3: bits[1] = and(a, b, id=3)
5938}"#,
5939 );
5940 let f = parser.parse_fn().unwrap();
5941 let samples = vec![
5942 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5943 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5944 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5945 ];
5946 let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5947 let c = cost_with_toggle_stimulus(
5948 &f,
5949 Objective::G8rNodesTimesDepthTimesToggles,
5950 Some(&lowered),
5951 )
5952 .unwrap();
5953 assert!(
5954 c.g8r_gate_output_toggles > 0,
5955 "expected positive interior toggle count"
5956 );
5957 }
5958
5959 #[test]
5960 fn cost_with_weighted_switching_objective_populates_weighted_metric() {
5961 let mut parser = ir_parser::Parser::new(
5962 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5963 ret and.3: bits[1] = and(a, b, id=3)
5964}"#,
5965 );
5966 let f = parser.parse_fn().unwrap();
5967 let samples = vec![
5968 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5969 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5970 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5971 ];
5972 let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5973 let c = cost_with_effort_options_and_toggle_stimulus(
5974 &f,
5975 Objective::G8rWeightedSwitching,
5976 Some(&lowered),
5977 &WeightedSwitchingOptions::default(),
5978 )
5979 .unwrap();
5980 assert!(
5981 c.g8r_weighted_switching_milli > 0,
5982 "expected positive weighted switching estimate"
5983 );
5984 }
5985
5986 #[test]
5987 fn cost_with_weighted_switching_rejects_non_finite_options() {
5988 let mut parser = ir_parser::Parser::new(
5989 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5990 ret and.3: bits[1] = and(a, b, id=3)
5991}"#,
5992 );
5993 let f = parser.parse_fn().unwrap();
5994 let samples = vec![
5995 IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5996 IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5997 ];
5998 let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5999
6000 let err = cost_with_effort_options_and_toggle_stimulus(
6001 &f,
6002 Objective::G8rWeightedSwitching,
6003 Some(&lowered),
6004 &WeightedSwitchingOptions {
6005 beta1: f64::NAN,
6006 beta2: 0.0,
6007 primary_output_load: 1.0,
6008 },
6009 )
6010 .unwrap_err();
6011 assert!(
6012 err.to_string().contains("must be finite"),
6013 "expected finite-coefficient validation error, got: {}",
6014 err
6015 );
6016 }
6017
6018 #[test]
6019 fn trajectory_json_preserves_large_u128_metrics() {
6020 let rec = serde_json::json!({
6021 "metric": u128::MAX,
6022 "g8r_weighted_switching_milli": u128::MAX,
6023 });
6024 let s = serde_json::to_string(&rec).unwrap();
6025 assert!(
6026 s.contains(&u128::MAX.to_string()),
6027 "expected full u128 JSON number in serialized output"
6028 );
6029 }
6030
6031 #[test]
6032 fn trajectory_json_emits_transform_mechanism() {
6033 let mut parser = ir_parser::Parser::new(
6034 r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
6035 ret and.3: bits[1] = and(a, b, id=3)
6036}"#,
6037 );
6038 let f = parser.parse_fn().unwrap();
6039 let temp_dir = tempfile::tempdir().unwrap();
6040 let opts = RunOptions {
6041 max_iters: 1,
6042 threads: 1,
6043 chain_strategy: ChainStrategy::Independent,
6044 checkpoint_iters: 100,
6045 progress_iters: 0,
6046 seed: 1,
6047 initial_temperature: 1.0,
6048 objective: Objective::Nodes,
6049 extension_costing_mode: ExtensionCostingMode::Preserve,
6050 g8r_evaluation_mode: G8rEvaluationMode::Builtin,
6051 canonical_g8r_options: CanonicalG8rOptions::default(),
6052 max_allowed_depth: None,
6053 max_allowed_area: None,
6054 weighted_switching_options: WeightedSwitchingOptions::default(),
6055 enable_formal_oracle: false,
6056 trajectory_dir: Some(temp_dir.path().to_path_buf()),
6057 toggle_stimulus: None,
6058 };
6059
6060 let _ = run_pir_mcmc(f, opts).unwrap();
6061 let path = temp_dir.path().join("trajectory.c000.jsonl");
6062 let text = std::fs::read_to_string(path).unwrap();
6063 let first_line = text.lines().next().expect("trajectory line");
6064 let value: serde_json::Value = serde_json::from_str(first_line).unwrap();
6065 assert!(
6066 value.get("transform_mechanism").is_some(),
6067 "expected transform_mechanism in trajectory JSON: {first_line}"
6068 );
6069 }
6070}