Skip to main content

stwo_constraint_framework/
component.rs

1use core::fmt::{self, Display, Formatter};
2use core::iter::zip;
3use core::ops::Deref;
4
5use hashbrown::HashMap;
6use itertools::Itertools;
7use num_traits::Zero;
8use std_shims::{vec, String, Vec};
9use stwo::core::air::accumulation::PointEvaluationAccumulator;
10use stwo::core::air::Component;
11use stwo::core::circle::CirclePoint;
12use stwo::core::constraints::coset_vanishing;
13use stwo::core::fields::qm31::SecureField;
14use stwo::core::fields::FieldExpOps;
15use stwo::core::pcs::{TreeSubspan, TreeVec};
16use stwo::core::poly::circle::{CanonicCoset, MIN_CIRCLE_DOMAIN_LOG_SIZE};
17use stwo::core::utils::all_unique;
18use stwo::core::ColumnVec;
19
20use super::preprocessed_columns::PreProcessedColumnId;
21use super::{EvalAtRow, InfoEvaluator, PointEvaluator, PREPROCESSED_TRACE_IDX};
22
23#[derive(Debug, Default)]
24enum PreprocessedColumnsAllocationMode {
25    #[default]
26    Dynamic,
27    Static,
28}
29
30// TODO(andrew): Docs.
31// TODO(andrew): Consider better location for this.
32#[derive(Debug, Default)]
33pub struct TraceLocationAllocator {
34    /// Mapping of tree index to next available column offset.
35    next_tree_offsets: TreeVec<usize>,
36    /// Mapping of preprocessed columns to their index.
37    preprocessed_columns: Vec<PreProcessedColumnId>,
38    /// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
39    preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
40}
41
42impl TraceLocationAllocator {
43    pub fn next_for_structure<T>(
44        &mut self,
45        structure: &TreeVec<ColumnVec<T>>,
46    ) -> TreeVec<TreeSubspan> {
47        if structure.len() > self.next_tree_offsets.len() {
48            self.next_tree_offsets.resize(structure.len(), 0);
49        }
50
51        TreeVec::new(
52            zip(&mut *self.next_tree_offsets, &**structure)
53                .enumerate()
54                .map(|(tree_index, (offset, cols))| {
55                    let col_start = *offset;
56                    let col_end = col_start + cols.len();
57                    *offset = col_end;
58                    TreeSubspan {
59                        tree_index,
60                        col_start,
61                        col_end,
62                    }
63                })
64                .collect(),
65        )
66    }
67
68    /// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
69    pub fn new_with_preprocessed_columns(preprocessed_columns: &[PreProcessedColumnId]) -> Self {
70        assert!(
71            all_unique(preprocessed_columns),
72            "Duplicate preprocessed columns are not allowed!"
73        );
74        Self {
75            next_tree_offsets: Default::default(),
76            preprocessed_columns: preprocessed_columns.to_vec(),
77            preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
78        }
79    }
80
81    pub const fn preprocessed_columns(&self) -> &Vec<PreProcessedColumnId> {
82        &self.preprocessed_columns
83    }
84
85    // Validates that `self.preprocessed_columns` is a permutation of
86    // `preprocessed_columns`.
87    pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreProcessedColumnId]) {
88        let mut self_columns = self.preprocessed_columns.clone();
89        let mut input_columns = preprocessed_columns.to_vec();
90        self_columns.sort_by_key(|col| col.id.clone());
91        input_columns.sort_by_key(|col| col.id.clone());
92        assert_eq!(
93            self_columns, input_columns,
94            "Preprocessed columns are not a permutation."
95        );
96    }
97}
98
99/// A component defined solely in means of the constraints framework.
100///
101/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
102/// the SIMD backend. Note that the constraint framework only supports components with columns of
103/// the same size.
104///
105/// [`ComponentProver`]: stwo::prover::ComponentProver
106pub trait FrameworkEval {
107    fn log_size(&self) -> u32;
108
109    fn max_constraint_log_degree_bound(&self) -> u32;
110
111    fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
112}
113
114pub struct FrameworkComponent<C: FrameworkEval> {
115    pub(super) eval: C,
116    pub(super) trace_locations: TreeVec<TreeSubspan>,
117    pub(super) preprocessed_column_indices: Vec<usize>,
118    pub(super) claimed_sum: SecureField,
119    info: InfoEvaluator,
120    is_disabled: bool,
121}
122
123impl<E: FrameworkEval> FrameworkComponent<E> {
124    pub fn new(
125        location_allocator: &mut TraceLocationAllocator,
126        eval: E,
127        claimed_sum: SecureField,
128    ) -> Self {
129        let is_disabled = false;
130        Self::new_ex(location_allocator, eval, claimed_sum, is_disabled)
131    }
132
133    pub fn disabled(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
134        let claimed_sum = SecureField::zero();
135        let is_disabled = true;
136        Self::new_ex(location_allocator, eval, claimed_sum, is_disabled)
137    }
138
139    pub fn new_ex(
140        location_allocator: &mut TraceLocationAllocator,
141        eval: E,
142        claimed_sum: SecureField,
143        is_disabled: bool,
144    ) -> Self {
145        let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], claimed_sum));
146        let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);
147
148        let preprocessed_column_indices = info
149            .preprocessed_columns
150            .iter()
151            .map(|col| {
152                let next_column = location_allocator.preprocessed_columns.len();
153                if let Some(pos) = location_allocator
154                    .preprocessed_columns
155                    .iter()
156                    .position(|x| x.id == col.id)
157                {
158                    pos
159                } else {
160                    if matches!(
161                        location_allocator.preprocessed_columns_allocation_mode,
162                        PreprocessedColumnsAllocationMode::Static
163                    ) {
164                        panic!("Preprocessed column {col:?} is missing from static allocation");
165                    }
166                    location_allocator.preprocessed_columns.push(col.clone());
167                    next_column
168                }
169            })
170            .collect();
171        Self {
172            eval,
173            trace_locations,
174            info,
175            preprocessed_column_indices,
176            claimed_sum,
177            is_disabled,
178        }
179    }
180
181    pub fn trace_locations(&self) -> &[TreeSubspan] {
182        &self.trace_locations
183    }
184
185    pub fn preprocessed_column_indices(&self) -> &[usize] {
186        &self.preprocessed_column_indices
187    }
188
189    pub const fn claimed_sum(&self) -> SecureField {
190        self.claimed_sum
191    }
192
193    pub fn logup_counts(&self) -> RelationCounts {
194        let size = 1 << self.eval.log_size();
195        RelationCounts(
196            self.info
197                .logup_counts
198                .iter()
199                .map(|(k, v)| (k.clone(), v * size))
200                .collect(),
201        )
202    }
203
204    pub fn is_disabled(&self) -> bool {
205        self.is_disabled
206    }
207}
208
209pub struct RelationCounts(HashMap<String, usize>);
210impl Deref for RelationCounts {
211    type Target = HashMap<String, usize>;
212
213    fn deref(&self) -> &Self::Target {
214        &self.0
215    }
216}
217
218impl<E: FrameworkEval> Component for FrameworkComponent<E> {
219    fn n_constraints(&self) -> usize {
220        self.info.n_constraints
221    }
222
223    fn max_constraint_log_degree_bound(&self) -> u32 {
224        if self.is_disabled() {
225            MIN_CIRCLE_DOMAIN_LOG_SIZE
226        } else {
227            self.eval.max_constraint_log_degree_bound()
228        }
229    }
230
231    fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
232        let log_size = if self.is_disabled() {
233            MIN_CIRCLE_DOMAIN_LOG_SIZE
234        } else {
235            self.eval.log_size()
236        };
237
238        let mut log_degree_bounds = self
239            .info
240            .mask_offsets
241            .as_ref()
242            .map(|tree_offsets| vec![log_size; tree_offsets.len()]);
243
244        log_degree_bounds[0] = self
245            .preprocessed_column_indices
246            .iter()
247            .map(|_| log_size)
248            .collect();
249
250        log_degree_bounds
251    }
252
253    fn mask_points(
254        &self,
255        point: CirclePoint<SecureField>,
256        max_log_degree_bound: u32,
257    ) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
258        let trace_step = CanonicCoset::new(max_log_degree_bound).step();
259        self.info.mask_offsets.as_ref().map_cols(|col_offsets| {
260            col_offsets
261                .iter()
262                .map(|offset| point + trace_step.mul_signed(*offset).into_ef())
263                .collect()
264        })
265    }
266
267    fn preprocessed_column_indices(&self) -> ColumnVec<usize> {
268        self.preprocessed_column_indices.clone()
269    }
270
271    fn evaluate_constraint_quotients_at_point(
272        &self,
273        point: CirclePoint<SecureField>,
274        mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
275        evaluation_accumulator: &mut PointEvaluationAccumulator,
276        max_log_degree_bound: u32,
277    ) {
278        if self.is_disabled {
279            for _ in 0..self.n_constraints() {
280                evaluation_accumulator.accumulate(SecureField::zero());
281            }
282            return;
283        }
284        let preprocessed_mask = self
285            .preprocessed_column_indices
286            .iter()
287            .map(|idx| &mask[PREPROCESSED_TRACE_IDX][*idx])
288            .collect_vec();
289
290        let mut mask_points = mask.sub_tree(&self.trace_locations);
291        mask_points[PREPROCESSED_TRACE_IDX] = preprocessed_mask;
292
293        self.eval.evaluate(PointEvaluator::new(
294            mask_points,
295            evaluation_accumulator,
296            coset_vanishing(CanonicCoset::new(max_log_degree_bound).coset, point).inverse(),
297            self.eval.log_size(),
298            self.claimed_sum,
299        ));
300    }
301}
302
303impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
304    type Target = E;
305
306    fn deref(&self) -> &E {
307        &self.eval
308    }
309}
310
311impl<E: FrameworkEval> Display for FrameworkComponent<E> {
312    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
313        let log_n_rows = self.log_size();
314        let mut n_cols = vec![];
315        self.trace_log_degree_bounds()
316            .0
317            .iter()
318            .for_each(|interaction| {
319                n_cols.push(interaction.len());
320            });
321        writeln!(f, "n_rows 2^{log_n_rows}")?;
322        writeln!(f, "n_constraints {}", self.n_constraints())?;
323        writeln!(
324            f,
325            "constraint_log_degree_bound {}",
326            self.max_constraint_log_degree_bound()
327        )?;
328        writeln!(
329            f,
330            "total felts: 2^{} * {}",
331            log_n_rows,
332            n_cols.iter().sum::<usize>()
333        )?;
334        for (j, n_cols) in n_cols.into_iter().enumerate() {
335            writeln!(f, "\t Interaction {j}: n_cols {n_cols}")?;
336        }
337        Ok(())
338    }
339}