sigma_proofs/linear_relation/
canonical.rs1use alloc::format;
2use alloc::vec::Vec;
3use core::iter;
4use core::marker::PhantomData;
5
6use ff::Field;
7use group::prime::PrimeGroup;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar};
11use crate::errors::{Error, InvalidInstance};
12use crate::group::msm::MultiScalarMul;
13use crate::serialization::serialize_elements;
14
15#[derive(Clone, Debug, Default)]
25pub struct CanonicalLinearRelation<G: PrimeGroup> {
26 pub image: Vec<GroupVar<G>>,
28 pub linear_combinations: Vec<Vec<(ScalarVar<G>, GroupVar<G>)>>,
31 pub group_elements: GroupMap<G>,
33 pub num_scalars: usize,
35}
36
37type WeightedGroupCache<G> = Vec<Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
41
42impl<G: PrimeGroup> CanonicalLinearRelation<G> {
43 fn new() -> Self {
48 Self {
49 image: Vec::new(),
50 linear_combinations: Vec::new(),
51 group_elements: GroupMap::default(),
52 num_scalars: 0,
53 }
54 }
55
56 pub fn evaluate(&self, scalars: &[G::Scalar]) -> Vec<G>
67 where
68 G: MultiScalarMul,
69 {
70 self.linear_combinations
71 .iter()
72 .map(|lc| {
73 let scalars = lc
74 .iter()
75 .map(|(scalar_var, _)| scalars[scalar_var.index()])
76 .collect::<Vec<_>>();
77 let bases = lc
78 .iter()
79 .map(|(_, group_var)| self.group_elements.get(*group_var).unwrap())
80 .collect::<Vec<_>>();
81 G::msm(&scalars, &bases)
82 })
83 .collect()
84 }
85
86 fn get_or_create_weighted_group_var(
88 &mut self,
89 group_var: GroupVar<G>,
90 weight: &G::Scalar,
91 original_group_elements: &GroupMap<G>,
92 weighted_group_cache: &mut WeightedGroupCache<G>,
93 ) -> Result<GroupVar<G>, InvalidInstance> {
94 let index = group_var.index();
96 if weighted_group_cache.len() <= index {
97 weighted_group_cache.resize_with(index + 1, Vec::new);
98 }
99 let entry = &mut weighted_group_cache[index];
100
101 if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
103 return Ok(*existing_var);
104 }
105
106 let original_group_val = original_group_elements.get(group_var)?;
109 let weighted_group = match *weight == G::Scalar::ONE {
110 true => original_group_val,
111 false => original_group_val * weight,
112 };
113
114 let new_var = self.group_elements.push(weighted_group);
116
117 entry.push((*weight, new_var));
119
120 Ok(new_var)
121 }
122
123 fn process_constraint(
125 &mut self,
126 &image_var: &GroupVar<G>,
127 equation: &LinearCombination<G>,
128 original_relation: &LinearRelation<G>,
129 weighted_group_cache: &mut WeightedGroupCache<G>,
130 ) -> Result<(), InvalidInstance> {
131 let mut rhs_terms = Vec::new();
132
133 for weighted_term in equation.terms() {
135 if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
136 let group_var = weighted_term.term.elem;
137 let weight = &weighted_term.weight;
138
139 if weight.is_zero_vartime() {
140 continue; }
142
143 let canonical_group_var = self.get_or_create_weighted_group_var(
144 group_var,
145 weight,
146 &original_relation.linear_map.group_elements,
147 weighted_group_cache,
148 )?;
149
150 rhs_terms.push((scalar_var, canonical_group_var));
151 }
152 }
153
154 let mut canonical_image = original_relation.linear_map.group_elements.get(image_var)?;
156 for weighted_term in equation.terms() {
157 if let ScalarTerm::Unit = weighted_term.term.scalar {
158 let group_val = original_relation
159 .linear_map
160 .group_elements
161 .get(weighted_term.term.elem)?;
162 canonical_image -= group_val * weighted_term.weight;
163 }
164 }
165
166 #[expect(clippy::collapsible_if)]
168 if rhs_terms.is_empty() {
169 if canonical_image.is_identity().into() {
170 return Ok(());
171 }
172 }
181
182 let canonical_image_group_var = self.group_elements.push(canonical_image);
183 self.image.push(canonical_image_group_var);
184 self.linear_combinations.push(rhs_terms);
185
186 Ok(())
187 }
188
189 pub fn label(&self) -> Vec<u8> {
200 let mut out = Vec::new();
201
202 let mut constraint_data = Vec::<(u32, Vec<(u32, u32)>)>::new();
205
206 for (image_var, constraint_terms) in iter::zip(&self.image, &self.linear_combinations) {
207 let mut rhs_terms = Vec::new();
209 for (scalar_var, group_var) in constraint_terms {
210 rhs_terms.push((scalar_var.0 as u32, group_var.0 as u32));
211 }
212
213 constraint_data.push((image_var.0 as u32, rhs_terms));
214 }
215
216 let ne = constraint_data.len();
218 out.extend_from_slice(&(ne as u32).to_le_bytes());
219
220 for (lhs_index, rhs_terms) in constraint_data {
222 out.extend_from_slice(&lhs_index.to_le_bytes());
224
225 out.extend_from_slice(&(rhs_terms.len() as u32).to_le_bytes());
227
228 for (scalar_index, group_index) in rhs_terms {
230 out.extend_from_slice(&scalar_index.to_le_bytes());
231 out.extend_from_slice(&group_index.to_le_bytes());
232 }
233 }
234
235 let group_reprs = serialize_elements(
237 self.group_elements
238 .iter()
239 .map(|(_, elem)| elem.expect("expected group variable to be assigned")),
240 );
241 out.extend_from_slice(&group_reprs);
242
243 out
244 }
245
246 pub fn from_label(data: &[u8]) -> Result<Self, Error> {
262 use crate::errors::InvalidInstance;
263 use crate::group::serialization::group_elt_serialized_len;
264
265 let mut offset = 0;
266
267 if data.len() < 4 {
269 return Err(InvalidInstance::new("Invalid label: too short for equation count").into());
270 }
271 let num_equations = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
272 offset += 4;
273
274 let mut constraint_data = Vec::new();
276 let mut max_scalar_index = 0u32;
277 let mut max_group_index = 0u32;
278
279 for _ in 0..num_equations {
280 if offset + 4 > data.len() {
282 return Err(InvalidInstance::new("Invalid label: truncated LHS index").into());
283 }
284 let lhs_index = u32::from_le_bytes([
285 data[offset],
286 data[offset + 1],
287 data[offset + 2],
288 data[offset + 3],
289 ]);
290 offset += 4;
291 max_group_index = max_group_index.max(lhs_index);
292
293 if offset + 4 > data.len() {
295 return Err(InvalidInstance::new("Invalid label: truncated RHS count").into());
296 }
297 let num_rhs_terms = u32::from_le_bytes([
298 data[offset],
299 data[offset + 1],
300 data[offset + 2],
301 data[offset + 3],
302 ]) as usize;
303 offset += 4;
304
305 let mut rhs_terms = Vec::new();
307 for _ in 0..num_rhs_terms {
308 if offset + 4 > data.len() {
310 return Err(
311 InvalidInstance::new("Invalid label: truncated scalar index").into(),
312 );
313 }
314 let scalar_index = u32::from_le_bytes([
315 data[offset],
316 data[offset + 1],
317 data[offset + 2],
318 data[offset + 3],
319 ]);
320 offset += 4;
321 max_scalar_index = max_scalar_index.max(scalar_index);
322
323 if offset + 4 > data.len() {
325 return Err(InvalidInstance::new("Invalid label: truncated group index").into());
326 }
327 let group_index = u32::from_le_bytes([
328 data[offset],
329 data[offset + 1],
330 data[offset + 2],
331 data[offset + 3],
332 ]);
333 offset += 4;
334 max_group_index = max_group_index.max(group_index);
335
336 rhs_terms.push((scalar_index, group_index));
337 }
338
339 constraint_data.push((lhs_index, rhs_terms));
340 }
341
342 let num_group_elements = (max_group_index + 1) as usize;
344 let group_element_size = group_elt_serialized_len::<G>();
345 let expected_remaining = num_group_elements * group_element_size;
346
347 if data.len() - offset != expected_remaining {
348 return Err(InvalidInstance::new(format!(
349 "Invalid label: expected {} bytes for {} group elements, got {}",
350 expected_remaining,
351 num_group_elements,
352 data.len() - offset
353 ))
354 .into());
355 }
356
357 let mut group_elements_ordered = Vec::new();
359 for i in 0..num_group_elements {
360 let start = offset + i * group_element_size;
361 let end = start + group_element_size;
362 let elem_bytes = &data[start..end];
363
364 let mut repr = G::Repr::default();
365 repr.as_mut().copy_from_slice(elem_bytes);
366
367 let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
368 Error::from(InvalidInstance::new(format!(
369 "Invalid group element at index {i}"
370 )))
371 })?;
372
373 group_elements_ordered.push(elem);
374 }
375
376 let mut canonical = Self::new();
378 canonical.num_scalars = (max_scalar_index + 1) as usize;
379
380 let mut group_var_map = Vec::new();
382 for elem in &group_elements_ordered {
383 let var = canonical.group_elements.push(*elem);
384 group_var_map.push(var);
385 }
386
387 for (lhs_index, rhs_terms) in constraint_data {
389 canonical.image.push(group_var_map[lhs_index as usize]);
391
392 let mut linear_combination = Vec::new();
394 for (scalar_index, group_index) in rhs_terms {
395 let scalar_var = ScalarVar(scalar_index as usize, PhantomData);
396 let group_var = group_var_map[group_index as usize];
397 linear_combination.push((scalar_var, group_var));
398 }
399 canonical.linear_combinations.push(linear_combination);
400 }
401
402 Ok(canonical)
403 }
404
405 pub(crate) fn image_elements(&self) -> impl Iterator<Item = G> + use<'_, G> {
408 self.image.iter().map(|var| {
409 self.group_elements
410 .get(*var)
411 .expect("expected group variable to be assigned")
412 })
413 }
414}
415
416impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for CanonicalLinearRelation<G> {
417 type Error = InvalidInstance;
418
419 fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
420 Self::try_from(&value)
421 }
422}
423
424impl<G: PrimeGroup + MultiScalarMul> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
425 type Error = InvalidInstance;
426
427 fn try_from(relation: &LinearRelation<G>) -> Result<Self, Self::Error> {
428 if relation.image.len() != relation.linear_map.linear_combinations.len() {
429 return Err(InvalidInstance::new(
430 "Number of equations must be equal to number of image elements.",
431 ));
432 }
433
434 let mut canonical = CanonicalLinearRelation::new();
435 canonical.num_scalars = relation.linear_map.num_scalars;
436
437 let mut weighted_group_cache = Vec::new();
439
440 for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) {
442 let lhs_value = relation.linear_map.group_elements.get(*lhs)?;
444
445 let rhs_constant_terms = rhs
448 .0
449 .iter()
450 .filter(|term| matches!(term.term.scalar, ScalarTerm::Unit))
451 .map(|term| {
452 let elem = relation.linear_map.group_elements.get(term.term.elem)?;
453 let scalar = term.weight;
454 Ok((elem, scalar))
455 })
456 .collect::<Result<(Vec<G>, Vec<G::Scalar>), _>>()?;
457
458 let rhs_constant_term = G::msm(&rhs_constant_terms.1, &rhs_constant_terms.0);
459
460 let is_trivial = rhs.0.iter().all(|term| {
463 matches!(term.term.scalar, ScalarTerm::Unit) || term.weight.is_zero_vartime()
464 });
465
466 let is_homogenous = rhs_constant_term == lhs_value;
468
469 if is_trivial && is_homogenous {
472 continue;
473 }
474
475 if !is_trivial && is_homogenous {
477 return Err(InvalidInstance::new("Trivial kernel in this relation"));
478 }
479
480 canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
481 }
482
483 Ok(canonical)
484 }
485}
486
487impl<G: PrimeGroup + ConstantTimeEq + MultiScalarMul> CanonicalLinearRelation<G> {
488 pub fn is_witness_valid(&self, witness: &[G::Scalar]) -> Choice {
497 let got = self.evaluate(witness);
498 self.image_elements()
499 .zip(got)
500 .fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs))
501 }
502}