1use crate::{
17 fft::{
18 EvaluationDomain,
19 domain::{FFTPrecomputation, IFFTPrecomputation},
20 },
21 polycommit::sonic_pc::{LCTerm, LabeledPolynomial, LinearCombination},
22 r1cs::SynthesisError,
23 snark::varuna::{
24 SNARKMode,
25 VarunaVersion,
26 ahp::{AHPError, CircuitId, CircuitInfo, verifier},
27 prover,
28 selectors::precompute_selectors,
29 verifier::{QueryPoints, select_third_round_challenges},
30 },
31};
32use anyhow::{Result, anyhow, ensure};
33use snarkvm_fields::{Field, PrimeField};
34
35use core::{borrow::Borrow, marker::PhantomData};
36use itertools::Itertools;
37use std::{collections::BTreeMap, fmt::Write};
38
39pub struct AHPForR1CS<F: Field, SM: SNARKMode> {
43 field: PhantomData<F>,
44 mode: PhantomData<SM>,
45}
46
47pub(crate) fn witness_label(circuit_id: CircuitId, poly: &str, i: usize) -> String {
48 let mut label = String::with_capacity(82 + poly.len());
49 let _ = write!(&mut label, "circuit_{circuit_id}_{poly}_{i:0>8}");
50 label
51}
52
53pub(crate) struct NonZeroDomains<F: PrimeField> {
54 pub(crate) max_non_zero_domain: Option<EvaluationDomain<F>>,
55 pub(crate) domain_a: EvaluationDomain<F>,
56 pub(crate) domain_b: EvaluationDomain<F>,
57 pub(crate) domain_c: EvaluationDomain<F>,
58}
59
60impl<F: PrimeField, SM: SNARKMode> AHPForR1CS<F, SM> {
61 pub const LC_WITH_ZERO_EVAL: [&'static str; 3] = ["matrix_sumcheck", "lineval_sumcheck", "rowcheck_zerocheck"];
65
66 pub fn zk_bound() -> Option<usize> {
67 SM::ZK.then_some(1)
68 }
69
70 pub fn num_formatted_public_inputs_is_admissible(num_inputs: usize) -> Result<(), AHPError> {
73 match num_inputs.count_ones() == 1 {
74 true => Ok(()),
75 false => Err(AHPError::InvalidPublicInputLength),
76 }
77 }
78
79 pub fn formatted_public_input_is_admissible(input: &[F]) -> Result<(), AHPError> {
82 Self::num_formatted_public_inputs_is_admissible(input.len())
83 }
84
85 pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result<usize> {
90 let zk_bound = Self::zk_bound().unwrap_or(0);
91 let constraint_domain_size =
92 EvaluationDomain::<F>::compute_size_of_domain(num_constraints).ok_or(AHPError::PolyTooLarge)?;
93 let variable_domain_size =
94 EvaluationDomain::<F>::compute_size_of_domain(num_variables).ok_or(AHPError::PolyTooLarge)?;
95 let non_zero_domain_size =
96 EvaluationDomain::<F>::compute_size_of_domain(num_non_zero).ok_or(AHPError::PolyTooLarge)?;
97
98 [
100 2 * constraint_domain_size + 2 * zk_bound - 2,
101 2 * variable_domain_size + 2 * zk_bound - 2,
102 if SM::ZK { variable_domain_size + 3 } else { 0 }, variable_domain_size,
104 constraint_domain_size,
105 non_zero_domain_size - 1, ]
107 .iter()
108 .max()
109 .copied()
110 .ok_or(anyhow!("Could not find max_degree"))
111 }
112
113 pub fn get_degree_bounds(info: &CircuitInfo) -> Result<[usize; 4]> {
115 let num_variables = info.num_public_and_private_variables;
116 let num_non_zero_a = info.num_non_zero_a;
117 let num_non_zero_b = info.num_non_zero_b;
118 let num_non_zero_c = info.num_non_zero_c;
119 Ok([
120 EvaluationDomain::<F>::compute_size_of_domain(num_variables).ok_or(SynthesisError::PolyTooLarge)? - 2,
121 EvaluationDomain::<F>::compute_size_of_domain(num_non_zero_a).ok_or(SynthesisError::PolyTooLarge)? - 2,
122 EvaluationDomain::<F>::compute_size_of_domain(num_non_zero_b).ok_or(SynthesisError::PolyTooLarge)? - 2,
123 EvaluationDomain::<F>::compute_size_of_domain(num_non_zero_c).ok_or(SynthesisError::PolyTooLarge)? - 2,
124 ])
125 }
126
127 pub(crate) fn cmp_non_zero_domains(
128 info: &CircuitInfo,
129 max_candidate: Option<EvaluationDomain<F>>,
130 ) -> Result<NonZeroDomains<F>> {
131 let domain_a = EvaluationDomain::new(info.num_non_zero_a).ok_or(SynthesisError::PolyTooLarge)?;
132 let domain_b = EvaluationDomain::new(info.num_non_zero_b).ok_or(SynthesisError::PolyTooLarge)?;
133 let domain_c = EvaluationDomain::new(info.num_non_zero_c).ok_or(SynthesisError::PolyTooLarge)?;
134 let new_candidate = [domain_a, domain_b, domain_c]
135 .into_iter()
136 .max_by_key(|d| d.size())
137 .ok_or(anyhow!("could not find max domain"))?;
138 let mut max_non_zero_domain = Some(new_candidate);
139 if let Some(max_candidate) = max_candidate {
140 if max_candidate.size() > new_candidate.size() {
141 max_non_zero_domain = Some(max_candidate);
142 }
143 }
144 Ok(NonZeroDomains { max_non_zero_domain, domain_a, domain_b, domain_c })
145 }
146
147 pub fn fft_precomputation(
148 constraint_domain_size: usize,
149 variable_domain_size: usize,
150 non_zero_a_domain_size: usize,
151 non_zero_b_domain_size: usize,
152 non_zero_c_domain_size: usize,
153 ) -> Option<(FFTPrecomputation<F>, IFFTPrecomputation<F>)> {
154 let largest_domain_size = [
155 2 * constraint_domain_size,
156 2 * variable_domain_size,
157 2 * non_zero_a_domain_size,
158 2 * non_zero_b_domain_size,
159 2 * non_zero_c_domain_size,
160 ]
161 .into_iter()
162 .max()?;
163 let largest_mul_domain = EvaluationDomain::new(largest_domain_size)?;
164
165 let fft_precomputation = largest_mul_domain.precompute_fft();
166 let ifft_precomputation = fft_precomputation.to_ifft_precomputation();
167 Some((fft_precomputation, ifft_precomputation))
168 }
169
170 #[allow(non_snake_case)]
179 pub fn construct_linear_combinations<E: EvaluationsProvider<F>>(
180 public_inputs: &BTreeMap<CircuitId, Vec<Vec<F>>>,
181 evals: &E,
182 prover_third_message: &prover::ThirdMessage<F>,
183 prover_fourth_message: &prover::FourthMessage<F>,
184 state: &verifier::State<F, SM>,
185 varuna_version: VarunaVersion,
186 ) -> Result<BTreeMap<String, LinearCombination<F>>> {
187 ensure!(!public_inputs.is_empty());
188 let max_constraint_domain = state.max_constraint_domain;
189 let max_variable_domain = state.max_variable_domain;
190 let max_non_zero_domain = state.max_non_zero_domain;
191 let mut formatted_public_inputs = Vec::with_capacity(state.circuit_specific_states.len());
192 for (circuit_id, circuit_state) in &state.circuit_specific_states {
193 let input_domain = circuit_state.input_domain;
194 let public_inputs_i = public_inputs[circuit_id]
195 .iter()
196 .map(|p| {
197 let public_input = prover::ConstraintSystem::format_public_input(p);
198 Self::formatted_public_input_is_admissible(&public_input)?;
199 Ok::<_, AHPError>(public_input)
200 })
201 .collect::<Result<Vec<_>, _>>()?;
202 ensure!(public_inputs_i[0].len() == input_domain.size());
203 formatted_public_inputs.push(public_inputs_i);
204 }
205
206 let verifier::FirstMessage { first_round_batch_combiners } = state.first_round_message.as_ref().unwrap();
207 let verifier::ThirdMessage { beta } = state.third_round_message.unwrap();
208
209 let (alpha, third_round_batch_combiners, eta_b, eta_c) = select_third_round_challenges(
211 state.first_round_message.as_ref().unwrap(),
212 state.second_round_message.as_ref().unwrap(),
213 state.prepare_third_round_message.as_ref(),
214 varuna_version,
215 )
216 .map_err(AHPError::AnyhowError)?;
217
218 let batch_lineval_sum =
219 prover_third_message.sum(&third_round_batch_combiners, eta_b, eta_c) * state.max_variable_domain.size_inv;
220 let verifier::FourthMessage { delta_a, delta_b, delta_c } = state.fourth_round_message.as_ref().unwrap();
221 let sums_fourth_msg = &prover_fourth_message.sums;
222 let gamma = state.gamma.unwrap();
223 let challenges = QueryPoints::new(alpha, beta, gamma);
224
225 let mut linear_combinations = BTreeMap::new();
226 let constraint_domains = state.constraint_domains();
227 let variable_domains = state.variable_domains();
228 let non_zero_domains = state.non_zero_domains();
229 let selectors = precompute_selectors(
230 max_constraint_domain,
231 constraint_domains,
232 max_variable_domain,
233 variable_domains,
234 max_non_zero_domain,
235 non_zero_domains,
236 challenges,
237 );
238
239 let rowcheck_time = start_timer!(|| "Rowcheck");
241
242 let v_R_at_alpha_time = start_timer!(|| "v_R_at_alpha");
243 let v_R_at_alpha = max_constraint_domain.evaluate_vanishing_polynomial(alpha);
244 end_timer!(v_R_at_alpha_time);
245
246 let rowcheck_zerocheck = {
247 let mut rowcheck_zerocheck = LinearCombination::empty("rowcheck_zerocheck");
248 for (i, (id, c)) in first_round_batch_combiners.iter().enumerate() {
249 let mut circuit_term = LinearCombination::empty(format!("rowcheck_zerocheck term {id}"));
250 let third_sums_i = &prover_third_message.sums[i];
251 let circuit_state = &state.circuit_specific_states[id];
252
253 for (j, instance_combiner) in c.instance_combiners.iter().enumerate() {
254 let mut rowcheck = LinearCombination::empty(format!("rowcheck term {id}"));
255 let sum_a_third = third_sums_i[j].sum_a;
256 let sum_b_third = third_sums_i[j].sum_b;
257 let sum_c_third = third_sums_i[j].sum_c;
258
259 rowcheck.add(sum_a_third * sum_b_third - sum_c_third, LCTerm::One);
260
261 circuit_term += (*instance_combiner, &rowcheck);
262 }
263 let constraint_domain = circuit_state.constraint_domain;
264 let selector = selectors
265 .get(&(max_constraint_domain.size, constraint_domain.size, alpha))
266 .ok_or(anyhow!("Could not find selector at alpha"))?;
267 circuit_term *= *selector;
268 rowcheck_zerocheck += (c.circuit_combiner, &circuit_term);
269 }
270 rowcheck_zerocheck.add(-v_R_at_alpha, "h_0");
271 rowcheck_zerocheck
272 };
273
274 debug_assert!(evals.get_lc_eval(&rowcheck_zerocheck, alpha)?.is_zero());
275 linear_combinations.insert("rowcheck_zerocheck".into(), rowcheck_zerocheck);
276 end_timer!(rowcheck_time);
277
278 let lineval_time = start_timer!(|| "Lineval");
280
281 let g_1 = LinearCombination::new("g_1", [(F::one(), "g_1")]);
282
283 let v_C_at_beta = max_variable_domain.evaluate_vanishing_polynomial(beta);
284 let v_K_at_gamma = max_non_zero_domain.evaluate_vanishing_polynomial(gamma);
285
286 let v_X_at_beta_time = start_timer!(|| "v_X_at_beta");
287 let v_X_at_beta = state
288 .circuit_specific_states
289 .iter()
290 .map(|(circuit_id, circuit_state)| {
291 let v_X_i_at_beta = circuit_state.input_domain.evaluate_vanishing_polynomial(beta);
292 (circuit_id, v_X_i_at_beta)
293 })
294 .collect::<BTreeMap<_, _>>();
295 end_timer!(v_X_at_beta_time);
296
297 let x_at_betas = state
298 .circuit_specific_states
299 .iter()
300 .enumerate()
301 .map(|(i, (circuit_id, circuit_state))| {
302 let lag_at_beta = circuit_state.input_domain.evaluate_all_lagrange_coefficients(beta);
303 let x_at_beta = formatted_public_inputs[i]
304 .iter()
305 .map(|x| x.iter().zip_eq(&lag_at_beta).map(|(x, l)| *x * l).sum::<F>())
306 .collect_vec();
307 (circuit_id, x_at_beta)
308 })
309 .collect::<BTreeMap<_, _>>();
310
311 let g_1_at_beta = evals.get_lc_eval(&g_1, beta)?;
312
313 let lineval_sumcheck = {
315 let mut lineval_sumcheck = LinearCombination::empty("lineval_sumcheck");
316 if SM::ZK {
317 lineval_sumcheck.add(F::one(), "mask_poly");
318 }
319 for (i, (id, c)) in third_round_batch_combiners.iter().enumerate() {
320 let mut circuit_term = LinearCombination::empty(format!("lineval_sumcheck term {id}"));
321 let fourth_sums_i = &sums_fourth_msg[i];
322 let circuit_state = &state.circuit_specific_states[id];
323
324 for (j, instance_combiner) in c.instance_combiners.iter().enumerate() {
325 let w_j = witness_label(*id, "w", j);
326 let mut lineval = LinearCombination::empty(format!("lineval term {j}"));
327 let sum_a_fourth = fourth_sums_i.sum_a * circuit_state.non_zero_a_domain.size_as_field_element;
328 let sum_b_fourth = fourth_sums_i.sum_b * circuit_state.non_zero_b_domain.size_as_field_element;
329 let sum_c_fourth = fourth_sums_i.sum_c * circuit_state.non_zero_c_domain.size_as_field_element;
330
331 lineval.add(sum_a_fourth * x_at_betas[id][j], LCTerm::One);
332 lineval.add(sum_a_fourth * v_X_at_beta[id], w_j.clone());
333
334 lineval.add(sum_b_fourth * eta_b * x_at_betas[id][j], LCTerm::One);
335 lineval.add(sum_b_fourth * eta_b * v_X_at_beta[id], w_j.clone());
336
337 lineval.add(sum_c_fourth * eta_c * x_at_betas[id][j], LCTerm::One);
338 lineval.add(sum_c_fourth * eta_c * v_X_at_beta[id], w_j);
339
340 circuit_term += (*instance_combiner, &lineval);
341 }
342 let variable_domain = circuit_state.variable_domain;
343 let selector = selectors
344 .get(&(max_variable_domain.size, variable_domain.size, beta))
345 .ok_or(anyhow!("Could not find selector at beta"))?;
346 circuit_term *= *selector;
347
348 lineval_sumcheck += (c.circuit_combiner, &circuit_term);
349 }
350 lineval_sumcheck
351 .add(-v_C_at_beta, "h_1")
352 .add(-beta * g_1_at_beta, LCTerm::One)
353 .add(-batch_lineval_sum, LCTerm::One);
354 lineval_sumcheck
355 };
356 debug_assert!(evals.get_lc_eval(&lineval_sumcheck, beta)?.is_zero());
357
358 linear_combinations.insert("g_1".into(), g_1);
359 linear_combinations.insert("lineval_sumcheck".into(), lineval_sumcheck);
360 end_timer!(lineval_time);
361
362 let mut matrix_sumcheck = LinearCombination::empty("matrix_sumcheck");
364
365 for (i, (&id, state_i)) in state.circuit_specific_states.iter().enumerate() {
366 let v_R_i_at_alpha = state_i.constraint_domain.evaluate_vanishing_polynomial(alpha);
367 let v_C_i_at_beta = state_i.variable_domain.evaluate_vanishing_polynomial(beta);
368 let v_rc = v_R_i_at_alpha * v_C_i_at_beta;
369 let rc = state_i.constraint_domain.size_as_field_element * state_i.variable_domain.size_as_field_element;
370
371 let matrices = ["a", "b", "c"];
372 let deltas = [delta_a[i], delta_b[i], delta_c[i]];
373 let non_zero_domains = [&state_i.non_zero_a_domain, &state_i.non_zero_b_domain, &state_i.non_zero_c_domain];
374 let sums = sums_fourth_msg[i].iter();
375
376 ensure!(matrices.len() == sums.len());
377 ensure!(matrices.len() == deltas.len());
378 ensure!(matrices.len() == non_zero_domains.len());
379 for (((m, sum), delta), non_zero_domain) in
380 matrices.into_iter().zip_eq(sums).zip_eq(deltas).zip_eq(non_zero_domains)
381 {
382 let selector = selectors
383 .get(&(max_non_zero_domain.size, non_zero_domain.size, gamma))
384 .ok_or(anyhow!("Could not find selector at gamma"))?;
385 let label = "g_".to_string() + m;
386 let g_m_label = witness_label(id, &label, 0);
387 let g_m = LinearCombination::new(g_m_label.clone(), [(F::one(), g_m_label)]);
388 let g_m_at_gamma = evals.get_lc_eval(&g_m, gamma)?;
389
390 let (a_poly, b_poly) = Self::construct_matrix_linear_combinations(evals, id, m, v_rc, challenges, rc)?;
391 let g_m_term = Self::construct_g_m_term(gamma, g_m_at_gamma, sum, *selector, a_poly, b_poly);
392
393 matrix_sumcheck += (delta, &g_m_term);
394
395 linear_combinations.insert(g_m.label.clone(), g_m);
396 }
397 }
398
399 matrix_sumcheck -= &LinearCombination::new("h_2", [(v_K_at_gamma, "h_2")]);
400 debug_assert!(evals.get_lc_eval(&matrix_sumcheck, gamma)?.is_zero());
401
402 linear_combinations.insert("matrix_sumcheck".into(), matrix_sumcheck);
403
404 Ok(linear_combinations)
405 }
406
407 fn construct_g_m_term(
408 gamma: F,
409 g_m_at_gamma: F,
410 sum: F,
411 selector_at_gamma: F,
412 a_poly: LinearCombination<F>,
413 mut b_poly: LinearCombination<F>,
414 ) -> LinearCombination<F> {
415 let b_term = gamma * g_m_at_gamma + sum; b_poly *= b_term;
417
418 let mut lhs = a_poly;
419 lhs -= &b_poly;
420 lhs *= selector_at_gamma;
421 lhs
422 }
423
424 fn construct_matrix_linear_combinations<E: EvaluationsProvider<F>>(
425 evals: &E,
426 id: CircuitId,
427 matrix: &str,
428 v_rc_at_alpha_beta: F,
429 challenges: QueryPoints<F>,
430 rc_size: F,
431 ) -> Result<(LinearCombination<F>, LinearCombination<F>)> {
432 let label_a_poly = format!("circuit_{id}_a_poly_{matrix}");
433 let label_b_poly = format!("circuit_{id}_b_poly_{matrix}");
434 let QueryPoints { alpha, beta, gamma } = challenges;
435
436 let a_poly = LinearCombination::new(label_a_poly.clone(), [(F::one(), label_a_poly.clone())]);
439 let a_poly_eval_available = evals.get_lc_eval(&a_poly, gamma).is_ok();
440 let b_poly = LinearCombination::new(label_b_poly.clone(), [(F::one(), label_b_poly.clone())]);
441 let b_poly_eval_available = evals.get_lc_eval(&b_poly, gamma).is_ok();
442 ensure!(a_poly_eval_available == b_poly_eval_available);
443 if a_poly_eval_available && b_poly_eval_available {
444 return Ok((a_poly, b_poly));
445 };
446
447 let label_col = format!("circuit_{id}_col_{matrix}");
450 let label_row = format!("circuit_{id}_row_{matrix}");
451 let label_row_col = format!("circuit_{id}_row_col_{matrix}");
452 let label_row_col_val = format!("circuit_{id}_row_col_val_{matrix}");
454 let a = LinearCombination::new(label_a_poly, [(v_rc_at_alpha_beta, label_row_col_val)]);
455 let mut b = LinearCombination::new(label_b_poly, [
456 (alpha * beta, LCTerm::One),
457 (-alpha, (label_col).into()),
458 (-beta, (label_row).into()),
459 (F::one(), (label_row_col).into()),
460 ]);
461 b *= rc_size;
462 Ok((a, b))
463 }
464}
465
466pub trait EvaluationsProvider<F: PrimeField>: core::fmt::Debug {
473 fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F>;
475}
476
477impl<F: PrimeField> EvaluationsProvider<F> for crate::polycommit::sonic_pc::Evaluations<F> {
479 fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F> {
480 let key = (lc.label.clone(), point);
481 self.get(&key).copied().ok_or_else(|| AHPError::MissingEval(lc.label.clone())).map_err(Into::into)
482 }
483}
484
485impl<F, T> EvaluationsProvider<F> for Vec<T>
487where
488 F: PrimeField,
489 T: Borrow<LabeledPolynomial<F>> + core::fmt::Debug,
490{
491 fn get_lc_eval(&self, lc: &LinearCombination<F>, point: F) -> Result<F> {
492 let mut eval = F::zero();
493 for (coeff, term) in lc.iter() {
494 let value = if let LCTerm::PolyLabel(label) = term {
495 self.iter()
496 .find(|p| (*p).borrow().label() == label)
497 .ok_or_else(|| AHPError::MissingEval(format!("Missing {} for {}", label, lc.label)))?
498 .borrow()
499 .evaluate(point)
500 } else {
501 ensure!(term.is_one());
502 F::one()
503 };
504 eval += &(*coeff * value)
505 }
506 Ok(eval)
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use crate::fft::DensePolynomial;
514 use snarkvm_curves::bls12_377::fr::Fr;
515 use snarkvm_fields::Zero;
516 use snarkvm_utilities::rand::TestRng;
517
518 #[test]
519 fn test_summation() {
520 let rng = &mut TestRng::default();
521 let size = 1 << 4;
522 let domain = EvaluationDomain::<Fr>::new(1 << 4).unwrap();
523 let size_as_fe = domain.size_as_field_element;
524 let poly = DensePolynomial::rand(size, rng);
525
526 let mut sum: Fr = Fr::zero();
527 for eval in domain.elements().map(|e| poly.evaluate(e)) {
528 sum += &eval;
529 }
530 let first = poly.coeffs[0] * size_as_fe;
531 let last = *poly.coeffs.last().unwrap() * size_as_fe;
532 assert_eq!(sum, first + last);
533 }
534}