1use std::ffi::c_void;
6use std::fmt;
7use std::marker::PhantomData;
8use std::ptr;
9
10use num_bigint::{BigInt, BigUint, Sign};
11
12pub type CompiledEntrypoint = unsafe extern "C" fn(
14 inputs: *const *const u8,
15 output: *mut u8,
16 scratch: *mut u8,
17 context: *mut RawExecutionContext,
18) -> i32;
19
20#[repr(C)]
22pub struct RawExecutionContext {
23 private_state: *mut c_void,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct RunError(pub String);
29
30impl fmt::Display for RunError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.write_str(&self.0)
33 }
34}
35
36impl std::error::Error for RunError {}
37
38macro_rules! define_native_bits {
39 ($name:ident, $carrier:ty, $carrier_bits:expr) => {
40 #[repr(transparent)]
41 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
42 pub struct $name<const BIT_COUNT: usize>($carrier);
43
44 impl<const BIT_COUNT: usize> $name<BIT_COUNT> {
45 fn validate_width() -> Result<(), RunError> {
46 if BIT_COUNT == 0 || BIT_COUNT > $carrier_bits {
47 Err(RunError(format!(
48 "bits[{BIT_COUNT}] cannot use a {}-bit native carrier",
49 $carrier_bits
50 )))
51 } else {
52 Ok(())
53 }
54 }
55
56 const fn mask() -> $carrier {
57 if BIT_COUNT == $carrier_bits {
58 <$carrier>::MAX
59 } else {
60 ((1 as $carrier) << BIT_COUNT) - 1
61 }
62 }
63
64 pub fn new(value: $carrier) -> Result<Self, RunError> {
66 Self::validate_width()?;
67 if value & !Self::mask() != 0 {
68 Err(RunError(format!(
69 "value {value} does not fit in bits[{BIT_COUNT}]"
70 )))
71 } else {
72 Ok(Self(value))
73 }
74 }
75
76 pub const fn wrapping(value: $carrier) -> Self {
78 assert!(
79 BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
80 "invalid native bits carrier width"
81 );
82 Self(value & Self::mask())
83 }
84
85 pub const fn get(self) -> $carrier {
87 self.0
88 }
89
90 pub const fn to_u64(self) -> u64 {
92 self.0 as u64
93 }
94 }
95 };
96}
97
98define_native_bits!(BitsInU8, u8, 8);
99define_native_bits!(BitsInU16, u16, 16);
100define_native_bits!(BitsInU32, u32, 32);
101define_native_bits!(BitsInU64, u64, 64);
102
103#[repr(C)]
105#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
106pub struct Bits0;
107
108#[repr(C)]
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
111pub struct UnsignedBits0;
112
113impl UnsignedBits0 {
114 pub const fn from_raw_bits(value: u64) -> Self {
116 assert!(value == 0, "raw bits do not fit target width");
117 Self
118 }
119
120 pub const fn to_u64(self) -> u64 {
122 0
123 }
124
125 pub const fn raw_bits(self) -> u64 {
127 0
128 }
129}
130
131impl TryFrom<u64> for UnsignedBits0 {
132 type Error = RunError;
133
134 fn try_from(value: u64) -> Result<Self, Self::Error> {
135 if value == 0 {
136 Ok(Self)
137 } else {
138 Err(RunError(format!("value {value} does not fit in bits[0]")))
139 }
140 }
141}
142
143#[repr(C)]
145#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
146pub struct SignedBits0;
147
148impl SignedBits0 {
149 pub const fn from_raw_bits(value: u64) -> Self {
151 assert!(value == 0, "raw bits do not fit target width");
152 Self
153 }
154
155 pub const fn to_i64(self) -> i64 {
157 0
158 }
159
160 pub const fn raw_bits(self) -> u64 {
162 0
163 }
164}
165
166impl TryFrom<i64> for SignedBits0 {
167 type Error = RunError;
168
169 fn try_from(value: i64) -> Result<Self, Self::Error> {
170 if value == 0 {
171 Ok(Self)
172 } else {
173 Err(RunError(format!("value {value} does not fit in s0")))
174 }
175 }
176}
177
178macro_rules! define_public_bits_wrappers {
179 (
180 $unsigned_name:ident,
181 $signed_name:ident,
182 $raw_name:ident,
183 $unsigned_carrier:ty,
184 $signed_carrier:ty,
185 $carrier_bits:expr
186 ) => {
187 #[repr(transparent)]
188 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
189 pub struct $unsigned_name<const BIT_COUNT: usize>($raw_name<BIT_COUNT>);
190
191 impl<const BIT_COUNT: usize> $unsigned_name<BIT_COUNT> {
192 const fn mask() -> $unsigned_carrier {
193 if BIT_COUNT == $carrier_bits {
194 <$unsigned_carrier>::MAX
195 } else {
196 ((1 as $unsigned_carrier) << BIT_COUNT) - 1
197 }
198 }
199
200 pub fn new(value: $unsigned_carrier) -> Result<Self, RunError> {
202 Ok(Self($raw_name::<BIT_COUNT>::new(value)?))
203 }
204
205 pub const fn wrapping(value: $unsigned_carrier) -> Self {
207 Self($raw_name::<BIT_COUNT>::wrapping(value))
208 }
209
210 pub const fn from_raw_bits(value: $unsigned_carrier) -> Self {
212 assert!(
213 BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
214 "invalid raw bit carrier width"
215 );
216 assert!(
217 value & !Self::mask() == 0,
218 "raw bits do not fit target width"
219 );
220 Self($raw_name::<BIT_COUNT>::wrapping(value))
221 }
222
223 pub const fn to_u64(self) -> u64 {
226 self.0.to_u64()
227 }
228
229 pub const fn raw_bits(self) -> $unsigned_carrier {
231 self.0.get()
232 }
233 }
234
235 impl<const BIT_COUNT: usize> TryFrom<u64> for $unsigned_name<BIT_COUNT> {
236 type Error = RunError;
237
238 fn try_from(value: u64) -> Result<Self, Self::Error> {
239 let carrier = <$unsigned_carrier>::try_from(value).map_err(|_| {
240 RunError(format!(
241 "value {value} does not fit in bits[{BIT_COUNT}]"
242 ))
243 })?;
244 Self::new(carrier)
245 }
246 }
247
248 #[repr(transparent)]
249 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
250 pub struct $signed_name<const BIT_COUNT: usize>($raw_name<BIT_COUNT>);
251
252 impl<const BIT_COUNT: usize> $signed_name<BIT_COUNT> {
253 fn validate_signed_value(value: $signed_carrier) -> Result<(), RunError> {
254 if BIT_COUNT == 0 || BIT_COUNT > $carrier_bits {
255 return Err(RunError(format!(
256 "s{BIT_COUNT} cannot use a {}-bit native carrier",
257 $carrier_bits
258 )));
259 }
260 let min = -(1i128 << (BIT_COUNT - 1));
261 let max = (1i128 << (BIT_COUNT - 1)) - 1;
262 let value = value as i128;
263 if value < min || value > max {
264 Err(RunError(format!(
265 "value {value} does not fit in s{BIT_COUNT}"
266 )))
267 } else {
268 Ok(())
269 }
270 }
271
272 const fn mask() -> $unsigned_carrier {
273 if BIT_COUNT == $carrier_bits {
274 <$unsigned_carrier>::MAX
275 } else {
276 ((1 as $unsigned_carrier) << BIT_COUNT) - 1
277 }
278 }
279
280 pub fn new(value: $signed_carrier) -> Result<Self, RunError> {
282 Self::validate_signed_value(value)?;
283 Ok(Self($raw_name::<BIT_COUNT>::wrapping(
284 value as $unsigned_carrier,
285 )))
286 }
287
288 pub const fn wrapping(value: $signed_carrier) -> Self {
290 Self($raw_name::<BIT_COUNT>::wrapping(
291 value as $unsigned_carrier,
292 ))
293 }
294
295 pub const fn from_raw_bits(value: $unsigned_carrier) -> Self {
297 assert!(
298 BIT_COUNT > 0 && BIT_COUNT <= $carrier_bits,
299 "invalid raw bit carrier width"
300 );
301 assert!(
302 value & !Self::mask() == 0,
303 "raw bits do not fit target width"
304 );
305 Self($raw_name::<BIT_COUNT>::wrapping(value))
306 }
307
308 fn to_signed_carrier(self) -> $signed_carrier {
310 let raw = self.0.get();
311 let sign_bit = 1 as $unsigned_carrier << (BIT_COUNT - 1);
312 if raw & sign_bit == 0 {
313 raw as $signed_carrier
314 } else {
315 (raw | !Self::mask()) as $signed_carrier
316 }
317 }
318
319 pub fn to_i64(self) -> i64 {
321 self.to_signed_carrier() as i64
322 }
323
324 pub const fn raw_bits(self) -> $unsigned_carrier {
326 self.0.get()
327 }
328 }
329
330 impl<const BIT_COUNT: usize> TryFrom<i64> for $signed_name<BIT_COUNT> {
331 type Error = RunError;
332
333 fn try_from(value: i64) -> Result<Self, Self::Error> {
334 let carrier = <$signed_carrier>::try_from(value).map_err(|_| {
335 RunError(format!("value {value} does not fit in s{BIT_COUNT}"))
336 })?;
337 Self::new(carrier)
338 }
339 }
340 };
341}
342
343define_public_bits_wrappers!(UnsignedBitsInU8, SignedBitsInU8, BitsInU8, u8, i8, 8);
344define_public_bits_wrappers!(UnsignedBitsInU16, SignedBitsInU16, BitsInU16, u16, i16, 16);
345define_public_bits_wrappers!(UnsignedBitsInU32, SignedBitsInU32, BitsInU32, u32, i32, 32);
346define_public_bits_wrappers!(UnsignedBitsInU64, SignedBitsInU64, BitsInU64, u64, i64, 64);
347
348#[repr(transparent)]
351#[derive(Debug, Clone, Copy, PartialEq, Eq)]
352pub struct WideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>([u64; LIMB_COUNT]);
353
354impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> WideBits<BIT_COUNT, LIMB_COUNT> {
355 fn validate_layout() -> Result<(), RunError> {
356 if BIT_COUNT <= 64 || LIMB_COUNT != BIT_COUNT.div_ceil(64) {
357 Err(RunError(format!(
358 "bits[{BIT_COUNT}] cannot use {LIMB_COUNT} native wide limb(s)"
359 )))
360 } else {
361 Ok(())
362 }
363 }
364
365 fn high_mask() -> u64 {
366 let high_width = BIT_COUNT % 64;
367 if high_width == 0 {
368 u64::MAX
369 } else {
370 (1u64 << high_width) - 1
371 }
372 }
373
374 pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
376 Self::validate_layout()?;
377 if limbs[LIMB_COUNT - 1] & !Self::high_mask() != 0 {
378 Err(RunError(format!(
379 "high limb does not fit in bits[{BIT_COUNT}]"
380 )))
381 } else {
382 Ok(Self(limbs))
383 }
384 }
385
386 pub fn wrapping_limbs(mut limbs: [u64; LIMB_COUNT]) -> Self {
388 assert!(
389 BIT_COUNT > 64 && LIMB_COUNT == BIT_COUNT.div_ceil(64),
390 "invalid native wide bits layout"
391 );
392 limbs[LIMB_COUNT - 1] &= Self::high_mask();
393 Self(limbs)
394 }
395
396 pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
398 &self.0
399 }
400}
401
402impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default for WideBits<BIT_COUNT, LIMB_COUNT> {
403 fn default() -> Self {
404 Self([0; LIMB_COUNT])
405 }
406}
407
408#[repr(transparent)]
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
411pub struct UnsignedWideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>(
412 WideBits<BIT_COUNT, LIMB_COUNT>,
413);
414
415impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> UnsignedWideBits<BIT_COUNT, LIMB_COUNT> {
416 pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
419 Ok(Self(WideBits::<BIT_COUNT, LIMB_COUNT>::from_limbs(limbs)?))
420 }
421
422 pub fn wrapping_limbs(limbs: [u64; LIMB_COUNT]) -> Self {
425 Self(WideBits::<BIT_COUNT, LIMB_COUNT>::wrapping_limbs(limbs))
426 }
427
428 pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
430 self.0.limbs()
431 }
432
433 pub fn to_biguint(&self) -> BigUint {
435 let mut bytes = Vec::with_capacity(LIMB_COUNT * std::mem::size_of::<u64>());
436 for limb in self.limbs() {
437 bytes.extend_from_slice(&limb.to_le_bytes());
438 }
439 BigUint::from_bytes_le(&bytes)
440 }
441}
442
443impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default
444 for UnsignedWideBits<BIT_COUNT, LIMB_COUNT>
445{
446 fn default() -> Self {
447 Self(WideBits::default())
448 }
449}
450
451#[repr(transparent)]
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub struct SignedWideBits<const BIT_COUNT: usize, const LIMB_COUNT: usize>(
455 WideBits<BIT_COUNT, LIMB_COUNT>,
456);
457
458impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> SignedWideBits<BIT_COUNT, LIMB_COUNT> {
459 pub fn from_limbs(limbs: [u64; LIMB_COUNT]) -> Result<Self, RunError> {
461 Ok(Self(WideBits::<BIT_COUNT, LIMB_COUNT>::from_limbs(limbs)?))
462 }
463
464 pub fn wrapping_limbs(limbs: [u64; LIMB_COUNT]) -> Self {
467 Self(WideBits::<BIT_COUNT, LIMB_COUNT>::wrapping_limbs(limbs))
468 }
469
470 pub const fn limbs(&self) -> &[u64; LIMB_COUNT] {
472 self.0.limbs()
473 }
474
475 pub fn to_bigint(&self) -> BigInt {
477 let unsigned = UnsignedWideBits::<BIT_COUNT, LIMB_COUNT>(self.0).to_biguint();
478 if BIT_COUNT == 0 || !unsigned.bit((BIT_COUNT - 1) as u64) {
479 BigInt::from_biguint(Sign::Plus, unsigned)
480 } else {
481 BigInt::from_biguint(Sign::Plus, unsigned) - (BigInt::from(1u8) << BIT_COUNT)
482 }
483 }
484}
485
486impl<const BIT_COUNT: usize, const LIMB_COUNT: usize> Default
487 for SignedWideBits<BIT_COUNT, LIMB_COUNT>
488{
489 fn default() -> Self {
490 Self(WideBits::default())
491 }
492}
493
494#[repr(C)]
496#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
497pub struct Token;
498
499#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501pub enum EventKind {
502 Assert,
503 Assumption(AssumptionFailureKind),
504 Cover,
505 Trace,
506}
507
508#[derive(Debug, Clone, PartialEq, Eq)]
510pub enum TraceValueLayout {
511 Bits {
512 bit_count: usize,
513 byte_count: usize,
514 },
515 WideBits {
516 bit_count: usize,
517 limb_count: usize,
518 },
519 Array {
520 element: Box<TraceValueLayout>,
521 element_count: usize,
522 },
523 Tuple {
524 fields: Vec<TraceTupleFieldLayout>,
525 byte_count: usize,
526 },
527 Token,
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532#[repr(u32)]
533pub enum WideBinaryOp {
534 Umul = 0,
535 Smul = 1,
536 Udiv = 2,
537 Sdiv = 3,
538 Umod = 4,
539 Smod = 5,
540 Shll = 6,
541 Shrl = 7,
542 Shra = 8,
543}
544
545impl WideBinaryOp {
546 fn from_abi(value: u32) -> Option<Self> {
547 Some(match value {
548 0 => Self::Umul,
549 1 => Self::Smul,
550 2 => Self::Udiv,
551 3 => Self::Sdiv,
552 4 => Self::Umod,
553 5 => Self::Smod,
554 6 => Self::Shll,
555 7 => Self::Shrl,
556 8 => Self::Shra,
557 _ => return None,
558 })
559 }
560}
561
562#[derive(Debug, Clone, Copy, PartialEq, Eq)]
564#[repr(u32)]
565pub enum WideUnaryOp {
566 OneHot = 0,
567 Encode = 1,
568 Decode = 2,
569 ExtPrioEncode = 3,
570 ExtClz = 4,
571 ExtNormalizeLeft = 5,
572 ExtMaskLow = 6,
573}
574
575impl WideUnaryOp {
576 fn from_abi(value: u32) -> Option<Self> {
577 Some(match value {
578 0 => Self::OneHot,
579 1 => Self::Encode,
580 2 => Self::Decode,
581 3 => Self::ExtPrioEncode,
582 4 => Self::ExtClz,
583 5 => Self::ExtNormalizeLeft,
584 6 => Self::ExtMaskLow,
585 _ => return None,
586 })
587 }
588}
589
590#[derive(Debug, Clone, PartialEq, Eq)]
592pub struct TraceTupleFieldLayout {
593 pub layout: Box<TraceValueLayout>,
594 pub offset: usize,
595}
596
597#[derive(Debug, Clone, PartialEq, Eq)]
599pub struct EventSiteMetadata {
600 pub node_text_id: usize,
601 pub kind: EventKind,
602 pub label: Option<String>,
603 pub message: Option<String>,
604 pub format: Option<String>,
605 pub verbosity: i64,
606 pub operand_layouts: Vec<TraceValueLayout>,
607}
608
609#[derive(Debug, Clone, Default, PartialEq, Eq)]
611pub struct CompiledFunctionMetadata {
612 pub event_sites: Vec<EventSiteMetadata>,
613}
614
615#[derive(Debug, Clone, PartialEq, Eq)]
617pub struct AssertionFailure {
618 pub node_text_id: usize,
619 pub message: String,
620 pub label: String,
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
625pub enum AssumptionFailureKind {
626 ArrayIndexOutOfBounds,
627 ArrayUpdateOutOfBounds,
628}
629
630#[derive(Debug, Clone, PartialEq, Eq)]
632pub struct AssumptionFailure {
633 pub node_text_id: usize,
634 pub kind: AssumptionFailureKind,
635}
636
637#[derive(Debug, Clone, PartialEq, Eq)]
639pub struct TraceMessage {
640 pub node_text_id: usize,
641 pub message: String,
642 pub verbosity: i64,
643}
644
645#[derive(Debug, Clone, PartialEq, Eq)]
647pub struct CoverCount {
648 pub node_text_id: usize,
649 pub label: String,
650 pub count: u64,
651}
652
653#[derive(Debug, Clone, Default, PartialEq, Eq)]
655pub struct ExecutionResult {
656 pub assertion_failures: Vec<AssertionFailure>,
657 pub assumption_failures: Vec<AssumptionFailure>,
658 pub trace_messages: Vec<TraceMessage>,
659 pub cover_counts: Vec<CoverCount>,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Eq)]
664pub struct ExecutionOptions {
665 pub trace_verbosity: Option<i64>,
666 pub collect_covers: bool,
667}
668
669impl ExecutionOptions {
670 pub const NO_EVENTS: Self = Self {
672 trace_verbosity: None,
673 collect_covers: false,
674 };
675
676 pub const fn new(trace_verbosity: Option<i64>, collect_covers: bool) -> Self {
678 Self {
679 trace_verbosity,
680 collect_covers,
681 }
682 }
683
684 pub const fn collect_all() -> Self {
686 Self {
687 trace_verbosity: Some(i64::MAX),
688 collect_covers: true,
689 }
690 }
691}
692
693impl Default for ExecutionOptions {
694 fn default() -> Self {
695 Self::NO_EVENTS
696 }
697}
698
699#[derive(Debug, Clone, Copy, PartialEq, Eq)]
700enum TraceFormatPreference {
701 Default,
702 UnsignedDecimal,
703 SignedDecimal,
704 PlainHex,
705 ZeroPaddedHex,
706 Hex,
707 PlainBinary,
708 ZeroPaddedBinary,
709 Binary,
710}
711
712const TRACE_FORMAT_SPECIFIERS: [(&str, TraceFormatPreference); 9] = [
713 ("{}", TraceFormatPreference::Default),
714 ("{:u}", TraceFormatPreference::UnsignedDecimal),
715 ("{:d}", TraceFormatPreference::SignedDecimal),
716 ("{:x}", TraceFormatPreference::PlainHex),
717 ("{:0x}", TraceFormatPreference::ZeroPaddedHex),
718 ("{:#x}", TraceFormatPreference::Hex),
719 ("{:b}", TraceFormatPreference::PlainBinary),
720 ("{:0b}", TraceFormatPreference::ZeroPaddedBinary),
721 ("{:#b}", TraceFormatPreference::Binary),
722];
723
724struct ContextState {
725 metadata: *const CompiledFunctionMetadata,
726 options: ExecutionOptions,
727 assertion_failures: Vec<AssertionFailure>,
728 assumption_failures: Vec<AssumptionFailure>,
729 trace_messages: Vec<TraceMessage>,
730 event_counts: Option<Vec<u64>>,
731}
732
733pub struct ExecutionContext<'metadata> {
739 state: Box<ContextState>,
740 marker: PhantomData<&'metadata CompiledFunctionMetadata>,
741}
742
743impl<'metadata> ExecutionContext<'metadata> {
744 pub fn new(metadata: &'metadata CompiledFunctionMetadata) -> Self {
746 Self::new_with_options(metadata, ExecutionOptions::default())
747 }
748
749 pub fn new_with_options(
751 metadata: &'metadata CompiledFunctionMetadata,
752 options: ExecutionOptions,
753 ) -> Self {
754 let event_counts = options
755 .collect_covers
756 .then(|| vec![0; metadata.event_sites.len()]);
757 Self {
758 state: Box::new(ContextState {
759 metadata,
760 options,
761 assertion_failures: Vec::new(),
762 assumption_failures: Vec::new(),
763 trace_messages: Vec::new(),
764 event_counts,
765 }),
766 marker: PhantomData,
767 }
768 }
769
770 pub fn raw_context(&mut self) -> RawExecutionContext {
773 RawExecutionContext {
774 private_state: ptr::from_mut(self.state.as_mut()).cast(),
775 }
776 }
777
778 pub fn result(&self) -> ExecutionResult {
780 let metadata = self.metadata();
781 let cover_counts = self
782 .state
783 .event_counts
784 .as_ref()
785 .map(|event_counts| {
786 metadata
787 .event_sites
788 .iter()
789 .zip(event_counts)
790 .filter(|(site, _)| site.kind == EventKind::Cover)
791 .map(|(site, count)| CoverCount {
792 node_text_id: site.node_text_id,
793 label: site.label.clone().unwrap_or_default(),
794 count: *count,
795 })
796 .collect()
797 })
798 .unwrap_or_default();
799 ExecutionResult {
800 assertion_failures: self.state.assertion_failures.clone(),
801 assumption_failures: self.state.assumption_failures.clone(),
802 trace_messages: self.state.trace_messages.clone(),
803 cover_counts,
804 }
805 }
806
807 pub fn assertion_failures(&self) -> &[AssertionFailure] {
809 &self.state.assertion_failures
810 }
811
812 pub fn assumption_failures(&self) -> &[AssumptionFailure] {
814 &self.state.assumption_failures
815 }
816
817 pub fn clear(&mut self) {
819 self.clear_with_options(self.state.options);
820 }
821
822 pub fn clear_with_options(&mut self, options: ExecutionOptions) {
825 self.state.assertion_failures.clear();
826 self.state.assumption_failures.clear();
827 self.state.trace_messages.clear();
828 self.state.options = options;
829 if options.collect_covers {
830 match &mut self.state.event_counts {
831 Some(event_counts) => event_counts.fill(0),
832 None => {
833 let site_count = self.metadata().event_sites.len();
834 self.state.event_counts = Some(vec![0; site_count]);
835 }
836 }
837 } else {
838 self.state.event_counts = None;
839 }
840 }
841
842 fn metadata(&self) -> &CompiledFunctionMetadata {
843 unsafe { &*self.state.metadata }
845 }
846}
847
848unsafe fn state_from_raw<'a>(context: *mut RawExecutionContext) -> &'a mut ContextState {
849 assert!(
850 !context.is_null(),
851 "compiled PIR callback requires an execution context"
852 );
853 unsafe {
856 (*context)
857 .private_state
858 .cast::<ContextState>()
859 .as_mut()
860 .expect("compiled PIR callback received an invalid execution context")
861 }
862}
863
864fn site(state: &ContextState, site_id: u32, kind: EventKind) -> Option<&EventSiteMetadata> {
865 let metadata = unsafe { state.metadata.as_ref()? };
867 let site = metadata.event_sites.get(site_id as usize)?;
868 (site.kind == kind).then_some(site)
869}
870
871#[unsafe(no_mangle)]
878pub unsafe extern "C" fn xlsynth_pir_record_assert(
879 context: *mut RawExecutionContext,
880 site_id: u32,
881) {
882 let state = unsafe { state_from_raw(context) };
884 let Some(site) = site(state, site_id, EventKind::Assert).cloned() else {
885 return;
886 };
887 state.assertion_failures.push(AssertionFailure {
888 node_text_id: site.node_text_id,
889 message: site.message.unwrap_or_default(),
890 label: site.label.unwrap_or_default(),
891 });
892}
893
894#[unsafe(no_mangle)]
901pub unsafe extern "C" fn xlsynth_pir_record_assumption_failure(
902 context: *mut RawExecutionContext,
903 site_id: u32,
904) {
905 let state = unsafe { state_from_raw(context) };
907 let Some(site) = (unsafe { state.metadata.as_ref() })
909 .and_then(|metadata| metadata.event_sites.get(site_id as usize))
910 else {
911 return;
912 };
913 let EventKind::Assumption(kind) = site.kind else {
914 return;
915 };
916 state.assumption_failures.push(AssumptionFailure {
917 node_text_id: site.node_text_id,
918 kind,
919 });
920}
921
922#[unsafe(no_mangle)]
929pub unsafe extern "C" fn xlsynth_pir_record_cover(context: *mut RawExecutionContext, site_id: u32) {
930 let state = unsafe { state_from_raw(context) };
932 if state.event_counts.is_none() || site(state, site_id, EventKind::Cover).is_none() {
933 return;
934 }
935 if let Some(count) = state
936 .event_counts
937 .as_mut()
938 .and_then(|event_counts| event_counts.get_mut(site_id as usize))
939 {
940 *count = count.saturating_add(1);
941 }
942}
943
944#[unsafe(no_mangle)]
953pub unsafe extern "C" fn xlsynth_pir_record_trace(
954 context: *mut RawExecutionContext,
955 site_id: u32,
956 operand_ptrs: *const *const u8,
957) {
958 let state = unsafe { state_from_raw(context) };
960 let Some(max_verbosity) = state.options.trace_verbosity else {
961 return;
962 };
963 let Some(site) = site(state, site_id, EventKind::Trace) else {
964 return;
965 };
966 if site.verbosity > max_verbosity {
967 return;
968 }
969 if !site.operand_layouts.is_empty() && operand_ptrs.is_null() {
970 return;
971 }
972 state.trace_messages.push(TraceMessage {
973 node_text_id: site.node_text_id,
974 message: unsafe {
976 format_trace_message(
977 site.format.as_deref().unwrap_or(""),
978 &site.operand_layouts,
979 operand_ptrs,
980 )
981 },
982 verbosity: site.verbosity,
983 });
984}
985
986fn bit_mask(bit_count: usize) -> BigUint {
987 if bit_count == 0 {
988 BigUint::from(0u8)
989 } else {
990 (BigUint::from(1u8) << bit_count) - BigUint::from(1u8)
991 }
992}
993
994fn truncate_unsigned(value: BigUint, bit_count: usize) -> BigUint {
995 value & bit_mask(bit_count)
996}
997
998fn as_signed(value: BigUint, bit_count: usize) -> BigInt {
999 if bit_count != 0 && (&value & (BigUint::from(1u8) << (bit_count - 1))) != BigUint::from(0u8) {
1000 BigInt::from_biguint(Sign::Plus, value)
1001 - BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count)
1002 } else {
1003 BigInt::from_biguint(Sign::Plus, value)
1004 }
1005}
1006
1007fn truncate_signed(value: BigInt, bit_count: usize) -> BigUint {
1008 if bit_count == 0 {
1009 return BigUint::from(0u8);
1010 }
1011 let modulus = BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count);
1012 let mut reduced = value % &modulus;
1013 if reduced.sign() == Sign::Minus {
1014 reduced += &modulus;
1015 }
1016 let (_, bytes) = reduced.to_bytes_le();
1017 BigUint::from_bytes_le(&bytes)
1018}
1019
1020fn bounded_shift_amount(value: &BigUint, bit_count: usize) -> Option<usize> {
1021 let digits = value.to_u64_digits();
1022 if digits.len() > 1 {
1023 return None;
1024 }
1025 let amount = digits.first().copied().unwrap_or(0);
1026 usize::try_from(amount)
1027 .ok()
1028 .filter(|amount| *amount < bit_count)
1029}
1030
1031unsafe fn read_wide_bits(limbs: *const u64, bit_count: usize) -> BigUint {
1037 let limb_count = bit_count.div_ceil(64);
1038 let mut bytes = Vec::with_capacity(limb_count * std::mem::size_of::<u64>());
1039 for index in 0..limb_count {
1040 let limb = unsafe { limbs.add(index).read() };
1042 bytes.extend_from_slice(&limb.to_le_bytes());
1043 }
1044 truncate_unsigned(BigUint::from_bytes_le(&bytes), bit_count)
1045}
1046
1047unsafe fn write_wide_bits(limbs: *mut u64, bit_count: usize, value: BigUint) {
1053 let limb_count = bit_count.div_ceil(64);
1054 let bytes = truncate_unsigned(value, bit_count).to_bytes_le();
1055 for index in 0..limb_count {
1056 let mut limb_bytes = [0u8; std::mem::size_of::<u64>()];
1057 let start = index * std::mem::size_of::<u64>();
1058 if start < bytes.len() {
1059 let end = bytes.len().min(start + std::mem::size_of::<u64>());
1060 limb_bytes[..end - start].copy_from_slice(&bytes[start..end]);
1061 }
1062 unsafe { limbs.add(index).write(u64::from_le_bytes(limb_bytes)) };
1064 }
1065}
1066
1067fn get_bit(value: &BigUint, index: usize) -> bool {
1068 value
1069 .to_u64_digits()
1070 .get(index / u64::BITS as usize)
1071 .is_some_and(|limb| limb & (1u64 << (index % u64::BITS as usize)) != 0)
1072}
1073
1074fn prioritized_set_bit(value: &BigUint, bit_count: usize, lsb_prio: bool) -> Option<usize> {
1075 if lsb_prio {
1076 (0..bit_count).find(|index| get_bit(value, *index))
1077 } else {
1078 (0..bit_count).rev().find(|index| get_bit(value, *index))
1079 }
1080}
1081
1082fn leading_zero_count(value: &BigUint, bit_count: usize) -> usize {
1083 prioritized_set_bit(value, bit_count, false)
1084 .map(|index| bit_count - index - 1)
1085 .unwrap_or(bit_count)
1086}
1087
1088fn mulp_offset(result_width: usize) -> BigUint {
1089 let low_width = result_width.saturating_sub(2);
1090 let high_width = result_width - low_width;
1091 let low_shift = low_width.saturating_sub(1).min(3);
1092 let low = if low_width == 0 {
1093 BigUint::from(0u8)
1094 } else {
1095 bit_mask(low_width) >> low_shift
1096 };
1097 let high = if high_width == 0 {
1098 BigUint::from(0u8)
1099 } else {
1100 bit_mask(high_width.saturating_sub(1)) << low_width
1101 };
1102 low | high
1103}
1104
1105#[unsafe(no_mangle)]
1116pub unsafe extern "C" fn xlsynth_pir_runtime_wide_binop(
1117 dst: *mut u64,
1118 dst_bit_count: usize,
1119 lhs: *const u64,
1120 lhs_bit_count: usize,
1121 rhs: *const u64,
1122 rhs_bit_count: usize,
1123 operation: u32,
1124) {
1125 let Some(operation) = WideBinaryOp::from_abi(operation) else {
1126 return;
1127 };
1128 let lhs_unsigned = unsafe { read_wide_bits(lhs, lhs_bit_count) };
1130 let rhs_unsigned = unsafe { read_wide_bits(rhs, rhs_bit_count) };
1132 let result = match operation {
1133 WideBinaryOp::Umul => truncate_unsigned(lhs_unsigned * rhs_unsigned, dst_bit_count),
1134 WideBinaryOp::Smul => truncate_signed(
1135 as_signed(lhs_unsigned, lhs_bit_count) * as_signed(rhs_unsigned, rhs_bit_count),
1136 dst_bit_count,
1137 ),
1138 WideBinaryOp::Udiv => {
1139 if rhs_unsigned == BigUint::from(0u8) {
1140 bit_mask(dst_bit_count)
1141 } else {
1142 truncate_unsigned(lhs_unsigned / rhs_unsigned, dst_bit_count)
1143 }
1144 }
1145 WideBinaryOp::Umod => {
1146 if rhs_unsigned == BigUint::from(0u8) {
1147 BigUint::from(0u8)
1148 } else {
1149 truncate_unsigned(lhs_unsigned % rhs_unsigned, dst_bit_count)
1150 }
1151 }
1152 WideBinaryOp::Sdiv | WideBinaryOp::Smod => {
1153 let lhs_signed = as_signed(lhs_unsigned, lhs_bit_count);
1154 let rhs_signed = as_signed(rhs_unsigned, rhs_bit_count);
1155 if dst_bit_count == 0 {
1156 BigUint::from(0u8)
1157 } else if rhs_signed == BigInt::from(0u8) {
1158 if operation == WideBinaryOp::Smod {
1159 BigUint::from(0u8)
1160 } else if lhs_signed.sign() == Sign::Minus {
1161 BigUint::from(1u8) << (dst_bit_count - 1)
1162 } else {
1163 (BigUint::from(1u8) << (dst_bit_count - 1)) - BigUint::from(1u8)
1164 }
1165 } else if operation == WideBinaryOp::Sdiv {
1166 truncate_signed(lhs_signed / rhs_signed, dst_bit_count)
1167 } else {
1168 truncate_signed(lhs_signed % rhs_signed, dst_bit_count)
1169 }
1170 }
1171 WideBinaryOp::Shll | WideBinaryOp::Shrl | WideBinaryOp::Shra => {
1172 match bounded_shift_amount(&rhs_unsigned, lhs_bit_count) {
1173 None if operation == WideBinaryOp::Shra => {
1174 if as_signed(lhs_unsigned, lhs_bit_count).sign() == Sign::Minus {
1175 bit_mask(dst_bit_count)
1176 } else {
1177 BigUint::from(0u8)
1178 }
1179 }
1180 None => BigUint::from(0u8),
1181 Some(amount) if operation == WideBinaryOp::Shll => {
1182 truncate_unsigned(lhs_unsigned << amount, dst_bit_count)
1183 }
1184 Some(amount) if operation == WideBinaryOp::Shrl => {
1185 truncate_unsigned(lhs_unsigned >> amount, dst_bit_count)
1186 }
1187 Some(amount) => truncate_signed(
1188 as_signed(lhs_unsigned, lhs_bit_count) >> amount,
1189 dst_bit_count,
1190 ),
1191 }
1192 }
1193 };
1194 unsafe { write_wide_bits(dst, dst_bit_count, result) };
1196}
1197
1198#[unsafe(no_mangle)]
1204pub unsafe extern "C" fn xlsynth_pir_runtime_wide_dynamic_bit_slice(
1205 dst: *mut u64,
1206 dst_bit_count: usize,
1207 arg: *const u64,
1208 arg_bit_count: usize,
1209 start: *const u64,
1210 start_bit_count: usize,
1211) {
1212 let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1214 let start = unsafe { read_wide_bits(start, start_bit_count) };
1216 let result = bounded_shift_amount(&start, arg_bit_count)
1217 .map(|amount| truncate_unsigned(arg >> amount, dst_bit_count))
1218 .unwrap_or_else(|| BigUint::from(0u8));
1219 unsafe { write_wide_bits(dst, dst_bit_count, result) };
1221}
1222
1223#[unsafe(no_mangle)]
1229pub unsafe extern "C" fn xlsynth_pir_runtime_wide_bit_slice_update(
1230 dst: *mut u64,
1231 dst_bit_count: usize,
1232 arg: *const u64,
1233 arg_bit_count: usize,
1234 start: *const u64,
1235 start_bit_count: usize,
1236 update: *const u64,
1237 update_bit_count: usize,
1238) {
1239 let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1241 let start = unsafe { read_wide_bits(start, start_bit_count) };
1243 let update = unsafe { read_wide_bits(update, update_bit_count) };
1245 let result = if let Some(start) = bounded_shift_amount(&start, arg_bit_count) {
1246 let written_width = update_bit_count.min(arg_bit_count - start);
1247 let written_mask = bit_mask(written_width) << start;
1248 let retained = &arg & (&bit_mask(arg_bit_count) ^ &written_mask);
1249 retained | ((update & bit_mask(written_width)) << start)
1250 } else {
1251 arg
1252 };
1253 unsafe { write_wide_bits(dst, dst_bit_count, result) };
1255}
1256
1257#[unsafe(no_mangle)]
1267pub unsafe extern "C" fn xlsynth_pir_runtime_wide_unary_op(
1268 dst: *mut u64,
1269 dst_bit_count: usize,
1270 arg: *const u64,
1271 arg_bit_count: usize,
1272 operation: u32,
1273 attribute: usize,
1274) {
1275 let Some(operation) = WideUnaryOp::from_abi(operation) else {
1276 return;
1277 };
1278 let arg = unsafe { read_wide_bits(arg, arg_bit_count) };
1280 let result = match operation {
1281 WideUnaryOp::OneHot => {
1282 let selected =
1283 prioritized_set_bit(&arg, arg_bit_count, attribute != 0).unwrap_or(arg_bit_count);
1284 BigUint::from(1u8) << selected
1285 }
1286 WideUnaryOp::Encode => {
1287 let mut result = 0usize;
1288 for index in 0..arg_bit_count {
1289 if get_bit(&arg, index) {
1290 result |= index;
1291 }
1292 }
1293 BigUint::from(result)
1294 }
1295 WideUnaryOp::Decode => bounded_shift_amount(&arg, dst_bit_count)
1296 .map(|amount| BigUint::from(1u8) << amount)
1297 .unwrap_or_else(|| BigUint::from(0u8)),
1298 WideUnaryOp::ExtPrioEncode => BigUint::from(
1299 prioritized_set_bit(&arg, arg_bit_count, attribute != 0).unwrap_or(arg_bit_count),
1300 ),
1301 WideUnaryOp::ExtClz => BigUint::from(leading_zero_count(&arg, arg_bit_count) + attribute),
1302 WideUnaryOp::ExtNormalizeLeft => {
1303 let shift = leading_zero_count(&arg, arg_bit_count).saturating_add(attribute);
1304 if shift >= dst_bit_count {
1305 BigUint::from(0u8)
1306 } else {
1307 truncate_unsigned(arg << shift, dst_bit_count)
1308 }
1309 }
1310 WideUnaryOp::ExtMaskLow => {
1311 if arg >= BigUint::from(dst_bit_count) {
1312 bit_mask(dst_bit_count)
1313 } else {
1314 let count = arg.to_u64_digits().first().copied().unwrap_or(0) as usize;
1315 bit_mask(count)
1316 }
1317 }
1318 };
1319 unsafe { write_wide_bits(dst, dst_bit_count, result) };
1321}
1322
1323#[unsafe(no_mangle)]
1329pub unsafe extern "C" fn xlsynth_pir_runtime_wide_mulp(
1330 offset_dst: *mut u64,
1331 residual_dst: *mut u64,
1332 dst_bit_count: usize,
1333 lhs: *const u64,
1334 lhs_bit_count: usize,
1335 rhs: *const u64,
1336 rhs_bit_count: usize,
1337 signed: u32,
1338) {
1339 let lhs = unsafe { read_wide_bits(lhs, lhs_bit_count) };
1341 let rhs = unsafe { read_wide_bits(rhs, rhs_bit_count) };
1343 let product = if signed != 0 {
1344 truncate_signed(
1345 as_signed(lhs, lhs_bit_count) * as_signed(rhs, rhs_bit_count),
1346 dst_bit_count,
1347 )
1348 } else {
1349 truncate_unsigned(lhs * rhs, dst_bit_count)
1350 };
1351 let offset = mulp_offset(dst_bit_count);
1352 let residual = truncate_signed(
1353 BigInt::from_biguint(Sign::Plus, product)
1354 - BigInt::from_biguint(Sign::Plus, offset.clone()),
1355 dst_bit_count,
1356 );
1357 unsafe { write_wide_bits(offset_dst, dst_bit_count, offset) };
1359 unsafe { write_wide_bits(residual_dst, dst_bit_count, residual) };
1361}
1362
1363unsafe fn format_trace_message(
1365 format: &str,
1366 layouts: &[TraceValueLayout],
1367 operand_ptrs: *const *const u8,
1368) -> String {
1369 let mut output = String::new();
1370 let mut offset = 0usize;
1371 let mut operand_index = 0usize;
1372 while offset < format.len() {
1373 let remainder = &format[offset..];
1374 if remainder.starts_with("{{") || remainder.starts_with("}}") {
1375 output.push_str(&remainder[..2]);
1378 offset += 2;
1379 continue;
1380 }
1381 if let Some((specifier, preference)) = TRACE_FORMAT_SPECIFIERS
1382 .iter()
1383 .find(|(specifier, _)| remainder.starts_with(specifier))
1384 {
1385 if let Some(layout) = layouts.get(operand_index) {
1386 let pointer = unsafe { *operand_ptrs.add(operand_index) };
1388 output.push_str(&unsafe { format_native_value(pointer, layout, *preference) });
1390 }
1391 operand_index += 1;
1392 offset += specifier.len();
1393 continue;
1394 }
1395 let character = remainder
1396 .chars()
1397 .next()
1398 .expect("offset is within trace format string");
1399 output.push(character);
1400 offset += character.len_utf8();
1401 }
1402 output
1403}
1404
1405unsafe fn format_native_value(
1407 pointer: *const u8,
1408 layout: &TraceValueLayout,
1409 preference: TraceFormatPreference,
1410) -> String {
1411 match layout {
1412 TraceValueLayout::Bits {
1413 bit_count,
1414 byte_count,
1415 } => {
1416 let mut bytes = vec![0u8; *byte_count];
1417 if *byte_count != 0 {
1418 unsafe { ptr::copy_nonoverlapping(pointer, bytes.as_mut_ptr(), *byte_count) };
1420 }
1421 let value = if cfg!(target_endian = "little") {
1422 BigUint::from_bytes_le(&bytes)
1423 } else {
1424 BigUint::from_bytes_be(&bytes)
1425 };
1426 format_trace_bits(value, *bit_count, preference)
1427 }
1428 TraceValueLayout::WideBits {
1429 bit_count,
1430 limb_count: _,
1431 } => {
1432 let value = unsafe { read_wide_bits(pointer.cast::<u64>(), *bit_count) };
1435 format_trace_bits(value, *bit_count, preference)
1436 }
1437 TraceValueLayout::Array {
1438 element,
1439 element_count,
1440 } => {
1441 let elements = (0..*element_count)
1442 .map(|index| {
1443 unsafe {
1445 format_native_value(
1446 pointer.wrapping_add(index * element.byte_count()),
1447 element,
1448 preference,
1449 )
1450 }
1451 })
1452 .collect::<Vec<_>>();
1453 format!("[{}]", elements.join(", "))
1454 }
1455 TraceValueLayout::Tuple { fields, .. } => {
1456 let fields = fields
1457 .iter()
1458 .map(|field| {
1459 unsafe {
1461 format_native_value(
1462 pointer.wrapping_add(field.offset),
1463 &field.layout,
1464 preference,
1465 )
1466 }
1467 })
1468 .collect::<Vec<_>>();
1469 format!("({})", fields.join(", "))
1470 }
1471 TraceValueLayout::Token => "token".to_string(),
1472 }
1473}
1474
1475fn format_trace_bits(
1476 mut value: BigUint,
1477 bit_count: usize,
1478 preference: TraceFormatPreference,
1479) -> String {
1480 if bit_count == 0 {
1481 value = BigUint::from(0u8);
1482 } else {
1483 value &= (BigUint::from(1u8) << bit_count) - BigUint::from(1u8);
1484 }
1485 match preference {
1486 TraceFormatPreference::Default => {
1487 if bit_count <= 64 {
1488 value.to_str_radix(10)
1489 } else {
1490 format_trace_bits(value, bit_count, TraceFormatPreference::Hex)
1491 }
1492 }
1493 TraceFormatPreference::UnsignedDecimal => value.to_str_radix(10),
1494 TraceFormatPreference::SignedDecimal => {
1495 if bit_count != 0
1496 && (&value & (BigUint::from(1u8) << (bit_count - 1))) != BigUint::from(0u8)
1497 {
1498 (BigInt::from_biguint(Sign::Plus, value)
1499 - BigInt::from_biguint(Sign::Plus, BigUint::from(1u8) << bit_count))
1500 .to_string()
1501 } else {
1502 value.to_str_radix(10)
1503 }
1504 }
1505 TraceFormatPreference::PlainHex => value.to_str_radix(16),
1506 TraceFormatPreference::ZeroPaddedHex => {
1507 zero_padded_grouped_digits(&value, bit_count, 4, 16)
1508 }
1509 TraceFormatPreference::Hex => {
1510 format!("0x{}", grouped_digits(&value.to_str_radix(16)))
1511 }
1512 TraceFormatPreference::PlainBinary => value.to_str_radix(2),
1513 TraceFormatPreference::ZeroPaddedBinary => {
1514 zero_padded_grouped_digits(&value, bit_count, 1, 2)
1515 }
1516 TraceFormatPreference::Binary => {
1517 format!("0b{}", grouped_digits(&value.to_str_radix(2)))
1518 }
1519 }
1520}
1521
1522fn zero_padded_grouped_digits(
1523 value: &BigUint,
1524 bit_count: usize,
1525 bits_per_digit: usize,
1526 radix: u32,
1527) -> String {
1528 let digit_count = bit_count.div_ceil(bits_per_digit).max(1);
1529 let digits = format!(
1530 "{:0>width$}",
1531 value.to_str_radix(radix),
1532 width = digit_count
1533 );
1534 grouped_digits(&digits)
1535}
1536
1537fn grouped_digits(digits: &str) -> String {
1538 let first_group_width = match digits.len() % 4 {
1539 0 => 4,
1540 remainder => remainder,
1541 };
1542 let mut result = digits[..first_group_width].to_string();
1543 for group_start in (first_group_width..digits.len()).step_by(4) {
1544 result.push('_');
1545 result.push_str(&digits[group_start..group_start + 4]);
1546 }
1547 result
1548}
1549
1550impl TraceValueLayout {
1551 fn byte_count(&self) -> usize {
1552 match self {
1553 Self::Bits { byte_count, .. } => *byte_count,
1554 Self::WideBits { limb_count, .. } => limb_count * std::mem::size_of::<u64>(),
1555 Self::Array {
1556 element,
1557 element_count,
1558 } => element.byte_count() * element_count,
1559 Self::Tuple { byte_count, .. } => *byte_count,
1560 Self::Token => 0,
1561 }
1562 }
1563}
1564
1565#[cfg(test)]
1566mod tests {
1567 use super::*;
1568
1569 #[test]
1570 fn native_bits_wrappers_enforce_semantic_widths() {
1571 let value = BitsInU64::<42>::new((1u64 << 41) | 7).expect("value fits in bits[42]");
1572 assert_eq!(value.to_u64(), (1u64 << 41) | 7);
1573 assert!(BitsInU64::<42>::new(1u64 << 42).is_err());
1574 assert_eq!(BitsInU16::<9>::wrapping(0xffff).get(), 0x1ff);
1575 assert!(BitsInU8::<9>::new(0).is_err());
1576 }
1577
1578 #[test]
1579 fn public_signed_and_unsigned_bits_wrappers_preserve_raw_abi_bits() {
1580 let unsigned = UnsignedBitsInU8::<4>::new(15).expect("u4 max");
1581 assert_eq!(unsigned.to_u64(), 15);
1582 assert_eq!(unsigned.raw_bits(), 15);
1583 assert!(UnsignedBitsInU8::<4>::new(16).is_err());
1584 assert!(std::panic::catch_unwind(|| UnsignedBitsInU8::<4>::from_raw_bits(16)).is_err());
1585
1586 let signed = SignedBitsInU8::<4>::new(-1).expect("s4 -1");
1587 assert_eq!(signed.to_i64(), -1);
1588 assert_eq!(signed.raw_bits(), 15);
1589 assert!(SignedBitsInU8::<4>::new(8).is_err());
1590 assert!(SignedBitsInU8::<4>::new(-9).is_err());
1591 assert!(std::panic::catch_unwind(|| SignedBitsInU8::<4>::from_raw_bits(16)).is_err());
1592 assert_eq!(SignedBitsInU16::<9>::from_raw_bits(0x101).to_i64(), -255);
1593
1594 let wide = SignedWideBits::<65, 2>::from_limbs([u64::MAX, 1]).expect("s65 -1");
1595 assert_eq!(wide.to_bigint(), BigInt::from(-1));
1596 assert_eq!(wide.limbs(), &[u64::MAX, 1]);
1597 }
1598
1599 #[test]
1600 fn public_bits_wrappers_try_from_widened_integers() {
1601 assert_eq!(std::mem::size_of::<Bits0>(), 0);
1602 assert_eq!(std::mem::size_of::<UnsignedBits0>(), 0);
1603 assert_eq!(std::mem::size_of::<SignedBits0>(), 0);
1604
1605 let unsigned_zero = UnsignedBits0::try_from(0_u64).expect("0 fits in u0");
1606 assert_eq!(unsigned_zero.to_u64(), 0);
1607 assert_eq!(unsigned_zero.raw_bits(), 0);
1608 assert!(UnsignedBits0::try_from(1_u64).is_err());
1609 assert!(std::panic::catch_unwind(|| UnsignedBits0::from_raw_bits(1)).is_err());
1610
1611 let signed_zero = SignedBits0::try_from(0_i64).expect("0 fits in s0");
1612 assert_eq!(signed_zero.to_i64(), 0);
1613 assert_eq!(signed_zero.raw_bits(), 0);
1614 assert!(SignedBits0::try_from(-1_i64).is_err());
1615 assert!(SignedBits0::try_from(1_i64).is_err());
1616 assert!(std::panic::catch_unwind(|| SignedBits0::from_raw_bits(1)).is_err());
1617
1618 assert_eq!(
1619 UnsignedBitsInU8::<4>::try_from(15_u64)
1620 .expect("15 fits in u4")
1621 .to_u64(),
1622 15
1623 );
1624 assert!(UnsignedBitsInU8::<4>::try_from(16_u64).is_err());
1625 assert!(UnsignedBitsInU8::<8>::try_from(256_u64).is_err());
1626 assert_eq!(
1627 UnsignedBitsInU16::<9>::try_from(0x1ff_u64)
1628 .expect("0x1ff fits in u9")
1629 .to_u64(),
1630 0x1ff
1631 );
1632 assert_eq!(
1633 UnsignedBitsInU32::<17>::try_from(0x1ffff_u64)
1634 .expect("0x1ffff fits in u17")
1635 .to_u64(),
1636 0x1ffff
1637 );
1638 assert_eq!(
1639 UnsignedBitsInU64::<33>::try_from(0x1ffffffff_u64)
1640 .expect("0x1ffffffff fits in u33")
1641 .to_u64(),
1642 0x1ffffffff
1643 );
1644
1645 assert_eq!(
1646 SignedBitsInU8::<4>::try_from(-8_i64)
1647 .expect("-8 fits in s4")
1648 .to_i64(),
1649 -8
1650 );
1651 assert!(SignedBitsInU8::<4>::try_from(-9_i64).is_err());
1652 assert!(SignedBitsInU8::<8>::try_from(128_i64).is_err());
1653 assert_eq!(
1654 SignedBitsInU16::<9>::try_from(-256_i64)
1655 .expect("-256 fits in s9")
1656 .to_i64(),
1657 -256
1658 );
1659 assert_eq!(
1660 SignedBitsInU32::<17>::try_from(-65_536_i64)
1661 .expect("-65536 fits in s17")
1662 .to_i64(),
1663 -65_536
1664 );
1665 assert_eq!(
1666 SignedBitsInU64::<33>::try_from(-4_294_967_296_i64)
1667 .expect("-4294967296 fits in s33")
1668 .to_i64(),
1669 -4_294_967_296
1670 );
1671 }
1672
1673 #[test]
1674 fn wide_bits_wrappers_use_lsb_first_limbs_and_mask_high_bits() {
1675 let value =
1676 WideBits::<65, 2>::from_limbs([0x0123_4567_89ab_cdef, 1]).expect("canonical value");
1677 assert_eq!(value.limbs(), &[0x0123_4567_89ab_cdef, 1]);
1678 assert!(WideBits::<65, 2>::from_limbs([0, 2]).is_err());
1679 assert_eq!(WideBits::<65, 2>::wrapping_limbs([7, 3]).limbs(), &[7, 1]);
1680 assert!(WideBits::<65, 3>::from_limbs([0, 0, 0]).is_err());
1681 assert_eq!(std::mem::size_of::<Token>(), 0);
1682 }
1683
1684 fn metadata() -> CompiledFunctionMetadata {
1685 CompiledFunctionMetadata {
1686 event_sites: vec![
1687 EventSiteMetadata {
1688 node_text_id: 10,
1689 kind: EventKind::Cover,
1690 label: Some("covered".to_string()),
1691 message: None,
1692 format: None,
1693 verbosity: 0,
1694 operand_layouts: Vec::new(),
1695 },
1696 EventSiteMetadata {
1697 node_text_id: 11,
1698 kind: EventKind::Assert,
1699 label: Some("assert_label".to_string()),
1700 message: Some("failed".to_string()),
1701 format: None,
1702 verbosity: 0,
1703 operand_layouts: Vec::new(),
1704 },
1705 EventSiteMetadata {
1706 node_text_id: 12,
1707 kind: EventKind::Trace,
1708 label: None,
1709 message: None,
1710 format: Some("x={} arr={}".to_string()),
1711 verbosity: 1,
1712 operand_layouts: vec![
1713 TraceValueLayout::Bits {
1714 bit_count: 8,
1715 byte_count: 1,
1716 },
1717 TraceValueLayout::Array {
1718 element: Box::new(TraceValueLayout::Bits {
1719 bit_count: 8,
1720 byte_count: 1,
1721 }),
1722 element_count: 2,
1723 },
1724 ],
1725 },
1726 EventSiteMetadata {
1727 node_text_id: 13,
1728 kind: EventKind::Assumption(AssumptionFailureKind::ArrayIndexOutOfBounds),
1729 label: None,
1730 message: None,
1731 format: None,
1732 verbosity: 0,
1733 operand_layouts: Vec::new(),
1734 },
1735 ],
1736 }
1737 }
1738
1739 #[test]
1740 fn cover_and_assert_callbacks_collect_rust_owned_results() {
1741 let metadata = metadata();
1742 let mut context =
1743 ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1744 let mut raw = context.raw_context();
1745 unsafe {
1747 xlsynth_pir_record_cover(&mut raw, 0);
1748 xlsynth_pir_record_cover(&mut raw, 0);
1749 xlsynth_pir_record_assert(&mut raw, 1);
1750 xlsynth_pir_record_assumption_failure(&mut raw, 3);
1751 }
1752 let result = context.result();
1753 assert_eq!(result.cover_counts[0].count, 2);
1754 assert_eq!(result.cover_counts[0].label, "covered");
1755 assert_eq!(result.assertion_failures[0].message, "failed");
1756 assert_eq!(result.assertion_failures[0].label, "assert_label");
1757 assert_eq!(
1758 result.assumption_failures,
1759 vec![AssumptionFailure {
1760 node_text_id: 13,
1761 kind: AssumptionFailureKind::ArrayIndexOutOfBounds,
1762 }]
1763 );
1764 }
1765
1766 #[test]
1767 fn trace_callback_decodes_values_before_native_storage_changes() {
1768 let metadata = metadata();
1769 let mut context =
1770 ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1771 let mut raw = context.raw_context();
1772 let mut scalar = 7u8;
1773 let mut array = [2u8, 3u8];
1774 let operands = [
1775 ptr::from_ref(&scalar).cast::<u8>(),
1776 ptr::from_ref(&array).cast::<u8>(),
1777 ];
1778 unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1780 scalar = 90;
1781 array[0] = 91;
1782 assert_eq!(scalar, 90);
1783 assert_eq!(array[0], 91);
1784 assert_eq!(context.result().trace_messages[0].message, "x=7 arr=[2, 3]");
1785 }
1786
1787 #[test]
1788 fn trace_callback_formats_all_specifiers_and_wide_decimal_without_xls() {
1789 let twelve_bits = TraceValueLayout::Bits {
1790 bit_count: 12,
1791 byte_count: 2,
1792 };
1793 let metadata = CompiledFunctionMetadata {
1794 event_sites: vec![EventSiteMetadata {
1795 node_text_id: 20,
1796 kind: EventKind::Trace,
1797 label: None,
1798 message: None,
1799 format: Some(
1800 "literal={{ default={} u={:u} d={:d} x={:x} 0x={:0x} #x={:#x} b={:b} 0b={:0b} #b={:#b} wide={} wide_u={:u}".to_string(),
1801 ),
1802 verbosity: 0,
1803 operand_layouts: vec![
1804 twelve_bits.clone(),
1805 twelve_bits.clone(),
1806 TraceValueLayout::Bits {
1807 bit_count: 8,
1808 byte_count: 1,
1809 },
1810 twelve_bits.clone(),
1811 twelve_bits.clone(),
1812 twelve_bits.clone(),
1813 twelve_bits.clone(),
1814 twelve_bits.clone(),
1815 twelve_bits,
1816 TraceValueLayout::Bits {
1817 bit_count: 72,
1818 byte_count: 9,
1819 },
1820 TraceValueLayout::Bits {
1821 bit_count: 72,
1822 byte_count: 9,
1823 },
1824 ],
1825 }],
1826 };
1827 let mut context =
1828 ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1829 let mut raw = context.raw_context();
1830 let twelve = 43u16;
1831 let negative = 251u8;
1832 let wide = [1u8, 0, 0, 0, 0, 0, 0, 0, 1];
1833 let operands = [
1834 ptr::from_ref(&twelve).cast::<u8>(),
1835 ptr::from_ref(&twelve).cast::<u8>(),
1836 ptr::from_ref(&negative).cast::<u8>(),
1837 ptr::from_ref(&twelve).cast::<u8>(),
1838 ptr::from_ref(&twelve).cast::<u8>(),
1839 ptr::from_ref(&twelve).cast::<u8>(),
1840 ptr::from_ref(&twelve).cast::<u8>(),
1841 ptr::from_ref(&twelve).cast::<u8>(),
1842 ptr::from_ref(&twelve).cast::<u8>(),
1843 ptr::from_ref(&wide).cast::<u8>(),
1844 ptr::from_ref(&wide).cast::<u8>(),
1845 ];
1846 unsafe { xlsynth_pir_record_trace(&mut raw, 0, operands.as_ptr()) };
1848 assert_eq!(
1849 context.result().trace_messages[0].message,
1850 "literal={{ default=43 u=43 d=-5 x=2b 0x=02b #x=0x2b b=101011 0b=0000_0010_1011 #b=0b10_1011 wide=0x1_0000_0000_0000_0001 wide_u=18446744073709551617"
1851 );
1852 }
1853
1854 #[test]
1855 fn clear_resets_accumulated_event_results() {
1856 let metadata = metadata();
1857 let mut context =
1858 ExecutionContext::new_with_options(&metadata, ExecutionOptions::collect_all());
1859 let mut raw = context.raw_context();
1860 unsafe { xlsynth_pir_record_cover(&mut raw, 0) };
1862 context.clear();
1863 let result = context.result();
1864 assert!(result.assertion_failures.is_empty());
1865 assert!(result.assumption_failures.is_empty());
1866 assert!(result.trace_messages.is_empty());
1867 assert_eq!(result.cover_counts[0].count, 0);
1868 }
1869
1870 #[test]
1871 fn default_context_does_not_collect_traces_or_covers() {
1872 let metadata = metadata();
1873 let mut context = ExecutionContext::new(&metadata);
1874 let mut raw = context.raw_context();
1875 let scalar = 7u8;
1876 let array = [2u8, 3u8];
1877 let operands = [
1878 ptr::from_ref(&scalar).cast::<u8>(),
1879 ptr::from_ref(&array).cast::<u8>(),
1880 ];
1881 unsafe {
1883 xlsynth_pir_record_cover(&mut raw, 0);
1884 xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr());
1885 }
1886 let result = context.result();
1887 assert!(result.cover_counts.is_empty());
1888 assert!(result.trace_messages.is_empty());
1889 }
1890
1891 #[test]
1892 fn trace_callback_respects_runtime_verbosity() {
1893 let metadata = metadata();
1894 let scalar = 7u8;
1895 let array = [2u8, 3u8];
1896 let operands = [
1897 ptr::from_ref(&scalar).cast::<u8>(),
1898 ptr::from_ref(&array).cast::<u8>(),
1899 ];
1900 let mut context = ExecutionContext::new_with_options(
1901 &metadata,
1902 ExecutionOptions::new(Some(0), false),
1903 );
1904 let mut raw = context.raw_context();
1905 unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1907 assert!(context.result().trace_messages.is_empty());
1908
1909 context.clear_with_options(ExecutionOptions::new(
1910 Some(1),
1911 false,
1912 ));
1913 let mut raw = context.raw_context();
1914 unsafe { xlsynth_pir_record_trace(&mut raw, 2, operands.as_ptr()) };
1916 assert_eq!(context.result().trace_messages[0].message, "x=7 arr=[2, 3]");
1917 }
1918
1919 #[test]
1920 fn wide_trace_values_use_lsb_first_native_limbs() {
1921 let value = [1u64, 1u64];
1922 let formatted = unsafe {
1924 format_native_value(
1925 value.as_ptr().cast(),
1926 &TraceValueLayout::WideBits {
1927 bit_count: 72,
1928 limb_count: 2,
1929 },
1930 TraceFormatPreference::Hex,
1931 )
1932 };
1933 assert_eq!(formatted, "0x1_0000_0000_0000_0001");
1934 }
1935
1936 #[test]
1937 fn wide_binary_runtime_helpers_cover_arithmetic_shifts_and_slices() {
1938 let lhs = [u64::MAX, 1];
1939 let rhs = [2u64, 0];
1940 let mut output = [0u64; 2];
1941 unsafe {
1943 xlsynth_pir_runtime_wide_binop(
1944 output.as_mut_ptr(),
1945 65,
1946 lhs.as_ptr(),
1947 65,
1948 rhs.as_ptr(),
1949 65,
1950 WideBinaryOp::Umul as u32,
1951 );
1952 }
1953 assert_eq!(output, [u64::MAX - 1, 1]);
1954
1955 let negative = [0u64, 1];
1956 let shift = [1u64, 0];
1957 unsafe {
1959 xlsynth_pir_runtime_wide_binop(
1960 output.as_mut_ptr(),
1961 65,
1962 negative.as_ptr(),
1963 65,
1964 shift.as_ptr(),
1965 65,
1966 WideBinaryOp::Shra as u32,
1967 );
1968 }
1969 assert_eq!(output, [1u64 << 63, 1]);
1970
1971 let start = [63u64, 0];
1972 unsafe {
1974 xlsynth_pir_runtime_wide_dynamic_bit_slice(
1975 output.as_mut_ptr(),
1976 65,
1977 lhs.as_ptr(),
1978 65,
1979 start.as_ptr(),
1980 65,
1981 );
1982 }
1983 assert_eq!(output, [3, 0]);
1984
1985 let zero = [0u64, 0];
1986 let update = [3u64, 0];
1987 unsafe {
1989 xlsynth_pir_runtime_wide_bit_slice_update(
1990 output.as_mut_ptr(),
1991 65,
1992 zero.as_ptr(),
1993 65,
1994 start.as_ptr(),
1995 65,
1996 update.as_ptr(),
1997 65,
1998 );
1999 }
2000 assert_eq!(output, [1u64 << 63, 1]);
2001 }
2002
2003 #[test]
2004 fn wide_runtime_helpers_accept_zero_width_storage() {
2005 unsafe {
2008 xlsynth_pir_runtime_wide_binop(
2009 ptr::null_mut(),
2010 0,
2011 ptr::null(),
2012 0,
2013 ptr::null(),
2014 0,
2015 WideBinaryOp::Sdiv as u32,
2016 );
2017 xlsynth_pir_runtime_wide_mulp(
2018 ptr::null_mut(),
2019 ptr::null_mut(),
2020 0,
2021 ptr::null(),
2022 0,
2023 ptr::null(),
2024 0,
2025 0,
2026 );
2027 }
2028 }
2029
2030 #[test]
2031 fn wide_unary_runtime_helpers_cover_encoding_and_extensions() {
2032 let input = [1u64 << 63, 1];
2033 let mut output = [0u64; 3];
2034 unsafe {
2036 xlsynth_pir_runtime_wide_unary_op(
2037 output.as_mut_ptr(),
2038 66,
2039 input.as_ptr(),
2040 65,
2041 WideUnaryOp::OneHot as u32,
2042 1,
2043 );
2044 }
2045 assert_eq!(output[..2], [1u64 << 63, 0]);
2046
2047 unsafe {
2049 xlsynth_pir_runtime_wide_unary_op(
2050 output.as_mut_ptr(),
2051 7,
2052 input.as_ptr(),
2053 65,
2054 WideUnaryOp::ExtPrioEncode as u32,
2055 0,
2056 );
2057 }
2058 assert_eq!(output[0], 64);
2059
2060 let zeros = [0u64; 2];
2061 unsafe {
2063 xlsynth_pir_runtime_wide_unary_op(
2064 output.as_mut_ptr(),
2065 129,
2066 zeros.as_ptr(),
2067 65,
2068 WideUnaryOp::ExtMaskLow as u32,
2069 0,
2070 );
2071 }
2072 assert_eq!(output, [0, 0, 0]);
2073
2074 let count = [80u64, 0];
2075 unsafe {
2077 xlsynth_pir_runtime_wide_unary_op(
2078 output.as_mut_ptr(),
2079 129,
2080 count.as_ptr(),
2081 65,
2082 WideUnaryOp::ExtMaskLow as u32,
2083 0,
2084 );
2085 }
2086 assert_eq!(output, [u64::MAX, 0xffff, 0]);
2087 }
2088
2089 #[test]
2090 fn wide_mulp_runtime_helper_returns_components_summing_to_product() {
2091 let lhs = [u64::MAX, 1];
2092 let rhs = [3u64, 0];
2093 let mut offset = [0u64; 3];
2094 let mut residual = [0u64; 3];
2095 unsafe {
2097 xlsynth_pir_runtime_wide_mulp(
2098 offset.as_mut_ptr(),
2099 residual.as_mut_ptr(),
2100 129,
2101 lhs.as_ptr(),
2102 65,
2103 rhs.as_ptr(),
2104 65,
2105 0,
2106 );
2107 }
2108 let offset = unsafe { read_wide_bits(offset.as_ptr(), 129) };
2109 let residual = unsafe { read_wide_bits(residual.as_ptr(), 129) };
2110 assert_eq!(
2111 truncate_unsigned(offset + residual, 129),
2112 truncate_unsigned(
2113 unsafe { read_wide_bits(lhs.as_ptr(), 65) }
2114 * unsafe { read_wide_bits(rhs.as_ptr(), 65) },
2115 129,
2116 )
2117 );
2118 }
2119}