sp1_core_machine/operations/field/
field_op.rs

1use std::fmt::Debug;
2
3use crate::air::WordAirBuilder;
4use num::{BigUint, Zero};
5
6use p3_air::AirBuilder;
7use p3_field::PrimeField32;
8
9use sp1_core_executor::events::{ByteRecord, FieldOperation};
10use sp1_derive::AlignedBorrow;
11use sp1_stark::air::{Polynomial, SP1AirBuilder};
12
13use super::{
14    util::{compute_root_quotient_and_shift, split_u16_limbs_to_u8_limbs},
15    util_air::eval_field_operation,
16};
17use sp1_curves::params::{FieldParameters, Limbs};
18
19use typenum::Unsigned;
20
21/// A set of columns to compute an emulated modular arithmetic operation.
22///
23/// *Safety* The input operands (a, b) (not included in the operation columns) are assumed to be
24/// elements within the range `[0, 2^{P::nb_bits()})`. the result is also assumed to be within the
25/// same range. Let `M = P:modulus()`. The constraints of the function [`FieldOpCols::eval`] assert
26/// that:
27/// * When `op` is `FieldOperation::Add`, then `result = a + b mod M`.
28/// * When `op` is `FieldOperation::Mul`, then `result = a * b mod M`.
29/// * When `op` is `FieldOperation::Sub`, then `result = a - b mod M`.
30/// * When `op` is `FieldOperation::Div`, then `result * b = a mod M`.
31///
32/// **Warning**: The constraints do not check for division by zero. The caller is responsible for
33/// ensuring that the division operation is valid.
34#[derive(Debug, Clone, AlignedBorrow)]
35#[repr(C)]
36pub struct FieldOpCols<T, P: FieldParameters> {
37    /// The result of `a op b`, where a, b are field elements
38    pub result: Limbs<T, P::Limbs>,
39    pub carry: Limbs<T, P::Limbs>,
40    pub(crate) witness_low: Limbs<T, P::Witness>,
41    pub(crate) witness_high: Limbs<T, P::Witness>,
42}
43
44impl<F: PrimeField32, P: FieldParameters> FieldOpCols<F, P> {
45    #[allow(clippy::too_many_arguments)]
46    /// Populate result and carry columns from the equation (a*b + c) % modulus
47    pub fn populate_mul_and_carry(
48        &mut self,
49        record: &mut impl ByteRecord,
50        a: &BigUint,
51        b: &BigUint,
52        c: &BigUint,
53        modulus: &BigUint,
54    ) -> (BigUint, BigUint) {
55        let p_a: Polynomial<F> = P::to_limbs_field::<F, _>(a).into();
56        let p_b: Polynomial<F> = P::to_limbs_field::<F, _>(b).into();
57        let p_c: Polynomial<F> = P::to_limbs_field::<F, _>(c).into();
58
59        let mul_add = a * b + c;
60        let result = &mul_add % modulus;
61        let carry = (mul_add - &result) / modulus;
62        debug_assert!(&result < modulus);
63        debug_assert!(&carry < modulus);
64        debug_assert_eq!(&carry * modulus, a * b + c - &result);
65
66        let p_modulus_limbs =
67            modulus.to_bytes_le().iter().map(|x| F::from_canonical_u8(*x)).collect::<Vec<F>>();
68        let p_modulus: Polynomial<F> = p_modulus_limbs.iter().into();
69        let p_result: Polynomial<F> = P::to_limbs_field::<F, _>(&result).into();
70        let p_carry: Polynomial<F> = P::to_limbs_field::<F, _>(&carry).into();
71
72        let p_op = &p_a * &p_b + &p_c;
73        let p_vanishing = &p_op - &p_result - &p_carry * &p_modulus;
74
75        let p_witness = compute_root_quotient_and_shift(
76            &p_vanishing,
77            P::WITNESS_OFFSET,
78            P::NB_BITS_PER_LIMB as u32,
79            P::NB_WITNESS_LIMBS,
80        );
81
82        let (mut p_witness_low, mut p_witness_high) = split_u16_limbs_to_u8_limbs(&p_witness);
83
84        self.result = p_result.into();
85        self.carry = p_carry.into();
86
87        p_witness_low.resize(P::Witness::USIZE, F::zero());
88        p_witness_high.resize(P::Witness::USIZE, F::zero());
89        self.witness_low = Limbs(p_witness_low.try_into().unwrap());
90        self.witness_high = Limbs(p_witness_high.try_into().unwrap());
91
92        record.add_u8_range_checks_field(&self.result.0);
93        record.add_u8_range_checks_field(&self.carry.0);
94        record.add_u8_range_checks_field(&self.witness_low.0);
95        record.add_u8_range_checks_field(&self.witness_high.0);
96
97        (result, carry)
98    }
99
100    pub fn populate_carry_and_witness(
101        &mut self,
102        a: &BigUint,
103        b: &BigUint,
104        op: FieldOperation,
105        modulus: &BigUint,
106    ) -> BigUint {
107        let p_a: Polynomial<F> = P::to_limbs_field::<F, _>(a).into();
108        let p_b: Polynomial<F> = P::to_limbs_field::<F, _>(b).into();
109        let (result, carry) = match op {
110            FieldOperation::Add => ((a + b) % modulus, (a + b - (a + b) % modulus) / modulus),
111            FieldOperation::Mul => ((a * b) % modulus, (a * b - (a * b) % modulus) / modulus),
112            FieldOperation::Sub | FieldOperation::Div => unreachable!(),
113        };
114        debug_assert!(&result < modulus);
115        debug_assert!(&carry < modulus);
116        match op {
117            FieldOperation::Add => debug_assert_eq!(&carry * modulus, a + b - &result),
118            FieldOperation::Mul => debug_assert_eq!(&carry * modulus, a * b - &result),
119            FieldOperation::Sub | FieldOperation::Div => unreachable!(),
120        }
121
122        // Here we have special logic for p_modulus because to_limbs_field only works for numbers in
123        // the field, but modulus can == the field modulus so it can have 1 extra limb (ex.
124        // uint256).
125        let p_modulus_limbs =
126            modulus.to_bytes_le().iter().map(|x| F::from_canonical_u8(*x)).collect::<Vec<F>>();
127        let p_modulus: Polynomial<F> = p_modulus_limbs.iter().into();
128        let p_result: Polynomial<F> = P::to_limbs_field::<F, _>(&result).into();
129        let p_carry: Polynomial<F> = P::to_limbs_field::<F, _>(&carry).into();
130
131        // Compute the vanishing polynomial.
132        let p_op = match op {
133            FieldOperation::Add => &p_a + &p_b,
134            FieldOperation::Mul => &p_a * &p_b,
135            FieldOperation::Sub | FieldOperation::Div => unreachable!(),
136        };
137        let p_vanishing: Polynomial<F> = &p_op - &p_result - &p_carry * &p_modulus;
138
139        let p_witness = compute_root_quotient_and_shift(
140            &p_vanishing,
141            P::WITNESS_OFFSET,
142            P::NB_BITS_PER_LIMB as u32,
143            P::NB_WITNESS_LIMBS,
144        );
145        let (mut p_witness_low, mut p_witness_high) = split_u16_limbs_to_u8_limbs(&p_witness);
146
147        self.result = p_result.into();
148        self.carry = p_carry.into();
149
150        p_witness_low.resize(P::Witness::USIZE, F::zero());
151        p_witness_high.resize(P::Witness::USIZE, F::zero());
152        self.witness_low = Limbs(p_witness_low.try_into().unwrap());
153        self.witness_high = Limbs(p_witness_high.try_into().unwrap());
154
155        result
156    }
157
158    /// Populate these columns with a specified modulus. This is useful in the `mulmod` precompile
159    /// as an example.
160    #[allow(clippy::too_many_arguments)]
161    pub fn populate_with_modulus(
162        &mut self,
163        record: &mut impl ByteRecord,
164        a: &BigUint,
165        b: &BigUint,
166        modulus: &BigUint,
167        op: FieldOperation,
168    ) -> BigUint {
169        if op == FieldOperation::Div {
170            assert_ne!(*b, BigUint::zero(), "division by zero is not allowed");
171            assert_ne!(*b, *modulus, "division by zero is not allowed");
172        }
173
174        let result = match op {
175            // If doing the subtraction operation, a - b = result, equivalent to a = result + b.
176            FieldOperation::Sub => {
177                let result = (modulus.clone() + a - b) % modulus;
178                // We populate the carry, witness_low, witness_high as if we were doing an addition
179                // with result + b. But we populate `result` with the actual result
180                // of the subtraction because those columns are expected to contain
181                // the result by the user. Note that this reversal means we have to
182                // flip result, a correspondingly in the `eval` function.
183                self.populate_carry_and_witness(&result, b, FieldOperation::Add, modulus);
184                self.result = P::to_limbs_field::<F, _>(&result);
185                result
186            }
187            // a / b = result is equivalent to a = result * b.
188            FieldOperation::Div => {
189                // As modulus is prime, we can use Fermat's little theorem to compute the
190                // inverse.
191                cfg_if::cfg_if! {
192                    if #[cfg(feature = "bigint-rug")] {
193                        use sp1_curves::utils::{biguint_to_rug, rug_to_biguint};
194                        let rug_a = biguint_to_rug(a);
195                        let rug_b = biguint_to_rug(b);
196                        let rug_modulus = biguint_to_rug(modulus);
197                        let rug_result = (rug_a
198                            * rug_b.pow_mod(&(rug_modulus.clone() - 2u32), &rug_modulus.clone()).unwrap())
199                            % rug_modulus.clone();
200                        let result = rug_to_biguint(&rug_result);
201                    } else {
202                        let result =
203                            (a * b.modpow(&(modulus.clone() - 2u32), &modulus.clone())) % modulus.clone();
204                    }
205                }
206                // We populate the carry, witness_low, witness_high as if we were doing a
207                // multiplication with result * b. But we populate `result` with the
208                // actual result of the multiplication because those columns are
209                // expected to contain the result by the user. Note that this
210                // reversal means we have to flip result, a correspondingly in the `eval`
211                // function.
212                self.populate_carry_and_witness(&result, b, FieldOperation::Mul, modulus);
213                self.result = P::to_limbs_field::<F, _>(&result);
214                result
215            }
216            _ => self.populate_carry_and_witness(a, b, op, modulus),
217        };
218
219        // Range checks
220        record.add_u8_range_checks_field(&self.result.0);
221        record.add_u8_range_checks_field(&self.carry.0);
222        record.add_u8_range_checks_field(&self.witness_low.0);
223        record.add_u8_range_checks_field(&self.witness_high.0);
224
225        result
226    }
227
228    /// Populate these columns without a specified modulus (will use the modulus of the field
229    /// parameters).
230    pub fn populate(
231        &mut self,
232        record: &mut impl ByteRecord,
233        a: &BigUint,
234        b: &BigUint,
235        op: FieldOperation,
236    ) -> BigUint {
237        self.populate_with_modulus(record, a, b, &P::modulus(), op)
238    }
239}
240
241impl<V: Copy, P: FieldParameters> FieldOpCols<V, P> {
242    /// Allows an evaluation over opetations specified by boolean flags.
243    #[allow(clippy::too_many_arguments)]
244    pub fn eval_variable<AB: SP1AirBuilder<Var = V>>(
245        &self,
246        builder: &mut AB,
247        a: &(impl Into<Polynomial<AB::Expr>> + Clone),
248        b: &(impl Into<Polynomial<AB::Expr>> + Clone),
249        modulus: &(impl Into<Polynomial<AB::Expr>> + Clone),
250        is_add: impl Into<AB::Expr> + Clone,
251        is_sub: impl Into<AB::Expr> + Clone,
252        is_mul: impl Into<AB::Expr> + Clone,
253        is_div: impl Into<AB::Expr> + Clone,
254        is_real: impl Into<AB::Expr> + Clone,
255    ) where
256        V: Into<AB::Expr>,
257        Limbs<V, P::Limbs>: Copy,
258    {
259        let p_a_param: Polynomial<AB::Expr> = (a).clone().into();
260        let p_b: Polynomial<AB::Expr> = (b).clone().into();
261        let p_res_param: Polynomial<AB::Expr> = self.result.into();
262
263        let is_add: AB::Expr = is_add.into();
264        let is_sub: AB::Expr = is_sub.into();
265        let is_mul: AB::Expr = is_mul.into();
266        let is_div: AB::Expr = is_div.into();
267
268        let p_result = p_res_param.clone() * (is_add.clone() + is_mul.clone()) +
269            p_a_param.clone() * (is_sub.clone() + is_div.clone());
270
271        let p_add = p_a_param.clone() + p_b.clone();
272        let p_sub = p_res_param.clone() + p_b.clone();
273        let p_mul = p_a_param.clone() * p_b.clone();
274        let p_div = p_res_param * p_b.clone();
275        let p_op = p_add * is_add + p_sub * is_sub + p_mul * is_mul + p_div * is_div;
276
277        self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result, is_real);
278    }
279
280    #[allow(clippy::too_many_arguments)]
281    pub fn eval_mul_and_carry<AB: SP1AirBuilder<Var = V>>(
282        &self,
283        builder: &mut AB,
284        a: &(impl Into<Polynomial<AB::Expr>> + Clone),
285        b: &(impl Into<Polynomial<AB::Expr>> + Clone),
286        c: &(impl Into<Polynomial<AB::Expr>> + Clone),
287        modulus: &(impl Into<Polynomial<AB::Expr>> + Clone),
288        is_real: impl Into<AB::Expr> + Clone,
289    ) where
290        V: Into<AB::Expr>,
291        Limbs<V, P::Limbs>: Copy,
292    {
293        let p_a: Polynomial<AB::Expr> = (a).clone().into();
294        let p_b: Polynomial<AB::Expr> = (b).clone().into();
295        let p_c: Polynomial<AB::Expr> = (c).clone().into();
296
297        let p_result: Polynomial<_> = self.result.into();
298        let p_op = p_a * p_b + p_c;
299
300        self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result, is_real);
301    }
302
303    #[allow(clippy::too_many_arguments)]
304    pub fn eval_with_modulus<AB: SP1AirBuilder<Var = V>>(
305        &self,
306        builder: &mut AB,
307        a: &(impl Into<Polynomial<AB::Expr>> + Clone),
308        b: &(impl Into<Polynomial<AB::Expr>> + Clone),
309        modulus: &(impl Into<Polynomial<AB::Expr>> + Clone),
310        op: FieldOperation,
311        is_real: impl Into<AB::Expr> + Clone,
312    ) where
313        V: Into<AB::Expr>,
314        Limbs<V, P::Limbs>: Copy,
315    {
316        let p_a_param: Polynomial<AB::Expr> = (a).clone().into();
317        let p_b: Polynomial<AB::Expr> = (b).clone().into();
318
319        let (p_a, p_result): (Polynomial<_>, Polynomial<_>) = match op {
320            FieldOperation::Add | FieldOperation::Mul => (p_a_param, self.result.into()),
321            FieldOperation::Sub | FieldOperation::Div => (self.result.into(), p_a_param),
322        };
323        let p_op: Polynomial<<AB as AirBuilder>::Expr> = match op {
324            FieldOperation::Add | FieldOperation::Sub => p_a + p_b,
325            FieldOperation::Mul | FieldOperation::Div => p_a * p_b,
326        };
327        self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result, is_real);
328    }
329
330    #[allow(clippy::too_many_arguments)]
331    pub fn eval_with_polynomials<AB: SP1AirBuilder<Var = V>>(
332        &self,
333        builder: &mut AB,
334        op: impl Into<Polynomial<AB::Expr>>,
335        modulus: impl Into<Polynomial<AB::Expr>>,
336        result: impl Into<Polynomial<AB::Expr>>,
337        is_real: impl Into<AB::Expr> + Clone,
338    ) where
339        V: Into<AB::Expr>,
340        Limbs<V, P::Limbs>: Copy,
341    {
342        let p_op: Polynomial<AB::Expr> = op.into();
343        let p_result: Polynomial<AB::Expr> = result.into();
344        let p_modulus: Polynomial<AB::Expr> = modulus.into();
345        let p_carry: Polynomial<<AB as AirBuilder>::Expr> = self.carry.into();
346        let p_op_minus_result: Polynomial<AB::Expr> = p_op - &p_result;
347        let p_vanishing = p_op_minus_result - &(&p_carry * &p_modulus);
348        let p_witness_low = self.witness_low.0.iter().into();
349        let p_witness_high = self.witness_high.0.iter().into();
350        eval_field_operation::<AB, P>(builder, &p_vanishing, &p_witness_low, &p_witness_high);
351
352        // Range checks for the result, carry, and witness columns.
353        builder.slice_range_check_u8(&self.result.0, is_real.clone());
354        builder.slice_range_check_u8(&self.carry.0, is_real.clone());
355        builder.slice_range_check_u8(p_witness_low.coefficients(), is_real.clone());
356        builder.slice_range_check_u8(p_witness_high.coefficients(), is_real);
357    }
358
359    #[allow(clippy::too_many_arguments)]
360    pub fn eval<AB: SP1AirBuilder<Var = V>>(
361        &self,
362        builder: &mut AB,
363        a: &(impl Into<Polynomial<AB::Expr>> + Clone),
364        b: &(impl Into<Polynomial<AB::Expr>> + Clone),
365        op: FieldOperation,
366        is_real: impl Into<AB::Expr> + Clone,
367    ) where
368        V: Into<AB::Expr>,
369        Limbs<V, P::Limbs>: Copy,
370    {
371        let p_limbs = Polynomial::from_iter(P::modulus_field_iter::<AB::F>().map(AB::Expr::from));
372        self.eval_with_modulus::<AB>(builder, a, b, &p_limbs, op, is_real);
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    #![allow(clippy::print_stdout)]
379
380    use num::BigUint;
381    use p3_air::BaseAir;
382    use p3_field::{Field, PrimeField32};
383    use sp1_core_executor::{ExecutionRecord, Program};
384    use sp1_curves::params::FieldParameters;
385    use sp1_stark::{
386        air::{MachineAir, SP1AirBuilder},
387        StarkGenericConfig,
388    };
389
390    use super::{FieldOpCols, FieldOperation, Limbs};
391
392    use crate::utils::{
393        pad_to_power_of_two,
394        uni_stark::{uni_stark_prove, uni_stark_verify},
395    };
396    use core::borrow::{Borrow, BorrowMut};
397    use num::bigint::RandBigInt;
398    use p3_air::Air;
399    use p3_baby_bear::BabyBear;
400    use p3_field::AbstractField;
401    use p3_matrix::{dense::RowMajorMatrix, Matrix};
402    use rand::thread_rng;
403    use sp1_core_executor::events::ByteRecord;
404    use sp1_curves::{
405        edwards::ed25519::Ed25519BaseField, weierstrass::secp256k1::Secp256k1BaseField,
406    };
407    use sp1_derive::AlignedBorrow;
408    use sp1_stark::baby_bear_poseidon2::BabyBearPoseidon2;
409    use std::mem::size_of;
410
411    #[derive(AlignedBorrow, Debug, Clone)]
412    pub struct TestCols<T, P: FieldParameters> {
413        pub a: Limbs<T, P::Limbs>,
414        pub b: Limbs<T, P::Limbs>,
415        pub a_op_b: FieldOpCols<T, P>,
416    }
417
418    pub const NUM_TEST_COLS: usize = size_of::<TestCols<u8, Secp256k1BaseField>>();
419
420    struct FieldOpChip<P: FieldParameters> {
421        pub operation: FieldOperation,
422        pub _phantom: std::marker::PhantomData<P>,
423    }
424
425    impl<P: FieldParameters> FieldOpChip<P> {
426        pub const fn new(operation: FieldOperation) -> Self {
427            Self { operation, _phantom: std::marker::PhantomData }
428        }
429    }
430
431    impl<F: PrimeField32, P: FieldParameters> MachineAir<F> for FieldOpChip<P> {
432        type Record = ExecutionRecord;
433
434        type Program = Program;
435
436        fn name(&self) -> String {
437            format!("FieldOp{:?}", self.operation)
438        }
439
440        fn generate_trace(
441            &self,
442            _: &ExecutionRecord,
443            output: &mut ExecutionRecord,
444        ) -> RowMajorMatrix<F> {
445            let mut rng = thread_rng();
446            let num_rows = 1 << 8;
447            let mut operands: Vec<(BigUint, BigUint)> = (0..num_rows - 4)
448                .map(|_| {
449                    let a = rng.gen_biguint(256) % &P::modulus();
450                    let b = rng.gen_biguint(256) % &P::modulus();
451                    (a, b)
452                })
453                .collect();
454
455            // Hardcoded edge cases.
456            operands.extend(vec![
457                (BigUint::from(0u32), BigUint::from(1u32)),
458                (BigUint::from(1u32), BigUint::from(2u32)),
459                (BigUint::from(4u32), BigUint::from(5u32)),
460                (BigUint::from(10u32), BigUint::from(19u32)),
461            ]);
462
463            let rows = operands
464                .iter()
465                .map(|(a, b)| {
466                    let mut blu_events = Vec::new();
467                    let mut row = [F::zero(); NUM_TEST_COLS];
468                    let cols: &mut TestCols<F, P> = row.as_mut_slice().borrow_mut();
469                    cols.a = P::to_limbs_field::<F, _>(a);
470                    cols.b = P::to_limbs_field::<F, _>(b);
471                    cols.a_op_b.populate(&mut blu_events, a, b, self.operation);
472                    output.add_byte_lookup_events(blu_events);
473                    row
474                })
475                .collect::<Vec<_>>();
476            // Convert the trace to a row major matrix.
477            let mut trace =
478                RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_TEST_COLS);
479
480            // Pad the trace to a power of two.
481            pad_to_power_of_two::<NUM_TEST_COLS, F>(&mut trace.values);
482
483            trace
484        }
485
486        fn included(&self, _: &Self::Record) -> bool {
487            true
488        }
489    }
490
491    impl<F: Field, P: FieldParameters> BaseAir<F> for FieldOpChip<P> {
492        fn width(&self) -> usize {
493            NUM_TEST_COLS
494        }
495    }
496
497    impl<AB, P: FieldParameters> Air<AB> for FieldOpChip<P>
498    where
499        AB: SP1AirBuilder,
500        Limbs<AB::Var, P::Limbs>: Copy,
501    {
502        fn eval(&self, builder: &mut AB) {
503            let main = builder.main();
504            let local = main.row_slice(0);
505            let local: &TestCols<AB::Var, P> = (*local).borrow();
506            local.a_op_b.eval(builder, &local.a, &local.b, self.operation, AB::F::one());
507        }
508    }
509
510    #[test]
511    fn generate_trace() {
512        for op in [FieldOperation::Add, FieldOperation::Mul, FieldOperation::Sub].iter() {
513            println!("op: {:?}", op);
514            let chip: FieldOpChip<Ed25519BaseField> = FieldOpChip::new(*op);
515            let shard = ExecutionRecord::default();
516            let _: RowMajorMatrix<BabyBear> =
517                chip.generate_trace(&shard, &mut ExecutionRecord::default());
518            // println!("{:?}", trace.values)
519        }
520    }
521
522    #[test]
523    fn prove_babybear() {
524        let config = BabyBearPoseidon2::new();
525
526        for op in
527            [FieldOperation::Add, FieldOperation::Sub, FieldOperation::Mul, FieldOperation::Div]
528                .iter()
529        {
530            println!("op: {:?}", op);
531
532            let mut challenger = config.challenger();
533
534            let chip: FieldOpChip<Ed25519BaseField> = FieldOpChip::new(*op);
535            let shard = ExecutionRecord::default();
536            let trace: RowMajorMatrix<BabyBear> =
537                chip.generate_trace(&shard, &mut ExecutionRecord::default());
538            let proof =
539                uni_stark_prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
540
541            let mut challenger = config.challenger();
542            uni_stark_verify(&config, &chip, &mut challenger, &proof).unwrap();
543        }
544    }
545}