Skip to main content

xlsynth_mcmc_pir/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! PIR-based MCMC optimization using the shared `xlsynth-mcmc` engine.
4//!
5//! This crate wires the XLSynth PIR IR (`xlsynth_pir::ir`) into the generic
6//! MCMC statistics and Metropolis helpers in `xlsynth-mcmc`.  It provides a
7//! small library API (`run_pir_mcmc`) that runs a single-chain MCMC over a
8//! single PIR function.
9
10use anyhow::Result;
11use rand::Rng;
12use rand::SeedableRng;
13use rand::distributions::Distribution;
14use rand::distributions::WeightedIndex;
15use rand::prelude::SliceRandom;
16use rand_pcg::Pcg64Mcg;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use std::collections::BTreeMap;
20use std::io::Write as IoWrite;
21use std::panic::{AssertUnwindSafe, catch_unwind};
22use std::path::{Path, PathBuf};
23use std::process::Command;
24use std::sync::Arc;
25use std::sync::Mutex;
26use std::sync::atomic::AtomicUsize;
27use std::sync::atomic::Ordering;
28use std::sync::mpsc::Sender;
29use std::time::Instant;
30
31use clap::ValueEnum;
32
33use xlsynth::IrBits;
34use xlsynth::IrPackage;
35use xlsynth::IrValue;
36use xlsynth_g8r::aig::get_summary_stats;
37use xlsynth_g8r::aig::get_summary_stats::AigStats;
38use xlsynth_g8r::aig::graph_logical_effort::GraphLogicalEffortOptions;
39use xlsynth_g8r::aig::graph_logical_effort::analyze_graph_logical_effort;
40use xlsynth_g8r::aig_serdes::emit_aiger_binary::emit_aiger_binary;
41use xlsynth_g8r::aig_serdes::gate2ir::{
42    GateFnInterfaceSchema, repack_gate_fn_interface_with_schema,
43};
44use xlsynth_g8r::aig_serdes::load_aiger_auto::load_aiger_auto_from_path;
45use xlsynth_g8r::aig_sim::count_toggles;
46use xlsynth_g8r::gate_builder::GateBuilderOptions;
47use xlsynth_g8r::process_ir_path::{
48    CanonicalG8rOptions, canonical_ir_text_to_g8r_lowering_artifacts,
49};
50use xlsynth_mcmc::MIN_TEMPERATURE_RATIO;
51use xlsynth_mcmc::McmcIterationOutput as SharedMcmcIterationOutput;
52use xlsynth_mcmc::McmcOptions as SharedMcmcOptions;
53use xlsynth_mcmc::McmcStats as SharedMcmcStats;
54use xlsynth_mcmc::metropolis_accept;
55use xlsynth_mcmc::multichain::{ChainRole, ChainStrategy, SegmentOutcome, SegmentRunParams};
56use xlsynth_mcmc::multichain::{SegmentRunner, run_multichain};
57use xlsynth_pir::desugar_extensions::{self, ExtensionEmitMode};
58use xlsynth_pir::ir::FileTable as PirFileTable;
59use xlsynth_pir::ir::Fn as IrFn;
60use xlsynth_pir::ir::Package as PirPackage;
61use xlsynth_pir::ir::PackageMember as PirPackageMember;
62use xlsynth_pir::ir::Param as PirParam;
63use xlsynth_pir::ir::Type as PirType;
64use xlsynth_pir::ir_eval::{FnEvalResult, eval_fn_assuming_node_index_topological};
65use xlsynth_pir::ir_parser;
66use xlsynth_pir::ir_utils::compact_and_toposort_in_place;
67use xlsynth_pir::ir_value_utils::flatten_ir_value_to_lsb0_bits_for_type;
68use xlsynth_pir::random_inputs::generate_biased_irbits_with_rng;
69use xlsynth_pir::structural_similarity::collect_structural_entries;
70
71pub mod driver_cli;
72pub mod transforms;
73
74use crate::transforms::{
75    PirTransform, PirTransformKind, build_transform_weights, get_all_pir_transforms,
76};
77
78const DEFAULT_ORACLE_RANDOM_SAMPLES: usize = 32;
79
80pub fn parse_irvals_tuple_lines(irvals_text: &str) -> Result<Vec<IrValue>> {
81    xlsynth_pir::irvals::parse_irvals_tuple_lines(irvals_text).map_err(anyhow::Error::msg)
82}
83
84pub fn parse_irvals_tuple_file(path: &Path) -> Result<Vec<IrValue>> {
85    xlsynth_pir::irvals::parse_irvals_tuple_file(path).map_err(anyhow::Error::msg)
86}
87
88// We want invalid-IR candidates (esp. bit_slice bounds) to be visible, since
89// they often indicate a bug in a transform. But they can also happen frequently
90// during exploration, which can drown out more important warnings. So: warn a
91// few times, then downgrade.
92static INVALID_BIT_SLICE_WARN_COUNT: AtomicUsize = AtomicUsize::new(0);
93const INVALID_BIT_SLICE_WARN_LIMIT: usize = 8;
94
95use xlsynth_prover::prover::prove_ir_fn_equiv;
96use xlsynth_prover::prover::types::EquivResult;
97
98/// Simple cost model for PIR MCMC.
99#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
100pub struct Cost {
101    /// Number of PIR nodes in the function.
102    pub pir_nodes: usize,
103    /// Number of gates in the corresponding `GateFn` after running the XLS
104    /// optimizer and gatifying.
105    pub g8r_nodes: usize,
106    /// Depth of the corresponding `GateFn` (deepest path) after running the XLS
107    /// optimizer and gatifying.
108    pub g8r_depth: usize,
109    /// Graph logical-effort worst-case delay, scaled by 1e3 and rounded.
110    ///
111    /// This is populated for graph-logical-effort objectives; otherwise it is
112    /// `0`.
113    pub g8r_le_graph_milli: usize,
114    /// Number of interior AIG gate output toggles across a provided input
115    /// stimulus sequence.
116    ///
117    /// This is populated when objective=`g8r-nodes-times-depth-times-toggles`;
118    /// otherwise it is `0`.
119    pub g8r_gate_output_toggles: usize,
120    /// Load-weighted switching activity (`alpha*C`) proxy, scaled by 1e3.
121    ///
122    /// This is populated for weighted-switching objectives; otherwise it is
123    /// `0`.
124    pub g8r_weighted_switching_milli: u128,
125    /// Number of live AND nodes after running the configured external g8r
126    /// postprocessor over the gatified graph.
127    pub g8r_post_and_nodes: usize,
128    /// Maximum AND depth after running the configured external g8r
129    /// postprocessor over the gatified graph.
130    pub g8r_post_depth: usize,
131    /// Graph logical-effort worst-case delay after g8r postprocessing, scaled
132    /// by 1e3 and rounded.
133    pub g8r_post_le_graph_milli: usize,
134    /// Number of interior AIG gate-output toggles after g8r postprocessing.
135    pub g8r_post_gate_output_toggles: usize,
136    /// Load-weighted switching activity after g8r postprocessing, scaled by
137    /// 1e3.
138    pub g8r_post_weighted_switching_milli: u128,
139}
140
141/// How PIR MCMC obtains gate-level cost data.
142#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
143#[serde(tag = "kind", rename_all = "snake_case")]
144pub enum G8rEvaluationMode {
145    /// Use the in-process optimized-PIR-to-g8r path only.
146    Builtin,
147    /// Emit binary AIGER to an external postprocessor and score the returned
148    /// AIGER graph for `g8r-post-*` objectives.
149    ExternalPostprocess { program: String },
150}
151
152impl Default for G8rEvaluationMode {
153    fn default() -> Self {
154        Self::Builtin
155    }
156}
157
158impl G8rEvaluationMode {
159    pub(crate) fn external_postprocess_program(&self) -> Option<&str> {
160        match self {
161            G8rEvaluationMode::Builtin => None,
162            G8rEvaluationMode::ExternalPostprocess { program } => Some(program.as_str()),
163        }
164    }
165
166    /// Rewrites external postprocessor paths into durable absolute paths.
167    pub fn canonicalized_for_persistence(&self) -> Result<Self> {
168        match self {
169            G8rEvaluationMode::Builtin => Ok(Self::Builtin),
170            G8rEvaluationMode::ExternalPostprocess { program } => {
171                let path = std::fs::canonicalize(program).map_err(|e| {
172                    anyhow::anyhow!(
173                        "failed to canonicalize g8r postprocess program '{}': {}",
174                        program,
175                        e
176                    )
177                })?;
178                Ok(Self::ExternalPostprocess {
179                    program: path.display().to_string(),
180                })
181            }
182        }
183    }
184}
185
186/// How PIR extension ops are projected before XLS optimization and g8r costing.
187#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
188pub enum ExtensionCostingMode {
189    /// Preserve extension ops through XLS optimization using FFI wrappers, then
190    /// reconstruct extension ops when reparsing the optimized IR.
191    #[value(name = "preserve")]
192    Preserve,
193    /// Desugar extension ops to ordinary XLS IR before optimization, so costs
194    /// and best artifacts are grounded in standard non-extension IR.
195    #[value(name = "desugar")]
196    Desugar,
197}
198
199impl ExtensionCostingMode {
200    pub fn value_name(self) -> &'static str {
201        match self {
202            ExtensionCostingMode::Preserve => "preserve",
203            ExtensionCostingMode::Desugar => "desugar",
204        }
205    }
206
207    fn from_value_name(value: &str) -> Result<Self> {
208        match value {
209            "preserve" => Ok(ExtensionCostingMode::Preserve),
210            "desugar" => Ok(ExtensionCostingMode::Desugar),
211            _ => Err(anyhow::anyhow!(
212                "unknown extension costing mode in artifact: {}",
213                value
214            )),
215        }
216    }
217
218    fn extension_emit_mode(self) -> ExtensionEmitMode {
219        match self {
220            ExtensionCostingMode::Preserve => ExtensionEmitMode::AsFfiFunction,
221            ExtensionCostingMode::Desugar => ExtensionEmitMode::Desugared,
222        }
223    }
224}
225
226/// Optional hard caps applied to gate-level cost components during PIR MCMC.
227///
228/// At most one cap may be active in a run.
229#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
230pub struct ConstraintLimits {
231    pub max_delay: Option<usize>,
232    pub max_area: Option<usize>,
233}
234
235/// Detailed violation information for an infeasible candidate.
236///
237/// At most one cap is active in any given run, so at most one of these fields
238/// is populated.
239#[derive(Clone, Copy, Debug, PartialEq, Eq)]
240pub struct ConstraintViolationScore {
241    pub delay_over: Option<usize>,
242    pub area_over: Option<usize>,
243}
244
245/// Ordered score used for selecting best-so-far states under optional caps.
246///
247/// `violation=None` means the candidate is feasible. Feasible candidates always
248/// beat infeasible ones; among infeasible candidates we minimize raw overage
249/// under the single active cap first, then fall back to the objective only as a
250/// final tiebreak.
251#[derive(Clone, Copy, Debug, PartialEq, Eq)]
252pub struct SearchScore {
253    pub objective: u128,
254    pub violation: Option<ConstraintViolationScore>,
255}
256
257impl SearchScore {
258    /// Returns true when the candidate satisfies all active hard caps.
259    pub fn feasible(self) -> bool {
260        self.violation.is_none()
261    }
262}
263
264impl Ord for SearchScore {
265    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
266        match (self.violation, other.violation) {
267            (None, None) => self.objective.cmp(&other.objective),
268            (None, Some(_)) => std::cmp::Ordering::Less,
269            (Some(_), None) => std::cmp::Ordering::Greater,
270            (Some(lhs), Some(rhs)) => lhs
271                .delay_over
272                .cmp(&rhs.delay_over)
273                .then_with(|| lhs.area_over.cmp(&rhs.area_over))
274                .then_with(|| self.objective.cmp(&other.objective)),
275        }
276    }
277}
278
279impl PartialOrd for SearchScore {
280    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
281        Some(self.cmp(other))
282    }
283}
284
285#[derive(Clone)]
286struct BestState {
287    score: SearchScore,
288    value: IrFn,
289}
290
291/// Shared best-so-far PIR function using structured feasibility-first scoring.
292pub struct Best {
293    inner: Mutex<BestState>,
294}
295
296impl Best {
297    pub fn new(initial_score: SearchScore, value: IrFn) -> Self {
298        Self {
299            inner: Mutex::new(BestState {
300                score: initial_score,
301                value,
302            }),
303        }
304    }
305
306    pub fn try_update(&self, new_score: SearchScore, new_value: IrFn) -> bool {
307        let mut guard = self.inner.lock().unwrap();
308        if new_score < guard.score {
309            *guard = BestState {
310                score: new_score,
311                value: new_value,
312            };
313            true
314        } else {
315            false
316        }
317    }
318
319    pub fn get(&self) -> IrFn {
320        self.inner.lock().unwrap().value.clone()
321    }
322
323    pub fn score(&self) -> SearchScore {
324        self.inner.lock().unwrap().score
325    }
326}
327
328/// Calculates the cost of a PIR function.
329///
330/// When the objective is g8r-based, this runs the XLS optimizer and gatify
331/// pipeline to obtain live gate count and depth. Failures are returned as an
332/// error (callers can choose to reject the candidate).
333pub fn cost(f: &IrFn, objective: Objective) -> Result<Cost> {
334    cost_with_effort_options_and_toggle_stimulus(
335        f,
336        objective,
337        None,
338        &count_toggles::WeightedSwitchingOptions::default(),
339    )
340}
341
342/// Calculates cost and, for toggle-based objectives, evaluates the candidate on
343/// a fixed gate-level toggle stimulus.
344pub fn cost_with_toggle_stimulus(
345    f: &IrFn,
346    objective: Objective,
347    toggle_stimulus: Option<&[Vec<IrBits>]>,
348) -> Result<Cost> {
349    cost_with_effort_options_and_toggle_stimulus(
350        f,
351        objective,
352        toggle_stimulus,
353        &count_toggles::WeightedSwitchingOptions::default(),
354    )
355}
356
357/// Calculates cost with explicit load-weighting options.
358pub fn cost_with_effort_options_and_toggle_stimulus(
359    f: &IrFn,
360    objective: Objective,
361    toggle_stimulus: Option<&[Vec<IrBits>]>,
362    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
363) -> Result<Cost> {
364    cost_with_effort_options_toggle_stimulus_and_extension_mode(
365        f,
366        objective,
367        toggle_stimulus,
368        weighted_switching_options,
369        ExtensionCostingMode::Preserve,
370    )
371}
372
373/// Calculates cost with explicit load-weighting and extension projection
374/// options.
375pub fn cost_with_effort_options_toggle_stimulus_and_extension_mode(
376    f: &IrFn,
377    objective: Objective,
378    toggle_stimulus: Option<&[Vec<IrBits>]>,
379    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
380    extension_costing_mode: ExtensionCostingMode,
381) -> Result<Cost> {
382    cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
383        f,
384        objective,
385        toggle_stimulus,
386        weighted_switching_options,
387        extension_costing_mode,
388        &G8rEvaluationMode::Builtin,
389    )
390}
391
392/// Calculates cost with explicit load-weighting, extension projection, and
393/// gate-level evaluator configuration.
394pub fn cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
395    f: &IrFn,
396    objective: Objective,
397    toggle_stimulus: Option<&[Vec<IrBits>]>,
398    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
399    extension_costing_mode: ExtensionCostingMode,
400    g8r_evaluation_mode: &G8rEvaluationMode,
401) -> Result<Cost> {
402    cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
403        f,
404        objective,
405        toggle_stimulus,
406        weighted_switching_options,
407        extension_costing_mode,
408        g8r_evaluation_mode,
409        &CanonicalG8rOptions::default(),
410    )
411}
412
413/// Calculates cost with explicit evaluator and canonical g8r lowering options.
414pub fn cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
415    f: &IrFn,
416    objective: Objective,
417    toggle_stimulus: Option<&[Vec<IrBits>]>,
418    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
419    extension_costing_mode: ExtensionCostingMode,
420    g8r_evaluation_mode: &G8rEvaluationMode,
421    canonical_g8r_options: &CanonicalG8rOptions,
422) -> Result<Cost> {
423    let pir_nodes = f.nodes.len();
424
425    if objective.needs_toggle_stimulus() && toggle_stimulus.is_none() {
426        return Err(anyhow::anyhow!(
427            "objective {} requires toggle stimulus",
428            objective.value_name()
429        ));
430    }
431    if !objective.needs_toggle_stimulus() && toggle_stimulus.is_some() {
432        return Err(anyhow::anyhow!(
433            "toggle stimulus provided but objective {} does not use toggles",
434            objective.value_name()
435        ));
436    }
437    if objective.needs_weighted_switching() {
438        validate_weighted_switching_options(weighted_switching_options)?;
439    }
440    if objective.uses_postprocessed_costing()
441        && g8r_evaluation_mode.external_postprocess_program().is_none()
442    {
443        return Err(anyhow::anyhow!(
444            "objective {} requires an external g8r postprocessor",
445            objective.value_name()
446        ));
447    }
448
449    let gate_stats = if objective.uses_gate_costing() {
450        compute_g8r_stats_for_pir_fn(
451            f,
452            objective.needs_graph_logical_effort(),
453            objective.needs_gate_output_toggles(),
454            objective.needs_weighted_switching(),
455            toggle_stimulus,
456            weighted_switching_options,
457            extension_costing_mode,
458            g8r_evaluation_mode,
459            canonical_g8r_options,
460            objective.uses_postprocessed_costing(),
461        )?
462    } else {
463        GateCostStats {
464            raw: RawG8rStats {
465                nodes: pir_nodes,
466                depth: pir_nodes,
467                le_graph_milli: 0,
468                gate_output_toggles: 0,
469                weighted_switching_milli: 0,
470            },
471            post: None,
472        }
473    };
474    let post = gate_stats.post.unwrap_or_default();
475
476    Ok(Cost {
477        pir_nodes,
478        g8r_nodes: gate_stats.raw.nodes,
479        g8r_depth: gate_stats.raw.depth,
480        g8r_le_graph_milli: gate_stats.raw.le_graph_milli,
481        g8r_gate_output_toggles: gate_stats.raw.gate_output_toggles,
482        g8r_weighted_switching_milli: gate_stats.raw.weighted_switching_milli,
483        g8r_post_and_nodes: post.and_nodes,
484        g8r_post_depth: post.depth,
485        g8r_post_le_graph_milli: post.le_graph_milli,
486        g8r_post_gate_output_toggles: post.gate_output_toggles,
487        g8r_post_weighted_switching_milli: post.weighted_switching_milli,
488    })
489}
490
491fn validate_weighted_switching_options(
492    options: &count_toggles::WeightedSwitchingOptions,
493) -> Result<()> {
494    let checks = [
495        ("beta1", options.beta1),
496        ("beta2", options.beta2),
497        ("primary_output_load", options.primary_output_load),
498    ];
499    for (name, value) in checks {
500        if !value.is_finite() {
501            return Err(anyhow::anyhow!(
502                "weighted switching option '{}' must be finite; got {}",
503                name,
504                value
505            ));
506        }
507    }
508    Ok(())
509}
510
511/// Optimizes a PIR package through XLS using the selected extension projection.
512pub(crate) fn optimize_pir_package_via_xls_with_extension_mode(
513    pkg: &PirPackage,
514    top: &str,
515    extension_costing_mode: ExtensionCostingMode,
516) -> Result<PirPackage> {
517    let wrapped_ir_text = desugar_extensions::emit_package_with_extension_mode(
518        pkg,
519        extension_costing_mode.extension_emit_mode(),
520    )
521    .map_err(|e| anyhow::anyhow!("emit_package_with_extension_mode failed: {}", e))?;
522
523    let ir_pkg = IrPackage::parse_ir(&wrapped_ir_text, None)
524        .map_err(|e| anyhow::anyhow!("IrPackage::parse_ir failed: {:?}", e))?;
525    let optimized_ir_pkg = xlsynth::optimize_ir(&ir_pkg, top)
526        .map_err(|e| anyhow::anyhow!("optimize_ir failed: {:?}", e))?;
527    let optimized_ir_text = optimized_ir_pkg.to_string();
528
529    let mut parser = ir_parser::Parser::new(&optimized_ir_text);
530    parser
531        .parse_and_validate_package()
532        .map_err(|e| anyhow::anyhow!("PIR parse_and_validate_package failed: {:?}", e))
533}
534
535/// Produces the XLS-optimized PIR function for `f` using the selected extension
536/// projection mode.
537pub(crate) fn optimize_pir_fn_via_xls_with_extension_mode(
538    f: &IrFn,
539    extension_costing_mode: ExtensionCostingMode,
540) -> Result<IrFn> {
541    // The pipeline assumes the IR is a DAG and that textual IR references only
542    // previously-defined names. MCMC exploration can transiently violate that;
543    // callers can choose how to handle errors.
544    let mut fn_for_text = f.clone();
545    compact_and_toposort_in_place(&mut fn_for_text)
546        .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
547
548    let mut pir_pkg = PirPackage {
549        name: "pir_mcmc".to_string(),
550        file_table: PirFileTable::new(),
551        members: vec![PirPackageMember::Function(fn_for_text)],
552        top: None,
553    };
554    pir_pkg
555        .set_top_fn(&f.name)
556        .map_err(|e| anyhow::anyhow!("set_top_fn failed: {}", e))?;
557
558    let pir_pkg = optimize_pir_package_via_xls_with_extension_mode(
559        &pir_pkg,
560        &f.name,
561        extension_costing_mode,
562    )?;
563    let top_fn = pir_pkg
564        .get_top_fn()
565        .ok_or_else(|| anyhow::anyhow!("No top function found in optimized PIR package"))?;
566    Ok(top_fn.clone())
567}
568
569/// Objective used to evaluate cost improvements for PIR MCMC.
570#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
571pub enum Objective {
572    Nodes,
573    G8rNodes,
574    G8rNodesTimesDepth,
575    #[value(name = "g8r-nodes-times-depth-times-toggles")]
576    G8rNodesTimesDepthTimesToggles,
577    #[value(
578        name = "g8r-le-graph",
579        alias = "g8r-graph-le",
580        alias = "g8r-graph-logical-effort"
581    )]
582    G8rLeGraph,
583    #[value(name = "g8r-le-graph-times-nodes")]
584    G8rLeGraphTimesNodes,
585    #[value(name = "g8r-le-graph-times-product")]
586    G8rLeGraphTimesProduct,
587    #[value(name = "g8r-weighted-switching")]
588    G8rWeightedSwitching,
589    #[value(name = "g8r-nodes-times-weighted-switching-no-depth-regress")]
590    G8rNodesTimesWeightedSwitchingNoDepthRegress,
591    #[value(name = "g8r-post-and-nodes")]
592    G8rPostAndNodes,
593    #[value(name = "g8r-post-and-nodes-times-depth")]
594    G8rPostAndNodesTimesDepth,
595    #[value(name = "g8r-post-and-nodes-times-depth-times-toggles")]
596    G8rPostAndNodesTimesDepthTimesToggles,
597    #[value(name = "g8r-post-le-graph")]
598    G8rPostLeGraph,
599    #[value(name = "g8r-post-le-graph-times-and-nodes")]
600    G8rPostLeGraphTimesAndNodes,
601    #[value(name = "g8r-post-le-graph-times-product")]
602    G8rPostLeGraphTimesProduct,
603    #[value(name = "g8r-post-weighted-switching")]
604    G8rPostWeightedSwitching,
605    #[value(name = "g8r-post-and-nodes-times-weighted-switching-no-depth-regress")]
606    G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
607}
608
609impl Objective {
610    pub fn uses_g8r_costing(self) -> bool {
611        matches!(
612            self,
613            Objective::G8rNodes
614                | Objective::G8rNodesTimesDepth
615                | Objective::G8rNodesTimesDepthTimesToggles
616                | Objective::G8rLeGraph
617                | Objective::G8rLeGraphTimesNodes
618                | Objective::G8rLeGraphTimesProduct
619                | Objective::G8rWeightedSwitching
620                | Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
621        )
622    }
623
624    pub fn uses_postprocessed_costing(self) -> bool {
625        matches!(
626            self,
627            Objective::G8rPostAndNodes
628                | Objective::G8rPostAndNodesTimesDepth
629                | Objective::G8rPostAndNodesTimesDepthTimesToggles
630                | Objective::G8rPostLeGraph
631                | Objective::G8rPostLeGraphTimesAndNodes
632                | Objective::G8rPostLeGraphTimesProduct
633                | Objective::G8rPostWeightedSwitching
634                | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
635        )
636    }
637
638    pub fn uses_gate_costing(self) -> bool {
639        self.uses_g8r_costing() || self.uses_postprocessed_costing()
640    }
641
642    pub fn needs_graph_logical_effort(self) -> bool {
643        matches!(
644            self,
645            Objective::G8rLeGraph
646                | Objective::G8rLeGraphTimesNodes
647                | Objective::G8rLeGraphTimesProduct
648                | Objective::G8rPostLeGraph
649                | Objective::G8rPostLeGraphTimesAndNodes
650                | Objective::G8rPostLeGraphTimesProduct
651        )
652    }
653
654    pub fn needs_toggle_stimulus(self) -> bool {
655        self.needs_gate_output_toggles() || self.needs_weighted_switching()
656    }
657
658    pub fn needs_gate_output_toggles(self) -> bool {
659        matches!(
660            self,
661            Objective::G8rNodesTimesDepthTimesToggles
662                | Objective::G8rPostAndNodesTimesDepthTimesToggles
663        )
664    }
665
666    pub fn needs_weighted_switching(self) -> bool {
667        matches!(
668            self,
669            Objective::G8rWeightedSwitching
670                | Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
671                | Objective::G8rPostWeightedSwitching
672                | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
673        )
674    }
675
676    pub fn enforces_non_regressing_depth(self) -> bool {
677        matches!(
678            self,
679            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
680                | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
681        )
682    }
683
684    pub fn value_name(self) -> &'static str {
685        match self {
686            Objective::Nodes => "nodes",
687            Objective::G8rNodes => "g8r-nodes",
688            Objective::G8rNodesTimesDepth => "g8r-nodes-times-depth",
689            Objective::G8rNodesTimesDepthTimesToggles => "g8r-nodes-times-depth-times-toggles",
690            Objective::G8rLeGraph => "g8r-le-graph",
691            Objective::G8rLeGraphTimesNodes => "g8r-le-graph-times-nodes",
692            Objective::G8rLeGraphTimesProduct => "g8r-le-graph-times-product",
693            Objective::G8rWeightedSwitching => "g8r-weighted-switching",
694            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress => {
695                "g8r-nodes-times-weighted-switching-no-depth-regress"
696            }
697            Objective::G8rPostAndNodes => "g8r-post-and-nodes",
698            Objective::G8rPostAndNodesTimesDepth => "g8r-post-and-nodes-times-depth",
699            Objective::G8rPostAndNodesTimesDepthTimesToggles => {
700                "g8r-post-and-nodes-times-depth-times-toggles"
701            }
702            Objective::G8rPostLeGraph => "g8r-post-le-graph",
703            Objective::G8rPostLeGraphTimesAndNodes => "g8r-post-le-graph-times-and-nodes",
704            Objective::G8rPostLeGraphTimesProduct => "g8r-post-le-graph-times-product",
705            Objective::G8rPostWeightedSwitching => "g8r-post-weighted-switching",
706            Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress => {
707                "g8r-post-and-nodes-times-weighted-switching-no-depth-regress"
708            }
709        }
710    }
711
712    fn from_value_name(value: &str) -> Result<Self> {
713        match value {
714            "nodes" => Ok(Objective::Nodes),
715            "g8r-nodes" => Ok(Objective::G8rNodes),
716            "g8r-nodes-times-depth" => Ok(Objective::G8rNodesTimesDepth),
717            "g8r-nodes-times-depth-times-toggles" => Ok(Objective::G8rNodesTimesDepthTimesToggles),
718            "g8r-le-graph" => Ok(Objective::G8rLeGraph),
719            "g8r-le-graph-times-nodes" => Ok(Objective::G8rLeGraphTimesNodes),
720            "g8r-le-graph-times-product" => Ok(Objective::G8rLeGraphTimesProduct),
721            "g8r-weighted-switching" => Ok(Objective::G8rWeightedSwitching),
722            "g8r-nodes-times-weighted-switching-no-depth-regress" => {
723                Ok(Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress)
724            }
725            "g8r-post-and-nodes" => Ok(Objective::G8rPostAndNodes),
726            "g8r-post-and-nodes-times-depth" => Ok(Objective::G8rPostAndNodesTimesDepth),
727            "g8r-post-and-nodes-times-depth-times-toggles" => {
728                Ok(Objective::G8rPostAndNodesTimesDepthTimesToggles)
729            }
730            "g8r-post-le-graph" => Ok(Objective::G8rPostLeGraph),
731            "g8r-post-le-graph-times-and-nodes" => Ok(Objective::G8rPostLeGraphTimesAndNodes),
732            "g8r-post-le-graph-times-product" => Ok(Objective::G8rPostLeGraphTimesProduct),
733            "g8r-post-weighted-switching" => Ok(Objective::G8rPostWeightedSwitching),
734            "g8r-post-and-nodes-times-weighted-switching-no-depth-regress" => {
735                Ok(Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress)
736            }
737            _ => Err(anyhow::anyhow!("unknown objective in artifact: {}", value)),
738        }
739    }
740
741    pub fn metric(self, c: &Cost) -> u128 {
742        match self {
743            Objective::Nodes => c.pir_nodes as u128,
744            Objective::G8rNodes => c.g8r_nodes as u128,
745            Objective::G8rNodesTimesDepth => {
746                (c.g8r_nodes as u128).saturating_mul(c.g8r_depth as u128)
747            }
748            Objective::G8rNodesTimesDepthTimesToggles => (c.g8r_nodes as u128)
749                .saturating_mul(c.g8r_depth as u128)
750                .saturating_mul(c.g8r_gate_output_toggles as u128),
751            Objective::G8rLeGraph => c.g8r_le_graph_milli as u128,
752            Objective::G8rLeGraphTimesNodes => {
753                (c.g8r_le_graph_milli as u128).saturating_mul(c.g8r_nodes as u128)
754            }
755            Objective::G8rLeGraphTimesProduct => {
756                let product = (c.g8r_nodes as u128).saturating_mul(c.g8r_depth as u128);
757                (c.g8r_le_graph_milli as u128).saturating_mul(product)
758            }
759            Objective::G8rWeightedSwitching => c.g8r_weighted_switching_milli,
760            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress => {
761                (c.g8r_nodes as u128).saturating_mul(c.g8r_weighted_switching_milli)
762            }
763            Objective::G8rPostAndNodes => c.g8r_post_and_nodes as u128,
764            Objective::G8rPostAndNodesTimesDepth => {
765                (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_depth as u128)
766            }
767            Objective::G8rPostAndNodesTimesDepthTimesToggles => (c.g8r_post_and_nodes as u128)
768                .saturating_mul(c.g8r_post_depth as u128)
769                .saturating_mul(c.g8r_post_gate_output_toggles as u128),
770            Objective::G8rPostLeGraph => c.g8r_post_le_graph_milli as u128,
771            Objective::G8rPostLeGraphTimesAndNodes => {
772                (c.g8r_post_le_graph_milli as u128).saturating_mul(c.g8r_post_and_nodes as u128)
773            }
774            Objective::G8rPostLeGraphTimesProduct => {
775                let product =
776                    (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_depth as u128);
777                (c.g8r_post_le_graph_milli as u128).saturating_mul(product)
778            }
779            Objective::G8rPostWeightedSwitching => c.g8r_post_weighted_switching_milli,
780            Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress => {
781                (c.g8r_post_and_nodes as u128).saturating_mul(c.g8r_post_weighted_switching_milli)
782            }
783        }
784    }
785
786    fn area_for_constraints(self, c: &Cost) -> usize {
787        if self.uses_postprocessed_costing() {
788            c.g8r_post_and_nodes
789        } else {
790            c.g8r_nodes
791        }
792    }
793
794    fn depth_for_constraints(self, c: &Cost) -> usize {
795        if self.uses_postprocessed_costing() {
796            c.g8r_post_depth
797        } else {
798            c.g8r_depth
799        }
800    }
801}
802
803pub(crate) fn validate_constraint_configuration(
804    objective: Objective,
805    limits: ConstraintLimits,
806) -> Result<()> {
807    if limits.max_delay.is_some() && limits.max_area.is_some() {
808        return Err(anyhow::anyhow!(
809            "at most one of --max-delay and --max-area may be specified"
810        ));
811    }
812    if !objective.uses_gate_costing() && (limits.max_delay.is_some() || limits.max_area.is_some()) {
813        return Err(anyhow::anyhow!(
814            "area/delay caps require a gate-based objective; got {}",
815            objective.value_name()
816        ));
817    }
818    if objective.enforces_non_regressing_depth() && limits.max_area.is_some() {
819        return Err(anyhow::anyhow!(
820            "--max-area is not compatible with objective {} because it already enforces a non-regressing depth cap",
821            objective.value_name()
822        ));
823    }
824    Ok(())
825}
826
827/// Computes the active constraint violation details for the given cost.
828pub fn constraint_violation(
829    c: &Cost,
830    objective: Objective,
831    limits: ConstraintLimits,
832) -> Option<ConstraintViolationScore> {
833    let delay_over = limits
834        .max_delay
835        .map(|max_delay| objective.depth_for_constraints(c).saturating_sub(max_delay))
836        .filter(|over| *over > 0);
837    let area_over = limits
838        .max_area
839        .map(|max_area| objective.area_for_constraints(c).saturating_sub(max_area))
840        .filter(|over| *over > 0);
841
842    if delay_over.is_none() && area_over.is_none() {
843        return None;
844    }
845
846    Some(ConstraintViolationScore {
847        delay_over,
848        area_over,
849    })
850}
851
852/// Computes the ordered search score for best-state tracking and multichain
853/// synchronization under optional area/delay caps.
854pub fn search_score(c: &Cost, objective: Objective, limits: ConstraintLimits) -> SearchScore {
855    SearchScore {
856        objective: objective.metric(c),
857        violation: constraint_violation(c, objective, limits),
858    }
859}
860
861pub(crate) fn effective_constraint_limits(
862    objective: Objective,
863    user_limits: ConstraintLimits,
864    initial_cost: &Cost,
865) -> ConstraintLimits {
866    let limits = ConstraintLimits {
867        max_delay: match (
868            user_limits.max_delay,
869            objective.enforces_non_regressing_depth(),
870        ) {
871            (Some(user_cap), true) => {
872                Some(user_cap.min(objective.depth_for_constraints(initial_cost)))
873            }
874            (Some(user_cap), false) => Some(user_cap),
875            (None, true) => Some(objective.depth_for_constraints(initial_cost)),
876            (None, false) => None,
877        },
878        max_area: user_limits.max_area,
879    };
880    debug_assert!(
881        limits.max_delay.is_none() || limits.max_area.is_none(),
882        "effective constraints must keep at most one active cap"
883    );
884    limits
885}
886
887fn repair_energy(v: ConstraintViolationScore) -> u128 {
888    match (v.delay_over, v.area_over) {
889        (Some(over), None) => over as u128,
890        (None, Some(over)) => over as u128,
891        (Some(_), Some(_)) => unreachable!("constraint configuration validation rejects dual caps"),
892        (None, None) => 0,
893    }
894}
895
896pub(crate) fn format_search_score(score: SearchScore) -> String {
897    match score.violation {
898        None => format!("feasible(obj={})", score.objective),
899        Some(v) => format!(
900            "infeasible(obj={}, delay_over={:?}, area_over={:?})",
901            score.objective, v.delay_over, v.area_over,
902        ),
903    }
904}
905
906#[derive(Clone, Copy, Debug)]
907struct RawG8rStats {
908    nodes: usize,
909    depth: usize,
910    le_graph_milli: usize,
911    gate_output_toggles: usize,
912    weighted_switching_milli: u128,
913}
914
915#[derive(Clone, Copy, Debug, Default)]
916struct G8rPostStats {
917    and_nodes: usize,
918    depth: usize,
919    le_graph_milli: usize,
920    gate_output_toggles: usize,
921    weighted_switching_milli: u128,
922}
923
924#[derive(Clone, Copy, Debug)]
925struct GateCostStats {
926    raw: RawG8rStats,
927    post: Option<G8rPostStats>,
928}
929
930/// Canonical IR package text plus selected top used for gate-level scoring.
931pub struct CanonicalG8rScoringInput {
932    pub top_fn: IrFn,
933    pub ir_text: String,
934}
935
936/// Materializes the exact IR package that MCMC gate-level scoring lowers.
937pub fn canonical_g8r_scoring_input_for_pir_fn(
938    f: &IrFn,
939    extension_costing_mode: ExtensionCostingMode,
940) -> Result<CanonicalG8rScoringInput> {
941    let top_fn = optimize_pir_fn_via_xls_with_extension_mode(f, extension_costing_mode)?;
942    let ir_text = format!("package pir_mcmc\n\ntop {}", top_fn);
943    Ok(CanonicalG8rScoringInput { top_fn, ir_text })
944}
945
946/// Returns the canonical lowering knobs used inside gate-cost scoring.
947fn canonical_g8r_scoring_lowering_options(
948    canonical_g8r_options: &CanonicalG8rOptions,
949) -> CanonicalG8rOptions {
950    let mut scoring_g8r_options = canonical_g8r_options.clone();
951    // Gate-cost scoring computes graph LE separately only when the objective
952    // needs it, so the summary-stats lowering pass should not duplicate it.
953    scoring_g8r_options.compute_graph_logical_effort = false;
954    scoring_g8r_options
955}
956
957/// Postprocessed AIG payload plus summary stats suitable for durable artifacts.
958pub(crate) struct PostprocessedAigArtifact {
959    pub bytes: Vec<u8>,
960    pub stats: AigStats,
961    pub graph_logical_effort_worst_case_delay: Option<f64>,
962}
963
964/// Computes gate-level cost data for a PIR function by optimizing it, gatifying
965/// it, and optionally running the configured external postprocessor.
966fn compute_g8r_stats_for_pir_fn(
967    f: &IrFn,
968    compute_graph_logical_effort: bool,
969    compute_gate_output_toggles: bool,
970    compute_weighted_switching: bool,
971    toggle_stimulus: Option<&[Vec<IrBits>]>,
972    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
973    extension_costing_mode: ExtensionCostingMode,
974    g8r_evaluation_mode: &G8rEvaluationMode,
975    canonical_g8r_options: &CanonicalG8rOptions,
976    compute_postprocessed_stats: bool,
977) -> Result<GateCostStats> {
978    // The PIR → text → XLS → optimize → PIR → gatify pipeline assumes a DAG.
979    // Random rewiring transforms can (transiently) create cycles; if that happens
980    // we treat it as a candidate failure and fall back to PIR node count.
981    //
982    // Note: we intentionally catch panics here because some PIR utilities
983    // currently panic on cycle detection.
984    let result = catch_unwind(AssertUnwindSafe(|| {
985        compute_g8r_stats_for_pir_fn_impl(
986            f,
987            compute_graph_logical_effort,
988            compute_gate_output_toggles,
989            compute_weighted_switching,
990            toggle_stimulus,
991            weighted_switching_options,
992            extension_costing_mode,
993            g8r_evaluation_mode,
994            canonical_g8r_options,
995            compute_postprocessed_stats,
996        )
997    }));
998    match result {
999        Ok(r) => r,
1000        Err(_panic) => Err(anyhow::anyhow!(
1001            "panic during g8r-stats pipeline (likely a cycle)"
1002        )),
1003    }
1004}
1005
1006fn graph_le_delay_to_milli(delay: f64) -> usize {
1007    if !delay.is_finite() {
1008        return usize::MAX;
1009    }
1010    if delay <= 0.0 {
1011        return 0;
1012    }
1013    let scaled = delay * 1000.0;
1014    if scaled >= usize::MAX as f64 {
1015        usize::MAX
1016    } else {
1017        scaled.round() as usize
1018    }
1019}
1020
1021fn compute_g8r_stats_for_pir_fn_impl(
1022    f: &IrFn,
1023    compute_graph_logical_effort: bool,
1024    compute_gate_output_toggles: bool,
1025    compute_weighted_switching: bool,
1026    toggle_stimulus: Option<&[Vec<IrBits>]>,
1027    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1028    extension_costing_mode: ExtensionCostingMode,
1029    g8r_evaluation_mode: &G8rEvaluationMode,
1030    canonical_g8r_options: &CanonicalG8rOptions,
1031    compute_postprocessed_stats: bool,
1032) -> Result<GateCostStats> {
1033    // 1-4) Materialize the exact package text used for canonical lowering.
1034    let scoring_input = canonical_g8r_scoring_input_for_pir_fn(f, extension_costing_mode)?;
1035    let top_fn = scoring_input.top_fn;
1036    let scoring_g8r_options = canonical_g8r_scoring_lowering_options(canonical_g8r_options);
1037    let artifacts = canonical_ir_text_to_g8r_lowering_artifacts(
1038        &scoring_input.ir_text,
1039        Some(&top_fn.name),
1040        &scoring_g8r_options,
1041    )
1042    .map_err(|e| anyhow::anyhow!("canonical g8r lowering failed: {}", e))?;
1043    let gate_fn = artifacts.gate_fn;
1044    let stats = artifacts.stats;
1045    let g8r_le_graph_milli = if compute_graph_logical_effort {
1046        let graph_le = analyze_graph_logical_effort(
1047            &gate_fn,
1048            &GraphLogicalEffortOptions {
1049                beta1: canonical_g8r_options.graph_logical_effort_beta1,
1050                beta2: canonical_g8r_options.graph_logical_effort_beta2,
1051            },
1052        );
1053        graph_le_delay_to_milli(graph_le.delay)
1054    } else {
1055        0
1056    };
1057    let (g8r_gate_output_toggles, g8r_weighted_switching_milli) = if compute_gate_output_toggles
1058        || compute_weighted_switching
1059    {
1060        let batch = toggle_stimulus.ok_or_else(|| {
1061            anyhow::anyhow!("toggle-based objective requires prepared toggle stimulus")
1062        })?;
1063        if batch.len() < 2 {
1064            return Err(anyhow::anyhow!(
1065                "toggle stimulus must contain at least two samples; got {}",
1066                batch.len()
1067            ));
1068        }
1069        let expected_input_count = gate_fn.inputs.len();
1070        for (sample_idx, sample) in batch.iter().enumerate() {
1071            if sample.len() != expected_input_count {
1072                return Err(anyhow::anyhow!(
1073                    "toggle sample {} has {} inputs, expected {}",
1074                    sample_idx + 1,
1075                    sample.len(),
1076                    expected_input_count
1077                ));
1078            }
1079            for (input_idx, (bits, gate_input)) in
1080                sample.iter().zip(gate_fn.inputs.iter()).enumerate()
1081            {
1082                let expected_width = gate_input.get_bit_count();
1083                if bits.get_bit_count() != expected_width {
1084                    return Err(anyhow::anyhow!(
1085                        "toggle sample {} input {} has width {}, expected {}",
1086                        sample_idx + 1,
1087                        input_idx,
1088                        bits.get_bit_count(),
1089                        expected_width
1090                    ));
1091                }
1092            }
1093        }
1094        let mut gate_output_toggles = 0usize;
1095        let mut weighted_switching_milli = 0u128;
1096        if compute_weighted_switching {
1097            let weighted_stats = count_toggles::count_weighted_switching(
1098                &gate_fn,
1099                batch,
1100                weighted_switching_options,
1101            );
1102            weighted_switching_milli = weighted_stats.weighted_switching_milli;
1103            if compute_gate_output_toggles {
1104                gate_output_toggles = weighted_stats.gate_output_toggles;
1105            }
1106        }
1107        if compute_gate_output_toggles && !compute_weighted_switching {
1108            gate_output_toggles = count_toggles::count_toggles(&gate_fn, batch).gate_output_toggles;
1109        }
1110        (gate_output_toggles, weighted_switching_milli)
1111    } else {
1112        (0, 0u128)
1113    };
1114    let post = if compute_postprocessed_stats {
1115        let schema = GateFnInterfaceSchema::from_pir_fn(&top_fn)
1116            .map_err(|e| anyhow::anyhow!("failed to derive gate interface schema: {}", e))?;
1117        let post_gate_fn =
1118            postprocess_gate_fn_with_external_program(&gate_fn, &schema, g8r_evaluation_mode)?
1119                .gate_fn;
1120        Some(compute_post_stats_for_gate_fn(
1121            &post_gate_fn,
1122            compute_graph_logical_effort,
1123            compute_gate_output_toggles,
1124            compute_weighted_switching,
1125            toggle_stimulus,
1126            weighted_switching_options,
1127            canonical_g8r_options,
1128        )?)
1129    } else {
1130        None
1131    };
1132
1133    Ok(GateCostStats {
1134        raw: RawG8rStats {
1135            nodes: stats.live_nodes,
1136            depth: stats.deepest_path,
1137            le_graph_milli: g8r_le_graph_milli,
1138            gate_output_toggles: g8r_gate_output_toggles,
1139            weighted_switching_milli: g8r_weighted_switching_milli,
1140        },
1141        post,
1142    })
1143}
1144
1145struct LoadedPostprocessedGateFn {
1146    gate_fn: xlsynth_g8r::aig::gate::GateFn,
1147    output_bytes: Vec<u8>,
1148}
1149
1150/// Runs the configured external postprocessor and loads the returned AIGER as a
1151/// `GateFn` with the original interface shape restored.
1152fn postprocess_gate_fn_with_external_program(
1153    gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1154    schema: &GateFnInterfaceSchema,
1155    g8r_evaluation_mode: &G8rEvaluationMode,
1156) -> Result<LoadedPostprocessedGateFn> {
1157    let program = g8r_evaluation_mode
1158        .external_postprocess_program()
1159        .ok_or_else(|| {
1160            anyhow::anyhow!("g8r postprocessing requested without an external postprocessor")
1161        })?;
1162    let temp_dir = tempfile::Builder::new()
1163        .prefix("pir_mcmc_g8r_postprocess_")
1164        .tempdir()
1165        .map_err(|e| anyhow::anyhow!("failed to create g8r postprocess tempdir: {}", e))?;
1166    let input_path = temp_dir.path().join("input.aig");
1167    let output_path = temp_dir.path().join("output.aig");
1168    let input_bytes = emit_aiger_binary(gate_fn, true)
1169        .map_err(|e| anyhow::anyhow!("emit AIGER failed: {}", e))?;
1170    std::fs::write(&input_path, input_bytes).map_err(|e| {
1171        anyhow::anyhow!(
1172            "failed to write g8r postprocess input {}: {}",
1173            input_path.display(),
1174            e
1175        )
1176    })?;
1177
1178    let output = Command::new(program)
1179        .arg(&input_path)
1180        .arg("--output-path")
1181        .arg(&output_path)
1182        .output()
1183        .map_err(|e| anyhow::anyhow!("failed to run g8r postprocessor '{}': {}", program, e))?;
1184    if !output.status.success() {
1185        return Err(anyhow::anyhow!(
1186            "g8r postprocessor '{}' failed with status {}: {}",
1187            program,
1188            output.status,
1189            String::from_utf8_lossy(&output.stderr).trim()
1190        ));
1191    }
1192    if !output_path.exists() {
1193        return Err(anyhow::anyhow!(
1194            "g8r postprocessor '{}' did not create {}",
1195            program,
1196            output_path.display()
1197        ));
1198    }
1199    let output_bytes = std::fs::read(&output_path).map_err(|e| {
1200        anyhow::anyhow!(
1201            "failed to read g8r postprocess output {}: {}",
1202            output_path.display(),
1203            e
1204        )
1205    })?;
1206    let loaded =
1207        load_aiger_auto_from_path(&output_path, GateBuilderOptions::no_opt()).map_err(|e| {
1208            anyhow::anyhow!(
1209                "failed to load g8r postprocess output {}: {}",
1210                output_path.display(),
1211                e
1212            )
1213        })?;
1214    let gate_fn = repack_gate_fn_interface_with_schema(loaded.gate_fn, schema)
1215        .map_err(|e| anyhow::anyhow!("failed to repack postprocessed AIGER interface: {}", e))?;
1216    Ok(LoadedPostprocessedGateFn {
1217        gate_fn,
1218        output_bytes,
1219    })
1220}
1221
1222/// Runs the external postprocessor for a gate function and returns durable
1223/// bytes plus structural stats for artifact emission.
1224pub(crate) fn postprocess_gate_fn_for_artifact(
1225    gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1226    schema: &GateFnInterfaceSchema,
1227    g8r_evaluation_mode: &G8rEvaluationMode,
1228    canonical_g8r_options: &CanonicalG8rOptions,
1229    compute_graph_logical_effort: bool,
1230) -> Result<PostprocessedAigArtifact> {
1231    let loaded = postprocess_gate_fn_with_external_program(gate_fn, schema, g8r_evaluation_mode)?;
1232    let stats = get_summary_stats::get_aig_stats(&loaded.gate_fn);
1233    let graph_logical_effort_worst_case_delay = compute_graph_logical_effort.then(|| {
1234        analyze_graph_logical_effort(
1235            &loaded.gate_fn,
1236            &GraphLogicalEffortOptions {
1237                beta1: canonical_g8r_options.graph_logical_effort_beta1,
1238                beta2: canonical_g8r_options.graph_logical_effort_beta2,
1239            },
1240        )
1241        .delay
1242    });
1243    Ok(PostprocessedAigArtifact {
1244        bytes: loaded.output_bytes,
1245        stats,
1246        graph_logical_effort_worst_case_delay,
1247    })
1248}
1249
1250/// Computes postprocessed AIG stats from a loaded gate function.
1251fn compute_post_stats_for_gate_fn(
1252    gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1253    compute_graph_logical_effort: bool,
1254    compute_gate_output_toggles: bool,
1255    compute_weighted_switching: bool,
1256    toggle_stimulus: Option<&[Vec<IrBits>]>,
1257    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1258    canonical_g8r_options: &CanonicalG8rOptions,
1259) -> Result<G8rPostStats> {
1260    let stats = get_summary_stats::get_aig_stats(gate_fn);
1261    let le_graph_milli = if compute_graph_logical_effort {
1262        let graph_le = analyze_graph_logical_effort(
1263            gate_fn,
1264            &GraphLogicalEffortOptions {
1265                beta1: canonical_g8r_options.graph_logical_effort_beta1,
1266                beta2: canonical_g8r_options.graph_logical_effort_beta2,
1267            },
1268        );
1269        graph_le_delay_to_milli(graph_le.delay)
1270    } else {
1271        0
1272    };
1273    let (gate_output_toggles, weighted_switching_milli) = compute_toggle_stats_for_gate_fn(
1274        gate_fn,
1275        compute_gate_output_toggles,
1276        compute_weighted_switching,
1277        toggle_stimulus,
1278        weighted_switching_options,
1279    )?;
1280    Ok(G8rPostStats {
1281        and_nodes: stats.and_nodes,
1282        depth: stats.max_depth,
1283        le_graph_milli,
1284        gate_output_toggles,
1285        weighted_switching_milli,
1286    })
1287}
1288
1289fn compute_toggle_stats_for_gate_fn(
1290    gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1291    compute_gate_output_toggles: bool,
1292    compute_weighted_switching: bool,
1293    toggle_stimulus: Option<&[Vec<IrBits>]>,
1294    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1295) -> Result<(usize, u128)> {
1296    if !compute_gate_output_toggles && !compute_weighted_switching {
1297        return Ok((0, 0));
1298    }
1299    let batch = validate_toggle_stimulus_for_gate_fn(gate_fn, toggle_stimulus)?;
1300    let mut gate_output_toggles = 0usize;
1301    let mut weighted_switching_milli = 0u128;
1302    if compute_weighted_switching {
1303        let weighted_stats =
1304            count_toggles::count_weighted_switching(gate_fn, batch, weighted_switching_options);
1305        weighted_switching_milli = weighted_stats.weighted_switching_milli;
1306        if compute_gate_output_toggles {
1307            gate_output_toggles = weighted_stats.gate_output_toggles;
1308        }
1309    }
1310    if compute_gate_output_toggles && !compute_weighted_switching {
1311        gate_output_toggles = count_toggles::count_toggles(gate_fn, batch).gate_output_toggles;
1312    }
1313    Ok((gate_output_toggles, weighted_switching_milli))
1314}
1315
1316fn validate_toggle_stimulus_for_gate_fn<'a>(
1317    gate_fn: &xlsynth_g8r::aig::gate::GateFn,
1318    toggle_stimulus: Option<&'a [Vec<IrBits>]>,
1319) -> Result<&'a [Vec<IrBits>]> {
1320    let batch = toggle_stimulus.ok_or_else(|| {
1321        anyhow::anyhow!("toggle-based objective requires prepared toggle stimulus")
1322    })?;
1323    if batch.len() < 2 {
1324        return Err(anyhow::anyhow!(
1325            "toggle stimulus must contain at least two samples; got {}",
1326            batch.len()
1327        ));
1328    }
1329    xlsynth_g8r::aig_sim::gate_simd::validate_ordered_batch_inputs(gate_fn, batch)
1330        .map_err(anyhow::Error::msg)?;
1331    Ok(batch)
1332}
1333
1334/// Validates tuple-valued stimulus samples against `f`'s parameter signature
1335/// and lowers each sample to GateFn input vectors (`Vec<IrBits>`).
1336pub fn lower_toggle_stimulus_for_fn(samples: &[IrValue], f: &IrFn) -> Result<Vec<Vec<IrBits>>> {
1337    if samples.len() < 2 {
1338        return Err(anyhow::anyhow!(
1339            "toggle stimulus must contain at least two samples; got {}",
1340            samples.len()
1341        ));
1342    }
1343
1344    let mut lowered: Vec<Vec<IrBits>> = Vec::with_capacity(samples.len());
1345    for (sample_idx, tuple_val) in samples.iter().enumerate() {
1346        let elems = tuple_val.get_elements().map_err(|e| {
1347            anyhow::anyhow!("sample {} is not a tuple value: {}", sample_idx + 1, e)
1348        })?;
1349        if elems.len() != f.params.len() {
1350            return Err(anyhow::anyhow!(
1351                "sample {} tuple arity mismatch: expected {}, got {}",
1352                sample_idx + 1,
1353                f.params.len(),
1354                elems.len()
1355            ));
1356        }
1357
1358        let mut sample_bits = Vec::with_capacity(f.params.len());
1359        for (param_idx, (elem, param)) in elems.iter().zip(f.params.iter()).enumerate() {
1360            let mut flat_bits: Vec<bool> = Vec::with_capacity(param.ty.bit_count());
1361            flatten_ir_value_to_lsb0_bits_for_type(elem, &param.ty, &mut flat_bits).map_err(
1362                |e| {
1363                    anyhow::anyhow!(
1364                        "sample {} param {} ('{}') incompatible with {}: {}",
1365                        sample_idx + 1,
1366                        param_idx,
1367                        param.name,
1368                        param.ty.to_string(),
1369                        e
1370                    )
1371                },
1372            )?;
1373            if flat_bits.len() != param.ty.bit_count() {
1374                return Err(anyhow::anyhow!(
1375                    "sample {} param {} ('{}') flattened width mismatch: expected {}, got {}",
1376                    sample_idx + 1,
1377                    param_idx,
1378                    param.name,
1379                    param.ty.bit_count(),
1380                    flat_bits.len()
1381                ));
1382            }
1383            sample_bits.push(IrBits::from_lsb_is_0(&flat_bits));
1384        }
1385        lowered.push(sample_bits);
1386    }
1387    Ok(lowered)
1388}
1389
1390/// Type aliases specializing the generic MCMC helpers from `xlsynth-mcmc` to
1391/// the PIR world.
1392pub type McmcStats = SharedMcmcStats<PirTransformKind>;
1393pub type IterationOutcomeDetails = xlsynth_mcmc::IterationOutcomeDetails<PirTransformKind>;
1394pub type McmcIterationOutput = SharedMcmcIterationOutput<IrFn, Cost, PirTransformKind>;
1395pub type McmcOptions = SharedMcmcOptions;
1396
1397/// Cached oracle inputs and expected baseline `eval_fn` results.
1398///
1399/// MCMC oracle checks compare many candidate functions against the same
1400/// accepted baseline. This cache stores the deterministic/random sample
1401/// arguments and the baseline return value for each sample so candidates only
1402/// need to evaluate their rewritten graph.
1403#[derive(Default)]
1404pub struct EvalFnBaselineResults {
1405    samples: Vec<Vec<IrValue>>,
1406    expected_values: Vec<Result<IrValue, ()>>,
1407    random_samples: usize,
1408    param_types: Vec<PirType>,
1409}
1410
1411impl EvalFnBaselineResults {
1412    fn clear(&mut self) {
1413        self.samples.clear();
1414        self.expected_values.clear();
1415        self.random_samples = 0;
1416        self.param_types.clear();
1417    }
1418
1419    fn matches_signature(&self, baseline: &IrFn, random_samples: usize) -> bool {
1420        self.random_samples == random_samples
1421            && self.param_types.len() == baseline.params.len()
1422            && self
1423                .param_types
1424                .iter()
1425                .zip(baseline.params.iter())
1426                .all(|(cached_ty, param)| cached_ty == &param.ty)
1427    }
1428
1429    fn populate_from_baseline<R: Rng>(
1430        &mut self,
1431        baseline: &IrFn,
1432        rng: &mut R,
1433        random_samples: usize,
1434    ) -> Result<()> {
1435        self.clear();
1436        self.random_samples = random_samples;
1437        self.param_types = baseline.params.iter().map(|p| p.ty.clone()).collect();
1438
1439        // Deterministic corner cases first: all-zeros and all-ones.
1440        self.samples.push(make_oracle_args(
1441            &baseline.params,
1442            "all-zeros",
1443            make_all_zeros_value,
1444        )?);
1445        self.samples.push(make_oracle_args(
1446            &baseline.params,
1447            "all-ones",
1448            make_all_ones_value,
1449        )?);
1450
1451        for _ in 0..random_samples {
1452            self.samples
1453                .push(make_oracle_args(&baseline.params, "random", |ty| {
1454                    arbitrary_value_for_type(rng, ty)
1455                })?);
1456        }
1457
1458        self.expected_values = self
1459            .samples
1460            .iter()
1461            .map(|args| eval_fn_safe(baseline, args))
1462            .collect();
1463        Ok(())
1464    }
1465
1466    fn ensure_populated<R: Rng>(
1467        &mut self,
1468        baseline_if_empty: &IrFn,
1469        rng: &mut R,
1470        random_samples: usize,
1471    ) -> Result<()> {
1472        if !self.matches_signature(baseline_if_empty, random_samples) || self.samples.is_empty() {
1473            self.populate_from_baseline(baseline_if_empty, rng, random_samples)?;
1474        }
1475        Ok(())
1476    }
1477}
1478
1479/// Context for a PIR MCMC iteration, holding shared resources.
1480pub struct PirMcmcContext<'a> {
1481    pub rng: &'a mut Pcg64Mcg,
1482    pub all_transforms: Vec<Box<dyn PirTransform>>,
1483    pub weights: Vec<f64>,
1484    pub enable_formal_oracle: bool,
1485    pub oracle_baseline_cache: EvalFnBaselineResults,
1486}
1487
1488/// Options controlling a PIR MCMC run.
1489#[derive(Clone, Debug)]
1490pub struct RunOptions {
1491    /// Maximum number of iterations to perform.
1492    pub max_iters: u64,
1493    /// Number of parallel chains to run.
1494    pub threads: u64,
1495    /// Strategy for running multiple chains.
1496    pub chain_strategy: ChainStrategy,
1497    /// Segment size (iterations) for explore/exploit synchronization.
1498    ///
1499    /// Only used when `chain_strategy=ExploreExploit`.
1500    pub checkpoint_iters: u64,
1501    /// Progress logging interval in iterations (0 disables progress logs).
1502    pub progress_iters: u64,
1503    /// RNG seed for the Markov chain.
1504    pub seed: u64,
1505    /// Initial temperature for MCMC.
1506    pub initial_temperature: f64,
1507    /// Objective to optimize.
1508    pub objective: Objective,
1509    /// How extension ops are projected before XLS optimization and g8r costing.
1510    pub extension_costing_mode: ExtensionCostingMode,
1511    /// How gate-level costs are obtained for `g8r-post-*` objectives.
1512    pub g8r_evaluation_mode: G8rEvaluationMode,
1513    /// Canonical g8r lowering knobs shared with `ir2g8r`.
1514    pub canonical_g8r_options: CanonicalG8rOptions,
1515    /// Optional hard cap on gate depth (`g8r_depth`) for g8r-based objectives.
1516    pub max_allowed_depth: Option<usize>,
1517    /// Optional hard cap on gate count (`g8r_nodes`) for g8r-based objectives.
1518    pub max_allowed_area: Option<usize>,
1519    /// Parameters used to convert per-node fanout to load weighting when
1520    /// computing weighted-switching objectives.
1521    pub weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1522    /// When true, and when the crate is built with a formal solver feature
1523    /// (e.g. `--features with-boolector-built`), run a formal equivalence
1524    /// oracle after the fast interpreter-based oracle for
1525    /// non-always-equivalent transforms.
1526    pub enable_formal_oracle: bool,
1527
1528    /// Optional directory for writing per-chain trajectory logs as JSONL.
1529    ///
1530    /// When set, each chain appends one JSON record per iteration to:
1531    ///   `trajectory.c{chain_no:03}.jsonl`
1532    pub trajectory_dir: Option<PathBuf>,
1533
1534    /// Optional toggle stimulus samples in `.irvals` tuple form (one tuple per
1535    /// sample) used by toggle-based objectives.
1536    pub toggle_stimulus: Option<Vec<IrValue>>,
1537}
1538
1539/// Message sent from the PIR MCMC engine to an optional checkpoint writer.
1540///
1541/// This is used by CLI checkpoint writers to keep on-disk best artifacts
1542/// up-to-date during long runs and (optionally) to snapshot the improvement
1543/// trajectory.
1544#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1545pub enum CheckpointKind {
1546    /// A periodic checkpoint tick (e.g. every N iterations).
1547    Periodic,
1548    /// A new global best was found.
1549    GlobalBestUpdate,
1550}
1551
1552/// A checkpoint writer notification, including the chain and iteration that
1553/// triggered the event.
1554#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1555pub struct CheckpointMsg {
1556    pub chain_no: usize,
1557    pub global_iter: u64,
1558    pub kind: CheckpointKind,
1559}
1560
1561/// Result of a PIR MCMC run.
1562pub struct PirMcmcResult {
1563    pub best_fn: IrFn,
1564    pub best_cost: Cost,
1565    pub stats: McmcStats,
1566}
1567
1568/// One exact provenance action on the path that led to an MCMC winner.
1569#[derive(Clone, Debug)]
1570pub enum PirMcmcProvenanceAction {
1571    /// A raw PIR rewrite accepted by the MCMC loop.
1572    AcceptedRewrite {
1573        /// One-based count of provenance actions from the origin to this state.
1574        action_index: usize,
1575        /// Chain that performed the action.
1576        chain_no: usize,
1577        /// One-based global MCMC iteration that accepted this rewrite.
1578        global_iter: u64,
1579        /// Transform accepted at this action.
1580        transform_kind: PirTransformKind,
1581        /// Raw accepted PIR state after the rewrite.
1582        state: IrFn,
1583        /// Cost of `state` under the run's objective semantics.
1584        cost: Cost,
1585    },
1586    /// A chain switched from the raw winner path to the XLS-optimized public
1587    /// best state used for explore/exploit handoff.
1588    XlsOptimizedHandoff {
1589        /// One-based count of provenance actions from the origin to this state.
1590        action_index: usize,
1591        /// Chain that received the handoff.
1592        chain_no: usize,
1593        /// Global MCMC iteration at the synchronization barrier.
1594        global_iter: u64,
1595        /// XLS-optimized state handed to the receiving chain.
1596        state: IrFn,
1597        /// Cost of `state` under the run's objective semantics.
1598        cost: Cost,
1599    },
1600}
1601
1602impl PirMcmcProvenanceAction {
1603    fn action_index(&self) -> usize {
1604        match self {
1605            Self::AcceptedRewrite { action_index, .. }
1606            | Self::XlsOptimizedHandoff { action_index, .. } => *action_index,
1607        }
1608    }
1609
1610    fn state(&self) -> &IrFn {
1611        match self {
1612            Self::AcceptedRewrite { state, .. } | Self::XlsOptimizedHandoff { state, .. } => state,
1613        }
1614    }
1615
1616    fn cost(&self) -> Cost {
1617        match self {
1618            Self::AcceptedRewrite { cost, .. } | Self::XlsOptimizedHandoff { cost, .. } => *cost,
1619        }
1620    }
1621
1622    fn transform_kind(&self) -> Option<PirTransformKind> {
1623        match self {
1624            Self::AcceptedRewrite { transform_kind, .. } => Some(*transform_kind),
1625            Self::XlsOptimizedHandoff { .. } => None,
1626        }
1627    }
1628}
1629
1630/// In-memory provenance artifact for minimizing a discovered MCMC witness.
1631#[derive(Clone, Debug)]
1632pub struct PirMcmcArtifact {
1633    /// Canonicalized function used as the exact rollout origin.
1634    pub origin_fn: IrFn,
1635    /// Cost of `origin_fn`.
1636    pub origin_cost: Cost,
1637    /// Options used to produce this artifact.
1638    pub run_options: RunOptions,
1639    /// Final winning state at the end of `winning_provenance`.
1640    pub raw_winner_fn: IrFn,
1641    /// Cost of `raw_winner_fn`.
1642    pub raw_winner_cost: Cost,
1643    /// Exact provenance action sequence from `origin_fn` to `raw_winner_fn`.
1644    pub winning_provenance: Vec<PirMcmcProvenanceAction>,
1645}
1646
1647/// Options for reducing winning provenance to an earliest useful prefix.
1648#[derive(Clone, Copy, Debug)]
1649pub struct PirMcmcPrefixMinimizeOptions {
1650    /// Fraction of the discovered objective improvement that must be retained.
1651    pub retained_win_fraction: f64,
1652}
1653
1654/// Result of reducing winning provenance to an earliest useful prefix.
1655#[derive(Clone, Debug)]
1656pub struct PirMcmcPrefixMinimizeResult {
1657    /// Earliest provenance-prefix state satisfying the requested retained win.
1658    pub witness_fn: IrFn,
1659    /// Cost of `witness_fn`.
1660    pub witness_cost: Cost,
1661    /// Number of provenance actions needed to reach `witness_fn`.
1662    pub provenance_action_count: usize,
1663    /// Number of provenance actions in the original winning path.
1664    pub original_winning_provenance_len: usize,
1665    /// Fraction requested by the caller.
1666    pub requested_retained_win_fraction: f64,
1667    /// Fraction actually retained by the selected witness.
1668    pub actual_retained_win_fraction: f64,
1669    /// Origin objective metric.
1670    pub origin_metric: u128,
1671    /// Final winner objective metric.
1672    pub winner_metric: u128,
1673    /// Selected witness objective metric.
1674    pub witness_metric: u128,
1675}
1676
1677/// Options for witness-guided short-witness frontier search.
1678#[derive(Clone, Copy, Debug)]
1679pub struct PirMcmcBudgetFrontierOptions {
1680    /// Requested budget spacing; budgets are `step, 2*step, ... <= max`.
1681    pub budget_step: usize,
1682    /// Largest provenance-action budget to evaluate.
1683    pub max_actions: usize,
1684    /// Number of independent short rollouts attempted per requested budget.
1685    pub rollouts_per_budget: usize,
1686    /// Search seed. Use the source artifact's seed by default.
1687    pub seed: u64,
1688    /// Extra proposal weight per winning-provenance occurrence of a transform
1689    /// kind.
1690    pub witness_kind_boost: f64,
1691    /// Proposal-attempt cap per accepted rewrite in each rollout.
1692    pub proposal_attempts_per_rewrite: usize,
1693}
1694
1695impl PirMcmcBudgetFrontierOptions {
1696    pub const DEFAULT_WITNESS_KIND_BOOST: f64 = 4.0;
1697    pub const DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE: usize = 64;
1698}
1699
1700/// One witness on a frontier, either searched or historical-prefix baseline.
1701#[derive(Clone, Debug)]
1702pub struct PirMcmcBudgetWitness {
1703    pub witness_fn: IrFn,
1704    pub witness_cost: Cost,
1705    pub provenance_action_count: usize,
1706    pub metric: u128,
1707    pub absolute_win: u128,
1708    pub win_percent_vs_origin: f64,
1709    pub retained_win_fraction: f64,
1710}
1711
1712/// One requested short-witness budget point.
1713#[derive(Clone, Debug)]
1714pub struct PirMcmcBudgetFrontierPoint {
1715    pub action_budget: usize,
1716    pub guided: PirMcmcBudgetWitness,
1717    pub prefix_baseline: PirMcmcBudgetWitness,
1718}
1719
1720/// Result of witness-guided short-witness frontier search.
1721#[derive(Clone, Debug)]
1722pub struct PirMcmcBudgetFrontierResult {
1723    pub origin_metric: u128,
1724    pub winner_metric: u128,
1725    pub original_winning_provenance_len: usize,
1726    pub points: Vec<PirMcmcBudgetFrontierPoint>,
1727}
1728
1729struct PirMcmcArtifactRunOutput {
1730    result: PirMcmcResult,
1731    artifact: PirMcmcArtifact,
1732}
1733
1734const PIR_MCMC_ARTIFACT_DIR_NAME: &str = "winning-lineage";
1735const PIR_MCMC_ARTIFACT_MANIFEST_FILE: &str = "manifest.json";
1736const PIR_MCMC_ARTIFACT_STATES_DIR_NAME: &str = "states";
1737const PIR_MCMC_ARTIFACT_SCHEMA_VERSION: u32 = 4;
1738
1739/// Durable PIR MCMC artifact loaded from a run directory.
1740pub struct LoadedPirMcmcArtifact {
1741    pub artifact: PirMcmcArtifact,
1742    pub package_template: PirPackage,
1743    pub top_fn_name: String,
1744}
1745
1746#[derive(Debug, Serialize, Deserialize)]
1747struct PersistedPirMcmcArtifactManifest {
1748    schema_version: u32,
1749    top_fn_name: String,
1750    run_options: PersistedRunOptions,
1751    origin: PersistedArtifactState,
1752    raw_winner: PersistedArtifactState,
1753    winning_provenance: Vec<PersistedPirMcmcProvenanceAction>,
1754}
1755
1756#[derive(Debug, Serialize, Deserialize)]
1757struct PersistedRunOptions {
1758    max_iters: u64,
1759    threads: u64,
1760    chain_strategy: String,
1761    checkpoint_iters: u64,
1762    progress_iters: u64,
1763    seed: u64,
1764    initial_temperature: f64,
1765    objective: String,
1766    extension_costing_mode: String,
1767    g8r_evaluation_mode: G8rEvaluationMode,
1768    canonical_g8r_options: CanonicalG8rOptions,
1769    max_allowed_depth: Option<usize>,
1770    max_allowed_area: Option<usize>,
1771    switching_beta1: f64,
1772    switching_beta2: f64,
1773    switching_primary_output_load: f64,
1774    enable_formal_oracle: bool,
1775}
1776
1777#[derive(Debug, Serialize, Deserialize)]
1778struct PersistedArtifactState {
1779    file: String,
1780    cost: Cost,
1781}
1782
1783#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
1784#[serde(rename_all = "snake_case")]
1785enum PersistedPirMcmcProvenanceActionKind {
1786    AcceptedRewrite,
1787    XlsOptimizedHandoff,
1788}
1789
1790#[derive(Debug, Serialize, Deserialize)]
1791struct PersistedPirMcmcProvenanceAction {
1792    kind: PersistedPirMcmcProvenanceActionKind,
1793    action_index: usize,
1794    chain_no: usize,
1795    global_iter: u64,
1796    transform_kind: Option<PirTransformKind>,
1797    state: PersistedArtifactState,
1798}
1799
1800/// Message sent from the PIR MCMC engine to an optional accepted-sample writer.
1801///
1802/// This is used by the `xlsynth-mcmc-pir-sampler` binary to build a
1803/// deduplicated corpus of accepted equivalent samples.
1804#[derive(Clone, Debug)]
1805pub struct AcceptedSampleMsg {
1806    pub chain_no: usize,
1807    pub global_iter: u64,
1808    pub digest: [u8; 32],
1809    pub cost: Cost,
1810    pub func: IrFn,
1811}
1812
1813fn compute_fn_structural_digest(f: &IrFn) -> Option<[u8; 32]> {
1814    let result = catch_unwind(AssertUnwindSafe(|| {
1815        let ret = f
1816            .ret_node_ref
1817            .expect("PIR functions must have a return node");
1818        let (entries, _depths) = collect_structural_entries(f);
1819        let h = entries[ret.index].hash.as_bytes();
1820        let mut out = [0u8; 32];
1821        out.copy_from_slice(h);
1822        out
1823    }));
1824    match result {
1825        Ok(v) => Some(v),
1826        Err(_panic) => None,
1827    }
1828}
1829
1830fn canonicalize_fn_for_sample(f: &IrFn) -> Result<IrFn> {
1831    let mut f = f.clone();
1832    compact_and_toposort_in_place(&mut f)
1833        .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
1834    Ok(f)
1835}
1836
1837fn iteration_outcome_tag<K>(o: &xlsynth_mcmc::IterationOutcomeDetails<K>) -> &'static str {
1838    match o {
1839        xlsynth_mcmc::IterationOutcomeDetails::CandidateFailure => "CandidateFailure",
1840        xlsynth_mcmc::IterationOutcomeDetails::ApplyFailure => "ApplyFailure",
1841        xlsynth_mcmc::IterationOutcomeDetails::SimFailure => "SimFailure",
1842        xlsynth_mcmc::IterationOutcomeDetails::OracleFailure => "OracleFailure",
1843        xlsynth_mcmc::IterationOutcomeDetails::MetropolisReject => "MetropolisReject",
1844        xlsynth_mcmc::IterationOutcomeDetails::Accepted { .. } => "Accepted",
1845    }
1846}
1847
1848fn hash_to_hex(bytes: &[u8; 32]) -> String {
1849    let mut s = String::with_capacity(64);
1850    for b in bytes.iter() {
1851        s.push_str(&format!("{:02x}", b));
1852    }
1853    s
1854}
1855
1856#[derive(Clone, Debug)]
1857struct ProvenancedChainState {
1858    /// State the next MCMC segment should continue searching from.
1859    search_fn: IrFn,
1860    /// Exact provenance to `search_fn`.
1861    search_provenance: Vec<PirMcmcProvenanceAction>,
1862    /// Winning state at the end of `raw_winner_provenance`.
1863    raw_winner_fn: IrFn,
1864    /// Cost of `raw_winner_fn`.
1865    raw_winner_cost: Cost,
1866    /// Exact provenance to `raw_winner_fn`.
1867    raw_winner_provenance: Vec<PirMcmcProvenanceAction>,
1868    /// Handoff metadata to materialize once the next segment recomputes the
1869    /// optimized search state's true cost.
1870    pending_handoff: Option<(usize, u64)>,
1871}
1872
1873impl ProvenancedChainState {
1874    fn origin(origin_fn: IrFn, origin_cost: Cost) -> Self {
1875        Self {
1876            search_fn: origin_fn.clone(),
1877            search_provenance: Vec::new(),
1878            raw_winner_fn: origin_fn,
1879            raw_winner_cost: origin_cost,
1880            raw_winner_provenance: Vec::new(),
1881            pending_handoff: None,
1882        }
1883    }
1884
1885    fn with_xls_optimized_handoff(&self, receiving_chain_no: usize, global_iter: u64) -> Self {
1886        let mut next = self.clone();
1887        next.pending_handoff = Some((receiving_chain_no, global_iter));
1888        next
1889    }
1890}
1891
1892struct PirSegmentRunner {
1893    objective: Objective,
1894    extension_costing_mode: ExtensionCostingMode,
1895    g8r_evaluation_mode: G8rEvaluationMode,
1896    canonical_g8r_options: CanonicalG8rOptions,
1897    weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1898    initial_temperature: f64,
1899    constraints: ConstraintLimits,
1900    enable_formal_oracle: bool,
1901    progress_iters: u64,
1902    checkpoint_iters: u64,
1903    checkpoint_tx: Option<Sender<CheckpointMsg>>,
1904    accepted_sample_tx: Option<Sender<AcceptedSampleMsg>>,
1905    shared_best: Option<Arc<Best>>,
1906    trajectory_dir: Option<PathBuf>,
1907    prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
1908}
1909
1910type PirTransformFactory = Arc<dyn Fn() -> Vec<Box<dyn PirTransform>> + Send + Sync>;
1911
1912struct PirArtifactSegmentRunner {
1913    objective: Objective,
1914    extension_costing_mode: ExtensionCostingMode,
1915    g8r_evaluation_mode: G8rEvaluationMode,
1916    canonical_g8r_options: CanonicalG8rOptions,
1917    weighted_switching_options: count_toggles::WeightedSwitchingOptions,
1918    initial_temperature: f64,
1919    constraints: ConstraintLimits,
1920    enable_formal_oracle: bool,
1921    progress_iters: u64,
1922    checkpoint_iters: u64,
1923    checkpoint_tx: Option<Sender<CheckpointMsg>>,
1924    shared_best: Option<Arc<Best>>,
1925    trajectory_dir: Option<PathBuf>,
1926    prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
1927    transform_factory: PirTransformFactory,
1928}
1929
1930/// Performs a single iteration of the PIR MCMC process.
1931pub fn mcmc_iteration(
1932    current_fn: IrFn, /* Takes ownership, becomes the basis for candidate or returned if no
1933                       * change */
1934    current_cost: Cost,
1935    best_fn: &mut IrFn,   // Mutated if new best is found
1936    best_cost: &mut Cost, // Mutated if new best is found
1937    best_score: &mut SearchScore,
1938    context: &mut PirMcmcContext,
1939    temp: f64,
1940    objective: Objective,
1941    extension_costing_mode: ExtensionCostingMode,
1942    g8r_evaluation_mode: &G8rEvaluationMode,
1943    canonical_g8r_options: &CanonicalG8rOptions,
1944    toggle_stimulus: Option<&[Vec<IrBits>]>,
1945    weighted_switching_options: &count_toggles::WeightedSwitchingOptions,
1946    constraints: ConstraintLimits,
1947) -> McmcIterationOutput {
1948    let mut iteration_best_updated = false;
1949
1950    if context.all_transforms.is_empty() {
1951        // No transforms available to apply.
1952        return McmcIterationOutput {
1953            output_state: current_fn,
1954            output_cost: current_cost,
1955            best_updated: false,
1956            outcome: IterationOutcomeDetails::CandidateFailure,
1957            oracle_time_micros: 0,
1958            transform_always_equivalent: true,
1959            transform: None,
1960        };
1961    }
1962
1963    let dist = WeightedIndex::new(&context.weights).expect("non-empty weights");
1964    let chosen_transform_idx = dist.sample(context.rng);
1965    let chosen_transform = &mut context.all_transforms[chosen_transform_idx];
1966    let current_transform_kind = chosen_transform.kind();
1967
1968    let mut candidates = chosen_transform.find_candidates(&current_fn);
1969    if !context.enable_formal_oracle {
1970        candidates.retain(|c| c.always_equivalent);
1971    }
1972
1973    log::trace!(
1974        "Found {} PIR candidates for {:?}",
1975        candidates.len(),
1976        current_transform_kind,
1977    );
1978
1979    if candidates.is_empty() {
1980        return McmcIterationOutput {
1981            output_state: current_fn,
1982            output_cost: current_cost,
1983            best_updated: false,
1984            outcome: IterationOutcomeDetails::CandidateFailure,
1985            oracle_time_micros: 0,
1986            transform_always_equivalent: true,
1987            transform: Some(current_transform_kind),
1988        };
1989    }
1990
1991    let chosen_candidate = candidates.choose(context.rng).unwrap();
1992
1993    log::trace!("Chosen PIR candidate: {:?}", chosen_candidate);
1994
1995    let mut candidate_fn = current_fn.clone();
1996
1997    log::trace!(
1998        "Applying PIR transform {:?} at {:?}",
1999        current_transform_kind,
2000        chosen_candidate.location
2001    );
2002
2003    match chosen_transform.apply(&mut candidate_fn, &chosen_candidate.location) {
2004        Ok(()) => {
2005            // Transform implementations may emit acyclic def-after-use graphs.
2006            // Canonicalize centrally here so the oracle and cost paths see the
2007            // same normalized IR without each transform paying a local
2008            // topo-sort/compaction cost.
2009            if let Err(e) = compact_and_toposort_in_place(&mut candidate_fn) {
2010                log::debug!(
2011                    "[pir-mcmc] compact/toposort failed for '{}' after {:?} at {:?}: {}; \
2012                     rejecting candidate",
2013                    candidate_fn.name,
2014                    current_transform_kind,
2015                    chosen_candidate.location,
2016                    e
2017                );
2018                return McmcIterationOutput {
2019                    output_state: current_fn,
2020                    output_cost: current_cost,
2021                    best_updated: false,
2022                    outcome: IterationOutcomeDetails::CandidateFailure,
2023                    oracle_time_micros: 0,
2024                    transform_always_equivalent: chosen_candidate.always_equivalent,
2025                    transform: Some(current_transform_kind),
2026                };
2027            }
2028
2029            log::trace!("PIR transform applied successfully; determining cost...");
2030            let (is_equiv, oracle_time_micros) = if chosen_candidate.always_equivalent {
2031                // Always-equivalent transforms can skip equivalence checks.
2032                (true, 0u128)
2033            } else {
2034                let oracle_start = Instant::now();
2035                // For transforms that are not guaranteed to preserve semantics, we run a
2036                // lightweight equivalence oracle:
2037                // - Deterministic corner cases (all-zeros, all-ones)
2038                // - A small number of randomized samples (seeded by the run's RNG)
2039                //
2040                // If evaluation fails (e.g. due to a cycle or unsupported node kinds),
2041                // we treat that as non-equivalence for this candidate and reject it.
2042                let ok = pir_equiv_oracle(
2043                    &current_fn,
2044                    &candidate_fn,
2045                    context.rng,
2046                    DEFAULT_ORACLE_RANDOM_SAMPLES,
2047                    context.enable_formal_oracle,
2048                    &mut context.oracle_baseline_cache,
2049                );
2050                let micros = oracle_start.elapsed().as_micros();
2051                (ok, micros)
2052            };
2053
2054            if !is_equiv {
2055                McmcIterationOutput {
2056                    output_state: current_fn,
2057                    output_cost: current_cost,
2058                    best_updated: false,
2059                    outcome: IterationOutcomeDetails::OracleFailure,
2060                    oracle_time_micros,
2061                    transform_always_equivalent: chosen_candidate.always_equivalent,
2062                    transform: Some(current_transform_kind),
2063                }
2064            } else {
2065                let cost_start = Instant::now();
2066                let new_candidate_cost =
2067                    match cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
2068                        &candidate_fn,
2069                        objective,
2070                        toggle_stimulus,
2071                        weighted_switching_options,
2072                        extension_costing_mode,
2073                        g8r_evaluation_mode,
2074                        canonical_g8r_options,
2075                    ) {
2076                        Ok(c) => c,
2077                        Err(e) => {
2078                            let sim_micros = cost_start.elapsed().as_micros();
2079                            let msg = e.to_string();
2080                            let is_invalid_bit_slice = msg
2081                                .contains("Expected operand 0 of bit_slice")
2082                                || msg.contains("invalid bit_slice");
2083
2084                            if is_invalid_bit_slice {
2085                                // Not a sample failure: we sometimes propose structurally invalid
2086                                // candidates (e.g. bit_slice bounds violations) while exploring.
2087                                // These are rejected.
2088                                //
2089                                // Still, keep this loud for a bit to catch regressions in
2090                                // transforms.
2091                                let n = INVALID_BIT_SLICE_WARN_COUNT
2092                                    .fetch_add(1, Ordering::Relaxed)
2093                                    .saturating_add(1);
2094                                if n <= INVALID_BIT_SLICE_WARN_LIMIT {
2095                                    log::warn!(
2096                                        "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate (invalid bit_slice; warning {}/{})",
2097                                        candidate_fn.name,
2098                                        objective,
2099                                        e,
2100                                        n,
2101                                        INVALID_BIT_SLICE_WARN_LIMIT
2102                                    );
2103                                } else {
2104                                    log::debug!(
2105                                        "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate (invalid bit_slice; further occurrences silenced to debug)",
2106                                        candidate_fn.name,
2107                                        objective,
2108                                        e
2109                                    );
2110                                }
2111                            } else {
2112                                log::warn!(
2113                                    "[pir-mcmc] cost evaluation failed for '{}' under {:?}: {}; rejecting candidate",
2114                                    candidate_fn.name,
2115                                    objective,
2116                                    e
2117                                );
2118                            }
2119                            return McmcIterationOutput {
2120                                output_state: current_fn,
2121                                output_cost: current_cost,
2122                                best_updated: false,
2123                                outcome: IterationOutcomeDetails::SimFailure,
2124                                oracle_time_micros: sim_micros,
2125                                transform_always_equivalent: chosen_candidate.always_equivalent,
2126                                transform: Some(current_transform_kind),
2127                            };
2128                        }
2129                    };
2130
2131                let current_score = search_score(&current_cost, objective, constraints);
2132                let new_score = search_score(&new_candidate_cost, objective, constraints);
2133                let curr_metric_u128 = objective.metric(&current_cost);
2134                let new_metric_u128 = objective.metric(&new_candidate_cost);
2135                let accept = match (current_score.violation, new_score.violation) {
2136                    (Some(_), None) => true,
2137                    (None, Some(_)) => false,
2138                    (Some(curr_violation), Some(new_violation)) => metropolis_accept(
2139                        repair_energy(curr_violation) as f64,
2140                        repair_energy(new_violation) as f64,
2141                        temp,
2142                        context.rng,
2143                    ),
2144                    (None, None) => {
2145                        if new_metric_u128 == curr_metric_u128
2146                            && new_candidate_cost.pir_nodes > current_cost.pir_nodes
2147                        {
2148                            // Equal objective metric but PIR nodes grew: only accept if
2149                            // the temperature still allows it.
2150                            metropolis_accept(
2151                                current_cost.pir_nodes as f64,
2152                                new_candidate_cost.pir_nodes as f64,
2153                                temp,
2154                                context.rng,
2155                            )
2156                        } else {
2157                            metropolis_accept(
2158                                curr_metric_u128 as f64,
2159                                new_metric_u128 as f64,
2160                                temp,
2161                                context.rng,
2162                            )
2163                        }
2164                    }
2165                };
2166
2167                if accept {
2168                    if new_score < *best_score {
2169                        // When storing a new global best, prefer the optimized IR form so
2170                        // artifacts (and subsequent segments via shared best) are based on
2171                        // the canonical optimized representation, not the raw exploration
2172                        // state.
2173                        *best_fn = match optimize_pir_fn_via_xls_with_extension_mode(
2174                            &candidate_fn,
2175                            extension_costing_mode,
2176                        ) {
2177                            Ok(opt) => opt,
2178                            Err(e) => {
2179                                log::warn!(
2180                                    "[pir-mcmc] failed to optimize new best candidate '{}': {}; storing unoptimized function",
2181                                    candidate_fn.name,
2182                                    e
2183                                );
2184                                candidate_fn.clone()
2185                            }
2186                        };
2187                        *best_cost = new_candidate_cost;
2188                        *best_score = new_score;
2189                        iteration_best_updated = true;
2190                    }
2191                    McmcIterationOutput {
2192                        output_state: candidate_fn,
2193                        output_cost: new_candidate_cost,
2194                        best_updated: iteration_best_updated,
2195                        outcome: IterationOutcomeDetails::Accepted {
2196                            kind: current_transform_kind,
2197                        },
2198                        oracle_time_micros,
2199                        transform_always_equivalent: chosen_candidate.always_equivalent,
2200                        transform: Some(current_transform_kind),
2201                    }
2202                } else {
2203                    McmcIterationOutput {
2204                        output_state: current_fn,
2205                        output_cost: current_cost,
2206                        best_updated: false,
2207                        outcome: IterationOutcomeDetails::MetropolisReject,
2208                        oracle_time_micros,
2209                        transform_always_equivalent: chosen_candidate.always_equivalent,
2210                        transform: Some(current_transform_kind),
2211                    }
2212                }
2213            }
2214        }
2215        Err(e) => {
2216            log::debug!(
2217                "Error applying PIR transform {:?}: {:?}",
2218                current_transform_kind,
2219                e
2220            );
2221            McmcIterationOutput {
2222                output_state: current_fn,
2223                output_cost: current_cost,
2224                best_updated: false,
2225                outcome: IterationOutcomeDetails::ApplyFailure,
2226                oracle_time_micros: 0,
2227                transform_always_equivalent: chosen_candidate.always_equivalent,
2228                transform: Some(current_transform_kind),
2229            }
2230        }
2231    }
2232}
2233
2234fn make_all_zeros_value(ty: &PirType) -> Result<IrValue> {
2235    match ty {
2236        PirType::Token => Ok(IrValue::make_token()),
2237        PirType::Bits(width) => {
2238            if *width == 0 {
2239                Ok(IrValue::from_bits(&IrBits::make_ubits(0, 0).unwrap()))
2240            } else {
2241                Ok(IrValue::from_bits(&IrBits::make_ubits(*width, 0).unwrap()))
2242            }
2243        }
2244        PirType::Tuple(elem_types) => {
2245            let elems: Result<Vec<IrValue>> =
2246                elem_types.iter().map(|t| make_all_zeros_value(t)).collect();
2247            Ok(IrValue::make_tuple(&elems?))
2248        }
2249        PirType::Array(arr) => {
2250            if arr.element_count == 0 {
2251                return Err(anyhow::anyhow!(
2252                    "cannot construct all-zeros oracle sample for zero-length array type {}",
2253                    ty
2254                ));
2255            }
2256            let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2257            for _ in 0..arr.element_count {
2258                elems.push(make_all_zeros_value(&arr.element_type)?);
2259            }
2260            IrValue::make_array(&elems).map_err(|e| {
2261                anyhow::anyhow!("failed to construct all-zeros array oracle sample: {}", e)
2262            })
2263        }
2264    }
2265}
2266
2267fn make_all_ones_value(ty: &PirType) -> Result<IrValue> {
2268    match ty {
2269        PirType::Token => Ok(IrValue::make_token()),
2270        PirType::Bits(width) => {
2271            if *width == 0 {
2272                Ok(IrValue::from_bits(&IrBits::make_ubits(0, 0).unwrap()))
2273            } else if *width <= 64 {
2274                let mask = if *width == 64 {
2275                    u64::MAX
2276                } else {
2277                    (1u64 << *width) - 1
2278                };
2279                Ok(IrValue::from_bits(
2280                    &IrBits::make_ubits(*width, mask).unwrap(),
2281                ))
2282            } else {
2283                // Build the bit vector directly to avoid fixed-width integer limits.
2284                let ones: Vec<bool> = vec![true; *width];
2285                Ok(IrValue::from_bits(&IrBits::from_lsb_is_0(&ones)))
2286            }
2287        }
2288        PirType::Tuple(elem_types) => {
2289            let elems: Result<Vec<IrValue>> =
2290                elem_types.iter().map(|t| make_all_ones_value(t)).collect();
2291            Ok(IrValue::make_tuple(&elems?))
2292        }
2293        PirType::Array(arr) => {
2294            if arr.element_count == 0 {
2295                return Err(anyhow::anyhow!(
2296                    "cannot construct all-ones oracle sample for zero-length array type {}",
2297                    ty
2298                ));
2299            }
2300            let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2301            for _ in 0..arr.element_count {
2302                elems.push(make_all_ones_value(&arr.element_type)?);
2303            }
2304            IrValue::make_array(&elems).map_err(|e| {
2305                anyhow::anyhow!("failed to construct all-ones array oracle sample: {}", e)
2306            })
2307        }
2308    }
2309}
2310
2311fn arbitrary_value_for_type<R: Rng>(rng: &mut R, ty: &PirType) -> Result<IrValue> {
2312    match ty {
2313        PirType::Token => Ok(IrValue::make_token()),
2314        PirType::Bits(width) => {
2315            let bits = generate_biased_irbits_with_rng(rng, *width);
2316            Ok(IrValue::from_bits(&bits))
2317        }
2318        PirType::Tuple(elem_types) => {
2319            let elems: Result<Vec<IrValue>> = elem_types
2320                .iter()
2321                .map(|t| arbitrary_value_for_type(rng, t))
2322                .collect();
2323            Ok(IrValue::make_tuple(&elems?))
2324        }
2325        PirType::Array(arr) => {
2326            if arr.element_count == 0 {
2327                return Err(anyhow::anyhow!(
2328                    "cannot construct random oracle sample for zero-length array type {}",
2329                    ty
2330                ));
2331            }
2332            let mut elems: Vec<IrValue> = Vec::with_capacity(arr.element_count);
2333            for _ in 0..arr.element_count {
2334                elems.push(arbitrary_value_for_type(rng, &arr.element_type)?);
2335            }
2336            IrValue::make_array(&elems).map_err(|e| {
2337                anyhow::anyhow!("failed to construct random array oracle sample: {}", e)
2338            })
2339        }
2340    }
2341}
2342
2343fn eval_fn_safe(f: &IrFn, args: &[IrValue]) -> Result<IrValue, ()> {
2344    // Note: `xlsynth_pir::ir_eval` uses internal `expect` / `unwrap` paths for
2345    // invariants; rewiring transforms may temporarily violate those (cycles,
2346    // missing package context for invoke, etc.). We treat any such failure as a
2347    // rejection signal for the candidate, not a crash.
2348    //
2349    // The MCMC state is compacted/toposorted at initialization and after each
2350    // successful candidate application, so the oracle can skip the evaluator's
2351    // per-call node-order check.
2352    let result = catch_unwind(AssertUnwindSafe(|| {
2353        eval_fn_assuming_node_index_topological(f, args)
2354    }));
2355    match result {
2356        Ok(FnEvalResult::Success(s)) => Ok(s.value),
2357        Ok(FnEvalResult::Failure(_f)) => Err(()),
2358        Err(_panic) => Err(()),
2359    }
2360}
2361
2362fn make_oracle_args<F>(params: &[PirParam], label: &str, mut make_value: F) -> Result<Vec<IrValue>>
2363where
2364    F: FnMut(&PirType) -> Result<IrValue>,
2365{
2366    params
2367        .iter()
2368        .map(|p| make_value(&p.ty))
2369        .collect::<Result<Vec<_>>>()
2370        .map_err(|e| anyhow::anyhow!("failed to construct {} oracle sample args: {}", label, e))
2371}
2372
2373fn pir_equiv_oracle<R: Rng>(
2374    lhs: &IrFn,
2375    rhs: &IrFn,
2376    rng: &mut R,
2377    random_samples: usize,
2378    enable_formal_oracle: bool,
2379    baseline_cache: &mut EvalFnBaselineResults,
2380) -> bool {
2381    if lhs.params.len() != rhs.params.len() || lhs.ret_ty != rhs.ret_ty {
2382        return false;
2383    }
2384    for (lp, rp) in lhs.params.iter().zip(rhs.params.iter()) {
2385        if lp.ty != rp.ty {
2386            return false;
2387        }
2388    }
2389
2390    // The accepted-state invariant says each current `lhs` is equivalent to the
2391    // initial baseline. Populate this once for the first oracle check in a
2392    // chain/segment, then keep comparing candidates to those expected return
2393    // values instead of re-evaluating `lhs`.
2394    if let Err(e) = baseline_cache.ensure_populated(lhs, rng, random_samples) {
2395        log::debug!(
2396            "[pir-mcmc] failed to populate oracle baseline cache: {}; rejecting candidate",
2397            e
2398        );
2399        return false;
2400    }
2401    for (args, expected_value) in baseline_cache
2402        .samples
2403        .iter()
2404        .zip(baseline_cache.expected_values.iter())
2405    {
2406        let Ok(expected_value) = expected_value else {
2407            return false;
2408        };
2409        let Ok(rhs_value) = eval_fn_safe(rhs, args) else {
2410            return false;
2411        };
2412        if expected_value != &rhs_value {
2413            return false;
2414        }
2415    }
2416
2417    if enable_formal_oracle {
2418        {
2419            match prove_ir_fn_equiv(lhs, rhs) {
2420                EquivResult::Proved => true,
2421                EquivResult::Disproved { .. } | EquivResult::ToolchainDisproved(_) => false,
2422                EquivResult::Inconclusive(msg) => {
2423                    log::warn!(
2424                        "[pir-mcmc] formal oracle inconclusive for '{}' vs '{}': {}; rejecting candidate",
2425                        lhs.name,
2426                        rhs.name,
2427                        msg
2428                    );
2429                    false
2430                }
2431                EquivResult::Error(msg) => {
2432                    log::warn!(
2433                        "[pir-mcmc] formal oracle error for '{}' vs '{}': {}; rejecting candidate",
2434                        lhs.name,
2435                        rhs.name,
2436                        msg
2437                    );
2438                    false
2439                }
2440            }
2441        }
2442    } else {
2443        true
2444    }
2445}
2446
2447fn get_pir_transforms_for_run(enable_formal_oracle: bool) -> Vec<Box<dyn PirTransform>> {
2448    let mut all_transforms = get_all_pir_transforms();
2449    if !enable_formal_oracle {
2450        all_transforms.retain(|t| t.can_emit_always_equivalent_candidates());
2451    }
2452    all_transforms
2453}
2454
2455struct PreparedRun {
2456    start_fn: IrFn,
2457    prepared_toggle_stimulus: Option<Arc<Vec<Vec<IrBits>>>>,
2458    initial_cost: Cost,
2459    effective_constraints: ConstraintLimits,
2460}
2461
2462/// Validates a run and computes the canonicalized origin artifacts shared by
2463/// ordinary MCMC runs and provenance-producing runs.
2464fn prepare_run_start(mut start_fn: IrFn, options: &RunOptions) -> Result<PreparedRun> {
2465    if !options.objective.needs_toggle_stimulus() && options.toggle_stimulus.is_some() {
2466        return Err(anyhow::anyhow!(
2467            "toggle stimulus is not valid with objective {}",
2468            options.objective.value_name()
2469        ));
2470    }
2471    if options.objective.uses_postprocessed_costing()
2472        && options
2473            .g8r_evaluation_mode
2474            .external_postprocess_program()
2475            .is_none()
2476    {
2477        return Err(anyhow::anyhow!(
2478            "objective {} requires --g8r-postprocess-program",
2479            options.objective.value_name()
2480        ));
2481    }
2482    validate_constraint_configuration(
2483        options.objective,
2484        ConstraintLimits {
2485            max_delay: options.max_allowed_depth,
2486            max_area: options.max_allowed_area,
2487        },
2488    )?;
2489    compact_and_toposort_in_place(&mut start_fn)
2490        .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2491
2492    let prepared_toggle_stimulus = if options.objective.needs_toggle_stimulus() {
2493        let samples = options.toggle_stimulus.as_ref().ok_or_else(|| {
2494            anyhow::anyhow!(
2495                "objective {} requires toggle stimulus",
2496                options.objective.value_name()
2497            )
2498        })?;
2499        Some(Arc::new(lower_toggle_stimulus_for_fn(samples, &start_fn)?))
2500    } else {
2501        None
2502    };
2503
2504    let initial_cost =
2505        cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
2506            &start_fn,
2507            options.objective,
2508            prepared_toggle_stimulus.as_ref().map(|v| v.as_slice()),
2509            &options.weighted_switching_options,
2510            options.extension_costing_mode,
2511            &options.g8r_evaluation_mode,
2512            &options.canonical_g8r_options,
2513        )?;
2514    let effective_constraints = effective_constraint_limits(
2515        options.objective,
2516        ConstraintLimits {
2517            max_delay: options.max_allowed_depth,
2518            max_area: options.max_allowed_area,
2519        },
2520        &initial_cost,
2521    );
2522
2523    Ok(PreparedRun {
2524        start_fn,
2525        prepared_toggle_stimulus,
2526        initial_cost,
2527        effective_constraints,
2528    })
2529}
2530
2531/// Validates whether a run shape can produce a provenance artifact.
2532pub fn validate_pir_mcmc_artifact_run_options(options: &RunOptions) -> Result<()> {
2533    if options.max_allowed_depth.is_some() || options.max_allowed_area.is_some() {
2534        return Err(anyhow::anyhow!(
2535            "run_pir_mcmc_with_artifact currently supports only unconstrained runs"
2536        ));
2537    }
2538    if options.objective.enforces_non_regressing_depth() {
2539        return Err(anyhow::anyhow!(
2540            "run_pir_mcmc_with_artifact does not yet support objectives with implicit feasibility caps; got {}",
2541            options.objective.value_name()
2542        ));
2543    }
2544    Ok(())
2545}
2546
2547fn validate_prefix_minimization_artifact(artifact: &PirMcmcArtifact) -> Result<()> {
2548    validate_pir_mcmc_artifact_run_options(&artifact.run_options)?;
2549    for (expected_index, action) in artifact.winning_provenance.iter().enumerate() {
2550        if action.action_index() != expected_index + 1 {
2551            return Err(anyhow::anyhow!(
2552                "winning provenance action indices must be contiguous from 1; expected {}, got {}",
2553                expected_index + 1,
2554                action.action_index()
2555            ));
2556        }
2557    }
2558    match artifact.winning_provenance.last() {
2559        Some(last_action) => {
2560            if last_action.cost() != artifact.raw_winner_cost
2561                || last_action.state().to_string() != artifact.raw_winner_fn.to_string()
2562            {
2563                return Err(anyhow::anyhow!(
2564                    "winning provenance endpoint does not match the recorded raw winner"
2565                ));
2566            }
2567        }
2568        None => {
2569            if artifact.raw_winner_cost != artifact.origin_cost
2570                || artifact.raw_winner_fn.to_string() != artifact.origin_fn.to_string()
2571            {
2572                return Err(anyhow::anyhow!(
2573                    "empty winning provenance is valid only when the raw winner is the origin"
2574                ));
2575            }
2576        }
2577    }
2578    Ok(())
2579}
2580
2581fn retained_win_fraction_for_metric(origin_metric: u128, winner_metric: u128, metric: u128) -> f64 {
2582    let total_win = origin_metric.saturating_sub(winner_metric);
2583    if total_win == 0 {
2584        return 0.0;
2585    }
2586    let retained_win = origin_metric.saturating_sub(metric);
2587    retained_win as f64 / total_win as f64
2588}
2589
2590fn win_percent_vs_origin_for_metric(origin_metric: u128, metric: u128) -> f64 {
2591    if origin_metric == 0 {
2592        return 0.0;
2593    }
2594    100.0 * origin_metric.saturating_sub(metric) as f64 / origin_metric as f64
2595}
2596
2597fn objective_supports_budget_frontier_search(objective: Objective) -> bool {
2598    !objective.needs_toggle_stimulus()
2599        && !matches!(
2600            objective,
2601            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress
2602                | Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress
2603        )
2604}
2605
2606fn make_budget_witness(
2607    witness_fn: IrFn,
2608    witness_cost: Cost,
2609    provenance_action_count: usize,
2610    objective: Objective,
2611    origin_metric: u128,
2612    winner_metric: u128,
2613) -> PirMcmcBudgetWitness {
2614    let metric = objective.metric(&witness_cost);
2615    PirMcmcBudgetWitness {
2616        witness_fn,
2617        witness_cost,
2618        provenance_action_count,
2619        metric,
2620        absolute_win: origin_metric.saturating_sub(metric),
2621        win_percent_vs_origin: win_percent_vs_origin_for_metric(origin_metric, metric),
2622        retained_win_fraction: retained_win_fraction_for_metric(
2623            origin_metric,
2624            winner_metric,
2625            metric,
2626        ),
2627    }
2628}
2629
2630fn better_budget_witness(
2631    candidate: &PirMcmcBudgetWitness,
2632    incumbent: &PirMcmcBudgetWitness,
2633) -> bool {
2634    candidate.metric < incumbent.metric
2635        || (candidate.metric == incumbent.metric
2636            && candidate.provenance_action_count < incumbent.provenance_action_count)
2637}
2638
2639fn frontier_budgets(options: PirMcmcBudgetFrontierOptions) -> Result<Vec<usize>> {
2640    if options.budget_step == 0 {
2641        return Err(anyhow::anyhow!("budget_step must be > 0"));
2642    }
2643    if options.max_actions == 0 {
2644        return Err(anyhow::anyhow!("max_actions must be > 0"));
2645    }
2646    if options.budget_step > options.max_actions {
2647        return Err(anyhow::anyhow!(
2648            "budget_step must be <= max_actions; got step={} max={}",
2649            options.budget_step,
2650            options.max_actions
2651        ));
2652    }
2653    if options.rollouts_per_budget == 0 {
2654        return Err(anyhow::anyhow!("rollouts_per_budget must be > 0"));
2655    }
2656    if options.proposal_attempts_per_rewrite == 0 {
2657        return Err(anyhow::anyhow!("proposal_attempts_per_rewrite must be > 0"));
2658    }
2659    if !options.witness_kind_boost.is_finite() || options.witness_kind_boost < 0.0 {
2660        return Err(anyhow::anyhow!(
2661            "witness_kind_boost must be finite and >= 0; got {}",
2662            options.witness_kind_boost
2663        ));
2664    }
2665
2666    let mut budgets = Vec::new();
2667    let mut budget = options.budget_step;
2668    while budget <= options.max_actions {
2669        budgets.push(budget);
2670        match budget.checked_add(options.budget_step) {
2671            Some(next) => budget = next,
2672            None => break,
2673        }
2674    }
2675    if budgets.last().copied() != Some(options.max_actions) {
2676        budgets.push(options.max_actions);
2677    }
2678    Ok(budgets)
2679}
2680
2681fn build_witness_guided_transform_weights(
2682    transforms: &[Box<dyn PirTransform>],
2683    artifact: &PirMcmcArtifact,
2684    witness_kind_boost: f64,
2685) -> Vec<f64> {
2686    let mut counts = BTreeMap::<PirTransformKind, usize>::new();
2687    for action in artifact.winning_provenance.iter() {
2688        if let Some(kind) = action.transform_kind() {
2689            *counts.entry(kind).or_insert(0) += 1;
2690        }
2691    }
2692    transforms
2693        .iter()
2694        .map(|transform| {
2695            1.0 + witness_kind_boost * counts.get(&transform.kind()).copied().unwrap_or(0) as f64
2696        })
2697        .collect()
2698}
2699
2700fn prefix_baseline_for_budget(
2701    artifact: &PirMcmcArtifact,
2702    action_budget: usize,
2703    origin_metric: u128,
2704    winner_metric: u128,
2705) -> PirMcmcBudgetWitness {
2706    let objective = artifact.run_options.objective;
2707    let mut best = make_budget_witness(
2708        artifact.origin_fn.clone(),
2709        artifact.origin_cost,
2710        0,
2711        objective,
2712        origin_metric,
2713        winner_metric,
2714    );
2715    for action in artifact
2716        .winning_provenance
2717        .iter()
2718        .take_while(|action| action.action_index() <= action_budget)
2719    {
2720        let candidate = make_budget_witness(
2721            action.state().clone(),
2722            action.cost(),
2723            action.action_index(),
2724            objective,
2725            origin_metric,
2726            winner_metric,
2727        );
2728        if better_budget_witness(&candidate, &best) {
2729            best = candidate;
2730        }
2731    }
2732    best
2733}
2734
2735fn chain_strategy_value_name(strategy: ChainStrategy) -> &'static str {
2736    match strategy {
2737        ChainStrategy::Independent => "independent",
2738        ChainStrategy::ExploreExploit => "explore-exploit",
2739    }
2740}
2741
2742fn chain_strategy_from_value_name(value: &str) -> Result<ChainStrategy> {
2743    match value {
2744        "independent" => Ok(ChainStrategy::Independent),
2745        "explore-exploit" => Ok(ChainStrategy::ExploreExploit),
2746        _ => Err(anyhow::anyhow!(
2747            "unknown chain strategy in artifact: {}",
2748            value
2749        )),
2750    }
2751}
2752
2753fn emit_pkg_text_toposorted(pkg: &PirPackage) -> Result<String> {
2754    let mut pkg = pkg.clone();
2755    for member in pkg.members.iter_mut() {
2756        match member {
2757            PirPackageMember::Function(f) => {
2758                compact_and_toposort_in_place(f)
2759                    .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2760            }
2761            PirPackageMember::Block { func, .. } => {
2762                compact_and_toposort_in_place(func)
2763                    .map_err(|e| anyhow::anyhow!("compact_and_toposort_in_place failed: {}", e))?;
2764            }
2765        }
2766    }
2767    Ok(pkg.to_string())
2768}
2769
2770fn package_with_replaced_fn(
2771    package_template: &PirPackage,
2772    top_fn_name: &str,
2773    replacement: &IrFn,
2774) -> Result<PirPackage> {
2775    let mut pkg = package_template.clone();
2776    let top_mut = pkg.get_fn_mut(top_fn_name).ok_or_else(|| {
2777        anyhow::anyhow!(
2778            "top function '{}' not found in artifact package template",
2779            top_fn_name
2780        )
2781    })?;
2782    *top_mut = replacement.clone();
2783    Ok(pkg)
2784}
2785
2786fn write_artifact_state_package(
2787    artifact_dir: &Path,
2788    package_template: &PirPackage,
2789    top_fn_name: &str,
2790    relative_file: &str,
2791    state_fn: &IrFn,
2792) -> Result<()> {
2793    let pkg = package_with_replaced_fn(package_template, top_fn_name, state_fn)?;
2794    let text = emit_pkg_text_toposorted(&pkg)?;
2795    let path = artifact_dir.join(relative_file);
2796    if let Some(parent) = path.parent() {
2797        std::fs::create_dir_all(parent).map_err(|e| {
2798            anyhow::anyhow!(
2799                "failed to create artifact state directory {}: {}",
2800                parent.display(),
2801                e
2802            )
2803        })?;
2804    }
2805    std::fs::write(&path, text.as_bytes())
2806        .map_err(|e| anyhow::anyhow!("failed to write {}: {}", path.display(), e))
2807}
2808
2809fn load_artifact_state_fn(
2810    artifact_dir: &Path,
2811    relative_file: &str,
2812    top_fn_name: &str,
2813) -> Result<(PirPackage, IrFn)> {
2814    let path = artifact_dir.join(relative_file);
2815    let pkg = ir_parser::parse_and_validate_path_to_package(&path).map_err(|e| {
2816        anyhow::anyhow!(
2817            "failed to parse artifact state package {}: {:?}",
2818            path.display(),
2819            e
2820        )
2821    })?;
2822    let state_fn = pkg.get_fn(top_fn_name).cloned().ok_or_else(|| {
2823        anyhow::anyhow!(
2824            "artifact state package {} does not contain top function '{}'",
2825            path.display(),
2826            top_fn_name
2827        )
2828    })?;
2829    Ok((pkg, state_fn))
2830}
2831
2832impl PersistedRunOptions {
2833    fn from_run_options(options: &RunOptions) -> Result<Self> {
2834        Ok(Self {
2835            max_iters: options.max_iters,
2836            threads: options.threads,
2837            chain_strategy: chain_strategy_value_name(options.chain_strategy).to_string(),
2838            checkpoint_iters: options.checkpoint_iters,
2839            progress_iters: options.progress_iters,
2840            seed: options.seed,
2841            initial_temperature: options.initial_temperature,
2842            objective: options.objective.value_name().to_string(),
2843            extension_costing_mode: options.extension_costing_mode.value_name().to_string(),
2844            g8r_evaluation_mode: options
2845                .g8r_evaluation_mode
2846                .canonicalized_for_persistence()?,
2847            canonical_g8r_options: options.canonical_g8r_options.clone(),
2848            max_allowed_depth: options.max_allowed_depth,
2849            max_allowed_area: options.max_allowed_area,
2850            switching_beta1: options.weighted_switching_options.beta1,
2851            switching_beta2: options.weighted_switching_options.beta2,
2852            switching_primary_output_load: options.weighted_switching_options.primary_output_load,
2853            enable_formal_oracle: options.enable_formal_oracle,
2854        })
2855    }
2856
2857    fn into_run_options(self) -> Result<RunOptions> {
2858        Ok(RunOptions {
2859            max_iters: self.max_iters,
2860            threads: self.threads,
2861            chain_strategy: chain_strategy_from_value_name(&self.chain_strategy)?,
2862            checkpoint_iters: self.checkpoint_iters,
2863            progress_iters: self.progress_iters,
2864            seed: self.seed,
2865            initial_temperature: self.initial_temperature,
2866            objective: Objective::from_value_name(&self.objective)?,
2867            extension_costing_mode: ExtensionCostingMode::from_value_name(
2868                &self.extension_costing_mode,
2869            )?,
2870            g8r_evaluation_mode: self.g8r_evaluation_mode,
2871            canonical_g8r_options: self.canonical_g8r_options,
2872            max_allowed_depth: self.max_allowed_depth,
2873            max_allowed_area: self.max_allowed_area,
2874            weighted_switching_options: count_toggles::WeightedSwitchingOptions {
2875                beta1: self.switching_beta1,
2876                beta2: self.switching_beta2,
2877                primary_output_load: self.switching_primary_output_load,
2878            },
2879            enable_formal_oracle: self.enable_formal_oracle,
2880            trajectory_dir: None,
2881            toggle_stimulus: None,
2882        })
2883    }
2884}
2885
2886/// Writes a durable winning-provenance artifact under
2887/// `run_dir/winning-lineage`.
2888pub fn write_pir_mcmc_artifact_dir(
2889    artifact: &PirMcmcArtifact,
2890    package_template: &PirPackage,
2891    run_dir: &Path,
2892) -> Result<PathBuf> {
2893    validate_prefix_minimization_artifact(artifact)?;
2894    let artifact_dir = run_dir.join(PIR_MCMC_ARTIFACT_DIR_NAME);
2895    std::fs::create_dir_all(&artifact_dir).map_err(|e| {
2896        anyhow::anyhow!(
2897            "failed to create artifact directory {}: {}",
2898            artifact_dir.display(),
2899            e
2900        )
2901    })?;
2902
2903    let top_fn_name = artifact.origin_fn.name.clone();
2904    let origin_file = format!("{}/origin.ir", PIR_MCMC_ARTIFACT_STATES_DIR_NAME);
2905    let raw_winner_file = format!("{}/raw-winner.ir", PIR_MCMC_ARTIFACT_STATES_DIR_NAME);
2906    write_artifact_state_package(
2907        &artifact_dir,
2908        package_template,
2909        &top_fn_name,
2910        &origin_file,
2911        &artifact.origin_fn,
2912    )?;
2913    write_artifact_state_package(
2914        &artifact_dir,
2915        package_template,
2916        &top_fn_name,
2917        &raw_winner_file,
2918        &artifact.raw_winner_fn,
2919    )?;
2920
2921    let mut winning_provenance = Vec::with_capacity(artifact.winning_provenance.len());
2922    for action in artifact.winning_provenance.iter() {
2923        let state_file = format!(
2924            "{}/action-{:06}.ir",
2925            PIR_MCMC_ARTIFACT_STATES_DIR_NAME,
2926            action.action_index()
2927        );
2928        write_artifact_state_package(
2929            &artifact_dir,
2930            package_template,
2931            &top_fn_name,
2932            &state_file,
2933            action.state(),
2934        )?;
2935        let state = PersistedArtifactState {
2936            file: state_file,
2937            cost: action.cost(),
2938        };
2939        winning_provenance.push(match action {
2940            PirMcmcProvenanceAction::AcceptedRewrite {
2941                action_index,
2942                chain_no,
2943                global_iter,
2944                transform_kind,
2945                ..
2946            } => PersistedPirMcmcProvenanceAction {
2947                kind: PersistedPirMcmcProvenanceActionKind::AcceptedRewrite,
2948                action_index: *action_index,
2949                chain_no: *chain_no,
2950                global_iter: *global_iter,
2951                transform_kind: Some(*transform_kind),
2952                state,
2953            },
2954            PirMcmcProvenanceAction::XlsOptimizedHandoff {
2955                action_index,
2956                chain_no,
2957                global_iter,
2958                ..
2959            } => PersistedPirMcmcProvenanceAction {
2960                kind: PersistedPirMcmcProvenanceActionKind::XlsOptimizedHandoff,
2961                action_index: *action_index,
2962                chain_no: *chain_no,
2963                global_iter: *global_iter,
2964                transform_kind: None,
2965                state,
2966            },
2967        });
2968    }
2969
2970    let manifest = PersistedPirMcmcArtifactManifest {
2971        schema_version: PIR_MCMC_ARTIFACT_SCHEMA_VERSION,
2972        top_fn_name,
2973        run_options: PersistedRunOptions::from_run_options(&artifact.run_options)?,
2974        origin: PersistedArtifactState {
2975            file: origin_file,
2976            cost: artifact.origin_cost,
2977        },
2978        raw_winner: PersistedArtifactState {
2979            file: raw_winner_file,
2980            cost: artifact.raw_winner_cost,
2981        },
2982        winning_provenance,
2983    };
2984    let manifest_json = serde_json::to_string_pretty(&manifest)
2985        .map_err(|e| anyhow::anyhow!("failed to serialize artifact manifest: {}", e))?;
2986    let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
2987    std::fs::write(&manifest_path, manifest_json.as_bytes())
2988        .map_err(|e| anyhow::anyhow!("failed to write {}: {}", manifest_path.display(), e))?;
2989    Ok(artifact_dir)
2990}
2991
2992/// Loads a durable winning-provenance artifact from `run_dir/winning-lineage`.
2993pub fn read_pir_mcmc_artifact_dir(run_dir: &Path) -> Result<LoadedPirMcmcArtifact> {
2994    let artifact_dir = run_dir.join(PIR_MCMC_ARTIFACT_DIR_NAME);
2995    let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
2996    let manifest_text = std::fs::read_to_string(&manifest_path)
2997        .map_err(|e| anyhow::anyhow!("failed to read {}: {}", manifest_path.display(), e))?;
2998    let manifest_value: serde_json::Value = serde_json::from_str(&manifest_text)
2999        .map_err(|e| anyhow::anyhow!("failed to parse {}: {}", manifest_path.display(), e))?;
3000    let schema_version = manifest_value
3001        .get("schema_version")
3002        .and_then(serde_json::Value::as_u64)
3003        .ok_or_else(|| {
3004            anyhow::anyhow!(
3005                "artifact manifest {} is missing integer schema_version",
3006                manifest_path.display()
3007            )
3008        })? as u32;
3009    if schema_version != PIR_MCMC_ARTIFACT_SCHEMA_VERSION {
3010        return Err(anyhow::anyhow!(
3011            "unsupported PIR MCMC artifact schema version {}; expected {}",
3012            schema_version,
3013            PIR_MCMC_ARTIFACT_SCHEMA_VERSION
3014        ));
3015    }
3016    let manifest: PersistedPirMcmcArtifactManifest = serde_json::from_str(&manifest_text)
3017        .map_err(|e| anyhow::anyhow!("failed to parse {}: {}", manifest_path.display(), e))?;
3018
3019    let top_fn_name = manifest.top_fn_name.clone();
3020    let (package_template, origin_fn) =
3021        load_artifact_state_fn(&artifact_dir, &manifest.origin.file, &top_fn_name)?;
3022    let (_, raw_winner_fn) =
3023        load_artifact_state_fn(&artifact_dir, &manifest.raw_winner.file, &top_fn_name)?;
3024
3025    let mut winning_provenance = Vec::with_capacity(manifest.winning_provenance.len());
3026    for action in manifest.winning_provenance.into_iter() {
3027        let (_, state) = load_artifact_state_fn(&artifact_dir, &action.state.file, &top_fn_name)?;
3028        winning_provenance.push(match action.kind {
3029            PersistedPirMcmcProvenanceActionKind::AcceptedRewrite => {
3030                let transform_kind = action.transform_kind.ok_or_else(|| {
3031                    anyhow::anyhow!(
3032                        "accepted_rewrite action {} is missing transform_kind",
3033                        action.action_index
3034                    )
3035                })?;
3036                PirMcmcProvenanceAction::AcceptedRewrite {
3037                    action_index: action.action_index,
3038                    chain_no: action.chain_no,
3039                    global_iter: action.global_iter,
3040                    transform_kind,
3041                    state,
3042                    cost: action.state.cost,
3043                }
3044            }
3045            PersistedPirMcmcProvenanceActionKind::XlsOptimizedHandoff => {
3046                if action.transform_kind.is_some() {
3047                    return Err(anyhow::anyhow!(
3048                        "xls_optimized_handoff action {} must not include transform_kind",
3049                        action.action_index
3050                    ));
3051                }
3052                PirMcmcProvenanceAction::XlsOptimizedHandoff {
3053                    action_index: action.action_index,
3054                    chain_no: action.chain_no,
3055                    global_iter: action.global_iter,
3056                    state,
3057                    cost: action.state.cost,
3058                }
3059            }
3060        });
3061    }
3062
3063    let artifact = PirMcmcArtifact {
3064        origin_fn,
3065        origin_cost: manifest.origin.cost,
3066        run_options: manifest.run_options.into_run_options()?,
3067        raw_winner_fn,
3068        raw_winner_cost: manifest.raw_winner.cost,
3069        winning_provenance,
3070    };
3071    validate_prefix_minimization_artifact(&artifact)?;
3072    Ok(LoadedPirMcmcArtifact {
3073        artifact,
3074        package_template,
3075        top_fn_name,
3076    })
3077}
3078
3079impl SegmentRunner<IrFn, Cost, PirTransformKind> for PirSegmentRunner {
3080    type Error = anyhow::Error;
3081
3082    fn run_segment(
3083        &self,
3084        start_state: IrFn,
3085        params: SegmentRunParams,
3086    ) -> Result<SegmentOutcome<IrFn, Cost, PirTransformKind>, Self::Error> {
3087        let mut trajectory_writer: Option<std::io::BufWriter<std::fs::File>> =
3088            if let Some(dir) = &self.trajectory_dir {
3089                std::fs::create_dir_all(dir).map_err(|e| {
3090                    anyhow::anyhow!("failed to create trajectory dir {}: {}", dir.display(), e)
3091                })?;
3092                let path = dir.join(format!("trajectory.c{:03}.jsonl", params.chain_no));
3093                let f = std::fs::OpenOptions::new()
3094                    .create(true)
3095                    .append(true)
3096                    .open(&path)
3097                    .map_err(|e| anyhow::anyhow!("failed to open {}: {}", path.display(), e))?;
3098                Some(std::io::BufWriter::new(f))
3099            } else {
3100                None
3101            };
3102
3103        let mut iteration_rng = Pcg64Mcg::seed_from_u64(params.seed);
3104        let all_transforms = get_pir_transforms_for_run(self.enable_formal_oracle);
3105        let weights = build_transform_weights(&all_transforms);
3106
3107        let mut context = PirMcmcContext {
3108            rng: &mut iteration_rng,
3109            all_transforms,
3110            weights,
3111            enable_formal_oracle: self.enable_formal_oracle,
3112            oracle_baseline_cache: EvalFnBaselineResults::default(),
3113        };
3114
3115        let toggle_stimulus = self.prepared_toggle_stimulus.as_ref().map(|v| v.as_slice());
3116
3117        let mut current_fn = start_state.clone();
3118        let mut current_cost =
3119            cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
3120                &current_fn,
3121                self.objective,
3122                toggle_stimulus,
3123                &self.weighted_switching_options,
3124                self.extension_costing_mode,
3125                &self.g8r_evaluation_mode,
3126                &self.canonical_g8r_options,
3127            )
3128            .map_err(|e| {
3129                anyhow::anyhow!(
3130                    "failed to evaluate initial cost for '{}' under {:?}: {}",
3131                    current_fn.name,
3132                    self.objective,
3133                    e
3134                )
3135            })?;
3136        let mut best_fn = start_state;
3137        let mut best_cost = current_cost;
3138        let mut best_score = search_score(&best_cost, self.objective, self.constraints);
3139        let mut stats = McmcStats::default();
3140
3141        let seg_start_time = Instant::now();
3142
3143        let mut iterations_count: u64 = 0;
3144        while iterations_count < params.segment_iters {
3145            iterations_count += 1;
3146            let global_iter = params.iter_offset + iterations_count;
3147
3148            let temp = match params.role {
3149                ChainRole::Explorer => self.initial_temperature * 10.0,
3150                ChainRole::Exploit => {
3151                    let progress_ratio = if params.total_iters > 0 {
3152                        (global_iter as f64) / (params.total_iters as f64)
3153                    } else {
3154                        0.0
3155                    };
3156                    let progress_ratio = progress_ratio.min(1.0);
3157                    self.initial_temperature * (1.0 - progress_ratio).max(MIN_TEMPERATURE_RATIO)
3158                }
3159            };
3160
3161            let iteration_output = mcmc_iteration(
3162                current_fn,
3163                current_cost,
3164                &mut best_fn,
3165                &mut best_cost,
3166                &mut best_score,
3167                &mut context,
3168                temp,
3169                self.objective,
3170                self.extension_costing_mode,
3171                &self.g8r_evaluation_mode,
3172                &self.canonical_g8r_options,
3173                toggle_stimulus,
3174                &self.weighted_switching_options,
3175                self.constraints,
3176            );
3177
3178            let mut accepted_digest: Option<[u8; 32]> = None;
3179            let mut accepted_sample_sent = false;
3180
3181            if let IterationOutcomeDetails::Accepted { .. } = iteration_output.outcome {
3182                if let Some(ref tx) = self.accepted_sample_tx {
3183                    match canonicalize_fn_for_sample(&iteration_output.output_state) {
3184                        Ok(canon) => match optimize_pir_fn_via_xls_with_extension_mode(
3185                            &canon,
3186                            self.extension_costing_mode,
3187                        ) {
3188                            Ok(mut opt) => {
3189                                let _ = compact_and_toposort_in_place(&mut opt);
3190                                match compute_fn_structural_digest(&opt) {
3191                                    Some(digest) => {
3192                                        accepted_digest = Some(digest);
3193                                        accepted_sample_sent = tx
3194                                            .send(AcceptedSampleMsg {
3195                                                chain_no: params.chain_no,
3196                                                global_iter,
3197                                                digest,
3198                                                cost: iteration_output.output_cost,
3199                                                func: opt,
3200                                            })
3201                                            .is_ok();
3202                                    }
3203                                    None => {
3204                                        log::warn!(
3205                                            "[pir-mcmc] failed to compute structural digest for accepted sample '{}' after XLS optimize (c{:03}:i{:06}); skipping sample emission",
3206                                            iteration_output.output_state.name,
3207                                            params.chain_no,
3208                                            global_iter
3209                                        );
3210                                    }
3211                                }
3212                            }
3213                            Err(e) => {
3214                                // The sampler wants uniqueness defined by the XLS-optimized form.
3215                                // If we cannot obtain it, skip emission rather than falling back
3216                                // to the pre-optimized state.
3217                                log::warn!(
3218                                    "[pir-mcmc] failed to XLS-optimize accepted sample '{}' (c{:03}:i{:06}): {}; skipping sample emission",
3219                                    iteration_output.output_state.name,
3220                                    params.chain_no,
3221                                    global_iter,
3222                                    e
3223                                );
3224                            }
3225                        },
3226                        Err(e) => {
3227                            log::warn!(
3228                                "[pir-mcmc] failed to canonicalize accepted sample '{}' (c{:03}:i{:06}): {}; skipping sample emission",
3229                                iteration_output.output_state.name,
3230                                params.chain_no,
3231                                global_iter,
3232                                e
3233                            );
3234                        }
3235                    }
3236                }
3237            }
3238
3239            if let Some(w) = trajectory_writer.as_mut() {
3240                let metric_u128 = self.objective.metric(&iteration_output.output_cost);
3241                let iter_score = search_score(
3242                    &iteration_output.output_cost,
3243                    self.objective,
3244                    self.constraints,
3245                );
3246                let iter_violation = iter_score.violation;
3247                let rec = json!({
3248                    "chain_no": params.chain_no,
3249                    "role": format!("{:?}", params.role),
3250                    "global_iter": global_iter,
3251                    "temp": temp,
3252                    "outcome": iteration_outcome_tag(&iteration_output.outcome),
3253                    "best_updated": iteration_output.best_updated,
3254                    "objective": format!("{:?}", self.objective),
3255                    "extension_costing_mode": self.extension_costing_mode.value_name(),
3256                    "metric": metric_u128,
3257                    "pir_nodes": iteration_output.output_cost.pir_nodes,
3258                    "g8r_nodes": iteration_output.output_cost.g8r_nodes,
3259                    "g8r_depth": iteration_output.output_cost.g8r_depth,
3260                    "g8r_le_graph_milli": iteration_output.output_cost.g8r_le_graph_milli,
3261                    "g8r_gate_output_toggles": iteration_output.output_cost.g8r_gate_output_toggles,
3262                    "g8r_weighted_switching_milli": iteration_output.output_cost.g8r_weighted_switching_milli,
3263                    "g8r_post_and_nodes": iteration_output.output_cost.g8r_post_and_nodes,
3264                    "g8r_post_depth": iteration_output.output_cost.g8r_post_depth,
3265                    "g8r_post_le_graph_milli": iteration_output.output_cost.g8r_post_le_graph_milli,
3266                    "g8r_post_gate_output_toggles": iteration_output.output_cost.g8r_post_gate_output_toggles,
3267                    "g8r_post_weighted_switching_milli": iteration_output.output_cost.g8r_post_weighted_switching_milli,
3268                    "feasible": iter_score.feasible(),
3269                    "delay_over": iter_violation.and_then(|v| v.delay_over),
3270                    "area_over": iter_violation.and_then(|v| v.area_over),
3271                    "oracle_time_micros": iteration_output.oracle_time_micros,
3272                    "transform": iteration_output.transform.map(|k| format!("{:?}", k)),
3273                    "transform_mechanism": iteration_output.transform.map(|k| k.mechanism_hint()),
3274                    "transform_always_equivalent": iteration_output.transform_always_equivalent,
3275                    "accepted_digest": accepted_digest.map(|d| hash_to_hex(&d)),
3276                    "accepted_sample_sent": accepted_sample_sent,
3277                });
3278                // Best-effort: if trajectory logging fails, abort the segment. This should
3279                // never happen and indicates an infrastructure issue (disk full, permissions,
3280                // etc.).
3281                writeln!(w, "{}", rec.to_string())?;
3282                if global_iter % 1000 == 0 {
3283                    w.flush()?;
3284                }
3285            }
3286
3287            current_fn = iteration_output.output_state.clone();
3288            current_cost = iteration_output.output_cost;
3289
3290            if iteration_output.best_updated {
3291                if let Some(ref shared_best) = self.shared_best {
3292                    let before = shared_best.score();
3293                    let _ = shared_best.try_update(best_score, best_fn.clone());
3294                    let after = shared_best.score();
3295                    if after < before {
3296                        log::info!(
3297                            "[pir-mcmc] GLOBAL BEST UPDATE c{:03}:i{:06} | {} -> {}",
3298                            params.chain_no,
3299                            global_iter,
3300                            format_search_score(before),
3301                            format_search_score(after),
3302                        );
3303                        // Best-effort: if a checkpoint writer is active, trigger an
3304                        // immediate update so the monotone global best is visible on disk.
3305                        if let Some(ref tx) = self.checkpoint_tx {
3306                            let _ = tx.send(CheckpointMsg {
3307                                chain_no: params.chain_no,
3308                                global_iter,
3309                                kind: CheckpointKind::GlobalBestUpdate,
3310                            });
3311                        }
3312                    }
3313                }
3314            }
3315
3316            if self.checkpoint_iters > 0 && global_iter % self.checkpoint_iters == 0 {
3317                if let Some(ref tx) = self.checkpoint_tx {
3318                    // Best-effort: if the receiver is gone, stop sending.
3319                    let _ = tx.send(CheckpointMsg {
3320                        chain_no: params.chain_no,
3321                        global_iter,
3322                        kind: CheckpointKind::Periodic,
3323                    });
3324                }
3325            }
3326
3327            stats.update_for_iteration(&iteration_output, /* paranoid= */ false, global_iter);
3328
3329            if self.progress_iters > 0
3330                && (global_iter % self.progress_iters == 0
3331                    || global_iter == params.total_iters
3332                    || iterations_count == params.segment_iters)
3333            {
3334                let elapsed_secs = seg_start_time.elapsed().as_secs_f64();
3335                let samples_per_sec = if elapsed_secs > 0.0 {
3336                    iterations_count as f64 / elapsed_secs
3337                } else {
3338                    0.0
3339                };
3340                log::info!(
3341                    "PIR MCMC c{:03}:i{:06} | GBest={} | LBest (pir={}, g8r_n={}, g8r_d={}, score={}) | Cur (pir={}, g8r_n={}, g8r_d={}, score={}) | Temp={:.2e} | Samples/s={:.2}",
3342                    params.chain_no,
3343                    global_iter,
3344                    self.shared_best
3345                        .as_ref()
3346                        .map(|b| format_search_score(b.score()))
3347                        .unwrap_or_else(|| "none".to_string()),
3348                    best_cost.pir_nodes,
3349                    best_cost.g8r_nodes,
3350                    best_cost.g8r_depth,
3351                    format_search_score(best_score),
3352                    current_cost.pir_nodes,
3353                    current_cost.g8r_nodes,
3354                    current_cost.g8r_depth,
3355                    format_search_score(search_score(
3356                        &current_cost,
3357                        self.objective,
3358                        self.constraints
3359                    )),
3360                    temp,
3361                    samples_per_sec,
3362                );
3363            }
3364        }
3365
3366        if let Some(mut w) = trajectory_writer {
3367            w.flush()?;
3368        }
3369
3370        Ok(SegmentOutcome {
3371            end_state: current_fn,
3372            end_cost: current_cost,
3373            best_state: best_fn,
3374            best_cost,
3375            stats,
3376        })
3377    }
3378}
3379
3380impl SegmentRunner<ProvenancedChainState, Cost, PirTransformKind> for PirArtifactSegmentRunner {
3381    type Error = anyhow::Error;
3382
3383    fn run_segment(
3384        &self,
3385        start_state: ProvenancedChainState,
3386        params: SegmentRunParams,
3387    ) -> Result<SegmentOutcome<ProvenancedChainState, Cost, PirTransformKind>, Self::Error> {
3388        let mut trajectory_writer: Option<std::io::BufWriter<std::fs::File>> =
3389            if let Some(dir) = &self.trajectory_dir {
3390                std::fs::create_dir_all(dir).map_err(|e| {
3391                    anyhow::anyhow!("failed to create trajectory dir {}: {}", dir.display(), e)
3392                })?;
3393                let path = dir.join(format!("trajectory.c{:03}.jsonl", params.chain_no));
3394                let f = std::fs::OpenOptions::new()
3395                    .create(true)
3396                    .append(true)
3397                    .open(&path)
3398                    .map_err(|e| anyhow::anyhow!("failed to open {}: {}", path.display(), e))?;
3399                Some(std::io::BufWriter::new(f))
3400            } else {
3401                None
3402            };
3403
3404        let mut iteration_rng = Pcg64Mcg::seed_from_u64(params.seed);
3405        let all_transforms = (self.transform_factory)();
3406        let weights = build_transform_weights(&all_transforms);
3407        let mut context = PirMcmcContext {
3408            rng: &mut iteration_rng,
3409            all_transforms,
3410            weights,
3411            enable_formal_oracle: self.enable_formal_oracle,
3412            oracle_baseline_cache: EvalFnBaselineResults::default(),
3413        };
3414
3415        let toggle_stimulus = self.prepared_toggle_stimulus.as_ref().map(|v| v.as_slice());
3416        let mut current_fn = start_state.search_fn.clone();
3417        let mut current_provenance = start_state.search_provenance.clone();
3418        let mut current_cost =
3419            cost_with_effort_options_toggle_stimulus_extension_mode_evaluator_and_g8r_options(
3420                &current_fn,
3421                self.objective,
3422                toggle_stimulus,
3423                &self.weighted_switching_options,
3424                self.extension_costing_mode,
3425                &self.g8r_evaluation_mode,
3426                &self.canonical_g8r_options,
3427            )
3428            .map_err(|e| {
3429                anyhow::anyhow!(
3430                    "failed to evaluate initial cost for '{}' under {:?}: {}",
3431                    current_fn.name,
3432                    self.objective,
3433                    e
3434                )
3435            })?;
3436        if let Some((chain_no, global_iter)) = start_state.pending_handoff {
3437            current_provenance.push(PirMcmcProvenanceAction::XlsOptimizedHandoff {
3438                action_index: current_provenance.len() + 1,
3439                chain_no,
3440                global_iter,
3441                state: current_fn.clone(),
3442                cost: current_cost,
3443            });
3444        }
3445        let mut raw_winner_fn = start_state.raw_winner_fn;
3446        let mut raw_winner_cost = start_state.raw_winner_cost;
3447        let mut raw_winner_provenance = start_state.raw_winner_provenance;
3448        let current_score = search_score(&current_cost, self.objective, self.constraints);
3449        if current_score < search_score(&raw_winner_cost, self.objective, self.constraints) {
3450            raw_winner_fn = current_fn.clone();
3451            raw_winner_cost = current_cost;
3452            raw_winner_provenance = current_provenance.clone();
3453        }
3454        let mut best_fn_for_iteration = start_state.search_fn.clone();
3455        let mut best_cost_for_iteration = current_cost;
3456        let mut best_score =
3457            search_score(&best_cost_for_iteration, self.objective, self.constraints);
3458        let mut best_state = ProvenancedChainState {
3459            search_fn: start_state.search_fn,
3460            search_provenance: current_provenance.clone(),
3461            raw_winner_fn,
3462            raw_winner_cost,
3463            raw_winner_provenance,
3464            pending_handoff: None,
3465        };
3466        let mut stats = McmcStats::default();
3467        let seg_start_time = Instant::now();
3468
3469        let mut iterations_count: u64 = 0;
3470        while iterations_count < params.segment_iters {
3471            iterations_count += 1;
3472            let global_iter = params.iter_offset + iterations_count;
3473            let temp = match params.role {
3474                ChainRole::Explorer => self.initial_temperature * 10.0,
3475                ChainRole::Exploit => {
3476                    let progress_ratio = if params.total_iters > 0 {
3477                        (global_iter as f64) / (params.total_iters as f64)
3478                    } else {
3479                        0.0
3480                    };
3481                    let progress_ratio = progress_ratio.min(1.0);
3482                    self.initial_temperature * (1.0 - progress_ratio).max(MIN_TEMPERATURE_RATIO)
3483                }
3484            };
3485
3486            let iteration_output = mcmc_iteration(
3487                current_fn,
3488                current_cost,
3489                &mut best_fn_for_iteration,
3490                &mut best_cost_for_iteration,
3491                &mut best_score,
3492                &mut context,
3493                temp,
3494                self.objective,
3495                self.extension_costing_mode,
3496                &self.g8r_evaluation_mode,
3497                &self.canonical_g8r_options,
3498                toggle_stimulus,
3499                &self.weighted_switching_options,
3500                self.constraints,
3501            );
3502
3503            if let IterationOutcomeDetails::Accepted { kind } = &iteration_output.outcome {
3504                current_provenance.push(PirMcmcProvenanceAction::AcceptedRewrite {
3505                    action_index: current_provenance.len() + 1,
3506                    chain_no: params.chain_no,
3507                    global_iter,
3508                    transform_kind: *kind,
3509                    state: iteration_output.output_state.clone(),
3510                    cost: iteration_output.output_cost,
3511                });
3512            }
3513
3514            if let Some(w) = trajectory_writer.as_mut() {
3515                let metric_u128 = self.objective.metric(&iteration_output.output_cost);
3516                let iter_score = search_score(
3517                    &iteration_output.output_cost,
3518                    self.objective,
3519                    self.constraints,
3520                );
3521                let iter_violation = iter_score.violation;
3522                let rec = json!({
3523                    "chain_no": params.chain_no,
3524                    "role": format!("{:?}", params.role),
3525                    "global_iter": global_iter,
3526                    "temp": temp,
3527                    "outcome": iteration_outcome_tag(&iteration_output.outcome),
3528                    "best_updated": iteration_output.best_updated,
3529                    "objective": format!("{:?}", self.objective),
3530                    "extension_costing_mode": self.extension_costing_mode.value_name(),
3531                    "metric": metric_u128,
3532                    "pir_nodes": iteration_output.output_cost.pir_nodes,
3533                    "g8r_nodes": iteration_output.output_cost.g8r_nodes,
3534                    "g8r_depth": iteration_output.output_cost.g8r_depth,
3535                    "g8r_le_graph_milli": iteration_output.output_cost.g8r_le_graph_milli,
3536                    "g8r_gate_output_toggles": iteration_output.output_cost.g8r_gate_output_toggles,
3537                    "g8r_weighted_switching_milli": iteration_output.output_cost.g8r_weighted_switching_milli,
3538                    "g8r_post_and_nodes": iteration_output.output_cost.g8r_post_and_nodes,
3539                    "g8r_post_depth": iteration_output.output_cost.g8r_post_depth,
3540                    "g8r_post_le_graph_milli": iteration_output.output_cost.g8r_post_le_graph_milli,
3541                    "g8r_post_gate_output_toggles": iteration_output.output_cost.g8r_post_gate_output_toggles,
3542                    "g8r_post_weighted_switching_milli": iteration_output.output_cost.g8r_post_weighted_switching_milli,
3543                    "feasible": iter_score.feasible(),
3544                    "delay_over": iter_violation.and_then(|v| v.delay_over),
3545                    "area_over": iter_violation.and_then(|v| v.area_over),
3546                    "oracle_time_micros": iteration_output.oracle_time_micros,
3547                    "transform": iteration_output.transform.map(|k| format!("{:?}", k)),
3548                    "transform_mechanism": iteration_output.transform.map(|k| k.mechanism_hint()),
3549                    "transform_always_equivalent": iteration_output.transform_always_equivalent,
3550                    "accepted_digest": Option::<String>::None,
3551                    "accepted_sample_sent": false,
3552                });
3553                writeln!(w, "{}", rec.to_string())?;
3554                if global_iter % 1000 == 0 {
3555                    w.flush()?;
3556                }
3557            }
3558
3559            current_fn = iteration_output.output_state.clone();
3560            current_cost = iteration_output.output_cost;
3561
3562            if iteration_output.best_updated {
3563                best_state = ProvenancedChainState {
3564                    search_fn: best_fn_for_iteration.clone(),
3565                    search_provenance: current_provenance.clone(),
3566                    raw_winner_fn: current_fn.clone(),
3567                    raw_winner_cost: current_cost,
3568                    raw_winner_provenance: current_provenance.clone(),
3569                    pending_handoff: None,
3570                };
3571
3572                if let Some(ref shared_best) = self.shared_best {
3573                    let before = shared_best.score();
3574                    let _ = shared_best.try_update(best_score, best_fn_for_iteration.clone());
3575                    let after = shared_best.score();
3576                    if after < before {
3577                        log::info!(
3578                            "[pir-mcmc] GLOBAL BEST UPDATE c{:03}:i{:06} | {} -> {}",
3579                            params.chain_no,
3580                            global_iter,
3581                            format_search_score(before),
3582                            format_search_score(after),
3583                        );
3584                        if let Some(ref tx) = self.checkpoint_tx {
3585                            let _ = tx.send(CheckpointMsg {
3586                                chain_no: params.chain_no,
3587                                global_iter,
3588                                kind: CheckpointKind::GlobalBestUpdate,
3589                            });
3590                        }
3591                    }
3592                }
3593            }
3594
3595            if self.checkpoint_iters > 0 && global_iter % self.checkpoint_iters == 0 {
3596                if let Some(ref tx) = self.checkpoint_tx {
3597                    let _ = tx.send(CheckpointMsg {
3598                        chain_no: params.chain_no,
3599                        global_iter,
3600                        kind: CheckpointKind::Periodic,
3601                    });
3602                }
3603            }
3604
3605            stats.update_for_iteration(&iteration_output, /* paranoid= */ false, global_iter);
3606
3607            if self.progress_iters > 0
3608                && (global_iter % self.progress_iters == 0
3609                    || global_iter == params.total_iters
3610                    || iterations_count == params.segment_iters)
3611            {
3612                let elapsed_secs = seg_start_time.elapsed().as_secs_f64();
3613                let samples_per_sec = if elapsed_secs > 0.0 {
3614                    iterations_count as f64 / elapsed_secs
3615                } else {
3616                    0.0
3617                };
3618                log::info!(
3619                    "PIR MCMC c{:03}:i{:06} | GBest={} | LBest (pir={}, g8r_n={}, g8r_d={}, score={}) | Cur (pir={}, g8r_n={}, g8r_d={}, score={}) | Temp={:.2e} | Samples/s={:.2}",
3620                    params.chain_no,
3621                    global_iter,
3622                    self.shared_best
3623                        .as_ref()
3624                        .map(|b| format_search_score(b.score()))
3625                        .unwrap_or_else(|| "none".to_string()),
3626                    best_cost_for_iteration.pir_nodes,
3627                    best_cost_for_iteration.g8r_nodes,
3628                    best_cost_for_iteration.g8r_depth,
3629                    format_search_score(best_score),
3630                    current_cost.pir_nodes,
3631                    current_cost.g8r_nodes,
3632                    current_cost.g8r_depth,
3633                    format_search_score(search_score(
3634                        &current_cost,
3635                        self.objective,
3636                        self.constraints
3637                    )),
3638                    temp,
3639                    samples_per_sec,
3640                );
3641            }
3642        }
3643
3644        if let Some(mut w) = trajectory_writer {
3645            w.flush()?;
3646        }
3647
3648        Ok(SegmentOutcome {
3649            end_state: ProvenancedChainState {
3650                search_fn: current_fn,
3651                search_provenance: current_provenance,
3652                raw_winner_fn: best_state.raw_winner_fn.clone(),
3653                raw_winner_cost: best_state.raw_winner_cost,
3654                raw_winner_provenance: best_state.raw_winner_provenance.clone(),
3655                pending_handoff: None,
3656            },
3657            end_cost: current_cost,
3658            best_state,
3659            best_cost: best_cost_for_iteration,
3660            stats,
3661        })
3662    }
3663}
3664
3665/// Runs a single-chain MCMC optimization over a PIR function.
3666///
3667/// This function is deterministic for fixed `start_fn` and `options`.
3668pub fn run_pir_mcmc(start_fn: IrFn, options: RunOptions) -> Result<PirMcmcResult> {
3669    run_pir_mcmc_with_shared_best(start_fn, options, None, None, None)
3670}
3671
3672/// Runs MCMC while retaining the exact provenance that led to the final raw
3673/// winning state.
3674pub fn run_pir_mcmc_with_artifact(start_fn: IrFn, options: RunOptions) -> Result<PirMcmcArtifact> {
3675    validate_pir_mcmc_artifact_run_options(&options)?;
3676    let enable_formal_oracle = options.enable_formal_oracle;
3677    Ok(run_pir_mcmc_with_artifact_using_transform_factory(
3678        start_fn,
3679        options,
3680        Arc::new(move || get_pir_transforms_for_run(enable_formal_oracle)),
3681    )?
3682    .artifact)
3683}
3684
3685#[cfg(test)]
3686fn run_pir_mcmc_with_artifact_using_transforms(
3687    start_fn: IrFn,
3688    options: RunOptions,
3689    all_transforms: Vec<Box<dyn PirTransform>>,
3690) -> Result<PirMcmcArtifactRunOutput> {
3691    let transforms = Arc::new(Mutex::new(Some(all_transforms)));
3692    run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3693        start_fn,
3694        options,
3695        Arc::new(move || {
3696            transforms
3697                .lock()
3698                .expect("artifact transform fixture lock poisoned")
3699                .take()
3700                .expect("artifact transform fixture can only be consumed once")
3701        }),
3702        None,
3703        None,
3704    )
3705}
3706
3707fn run_pir_mcmc_with_artifact_using_transform_factory(
3708    start_fn: IrFn,
3709    options: RunOptions,
3710    transform_factory: PirTransformFactory,
3711) -> Result<PirMcmcArtifactRunOutput> {
3712    run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3713        start_fn,
3714        options,
3715        transform_factory,
3716        None,
3717        None,
3718    )
3719}
3720
3721fn run_pir_mcmc_with_artifact_and_observers(
3722    start_fn: IrFn,
3723    options: RunOptions,
3724    shared_best: Option<Arc<Best>>,
3725    checkpoint_tx: Option<Sender<CheckpointMsg>>,
3726) -> Result<PirMcmcArtifactRunOutput> {
3727    validate_pir_mcmc_artifact_run_options(&options)?;
3728    let enable_formal_oracle = options.enable_formal_oracle;
3729    run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3730        start_fn,
3731        options,
3732        Arc::new(move || get_pir_transforms_for_run(enable_formal_oracle)),
3733        shared_best,
3734        checkpoint_tx,
3735    )
3736}
3737
3738fn run_pir_mcmc_with_artifact_using_transform_factory_and_observers(
3739    start_fn: IrFn,
3740    options: RunOptions,
3741    transform_factory: PirTransformFactory,
3742    shared_best: Option<Arc<Best>>,
3743    checkpoint_tx: Option<Sender<CheckpointMsg>>,
3744) -> Result<PirMcmcArtifactRunOutput> {
3745    let prepared = prepare_run_start(start_fn, &options)?;
3746    let origin_fn = prepared.start_fn.clone();
3747    let origin_cost = prepared.initial_cost;
3748    let runner = PirArtifactSegmentRunner {
3749        objective: options.objective,
3750        extension_costing_mode: options.extension_costing_mode,
3751        g8r_evaluation_mode: options.g8r_evaluation_mode.clone(),
3752        canonical_g8r_options: options.canonical_g8r_options.clone(),
3753        weighted_switching_options: options.weighted_switching_options,
3754        initial_temperature: options.initial_temperature,
3755        constraints: prepared.effective_constraints,
3756        enable_formal_oracle: options.enable_formal_oracle,
3757        progress_iters: options.progress_iters,
3758        checkpoint_iters: options.checkpoint_iters,
3759        checkpoint_tx,
3760        shared_best,
3761        trajectory_dir: options.trajectory_dir.clone(),
3762        prepared_toggle_stimulus: prepared.prepared_toggle_stimulus,
3763        transform_factory,
3764    };
3765    let objective = options.objective;
3766    let threshold = options.initial_temperature as u128;
3767    let constraints = prepared.effective_constraints;
3768    let (best_state, best_cost, stats) = run_multichain(
3769        ProvenancedChainState::origin(prepared.start_fn, prepared.initial_cost),
3770        options.max_iters,
3771        options.seed,
3772        options.threads.max(1) as usize,
3773        options.chain_strategy,
3774        options.checkpoint_iters,
3775        Arc::new(runner),
3776        move |c: &Cost| search_score(c, objective, constraints),
3777        |s: &ProvenancedChainState| s.search_fn.to_string(),
3778        move |cur_cost: &Cost, global_best_cost: &Cost| {
3779            let cur_score = search_score(cur_cost, objective, constraints);
3780            let global_best_score = search_score(global_best_cost, objective, constraints);
3781            match (cur_score.violation, global_best_score.violation) {
3782                (Some(_), None) => true,
3783                (None, Some(_)) => false,
3784                (Some(cur_violation), Some(best_violation)) => {
3785                    repair_energy(cur_violation)
3786                        > repair_energy(best_violation).saturating_add(threshold)
3787                }
3788                (None, None) => {
3789                    objective.metric(cur_cost)
3790                        > objective.metric(global_best_cost).saturating_add(threshold)
3791                }
3792            }
3793        },
3794        |best_state: &ProvenancedChainState, receiving_chain_no, global_iter| {
3795            best_state.with_xls_optimized_handoff(receiving_chain_no, global_iter)
3796        },
3797    )?;
3798
3799    Ok(PirMcmcArtifactRunOutput {
3800        result: PirMcmcResult {
3801            best_fn: best_state.search_fn.clone(),
3802            best_cost,
3803            stats,
3804        },
3805        artifact: PirMcmcArtifact {
3806            origin_fn,
3807            origin_cost,
3808            run_options: options,
3809            raw_winner_fn: best_state.raw_winner_fn,
3810            raw_winner_cost: best_state.raw_winner_cost,
3811            winning_provenance: best_state.raw_winner_provenance,
3812        },
3813    })
3814}
3815
3816/// Selects the earliest winning-provenance prefix that retains the
3817/// requested fraction of the discovered objective improvement.
3818pub fn minimize_winning_prefix(
3819    artifact: &PirMcmcArtifact,
3820    options: PirMcmcPrefixMinimizeOptions,
3821) -> Result<PirMcmcPrefixMinimizeResult> {
3822    validate_prefix_minimization_artifact(artifact)?;
3823    if !options.retained_win_fraction.is_finite()
3824        || !(0.0..=1.0).contains(&options.retained_win_fraction)
3825    {
3826        return Err(anyhow::anyhow!(
3827            "retained_win_fraction must be finite and in [0, 1]; got {}",
3828            options.retained_win_fraction
3829        ));
3830    }
3831
3832    let objective = artifact.run_options.objective;
3833    let origin_metric = objective.metric(&artifact.origin_cost);
3834    let winner_metric = objective.metric(&artifact.raw_winner_cost);
3835    if winner_metric >= origin_metric {
3836        return Err(anyhow::anyhow!(
3837            "artifact does not contain a positive objective win: origin_metric={}, winner_metric={}",
3838            origin_metric,
3839            winner_metric
3840        ));
3841    }
3842
3843    if options.retained_win_fraction == 0.0 {
3844        return Ok(PirMcmcPrefixMinimizeResult {
3845            witness_fn: artifact.origin_fn.clone(),
3846            witness_cost: artifact.origin_cost,
3847            provenance_action_count: 0,
3848            original_winning_provenance_len: artifact.winning_provenance.len(),
3849            requested_retained_win_fraction: options.retained_win_fraction,
3850            actual_retained_win_fraction: 0.0,
3851            origin_metric,
3852            winner_metric,
3853            witness_metric: origin_metric,
3854        });
3855    }
3856
3857    for action in artifact.winning_provenance.iter() {
3858        let witness_metric = objective.metric(&action.cost());
3859        let actual_retained_win_fraction =
3860            retained_win_fraction_for_metric(origin_metric, winner_metric, witness_metric);
3861        if actual_retained_win_fraction >= options.retained_win_fraction {
3862            return Ok(PirMcmcPrefixMinimizeResult {
3863                witness_fn: action.state().clone(),
3864                witness_cost: action.cost(),
3865                provenance_action_count: action.action_index(),
3866                original_winning_provenance_len: artifact.winning_provenance.len(),
3867                requested_retained_win_fraction: options.retained_win_fraction,
3868                actual_retained_win_fraction,
3869                origin_metric,
3870                winner_metric,
3871                witness_metric,
3872            });
3873        }
3874    }
3875
3876    Err(anyhow::anyhow!(
3877        "winning provenance did not contain a prefix retaining requested win fraction {}",
3878        options.retained_win_fraction
3879    ))
3880}
3881
3882struct GuidedRolloutResult {
3883    best_witness: PirMcmcBudgetWitness,
3884}
3885
3886fn run_witness_guided_rollout(
3887    artifact: &PirMcmcArtifact,
3888    action_budget: usize,
3889    rollout_seed: u64,
3890    weights: Vec<f64>,
3891    proposal_attempts_per_rewrite: usize,
3892    transforms: Vec<Box<dyn PirTransform>>,
3893    origin_metric: u128,
3894    winner_metric: u128,
3895) -> Result<GuidedRolloutResult> {
3896    let objective = artifact.run_options.objective;
3897    let mut iteration_rng = Pcg64Mcg::seed_from_u64(rollout_seed);
3898    let mut context = PirMcmcContext {
3899        rng: &mut iteration_rng,
3900        all_transforms: transforms,
3901        weights,
3902        enable_formal_oracle: artifact.run_options.enable_formal_oracle,
3903        oracle_baseline_cache: EvalFnBaselineResults::default(),
3904    };
3905    let mut current_fn = artifact.origin_fn.clone();
3906    let mut current_cost = artifact.origin_cost;
3907    let mut best_fn_for_iteration = artifact.origin_fn.clone();
3908    let mut best_cost_for_iteration = artifact.origin_cost;
3909    let mut best_score = search_score(
3910        &best_cost_for_iteration,
3911        objective,
3912        ConstraintLimits::default(),
3913    );
3914    let mut best_witness = make_budget_witness(
3915        artifact.origin_fn.clone(),
3916        artifact.origin_cost,
3917        0,
3918        objective,
3919        origin_metric,
3920        winner_metric,
3921    );
3922    let max_proposals = action_budget.saturating_mul(proposal_attempts_per_rewrite);
3923    let mut accepted_rewrites = 0usize;
3924    let mut proposal_attempts = 0usize;
3925
3926    while accepted_rewrites < action_budget && proposal_attempts < max_proposals {
3927        proposal_attempts += 1;
3928        let progress_ratio = accepted_rewrites as f64 / action_budget.max(1) as f64;
3929        let temp = artifact.run_options.initial_temperature
3930            * (1.0 - progress_ratio.min(1.0)).max(MIN_TEMPERATURE_RATIO);
3931        let iteration_output = mcmc_iteration(
3932            current_fn,
3933            current_cost,
3934            &mut best_fn_for_iteration,
3935            &mut best_cost_for_iteration,
3936            &mut best_score,
3937            &mut context,
3938            temp,
3939            objective,
3940            artifact.run_options.extension_costing_mode,
3941            &artifact.run_options.g8r_evaluation_mode,
3942            &artifact.run_options.canonical_g8r_options,
3943            None,
3944            &artifact.run_options.weighted_switching_options,
3945            ConstraintLimits::default(),
3946        );
3947        current_fn = iteration_output.output_state.clone();
3948        current_cost = iteration_output.output_cost;
3949        if matches!(
3950            iteration_output.outcome,
3951            IterationOutcomeDetails::Accepted { .. }
3952        ) {
3953            accepted_rewrites += 1;
3954            let candidate = make_budget_witness(
3955                current_fn.clone(),
3956                current_cost,
3957                accepted_rewrites,
3958                objective,
3959                origin_metric,
3960                winner_metric,
3961            );
3962            if better_budget_witness(&candidate, &best_witness) {
3963                best_witness = candidate;
3964            }
3965        }
3966    }
3967
3968    Ok(GuidedRolloutResult { best_witness })
3969}
3970
3971/// Searches for best-found short witnesses at a schedule of provenance-action
3972/// budgets, using the long witness to bias transform proposals.
3973pub fn search_winning_budget_frontier(
3974    artifact: &PirMcmcArtifact,
3975    options: PirMcmcBudgetFrontierOptions,
3976) -> Result<PirMcmcBudgetFrontierResult> {
3977    search_winning_budget_frontier_with_rollout(
3978        artifact,
3979        options,
3980        |action_budget, _rollout_idx, rollout_seed, origin_metric, winner_metric| {
3981            let transforms = get_pir_transforms_for_run(artifact.run_options.enable_formal_oracle);
3982            let weights = build_witness_guided_transform_weights(
3983                &transforms,
3984                artifact,
3985                options.witness_kind_boost,
3986            );
3987            let rollout = run_witness_guided_rollout(
3988                artifact,
3989                action_budget,
3990                rollout_seed,
3991                weights,
3992                options.proposal_attempts_per_rewrite,
3993                transforms,
3994                origin_metric,
3995                winner_metric,
3996            )?;
3997            Ok(rollout.best_witness)
3998        },
3999    )
4000}
4001
4002fn search_winning_budget_frontier_with_rollout<F>(
4003    artifact: &PirMcmcArtifact,
4004    options: PirMcmcBudgetFrontierOptions,
4005    mut run_rollout: F,
4006) -> Result<PirMcmcBudgetFrontierResult>
4007where
4008    F: FnMut(usize, usize, u64, u128, u128) -> Result<PirMcmcBudgetWitness>,
4009{
4010    validate_prefix_minimization_artifact(artifact)?;
4011    if !objective_supports_budget_frontier_search(artifact.run_options.objective) {
4012        return Err(anyhow::anyhow!(
4013            "budget frontier search currently supports only objectives that can be recomputed without stored toggle stimulus; got {}",
4014            artifact.run_options.objective.value_name()
4015        ));
4016    }
4017    let budgets = frontier_budgets(options)?;
4018    let objective = artifact.run_options.objective;
4019    let origin_metric = objective.metric(&artifact.origin_cost);
4020    let winner_metric = objective.metric(&artifact.raw_winner_cost);
4021    if winner_metric >= origin_metric {
4022        return Err(anyhow::anyhow!(
4023            "artifact does not contain a positive objective win: origin_metric={}, winner_metric={}",
4024            origin_metric,
4025            winner_metric
4026        ));
4027    }
4028
4029    let mut carried_guided = make_budget_witness(
4030        artifact.origin_fn.clone(),
4031        artifact.origin_cost,
4032        0,
4033        objective,
4034        origin_metric,
4035        winner_metric,
4036    );
4037    let mut points = Vec::with_capacity(budgets.len());
4038    for action_budget in budgets {
4039        let prefix_baseline =
4040            prefix_baseline_for_budget(artifact, action_budget, origin_metric, winner_metric);
4041        let mut best_for_budget = carried_guided.clone();
4042        for rollout_idx in 0..options.rollouts_per_budget {
4043            let rollout_seed = options
4044                .seed
4045                .wrapping_add((action_budget as u64).wrapping_mul(1_000_003))
4046                .wrapping_add(rollout_idx as u64);
4047            let rollout_witness = run_rollout(
4048                action_budget,
4049                rollout_idx,
4050                rollout_seed,
4051                origin_metric,
4052                winner_metric,
4053            )?;
4054            if better_budget_witness(&rollout_witness, &best_for_budget) {
4055                best_for_budget = rollout_witness;
4056            }
4057        }
4058        carried_guided = best_for_budget.clone();
4059        points.push(PirMcmcBudgetFrontierPoint {
4060            action_budget,
4061            guided: best_for_budget,
4062            prefix_baseline,
4063        });
4064    }
4065
4066    Ok(PirMcmcBudgetFrontierResult {
4067        origin_metric,
4068        winner_metric,
4069        original_winning_provenance_len: artifact.winning_provenance.len(),
4070        points,
4071    })
4072}
4073
4074pub fn run_pir_mcmc_with_shared_best(
4075    start_fn: IrFn,
4076    options: RunOptions,
4077    shared_best: Option<Arc<Best>>,
4078    checkpoint_tx: Option<Sender<CheckpointMsg>>,
4079    accepted_sample_tx: Option<Sender<AcceptedSampleMsg>>,
4080) -> Result<PirMcmcResult> {
4081    let prepared = prepare_run_start(start_fn, &options)?;
4082    let runner = PirSegmentRunner {
4083        objective: options.objective,
4084        extension_costing_mode: options.extension_costing_mode,
4085        g8r_evaluation_mode: options.g8r_evaluation_mode.clone(),
4086        canonical_g8r_options: options.canonical_g8r_options.clone(),
4087        weighted_switching_options: options.weighted_switching_options,
4088        initial_temperature: options.initial_temperature,
4089        constraints: prepared.effective_constraints,
4090        enable_formal_oracle: options.enable_formal_oracle,
4091        progress_iters: options.progress_iters,
4092        checkpoint_iters: options.checkpoint_iters,
4093        checkpoint_tx,
4094        accepted_sample_tx,
4095        shared_best,
4096        trajectory_dir: options.trajectory_dir.clone(),
4097        prepared_toggle_stimulus: prepared.prepared_toggle_stimulus,
4098    };
4099
4100    let objective = options.objective;
4101    let threshold = options.initial_temperature as u128;
4102    let constraints = prepared.effective_constraints;
4103
4104    let (best_fn, best_cost, stats) = run_multichain(
4105        prepared.start_fn,
4106        options.max_iters,
4107        options.seed,
4108        options.threads.max(1) as usize,
4109        options.chain_strategy,
4110        options.checkpoint_iters,
4111        Arc::new(runner),
4112        move |c: &Cost| search_score(c, objective, constraints),
4113        |f: &IrFn| f.to_string(),
4114        move |cur_cost: &Cost, global_best_cost: &Cost| {
4115            let cur_score = search_score(cur_cost, objective, constraints);
4116            let global_best_score = search_score(global_best_cost, objective, constraints);
4117            match (cur_score.violation, global_best_score.violation) {
4118                (Some(_), None) => true,
4119                (None, Some(_)) => false,
4120                (Some(cur_violation), Some(best_violation)) => {
4121                    repair_energy(cur_violation)
4122                        > repair_energy(best_violation).saturating_add(threshold)
4123                }
4124                (None, None) => {
4125                    objective.metric(cur_cost)
4126                        > objective.metric(global_best_cost).saturating_add(threshold)
4127                }
4128            }
4129        },
4130        |best_fn: &IrFn, _, _| best_fn.clone(),
4131    )?;
4132
4133    Ok(PirMcmcResult {
4134        best_fn,
4135        best_cost,
4136        stats,
4137    })
4138}
4139
4140#[cfg(test)]
4141mod tests {
4142    use std::collections::HashSet;
4143    use std::fs;
4144    use std::os::unix::fs::PermissionsExt;
4145
4146    use super::*;
4147    use count_toggles::WeightedSwitchingOptions;
4148    use tempfile::tempdir;
4149    use xlsynth_g8r::gatify::ir2gate::{self, GatifyOptions};
4150    use xlsynth_pir::ir::{ExtNaryAddArchitecture, NodePayload};
4151    use xlsynth_pir::ir_parser;
4152    use xlsynth_pir::ir_utils::remap_payload_with;
4153
4154    fn parse_fn(ir_text: &str) -> IrFn {
4155        let mut parser = ir_parser::Parser::new(ir_text);
4156        parser.parse_fn().unwrap()
4157    }
4158
4159    fn parse_pkg(ir_text: &str) -> PirPackage {
4160        let mut parser = ir_parser::Parser::new(ir_text);
4161        parser.parse_and_validate_package().unwrap()
4162    }
4163
4164    fn write_executable_script(dir: &Path, name: &str, body: &str) -> PathBuf {
4165        let path = dir.join(name);
4166        fs::write(&path, body).unwrap();
4167        let mut permissions = fs::metadata(&path).unwrap().permissions();
4168        permissions.set_mode(0o755);
4169        fs::set_permissions(&path, permissions).unwrap();
4170        path
4171    }
4172
4173    fn test_run_options(objective: Objective) -> RunOptions {
4174        RunOptions {
4175            max_iters: 1,
4176            threads: 1,
4177            chain_strategy: ChainStrategy::Independent,
4178            checkpoint_iters: 100,
4179            progress_iters: 0,
4180            seed: 1,
4181            initial_temperature: 1.0,
4182            objective,
4183            extension_costing_mode: ExtensionCostingMode::Preserve,
4184            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4185            canonical_g8r_options: CanonicalG8rOptions::default(),
4186            max_allowed_depth: None,
4187            max_allowed_area: None,
4188            weighted_switching_options: WeightedSwitchingOptions::default(),
4189            enable_formal_oracle: false,
4190            trajectory_dir: None,
4191            toggle_stimulus: None,
4192        }
4193    }
4194
4195    fn cost_with_pir_nodes(pir_nodes: usize) -> Cost {
4196        Cost {
4197            pir_nodes,
4198            g8r_nodes: pir_nodes,
4199            g8r_depth: pir_nodes,
4200            g8r_le_graph_milli: 0,
4201            g8r_gate_output_toggles: 0,
4202            g8r_weighted_switching_milli: 0,
4203            g8r_post_and_nodes: 0,
4204            g8r_post_depth: 0,
4205            g8r_post_le_graph_milli: 0,
4206            g8r_post_gate_output_toggles: 0,
4207            g8r_post_weighted_switching_milli: 0,
4208        }
4209    }
4210
4211    fn renamed_state(base: &IrFn, name: &str) -> IrFn {
4212        let mut f = base.clone();
4213        f.name = name.to_string();
4214        f
4215    }
4216
4217    fn manual_prefix_artifact() -> PirMcmcArtifact {
4218        let origin = parse_fn(
4219            r#"fn origin(x: bits[8] id=1) -> bits[8] {
4220  ret identity.2: bits[8] = identity(x, id=2)
4221}"#,
4222        );
4223        let step1_state = renamed_state(&origin, "step1");
4224        let step2_state = renamed_state(&origin, "step2");
4225        let step3_state = renamed_state(&origin, "step3");
4226        PirMcmcArtifact {
4227            origin_fn: origin.clone(),
4228            origin_cost: cost_with_pir_nodes(100),
4229            run_options: test_run_options(Objective::Nodes),
4230            raw_winner_fn: step3_state.clone(),
4231            raw_winner_cost: cost_with_pir_nodes(50),
4232            winning_provenance: vec![
4233                PirMcmcProvenanceAction::AcceptedRewrite {
4234                    action_index: 1,
4235                    chain_no: 0,
4236                    global_iter: 2,
4237                    transform_kind: PirTransformKind::NotNotCancel,
4238                    state: step1_state,
4239                    cost: cost_with_pir_nodes(90),
4240                },
4241                PirMcmcProvenanceAction::AcceptedRewrite {
4242                    action_index: 2,
4243                    chain_no: 0,
4244                    global_iter: 4,
4245                    transform_kind: PirTransformKind::NegNegCancel,
4246                    state: step2_state,
4247                    cost: cost_with_pir_nodes(70),
4248                },
4249                PirMcmcProvenanceAction::AcceptedRewrite {
4250                    action_index: 3,
4251                    chain_no: 0,
4252                    global_iter: 7,
4253                    transform_kind: PirTransformKind::SelSameArmsFold,
4254                    state: step3_state,
4255                    cost: cost_with_pir_nodes(50),
4256                },
4257            ],
4258        }
4259    }
4260
4261    #[derive(Debug)]
4262    struct RemoveDeadNodeTestTransform;
4263
4264    impl PirTransform for RemoveDeadNodeTestTransform {
4265        fn kind(&self) -> PirTransformKind {
4266            PirTransformKind::NotNotCancel
4267        }
4268
4269        fn find_candidates(&mut self, f: &IrFn) -> Vec<crate::transforms::TransformCandidate> {
4270            f.node_refs()
4271                .into_iter()
4272                .filter(|nr| {
4273                    f.get_node(*nr)
4274                        .name
4275                        .as_deref()
4276                        .map(|name| name.starts_with("dead"))
4277                        .unwrap_or(false)
4278                })
4279                .map(|nr| crate::transforms::TransformCandidate {
4280                    location: crate::transforms::TransformLocation::Node(nr),
4281                    always_equivalent: true,
4282                })
4283                .collect()
4284        }
4285
4286        fn apply(
4287            &self,
4288            f: &mut IrFn,
4289            loc: &crate::transforms::TransformLocation,
4290        ) -> Result<(), String> {
4291            let crate::transforms::TransformLocation::Node(nr) = loc else {
4292                return Err("RemoveDeadNodeTestTransform expects a node location".to_string());
4293            };
4294            f.get_node_mut(*nr).payload = NodePayload::Nil;
4295            Ok(())
4296        }
4297    }
4298
4299    #[test]
4300    fn pir_mcmc_runs_and_is_deterministic_on_simple_add() {
4301        let ir_text = r#"fn add(x: bits[8] id=10, y: bits[8] id=20) -> bits[8] {
4302  ret add.42: bits[8] = add(x, y, id=42)
4303}"#;
4304        let mut parser = ir_parser::Parser::new(ir_text);
4305        let ir_fn = parser.parse_fn().unwrap();
4306
4307        let opts = RunOptions {
4308            max_iters: 10,
4309            threads: 1,
4310            chain_strategy: ChainStrategy::Independent,
4311            checkpoint_iters: 5000,
4312            progress_iters: 0,
4313            seed: 1,
4314            initial_temperature: 5.0,
4315            objective: Objective::Nodes,
4316            extension_costing_mode: ExtensionCostingMode::Preserve,
4317            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4318            canonical_g8r_options: CanonicalG8rOptions::default(),
4319            max_allowed_depth: None,
4320            max_allowed_area: None,
4321            weighted_switching_options: WeightedSwitchingOptions::default(),
4322            enable_formal_oracle: false,
4323            trajectory_dir: None,
4324            toggle_stimulus: None,
4325        };
4326
4327        let res1 = run_pir_mcmc(ir_fn.clone(), opts.clone()).unwrap();
4328        let res2 = run_pir_mcmc(ir_fn.clone(), opts).unwrap();
4329
4330        assert_eq!(res1.best_cost.pir_nodes, ir_fn.nodes.len());
4331        assert_eq!(res2.best_cost.pir_nodes, ir_fn.nodes.len());
4332
4333        // Determinism: for fixed seed and options, we should get the same best
4334        // function text.
4335        assert_eq!(res1.best_fn.to_string(), res2.best_fn.to_string());
4336    }
4337
4338    #[test]
4339    fn prefix_minimization_selects_earliest_prefix_meeting_requested_win() {
4340        let artifact = manual_prefix_artifact();
4341
4342        let retain_all = minimize_winning_prefix(
4343            &artifact,
4344            PirMcmcPrefixMinimizeOptions {
4345                retained_win_fraction: 1.0,
4346            },
4347        )
4348        .unwrap();
4349        assert_eq!(retain_all.provenance_action_count, 3);
4350        assert_eq!(retain_all.witness_metric, 50);
4351
4352        let retain_most = minimize_winning_prefix(
4353            &artifact,
4354            PirMcmcPrefixMinimizeOptions {
4355                retained_win_fraction: 0.6,
4356            },
4357        )
4358        .unwrap();
4359        assert_eq!(retain_most.provenance_action_count, 2);
4360        assert_eq!(retain_most.witness_fn.name, "step2");
4361        assert_eq!(retain_most.witness_metric, 70);
4362        assert_eq!(retain_most.original_winning_provenance_len, 3);
4363        assert!((retain_most.actual_retained_win_fraction - 0.6).abs() < 1e-12);
4364
4365        let retain_none = minimize_winning_prefix(
4366            &artifact,
4367            PirMcmcPrefixMinimizeOptions {
4368                retained_win_fraction: 0.0,
4369            },
4370        )
4371        .unwrap();
4372        assert_eq!(retain_none.provenance_action_count, 0);
4373        assert_eq!(retain_none.witness_fn.name, "origin");
4374        assert_eq!(retain_none.witness_metric, 100);
4375    }
4376
4377    #[test]
4378    fn prefix_minimization_rejects_invalid_fraction_and_non_winning_artifacts() {
4379        let artifact = manual_prefix_artifact();
4380        for retained_win_fraction in [f64::NAN, -0.1, 1.1] {
4381            let err = minimize_winning_prefix(
4382                &artifact,
4383                PirMcmcPrefixMinimizeOptions {
4384                    retained_win_fraction,
4385                },
4386            )
4387            .unwrap_err();
4388            assert!(
4389                err.to_string().contains("retained_win_fraction"),
4390                "unexpected error: {err}"
4391            );
4392        }
4393
4394        let mut no_win = artifact.clone();
4395        no_win.origin_cost = no_win.raw_winner_cost;
4396        let err = minimize_winning_prefix(
4397            &no_win,
4398            PirMcmcPrefixMinimizeOptions {
4399                retained_win_fraction: 0.5,
4400            },
4401        )
4402        .unwrap_err();
4403        assert!(
4404            err.to_string().contains("positive objective win"),
4405            "unexpected error: {err}"
4406        );
4407    }
4408
4409    #[test]
4410    fn frontier_schedule_validates_step_and_max() {
4411        let opts = PirMcmcBudgetFrontierOptions {
4412            budget_step: 4,
4413            max_actions: 16,
4414            rollouts_per_budget: 1,
4415            seed: 1,
4416            witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4417            proposal_attempts_per_rewrite:
4418                PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4419        };
4420        assert_eq!(frontier_budgets(opts).unwrap(), vec![4, 8, 12, 16]);
4421        assert_eq!(
4422            frontier_budgets(PirMcmcBudgetFrontierOptions {
4423                max_actions: 10,
4424                ..opts
4425            })
4426            .unwrap(),
4427            vec![4, 8, 10]
4428        );
4429        assert!(
4430            frontier_budgets(PirMcmcBudgetFrontierOptions {
4431                budget_step: 0,
4432                ..opts
4433            })
4434            .is_err()
4435        );
4436        assert!(
4437            frontier_budgets(PirMcmcBudgetFrontierOptions {
4438                max_actions: 0,
4439                ..opts
4440            })
4441            .is_err()
4442        );
4443        assert!(
4444            frontier_budgets(PirMcmcBudgetFrontierOptions {
4445                budget_step: 8,
4446                max_actions: 4,
4447                ..opts
4448            })
4449            .is_err()
4450        );
4451    }
4452
4453    #[test]
4454    fn win_percent_vs_origin_reports_percentage_points() {
4455        assert!(
4456            (win_percent_vs_origin_for_metric(1205, 1173) - 2.655_601_659_751_037).abs() < 1e-12
4457        );
4458    }
4459
4460    #[test]
4461    fn witness_guided_weights_favor_lineage_kinds_without_excluding_others() {
4462        #[derive(Debug)]
4463        struct KindOnlyTransform(PirTransformKind);
4464        impl PirTransform for KindOnlyTransform {
4465            fn kind(&self) -> PirTransformKind {
4466                self.0
4467            }
4468
4469            fn find_candidates(&mut self, _f: &IrFn) -> Vec<crate::transforms::TransformCandidate> {
4470                Vec::new()
4471            }
4472
4473            fn apply(
4474                &self,
4475                _f: &mut IrFn,
4476                _loc: &crate::transforms::TransformLocation,
4477            ) -> Result<(), String> {
4478                Ok(())
4479            }
4480        }
4481
4482        let artifact = manual_prefix_artifact();
4483        let transforms: Vec<Box<dyn PirTransform>> = vec![
4484            Box::new(KindOnlyTransform(PirTransformKind::NotNotCancel)),
4485            Box::new(KindOnlyTransform(PirTransformKind::CmpSwap)),
4486        ];
4487        let weights = build_witness_guided_transform_weights(&transforms, &artifact, 4.0);
4488        assert_eq!(weights, vec![5.0, 1.0]);
4489    }
4490
4491    #[test]
4492    fn frontier_reports_prefix_baseline_and_carries_guided_points_forward() {
4493        let artifact = manual_prefix_artifact();
4494        let opts = PirMcmcBudgetFrontierOptions {
4495            budget_step: 1,
4496            max_actions: 3,
4497            rollouts_per_budget: 1,
4498            seed: 7,
4499            witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4500            proposal_attempts_per_rewrite:
4501                PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4502        };
4503        let result = search_winning_budget_frontier_with_rollout(
4504            &artifact,
4505            opts,
4506            |budget, _, _, origin_metric, winner_metric| {
4507                let (state_name, cost) = match budget {
4508                    1 => ("guided1", cost_with_pir_nodes(95)),
4509                    2 => ("guided2", cost_with_pir_nodes(80)),
4510                    3 => ("guided3", cost_with_pir_nodes(85)),
4511                    _ => unreachable!(),
4512                };
4513                Ok(make_budget_witness(
4514                    renamed_state(&artifact.origin_fn, state_name),
4515                    cost,
4516                    budget,
4517                    Objective::Nodes,
4518                    origin_metric,
4519                    winner_metric,
4520                ))
4521            },
4522        )
4523        .unwrap();
4524
4525        assert_eq!(result.points.len(), 3);
4526        assert_eq!(result.points[0].prefix_baseline.metric, 90);
4527        assert_eq!(result.points[1].prefix_baseline.metric, 70);
4528        assert_eq!(result.points[2].prefix_baseline.metric, 50);
4529        assert_eq!(result.points[0].guided.metric, 95);
4530        assert_eq!(result.points[1].guided.metric, 80);
4531        assert_eq!(
4532            result.points[2].guided.metric, 80,
4533            "worse later searches must carry forward the prior frontier point"
4534        );
4535    }
4536
4537    #[test]
4538    fn frontier_uses_origin_fallback_when_rollouts_do_not_improve() {
4539        let artifact = manual_prefix_artifact();
4540        let opts = PirMcmcBudgetFrontierOptions {
4541            budget_step: 2,
4542            max_actions: 4,
4543            rollouts_per_budget: 1,
4544            seed: 1,
4545            witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4546            proposal_attempts_per_rewrite:
4547                PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4548        };
4549        let result = search_winning_budget_frontier_with_rollout(
4550            &artifact,
4551            opts,
4552            |budget, _, _, origin_metric, winner_metric| {
4553                Ok(make_budget_witness(
4554                    renamed_state(&artifact.origin_fn, &format!("noop{budget}")),
4555                    artifact.origin_cost,
4556                    budget,
4557                    Objective::Nodes,
4558                    origin_metric,
4559                    winner_metric,
4560                ))
4561            },
4562        )
4563        .unwrap();
4564        assert_eq!(result.points[0].guided.provenance_action_count, 0);
4565        assert_eq!(result.points[0].guided.metric, 100);
4566        assert_eq!(result.points[1].guided.provenance_action_count, 0);
4567        assert_eq!(result.points[1].guided.metric, 100);
4568    }
4569
4570    #[test]
4571    fn frontier_rejects_no_win_and_toggle_objectives() {
4572        let opts = PirMcmcBudgetFrontierOptions {
4573            budget_step: 1,
4574            max_actions: 1,
4575            rollouts_per_budget: 1,
4576            seed: 1,
4577            witness_kind_boost: PirMcmcBudgetFrontierOptions::DEFAULT_WITNESS_KIND_BOOST,
4578            proposal_attempts_per_rewrite:
4579                PirMcmcBudgetFrontierOptions::DEFAULT_PROPOSAL_ATTEMPTS_PER_REWRITE,
4580        };
4581        let mut no_win = manual_prefix_artifact();
4582        no_win.raw_winner_cost = no_win.origin_cost;
4583        assert!(search_winning_budget_frontier(&no_win, opts).is_err());
4584
4585        let mut toggle_artifact = manual_prefix_artifact();
4586        toggle_artifact.run_options.objective = Objective::G8rNodesTimesDepthTimesToggles;
4587        toggle_artifact.raw_winner_cost.g8r_gate_output_toggles = 1;
4588        assert!(search_winning_budget_frontier(&toggle_artifact, opts).is_err());
4589    }
4590
4591    #[test]
4592    fn artifact_api_rejects_unsupported_run_shapes() {
4593        let f = parse_fn(
4594            r#"fn f(x: bits[1] id=1) -> bits[1] {
4595  ret identity.2: bits[1] = identity(x, id=2)
4596}"#,
4597        );
4598
4599        let mut constrained = test_run_options(Objective::G8rNodes);
4600        constrained.max_allowed_depth = Some(4);
4601        assert!(run_pir_mcmc_with_artifact(f.clone(), constrained).is_err());
4602
4603        let mut implicit_constraint =
4604            test_run_options(Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress);
4605        implicit_constraint.toggle_stimulus = Some(vec![
4606            IrValue::parse_typed("(bits[1]:0)").unwrap(),
4607            IrValue::parse_typed("(bits[1]:1)").unwrap(),
4608        ]);
4609        assert!(run_pir_mcmc_with_artifact(f, implicit_constraint).is_err());
4610    }
4611
4612    #[test]
4613    fn artifact_api_supports_independent_multichain_runs() {
4614        let f = parse_fn(
4615            r#"fn f(x: bits[8] id=1) -> bits[8] {
4616  dead: bits[8] = identity(x, id=2)
4617  ret live: bits[8] = identity(x, id=3)
4618}"#,
4619        );
4620        let mut opts = test_run_options(Objective::Nodes);
4621        opts.max_iters = 1;
4622        opts.threads = 2;
4623        let artifact = run_pir_mcmc_with_artifact_using_transform_factory(
4624            f,
4625            opts,
4626            Arc::new(|| vec![Box::new(RemoveDeadNodeTestTransform)]),
4627        )
4628        .unwrap()
4629        .artifact;
4630
4631        assert_eq!(artifact.winning_provenance.len(), 1);
4632        let PirMcmcProvenanceAction::AcceptedRewrite {
4633            action_index,
4634            chain_no,
4635            ..
4636        } = artifact.winning_provenance.last().unwrap()
4637        else {
4638            panic!("expected accepted rewrite provenance");
4639        };
4640        assert_eq!(*action_index, 1);
4641        assert_eq!(*chain_no, 0);
4642        assert_eq!(
4643            artifact
4644                .winning_provenance
4645                .last()
4646                .unwrap()
4647                .state()
4648                .to_string(),
4649            artifact.raw_winner_fn.to_string()
4650        );
4651    }
4652
4653    #[test]
4654    fn artifact_api_supports_explore_exploit_multichain_runs() {
4655        let f = parse_fn(
4656            r#"fn f(x: bits[1] id=1) -> bits[1] {
4657  ret identity.2: bits[1] = identity(x, id=2)
4658}"#,
4659        );
4660        let mut opts = test_run_options(Objective::Nodes);
4661        opts.max_iters = 0;
4662        opts.threads = 2;
4663        opts.chain_strategy = ChainStrategy::ExploreExploit;
4664        let artifact = run_pir_mcmc_with_artifact(f, opts).unwrap();
4665        assert!(artifact.winning_provenance.is_empty());
4666        assert_eq!(artifact.raw_winner_cost, artifact.origin_cost);
4667    }
4668
4669    #[test]
4670    fn artifact_segment_records_handoff_before_later_rewrite() {
4671        let f = parse_fn(
4672            r#"fn f(x: bits[8] id=1) -> bits[8] {
4673  dead: bits[8] = identity(x, id=2)
4674  ret live: bits[8] = identity(x, id=3)
4675}"#,
4676        );
4677        let origin_cost = cost_with_effort_options_toggle_stimulus_and_extension_mode(
4678            &f,
4679            Objective::Nodes,
4680            None,
4681            &WeightedSwitchingOptions::default(),
4682            ExtensionCostingMode::Preserve,
4683        )
4684        .unwrap();
4685        let runner = PirArtifactSegmentRunner {
4686            objective: Objective::Nodes,
4687            extension_costing_mode: ExtensionCostingMode::Preserve,
4688            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
4689            canonical_g8r_options: CanonicalG8rOptions::default(),
4690            weighted_switching_options: WeightedSwitchingOptions::default(),
4691            initial_temperature: 1.0,
4692            constraints: ConstraintLimits::default(),
4693            enable_formal_oracle: false,
4694            progress_iters: 0,
4695            checkpoint_iters: 0,
4696            checkpoint_tx: None,
4697            shared_best: None,
4698            trajectory_dir: None,
4699            prepared_toggle_stimulus: None,
4700            transform_factory: Arc::new(|| vec![Box::new(RemoveDeadNodeTestTransform)]),
4701        };
4702        let out = runner
4703            .run_segment(
4704                ProvenancedChainState::origin(f, origin_cost).with_xls_optimized_handoff(1, 7),
4705                SegmentRunParams {
4706                    chain_no: 1,
4707                    role: ChainRole::Exploit,
4708                    iter_offset: 7,
4709                    segment_iters: 1,
4710                    total_iters: 8,
4711                    seed: 1,
4712                },
4713            )
4714            .unwrap();
4715
4716        assert_eq!(out.best_state.raw_winner_provenance.len(), 2);
4717        assert!(matches!(
4718            out.best_state.raw_winner_provenance[0],
4719            PirMcmcProvenanceAction::XlsOptimizedHandoff {
4720                action_index: 1,
4721                chain_no: 1,
4722                global_iter: 7,
4723                ..
4724            }
4725        ));
4726        assert!(matches!(
4727            out.best_state.raw_winner_provenance[1],
4728            PirMcmcProvenanceAction::AcceptedRewrite {
4729                action_index: 2,
4730                chain_no: 1,
4731                global_iter: 8,
4732                ..
4733            }
4734        ));
4735    }
4736
4737    #[test]
4738    fn artifact_run_is_deterministic_and_captures_raw_winning_provenance() {
4739        let f = parse_fn(
4740            r#"fn f(x: bits[8] id=1) -> bits[8] {
4741  dead: bits[8] = identity(x, id=2)
4742  ret live: bits[8] = identity(x, id=3)
4743}"#,
4744        );
4745        let mut opts = test_run_options(Objective::Nodes);
4746        opts.max_iters = 1;
4747        opts.seed = 7;
4748        let transforms1: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4749        let transforms2: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4750
4751        let artifact1 =
4752            run_pir_mcmc_with_artifact_using_transforms(f.clone(), opts.clone(), transforms1)
4753                .unwrap()
4754                .artifact;
4755        let artifact2 = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms2)
4756            .unwrap()
4757            .artifact;
4758
4759        assert!(
4760            artifact1.raw_winner_cost.pir_nodes < artifact1.origin_cost.pir_nodes,
4761            "expected a real node-count win in the deterministic fixture"
4762        );
4763        assert_eq!(
4764            artifact1.origin_fn.to_string(),
4765            artifact2.origin_fn.to_string()
4766        );
4767        assert_eq!(artifact1.origin_cost, artifact2.origin_cost);
4768        assert_eq!(artifact1.raw_winner_cost, artifact2.raw_winner_cost);
4769        assert_eq!(
4770            artifact1.raw_winner_fn.to_string(),
4771            artifact2.raw_winner_fn.to_string()
4772        );
4773        assert_eq!(
4774            artifact1.winning_provenance.len(),
4775            artifact2.winning_provenance.len()
4776        );
4777        assert!(
4778            !artifact1.winning_provenance.is_empty(),
4779            "expected a non-empty winning provenance"
4780        );
4781        let last1 = artifact1.winning_provenance.last().unwrap();
4782        let last2 = artifact2.winning_provenance.last().unwrap();
4783        assert_eq!(last1.cost(), artifact1.raw_winner_cost);
4784        assert_eq!(
4785            last1.state().to_string(),
4786            artifact1.raw_winner_fn.to_string()
4787        );
4788        assert_eq!(last1.cost(), last2.cost());
4789        assert_eq!(last1.state().to_string(), last2.state().to_string());
4790        assert_eq!(last1.transform_kind(), last2.transform_kind());
4791    }
4792
4793    #[test]
4794    fn artifact_provenance_can_be_minimized_end_to_end() {
4795        let f = parse_fn(
4796            r#"fn f(x: bits[8] id=1) -> bits[8] {
4797  dead_a: bits[8] = identity(x, id=2)
4798  dead_b: bits[8] = identity(x, id=3)
4799  ret live: bits[8] = identity(x, id=4)
4800}"#,
4801        );
4802        let mut opts = test_run_options(Objective::Nodes);
4803        opts.max_iters = 2;
4804        let transforms: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4805
4806        let artifact = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms)
4807            .unwrap()
4808            .artifact;
4809        assert_eq!(artifact.origin_cost.pir_nodes, 5);
4810        assert_eq!(artifact.raw_winner_cost.pir_nodes, 3);
4811        assert_eq!(artifact.winning_provenance.len(), 2);
4812
4813        let minimized = minimize_winning_prefix(
4814            &artifact,
4815            PirMcmcPrefixMinimizeOptions {
4816                retained_win_fraction: 0.5,
4817            },
4818        )
4819        .unwrap();
4820        assert_eq!(minimized.provenance_action_count, 1);
4821        assert_eq!(minimized.original_winning_provenance_len, 2);
4822        assert_eq!(minimized.origin_metric, 5);
4823        assert_eq!(minimized.winner_metric, 3);
4824        assert_eq!(minimized.witness_metric, 4);
4825        assert!((minimized.actual_retained_win_fraction - 0.5).abs() < 1e-12);
4826    }
4827
4828    #[test]
4829    fn durable_artifact_round_trips_and_minimizes_identically() {
4830        let pkg = parse_pkg(
4831            r#"package sample
4832
4833top fn f(x: bits[8] id=1) -> bits[8] {
4834  dead_a: bits[8] = identity(x, id=2)
4835  dead_b: bits[8] = identity(x, id=3)
4836  ret live: bits[8] = identity(x, id=4)
4837}
4838"#,
4839        );
4840        let f = pkg.get_fn("f").unwrap().clone();
4841        let mut opts = test_run_options(Objective::Nodes);
4842        opts.max_iters = 2;
4843        let transforms: Vec<Box<dyn PirTransform>> = vec![Box::new(RemoveDeadNodeTestTransform)];
4844        let artifact = run_pir_mcmc_with_artifact_using_transforms(f, opts, transforms)
4845            .unwrap()
4846            .artifact;
4847        let before = minimize_winning_prefix(
4848            &artifact,
4849            PirMcmcPrefixMinimizeOptions {
4850                retained_win_fraction: 0.5,
4851            },
4852        )
4853        .unwrap();
4854
4855        let run_dir = tempdir().unwrap();
4856        write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4857        let loaded = read_pir_mcmc_artifact_dir(run_dir.path()).unwrap();
4858        let after = minimize_winning_prefix(
4859            &loaded.artifact,
4860            PirMcmcPrefixMinimizeOptions {
4861                retained_win_fraction: 0.5,
4862            },
4863        )
4864        .unwrap();
4865
4866        assert_eq!(loaded.top_fn_name, "f");
4867        assert_eq!(loaded.artifact.origin_cost, artifact.origin_cost);
4868        assert_eq!(
4869            loaded.artifact.origin_fn.to_string(),
4870            artifact.origin_fn.to_string()
4871        );
4872        assert_eq!(
4873            loaded.artifact.raw_winner_fn.to_string(),
4874            artifact.raw_winner_fn.to_string()
4875        );
4876        assert_eq!(
4877            loaded.artifact.winning_provenance.len(),
4878            artifact.winning_provenance.len()
4879        );
4880        assert_eq!(
4881            loaded.artifact.run_options.g8r_evaluation_mode,
4882            artifact.run_options.g8r_evaluation_mode
4883        );
4884        assert_eq!(
4885            loaded.artifact.winning_provenance[0].transform_kind(),
4886            artifact.winning_provenance[0].transform_kind()
4887        );
4888        assert_eq!(
4889            after.provenance_action_count,
4890            before.provenance_action_count
4891        );
4892        assert_eq!(after.witness_metric, before.witness_metric);
4893        assert_eq!(after.witness_fn.to_string(), before.witness_fn.to_string());
4894    }
4895
4896    #[test]
4897    fn durable_artifact_manifest_is_deterministic_for_fixed_run() {
4898        let pkg = parse_pkg(
4899            r#"package sample
4900
4901top fn f(x: bits[8] id=1) -> bits[8] {
4902  dead_a: bits[8] = identity(x, id=2)
4903  dead_b: bits[8] = identity(x, id=3)
4904  ret live: bits[8] = identity(x, id=4)
4905}
4906"#,
4907        );
4908        let f = pkg.get_fn("f").unwrap().clone();
4909        let mut opts = test_run_options(Objective::Nodes);
4910        opts.max_iters = 2;
4911        let artifact1 = run_pir_mcmc_with_artifact_using_transforms(
4912            f.clone(),
4913            opts.clone(),
4914            vec![Box::new(RemoveDeadNodeTestTransform)],
4915        )
4916        .unwrap()
4917        .artifact;
4918        let artifact2 = run_pir_mcmc_with_artifact_using_transforms(
4919            f,
4920            opts,
4921            vec![Box::new(RemoveDeadNodeTestTransform)],
4922        )
4923        .unwrap()
4924        .artifact;
4925
4926        let run_dir1 = tempdir().unwrap();
4927        let run_dir2 = tempdir().unwrap();
4928        let artifact_dir1 = write_pir_mcmc_artifact_dir(&artifact1, &pkg, run_dir1.path()).unwrap();
4929        let artifact_dir2 = write_pir_mcmc_artifact_dir(&artifact2, &pkg, run_dir2.path()).unwrap();
4930        let manifest1 =
4931            fs::read_to_string(artifact_dir1.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap();
4932        let manifest2 =
4933            fs::read_to_string(artifact_dir2.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap();
4934        assert_eq!(manifest1, manifest2);
4935    }
4936
4937    #[test]
4938    fn durable_artifact_manifest_canonicalizes_relative_postprocessor_path() {
4939        let pkg = parse_pkg(
4940            r#"package sample
4941
4942top fn f(x: bits[8] id=1) -> bits[8] {
4943  dead: bits[8] = identity(x, id=2)
4944  ret live: bits[8] = identity(x, id=3)
4945}
4946"#,
4947        );
4948        let f = pkg.get_fn("f").unwrap().clone();
4949        let cwd = std::env::current_dir().unwrap();
4950        let hook_dir = tempfile::tempdir_in(&cwd).unwrap();
4951        let hook = hook_dir.path().join("identity.sh");
4952        fs::write(&hook, "#!/bin/sh\n").unwrap();
4953        let relative_hook = hook.strip_prefix(&cwd).unwrap().display().to_string();
4954        let mut artifact = run_pir_mcmc_with_artifact_using_transforms(
4955            f,
4956            test_run_options(Objective::Nodes),
4957            vec![Box::new(RemoveDeadNodeTestTransform)],
4958        )
4959        .unwrap()
4960        .artifact;
4961        artifact.run_options.g8r_evaluation_mode = G8rEvaluationMode::ExternalPostprocess {
4962            program: relative_hook,
4963        };
4964        let run_dir = tempdir().unwrap();
4965        let artifact_dir = write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4966        let manifest: serde_json::Value = serde_json::from_str(
4967            &fs::read_to_string(artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE)).unwrap(),
4968        )
4969        .unwrap();
4970        assert_eq!(
4971            manifest["run_options"]["g8r_evaluation_mode"]["program"],
4972            std::fs::canonicalize(&hook).unwrap().display().to_string()
4973        );
4974    }
4975
4976    #[test]
4977    fn durable_artifact_rejects_malformed_action_records() {
4978        let pkg = parse_pkg(
4979            r#"package sample
4980
4981top fn f(x: bits[8] id=1) -> bits[8] {
4982  dead: bits[8] = identity(x, id=2)
4983  ret live: bits[8] = identity(x, id=3)
4984}
4985"#,
4986        );
4987        let f = pkg.get_fn("f").unwrap().clone();
4988        let artifact = run_pir_mcmc_with_artifact_using_transforms(
4989            f,
4990            test_run_options(Objective::Nodes),
4991            vec![Box::new(RemoveDeadNodeTestTransform)],
4992        )
4993        .unwrap()
4994        .artifact;
4995        let run_dir = tempdir().unwrap();
4996        let artifact_dir = write_pir_mcmc_artifact_dir(&artifact, &pkg, run_dir.path()).unwrap();
4997        let manifest_path = artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE);
4998        let mut manifest: serde_json::Value =
4999            serde_json::from_str(&fs::read_to_string(&manifest_path).unwrap()).unwrap();
5000        manifest["winning_provenance"][0]["transform_kind"] = serde_json::Value::Null;
5001        fs::write(
5002            &manifest_path,
5003            serde_json::to_string_pretty(&manifest).unwrap(),
5004        )
5005        .unwrap();
5006
5007        let err = read_pir_mcmc_artifact_dir(run_dir.path())
5008            .err()
5009            .expect("malformed action record must be rejected");
5010        assert!(
5011            err.to_string().contains("missing transform_kind"),
5012            "unexpected error: {err}"
5013        );
5014    }
5015
5016    #[test]
5017    fn durable_artifact_rejects_malformed_or_old_schema_manifests() {
5018        let malformed_dir = tempdir().unwrap();
5019        let malformed_artifact_dir = malformed_dir.path().join(PIR_MCMC_ARTIFACT_DIR_NAME);
5020        fs::create_dir_all(&malformed_artifact_dir).unwrap();
5021        fs::write(
5022            malformed_artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE),
5023            "{not json",
5024        )
5025        .unwrap();
5026        assert!(read_pir_mcmc_artifact_dir(malformed_dir.path()).is_err());
5027
5028        let incomplete_dir = tempdir().unwrap();
5029        let incomplete_artifact_dir = incomplete_dir.path().join(PIR_MCMC_ARTIFACT_DIR_NAME);
5030        fs::create_dir_all(&incomplete_artifact_dir).unwrap();
5031        fs::write(
5032            incomplete_artifact_dir.join(PIR_MCMC_ARTIFACT_MANIFEST_FILE),
5033            r#"{
5034  "schema_version": 1,
5035  "top_fn_name": "f",
5036  "run_options": {
5037    "max_iters": 1,
5038    "threads": 1,
5039    "chain_strategy": "independent",
5040    "checkpoint_iters": 1,
5041    "progress_iters": 0,
5042    "seed": 1,
5043    "initial_temperature": 1.0,
5044    "objective": "nodes",
5045    "extension_costing_mode": "preserve",
5046    "max_allowed_depth": null,
5047    "max_allowed_area": null,
5048    "switching_beta1": 1.0,
5049    "switching_beta2": 0.0,
5050    "switching_primary_output_load": 1.0,
5051    "enable_formal_oracle": false
5052  },
5053  "origin": {"file": "states/origin.ir", "cost": {"pir_nodes": 2, "g8r_nodes": 2, "g8r_depth": 2, "g8r_le_graph_milli": 0, "g8r_gate_output_toggles": 0, "g8r_weighted_switching_milli": 0}},
5054  "raw_winner": {"file": "states/raw-winner.ir", "cost": {"pir_nodes": 1, "g8r_nodes": 1, "g8r_depth": 1, "g8r_le_graph_milli": 0, "g8r_gate_output_toggles": 0, "g8r_weighted_switching_milli": 0}},
5055  "winning_lineage": []
5056}"#,
5057        )
5058        .unwrap();
5059        let err = read_pir_mcmc_artifact_dir(incomplete_dir.path())
5060            .err()
5061            .expect("old-schema artifact must be rejected");
5062        assert!(
5063            err.to_string().contains("schema version 1"),
5064            "unexpected error: {err}"
5065        );
5066    }
5067
5068    #[test]
5069    fn pir_equiv_oracle_rejects_obviously_non_equivalent_rewire() {
5070        // Build a tiny function with a literal so the rewire can substitute a
5071        // same-typed node but change semantics.
5072        let ir_text = r#"fn add_lit(x: bits[8] id=10, y: bits[8] id=20) -> bits[8] {
5073  literal.30: bits[8] = literal(value=0, id=30)
5074  ret add.42: bits[8] = add(x, y, id=42)
5075}"#;
5076        let mut parser1 = ir_parser::Parser::new(ir_text);
5077        let orig_fn = parser1.parse_fn().unwrap();
5078        let mut parser2 = ir_parser::Parser::new(ir_text);
5079        let mut rewired_fn = parser2.parse_fn().unwrap();
5080
5081        // Rewire the RHS operand of add.42 from y to literal.30 (same type).
5082        // This should change semantics for the all-ones test vector.
5083        let mut add_ref = None;
5084        let mut lit_ref = None;
5085        for nr in rewired_fn.node_refs() {
5086            let node = rewired_fn.get_node(nr);
5087            match &node.payload {
5088                xlsynth_pir::ir::NodePayload::Literal(_) => {
5089                    lit_ref = Some(nr);
5090                }
5091                xlsynth_pir::ir::NodePayload::Binop(xlsynth_pir::ir::Binop::Add, _, _) => {
5092                    add_ref = Some(nr);
5093                }
5094                _ => {}
5095            }
5096        }
5097        let add_ref = add_ref.expect("expected add node");
5098        let lit_ref = lit_ref.expect("expected literal node");
5099
5100        let old_add_payload = rewired_fn.get_node(add_ref).payload.clone();
5101        rewired_fn.get_node_mut(add_ref).payload = remap_payload_with(
5102            &old_add_payload,
5103            |(slot, dep)| {
5104                if slot == 1 { lit_ref } else { dep }
5105            },
5106        );
5107
5108        let mut rng = Pcg64Mcg::seed_from_u64(1);
5109        let mut baseline_cache = EvalFnBaselineResults::default();
5110        assert!(!pir_equiv_oracle(
5111            &orig_fn,
5112            &rewired_fn,
5113            &mut rng,
5114            4,
5115            /* enable_formal_oracle= */ false,
5116            &mut baseline_cache,
5117        ));
5118    }
5119
5120    #[test]
5121    fn pir_equiv_oracle_rejects_zero_length_array_params_without_panic() {
5122        let ir_text = r#"fn zero_len(a: bits[8][0] id=10) -> bits[1] {
5123  ret literal.20: bits[1] = literal(value=0, id=20)
5124}"#;
5125        let mut parser1 = ir_parser::Parser::new(ir_text);
5126        let lhs = parser1.parse_fn().unwrap();
5127        let mut parser2 = ir_parser::Parser::new(ir_text);
5128        let rhs = parser2.parse_fn().unwrap();
5129
5130        let mut rng = Pcg64Mcg::seed_from_u64(1);
5131        let mut baseline_cache = EvalFnBaselineResults::default();
5132        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
5133            pir_equiv_oracle(
5134                &lhs,
5135                &rhs,
5136                &mut rng,
5137                4,
5138                /* enable_formal_oracle= */ false,
5139                &mut baseline_cache,
5140            )
5141        }));
5142
5143        assert!(result.is_ok());
5144        assert!(!result.unwrap());
5145    }
5146
5147    #[test]
5148    fn optimize_pir_fn_via_xls_preserves_ext_nary_add_arch_via_ffi_wrappers() {
5149        let mut parser = ir_parser::Parser::new(
5150            r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5151  ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5152}"#,
5153        );
5154        let f = parser.parse_fn().unwrap();
5155
5156        let optimized =
5157            optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Preserve)
5158                .unwrap();
5159        let ext_nodes = optimized
5160            .nodes
5161            .iter()
5162            .filter(|node| {
5163                matches!(
5164                    &node.payload,
5165                    NodePayload::ExtNaryAdd {
5166                        arch: Some(ExtNaryAddArchitecture::BrentKung),
5167                        ..
5168                    }
5169                )
5170            })
5171            .count();
5172        assert_eq!(
5173            ext_nodes, 1,
5174            "expected optimized PIR to reconstruct brent_kung ext_nary_add:\n{}",
5175            optimized
5176        );
5177    }
5178
5179    #[test]
5180    fn canonical_g8r_scoring_input_materializes_the_optimized_top() {
5181        let f = parse_fn(
5182            r#"fn f(x: bits[8] id=1) -> bits[8] {
5183  dead: bits[8] = identity(x, id=2)
5184  ret live: bits[8] = identity(x, id=3)
5185}"#,
5186        );
5187
5188        let scoring_input =
5189            canonical_g8r_scoring_input_for_pir_fn(&f, ExtensionCostingMode::Preserve).unwrap();
5190        let mut parser = ir_parser::Parser::new(&scoring_input.ir_text);
5191        let pkg = parser.parse_and_validate_package().unwrap();
5192        let reparsed_top = pkg.get_top_fn().unwrap();
5193
5194        assert_eq!(reparsed_top.to_string(), scoring_input.top_fn.to_string());
5195        assert!(!scoring_input.ir_text.contains("dead:"));
5196    }
5197
5198    #[test]
5199    fn canonical_g8r_scoring_lowering_options_skip_summary_graph_le() {
5200        let source = CanonicalG8rOptions {
5201            compute_graph_logical_effort: true,
5202            graph_logical_effort_beta1: 2.0,
5203            ..CanonicalG8rOptions::default()
5204        };
5205
5206        let scoring = canonical_g8r_scoring_lowering_options(&source);
5207
5208        assert!(!scoring.compute_graph_logical_effort);
5209        assert_eq!(scoring.graph_logical_effort_beta1, 2.0);
5210    }
5211
5212    #[test]
5213    fn optimize_pir_fn_via_xls_can_desugar_ext_nary_add_to_standard_ir() {
5214        let mut parser = ir_parser::Parser::new(
5215            r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5216  ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5217}"#,
5218        );
5219        let f = parser.parse_fn().unwrap();
5220
5221        let optimized =
5222            optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Desugar).unwrap();
5223        let ext_nodes = optimized
5224            .nodes
5225            .iter()
5226            .filter(|node| matches!(&node.payload, NodePayload::ExtNaryAdd { .. }))
5227            .count();
5228        assert_eq!(
5229            ext_nodes, 0,
5230            "expected desugared optimized PIR to contain no ext_nary_add nodes:\n{}",
5231            optimized
5232        );
5233        assert!(
5234            !optimized.to_string().contains("ext_nary_add"),
5235            "expected desugared optimized PIR text to contain no extension op spelling:\n{}",
5236            optimized
5237        );
5238
5239        let cost = cost_with_effort_options_toggle_stimulus_and_extension_mode(
5240            &f,
5241            Objective::G8rNodes,
5242            None,
5243            &WeightedSwitchingOptions::default(),
5244            ExtensionCostingMode::Desugar,
5245        )
5246        .unwrap();
5247        assert!(cost.g8r_nodes > 0);
5248    }
5249
5250    #[test]
5251    fn get_pir_transforms_for_run_prunes_unsafe_only_classes_without_formal_oracle() {
5252        let no_oracle_kinds: HashSet<PirTransformKind> =
5253            get_pir_transforms_for_run(/* enable_formal_oracle= */ false)
5254                .into_iter()
5255                .map(|t| t.kind())
5256                .collect();
5257        let oracle_kinds: HashSet<PirTransformKind> =
5258            get_pir_transforms_for_run(/* enable_formal_oracle= */ true)
5259                .into_iter()
5260                .map(|t| t.kind())
5261                .collect();
5262
5263        assert!(!no_oracle_kinds.contains(&PirTransformKind::ShiftHoist));
5264        assert!(!no_oracle_kinds.contains(&PirTransformKind::MaskOperandHighBit));
5265        assert!(!no_oracle_kinds.contains(&PirTransformKind::RewireOperandToSameType));
5266        assert!(!no_oracle_kinds.contains(&PirTransformKind::GuardedPredicateRewire));
5267        assert!(no_oracle_kinds.contains(&PirTransformKind::AbsorbAddOperandIntoExtNaryAdd));
5268        assert!(no_oracle_kinds.contains(&PirTransformKind::AddToExtNaryAdd));
5269        assert!(no_oracle_kinds.contains(&PirTransformKind::ReduceSelDistribute));
5270
5271        assert!(oracle_kinds.contains(&PirTransformKind::ShiftHoist));
5272        assert!(oracle_kinds.contains(&PirTransformKind::MaskOperandHighBit));
5273        assert!(oracle_kinds.contains(&PirTransformKind::RewireOperandToSameType));
5274        assert!(oracle_kinds.contains(&PirTransformKind::GuardedPredicateRewire));
5275    }
5276
5277    #[test]
5278    fn ext_nary_add_arch_reaches_g8r_tags_after_xls_round_trip() {
5279        let mut parser = ir_parser::Parser::new(
5280            r#"fn f(a: bits[8] id=1, b: bits[8] id=2, c: bits[8] id=3) -> bits[8] {
5281  ret sum: bits[8] = ext_nary_add(a, b, c, signed=[false, false, false], negated=[false, false, false], arch=brent_kung, id=4)
5282}"#,
5283        );
5284        let f = parser.parse_fn().unwrap();
5285
5286        let optimized =
5287            optimize_pir_fn_via_xls_with_extension_mode(&f, ExtensionCostingMode::Preserve)
5288                .unwrap();
5289        let ext_text_id = optimized
5290            .nodes
5291            .iter()
5292            .find_map(|node| match &node.payload {
5293                NodePayload::ExtNaryAdd {
5294                    arch: Some(ExtNaryAddArchitecture::BrentKung),
5295                    ..
5296                } => Some(node.text_id),
5297                _ => None,
5298            })
5299            .expect("expected reconstructed brent_kung ext_nary_add");
5300
5301        let gatify_output =
5302            ir2gate::gatify(&optimized, GatifyOptions::all_opts_disabled()).unwrap();
5303        let gate_fn_text = gatify_output.gate_fn.to_string();
5304        assert!(
5305            gate_fn_text.contains(&format!(
5306                "ext_nary_add_{}_brent_kung_output_bit_",
5307                ext_text_id
5308            )),
5309            "expected gatify tags to reflect the ext_nary_add arch after XLS round-trip:\n{}",
5310            gate_fn_text
5311        );
5312    }
5313
5314    #[test]
5315    fn objective_metric_toggles_product_saturates() {
5316        let c = Cost {
5317            pir_nodes: 0,
5318            g8r_nodes: usize::MAX,
5319            g8r_depth: usize::MAX,
5320            g8r_le_graph_milli: 0,
5321            g8r_gate_output_toggles: 2,
5322            g8r_weighted_switching_milli: 0,
5323            g8r_post_and_nodes: 0,
5324            g8r_post_depth: 0,
5325            g8r_post_le_graph_milli: 0,
5326            g8r_post_gate_output_toggles: 0,
5327            g8r_post_weighted_switching_milli: 0,
5328        };
5329        assert_eq!(
5330            Objective::G8rNodesTimesDepthTimesToggles.metric(&c),
5331            u128::MAX
5332        );
5333    }
5334
5335    #[test]
5336    fn objective_metric_graph_logical_effort_times_nodes_multiplies_nodes() {
5337        let c = Cost {
5338            pir_nodes: 0,
5339            g8r_nodes: 7,
5340            g8r_depth: 11,
5341            g8r_le_graph_milli: 13,
5342            g8r_gate_output_toggles: 0,
5343            g8r_weighted_switching_milli: 0,
5344            g8r_post_and_nodes: 0,
5345            g8r_post_depth: 0,
5346            g8r_post_le_graph_milli: 0,
5347            g8r_post_gate_output_toggles: 0,
5348            g8r_post_weighted_switching_milli: 0,
5349        };
5350        assert_eq!(Objective::G8rLeGraphTimesNodes.metric(&c), 91);
5351    }
5352
5353    #[test]
5354    fn objective_metric_graph_logical_effort_times_nodes_handles_large_values() {
5355        let c = Cost {
5356            pir_nodes: 0,
5357            g8r_nodes: usize::MAX,
5358            g8r_depth: 0,
5359            g8r_le_graph_milli: usize::MAX,
5360            g8r_gate_output_toggles: 0,
5361            g8r_weighted_switching_milli: 0,
5362            g8r_post_and_nodes: 0,
5363            g8r_post_depth: 0,
5364            g8r_post_le_graph_milli: 0,
5365            g8r_post_gate_output_toggles: 0,
5366            g8r_post_weighted_switching_milli: 0,
5367        };
5368        assert_eq!(
5369            Objective::G8rLeGraphTimesNodes.metric(&c),
5370            (usize::MAX as u128) * (usize::MAX as u128)
5371        );
5372    }
5373
5374    #[test]
5375    fn graph_logical_effort_times_nodes_uses_graph_effort_without_toggles() {
5376        let objective = Objective::G8rLeGraphTimesNodes;
5377        assert!(objective.uses_g8r_costing());
5378        assert!(objective.needs_graph_logical_effort());
5379        assert!(!objective.needs_toggle_stimulus());
5380    }
5381
5382    #[test]
5383    fn objective_metric_nodes_times_weighted_switching_saturates() {
5384        let c = Cost {
5385            pir_nodes: 0,
5386            g8r_nodes: usize::MAX,
5387            g8r_depth: 0,
5388            g8r_le_graph_milli: 0,
5389            g8r_gate_output_toggles: 0,
5390            g8r_weighted_switching_milli: u128::MAX,
5391            g8r_post_and_nodes: 0,
5392            g8r_post_depth: 0,
5393            g8r_post_le_graph_milli: 0,
5394            g8r_post_gate_output_toggles: 0,
5395            g8r_post_weighted_switching_milli: 0,
5396        };
5397        assert_eq!(
5398            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress.metric(&c),
5399            u128::MAX
5400        );
5401    }
5402
5403    #[test]
5404    fn g8r_post_objectives_use_postprocessed_cost_fields() {
5405        let c = Cost {
5406            pir_nodes: 1,
5407            g8r_nodes: 2,
5408            g8r_depth: 3,
5409            g8r_le_graph_milli: 5,
5410            g8r_gate_output_toggles: 7,
5411            g8r_weighted_switching_milli: 11,
5412            g8r_post_and_nodes: 13,
5413            g8r_post_depth: 17,
5414            g8r_post_le_graph_milli: 19,
5415            g8r_post_gate_output_toggles: 23,
5416            g8r_post_weighted_switching_milli: 29,
5417        };
5418        assert_eq!(Objective::G8rPostAndNodes.metric(&c), 13);
5419        assert_eq!(Objective::G8rPostAndNodesTimesDepth.metric(&c), 221);
5420        assert_eq!(
5421            Objective::G8rPostAndNodesTimesDepthTimesToggles.metric(&c),
5422            5083
5423        );
5424        assert_eq!(Objective::G8rPostLeGraph.metric(&c), 19);
5425        assert_eq!(Objective::G8rPostLeGraphTimesAndNodes.metric(&c), 247);
5426        assert_eq!(Objective::G8rPostLeGraphTimesProduct.metric(&c), 4199);
5427        assert_eq!(Objective::G8rPostWeightedSwitching.metric(&c), 29);
5428        assert_eq!(
5429            Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress.metric(&c),
5430            377
5431        );
5432        for objective in [
5433            Objective::G8rPostAndNodes,
5434            Objective::G8rPostAndNodesTimesDepth,
5435            Objective::G8rPostAndNodesTimesDepthTimesToggles,
5436            Objective::G8rPostLeGraph,
5437            Objective::G8rPostLeGraphTimesAndNodes,
5438            Objective::G8rPostLeGraphTimesProduct,
5439            Objective::G8rPostWeightedSwitching,
5440            Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
5441        ] {
5442            assert_eq!(
5443                Objective::from_value_name(objective.value_name()).unwrap(),
5444                objective
5445            );
5446        }
5447    }
5448
5449    #[test]
5450    fn search_score_prefers_feasible_then_lower_violation_then_objective() {
5451        let feasible = search_score(
5452            &Cost {
5453                pir_nodes: 0,
5454                g8r_nodes: 10,
5455                g8r_depth: 10,
5456                g8r_le_graph_milli: 0,
5457                g8r_gate_output_toggles: 0,
5458                g8r_weighted_switching_milli: 0,
5459                g8r_post_and_nodes: 0,
5460                g8r_post_depth: 0,
5461                g8r_post_le_graph_milli: 0,
5462                g8r_post_gate_output_toggles: 0,
5463                g8r_post_weighted_switching_milli: 0,
5464            },
5465            Objective::G8rNodes,
5466            ConstraintLimits {
5467                max_delay: Some(12),
5468                max_area: None,
5469            },
5470        );
5471        let mildly_infeasible = search_score(
5472            &Cost {
5473                pir_nodes: 0,
5474                g8r_nodes: 10,
5475                g8r_depth: 13,
5476                g8r_le_graph_milli: 0,
5477                g8r_gate_output_toggles: 0,
5478                g8r_weighted_switching_milli: 0,
5479                g8r_post_and_nodes: 0,
5480                g8r_post_depth: 0,
5481                g8r_post_le_graph_milli: 0,
5482                g8r_post_gate_output_toggles: 0,
5483                g8r_post_weighted_switching_milli: 0,
5484            },
5485            Objective::G8rNodes,
5486            ConstraintLimits {
5487                max_delay: Some(12),
5488                max_area: None,
5489            },
5490        );
5491        let badly_infeasible = search_score(
5492            &Cost {
5493                pir_nodes: 0,
5494                g8r_depth: 16,
5495                g8r_nodes: 10,
5496                g8r_le_graph_milli: 0,
5497                g8r_gate_output_toggles: 0,
5498                g8r_weighted_switching_milli: 0,
5499                g8r_post_and_nodes: 0,
5500                g8r_post_depth: 0,
5501                g8r_post_le_graph_milli: 0,
5502                g8r_post_gate_output_toggles: 0,
5503                g8r_post_weighted_switching_milli: 0,
5504            },
5505            Objective::G8rNodes,
5506            ConstraintLimits {
5507                max_delay: Some(12),
5508                max_area: None,
5509            },
5510        );
5511
5512        assert!(feasible < mildly_infeasible);
5513        assert!(mildly_infeasible < badly_infeasible);
5514    }
5515
5516    #[test]
5517    fn effective_constraint_limits_respects_non_regressing_depth() {
5518        let initial = Cost {
5519            pir_nodes: 0,
5520            g8r_nodes: 10,
5521            g8r_depth: 17,
5522            g8r_le_graph_milli: 0,
5523            g8r_gate_output_toggles: 0,
5524            g8r_weighted_switching_milli: 0,
5525            g8r_post_and_nodes: 0,
5526            g8r_post_depth: 0,
5527            g8r_post_le_graph_milli: 0,
5528            g8r_post_gate_output_toggles: 0,
5529            g8r_post_weighted_switching_milli: 0,
5530        };
5531        let got = effective_constraint_limits(
5532            Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress,
5533            ConstraintLimits {
5534                max_delay: Some(20),
5535                max_area: None,
5536            },
5537            &initial,
5538        );
5539        assert_eq!(got.max_delay, Some(17));
5540        assert_eq!(got.max_area, None);
5541    }
5542
5543    #[test]
5544    fn postprocessed_constraints_use_post_area_and_depth() {
5545        let c = Cost {
5546            pir_nodes: 0,
5547            g8r_nodes: 1,
5548            g8r_depth: 1,
5549            g8r_le_graph_milli: 0,
5550            g8r_gate_output_toggles: 0,
5551            g8r_weighted_switching_milli: 0,
5552            g8r_post_and_nodes: 11,
5553            g8r_post_depth: 17,
5554            g8r_post_le_graph_milli: 0,
5555            g8r_post_gate_output_toggles: 0,
5556            g8r_post_weighted_switching_milli: 0,
5557        };
5558        let score = search_score(
5559            &c,
5560            Objective::G8rPostAndNodes,
5561            ConstraintLimits {
5562                max_delay: Some(12),
5563                max_area: None,
5564            },
5565        );
5566        assert_eq!(
5567            score.violation,
5568            Some(ConstraintViolationScore {
5569                delay_over: Some(5),
5570                area_over: None,
5571            })
5572        );
5573        let got = effective_constraint_limits(
5574            Objective::G8rPostAndNodesTimesWeightedSwitchingNoDepthRegress,
5575            ConstraintLimits {
5576                max_delay: Some(20),
5577                max_area: None,
5578            },
5579            &c,
5580        );
5581        assert_eq!(got.max_delay, Some(17));
5582    }
5583
5584    #[test]
5585    fn g8r_post_costing_requires_external_postprocessor() {
5586        let f = parse_fn(
5587            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5588  ret and.3: bits[1] = and(a, b, id=3)
5589}"#,
5590        );
5591        let err = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5592            &f,
5593            Objective::G8rPostAndNodes,
5594            None,
5595            &WeightedSwitchingOptions::default(),
5596            ExtensionCostingMode::Preserve,
5597            &G8rEvaluationMode::Builtin,
5598        )
5599        .unwrap_err();
5600        assert!(
5601            err.to_string()
5602                .contains("requires an external g8r postprocessor"),
5603            "unexpected error: {err}"
5604        );
5605    }
5606
5607    #[test]
5608    fn external_postprocessor_identity_populates_post_stats() {
5609        let temp = tempdir().unwrap();
5610        let hook =
5611            write_executable_script(temp.path(), "identity.sh", "#!/bin/sh\ncp \"$1\" \"$3\"\n");
5612        let f = parse_fn(
5613            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5614  ret and.3: bits[1] = and(a, b, id=3)
5615}"#,
5616        );
5617        let c = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5618            &f,
5619            Objective::G8rPostAndNodes,
5620            None,
5621            &WeightedSwitchingOptions::default(),
5622            ExtensionCostingMode::Preserve,
5623            &G8rEvaluationMode::ExternalPostprocess {
5624                program: hook.display().to_string(),
5625            },
5626        )
5627        .unwrap();
5628        assert!(c.g8r_post_and_nodes > 0);
5629        assert!(c.g8r_post_depth > 0);
5630    }
5631
5632    #[test]
5633    fn external_postprocessor_repacked_interface_supports_toggle_metrics() {
5634        let temp = tempdir().unwrap();
5635        let hook =
5636            write_executable_script(temp.path(), "identity.sh", "#!/bin/sh\ncp \"$1\" \"$3\"\n");
5637        let f = parse_fn(
5638            r#"fn f(a: bits[2] id=1, b: bits[2] id=2) -> bits[2] {
5639  ret and.3: bits[2] = and(a, b, id=3)
5640}"#,
5641        );
5642        let samples = vec![
5643            IrValue::parse_typed("(bits[2]:0b00, bits[2]:0b11)").unwrap(),
5644            IrValue::parse_typed("(bits[2]:0b11, bits[2]:0b11)").unwrap(),
5645            IrValue::parse_typed("(bits[2]:0b01, bits[2]:0b11)").unwrap(),
5646        ];
5647        let toggle_stimulus = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5648        let post_mode = G8rEvaluationMode::ExternalPostprocess {
5649            program: hook.display().to_string(),
5650        };
5651        let raw_toggles = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5652            &f,
5653            Objective::G8rNodesTimesDepthTimesToggles,
5654            Some(&toggle_stimulus),
5655            &WeightedSwitchingOptions::default(),
5656            ExtensionCostingMode::Preserve,
5657            &G8rEvaluationMode::Builtin,
5658        )
5659        .unwrap();
5660        let post_toggles = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5661            &f,
5662            Objective::G8rPostAndNodesTimesDepthTimesToggles,
5663            Some(&toggle_stimulus),
5664            &WeightedSwitchingOptions::default(),
5665            ExtensionCostingMode::Preserve,
5666            &post_mode,
5667        )
5668        .unwrap();
5669        assert!(post_toggles.g8r_post_gate_output_toggles > 0);
5670        assert_eq!(
5671            post_toggles.g8r_post_gate_output_toggles,
5672            raw_toggles.g8r_gate_output_toggles
5673        );
5674
5675        let raw_weighted = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5676            &f,
5677            Objective::G8rWeightedSwitching,
5678            Some(&toggle_stimulus),
5679            &WeightedSwitchingOptions::default(),
5680            ExtensionCostingMode::Preserve,
5681            &G8rEvaluationMode::Builtin,
5682        )
5683        .unwrap();
5684        let post_weighted = cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5685            &f,
5686            Objective::G8rPostWeightedSwitching,
5687            Some(&toggle_stimulus),
5688            &WeightedSwitchingOptions::default(),
5689            ExtensionCostingMode::Preserve,
5690            &post_mode,
5691        )
5692        .unwrap();
5693        assert!(post_weighted.g8r_post_weighted_switching_milli > 0);
5694        assert_eq!(
5695            post_weighted.g8r_post_weighted_switching_milli,
5696            raw_weighted.g8r_weighted_switching_milli
5697        );
5698    }
5699
5700    #[test]
5701    fn external_postprocessor_reports_failures() {
5702        let temp = tempdir().unwrap();
5703        let fail = write_executable_script(
5704            temp.path(),
5705            "fail.sh",
5706            "#!/bin/sh\necho broken >&2\nexit 7\n",
5707        );
5708        let missing = write_executable_script(temp.path(), "missing.sh", "#!/bin/sh\nexit 0\n");
5709        let malformed = write_executable_script(
5710            temp.path(),
5711            "malformed.sh",
5712            "#!/bin/sh\nprintf nope > \"$3\"\n",
5713        );
5714        let f = parse_fn(
5715            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5716  ret and.3: bits[1] = and(a, b, id=3)
5717}"#,
5718        );
5719        let err = |program: &Path| {
5720            cost_with_effort_options_toggle_stimulus_extension_mode_and_evaluator(
5721                &f,
5722                Objective::G8rPostAndNodes,
5723                None,
5724                &WeightedSwitchingOptions::default(),
5725                ExtensionCostingMode::Preserve,
5726                &G8rEvaluationMode::ExternalPostprocess {
5727                    program: program.display().to_string(),
5728                },
5729            )
5730            .unwrap_err()
5731            .to_string()
5732        };
5733        assert!(err(&fail).contains("broken"));
5734        assert!(err(&missing).contains("did not create"));
5735        assert!(err(&malformed).contains("failed to load"));
5736    }
5737
5738    #[test]
5739    fn parse_irvals_tuple_lines_accepts_valid_tuples() {
5740        let text = "(bits[1]:0, bits[1]:1)\n(bits[1]:1, bits[1]:1)\n";
5741        let got = parse_irvals_tuple_lines(text).unwrap();
5742        assert_eq!(got.len(), 2);
5743    }
5744
5745    #[test]
5746    fn parse_irvals_tuple_lines_rejects_invalid_or_non_tuple_lines() {
5747        let bad_parse = parse_irvals_tuple_lines("not_a_value\n").unwrap_err();
5748        assert!(
5749            bad_parse.to_string().contains("line 1"),
5750            "expected line number in parse error"
5751        );
5752
5753        let non_tuple = parse_irvals_tuple_lines("bits[1]:1\n").unwrap_err();
5754        assert!(
5755            non_tuple.to_string().contains("not a tuple"),
5756            "expected tuple-specific error"
5757        );
5758    }
5759
5760    #[test]
5761    fn lower_toggle_stimulus_rejects_arity_and_type_mismatch() {
5762        let mut parser = ir_parser::Parser::new(
5763            r#"fn f(a: bits[1] id=1, b: bits[2] id=2) -> bits[1] {
5764  ret identity.3: bits[1] = identity(a, id=3)
5765}"#,
5766        );
5767        let f = parser.parse_fn().unwrap();
5768
5769        let arity_bad = vec![
5770            IrValue::parse_typed("(bits[1]:0, bits[2]:0, bits[1]:0)").unwrap(),
5771            IrValue::parse_typed("(bits[1]:1, bits[2]:1, bits[1]:1)").unwrap(),
5772        ];
5773        assert!(lower_toggle_stimulus_for_fn(&arity_bad, &f).is_err());
5774
5775        let type_bad = vec![
5776            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5777            IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5778        ];
5779        assert!(lower_toggle_stimulus_for_fn(&type_bad, &f).is_err());
5780    }
5781
5782    #[test]
5783    fn run_pir_mcmc_rejects_invalid_toggle_stimulus_usage() {
5784        let mut parser = ir_parser::Parser::new(
5785            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5786  ret and.3: bits[1] = and(a, b, id=3)
5787}"#,
5788        );
5789        let f = parser.parse_fn().unwrap();
5790
5791        let opts_missing = RunOptions {
5792            max_iters: 1,
5793            threads: 1,
5794            chain_strategy: ChainStrategy::Independent,
5795            checkpoint_iters: 1,
5796            progress_iters: 0,
5797            seed: 1,
5798            initial_temperature: 1.0,
5799            objective: Objective::G8rNodesTimesDepthTimesToggles,
5800            extension_costing_mode: ExtensionCostingMode::Preserve,
5801            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5802            canonical_g8r_options: CanonicalG8rOptions::default(),
5803            max_allowed_depth: None,
5804            max_allowed_area: None,
5805            weighted_switching_options: WeightedSwitchingOptions::default(),
5806            enable_formal_oracle: false,
5807            trajectory_dir: None,
5808            toggle_stimulus: None,
5809        };
5810        assert!(run_pir_mcmc(f.clone(), opts_missing).is_err());
5811
5812        let opts_wrong_objective = RunOptions {
5813            max_iters: 1,
5814            threads: 1,
5815            chain_strategy: ChainStrategy::Independent,
5816            checkpoint_iters: 1,
5817            progress_iters: 0,
5818            seed: 1,
5819            initial_temperature: 1.0,
5820            objective: Objective::Nodes,
5821            extension_costing_mode: ExtensionCostingMode::Preserve,
5822            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5823            canonical_g8r_options: CanonicalG8rOptions::default(),
5824            max_allowed_depth: None,
5825            max_allowed_area: None,
5826            weighted_switching_options: WeightedSwitchingOptions::default(),
5827            enable_formal_oracle: false,
5828            trajectory_dir: None,
5829            toggle_stimulus: Some(vec![
5830                IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5831                IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5832            ]),
5833        };
5834        assert!(run_pir_mcmc(f, opts_wrong_objective).is_err());
5835    }
5836
5837    #[test]
5838    fn run_pir_mcmc_rejects_caps_with_nodes_objective() {
5839        let mut parser = ir_parser::Parser::new(
5840            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5841  ret and.3: bits[1] = and(a, b, id=3)
5842}"#,
5843        );
5844        let f = parser.parse_fn().unwrap();
5845
5846        let opts = RunOptions {
5847            max_iters: 1,
5848            threads: 1,
5849            chain_strategy: ChainStrategy::Independent,
5850            checkpoint_iters: 1,
5851            progress_iters: 0,
5852            seed: 1,
5853            initial_temperature: 1.0,
5854            objective: Objective::Nodes,
5855            extension_costing_mode: ExtensionCostingMode::Preserve,
5856            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5857            canonical_g8r_options: CanonicalG8rOptions::default(),
5858            max_allowed_depth: Some(10),
5859            max_allowed_area: None,
5860            weighted_switching_options: WeightedSwitchingOptions::default(),
5861            enable_formal_oracle: false,
5862            trajectory_dir: None,
5863            toggle_stimulus: None,
5864        };
5865        assert!(run_pir_mcmc(f, opts).is_err());
5866    }
5867
5868    #[test]
5869    fn run_pir_mcmc_rejects_dual_caps() {
5870        let mut parser = ir_parser::Parser::new(
5871            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5872  ret and.3: bits[1] = and(a, b, id=3)
5873}"#,
5874        );
5875        let f = parser.parse_fn().unwrap();
5876
5877        let opts = RunOptions {
5878            max_iters: 1,
5879            threads: 1,
5880            chain_strategy: ChainStrategy::Independent,
5881            checkpoint_iters: 1,
5882            progress_iters: 0,
5883            seed: 1,
5884            initial_temperature: 1.0,
5885            objective: Objective::G8rNodes,
5886            extension_costing_mode: ExtensionCostingMode::Preserve,
5887            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5888            canonical_g8r_options: CanonicalG8rOptions::default(),
5889            max_allowed_depth: Some(10),
5890            max_allowed_area: Some(10),
5891            weighted_switching_options: WeightedSwitchingOptions::default(),
5892            enable_formal_oracle: false,
5893            trajectory_dir: None,
5894            toggle_stimulus: None,
5895        };
5896        assert!(run_pir_mcmc(f, opts).is_err());
5897    }
5898
5899    #[test]
5900    fn run_pir_mcmc_rejects_area_cap_with_non_regressing_depth_objective() {
5901        let mut parser = ir_parser::Parser::new(
5902            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5903  ret and.3: bits[1] = and(a, b, id=3)
5904}"#,
5905        );
5906        let f = parser.parse_fn().unwrap();
5907
5908        let opts = RunOptions {
5909            max_iters: 1,
5910            threads: 1,
5911            chain_strategy: ChainStrategy::Independent,
5912            checkpoint_iters: 1,
5913            progress_iters: 0,
5914            seed: 1,
5915            initial_temperature: 1.0,
5916            objective: Objective::G8rNodesTimesWeightedSwitchingNoDepthRegress,
5917            extension_costing_mode: ExtensionCostingMode::Preserve,
5918            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
5919            canonical_g8r_options: CanonicalG8rOptions::default(),
5920            max_allowed_depth: None,
5921            max_allowed_area: Some(10),
5922            weighted_switching_options: WeightedSwitchingOptions::default(),
5923            enable_formal_oracle: false,
5924            trajectory_dir: None,
5925            toggle_stimulus: Some(vec![
5926                IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5927                IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5928            ]),
5929        };
5930        assert!(run_pir_mcmc(f, opts).is_err());
5931    }
5932
5933    #[test]
5934    fn cost_with_toggle_objective_populates_toggle_count() {
5935        let mut parser = ir_parser::Parser::new(
5936            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5937  ret and.3: bits[1] = and(a, b, id=3)
5938}"#,
5939        );
5940        let f = parser.parse_fn().unwrap();
5941        let samples = vec![
5942            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5943            IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5944            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5945        ];
5946        let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5947        let c = cost_with_toggle_stimulus(
5948            &f,
5949            Objective::G8rNodesTimesDepthTimesToggles,
5950            Some(&lowered),
5951        )
5952        .unwrap();
5953        assert!(
5954            c.g8r_gate_output_toggles > 0,
5955            "expected positive interior toggle count"
5956        );
5957    }
5958
5959    #[test]
5960    fn cost_with_weighted_switching_objective_populates_weighted_metric() {
5961        let mut parser = ir_parser::Parser::new(
5962            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5963  ret and.3: bits[1] = and(a, b, id=3)
5964}"#,
5965        );
5966        let f = parser.parse_fn().unwrap();
5967        let samples = vec![
5968            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5969            IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5970            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5971        ];
5972        let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5973        let c = cost_with_effort_options_and_toggle_stimulus(
5974            &f,
5975            Objective::G8rWeightedSwitching,
5976            Some(&lowered),
5977            &WeightedSwitchingOptions::default(),
5978        )
5979        .unwrap();
5980        assert!(
5981            c.g8r_weighted_switching_milli > 0,
5982            "expected positive weighted switching estimate"
5983        );
5984    }
5985
5986    #[test]
5987    fn cost_with_weighted_switching_rejects_non_finite_options() {
5988        let mut parser = ir_parser::Parser::new(
5989            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
5990  ret and.3: bits[1] = and(a, b, id=3)
5991}"#,
5992        );
5993        let f = parser.parse_fn().unwrap();
5994        let samples = vec![
5995            IrValue::parse_typed("(bits[1]:0, bits[1]:0)").unwrap(),
5996            IrValue::parse_typed("(bits[1]:1, bits[1]:1)").unwrap(),
5997        ];
5998        let lowered = lower_toggle_stimulus_for_fn(&samples, &f).unwrap();
5999
6000        let err = cost_with_effort_options_and_toggle_stimulus(
6001            &f,
6002            Objective::G8rWeightedSwitching,
6003            Some(&lowered),
6004            &WeightedSwitchingOptions {
6005                beta1: f64::NAN,
6006                beta2: 0.0,
6007                primary_output_load: 1.0,
6008            },
6009        )
6010        .unwrap_err();
6011        assert!(
6012            err.to_string().contains("must be finite"),
6013            "expected finite-coefficient validation error, got: {}",
6014            err
6015        );
6016    }
6017
6018    #[test]
6019    fn trajectory_json_preserves_large_u128_metrics() {
6020        let rec = serde_json::json!({
6021            "metric": u128::MAX,
6022            "g8r_weighted_switching_milli": u128::MAX,
6023        });
6024        let s = serde_json::to_string(&rec).unwrap();
6025        assert!(
6026            s.contains(&u128::MAX.to_string()),
6027            "expected full u128 JSON number in serialized output"
6028        );
6029    }
6030
6031    #[test]
6032    fn trajectory_json_emits_transform_mechanism() {
6033        let mut parser = ir_parser::Parser::new(
6034            r#"fn f(a: bits[1] id=1, b: bits[1] id=2) -> bits[1] {
6035  ret and.3: bits[1] = and(a, b, id=3)
6036}"#,
6037        );
6038        let f = parser.parse_fn().unwrap();
6039        let temp_dir = tempfile::tempdir().unwrap();
6040        let opts = RunOptions {
6041            max_iters: 1,
6042            threads: 1,
6043            chain_strategy: ChainStrategy::Independent,
6044            checkpoint_iters: 100,
6045            progress_iters: 0,
6046            seed: 1,
6047            initial_temperature: 1.0,
6048            objective: Objective::Nodes,
6049            extension_costing_mode: ExtensionCostingMode::Preserve,
6050            g8r_evaluation_mode: G8rEvaluationMode::Builtin,
6051            canonical_g8r_options: CanonicalG8rOptions::default(),
6052            max_allowed_depth: None,
6053            max_allowed_area: None,
6054            weighted_switching_options: WeightedSwitchingOptions::default(),
6055            enable_formal_oracle: false,
6056            trajectory_dir: Some(temp_dir.path().to_path_buf()),
6057            toggle_stimulus: None,
6058        };
6059
6060        let _ = run_pir_mcmc(f, opts).unwrap();
6061        let path = temp_dir.path().join("trajectory.c000.jsonl");
6062        let text = std::fs::read_to_string(path).unwrap();
6063        let first_line = text.lines().next().expect("trajectory line");
6064        let value: serde_json::Value = serde_json::from_str(first_line).unwrap();
6065        assert!(
6066            value.get("transform_mechanism").is_some(),
6067            "expected transform_mechanism in trajectory JSON: {first_line}"
6068        );
6069    }
6070}