Skip to main content

xlsynth_pir_compiler_runtime/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Runtime ABI and observable-event collection for compiled PIR functions.
4
5use std::ffi::c_void;
6use std::fmt;
7use std::marker::PhantomData;
8use std::ptr;
9
10use num_bigint::{BigInt, BigUint, Sign};
11
12/// Native compiled-function entrypoint shared by in-memory and AOT execution.
13pub type CompiledEntrypoint = unsafe extern "C" fn(
14    inputs: *const *const u8,
15    output: *mut u8,
16    scratch: *mut u8,
17    context: *mut RawExecutionContext,
18) -> i32;
19
20/// Opaque execution context forwarded by compiled code to runtime callbacks.
21#[repr(C)]
22pub struct RawExecutionContext {
23    private_state: *mut c_void,
24}
25
26/// Error returned by generated native-compiler entrypoint wrappers.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct RunError(pub String);
29
30impl fmt::Display for RunError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.write_str(&self.0)
33    }
34}
35
36impl std::error::Error for RunError {}
37
38macro_rules! define_native_bits {
39    ($name:ident, $carrier:ty, $carrier_bits:expr) => {
40        #[repr(transparent)]
41        #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
42        pub struct $name<const BIT_COUNT: usize>($carrier);
43
44        impl<const BIT_COUNT: usize> $name<BIT_COUNT> {
45            fn validate_width() -> Result<(), RunError> {
46                if BIT_COUNT == 0 || BIT_COUNT > $carrier_bits {
47                    Err(RunError(format!(
48                        "bits[{BIT_COUNT}] cannot use a {}-bit native carrier",
49                        $carrier_bits
50                    )))
51                } else {
52                    Ok(())
53                }
54            }
55
56            const fn mask() -> $carrier {
57                if BIT_COUNT == $carrier_bits {
58                    <$carrier>::MAX
59                } else {
60                    ((1 as $carrier) << BIT_COUNT) - 1
61                }
62            }
63
64            /// Constructs a canonical bitvector value, rejecting excess high bits.
65            pub fn new(value: $carrier) -> Result<Self, RunError> {
66                Self::validate_width()?;
67                if value & !Self::mask() != 0 {
68                    Err(RunError(format!(
69                        "value {value} does not fit in bits[{BIT_COUNT}]"
70                    )))
71                } else {
72                    Ok(Self(value))
73                }
74            }
75
76            /// Constructs a canonical bitvector by truncating high bits.
77            pub const fn wrapping(value: $carrier) -> Self {
78                assert!(
79                    BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
80                    "invalid native bits carrier width"
81                );
82                Self(value & Self::mask())
83            }
84
85            /// Returns the native carrier value.
86            pub const fn get(self) -> $carrier {
87                self.0
88            }
89
90            /// Returns the value widened to `u64`.
91            pub const fn to_u64(self) -> u64 {
92                self.0 as u64
93            }
94        }
95    };
96}
97
98define_native_bits!(BitsInU8, u8, 8);
99define_native_bits!(BitsInU16, u16, 16);
100define_native_bits!(BitsInU32, u32, 32);
101define_native_bits!(BitsInU64, u64, 64);
102
103/// Zero-sized native representation of a `bits[0]` value.
104#[repr(C)]
105#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
106pub struct Bits0;
107
108/// Public unsigned DSLX-style wrapper for a `bits[0]` value.
109#[repr(C)]
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
111pub struct UnsignedBits0;
112
113impl UnsignedBits0 {
114    /// Constructs the sole canonical raw `bits[0]` representation.
115    pub const fn from_raw_bits(value: u64) -> Self {
116        assert!(value == 0, "raw bits do not fit target width");
117        Self
118    }
119
120    /// Returns the sole unsigned `bits[0]` value widened to `u64`.
121    pub const fn to_u64(self) -> u64 {
122        0
123    }
124
125    /// Returns the raw ABI bits widened to `u64`.
126    pub const fn raw_bits(self) -> u64 {
127        0
128    }
129}
130
131impl TryFrom<u64> for UnsignedBits0 {
132    type Error = RunError;
133
134    fn try_from(value: u64) -> Result<Self, Self::Error> {
135        if value == 0 {
136            Ok(Self)
137        } else {
138            Err(RunError(format!("value {value} does not fit in bits[0]")))
139        }
140    }
141}
142
143/// Public signed DSLX-style wrapper for an `sbits[0]` value.
144#[repr(C)]
145#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
146pub struct SignedBits0;
147
148impl SignedBits0 {
149    /// Constructs the sole canonical raw `sbits[0]` representation.
150    pub const fn from_raw_bits(value: u64) -> Self {
151        assert!(value == 0, "raw bits do not fit target width");
152        Self
153    }
154
155    /// Returns the sole signed `sbits[0]` value widened to `i64`.
156    pub const fn to_i64(self) -> i64 {
157        0
158    }
159
160    /// Returns the raw ABI bits widened to `u64`.
161    pub const fn raw_bits(self) -> u64 {
162        0
163    }
164}
165
166impl TryFrom<i64> for SignedBits0 {
167    type Error = RunError;
168
169    fn try_from(value: i64) -> Result<Self, Self::Error> {
170        if value == 0 {
171            Ok(Self)
172        } else {
173            Err(RunError(format!("value {value} does not fit in s0")))
174        }
175    }
176}
177
178macro_rules! define_public_bits_wrappers {
179    (
180        $unsigned_name:ident,
181        $signed_name:ident,
182        $raw_name:ident,
183        $unsigned_carrier:ty,
184        $signed_carrier:ty,
185        $carrier_bits:expr
186    ) => {
187        #[repr(transparent)]
188        #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
189        pub struct $unsigned_name<const BIT_COUNT: usize>($raw_name<BIT_COUNT>);
190
191        impl<const BIT_COUNT: usize> $unsigned_name<BIT_COUNT> {
192            const fn mask() -> $unsigned_carrier {
193                if BIT_COUNT == $carrier_bits {
194                    <$unsigned_carrier>::MAX
195                } else {
196                    ((1 as $unsigned_carrier) << BIT_COUNT) - 1
197                }
198            }
199
200            /// Constructs an unsigned DSLX-style bit value, rejecting excess high bits.
201            pub fn new(value: $unsigned_carrier) -> Result<Self, RunError> {
202                Ok(Self($raw_name::<BIT_COUNT>::new(value)?))
203            }
204
205            /// Constructs an unsigned DSLX-style bit value by truncating high bits.
206            pub const fn wrapping(value: $unsigned_carrier) -> Self {
207                Self($raw_name::<BIT_COUNT>::wrapping(value))
208            }
209
210            /// Constructs an unsigned DSLX-style bit value from raw ABI bits.
211            pub const fn from_raw_bits(value: $unsigned_carrier) -> Self {
212                assert!(
213                    BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
214                    "invalid raw bit carrier width"
215                );
216                assert!(
217                    value & !Self::mask() == 0,
218                    "raw bits do not fit target width"
219                );
220                Self($raw_name::<BIT_COUNT>::wrapping(value))
221            }
222
223            /// Returns the unsigned native carrier value.
224            /// Returns the unsigned value widened to `u64`.
225            pub const fn to_u64(self) -> u64 {
226                self.0.to_u64()
227            }
228
229            /// Returns the raw ABI bits in the native carrier.
230            pub const fn raw_bits(self) -> $unsigned_carrier {
231                self.0.get()
232            }
233        }
234
235        impl<const BIT_COUNT: usize> TryFrom<u64> for $unsigned_name<BIT_COUNT> {
236            type Error = RunError;
237
238            fn try_from(value: u64) -> Result<Self, Self::Error> {
239                let carrier = <$unsigned_carrier>::try_from(value).map_err(|_| {
240                    RunError(format!(
241                        "value {value} does not fit in bits[{BIT_COUNT}]"
242                    ))
243                })?;
244                Self::new(carrier)
245            }
246        }
247
248        #[repr(transparent)]
249        #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
250        pub struct $signed_name<const BIT_COUNT: usize>($raw_name<BIT_COUNT>);
251
252        impl<const BIT_COUNT: usize> $signed_name<BIT_COUNT> {
253            fn validate_signed_value(value: $signed_carrier) -> Result<(), RunError> {
254                if BIT_COUNT == 0 || BIT_COUNT > $carrier_bits {
255                    return Err(RunError(format!(
256                        "s{BIT_COUNT} cannot use a {}-bit native carrier",
257                        $carrier_bits
258                    )));
259                }
260                let min = -(1i128 << (BIT_COUNT - 1));
261                let max = (1i128 << (BIT_COUNT - 1)) - 1;
262                let value = value as i128;
263                if value < min || value > max {
264                    Err(RunError(format!(
265                        "value {value} does not fit in s{BIT_COUNT}"
266                    )))
267                } else {
268                    Ok(())
269                }
270            }
271
272            const fn mask() -> $unsigned_carrier {
273                if BIT_COUNT == $carrier_bits {
274                    <$unsigned_carrier>::MAX
275                } else {
276                    ((1 as $unsigned_carrier) << BIT_COUNT) - 1
277                }
278            }
279
280            /// Constructs a signed DSLX-style bit value, rejecting out-of-range values.
281            pub fn new(value: $signed_carrier) -> Result<Self, RunError> {
282                Self::validate_signed_value(value)?;
283                Ok(Self($raw_name::<BIT_COUNT>::wrapping(
284                    value as $unsigned_carrier,
285                )))
286            }
287
288            /// Constructs a signed DSLX-style bit value by truncating to the target width.
289            pub const fn wrapping(value: $signed_carrier) -> Self {
290                Self($raw_name::<BIT_COUNT>::wrapping(
291                    value as $unsigned_carrier,
292                ))
293            }
294
295            /// Constructs a signed DSLX-style bit value from raw ABI bits.
296            pub const fn from_raw_bits(value: $unsigned_carrier) -> Self {
297                assert!(
298                    BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
299                    "invalid raw bit carrier width"
300                );
301                assert!(
302                    value & !Self::mask() == 0,
303                    "raw bits do not fit target width"
304                );
305                Self($raw_name::<BIT_COUNT>::wrapping(value))
306            }
307
308            /// Returns the sign-extended native signed carrier value.
309            fn to_signed_carrier(self) -> $signed_carrier {
310                let raw = self.0.get();
311                let sign_bit = 1 as $unsigned_carrier << (BIT_COUNT - 1);
312                if raw & sign_bit == 0 {
313                    raw as $signed_carrier
314                } else {
315                    (raw | !Self::mask()) as $signed_carrier
316                }
317            }
318
319            /// Returns the sign-extended value widened to `i64`.
320            pub fn to_i64(self) -> i64 {
321                self.to_signed_carrier() as i64
322            }
323
324            /// Returns the raw ABI bits in the native carrier.
325            pub const fn raw_bits(self) -> $unsigned_carrier {
326                self.0.get()
327            }
328        }
329
330        impl<const BIT_COUNT: usize> TryFrom<i64> for $signed_name<BIT_COUNT> {
331            type Error = RunError;
332
333            fn try_from(value: i64) -> Result<Self, Self::Error> {
334                let carrier = <$signed_carrier>::try_from(value).map_err(|_| {
335                    RunError(format!("value {value} does not fit in s{BIT_COUNT}"))
336                })?;
337                Self::new(carrier)
338            }
339        }
340    };
341}
342
343define_public_bits_wrappers!(UnsignedBitsInU8, SignedBitsInU8, BitsInU8, u8, i8, 8);
344define_public_bits_wrappers!(UnsignedBitsInU16, SignedBitsInU16, BitsInU16, u16, i16, 16);
345define_public_bits_wrappers!(UnsignedBitsInU32, SignedBitsInU32, BitsInU32, u32, i32, 32);
346define_public_bits_wrappers!(UnsignedBitsInU64, SignedBitsInU64, BitsInU64, u64, i64, 64);
347
348/// Native least-significant-first limb storage for a bitvector wider than 64
349/// bits.
350#[repr(transparent)]
351#[derive(Debug, Clone, Copy, PartialEq, Eq)]
352pub struct WideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>([u64; LIMB_COUNT]);
353
354impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> WideBits<BIT_COUNT, LIMB_COUNT> {
355    fn validate_layout() -> Result<(), RunError> {
356        if BIT_COUNT <= 64 || LIMB_COUNT != BIT_COUNT.div_ceil(64) {
357            Err(RunError(format!(
358                "bits[{BIT_COUNT}] cannot use {LIMB_COUNT} native wide limb(s)"
359            )))
360        } else {
361            Ok(())
362        }
363    }
364
365    fn high_mask() -> u64 {
366        let high_width = BIT_COUNT % 64;
367        if high_width == 0 {
368            u64::MAX
369        } else {
370            (1u64 << high_width) - 1
371        }
372    }
373
374    /// Constructs a canonical wide bitvector, rejecting excess high bits.
375    pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
376        Self::validate_layout()?;
377        if limbs[LIMB_COUNT - 1] & !Self::high_mask() != 0 {
378            Err(RunError(format!(
379                "high limb does not fit in bits[{BIT_COUNT}]"
380            )))
381        } else {
382            Ok(Self(limbs))
383        }
384    }
385
386    /// Constructs a canonical wide bitvector by truncating excess high bits.
387    pub fn wrapping_limbs(mut limbs: [u64; LIMB_COUNT]) -> Self {
388        assert!(
389            BIT_COUNT > 64 && LIMB_COUNT == BIT_COUNT.div_ceil(64),
390            "invalid native wide bits layout"
391        );
392        limbs[LIMB_COUNT - 1] &= Self::high_mask();
393        Self(limbs)
394    }
395
396    /// Returns the least-significant-first limb representation.
397    pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
398        &self.0
399    }
400}
401
402impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default for WideBits<BIT_COUNT, LIMB_COUNT> {
403    fn default() -> Self {
404        Self([0; LIMB_COUNT])
405    }
406}
407
408/// Public unsigned DSLX-style wrapper for a wide native bitvector.
409#[repr(transparent)]
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
411pub struct UnsignedWideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>(
412    WideBits<BIT_COUNT, LIMB_COUNT>,
413);
414
415impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> UnsignedWideBits<BIT_COUNT, LIMB_COUNT> {
416    /// Constructs a canonical wide unsigned bitvector, rejecting excess high
417    /// bits.
418    pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
419        Ok(Self(WideBits::<BIT_COUNT, LIMB_COUNT>::from_limbs(limbs)?))
420    }
421
422    /// Constructs a canonical wide unsigned bitvector by truncating excess high
423    /// bits.
424    pub fn wrapping_limbs(limbs: [u64; LIMB_COUNT]) -> Self {
425        Self(WideBits::<BIT_COUNT, LIMB_COUNT>::wrapping_limbs(limbs))
426    }
427
428    /// Returns the least-significant-first raw ABI limb representation.
429    pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
430        self.0.limbs()
431    }
432
433    /// Returns the unsigned value as a big integer.
434    pub fn to_biguint(&self) -> BigUint {
435        let mut bytes = Vec::with_capacity(LIMB_COUNT * std::mem::size_of::<u64>());
436        for limb in self.limbs() {
437            bytes.extend_from_slice(&limb.to_le_bytes());
438        }
439        BigUint::from_bytes_le(&bytes)
440    }
441}
442
443impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default
444    for UnsignedWideBits<BIT_COUNT, LIMB_COUNT>
445{
446    fn default() -> Self {
447        Self(WideBits::default())
448    }
449}
450
451/// Public signed DSLX-style wrapper for a wide native bitvector.
452#[repr(transparent)]
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub struct SignedWideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>(
455    WideBits<BIT_COUNT, LIMB_COUNT>,
456);
457
458impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> SignedWideBits<BIT_COUNT, LIMB_COUNT> {
459    /// Constructs a canonical wide signed bitvector from raw ABI limbs.
460    pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
461        Ok(Self(WideBits::<BIT_COUNT, LIMB_COUNT>::from_limbs(limbs)?))
462    }
463
464    /// Constructs a canonical wide signed bitvector by truncating excess high
465    /// bits.
466    pub fn wrapping_limbs(limbs: [u64; LIMB_COUNT]) -> Self {
467        Self(WideBits::<BIT_COUNT, LIMB_COUNT>::wrapping_limbs(limbs))
468    }
469
470    /// Returns the least-significant-first raw ABI limb representation.
471    pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
472        self.0.limbs()
473    }
474
475    /// Returns the sign-extended value as a big integer.
476    pub fn to_bigint(&self) -> BigInt {
477        let unsigned = UnsignedWideBits::<BIT_COUNT, LIMB_COUNT>(self.0).to_biguint();
478        if BIT_COUNT == 0 || !unsigned.bit((BIT_COUNT - 1) as u64) {
479            BigInt::from_biguint(Sign::Plus, unsigned)
480        } else {
481            BigInt::from_biguint(Sign::Plus, unsigned) - (BigInt::from(1u8) << BIT_COUNT)
482        }
483    }
484}
485
486impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default
487    for SignedWideBits<BIT_COUNT, LIMB_COUNT>
488{
489    fn default() -> Self {
490        Self(WideBits::default())
491    }
492}
493
494/// Zero-sized native representation of a PIR token value.
495#[repr(C)]
496#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
497pub struct Token;
498
499/// Kind of observable PIR event described by an event site.
500#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501pub enum EventKind {
502    Assert,
503    Assumption(AssumptionFailureKind),
504    Cover,
505    Trace,
506}
507
508/// Native data description sufficient for immediate trace-value decoding.
509#[derive(Debug, Clone, PartialEq, Eq)]
510pub enum TraceValueLayout {
511    Bits {
512        bit_count: usize,
513        byte_count: usize,
514    },
515    WideBits {
516        bit_count: usize,
517        limb_count: usize,
518    },
519    Array {
520        element: Box<TraceValueLayout>,
521        element_count: usize,
522    },
523    Tuple {
524        fields: Vec<TraceTupleFieldLayout>,
525        byte_count: usize,
526    },
527    Token,
528}
529
530/// Operation implemented by [`xlsynth_pir_runtime_wide_binop`].
531#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532#[repr(u32)]
533pub enum WideBinaryOp {
534    Umul = 0,
535    Smul = 1,
536    Udiv = 2,
537    Sdiv = 3,
538    Umod = 4,
539    Smod = 5,
540    Shll = 6,
541    Shrl = 7,
542    Shra = 8,
543}
544
545impl WideBinaryOp {
546    fn from_abi(value: u32) -> Option<Self> {
547        Some(match value {
548            0 => Self::Umul,
549            1 => Self::Smul,
550            2 => Self::Udiv,
551            3 => Self::Sdiv,
552            4 => Self::Umod,
553            5 => Self::Smod,
554            6 => Self::Shll,
555            7 => Self::Shrl,
556            8 => Self::Shra,
557            _ => return None,
558        })
559    }
560}
561
562/// Operation implemented by [`xlsynth_pir_runtime_wide_unary_op`].
563#[derive(Debug, Clone, Copy, PartialEq, Eq)]
564#[repr(u32)]
565pub enum WideUnaryOp {
566    OneHot = 0,
567    Encode = 1,
568    Decode = 2,
569    ExtPrioEncode = 3,
570    ExtClz = 4,
571    ExtNormalizeLeft = 5,
572    ExtMaskLow = 6,
573}
574
575impl WideUnaryOp {
576    fn from_abi(value: u32) -> Option<Self> {
577        Some(match value {
578            0 => Self::OneHot,
579            1 => Self::Encode,
580            2 => Self::Decode,
581            3 => Self::ExtPrioEncode,
582            4 => Self::ExtClz,
583            5 => Self::ExtNormalizeLeft,
584            6 => Self::ExtMaskLow,
585            _ => return None,
586        })
587    }
588}
589
590/// Description of one tuple field supplied as a trace operand.
591#[derive(Debug, Clone, PartialEq, Eq)]
592pub struct TraceTupleFieldLayout {
593    pub layout: Box<TraceValueLayout>,
594    pub offset: usize,
595}
596
597/// Static information attached to one observable node in compiled code.
598#[derive(Debug, Clone, PartialEq, Eq)]
599pub struct EventSiteMetadata {
600    pub node_text_id: usize,
601    pub kind: EventKind,
602    pub label: Option<String>,
603    pub message: Option<String>,
604    pub format: Option<String>,
605    pub verbosity: i64,
606    pub operand_layouts: Vec<TraceValueLayout>,
607}
608
609/// Runtime metadata for all observable sites in one compiled function.
610#[derive(Debug, Clone, Default, PartialEq, Eq)]
611pub struct CompiledFunctionMetadata {
612    pub event_sites: Vec<EventSiteMetadata>,
613}
614
615/// A failed compiled assertion, resolved to its source-site metadata.
616#[derive(Debug, Clone, PartialEq, Eq)]
617pub struct AssertionFailure {
618    pub node_text_id: usize,
619    pub message: String,
620    pub label: String,
621}
622
623/// A failed contract asserted by an `assumed_in_bounds` array operation.
624#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
625pub enum AssumptionFailureKind {
626    ArrayIndexOutOfBounds,
627    ArrayUpdateOutOfBounds,
628}
629
630/// One failed assumption observed while executing compiled code.
631#[derive(Debug, Clone, PartialEq, Eq)]
632pub struct AssumptionFailure {
633    pub node_text_id: usize,
634    pub kind: AssumptionFailureKind,
635}
636
637/// One emitted compiled trace statement.
638#[derive(Debug, Clone, PartialEq, Eq)]
639pub struct TraceMessage {
640    pub node_text_id: usize,
641    pub message: String,
642    pub verbosity: i64,
643}
644
645/// Accumulated execution count for a compiled `cover` site.
646#[derive(Debug, Clone, PartialEq, Eq)]
647pub struct CoverCount {
648    pub node_text_id: usize,
649    pub label: String,
650    pub count: u64,
651}
652
653/// Rust-owned observable results recorded while executing compiled code.
654#[derive(Debug, Clone, Default, PartialEq, Eq)]
655pub struct ExecutionResult {
656    pub assertion_failures: Vec<AssertionFailure>,
657    pub assumption_failures: Vec<AssumptionFailure>,
658    pub trace_messages: Vec<TraceMessage>,
659    pub cover_counts: Vec<CoverCount>,
660}
661
662/// Runtime options controlling which observable events are collected.
663#[derive(Debug, Clone, Copy, PartialEq, Eq)]
664pub struct ExecutionOptions {
665    pub trace_verbosity: Option<i64>,
666    pub collect_covers: bool,
667}
668
669impl ExecutionOptions {
670    /// Disables trace and cover collection while still recording failures.
671    pub const NO_EVENTS: Self = Self {
672        trace_verbosity: None,
673        collect_covers: false,
674    };
675
676    /// Collects covers and traces whose site verbosity is at most `verbosity`.
677    pub const fn new(trace_verbosity: Option<i64>, collect_covers: bool) -> Self {
678        Self {
679            trace_verbosity,
680            collect_covers,
681        }
682    }
683
684    /// Collects all traces and all covers.
685    pub const fn collect_all() -> Self {
686        Self {
687            trace_verbosity: Some(i64::MAX),
688            collect_covers: true,
689        }
690    }
691}
692
693impl Default for ExecutionOptions {
694    fn default() -> Self {
695        Self::NO_EVENTS
696    }
697}
698
699#[derive(Debug, Clone, Copy, PartialEq, Eq)]
700enum TraceFormatPreference {
701    Default,
702    UnsignedDecimal,
703    SignedDecimal,
704    PlainHex,
705    ZeroPaddedHex,
706    Hex,
707    PlainBinary,
708    ZeroPaddedBinary,
709    Binary,
710}
711
712const TRACE_FORMAT_SPECIFIERS: [(&str, TraceFormatPreference); 9] = [
713    ("{}", TraceFormatPreference::Default),
714    ("{:u}", TraceFormatPreference::UnsignedDecimal),
715    ("{:d}", TraceFormatPreference::SignedDecimal),
716    ("{:x}", TraceFormatPreference::PlainHex),
717    ("{:0x}", TraceFormatPreference::ZeroPaddedHex),
718    ("{:#x}", TraceFormatPreference::Hex),
719    ("{:b}", TraceFormatPreference::PlainBinary),
720    ("{:0b}", TraceFormatPreference::ZeroPaddedBinary),
721    ("{:#b}", TraceFormatPreference::Binary),
722];
723
724struct ContextState {
725    metadata: *const CompiledFunctionMetadata,
726    options: ExecutionOptions,
727    assertion_failures: Vec<AssertionFailure>,
728    assumption_failures: Vec<AssumptionFailure>,
729    trace_messages: Vec<TraceMessage>,
730    event_counts: Option<Vec<u64>>,
731}
732
733/// Rust-owned event collector used for one or more compiled executions.
734///
735/// Cover counts accumulate until [`Self::clear`] is called. Assertion,
736/// assumption, and trace results also accumulate, permitting callers to
737/// consume a batch of invocations through one context.
738pub struct ExecutionContext<'metadata> {
739    state: Box<ContextState>,
740    marker: PhantomData<&'metadata CompiledFunctionMetadata>,
741}
742
743impl<'metadata> ExecutionContext<'metadata> {
744    /// Creates an empty collector for the supplied function metadata.
745    pub fn new(metadata: &'metadata CompiledFunctionMetadata) -> Self {
746        Self::new_with_options(metadata, ExecutionOptions::default())
747    }
748
749    /// Creates an empty collector with explicit event collection options.
750    pub fn new_with_options(
751        metadata: &'metadata CompiledFunctionMetadata,
752        options: ExecutionOptions,
753    ) -> Self {
754        let event_counts = options
755            .collect_covers
756            .then(|| vec![0; metadata.event_sites.len()]);
757        Self {
758            state: Box::new(ContextState {
759                metadata,
760                options,
761                assertion_failures: Vec::new(),
762                assumption_failures: Vec::new(),
763                trace_messages: Vec::new(),
764                event_counts,
765            }),
766            marker: PhantomData,
767        }
768    }
769
770    /// Returns an opaque ABI object valid while this context is mutably
771    /// borrowed.
772    pub fn raw_context(&mut self) -> RawExecutionContext {
773        RawExecutionContext {
774            private_state: ptr::from_mut(self.state.as_mut()).cast(),
775        }
776    }
777
778    /// Resolves all currently recorded events into ordinary Rust values.
779    pub fn result(&self) -> ExecutionResult {
780        let metadata = self.metadata();
781        let cover_counts = self
782            .state
783            .event_counts
784            .as_ref()
785            .map(|event_counts| {
786                metadata
787                    .event_sites
788                    .iter()
789                    .zip(event_counts)
790                    .filter(|(site, _)| site.kind == EventKind::Cover)
791                    .map(|(site, count)| CoverCount {
792                        node_text_id: site.node_text_id,
793                        label: site.label.clone().unwrap_or_default(),
794                        count: *count,
795                    })
796                    .collect()
797            })
798            .unwrap_or_default();
799        ExecutionResult {
800            assertion_failures: self.state.assertion_failures.clone(),
801            assumption_failures: self.state.assumption_failures.clone(),
802            trace_messages: self.state.trace_messages.clone(),
803            cover_counts,
804        }
805    }
806
807    /// Returns currently recorded assertion failures without cloning.
808    pub fn assertion_failures(&self) -> &[AssertionFailure] {
809        &self.state.assertion_failures
810    }
811
812    /// Returns currently recorded assumption failures without cloning.
813    pub fn assumption_failures(&self) -> &[AssumptionFailure] {
814        &self.state.assumption_failures
815    }
816
817    /// Clears all event records and accumulated cover counters.
818    pub fn clear(&mut self) {
819        self.clear_with_options(self.state.options);
820    }
821
822    /// Clears all event records and switches to the supplied collection
823    /// options.
824    pub fn clear_with_options(&mut self, options: ExecutionOptions) {
825        self.state.assertion_failures.clear();
826        self.state.assumption_failures.clear();
827        self.state.trace_messages.clear();
828        self.state.options = options;
829        if options.collect_covers {
830            match &mut self.state.event_counts {
831                Some(event_counts) => event_counts.fill(0),
832                None => {
833                    let site_count = self.metadata().event_sites.len();
834                    self.state.event_counts = Some(vec![0; site_count]);
835                }
836            }
837        } else {
838            self.state.event_counts = None;
839        }
840    }
841
842    fn metadata(&self) -> &CompiledFunctionMetadata {
843        // SAFETY: the context's lifetime guarantees metadata remains alive.
844        unsafe { &*self.state.metadata }
845    }
846}
847
848unsafe fn state_from_raw<'a>(context: *mut RawExecutionContext) -> &'a mut ContextState {
849    assert!(
850        !context.is_null(),
851        "compiled PIR callback requires an execution context"
852    );
853    // SAFETY: generated entrypoints receive a `RawExecutionContext` produced
854    // by `ExecutionContext::raw_context` for the duration of the call.
855    unsafe {
856        (*context)
857            .private_state
858            .cast::<ContextState>()
859            .as_mut()
860            .expect("compiled PIR callback received an invalid execution context")
861    }
862}
863
864fn site(state: &ContextState, site_id: u32, kind: EventKind) -> Option<&EventSiteMetadata> {
865    // SAFETY: the owning `ExecutionContext` keeps metadata alive.
866    let metadata = unsafe { state.metadata.as_ref()? };
867    let site = metadata.event_sites.get(site_id as usize)?;
868    (site.kind == kind).then_some(site)
869}
870
871/// Records a failed assertion from generated code.
872///
873/// # Safety
874///
875/// `context` must point to an active raw context created by
876/// [`ExecutionContext::raw_context`].
877#[unsafe(no_mangle)]
878pub unsafe extern "C" fn xlsynth_pir_record_assert(
879    context: *mut RawExecutionContext,
880    site_id: u32,
881) {
882    // SAFETY: forwarded from the caller's ABI contract.
883    let state = unsafe { state_from_raw(context) };
884    let Some(site) = site(state, site_id, EventKind::Assert).cloned() else {
885        return;
886    };
887    state.assertion_failures.push(AssertionFailure {
888        node_text_id: site.node_text_id,
889        message: site.message.unwrap_or_default(),
890        label: site.label.unwrap_or_default(),
891    });
892}
893
894/// Records a failed `assumed_in_bounds` contract from generated code.
895///
896/// # Safety
897///
898/// `context` must point to an active raw context created by
899/// [`ExecutionContext::raw_context`].
900#[unsafe(no_mangle)]
901pub unsafe extern "C" fn xlsynth_pir_record_assumption_failure(
902    context: *mut RawExecutionContext,
903    site_id: u32,
904) {
905    // SAFETY: forwarded from the caller's ABI contract.
906    let state = unsafe { state_from_raw(context) };
907    // SAFETY: the owning `ExecutionContext` keeps metadata alive.
908    let Some(site) = (unsafe { state.metadata.as_ref() })
909        .and_then(|metadata| metadata.event_sites.get(site_id as usize))
910    else {
911        return;
912    };
913    let EventKind::Assumption(kind) = site.kind else {
914        return;
915    };
916    state.assumption_failures.push(AssumptionFailure {
917        node_text_id: site.node_text_id,
918        kind,
919    });
920}
921
922/// Records one active cover occurrence from generated code.
923///
924/// # Safety
925///
926/// `context` must point to an active raw context created by
927/// [`ExecutionContext::raw_context`].
928#[unsafe(no_mangle)]
929pub unsafe extern "C" fn xlsynth_pir_record_cover(context: *mut RawExecutionContext, site_id: u32) {
930    // SAFETY: forwarded from the caller's ABI contract.
931    let state = unsafe { state_from_raw(context) };
932    if state.event_counts.is_none() || site(state, site_id, EventKind::Cover).is_none() {
933        return;
934    }
935    if let Some(count) = state
936        .event_counts
937        .as_mut()
938        .and_then(|event_counts| event_counts.get_mut(site_id as usize))
939    {
940        *count = count.saturating_add(1);
941    }
942}
943
944/// Records and formats one active trace occurrence from generated code.
945///
946/// # Safety
947///
948/// `context` must point to an active raw context created by
949/// [`ExecutionContext::raw_context`]. Each operand pointer must describe
950/// readable native storage matching the corresponding site's operand layout
951/// for the duration of this callback.
952#[unsafe(no_mangle)]
953pub unsafe extern "C" fn xlsynth_pir_record_trace(
954    context: *mut RawExecutionContext,
955    site_id: u32,
956    operand_ptrs: *const *const u8,
957) {
958    // SAFETY: forwarded from the caller's ABI contract.
959    let state = unsafe { state_from_raw(context) };
960    let Some(max_verbosity) = state.options.trace_verbosity else {
961        return;
962    };
963    let Some(site) = site(state, site_id, EventKind::Trace) else {
964        return;
965    };
966    if site.verbosity > max_verbosity {
967        return;
968    }
969    if !site.operand_layouts.is_empty() && operand_ptrs.is_null() {
970        return;
971    }
972    state.trace_messages.push(TraceMessage {
973        node_text_id: site.node_text_id,
974        // SAFETY: the generated caller supplies one pointer per metadata operand.
975        message: unsafe {
976            format_trace_message(
977                site.format.as_deref().unwrap_or(""),
978                &site.operand_layouts,
979                operand_ptrs,
980            )
981        },
982        verbosity: site.verbosity,
983    });
984}
985
986fn bit_mask(bit_count: usize) -> BigUint {
987    if bit_count == 0 {
988        BigUint::from(0u8)
989    } else {
990        (BigUint::from(1u8) << bit_count) - BigUint::from(1u8)
991    }
992}
993
994fn truncate_unsigned(value: BigUint, bit_count: usize) -> BigUint {
995    value & bit_mask(bit_count)
996}
997
998fn as_signed(value: BigUint, bit_count: usize) -> BigInt {
999    if bit_count != 0 && (&value & (BigUint::from(1u8) << (bit_count - 1))) != BigUint::from(0u8) {
1000        BigInt::from_biguint(Sign::Plus, value)
1001            - BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count)
1002    } else {
1003        BigInt::from_biguint(Sign::Plus, value)
1004    }
1005}
1006
1007fn truncate_signed(value: BigInt, bit_count: usize) -> BigUint {
1008    if bit_count == 0 {
1009        return BigUint::from(0u8);
1010    }
1011    let modulus = BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count);
1012    let mut reduced = value % &modulus;
1013    if reduced.sign() == Sign::Minus {
1014        reduced += &modulus;
1015    }
1016    let (_, bytes) = reduced.to_bytes_le();
1017    BigUint::from_bytes_le(&bytes)
1018}
1019
1020fn bounded_shift_amount(value: &BigUint, bit_count: usize) -> Option<usize> {
1021    let digits = value.to_u64_digits();
1022    if digits.len() > 1 {
1023        return None;
1024    }
1025    let amount = digits.first().copied().unwrap_or(0);
1026    usize::try_from(amount)
1027        .ok()
1028        .filter(|amount| *amount < bit_count)
1029}
1030
1031/// Reads a fixed-width value from least-significant-first native `u64` limbs.
1032///
1033/// # Safety
1034///
1035/// `limbs` must be readable for `bit_count.div_ceil(64)` native `u64` values.
1036unsafe fn read_wide_bits(limbs: *const u64, bit_count: usize) -> BigUint {
1037    let limb_count = bit_count.div_ceil(64);
1038    let mut bytes = Vec::with_capacity(limb_count * std::mem::size_of::<u64>());
1039    for index in 0..limb_count {
1040        // SAFETY: forwarded from this function's pointer contract.
1041        let limb = unsafe { limbs.add(index).read() };
1042        bytes.extend_from_slice(&limb.to_le_bytes());
1043    }
1044    truncate_unsigned(BigUint::from_bytes_le(&bytes), bit_count)
1045}
1046
1047/// Writes a fixed-width value to least-significant-first native `u64` limbs.
1048///
1049/// # Safety
1050///
1051/// `limbs` must be writable for `bit_count.div_ceil(64)` native `u64` values.
1052unsafe fn write_wide_bits(limbs: *mut u64, bit_count: usize, value: BigUint) {
1053    let limb_count = bit_count.div_ceil(64);
1054    let bytes = truncate_unsigned(value, bit_count).to_bytes_le();
1055    for index in 0..limb_count {
1056        let mut limb_bytes = [0u8; std::mem::size_of::<u64>()];
1057        let start = index * std::mem::size_of::<u64>();
1058        if start < bytes.len() {
1059            let end = bytes.len().min(start + std::mem::size_of::<u64>());
1060            limb_bytes[..end - start].copy_from_slice(&bytes[start..end]);
1061        }
1062        // SAFETY: forwarded from this function's pointer contract.
1063        unsafe { limbs.add(index).write(u64::from_le_bytes(limb_bytes)) };
1064    }
1065}
1066
1067fn get_bit(value: &BigUint, index: usize) -> bool {
1068    value
1069        .to_u64_digits()
1070        .get(index / u64::BITS as usize)
1071        .is_some_and(|limb| limb & (1u64 << (index % u64::BITS as usize)) != 0)
1072}
1073
1074fn prioritized_set_bit(value: &BigUint, bit_count: usize, lsb_prio: bool) -> Option<usize> {
1075    if lsb_prio {
1076        (0..bit_count).find(|index| get_bit(value, *index))
1077    } else {
1078        (0..bit_count).rev().find(|index| get_bit(value, *index))
1079    }
1080}
1081
1082fn leading_zero_count(value: &BigUint, bit_count: usize) -> usize {
1083    prioritized_set_bit(value, bit_count, /* lsb_prio= */ false)
1084        .map(|index| bit_count - index - 1)
1085        .unwrap_or(bit_count)
1086}
1087
1088fn mulp_offset(result_width: usize) -> BigUint {
1089    let low_width = result_width.saturating_sub(2);
1090    let high_width = result_width - low_width;
1091    let low_shift = low_width.saturating_sub(1).min(3);
1092    let low = if low_width == 0 {
1093        BigUint::from(0u8)
1094    } else {
1095        bit_mask(low_width) >> low_shift
1096    };
1097    let high = if high_width == 0 {
1098        BigUint::from(0u8)
1099    } else {
1100        bit_mask(high_width.saturating_sub(1)) << low_width
1101    };
1102    low | high
1103}
1104
1105/// Computes a complex arbitrary-width binary operation over native limb
1106/// storage.
1107///
1108/// Limb arrays are ordered from least- to most-significant limb. The result is
1109/// truncated to `dst_bit_count`. `dst` may not alias either source.
1110///
1111/// # Safety
1112///
1113/// Each pointer must be valid for the number of `u64` limbs implied by its
1114/// supplied bit count and obey the non-aliasing rule above.
1115#[unsafe(no_mangle)]
1116pub unsafe extern "C" fn xlsynth_pir_runtime_wide_binop(
1117    dst: *mut u64,
1118    dst_bit_count: usize,
1119    lhs: *const u64,
1120    lhs_bit_count: usize,
1121    rhs: *const u64,
1122    rhs_bit_count: usize,
1123    operation: u32,
1124) {
1125    let Some(operation) = WideBinaryOp::from_abi(operation) else {
1126        return;
1127    };
1128    // SAFETY: forwarded from this callback's pointer contract.
1129    let lhs_unsigned = unsafe { read_wide_bits(lhs, lhs_bit_count) };
1130    // SAFETY: forwarded from this callback's pointer contract.
1131    let rhs_unsigned = unsafe { read_wide_bits(rhs, rhs_bit_count) };
1132    let result = match operation {
1133        WideBinaryOp::Umul => truncate_unsigned(lhs_unsigned * rhs_unsigned, dst_bit_count),
1134        WideBinaryOp::Smul => truncate_signed(
1135            as_signed(lhs_unsigned, lhs_bit_count) * as_signed(rhs_unsigned, rhs_bit_count),
1136            dst_bit_count,
1137        ),
1138        WideBinaryOp::Udiv => {
1139            if rhs_unsigned == BigUint::from(0u8) {
1140                bit_mask(dst_bit_count)
1141            } else {
1142                truncate_unsigned(lhs_unsigned / rhs_unsigned, dst_bit_count)
1143            }
1144        }
1145        WideBinaryOp::Umod => {
1146            if rhs_unsigned == BigUint::from(0u8) {
1147                BigUint::from(0u8)
1148            } else {
1149                truncate_unsigned(lhs_unsigned % rhs_unsigned, dst_bit_count)
1150            }
1151        }
1152        WideBinaryOp::Sdiv | WideBinaryOp::Smod => {
1153            let lhs_signed = as_signed(lhs_unsigned, lhs_bit_count);
1154            let rhs_signed = as_signed(rhs_unsigned, rhs_bit_count);
1155            if dst_bit_count == 0 {
1156                BigUint::from(0u8)
1157            } else if rhs_signed == BigInt::from(0u8) {
1158                if operation == WideBinaryOp::Smod {
1159                    BigUint::from(0u8)
1160                } else if lhs_signed.sign() == Sign::Minus {
1161                    BigUint::from(1u8) << (dst_bit_count - 1)
1162                } else {
1163                    (BigUint::from(1u8) << (dst_bit_count - 1)) - BigUint::from(1u8)
1164                }
1165            } else if operation == WideBinaryOp::Sdiv {
1166                truncate_signed(lhs_signed / rhs_signed, dst_bit_count)
1167            } else {
1168                truncate_signed(lhs_signed % rhs_signed, dst_bit_count)
1169            }
1170        }
1171        WideBinaryOp::Shll | WideBinaryOp::Shrl | WideBinaryOp::Shra => {
1172            match bounded_shift_amount(&rhs_unsigned, lhs_bit_count) {
1173                None if operation == WideBinaryOp::Shra => {
1174                    if as_signed(lhs_unsigned, lhs_bit_count).sign() == Sign::Minus {
1175                        bit_mask(dst_bit_count)
1176                    } else {
1177                        BigUint::from(0u8)
1178                    }
1179                }
1180                None => BigUint::from(0u8),
1181                Some(amount) if operation == WideBinaryOp::Shll => {
1182                    truncate_unsigned(lhs_unsigned << amount, dst_bit_count)
1183                }
1184                Some(amount) if operation == WideBinaryOp::Shrl => {
1185                    truncate_unsigned(lhs_unsigned >> amount, dst_bit_count)
1186                }
1187                Some(amount) => truncate_signed(
1188                    as_signed(lhs_unsigned, lhs_bit_count) >> amount,
1189                    dst_bit_count,
1190                ),
1191            }
1192        }
1193    };
1194    // SAFETY: forwarded from this callback's pointer contract.
1195    unsafe { write_wide_bits(dst, dst_bit_count, result) };
1196}
1197
1198/// Computes a zero-filled dynamic slice into native limb storage.
1199///
1200/// # Safety
1201///
1202/// Pointer requirements match [`xlsynth_pir_runtime_wide_binop`].
1203#[unsafe(no_mangle)]
1204pub unsafe extern "C" fn xlsynth_pir_runtime_wide_dynamic_bit_slice(
1205    dst: *mut u64,
1206    dst_bit_count: usize,
1207    arg: *const u64,
1208    arg_bit_count: usize,
1209    start: *const u64,
1210    start_bit_count: usize,
1211) {
1212    // SAFETY: forwarded from this callback's pointer contract.
1213    let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1214    // SAFETY: forwarded from this callback's pointer contract.
1215    let start = unsafe { read_wide_bits(start, start_bit_count) };
1216    let result = bounded_shift_amount(&start, arg_bit_count)
1217        .map(|amount| truncate_unsigned(arg >> amount, dst_bit_count))
1218        .unwrap_or_else(|| BigUint::from(0u8));
1219    // SAFETY: forwarded from this callback's pointer contract.
1220    unsafe { write_wide_bits(dst, dst_bit_count, result) };
1221}
1222
1223/// Inserts a dynamically positioned low-to-high slice into native limb storage.
1224///
1225/// # Safety
1226///
1227/// Pointer requirements match [`xlsynth_pir_runtime_wide_binop`].
1228#[unsafe(no_mangle)]
1229pub unsafe extern "C" fn xlsynth_pir_runtime_wide_bit_slice_update(
1230    dst: *mut u64,
1231    dst_bit_count: usize,
1232    arg: *const u64,
1233    arg_bit_count: usize,
1234    start: *const u64,
1235    start_bit_count: usize,
1236    update: *const u64,
1237    update_bit_count: usize,
1238) {
1239    // SAFETY: forwarded from this callback's pointer contract.
1240    let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1241    // SAFETY: forwarded from this callback's pointer contract.
1242    let start = unsafe { read_wide_bits(start, start_bit_count) };
1243    // SAFETY: forwarded from this callback's pointer contract.
1244    let update = unsafe { read_wide_bits(update, update_bit_count) };
1245    let result = if let Some(start) = bounded_shift_amount(&start, arg_bit_count) {
1246        let written_width = update_bit_count.min(arg_bit_count - start);
1247        let written_mask = bit_mask(written_width) << start;
1248        let retained = &arg & (&bit_mask(arg_bit_count) ^ &written_mask);
1249        retained | ((update & bit_mask(written_width)) << start)
1250    } else {
1251        arg
1252    };
1253    // SAFETY: forwarded from this callback's pointer contract.
1254    unsafe { write_wide_bits(dst, dst_bit_count, result) };
1255}
1256
1257/// Computes an arbitrary-width single-operand PIR transform.
1258///
1259/// `attribute` is interpreted as `lsb_prio` for `one_hot` and
1260/// `ext_prio_encode`, and as the static shift/count offset for `ext_clz` and
1261/// `ext_normalize_left`.
1262///
1263/// # Safety
1264///
1265/// Pointer requirements match [`xlsynth_pir_runtime_wide_binop`].
1266#[unsafe(no_mangle)]
1267pub unsafe extern "C" fn xlsynth_pir_runtime_wide_unary_op(
1268    dst: *mut u64,
1269    dst_bit_count: usize,
1270    arg: *const u64,
1271    arg_bit_count: usize,
1272    operation: u32,
1273    attribute: usize,
1274) {
1275    let Some(operation) = WideUnaryOp::from_abi(operation) else {
1276        return;
1277    };
1278    // SAFETY: forwarded from this callback's pointer contract.
1279    let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1280    let result = match operation {
1281        WideUnaryOp::OneHot => {
1282            let selected =
1283                prioritized_set_bit(&arg, arg_bit_count, attribute != 0).unwrap_or(arg_bit_count);
1284            BigUint::from(1u8) << selected
1285        }
1286        WideUnaryOp::Encode => {
1287            let mut result = 0usize;
1288            for index in 0..arg_bit_count {
1289                if get_bit(&arg, index) {
1290                    result |= index;
1291                }
1292            }
1293            BigUint::from(result)
1294        }
1295        WideUnaryOp::Decode => bounded_shift_amount(&arg, dst_bit_count)
1296            .map(|amount| BigUint::from(1u8) << amount)
1297            .unwrap_or_else(|| BigUint::from(0u8)),
1298        WideUnaryOp::ExtPrioEncode => BigUint::from(
1299            prioritized_set_bit(&arg, arg_bit_count, attribute != 0).unwrap_or(arg_bit_count),
1300        ),
1301        WideUnaryOp::ExtClz => BigUint::from(leading_zero_count(&arg, arg_bit_count) + attribute),
1302        WideUnaryOp::ExtNormalizeLeft => {
1303            let shift = leading_zero_count(&arg, arg_bit_count).saturating_add(attribute);
1304            if shift >= dst_bit_count {
1305                BigUint::from(0u8)
1306            } else {
1307                truncate_unsigned(arg << shift, dst_bit_count)
1308            }
1309        }
1310        WideUnaryOp::ExtMaskLow => {
1311            if arg >= BigUint::from(dst_bit_count) {
1312                bit_mask(dst_bit_count)
1313            } else {
1314                let count = arg.to_u64_digits().first().copied().unwrap_or(0) as usize;
1315                bit_mask(count)
1316            }
1317        }
1318    };
1319    // SAFETY: forwarded from this callback's pointer contract.
1320    unsafe { write_wide_bits(dst, dst_bit_count, result) };
1321}
1322
1323/// Computes the deterministic pair used for arbitrary-width `umulp`/`smulp`.
1324///
1325/// # Safety
1326///
1327/// Pointer requirements match [`xlsynth_pir_runtime_wide_binop`].
1328#[unsafe(no_mangle)]
1329pub unsafe extern "C" fn xlsynth_pir_runtime_wide_mulp(
1330    offset_dst: *mut u64,
1331    residual_dst: *mut u64,
1332    dst_bit_count: usize,
1333    lhs: *const u64,
1334    lhs_bit_count: usize,
1335    rhs: *const u64,
1336    rhs_bit_count: usize,
1337    signed: u32,
1338) {
1339    // SAFETY: forwarded from this callback's pointer contract.
1340    let lhs = unsafe { read_wide_bits(lhs, lhs_bit_count) };
1341    // SAFETY: forwarded from this callback's pointer contract.
1342    let rhs = unsafe { read_wide_bits(rhs, rhs_bit_count) };
1343    let product = if signed != 0 {
1344        truncate_signed(
1345            as_signed(lhs, lhs_bit_count) * as_signed(rhs, rhs_bit_count),
1346            dst_bit_count,
1347        )
1348    } else {
1349        truncate_unsigned(lhs * rhs, dst_bit_count)
1350    };
1351    let offset = mulp_offset(dst_bit_count);
1352    let residual = truncate_signed(
1353        BigInt::from_biguint(Sign::Plus, product)
1354            - BigInt::from_biguint(Sign::Plus, offset.clone()),
1355        dst_bit_count,
1356    );
1357    // SAFETY: forwarded from this callback's pointer contract.
1358    unsafe { write_wide_bits(offset_dst, dst_bit_count, offset) };
1359    // SAFETY: forwarded from this callback's pointer contract.
1360    unsafe { write_wide_bits(residual_dst, dst_bit_count, residual) };
1361}
1362
1363/// Formats a trace message according to the XLS trace-format string syntax.
1364unsafe fn format_trace_message(
1365    format: &str,
1366    layouts: &[TraceValueLayout],
1367    operand_ptrs: *const *const u8,
1368) -> String {
1369    let mut output = String::new();
1370    let mut offset = 0usize;
1371    let mut operand_index = 0usize;
1372    while offset < format.len() {
1373        let remainder = &format[offset..];
1374        if remainder.starts_with("{{") || remainder.starts_with("}}") {
1375            // XLS preserves escaped braces in trace output; Verilog emission
1376            // performs the collapse to single braces separately.
1377            output.push_str(&remainder[..2]);
1378            offset += 2;
1379            continue;
1380        }
1381        if let Some((specifier, preference)) = TRACE_FORMAT_SPECIFIERS
1382            .iter()
1383            .find(|(specifier, _)| remainder.starts_with(specifier))
1384        {
1385            if let Some(layout) = layouts.get(operand_index) {
1386                // SAFETY: callback ABI provides one matching pointer per layout.
1387                let pointer = unsafe { *operand_ptrs.add(operand_index) };
1388                // SAFETY: callback ABI specifies readable storage matching `layout`.
1389                output.push_str(&unsafe { format_native_value(pointer, layout, *preference) });
1390            }
1391            operand_index += 1;
1392            offset += specifier.len();
1393            continue;
1394        }
1395        let character = remainder
1396            .chars()
1397            .next()
1398            .expect("offset is within trace format string");
1399        output.push(character);
1400        offset += character.len_utf8();
1401    }
1402    output
1403}
1404
1405/// Formats one caller-owned native value without constructing an XLS value.
1406unsafe fn format_native_value(
1407    pointer: *const u8,
1408    layout: &TraceValueLayout,
1409    preference: TraceFormatPreference,
1410) -> String {
1411    match layout {
1412        TraceValueLayout::Bits {
1413            bit_count,
1414            byte_count,
1415        } => {
1416            let mut bytes = vec![0u8; *byte_count];
1417            if *byte_count != 0 {
1418                // SAFETY: callback ABI provides native scalar storage of this size.
1419                unsafe { ptr::copy_nonoverlapping(pointer, bytes.as_mut_ptr(), *byte_count) };
1420            }
1421            let value = if cfg!(target_endian = "little") {
1422                BigUint::from_bytes_le(&bytes)
1423            } else {
1424                BigUint::from_bytes_be(&bytes)
1425            };
1426            format_trace_bits(value, *bit_count, preference)
1427        }
1428        TraceValueLayout::WideBits {
1429            bit_count,
1430            limb_count: _,
1431        } => {
1432            // SAFETY: callback ABI provides the number of native limbs
1433            // prescribed by this layout.
1434            let value = unsafe { read_wide_bits(pointer.cast::<u64>(), *bit_count) };
1435            format_trace_bits(value, *bit_count, preference)
1436        }
1437        TraceValueLayout::Array {
1438            element,
1439            element_count,
1440        } => {
1441            let elements = (0..*element_count)
1442                .map(|index| {
1443                    // SAFETY: each element is within the caller-provided array region.
1444                    unsafe {
1445                        format_native_value(
1446                            pointer.wrapping_add(index * element.byte_count()),
1447                            element,
1448                            preference,
1449                        )
1450                    }
1451                })
1452                .collect::<Vec<_>>();
1453            format!("[{}]", elements.join(", "))
1454        }
1455        TraceValueLayout::Tuple { fields, .. } => {
1456            let fields = fields
1457                .iter()
1458                .map(|field| {
1459                    // SAFETY: each field offset is prescribed by native tuple metadata.
1460                    unsafe {
1461                        format_native_value(
1462                            pointer.wrapping_add(field.offset),
1463                            &field.layout,
1464                            preference,
1465                        )
1466                    }
1467                })
1468                .collect::<Vec<_>>();
1469            format!("({})", fields.join(", "))
1470        }
1471        TraceValueLayout::Token => "token".to_string(),
1472    }
1473}
1474
1475fn format_trace_bits(
1476    mut value: BigUint,
1477    bit_count: usize,
1478    preference: TraceFormatPreference,
1479) -> String {
1480    if bit_count == 0 {
1481        value = BigUint::from(0u8);
1482    } else {
1483        value &= (BigUint::from(1u8) << bit_count) - BigUint::from(1u8);
1484    }
1485    match preference {
1486        TraceFormatPreference::Default => {
1487            if bit_count <= 64 {
1488                value.to_str_radix(10)
1489            } else {
1490                format_trace_bits(value, bit_count, TraceFormatPreference::Hex)
1491            }
1492        }
1493        TraceFormatPreference::UnsignedDecimal => value.to_str_radix(10),
1494        TraceFormatPreference::SignedDecimal => {
1495            if bit_count != 0
1496                && (&value & (BigUint::from(1u8) << (bit_count - 1))) != BigUint::from(0u8)
1497            {
1498                (BigInt::from_biguint(Sign::Plus, value)
1499                    - BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count))
1500                .to_string()
1501            } else {
1502                value.to_str_radix(10)
1503            }
1504        }
1505        TraceFormatPreference::PlainHex => value.to_str_radix(16),
1506        TraceFormatPreference::ZeroPaddedHex => {
1507            zero_padded_grouped_digits(&value, bit_count, 4, 16)
1508        }
1509        TraceFormatPreference::Hex => {
1510            format!("0x{}", grouped_digits(&value.to_str_radix(16)))
1511        }
1512        TraceFormatPreference::PlainBinary => value.to_str_radix(2),
1513        TraceFormatPreference::ZeroPaddedBinary => {
1514            zero_padded_grouped_digits(&value, bit_count, 1, 2)
1515        }
1516        TraceFormatPreference::Binary => {
1517            format!("0b{}", grouped_digits(&value.to_str_radix(2)))
1518        }
1519    }
1520}
1521
1522fn zero_padded_grouped_digits(
1523    value: &BigUint,
1524    bit_count: usize,
1525    bits_per_digit: usize,
1526    radix: u32,
1527) -> String {
1528    let digit_count = bit_count.div_ceil(bits_per_digit).max(1);
1529    let digits = format!(
1530        "{:0>width$}",
1531        value.to_str_radix(radix),
1532        width = digit_count
1533    );
1534    grouped_digits(&digits)
1535}
1536
1537fn grouped_digits(digits: &str) -> String {
1538    let first_group_width = match digits.len() % 4 {
1539        0 => 4,
1540        remainder => remainder,
1541    };
1542    let mut result = digits[..first_group_width].to_string();
1543    for group_start in (first_group_width..digits.len()).step_by(4) {
1544        result.push('_');
1545        result.push_str(&digits[group_start..group_start + 4]);
1546    }
1547    result
1548}
1549
1550impl TraceValueLayout {
1551    fn byte_count(&self) -> usize {
1552        match self {
1553            Self::Bits { byte_count, .. } => *byte_count,
1554            Self::WideBits { limb_count, .. } => limb_count * std::mem::size_of::<u64>(),
1555            Self::Array {
1556                element,
1557                element_count,
1558            } => element.byte_count() * element_count,
1559            Self::Tuple { byte_count, .. } => *byte_count,
1560            Self::Token => 0,
1561        }
1562    }
1563}
1564
1565#[cfg(test)]
1566mod tests {
1567    use super::*;
1568
1569    #[test]
1570    fn native_bits_wrappers_enforce_semantic_widths() {
1571        let value = BitsInU64::<42>::new((1u64 << 41) | 7).expect("value fits in bits[42]");
1572        assert_eq!(value.to_u64(), (1u64 << 41) | 7);
1573        assert!(BitsInU64::<42>::new(1u64 << 42).is_err());
1574        assert_eq!(BitsInU16::<9>::wrapping(0xffff).get(), 0x1ff);
1575        assert!(BitsInU8::<9>::new(0).is_err());
1576    }
1577
1578    #[test]
1579    fn public_signed_and_unsigned_bits_wrappers_preserve_raw_abi_bits() {
1580        let unsigned = UnsignedBitsInU8::<4>::new(15).expect("u4 max");
1581        assert_eq!(unsigned.to_u64(), 15);
1582        assert_eq!(unsigned.raw_bits(), 15);
1583        assert!(UnsignedBitsInU8::<4>::new(16).is_err());
1584        assert!(std::panic::catch_unwind(|| UnsignedBitsInU8::<4>::from_raw_bits(16)).is_err());
1585
1586        let signed = SignedBitsInU8::<4>::new(-1).expect("s4 -1");
1587        assert_eq!(signed.to_i64(), -1);
1588        assert_eq!(signed.raw_bits(), 15);
1589        assert!(SignedBitsInU8::<4>::new(8).is_err());
1590        assert!(SignedBitsInU8::<4>::new(-9).is_err());
1591        assert!(std::panic::catch_unwind(|| SignedBitsInU8::<4>::from_raw_bits(16)).is_err());
1592        assert_eq!(SignedBitsInU16::<9>::from_raw_bits(0x101).to_i64(), -255);
1593
1594        let wide = SignedWideBits::<65, 2>::from_limbs([u64::MAX, 1]).expect("s65 -1");
1595        assert_eq!(wide.to_bigint(), BigInt::from(-1));
1596        assert_eq!(wide.limbs(), &[u64::MAX, 1]);
1597    }
1598
1599    #[test]
1600    fn public_bits_wrappers_try_from_widened_integers() {
1601        assert_eq!(std::mem::size_of::<Bits0>(), 0);
1602        assert_eq!(std::mem::size_of::<UnsignedBits0>(), 0);
1603        assert_eq!(std::mem::size_of::<SignedBits0>(), 0);
1604
1605        let unsigned_zero = UnsignedBits0::try_from(0_u64).expect("0 fits in u0");
1606        assert_eq!(unsigned_zero.to_u64(), 0);
1607        assert_eq!(unsigned_zero.raw_bits(), 0);
1608        assert!(UnsignedBits0::try_from(1_u64).is_err());
1609        assert!(std::panic::catch_unwind(|| UnsignedBits0::from_raw_bits(1)).is_err());
1610
1611        let signed_zero = SignedBits0::try_from(0_i64).expect("0 fits in s0");
1612        assert_eq!(signed_zero.to_i64(), 0);
1613        assert_eq!(signed_zero.raw_bits(), 0);
1614        assert!(SignedBits0::try_from(-1_i64).is_err());
1615        assert!(SignedBits0::try_from(1_i64).is_err());
1616        assert!(std::panic::catch_unwind(|| SignedBits0::from_raw_bits(1)).is_err());
1617
1618        assert_eq!(
1619            UnsignedBitsInU8::<4>::try_from(15_u64)
1620                .expect("15 fits in u4")
1621                .to_u64(),
1622            15
1623        );
1624        assert!(UnsignedBitsInU8::<4>::try_from(16_u64).is_err());
1625        assert!(UnsignedBitsInU8::<8>::try_from(256_u64).is_err());
1626        assert_eq!(
1627            UnsignedBitsInU16::<9>::try_from(0x1ff_u64)
1628                .expect("0x1ff fits in u9")
1629                .to_u64(),
1630            0x1ff
1631        );
1632        assert_eq!(
1633            UnsignedBitsInU32::<17>::try_from(0x1ffff_u64)
1634                .expect("0x1ffff fits in u17")
1635                .to_u64(),
1636            0x1ffff
1637        );
1638        assert_eq!(
1639            UnsignedBitsInU64::<33>::try_from(0x1ffffffff_u64)
1640                .expect("0x1ffffffff fits in u33")
1641                .to_u64(),
1642            0x1ffffffff
1643        );
1644
1645        assert_eq!(
1646            SignedBitsInU8::<4>::try_from(-8_i64)
1647                .expect("-8 fits in s4")
1648                .to_i64(),
1649            -8
1650        );
1651        assert!(SignedBitsInU8::<4>::try_from(-9_i64).is_err());
1652        assert!(SignedBitsInU8::<8>::try_from(128_i64).is_err());
1653        assert_eq!(
1654            SignedBitsInU16::<9>::try_from(-256_i64)
1655                .expect("-256 fits in s9")
1656                .to_i64(),
1657            -256
1658        );
1659        assert_eq!(
1660            SignedBitsInU32::<17>::try_from(-65_536_i64)
1661                .expect("-65536 fits in s17")
1662                .to_i64(),
1663            -65_536
1664        );
1665        assert_eq!(
1666            SignedBitsInU64::<33>::try_from(-4_294_967_296_i64)
1667                .expect("-4294967296 fits in s33")
1668                .to_i64(),
1669            -4_294_967_296
1670        );
1671    }
1672
1673    #[test]
1674    fn wide_bits_wrappers_use_lsb_first_limbs_and_mask_high_bits() {
1675        let value =
1676            WideBits::<65, 2>::from_limbs([0x0123_4567_89ab_cdef, 1]).expect("canonical value");
1677        assert_eq!(value.limbs(), &[0x0123_4567_89ab_cdef, 1]);
1678        assert!(WideBits::<65, 2>::from_limbs([0, 2]).is_err());
1679        assert_eq!(WideBits::<65, 2>::wrapping_limbs([7, 3]).limbs(), &[7, 1]);
1680        assert!(WideBits::<65, 3>::from_limbs([0, 0, 0]).is_err());
1681        assert_eq!(std::mem::size_of::<Token>(), 0);
1682    }
1683
1684    fn metadata() -> CompiledFunctionMetadata {
1685        CompiledFunctionMetadata {
1686            event_sites: vec![
1687                EventSiteMetadata {
1688                    node_text_id: 10,
1689                    kind: EventKind::Cover,
1690                    label: Some("covered".to_string()),
1691                    message: None,
1692                    format: None,
1693                    verbosity: 0,
1694                    operand_layouts: Vec::new(),
1695                },
1696                EventSiteMetadata {
1697                    node_text_id: 11,
1698                    kind: EventKind::Assert,
1699                    label: Some("assert_label".to_string()),
1700                    message: Some("failed".to_string()),
1701                    format: None,
1702                    verbosity: 0,
1703                    operand_layouts: Vec::new(),
1704                },
1705                EventSiteMetadata {
1706                    node_text_id: 12,
1707                    kind: EventKind::Trace,
1708                    label: None,
1709                    message: None,
1710                    format: Some("x={} arr={}".to_string()),
1711                    verbosity: 1,
1712                    operand_layouts: vec![
1713                        TraceValueLayout::Bits {
1714                            bit_count: 8,
1715                            byte_count: 1,
1716                        },
1717                        TraceValueLayout::Array {
1718                            element: Box::new(TraceValueLayout::Bits {
1719                                bit_count: 8,
1720                                byte_count: 1,
1721                            }),
1722                            element_count: 2,
1723                        },
1724                    ],
1725                },
1726                EventSiteMetadata {
1727                    node_text_id: 13,
1728                    kind: EventKind::Assumption(AssumptionFailureKind::ArrayIndexOutOfBounds),
1729                    label: None,
1730                    message: None,
1731                    format: None,
1732                    verbosity: 0,
1733                    operand_layouts: Vec::new(),
1734                },
1735            ],
1736        }
1737    }
1738
1739    #[test]
1740    fn cover_and_assert_callbacks_collect_rust_owned_results() {
1741        let metadata = metadata();
1742        let mut context =
1743            ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1744        let mut raw = context.raw_context();
1745        // SAFETY: `raw` points into `context` for these immediate calls.
1746        unsafe {
1747            xlsynth_pir_record_cover(&mut raw, 0);
1748            xlsynth_pir_record_cover(&mut raw, 0);
1749            xlsynth_pir_record_assert(&mut raw, 1);
1750            xlsynth_pir_record_assumption_failure(&mut raw, 3);
1751        }
1752        let result = context.result();
1753        assert_eq!(result.cover_counts[0].count, 2);
1754        assert_eq!(result.cover_counts[0].label, "covered");
1755        assert_eq!(result.assertion_failures[0].message, "failed");
1756        assert_eq!(result.assertion_failures[0].label, "assert_label");
1757        assert_eq!(
1758            result.assumption_failures,
1759            vec![AssumptionFailure {
1760                node_text_id: 13,
1761                kind: AssumptionFailureKind::ArrayIndexOutOfBounds,
1762            }]
1763        );
1764    }
1765
1766    #[test]
1767    fn trace_callback_decodes_values_before_native_storage_changes() {
1768        let metadata = metadata();
1769        let mut context =
1770            ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1771        let mut raw = context.raw_context();
1772        let mut scalar = 7u8;
1773        let mut array = [2u8, 3u8];
1774        let operands = [
1775            ptr::from_ref(&scalar).cast::<u8>(),
1776            ptr::from_ref(&array).cast::<u8>(),
1777        ];
1778        // SAFETY: operands use the native layouts specified by trace metadata.
1779        unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1780        scalar = 90;
1781        array[0] = 91;
1782        assert_eq!(scalar, 90);
1783        assert_eq!(array[0], 91);
1784        assert_eq!(context.result().trace_messages[0].message, "x=7 arr=[2, 3]");
1785    }
1786
1787    #[test]
1788    fn trace_callback_formats_all_specifiers_and_wide_decimal_without_xls() {
1789        let twelve_bits = TraceValueLayout::Bits {
1790            bit_count: 12,
1791            byte_count: 2,
1792        };
1793        let metadata = CompiledFunctionMetadata {
1794            event_sites: vec![EventSiteMetadata {
1795                node_text_id: 20,
1796                kind: EventKind::Trace,
1797                label: None,
1798                message: None,
1799                format: Some(
1800                    "literal={{ default={} u={:u} d={:d} x={:x} 0x={:0x} #x={:#x} b={:b} 0b={:0b} #b={:#b} wide={} wide_u={:u}".to_string(),
1801                    ),
1802                verbosity: 0,
1803                operand_layouts: vec![
1804                    twelve_bits.clone(),
1805                    twelve_bits.clone(),
1806                    TraceValueLayout::Bits {
1807                        bit_count: 8,
1808                        byte_count: 1,
1809                    },
1810                    twelve_bits.clone(),
1811                    twelve_bits.clone(),
1812                    twelve_bits.clone(),
1813                    twelve_bits.clone(),
1814                    twelve_bits.clone(),
1815                    twelve_bits,
1816                    TraceValueLayout::Bits {
1817                        bit_count: 72,
1818                        byte_count: 9,
1819                    },
1820                    TraceValueLayout::Bits {
1821                        bit_count: 72,
1822                        byte_count: 9,
1823                    },
1824                ],
1825            }],
1826        };
1827        let mut context =
1828            ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1829        let mut raw = context.raw_context();
1830        let twelve = 43u16;
1831        let negative = 251u8;
1832        let wide = [1u8, 0, 0, 0, 0, 0, 0, 0, 1];
1833        let operands = [
1834            ptr::from_ref(&twelve).cast::<u8>(),
1835            ptr::from_ref(&twelve).cast::<u8>(),
1836            ptr::from_ref(&negative).cast::<u8>(),
1837            ptr::from_ref(&twelve).cast::<u8>(),
1838            ptr::from_ref(&twelve).cast::<u8>(),
1839            ptr::from_ref(&twelve).cast::<u8>(),
1840            ptr::from_ref(&twelve).cast::<u8>(),
1841            ptr::from_ref(&twelve).cast::<u8>(),
1842            ptr::from_ref(&twelve).cast::<u8>(),
1843            ptr::from_ref(&wide).cast::<u8>(),
1844            ptr::from_ref(&wide).cast::<u8>(),
1845        ];
1846        // SAFETY: operands use the native layouts specified by trace metadata.
1847        unsafe { xlsynth_pir_record_trace(&mut raw, 0, operands.as_ptr()) };
1848        assert_eq!(
1849            context.result().trace_messages[0].message,
1850            "literal={{ default=43 u=43 d=-5 x=2b 0x=02b #x=0x2b b=101011 0b=0000_0010_1011 #b=0b10_1011 wide=0x1_0000_0000_0000_0001 wide_u=18446744073709551617"
1851        );
1852    }
1853
1854    #[test]
1855    fn clear_resets_accumulated_event_results() {
1856        let metadata = metadata();
1857        let mut context =
1858            ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1859        let mut raw = context.raw_context();
1860        // SAFETY: `raw` points into `context` for this immediate call.
1861        unsafe { xlsynth_pir_record_cover(&mut raw, 0) };
1862        context.clear();
1863        let result = context.result();
1864        assert!(result.assertion_failures.is_empty());
1865        assert!(result.assumption_failures.is_empty());
1866        assert!(result.trace_messages.is_empty());
1867        assert_eq!(result.cover_counts[0].count, 0);
1868    }
1869
1870    #[test]
1871    fn default_context_does_not_collect_traces_or_covers() {
1872        let metadata = metadata();
1873        let mut context = ExecutionContext::new(&metadata);
1874        let mut raw = context.raw_context();
1875        let scalar = 7u8;
1876        let array = [2u8, 3u8];
1877        let operands = [
1878            ptr::from_ref(&scalar).cast::<u8>(),
1879            ptr::from_ref(&array).cast::<u8>(),
1880        ];
1881        // SAFETY: `raw` and operands point to valid test storage.
1882        unsafe {
1883            xlsynth_pir_record_cover(&mut raw, 0);
1884            xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr());
1885        }
1886        let result = context.result();
1887        assert!(result.cover_counts.is_empty());
1888        assert!(result.trace_messages.is_empty());
1889    }
1890
1891    #[test]
1892    fn trace_callback_respects_runtime_verbosity() {
1893        let metadata = metadata();
1894        let scalar = 7u8;
1895        let array = [2u8, 3u8];
1896        let operands = [
1897            ptr::from_ref(&scalar).cast::<u8>(),
1898            ptr::from_ref(&array).cast::<u8>(),
1899        ];
1900        let mut context = ExecutionContext::new_with_options(
1901            &metadata,
1902            ExecutionOptions::new(Some(0), /* collect_covers= */ false),
1903        );
1904        let mut raw = context.raw_context();
1905        // SAFETY: `raw` and operands point to valid test storage.
1906        unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1907        assert!(context.result().trace_messages.is_empty());
1908
1909        context.clear_with_options(ExecutionOptions::new(
1910            Some(1),
1911            /* collect_covers= */ false,
1912        ));
1913        let mut raw = context.raw_context();
1914        // SAFETY: `raw` and operands point to valid test storage.
1915        unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1916        assert_eq!(context.result().trace_messages[0].message, "x=7 arr=[2, 3]");
1917    }
1918
1919    #[test]
1920    fn wide_trace_values_use_lsb_first_native_limbs() {
1921        let value = [1u64, 1u64];
1922        // SAFETY: `value` contains the two limbs required for bits[72].
1923        let formatted = unsafe {
1924            format_native_value(
1925                value.as_ptr().cast(),
1926                &TraceValueLayout::WideBits {
1927                    bit_count: 72,
1928                    limb_count: 2,
1929                },
1930                TraceFormatPreference::Hex,
1931            )
1932        };
1933        assert_eq!(formatted, "0x1_0000_0000_0000_0001");
1934    }
1935
1936    #[test]
1937    fn wide_binary_runtime_helpers_cover_arithmetic_shifts_and_slices() {
1938        let lhs = [u64::MAX, 1];
1939        let rhs = [2u64, 0];
1940        let mut output = [0u64; 2];
1941        // SAFETY: all arrays contain the required two native limbs.
1942        unsafe {
1943            xlsynth_pir_runtime_wide_binop(
1944                output.as_mut_ptr(),
1945                65,
1946                lhs.as_ptr(),
1947                65,
1948                rhs.as_ptr(),
1949                65,
1950                WideBinaryOp::Umul as u32,
1951            );
1952        }
1953        assert_eq!(output, [u64::MAX - 1, 1]);
1954
1955        let negative = [0u64, 1];
1956        let shift = [1u64, 0];
1957        // SAFETY: all arrays contain the required two native limbs.
1958        unsafe {
1959            xlsynth_pir_runtime_wide_binop(
1960                output.as_mut_ptr(),
1961                65,
1962                negative.as_ptr(),
1963                65,
1964                shift.as_ptr(),
1965                65,
1966                WideBinaryOp::Shra as u32,
1967            );
1968        }
1969        assert_eq!(output, [1u64 << 63, 1]);
1970
1971        let start = [63u64, 0];
1972        // SAFETY: all arrays contain the limbs prescribed by their widths.
1973        unsafe {
1974            xlsynth_pir_runtime_wide_dynamic_bit_slice(
1975                output.as_mut_ptr(),
1976                65,
1977                lhs.as_ptr(),
1978                65,
1979                start.as_ptr(),
1980                65,
1981            );
1982        }
1983        assert_eq!(output, [3, 0]);
1984
1985        let zero = [0u64, 0];
1986        let update = [3u64, 0];
1987        // SAFETY: all arrays contain the limbs prescribed by their widths.
1988        unsafe {
1989            xlsynth_pir_runtime_wide_bit_slice_update(
1990                output.as_mut_ptr(),
1991                65,
1992                zero.as_ptr(),
1993                65,
1994                start.as_ptr(),
1995                65,
1996                update.as_ptr(),
1997                65,
1998            );
1999        }
2000        assert_eq!(output, [1u64 << 63, 1]);
2001    }
2002
2003    #[test]
2004    fn wide_runtime_helpers_accept_zero_width_storage() {
2005        // SAFETY: zero-width values contain no limbs, so null pointers satisfy
2006        // the callback storage contract.
2007        unsafe {
2008            xlsynth_pir_runtime_wide_binop(
2009                ptr::null_mut(),
2010                0,
2011                ptr::null(),
2012                0,
2013                ptr::null(),
2014                0,
2015                WideBinaryOp::Sdiv as u32,
2016            );
2017            xlsynth_pir_runtime_wide_mulp(
2018                ptr::null_mut(),
2019                ptr::null_mut(),
2020                0,
2021                ptr::null(),
2022                0,
2023                ptr::null(),
2024                0,
2025                0,
2026            );
2027        }
2028    }
2029
2030    #[test]
2031    fn wide_unary_runtime_helpers_cover_encoding_and_extensions() {
2032        let input = [1u64 << 63, 1];
2033        let mut output = [0u64; 3];
2034        // SAFETY: all arrays contain sufficient native limbs for their widths.
2035        unsafe {
2036            xlsynth_pir_runtime_wide_unary_op(
2037                output.as_mut_ptr(),
2038                66,
2039                input.as_ptr(),
2040                65,
2041                WideUnaryOp::OneHot as u32,
2042                1,
2043            );
2044        }
2045        assert_eq!(output[..2], [1u64 << 63, 0]);
2046
2047        // SAFETY: all arrays contain sufficient native limbs for their widths.
2048        unsafe {
2049            xlsynth_pir_runtime_wide_unary_op(
2050                output.as_mut_ptr(),
2051                7,
2052                input.as_ptr(),
2053                65,
2054                WideUnaryOp::ExtPrioEncode as u32,
2055                0,
2056            );
2057        }
2058        assert_eq!(output[0], 64);
2059
2060        let zeros = [0u64; 2];
2061        // SAFETY: all arrays contain sufficient native limbs for their widths.
2062        unsafe {
2063            xlsynth_pir_runtime_wide_unary_op(
2064                output.as_mut_ptr(),
2065                129,
2066                zeros.as_ptr(),
2067                65,
2068                WideUnaryOp::ExtMaskLow as u32,
2069                0,
2070            );
2071        }
2072        assert_eq!(output, [0, 0, 0]);
2073
2074        let count = [80u64, 0];
2075        // SAFETY: all arrays contain sufficient native limbs for their widths.
2076        unsafe {
2077            xlsynth_pir_runtime_wide_unary_op(
2078                output.as_mut_ptr(),
2079                129,
2080                count.as_ptr(),
2081                65,
2082                WideUnaryOp::ExtMaskLow as u32,
2083                0,
2084            );
2085        }
2086        assert_eq!(output, [u64::MAX, 0xffff, 0]);
2087    }
2088
2089    #[test]
2090    fn wide_mulp_runtime_helper_returns_components_summing_to_product() {
2091        let lhs = [u64::MAX, 1];
2092        let rhs = [3u64, 0];
2093        let mut offset = [0u64; 3];
2094        let mut residual = [0u64; 3];
2095        // SAFETY: all arrays contain sufficient native limbs for their widths.
2096        unsafe {
2097            xlsynth_pir_runtime_wide_mulp(
2098                offset.as_mut_ptr(),
2099                residual.as_mut_ptr(),
2100                129,
2101                lhs.as_ptr(),
2102                65,
2103                rhs.as_ptr(),
2104                65,
2105                0,
2106            );
2107        }
2108        let offset = unsafe { read_wide_bits(offset.as_ptr(), 129) };
2109        let residual = unsafe { read_wide_bits(residual.as_ptr(), 129) };
2110        assert_eq!(
2111            truncate_unsigned(offset + residual, 129),
2112            truncate_unsigned(
2113                unsafe { read_wide_bits(lhs.as_ptr(), 65) }
2114                    * unsafe { read_wide_bits(rhs.as_ptr(), 65) },
2115                129,
2116            )
2117        );
2118    }
2119}