Skip to main content

prism_numerics/
bigint.rs

1//! `BigIntAxis` declaration + parametric modular-arithmetic impls + shape.
2//!
3//! Per [Wiki ADR-031][09-adr-031] the numerics sub-crate exposes
4//! `BigIntAxis` as the canonical Layer-3 vocabulary for fixed-width
5//! integer arithmetic. The reference impl [`BigIntModularNumeric`] is
6//! generic over operand byte-width per ADR-031's `BigInt<MaxBits>`
7//! shape commitment — every instantiation up to [`MAX_BIG_INT_BYTES`]
8//! (512 bits) is a distinct sealed `AxisExtension` that the
9//! application's `AxisTuple` can select.
10//!
11//! [`BigIntShape`] is the matching `ConstrainedTypeShape` so
12//! application authors can declare `BigInt<N>`-typed inputs and outputs
13//! to their `prism_model!` invocations without re-rolling the shape.
14//!
15//! [09-adr-031]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
16
17#![allow(missing_docs)]
18
19use uor_foundation::enforcement::{GroundedShape, ShapeViolation};
20use uor_foundation::pipeline::{ConstrainedTypeShape, ConstraintRef, IntoBindingValue, TermValue};
21use uor_foundation_sdk::axis;
22
23use crate::{check_output, split_pair};
24
25axis! {
26    /// Wiki ADR-031 fixed-width integer arithmetic axis.
27    ///
28    /// Kernels take input `a || b` (big-endian-encoded equal-width
29    /// operands) and emit modular arithmetic results. The reference
30    /// impl `BigIntModularNumeric<BYTES>` is generic in `BYTES` for
31    /// the full range `[1, MAX_BIG_INT_BYTES]`.
32    pub trait BigIntAxis: AxisExtension {
33        /// ADR-017 content address.
34        const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/BigIntAxis";
35        /// Operand byte-width (overridden per impl).
36        const MAX_OUTPUT_BYTES: usize = 32;
37        /// `(a + b) mod 2^(8*N)` — input is `a || b` (`2N` bytes).
38        ///
39        /// # Errors
40        ///
41        /// Returns `ShapeViolation` on input/output arity mismatch.
42        fn add(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
43        /// `(a - b) mod 2^(8*N)` — input is `a || b` (`2N` bytes).
44        ///
45        /// # Errors
46        ///
47        /// Returns `ShapeViolation` on input/output arity mismatch.
48        fn sub(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
49        /// `(a * b) mod 2^(8*N)` — input is `a || b` (`2N` bytes).
50        ///
51        /// # Errors
52        ///
53        /// Returns `ShapeViolation` on input/output arity mismatch.
54        fn mul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
55    }
56}
57
58/// Maximum operand byte-width any `BigIntModularNumeric<BYTES>`
59/// instantiation supports. Driven by the on-stack accumulator size
60/// used by the multiplication kernel (a `2*MAX_BIG_INT_BYTES` `u32`
61/// array — 1 KiB at 64 bytes / 512 bits).
62pub const MAX_BIG_INT_BYTES: usize = 64;
63
64const ACC_CAP: usize = 2 * MAX_BIG_INT_BYTES;
65
66fn width_violation() -> ShapeViolation {
67    ShapeViolation {
68        shape_iri: "https://uor.foundation/axis/BigIntAxis",
69        constraint_iri: "https://uor.foundation/axis/BigIntAxis/widthInRange",
70        property_iri: "https://uor.foundation/axis/operandByteWidth",
71        expected_range: "https://uor.foundation/axis/BigIntAxis/MaxBigIntBytes",
72        min_count: 1,
73        #[allow(clippy::cast_possible_truncation)]
74        max_count: MAX_BIG_INT_BYTES as u32,
75        kind: uor_foundation::ViolationKind::ValueCheck,
76    }
77}
78
79/// Parametric `N`-byte modular-arithmetic impl of [`BigIntAxis`].
80///
81/// `BYTES` is the operand width in bytes (`8 * BYTES` bits). Arithmetic
82/// is mod `2^(8 * BYTES)` (wrapping). The supported range is
83/// `[1, MAX_BIG_INT_BYTES]` (512 bits at the upper bound).
84#[derive(Debug, Clone, Copy)]
85pub struct BigIntModularNumeric<const BYTES: usize>;
86
87impl<const BYTES: usize> Default for BigIntModularNumeric<BYTES> {
88    fn default() -> Self {
89        Self
90    }
91}
92
93impl<const BYTES: usize> BigIntAxis for BigIntModularNumeric<BYTES> {
94    const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/BigIntAxis/Modular";
95    const MAX_OUTPUT_BYTES: usize = BYTES;
96
97    fn add(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
98        if BYTES == 0 || BYTES > MAX_BIG_INT_BYTES {
99            return Err(width_violation());
100        }
101        let (a, b) = split_pair(input, BYTES)?;
102        check_output(out, BYTES)?;
103        let mut carry: u16 = 0;
104        for i in (0..BYTES).rev() {
105            let sum = u16::from(a[i]) + u16::from(b[i]) + carry;
106            #[allow(clippy::cast_possible_truncation)]
107            {
108                out[i] = (sum & 0xff) as u8;
109            }
110            carry = sum >> 8;
111        }
112        Ok(BYTES)
113    }
114
115    fn sub(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
116        if BYTES == 0 || BYTES > MAX_BIG_INT_BYTES {
117            return Err(width_violation());
118        }
119        let (a, b) = split_pair(input, BYTES)?;
120        check_output(out, BYTES)?;
121        let mut borrow: i16 = 0;
122        for i in (0..BYTES).rev() {
123            let diff = i16::from(a[i]) - i16::from(b[i]) - borrow;
124            if diff < 0 {
125                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
126                {
127                    out[i] = (diff + 256) as u8;
128                }
129                borrow = 1;
130            } else {
131                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
132                {
133                    out[i] = diff as u8;
134                }
135                borrow = 0;
136            }
137        }
138        Ok(BYTES)
139    }
140
141    fn mul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
142        if BYTES == 0 || BYTES > MAX_BIG_INT_BYTES {
143            return Err(width_violation());
144        }
145        let (a, b) = split_pair(input, BYTES)?;
146        check_output(out, BYTES)?;
147        // Schoolbook product into a fixed-size accumulator sized for
148        // MAX_BIG_INT_BYTES; only the first 2*BYTES positions are used.
149        let mut acc = [0u32; ACC_CAP];
150        for i in (0..BYTES).rev() {
151            for j in (0..BYTES).rev() {
152                let prod = u32::from(a[i]) * u32::from(b[j]);
153                let pos = i + j + 1;
154                let sum = acc[pos] + (prod & 0xff);
155                acc[pos] = sum & 0xff;
156                let mut carry = (sum >> 8) + (prod >> 8);
157                let mut k = pos;
158                while carry > 0 && k > 0 {
159                    k -= 1;
160                    let next = acc[k] + carry;
161                    acc[k] = next & 0xff;
162                    carry = next >> 8;
163                }
164            }
165        }
166        for i in 0..BYTES {
167            #[allow(clippy::cast_possible_truncation)]
168            {
169                out[i] = (acc[i + BYTES] & 0xff) as u8;
170            }
171        }
172        Ok(BYTES)
173    }
174}
175
176// ADR-052 generic-form companion: replaces the hand-written
177// AxisExtension impl. The macro's @generic arm accepts a `:ty` plus a
178// generic parameter list so parametric Layer-3 axes inherit the
179// dispatch body from the `axis!` emission.
180axis_extension_impl_for_big_int_axis!(@generic BigIntModularNumeric<BYTES>, [const BYTES: usize]);
181
182/// 256-bit modular arithmetic (mod `2^256`).
183pub type BigInt256Numeric = BigIntModularNumeric<32>;
184/// 512-bit modular arithmetic (mod `2^512`).
185pub type BigInt512Numeric = BigIntModularNumeric<64>;
186/// 128-bit modular arithmetic (mod `2^128`).
187pub type BigInt128Numeric = BigIntModularNumeric<16>;
188/// 64-bit modular arithmetic (mod `2^64`) — matches `u64` wrapping.
189pub type BigInt64Numeric = BigIntModularNumeric<8>;
190
191// ---- BigIntShape: ConstrainedTypeShape carrier for BigInt<N> -----------
192
193/// Parametric ConstrainedTypeShape: an `N`-byte big-endian integer.
194///
195/// Per ADR-031 this is the canonical Layer-3 shape downstream
196/// `prism_model!` invocations use to type their `Input` / `Output` as
197/// big-integer values. The shape carries `BYTES` sites with no
198/// admission constraints; admission discipline (range bounds, modulus,
199/// etc.) is the consumer's responsibility through additional
200/// constraint refs.
201///
202/// Per ADR-017's closure rule the IRI is the foundation's shared
203/// `ConstrainedType` class; instance identity flows through
204/// `(SITE_COUNT, CONSTRAINTS)`.
205#[derive(Debug, Clone, Copy)]
206pub struct BigIntShape<const BYTES: usize>;
207
208impl<const BYTES: usize> Default for BigIntShape<BYTES> {
209    fn default() -> Self {
210        Self
211    }
212}
213
214impl<const BYTES: usize> ConstrainedTypeShape for BigIntShape<BYTES> {
215    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
216    const SITE_COUNT: usize = BYTES;
217    const CONSTRAINTS: &'static [ConstraintRef] = &[];
218    #[allow(clippy::cast_possible_truncation)]
219    const CYCLE_SIZE: u64 = 256u64.saturating_pow(BYTES as u32);
220}
221
222impl<const BYTES: usize> uor_foundation::pipeline::__sdk_seal::Sealed for BigIntShape<BYTES> {}
223impl<const BYTES: usize> GroundedShape for BigIntShape<BYTES> {}
224impl<'a, const BYTES: usize> IntoBindingValue<'a> for BigIntShape<BYTES> {
225    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
226        // The shape is a phantom carrier; downstream impls that want to
227        // bind an actual N-byte big-int value wrap this shape in a
228        // newtype carrying the data + a bespoke carrier.
229        TermValue::empty()
230    }
231}
232
233// ADR-033 G20 leaf-shape PartitionProductFields impl per
234// foundation-sdk 0.4.11's depth-2 verb!-macro projection chain.
235// Foundation-sdk 0.4.11 requires `PartitionProductFields` on every
236// type used as a partition-product factor (including leaves) for
237// the depth-2 chained-field-access trait-bound check to resolve.
238// Empty FIELDS signals "atomic byte-sequence carrier — no further
239// projection possible"; the macro respects the termination marker
240// without indexing into the empty array (the 0.4.10 const-eval
241// panic on empty FIELDS is fixed in 0.4.11).
242impl<const BYTES: usize> uor_foundation::pipeline::PartitionProductFields for BigIntShape<BYTES> {
243    const FIELDS: &'static [(u32, u32)] = &[];
244    const FIELD_NAMES: &'static [&'static str] = &[];
245}