stwo_constraint_framework/
component.rs1use core::fmt::{self, Display, Formatter};
2use core::iter::zip;
3use core::ops::Deref;
4
5use hashbrown::HashMap;
6use itertools::Itertools;
7use num_traits::Zero;
8use std_shims::{vec, String, Vec};
9use stwo::core::air::accumulation::PointEvaluationAccumulator;
10use stwo::core::air::Component;
11use stwo::core::circle::CirclePoint;
12use stwo::core::constraints::coset_vanishing;
13use stwo::core::fields::qm31::SecureField;
14use stwo::core::fields::FieldExpOps;
15use stwo::core::pcs::{TreeSubspan, TreeVec};
16use stwo::core::poly::circle::{CanonicCoset, MIN_CIRCLE_DOMAIN_LOG_SIZE};
17use stwo::core::utils::all_unique;
18use stwo::core::ColumnVec;
19
20use super::preprocessed_columns::PreProcessedColumnId;
21use super::{EvalAtRow, InfoEvaluator, PointEvaluator, PREPROCESSED_TRACE_IDX};
22
23#[derive(Debug, Default)]
24enum PreprocessedColumnsAllocationMode {
25 #[default]
26 Dynamic,
27 Static,
28}
29
30#[derive(Debug, Default)]
33pub struct TraceLocationAllocator {
34 next_tree_offsets: TreeVec<usize>,
36 preprocessed_columns: Vec<PreProcessedColumnId>,
38 preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
40}
41
42impl TraceLocationAllocator {
43 pub fn next_for_structure<T>(
44 &mut self,
45 structure: &TreeVec<ColumnVec<T>>,
46 ) -> TreeVec<TreeSubspan> {
47 if structure.len() > self.next_tree_offsets.len() {
48 self.next_tree_offsets.resize(structure.len(), 0);
49 }
50
51 TreeVec::new(
52 zip(&mut *self.next_tree_offsets, &**structure)
53 .enumerate()
54 .map(|(tree_index, (offset, cols))| {
55 let col_start = *offset;
56 let col_end = col_start + cols.len();
57 *offset = col_end;
58 TreeSubspan {
59 tree_index,
60 col_start,
61 col_end,
62 }
63 })
64 .collect(),
65 )
66 }
67
68 pub fn new_with_preprocessed_columns(preprocessed_columns: &[PreProcessedColumnId]) -> Self {
70 assert!(
71 all_unique(preprocessed_columns),
72 "Duplicate preprocessed columns are not allowed!"
73 );
74 Self {
75 next_tree_offsets: Default::default(),
76 preprocessed_columns: preprocessed_columns.to_vec(),
77 preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
78 }
79 }
80
81 pub const fn preprocessed_columns(&self) -> &Vec<PreProcessedColumnId> {
82 &self.preprocessed_columns
83 }
84
85 pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreProcessedColumnId]) {
88 let mut self_columns = self.preprocessed_columns.clone();
89 let mut input_columns = preprocessed_columns.to_vec();
90 self_columns.sort_by_key(|col| col.id.clone());
91 input_columns.sort_by_key(|col| col.id.clone());
92 assert_eq!(
93 self_columns, input_columns,
94 "Preprocessed columns are not a permutation."
95 );
96 }
97}
98
99pub trait FrameworkEval {
107 fn log_size(&self) -> u32;
108
109 fn max_constraint_log_degree_bound(&self) -> u32;
110
111 fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
112}
113
114pub struct FrameworkComponent<C: FrameworkEval> {
115 pub(super) eval: C,
116 pub(super) trace_locations: TreeVec<TreeSubspan>,
117 pub(super) preprocessed_column_indices: Vec<usize>,
118 pub(super) claimed_sum: SecureField,
119 info: InfoEvaluator,
120 is_disabled: bool,
121}
122
123impl<E: FrameworkEval> FrameworkComponent<E> {
124 pub fn new(
125 location_allocator: &mut TraceLocationAllocator,
126 eval: E,
127 claimed_sum: SecureField,
128 ) -> Self {
129 let is_disabled = false;
130 Self::new_ex(location_allocator, eval, claimed_sum, is_disabled)
131 }
132
133 pub fn disabled(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
134 let claimed_sum = SecureField::zero();
135 let is_disabled = true;
136 Self::new_ex(location_allocator, eval, claimed_sum, is_disabled)
137 }
138
139 pub fn new_ex(
140 location_allocator: &mut TraceLocationAllocator,
141 eval: E,
142 claimed_sum: SecureField,
143 is_disabled: bool,
144 ) -> Self {
145 let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], claimed_sum));
146 let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);
147
148 let preprocessed_column_indices = info
149 .preprocessed_columns
150 .iter()
151 .map(|col| {
152 let next_column = location_allocator.preprocessed_columns.len();
153 if let Some(pos) = location_allocator
154 .preprocessed_columns
155 .iter()
156 .position(|x| x.id == col.id)
157 {
158 pos
159 } else {
160 if matches!(
161 location_allocator.preprocessed_columns_allocation_mode,
162 PreprocessedColumnsAllocationMode::Static
163 ) {
164 panic!("Preprocessed column {col:?} is missing from static allocation");
165 }
166 location_allocator.preprocessed_columns.push(col.clone());
167 next_column
168 }
169 })
170 .collect();
171 Self {
172 eval,
173 trace_locations,
174 info,
175 preprocessed_column_indices,
176 claimed_sum,
177 is_disabled,
178 }
179 }
180
181 pub fn trace_locations(&self) -> &[TreeSubspan] {
182 &self.trace_locations
183 }
184
185 pub fn preprocessed_column_indices(&self) -> &[usize] {
186 &self.preprocessed_column_indices
187 }
188
189 pub const fn claimed_sum(&self) -> SecureField {
190 self.claimed_sum
191 }
192
193 pub fn logup_counts(&self) -> RelationCounts {
194 let size = 1 << self.eval.log_size();
195 RelationCounts(
196 self.info
197 .logup_counts
198 .iter()
199 .map(|(k, v)| (k.clone(), v * size))
200 .collect(),
201 )
202 }
203
204 pub fn is_disabled(&self) -> bool {
205 self.is_disabled
206 }
207}
208
209pub struct RelationCounts(HashMap<String, usize>);
210impl Deref for RelationCounts {
211 type Target = HashMap<String, usize>;
212
213 fn deref(&self) -> &Self::Target {
214 &self.0
215 }
216}
217
218impl<E: FrameworkEval> Component for FrameworkComponent<E> {
219 fn n_constraints(&self) -> usize {
220 self.info.n_constraints
221 }
222
223 fn max_constraint_log_degree_bound(&self) -> u32 {
224 if self.is_disabled() {
225 MIN_CIRCLE_DOMAIN_LOG_SIZE
226 } else {
227 self.eval.max_constraint_log_degree_bound()
228 }
229 }
230
231 fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
232 let log_size = if self.is_disabled() {
233 MIN_CIRCLE_DOMAIN_LOG_SIZE
234 } else {
235 self.eval.log_size()
236 };
237
238 let mut log_degree_bounds = self
239 .info
240 .mask_offsets
241 .as_ref()
242 .map(|tree_offsets| vec![log_size; tree_offsets.len()]);
243
244 log_degree_bounds[0] = self
245 .preprocessed_column_indices
246 .iter()
247 .map(|_| log_size)
248 .collect();
249
250 log_degree_bounds
251 }
252
253 fn mask_points(
254 &self,
255 point: CirclePoint<SecureField>,
256 max_log_degree_bound: u32,
257 ) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
258 let trace_step = CanonicCoset::new(max_log_degree_bound).step();
259 self.info.mask_offsets.as_ref().map_cols(|col_offsets| {
260 col_offsets
261 .iter()
262 .map(|offset| point + trace_step.mul_signed(*offset).into_ef())
263 .collect()
264 })
265 }
266
267 fn preprocessed_column_indices(&self) -> ColumnVec<usize> {
268 self.preprocessed_column_indices.clone()
269 }
270
271 fn evaluate_constraint_quotients_at_point(
272 &self,
273 point: CirclePoint<SecureField>,
274 mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
275 evaluation_accumulator: &mut PointEvaluationAccumulator,
276 max_log_degree_bound: u32,
277 ) {
278 if self.is_disabled {
279 for _ in 0..self.n_constraints() {
280 evaluation_accumulator.accumulate(SecureField::zero());
281 }
282 return;
283 }
284 let preprocessed_mask = self
285 .preprocessed_column_indices
286 .iter()
287 .map(|idx| &mask[PREPROCESSED_TRACE_IDX][*idx])
288 .collect_vec();
289
290 let mut mask_points = mask.sub_tree(&self.trace_locations);
291 mask_points[PREPROCESSED_TRACE_IDX] = preprocessed_mask;
292
293 self.eval.evaluate(PointEvaluator::new(
294 mask_points,
295 evaluation_accumulator,
296 coset_vanishing(CanonicCoset::new(max_log_degree_bound).coset, point).inverse(),
297 self.eval.log_size(),
298 self.claimed_sum,
299 ));
300 }
301}
302
303impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
304 type Target = E;
305
306 fn deref(&self) -> &E {
307 &self.eval
308 }
309}
310
311impl<E: FrameworkEval> Display for FrameworkComponent<E> {
312 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
313 let log_n_rows = self.log_size();
314 let mut n_cols = vec![];
315 self.trace_log_degree_bounds()
316 .0
317 .iter()
318 .for_each(|interaction| {
319 n_cols.push(interaction.len());
320 });
321 writeln!(f, "n_rows 2^{log_n_rows}")?;
322 writeln!(f, "n_constraints {}", self.n_constraints())?;
323 writeln!(
324 f,
325 "constraint_log_degree_bound {}",
326 self.max_constraint_log_degree_bound()
327 )?;
328 writeln!(
329 f,
330 "total felts: 2^{} * {}",
331 log_n_rows,
332 n_cols.iter().sum::<usize>()
333 )?;
334 for (j, n_cols) in n_cols.into_iter().enumerate() {
335 writeln!(f, "\t Interaction {j}: n_cols {n_cols}")?;
336 }
337 Ok(())
338 }
339}