sp1_core_machine/operations/
lt.rs

1use itertools::izip;
2
3use p3_air::AirBuilder;
4use p3_field::{AbstractField, PrimeField32};
5
6use sp1_core_executor::{
7    events::{ByteLookupEvent, ByteRecord},
8    ByteOpcode,
9};
10use sp1_derive::AlignedBorrow;
11use sp1_stark::air::{BaseAirBuilder, SP1AirBuilder};
12
13/// Operation columns for verifying that an element is within the range `[0, modulus)`.
14#[derive(Debug, Clone, Copy, AlignedBorrow)]
15#[repr(C)]
16pub struct AssertLtColsBytes<T, const N: usize> {
17    /// Boolean flags to indicate the first byte in which the element is smaller than the modulus.
18    pub(crate) byte_flags: [T; N],
19
20    pub(crate) a_comparison_byte: T,
21    pub(crate) b_comparison_byte: T,
22}
23
24impl<F: PrimeField32, const N: usize> AssertLtColsBytes<F, N> {
25    pub fn populate(&mut self, record: &mut impl ByteRecord, a: &[u8], b: &[u8]) {
26        let mut byte_flags = vec![0u8; N];
27
28        for (a_byte, b_byte, flag) in
29            izip!(a.iter().rev(), b.iter().rev(), byte_flags.iter_mut().rev())
30        {
31            assert!(a_byte <= b_byte);
32            if a_byte < b_byte {
33                *flag = 1;
34                self.a_comparison_byte = F::from_canonical_u8(*a_byte);
35                self.b_comparison_byte = F::from_canonical_u8(*b_byte);
36                record.add_byte_lookup_event(ByteLookupEvent {
37                    opcode: ByteOpcode::LTU,
38                    a1: 1,
39                    a2: 0,
40                    b: *a_byte,
41                    c: *b_byte,
42                });
43                break;
44            }
45        }
46
47        for (byte, flag) in izip!(byte_flags.iter(), self.byte_flags.iter_mut()) {
48            *flag = F::from_canonical_u8(*byte);
49        }
50    }
51}
52
53impl<V: Copy, const N: usize> AssertLtColsBytes<V, N> {
54    pub fn eval<
55        AB: SP1AirBuilder<Var = V>,
56        Ea: Into<AB::Expr> + Clone,
57        Eb: Into<AB::Expr> + Clone,
58    >(
59        &self,
60        builder: &mut AB,
61        a: &[Ea],
62        b: &[Eb],
63        is_real: impl Into<AB::Expr> + Clone,
64    ) where
65        V: Into<AB::Expr>,
66    {
67        // The byte flags give a specification of which byte is `first_eq`, i,e, the first most
68        // significant byte for which the element `a` is smaller than `b`. To verify the
69        // less-than claim we need to check that:
70        // * For all bytes until `first_eq` the element `a` byte is equal to the `b` byte.
71        // * For the `first_eq` byte the `a`` byte is smaller than the `b`byte.
72        // * all byte flags are boolean.
73        // * only one byte flag is set to one, and the rest are set to zero.
74
75        // Check the flags are of valid form.
76
77        // Verrify that only one flag is set to one.
78        let mut sum_flags: AB::Expr = AB::Expr::zero();
79        for &flag in self.byte_flags.iter() {
80            // Assert that the flag is boolean.
81            builder.assert_bool(flag);
82            // Add the flag to the sum.
83            sum_flags = sum_flags.clone() + flag.into();
84        }
85        // Assert that the sum is equal to one.
86        builder.when(is_real.clone()).assert_one(sum_flags);
87
88        // Check the less-than condition.
89
90        // A flag to indicate whether an equality check is necessary (this is for all bytes from
91        // most significant until the first inequality.
92        let mut is_inequality_visited = AB::Expr::zero();
93
94        // The bytes of the modulus.
95
96        let a: [AB::Expr; N] = core::array::from_fn(|i| a[i].clone().into());
97        let b: [AB::Expr; N] = core::array::from_fn(|i| b[i].clone().into());
98
99        let mut first_lt_byte = AB::Expr::zero();
100        let mut b_comparison_byte = AB::Expr::zero();
101        for (a_byte, b_byte, &flag) in
102            izip!(a.iter().rev(), b.iter().rev(), self.byte_flags.iter().rev())
103        {
104            // Once the byte flag was set to one, we turn off the quality check flag.
105            // We can do this by calculating the sum of the flags since only `1` is set to `1`.
106            is_inequality_visited = is_inequality_visited.clone() + flag.into();
107
108            first_lt_byte = first_lt_byte.clone() + a_byte.clone() * flag;
109            b_comparison_byte = b_comparison_byte.clone() + b_byte.clone() * flag;
110
111            builder
112                .when_not(is_inequality_visited.clone())
113                .when(is_real.clone())
114                .assert_eq(a_byte.clone(), b_byte.clone());
115        }
116
117        builder.when(is_real.clone()).assert_eq(self.a_comparison_byte, first_lt_byte);
118        builder.when(is_real.clone()).assert_eq(self.b_comparison_byte, b_comparison_byte);
119
120        // Send the comparison interaction.
121        builder.send_byte(
122            ByteOpcode::LTU.as_field::<AB::F>(),
123            AB::F::one(),
124            self.a_comparison_byte,
125            self.b_comparison_byte,
126            is_real,
127        )
128    }
129}
130
131/// Operation columns for verifying that an element is within the range `[0, modulus)`.
132#[derive(Debug, Clone, Copy, AlignedBorrow)]
133#[repr(C)]
134pub struct AssertLtColsBits<T, const N: usize> {
135    /// Boolean flags to indicate the first byte in which the element is smaller than the modulus.
136    pub(crate) bit_flags: [T; N],
137}
138
139impl<F: PrimeField32, const N: usize> AssertLtColsBits<F, N> {
140    pub fn populate(&mut self, a: &[u32], b: &[u32]) {
141        let mut bit_flags = vec![0u8; N];
142
143        for (a_bit, b_bit, flag) in
144            izip!(a.iter().rev(), b.iter().rev(), bit_flags.iter_mut().rev())
145        {
146            assert!(a_bit <= b_bit);
147            debug_assert!(*a_bit == 0 || *a_bit == 1);
148            debug_assert!(*b_bit == 0 || *b_bit == 1);
149            if a_bit < b_bit {
150                *flag = 1;
151                break;
152            }
153        }
154
155        for (bit, flag) in izip!(bit_flags.iter(), self.bit_flags.iter_mut()) {
156            *flag = F::from_canonical_u8(*bit);
157        }
158    }
159}
160
161impl<V: Copy, const N: usize> AssertLtColsBits<V, N> {
162    pub fn eval<
163        AB: SP1AirBuilder<Var = V>,
164        Ea: Into<AB::Expr> + Clone,
165        Eb: Into<AB::Expr> + Clone,
166    >(
167        &self,
168        builder: &mut AB,
169        a: &[Ea],
170        b: &[Eb],
171        is_real: impl Into<AB::Expr> + Clone,
172    ) where
173        V: Into<AB::Expr>,
174    {
175        // The bit flags give a specification of which bit is `first_lt`, i,e, the first most
176        // significant bit for which the element `a` is smaller than `b`. To verify the
177        // less-than claim we need to check that:
178        // * For all bytes until `first_lt` the element `a` byte is equal to the `b` byte.
179        // * For the `first_lt` bit the `a`` bit is smaller than the `b` bit.
180        // * all bit flags are boolean.
181        // * only one bit flag is set to one, and the rest are set to zero.
182
183        // Check the flags are of valid form.
184
185        // Verrify that only one flag is set to one.
186        let mut sum_flags: AB::Expr = AB::Expr::zero();
187        for &flag in self.bit_flags.iter() {
188            // Assert that the flag is boolean.
189            builder.assert_bool(flag);
190            // Add the flag to the sum.
191            sum_flags = sum_flags.clone() + flag.into();
192        }
193        // Assert that the sum is equal to one.
194        builder.when(is_real.clone()).assert_one(sum_flags);
195
196        // Check the less-than condition.
197
198        // A flag to indicate whether an equality check is necessary (this is for all bits from
199        // most significant until the first inequality.
200        let mut is_inequality_visited = AB::Expr::zero();
201
202        // The bits of the elements.
203        let a: [AB::Expr; N] = core::array::from_fn(|i| a[i].clone().into());
204        let b: [AB::Expr; N] = core::array::from_fn(|i| b[i].clone().into());
205
206        // Calculate the bit which is the first inequality.
207        let mut a_comparison_bit = AB::Expr::zero();
208        let mut b_comparison_bit = AB::Expr::zero();
209        for (a_bit, b_bit, &flag) in
210            izip!(a.iter().rev(), b.iter().rev(), self.bit_flags.iter().rev())
211        {
212            // Once the bit flag was set to one, we turn off the quality check flag.
213            // We can do this by calculating the sum of the flags since only `1` is set to `1`.
214            is_inequality_visited = is_inequality_visited.clone() + flag.into();
215
216            a_comparison_bit = a_comparison_bit.clone() + a_bit.clone() * flag;
217            b_comparison_bit = b_comparison_bit.clone() + b_bit.clone() * flag;
218
219            builder
220                .when(is_real.clone())
221                .when_not(is_inequality_visited.clone())
222                .assert_eq(a_bit.clone(), b_bit.clone());
223        }
224
225        builder.when(is_real.clone()).assert_eq(a_comparison_bit, AB::F::zero());
226        builder.when(is_real.clone()).assert_eq(b_comparison_bit, AB::F::one());
227    }
228}