sigma_proofs/linear_relation/
canonical.rs1use alloc::format;
2use alloc::vec::Vec;
3use core::iter;
4use core::marker::PhantomData;
5
6use ff::Field;
7use group::prime::PrimeGroup;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar};
11use crate::errors::{Error, InvalidInstance};
12use crate::group::msm::MultiScalarMul;
13
14#[derive(Clone, Debug, Default)]
24pub struct CanonicalLinearRelation<G: PrimeGroup> {
25 pub image: Vec<GroupVar<G>>,
27 pub linear_combinations: Vec<Vec<(ScalarVar<G>, GroupVar<G>)>>,
30 pub group_elements: GroupMap<G>,
32 pub num_scalars: usize,
34}
35
36type WeightedGroupCache<G> = Vec<Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
40
41impl<G: PrimeGroup> CanonicalLinearRelation<G> {
42 fn new() -> Self {
47 Self {
48 image: Vec::new(),
49 linear_combinations: Vec::new(),
50 group_elements: GroupMap::default(),
51 num_scalars: 0,
52 }
53 }
54
55 pub fn evaluate(&self, scalars: &[G::Scalar]) -> Vec<G>
66 where
67 G: MultiScalarMul,
68 {
69 self.linear_combinations
70 .iter()
71 .map(|lc| {
72 let scalars = lc
73 .iter()
74 .map(|(scalar_var, _)| scalars[scalar_var.index()])
75 .collect::<Vec<_>>();
76 let bases = lc
77 .iter()
78 .map(|(_, group_var)| self.group_elements.get(*group_var).unwrap())
79 .collect::<Vec<_>>();
80 G::msm(&scalars, &bases)
81 })
82 .collect()
83 }
84
85 fn get_or_create_weighted_group_var(
87 &mut self,
88 group_var: GroupVar<G>,
89 weight: &G::Scalar,
90 original_group_elements: &GroupMap<G>,
91 weighted_group_cache: &mut WeightedGroupCache<G>,
92 ) -> Result<GroupVar<G>, InvalidInstance> {
93 let index = group_var.index();
95 if weighted_group_cache.len() <= index {
96 weighted_group_cache.resize_with(index + 1, Vec::new);
97 }
98 let entry = &mut weighted_group_cache[index];
99
100 if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
102 return Ok(*existing_var);
103 }
104
105 let original_group_val = original_group_elements.get(group_var)?;
108 let weighted_group = match *weight == G::Scalar::ONE {
109 true => original_group_val,
110 false => original_group_val * weight,
111 };
112
113 let new_var = self.group_elements.push(weighted_group);
115
116 entry.push((*weight, new_var));
118
119 Ok(new_var)
120 }
121
122 fn process_constraint(
124 &mut self,
125 &image_var: &GroupVar<G>,
126 equation: &LinearCombination<G>,
127 original_relation: &LinearRelation<G>,
128 weighted_group_cache: &mut WeightedGroupCache<G>,
129 ) -> Result<(), InvalidInstance> {
130 let mut rhs_terms = Vec::new();
131
132 for weighted_term in equation.terms() {
134 if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
135 let group_var = weighted_term.term.elem;
136 let weight = &weighted_term.weight;
137
138 if weight.is_zero_vartime() {
139 continue; }
141
142 let canonical_group_var = self.get_or_create_weighted_group_var(
143 group_var,
144 weight,
145 &original_relation.linear_map.group_elements,
146 weighted_group_cache,
147 )?;
148
149 rhs_terms.push((scalar_var, canonical_group_var));
150 }
151 }
152
153 let mut canonical_image = original_relation.linear_map.group_elements.get(image_var)?;
155 for weighted_term in equation.terms() {
156 if let ScalarTerm::Unit = weighted_term.term.scalar {
157 let group_val = original_relation
158 .linear_map
159 .group_elements
160 .get(weighted_term.term.elem)?;
161 canonical_image -= group_val * weighted_term.weight;
162 }
163 }
164
165 #[expect(clippy::collapsible_if)]
167 if rhs_terms.is_empty() {
168 if canonical_image.is_identity().into() {
169 return Ok(());
170 }
171 }
180
181 let canonical_image_group_var = self.group_elements.push(canonical_image);
182 self.image.push(canonical_image_group_var);
183 self.linear_combinations.push(rhs_terms);
184
185 Ok(())
186 }
187
188 pub fn label(&self) -> Vec<u8> {
199 let mut out = Vec::new();
200
201 let mut constraint_data = Vec::<(u32, Vec<(u32, u32)>)>::new();
204
205 for (image_var, constraint_terms) in iter::zip(&self.image, &self.linear_combinations) {
206 let mut rhs_terms = Vec::new();
208 for (scalar_var, group_var) in constraint_terms {
209 rhs_terms.push((scalar_var.0 as u32, group_var.0 as u32));
210 }
211
212 constraint_data.push((image_var.0 as u32, rhs_terms));
213 }
214
215 let ne = constraint_data.len();
217 out.extend_from_slice(&(ne as u32).to_le_bytes());
218
219 for (lhs_index, rhs_terms) in constraint_data {
221 out.extend_from_slice(&lhs_index.to_le_bytes());
223
224 out.extend_from_slice(&(rhs_terms.len() as u32).to_le_bytes());
226
227 for (scalar_index, group_index) in rhs_terms {
229 out.extend_from_slice(&scalar_index.to_le_bytes());
230 out.extend_from_slice(&group_index.to_le_bytes());
231 }
232 }
233
234 for (_, elem) in self.group_elements.iter() {
236 out.extend_from_slice(
237 elem.expect("expected group variable to be assigned")
238 .to_bytes()
239 .as_ref(),
240 );
241 }
242
243 out
244 }
245
246 pub fn from_label(data: &[u8]) -> Result<Self, Error> {
262 use crate::errors::InvalidInstance;
263
264 let mut offset = 0;
265
266 if data.len() < 4 {
268 return Err(InvalidInstance::new("Invalid label: too short for equation count").into());
269 }
270 let num_equations = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
271 offset += 4;
272
273 let mut constraint_data = Vec::new();
275 let mut max_scalar_index = 0u32;
276 let mut max_group_index = 0u32;
277
278 for _ in 0..num_equations {
279 if offset + 4 > data.len() {
281 return Err(InvalidInstance::new("Invalid label: truncated LHS index").into());
282 }
283 let lhs_index = u32::from_le_bytes([
284 data[offset],
285 data[offset + 1],
286 data[offset + 2],
287 data[offset + 3],
288 ]);
289 offset += 4;
290 max_group_index = max_group_index.max(lhs_index);
291
292 if offset + 4 > data.len() {
294 return Err(InvalidInstance::new("Invalid label: truncated RHS count").into());
295 }
296 let num_rhs_terms = u32::from_le_bytes([
297 data[offset],
298 data[offset + 1],
299 data[offset + 2],
300 data[offset + 3],
301 ]) as usize;
302 offset += 4;
303
304 let mut rhs_terms = Vec::new();
306 for _ in 0..num_rhs_terms {
307 if offset + 4 > data.len() {
309 return Err(
310 InvalidInstance::new("Invalid label: truncated scalar index").into(),
311 );
312 }
313 let scalar_index = u32::from_le_bytes([
314 data[offset],
315 data[offset + 1],
316 data[offset + 2],
317 data[offset + 3],
318 ]);
319 offset += 4;
320 max_scalar_index = max_scalar_index.max(scalar_index);
321
322 if offset + 4 > data.len() {
324 return Err(InvalidInstance::new("Invalid label: truncated group index").into());
325 }
326 let group_index = u32::from_le_bytes([
327 data[offset],
328 data[offset + 1],
329 data[offset + 2],
330 data[offset + 3],
331 ]);
332 offset += 4;
333 max_group_index = max_group_index.max(group_index);
334
335 rhs_terms.push((scalar_index, group_index));
336 }
337
338 constraint_data.push((lhs_index, rhs_terms));
339 }
340
341 let num_group_elements = (max_group_index + 1) as usize;
343 let group_element_size = G::Repr::default().as_ref().len();
344 let expected_remaining = num_group_elements * group_element_size;
345
346 if data.len() - offset != expected_remaining {
347 return Err(InvalidInstance::new(format!(
348 "Invalid label: expected {} bytes for {} group elements, got {}",
349 expected_remaining,
350 num_group_elements,
351 data.len() - offset
352 ))
353 .into());
354 }
355
356 let mut group_elements_ordered = Vec::new();
358 for i in 0..num_group_elements {
359 let start = offset + i * group_element_size;
360 let end = start + group_element_size;
361 let elem_bytes = &data[start..end];
362
363 let mut repr = G::Repr::default();
364 repr.as_mut().copy_from_slice(elem_bytes);
365
366 let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
367 Error::from(InvalidInstance::new(format!(
368 "Invalid group element at index {i}"
369 )))
370 })?;
371
372 group_elements_ordered.push(elem);
373 }
374
375 let mut canonical = Self::new();
377 canonical.num_scalars = (max_scalar_index + 1) as usize;
378
379 let mut group_var_map = Vec::new();
381 for elem in &group_elements_ordered {
382 let var = canonical.group_elements.push(*elem);
383 group_var_map.push(var);
384 }
385
386 for (lhs_index, rhs_terms) in constraint_data {
388 canonical.image.push(group_var_map[lhs_index as usize]);
390
391 let mut linear_combination = Vec::new();
393 for (scalar_index, group_index) in rhs_terms {
394 let scalar_var = ScalarVar(scalar_index as usize, PhantomData);
395 let group_var = group_var_map[group_index as usize];
396 linear_combination.push((scalar_var, group_var));
397 }
398 canonical.linear_combinations.push(linear_combination);
399 }
400
401 Ok(canonical)
402 }
403
404 pub(crate) fn image_elements(&self) -> impl Iterator<Item = G> + use<'_, G> {
407 self.image.iter().map(|var| {
408 self.group_elements
409 .get(*var)
410 .expect("expected group variable to be assigned")
411 })
412 }
413}
414
415impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for CanonicalLinearRelation<G> {
416 type Error = InvalidInstance;
417
418 fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
419 Self::try_from(&value)
420 }
421}
422
423impl<G: PrimeGroup + MultiScalarMul> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
424 type Error = InvalidInstance;
425
426 fn try_from(relation: &LinearRelation<G>) -> Result<Self, Self::Error> {
427 if relation.image.len() != relation.linear_map.linear_combinations.len() {
428 return Err(InvalidInstance::new(
429 "Number of equations must be equal to number of image elements.",
430 ));
431 }
432
433 let mut canonical = CanonicalLinearRelation::new();
434 canonical.num_scalars = relation.linear_map.num_scalars;
435
436 let mut weighted_group_cache = Vec::new();
438
439 for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) {
441 let lhs_value = relation.linear_map.group_elements.get(*lhs)?;
443
444 let rhs_constant_terms = rhs
447 .0
448 .iter()
449 .filter(|term| matches!(term.term.scalar, ScalarTerm::Unit))
450 .map(|term| {
451 let elem = relation.linear_map.group_elements.get(term.term.elem)?;
452 let scalar = term.weight;
453 Ok((elem, scalar))
454 })
455 .collect::<Result<(Vec<G>, Vec<G::Scalar>), _>>()?;
456
457 let rhs_constant_term = G::msm(&rhs_constant_terms.1, &rhs_constant_terms.0);
458
459 let is_trivial = rhs.0.iter().all(|term| {
462 matches!(term.term.scalar, ScalarTerm::Unit) || term.weight.is_zero_vartime()
463 });
464
465 let is_homogenous = rhs_constant_term == lhs_value;
467
468 if is_trivial && is_homogenous {
471 continue;
472 }
473
474 if !is_trivial && is_homogenous {
476 return Err(InvalidInstance::new("Trivial kernel in this relation"));
477 }
478
479 canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
480 }
481
482 Ok(canonical)
483 }
484}
485
486impl<G: PrimeGroup + ConstantTimeEq + MultiScalarMul> CanonicalLinearRelation<G> {
487 pub fn is_witness_valid(&self, witness: &[G::Scalar]) -> Choice {
496 let got = self.evaluate(witness);
497 self.image_elements()
498 .zip(got)
499 .fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs))
500 }
501}