1use std::{
2 borrow::Borrow,
3 panic::{self, AssertUnwindSafe},
4 process::exit,
5};
6
7use p3_air::{
8 Air, AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder,
9 PermutationAirBuilder,
10};
11use p3_field::{AbstractField, ExtensionField, Field, PrimeField32};
12use p3_matrix::{
13 dense::{RowMajorMatrix, RowMajorMatrixView},
14 stack::VerticalPair,
15 Matrix,
16};
17use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
18
19use super::{MachineChip, StarkGenericConfig, Val};
20use crate::{
21 air::{EmptyMessageBuilder, MachineAir, MultiTableAirBuilder},
22 septic_digest::SepticDigest,
23};
24
25#[allow(clippy::too_many_arguments)]
29pub fn debug_constraints<SC, A>(
30 chip: &MachineChip<SC, A>,
31 preprocessed: Option<&RowMajorMatrix<Val<SC>>>,
32 main: &RowMajorMatrix<Val<SC>>,
33 perm: &RowMajorMatrix<SC::Challenge>,
34 perm_challenges: &[SC::Challenge],
35 public_values: &[Val<SC>],
36 local_cumulative_sum: &SC::Challenge,
37 global_cumulative_sum: &SepticDigest<Val<SC>>,
38) where
39 SC: StarkGenericConfig,
40 Val<SC>: PrimeField32,
41 A: MachineAir<Val<SC>> + for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
42{
43 assert_eq!(main.height(), perm.height());
44 let height = main.height();
45 if height == 0 {
46 return;
47 }
48
49 (0..height).par_bridge().for_each(|i| {
51 let i_next = (i + 1) % height;
52
53 let main_local = main.row_slice(i);
54 let main_local = &(*main_local);
55 let main_next = main.row_slice(i_next);
56 let main_next = &(*main_next);
57 let preprocessed_local = if let Some(preprocessed) = preprocessed {
58 let row = preprocessed.row_slice(i);
59 let row: &[_] = (*row).borrow();
60 row.to_vec()
61 } else {
62 Vec::new()
63 };
64 let preprocessed_next = if let Some(preprocessed) = preprocessed {
65 let row = preprocessed.row_slice(i_next);
66 let row: &[_] = (*row).borrow();
67 row.to_vec()
68 } else {
69 Vec::new()
70 };
71 let perm_local = perm.row_slice(i);
72 let perm_local = &(*perm_local);
73 let perm_next = perm.row_slice(i_next);
74 let perm_next = &(*perm_next);
75
76 let mut builder = DebugConstraintBuilder {
77 preprocessed: VerticalPair::new(
78 RowMajorMatrixView::new_row(&preprocessed_local),
79 RowMajorMatrixView::new_row(&preprocessed_next),
80 ),
81 main: VerticalPair::new(
82 RowMajorMatrixView::new_row(main_local),
83 RowMajorMatrixView::new_row(main_next),
84 ),
85 perm: VerticalPair::new(
86 RowMajorMatrixView::new_row(perm_local),
87 RowMajorMatrixView::new_row(perm_next),
88 ),
89 perm_challenges,
90 local_cumulative_sum,
91 global_cumulative_sum,
92 is_first_row: Val::<SC>::zero(),
93 is_last_row: Val::<SC>::zero(),
94 is_transition: Val::<SC>::one(),
95 public_values,
96 };
97 if i == 0 {
98 builder.is_first_row = Val::<SC>::one();
99 }
100 if i == height - 1 {
101 builder.is_last_row = Val::<SC>::one();
102 builder.is_transition = Val::<SC>::zero();
103 }
104 let result = catch_unwind_silent(AssertUnwindSafe(|| {
105 chip.eval(&mut builder);
106 }));
107 if result.is_err() {
108 eprintln!("local: {main_local:?}");
109 eprintln!("next: {main_next:?}");
110 eprintln!("failed at row {} of chip {}", i, chip.name());
111 exit(1);
112 }
113 });
114}
115
116fn catch_unwind_silent<F: FnOnce() -> R + panic::UnwindSafe, R>(f: F) -> std::thread::Result<R> {
117 let prev_hook = panic::take_hook();
118 panic::set_hook(Box::new(|_| {}));
119 let result = panic::catch_unwind(f);
120 panic::set_hook(prev_hook);
121 result
122}
123
124pub fn debug_cumulative_sums<F: Field, EF: ExtensionField<F>>(perms: &[RowMajorMatrix<EF>]) {
128 let sum: EF = perms.iter().map(|perm| *perm.row_slice(perm.height() - 1).last().unwrap()).sum();
129 assert_eq!(sum, EF::zero());
130}
131
132pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField<F>> {
134 pub(crate) preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
135 pub(crate) main: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
136 pub(crate) perm: VerticalPair<RowMajorMatrixView<'a, EF>, RowMajorMatrixView<'a, EF>>,
137 pub(crate) local_cumulative_sum: &'a EF,
138 pub(crate) global_cumulative_sum: &'a SepticDigest<F>,
139 pub(crate) perm_challenges: &'a [EF],
140 pub(crate) is_first_row: F,
141 pub(crate) is_last_row: F,
142 pub(crate) is_transition: F,
143 pub(crate) public_values: &'a [F],
144}
145
146impl<F, EF> ExtensionBuilder for DebugConstraintBuilder<'_, F, EF>
147where
148 F: Field,
149 EF: ExtensionField<F>,
150{
151 type EF = EF;
152 type VarEF = EF;
153 type ExprEF = EF;
154
155 fn assert_zero_ext<I>(&mut self, x: I)
156 where
157 I: Into<Self::ExprEF>,
158 {
159 assert_eq!(x.into(), EF::zero(), "constraints must evaluate to zero");
160 }
161}
162
163impl<'a, F, EF> PermutationAirBuilder for DebugConstraintBuilder<'a, F, EF>
164where
165 F: Field,
166 EF: ExtensionField<F>,
167{
168 type MP = VerticalPair<RowMajorMatrixView<'a, EF>, RowMajorMatrixView<'a, EF>>;
169
170 type RandomVar = EF;
171
172 fn permutation(&self) -> Self::MP {
173 self.perm
174 }
175
176 fn permutation_randomness(&self) -> &[Self::EF] {
177 self.perm_challenges
178 }
179}
180
181impl<F, EF> PairBuilder for DebugConstraintBuilder<'_, F, EF>
182where
183 F: Field,
184 EF: ExtensionField<F>,
185{
186 fn preprocessed(&self) -> Self::M {
187 self.preprocessed
188 }
189}
190
191impl<F, EF> DebugConstraintBuilder<'_, F, EF>
192where
193 F: Field,
194 EF: ExtensionField<F>,
195{
196 #[allow(clippy::unused_self)]
197 #[inline]
198 fn debug_constraint(&self, x: F, y: F) {
199 if x != y {
200 let backtrace = std::backtrace::Backtrace::force_capture();
201 eprintln!("constraint failed: {x:?} != {y:?}\n{backtrace}");
202 panic!();
203 }
204 }
205}
206
207impl<'a, F, EF> AirBuilder for DebugConstraintBuilder<'a, F, EF>
208where
209 F: Field,
210 EF: ExtensionField<F>,
211{
212 type F = F;
213 type Expr = F;
214 type Var = F;
215 type M = VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>;
216
217 fn is_first_row(&self) -> Self::Expr {
218 self.is_first_row
219 }
220
221 fn is_last_row(&self) -> Self::Expr {
222 self.is_last_row
223 }
224
225 fn is_transition_window(&self, size: usize) -> Self::Expr {
226 if size == 2 {
227 self.is_transition
228 } else {
229 panic!("only supports a window size of 2")
230 }
231 }
232
233 fn main(&self) -> Self::M {
234 self.main
235 }
236
237 fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
238 self.debug_constraint(x.into(), F::zero());
239 }
240
241 fn assert_one<I: Into<Self::Expr>>(&mut self, x: I) {
242 self.debug_constraint(x.into(), F::one());
243 }
244
245 fn assert_eq<I1: Into<Self::Expr>, I2: Into<Self::Expr>>(&mut self, x: I1, y: I2) {
246 self.debug_constraint(x.into(), y.into());
247 }
248
249 fn assert_bool<I: Into<Self::Expr>>(&mut self, x: I) {
251 let x = x.into();
252 if x != F::zero() && x != F::one() {
253 let backtrace = std::backtrace::Backtrace::force_capture();
254 eprintln!("constraint failed: {x:?} is not a bool\n{backtrace}");
255 panic!();
256 }
257 }
258}
259
260impl<'a, F, EF> MultiTableAirBuilder<'a> for DebugConstraintBuilder<'a, F, EF>
261where
262 F: Field,
263 EF: ExtensionField<F>,
264{
265 type LocalSum = EF;
266 type GlobalSum = F;
267
268 fn local_cumulative_sum(&self) -> &'a Self::LocalSum {
269 self.local_cumulative_sum
270 }
271
272 fn global_cumulative_sum(&self) -> &'a SepticDigest<Self::GlobalSum> {
273 self.global_cumulative_sum
274 }
275}
276
277impl<F: Field, EF: ExtensionField<F>> EmptyMessageBuilder for DebugConstraintBuilder<'_, F, EF> {}
278
279impl<F: Field, EF: ExtensionField<F>> AirBuilderWithPublicValues
280 for DebugConstraintBuilder<'_, F, EF>
281{
282 type PublicVar = F;
283
284 fn public_values(&self) -> &[Self::PublicVar] {
285 self.public_values
286 }
287}