Skip to main content

stwo_constraint_framework/
lib.rs

1#![cfg_attr(feature = "prover", feature(portable_simd))]
2#![cfg_attr(not(feature = "std"), no_std)]
3
4/// ! This module contains helpers to express and use constraints for components.
5mod component;
6
7#[cfg(feature = "prover")]
8pub mod expr;
9mod info;
10pub mod logup;
11mod point;
12pub mod preprocessed_columns;
13#[cfg(all(feature = "prover", feature = "std"))]
14mod prover;
15
16use core::array;
17use core::fmt::Debug;
18use core::ops::{Add, AddAssign, Mul, Neg, Sub};
19
20pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
21pub use info::InfoEvaluator;
22use num_traits::{One, Zero};
23pub use point::PointEvaluator;
24use preprocessed_columns::PreProcessedColumnId;
25#[cfg(all(feature = "prover", feature = "std"))]
26pub use prover::{
27    assert_constraints_on_polys, assert_constraints_on_trace, relation_tracker, AssertEvaluator,
28    CpuDomainEvaluator, FractionWriter, LogupColGenerator, LogupTraceGenerator,
29    SimdDomainEvaluator,
30};
31use std_shims::Vec;
32use stwo::core::fields::m31::BaseField;
33use stwo::core::fields::qm31::{SecureField, SECURE_EXTENSION_DEGREE};
34use stwo::core::fields::FieldExpOps;
35use stwo::core::Fraction;
36
37#[rustfmt::skip]
38pub use stwo::core::verifier::PREPROCESSED_TRACE_IDX;
39pub const ORIGINAL_TRACE_IDX: usize = 1;
40pub const INTERACTION_TRACE_IDX: usize = 2;
41
42/// A vector that describes the batching of logup entries.
43/// Each vector member corresponds to a logup entry, and contains the batch number to which the
44/// entry should be added.
45/// Note that the batch numbers should be consecutive and start from 0, and that the vector's
46/// length should be equal to the number of logup entries.
47type Batching = Vec<usize>;
48
49/// A trait for evaluating expressions at some point or row.
50pub trait EvalAtRow {
51    // TODO(Ohad): Use a better trait for these, like 'Algebra' or something.
52    /// The field type holding values of columns for the component. These are the inputs to the
53    /// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating
54    /// the columns out of domain.
55    type F: FieldExpOps
56        + Clone
57        + Debug
58        + Zero
59        + Neg<Output = Self::F>
60        + AddAssign
61        + AddAssign<BaseField>
62        + Add<Self::F, Output = Self::F>
63        + Sub<Self::F, Output = Self::F>
64        + Mul<BaseField, Output = Self::F>
65        + Add<SecureField, Output = Self::EF>
66        + Mul<SecureField, Output = Self::EF>
67        + Neg<Output = Self::F>
68        + From<BaseField>;
69
70    /// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
71    /// usually get multiplied by [SecureField] values for security.
72    type EF: One
73        + Clone
74        + Debug
75        + Zero
76        + Neg<Output = Self::EF>
77        + AddAssign
78        + Add<BaseField, Output = Self::EF>
79        + Mul<BaseField, Output = Self::EF>
80        + Add<SecureField, Output = Self::EF>
81        + Sub<SecureField, Output = Self::EF>
82        + Mul<SecureField, Output = Self::EF>
83        + Add<Self::F, Output = Self::EF>
84        + Mul<Self::F, Output = Self::EF>
85        + Sub<Self::EF, Output = Self::EF>
86        + Mul<Self::EF, Output = Self::EF>
87        + From<SecureField>
88        + From<Self::F>;
89
90    /// Returns the next mask value for the first interaction at offset 0.
91    fn next_trace_mask(&mut self) -> Self::F {
92        let [mask_item] = self.next_interaction_mask(ORIGINAL_TRACE_IDX, [0]);
93        mask_item
94    }
95
96    fn get_preprocessed_column(&mut self, _column: PreProcessedColumnId) -> Self::F {
97        let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
98        mask_item
99    }
100
101    /// Returns the mask values of the given offsets for the next column in the interaction.
102    fn next_interaction_mask<const N: usize>(
103        &mut self,
104        interaction: usize,
105        offsets: [isize; N],
106    ) -> [Self::F; N];
107
108    /// Returns the extension mask values of the given offsets for the next extension degree many
109    /// columns in the interaction.
110    fn next_extension_interaction_mask<const N: usize>(
111        &mut self,
112        interaction: usize,
113        offsets: [isize; N],
114    ) -> [Self::EF; N] {
115        let mut res_col_major =
116            array::from_fn(|_| self.next_interaction_mask(interaction, offsets).into_iter());
117        array::from_fn(|_| {
118            Self::combine_ef(res_col_major.each_mut().map(|iter| iter.next().unwrap()))
119        })
120    }
121
122    /// Adds a constraint to the component.
123    fn add_constraint<G>(&mut self, constraint: G)
124    where
125        Self::EF: Mul<G, Output = Self::EF> + From<G>;
126
127    /// Adds an intermediate value in the base field to the component and returns its value.
128    /// Does nothing by default.
129    fn add_intermediate(&mut self, val: Self::F) -> Self::F {
130        val
131    }
132
133    /// Adds an intermediate value in the extension field to the component and returns its value.
134    /// Does nothing by default.
135    fn add_extension_intermediate(&mut self, val: Self::EF) -> Self::EF {
136        val
137    }
138
139    /// Combines 4 base field values into a single extension field value.
140    fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;
141
142    /// Adds `entry.values` to `entry.relation` with `entry.multiplicity` for all 'entry' in
143    /// 'entries', batched together.
144    /// Constraint degree increases with number of batched constraints as the denominators are
145    /// multiplied.
146    fn add_to_relation<R: Relation<Self::F, Self::EF>>(
147        &mut self,
148        entry: RelationEntry<'_, Self::F, Self::EF, R>,
149    ) {
150        let frac = Fraction::new(
151            entry.multiplicity.clone(),
152            entry.relation.combine(entry.values),
153        );
154        self.write_logup_frac(frac);
155    }
156
157    // TODO(alont): Remove these once LogupAtRow is no longer used.
158    fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {
159        unimplemented!()
160    }
161    fn finalize_logup_batched(&mut self, _batching: &Batching) {
162        unimplemented!()
163    }
164
165    fn finalize_logup(&mut self) {
166        unimplemented!();
167    }
168
169    fn finalize_logup_in_pairs(&mut self) {
170        unimplemented!();
171    }
172}
173
174/// Default implementation for evaluators that have an element called "logup" that works like a
175/// LogupAtRow, where the logup functionality can be proxied.
176/// TODO(alont): Remove once LogupAtRow is no longer used.
177macro_rules! logup_proxy {
178    () => {
179        fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
180            if self.logup.fracs.is_empty() {
181                self.logup.is_finalized = false;
182            }
183            self.logup.fracs.push(fraction.clone());
184        }
185
186        /// Finalize the logup by adding the constraints for the fractions, batched by
187        /// the given `batching`.
188        /// `batching` should contain the batch into which every logup entry should be inserted.
189        fn finalize_logup_batched(&mut self, batching: &crate::Batching) {
190            assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");
191            assert_eq!(
192                batching.len(),
193                self.logup.fracs.len(),
194                "Batching must be of the same length as the number of entries"
195            );
196
197            let last_batch = *batching.iter().max().unwrap();
198
199            let mut fracs_by_batch =
200                hashbrown::HashMap::<usize, std_shims::Vec<Fraction<Self::EF, Self::EF>>>::new();
201
202            for (batch, frac) in batching.iter().zip(self.logup.fracs.iter()) {
203                fracs_by_batch
204                    .entry(*batch)
205                    .or_insert_with(std_shims::Vec::new)
206                    .push(frac.clone());
207            }
208
209            let keys_set: hashbrown::HashSet<_> = fracs_by_batch.keys().cloned().collect();
210            let all_batches_set: hashbrown::HashSet<_> = (0..last_batch + 1).collect();
211
212            assert_eq!(
213                keys_set, all_batches_set,
214                "Batching must contain all consecutive batches"
215            );
216
217            let mut prev_col_cumsum = <Self::EF as num_traits::Zero>::zero();
218
219            // All batches except the last are cumulatively summed in new interaction columns.
220            for batch_id in (0..last_batch) {
221                let cur_frac: Fraction<_, _> = fracs_by_batch[&batch_id].iter().cloned().sum();
222                let [cur_cumsum] =
223                    self.next_extension_interaction_mask(self.logup.interaction, [0]);
224                let diff = cur_cumsum.clone() - prev_col_cumsum.clone();
225                prev_col_cumsum = cur_cumsum;
226                self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
227            }
228
229            let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum();
230            let [prev_row_cumsum, cur_cumsum] =
231                self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]);
232
233            let diff = cur_cumsum - prev_row_cumsum - prev_col_cumsum.clone();
234            // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
235            // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
236            // uniform - apply on all rows.
237            let shifted_diff = diff + self.logup.cumsum_shift.clone();
238
239            self.add_constraint(shifted_diff * frac.denominator - frac.numerator);
240
241            self.logup.is_finalized = true;
242        }
243
244        /// Finalizes the row's logup in the default way. Currently, this means no batching.
245        fn finalize_logup(&mut self) {
246            let batches = (0..self.logup.fracs.len()).collect();
247            self.finalize_logup_batched(&batches)
248        }
249
250        /// Finalizes the row's logup, batched in pairs.
251        /// TODO(alont) Remove this once a better batching mechanism is implemented.
252        fn finalize_logup_in_pairs(&mut self) {
253            let batches = (0..self.logup.fracs.len()).map(|n| n / 2).collect();
254            self.finalize_logup_batched(&batches)
255        }
256    };
257}
258pub(crate) use logup_proxy;
259
260pub trait RelationEFTraitBound<F: Clone>:
261    Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = Self> + Sub<Self, Output = Self>
262{
263}
264
265impl<F, EF> RelationEFTraitBound<F> for EF
266where
267    F: Clone,
268    EF: Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
269{
270}
271
272/// A trait for defining a logup relation type.
273pub trait Relation<F: Clone, EF: RelationEFTraitBound<F>>: Sized {
274    fn combine(&self, values: &[F]) -> EF;
275
276    fn get_name(&self) -> &str;
277    fn get_size(&self) -> usize;
278}
279
280/// A struct representing a relation entry.
281/// `relation` is the relation into which elements are entered.
282/// `multiplicity` is the multiplicity of the elements.
283///     A positive multiplicity is used to signify a "use", while a negative multiplicity
284///     signifies a "yield".
285/// `values` are elements in the base field that are entered into the relation.
286pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound<F>, R: Relation<F, EF>> {
287    relation: &'a R,
288    multiplicity: EF,
289    values: &'a [F],
290}
291impl<'a, F: Clone, EF: RelationEFTraitBound<F>, R: Relation<F, EF>> RelationEntry<'a, F, EF, R> {
292    pub const fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self {
293        Self {
294            relation,
295            multiplicity,
296            values,
297        }
298    }
299}
300
301#[macro_export]
302macro_rules! relation {
303    ($name:tt, $size:tt) => {
304        #[derive(Clone, Debug, PartialEq)]
305        pub struct $name($crate::logup::LookupElements<$size>);
306
307        #[allow(dead_code)]
308        impl $name {
309            pub fn dummy() -> Self {
310                Self($crate::logup::LookupElements::dummy())
311            }
312            pub fn draw(channel: &mut impl stwo::core::channel::Channel) -> Self {
313                Self($crate::logup::LookupElements::draw(channel))
314            }
315        }
316
317        impl<F: Clone, EF: $crate::RelationEFTraitBound<F>> $crate::Relation<F, EF> for $name {
318            fn combine(&self, values: &[F]) -> EF {
319                self.0.combine(values)
320            }
321
322            fn get_name(&self) -> &str {
323                stringify!($name)
324            }
325
326            fn get_size(&self) -> usize {
327                $size
328            }
329        }
330    };
331}
332
333#[cfg(test)]
334#[macro_export]
335macro_rules! m31 {
336    ($m:expr) => {
337        stwo::core::fields::m31::M31::from_u32_unchecked($m)
338    };
339}
340
341#[cfg(test)]
342#[macro_export]
343macro_rules! qm31 {
344    ($m0:expr, $m1:expr, $m2:expr, $m3:expr) => {{
345        stwo::core::fields::qm31::QM31::from_u32_unchecked($m0, $m1, $m2, $m3)
346    }};
347}