Skip to main content

sigma_proofs/linear_relation/
canonical.rs

1use alloc::format;
2use alloc::vec::Vec;
3use core::iter;
4use core::marker::PhantomData;
5
6use ff::Field;
7use group::prime::PrimeGroup;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar};
11use crate::errors::{Error, InvalidInstance};
12use crate::group::msm::MultiScalarMul;
13
14/// A [`LinearRelation`] in canonical form, compatible with the IETF spec.
15///
16/// This relation is type-safe:
17/// it can be instantiated only if all group vars are assigned,
18/// size match, and the relation is not trivially false.
19///
20/// This struct represents a normalized form of a linear relation where each
21/// constraint is of the form: image_i = Σ (scalar_j * group_element_k)
22/// without weights or extra scalars.
23#[derive(Clone, Debug, Default)]
24pub struct CanonicalLinearRelation<G: PrimeGroup> {
25    /// The image group elements (left-hand side of equations)
26    pub image: Vec<GroupVar<G>>,
27    /// The constraints, where each constraint is a vector of (scalar_var, group_var) pairs
28    /// representing the right-hand side of the equation
29    pub linear_combinations: Vec<Vec<(ScalarVar<G>, GroupVar<G>)>>,
30    /// The group elements map
31    pub group_elements: GroupMap<G>,
32    /// Number of scalar variables
33    pub num_scalars: usize,
34}
35
36/// Private type alias used to simplify function signatures below.
37///
38/// The cache maps each `GroupVar` index to a list of `(weight, canonical_group_var)` pairs.
39type WeightedGroupCache<G> = Vec<Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
40
41impl<G: PrimeGroup> CanonicalLinearRelation<G> {
42    /// Create a new empty canonical linear relation.
43    ///
44    /// This function is not meant to be publicly exposed. It is internally used to build a type-safe linear relation,
45    /// so that all instances guaranteed to be "good" relations over which the prover will want to make a proof.
46    fn new() -> Self {
47        Self {
48            image: Vec::new(),
49            linear_combinations: Vec::new(),
50            group_elements: GroupMap::default(),
51            num_scalars: 0,
52        }
53    }
54
55    /// Evaluate the canonical linear relation with the provided scalars
56    ///
57    /// This returns a list of image points produced by evaluating each linear combination in the
58    /// relation. The order of the returned list matches the order of [`Self::linear_combinations`].
59    ///
60    /// # Panic
61    ///
62    /// Panics if the number of scalars given is less than the number of scalar variables in this
63    /// linear relation.
64    /// If the vector of scalars if longer than the number of terms in each linear combinations, the extra terms are ignored.
65    pub fn evaluate(&self, scalars: &[G::Scalar]) -> Vec<G>
66    where
67        G: MultiScalarMul,
68    {
69        self.linear_combinations
70            .iter()
71            .map(|lc| {
72                let scalars = lc
73                    .iter()
74                    .map(|(scalar_var, _)| scalars[scalar_var.index()])
75                    .collect::<Vec<_>>();
76                let bases = lc
77                    .iter()
78                    .map(|(_, group_var)| self.group_elements.get(*group_var).unwrap())
79                    .collect::<Vec<_>>();
80                G::msm(&scalars, &bases)
81            })
82            .collect()
83    }
84
85    /// Get or create a GroupVar for a weighted group element, with deduplication
86    fn get_or_create_weighted_group_var(
87        &mut self,
88        group_var: GroupVar<G>,
89        weight: &G::Scalar,
90        original_group_elements: &GroupMap<G>,
91        weighted_group_cache: &mut WeightedGroupCache<G>,
92    ) -> Result<GroupVar<G>, InvalidInstance> {
93        // Check if we already have this (weight, group_var) combination.
94        let index = group_var.index();
95        if weighted_group_cache.len() <= index {
96            weighted_group_cache.resize_with(index + 1, Vec::new);
97        }
98        let entry = &mut weighted_group_cache[index];
99
100        // Find if we already have this weight for this group_var
101        if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
102            return Ok(*existing_var);
103        }
104
105        // Create new weighted group element
106        // Use a special case for one, as this is the most common weight.
107        let original_group_val = original_group_elements.get(group_var)?;
108        let weighted_group = match *weight == G::Scalar::ONE {
109            true => original_group_val,
110            false => original_group_val * weight,
111        };
112
113        // Add to our group elements with new index (length)
114        let new_var = self.group_elements.push(weighted_group);
115
116        // Cache the mapping for this group_var and weight
117        entry.push((*weight, new_var));
118
119        Ok(new_var)
120    }
121
122    /// Process a single constraint equation and add it to the canonical relation.
123    fn process_constraint(
124        &mut self,
125        &image_var: &GroupVar<G>,
126        equation: &LinearCombination<G>,
127        original_relation: &LinearRelation<G>,
128        weighted_group_cache: &mut WeightedGroupCache<G>,
129    ) -> Result<(), InvalidInstance> {
130        let mut rhs_terms = Vec::new();
131
132        // Collect RHS terms that have scalar variables and apply weights
133        for weighted_term in equation.terms() {
134            if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
135                let group_var = weighted_term.term.elem;
136                let weight = &weighted_term.weight;
137
138                if weight.is_zero_vartime() {
139                    continue; // Skip zero weights
140                }
141
142                let canonical_group_var = self.get_or_create_weighted_group_var(
143                    group_var,
144                    weight,
145                    &original_relation.linear_map.group_elements,
146                    weighted_group_cache,
147                )?;
148
149                rhs_terms.push((scalar_var, canonical_group_var));
150            }
151        }
152
153        // Compute the canonical image by subtracting constant terms from the original image
154        let mut canonical_image = original_relation.linear_map.group_elements.get(image_var)?;
155        for weighted_term in equation.terms() {
156            if let ScalarTerm::Unit = weighted_term.term.scalar {
157                let group_val = original_relation
158                    .linear_map
159                    .group_elements
160                    .get(weighted_term.term.elem)?;
161                canonical_image -= group_val * weighted_term.weight;
162            }
163        }
164
165        // Only include constraints that are non-trivial (not zero constraints).
166        #[expect(clippy::collapsible_if)]
167        if rhs_terms.is_empty() {
168            if canonical_image.is_identity().into() {
169                return Ok(());
170            }
171            // In this location, we have determined that the constraint is trivially false.
172            // If the constraint is added to the relation, proving will always fail for this
173            // constraint. A composed relation containing a trivially false constraint in an OR
174            // branch may still be provable.
175            //
176            // TODO: In this case, we can optimize and improve error reporting by having this
177            // library special-case trvially false statements.
178            // One approach would be to return an error here and handle it in the OR composition.
179        }
180
181        let canonical_image_group_var = self.group_elements.push(canonical_image);
182        self.image.push(canonical_image_group_var);
183        self.linear_combinations.push(rhs_terms);
184
185        Ok(())
186    }
187
188    /// Serialize the linear relation to bytes.
189    ///
190    /// The output format is:
191    ///
192    /// - `[Ne: u32]` number of equations
193    /// - `Ne × equations`:
194    ///   - `[lhs_index: u32]` output group element index
195    ///   - `[Nt: u32]` number of terms
196    ///   - `Nt × [scalar_index: u32, group_index: u32]` term entries
197    /// - All group elements in serialized form.
198    pub fn label(&self) -> Vec<u8> {
199        let mut out = Vec::new();
200
201        // Build constraint data in the same order as original, as a nested list of group and
202        // scalar indices. Note that the group indices are into group_elements_ordered.
203        let mut constraint_data = Vec::<(u32, Vec<(u32, u32)>)>::new();
204
205        for (image_var, constraint_terms) in iter::zip(&self.image, &self.linear_combinations) {
206            // Build the RHS terms
207            let mut rhs_terms = Vec::new();
208            for (scalar_var, group_var) in constraint_terms {
209                rhs_terms.push((scalar_var.0 as u32, group_var.0 as u32));
210            }
211
212            constraint_data.push((image_var.0 as u32, rhs_terms));
213        }
214
215        // 1. Number of equations
216        let ne = constraint_data.len();
217        out.extend_from_slice(&(ne as u32).to_le_bytes());
218
219        // 2. Encode each equation
220        for (lhs_index, rhs_terms) in constraint_data {
221            // a. Output point index (LHS)
222            out.extend_from_slice(&lhs_index.to_le_bytes());
223
224            // b. Number of terms in the RHS linear combination
225            out.extend_from_slice(&(rhs_terms.len() as u32).to_le_bytes());
226
227            // c. Each term: scalar index and point index
228            for (scalar_index, group_index) in rhs_terms {
229                out.extend_from_slice(&scalar_index.to_le_bytes());
230                out.extend_from_slice(&group_index.to_le_bytes());
231            }
232        }
233
234        // Dump the group elements.
235        for (_, elem) in self.group_elements.iter() {
236            out.extend_from_slice(
237                elem.expect("expected group variable to be assigned")
238                    .to_bytes()
239                    .as_ref(),
240            );
241        }
242
243        out
244    }
245
246    /// Parse a canonical linear relation from its label representation.
247    ///
248    /// Returns an [`InvalidInstance`] error if the label is malformed.
249    ///
250    /// # Examples
251    ///
252    /// ```
253    /// use hex_literal::hex;
254    /// use sigma_proofs::linear_relation::CanonicalLinearRelation;
255    /// type G = bls12_381::G1Projective;
256    ///
257    /// let dlog_instance_label = hex!("01000000000000000100000000000000010000009823a3def60a6e07fb25feb35f211ee2cbc9c130c1959514f5df6b5021a2b21a4c973630ec2090c733c1fe791834ce1197f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb");
258    /// let instance = CanonicalLinearRelation::<G>::from_label(&dlog_instance_label).unwrap();
259    /// assert_eq!(&dlog_instance_label[..], &instance.label()[..]);
260    /// ```
261    pub fn from_label(data: &[u8]) -> Result<Self, Error> {
262        use crate::errors::InvalidInstance;
263
264        let mut offset = 0;
265
266        // Read number of equations (4 bytes, little endian)
267        if data.len() < 4 {
268            return Err(InvalidInstance::new("Invalid label: too short for equation count").into());
269        }
270        let num_equations = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
271        offset += 4;
272
273        // Parse constraints and collect unique group element indices
274        let mut constraint_data = Vec::new();
275        let mut max_scalar_index = 0u32;
276        let mut max_group_index = 0u32;
277
278        for _ in 0..num_equations {
279            // Read LHS index (4 bytes)
280            if offset + 4 > data.len() {
281                return Err(InvalidInstance::new("Invalid label: truncated LHS index").into());
282            }
283            let lhs_index = u32::from_le_bytes([
284                data[offset],
285                data[offset + 1],
286                data[offset + 2],
287                data[offset + 3],
288            ]);
289            offset += 4;
290            max_group_index = max_group_index.max(lhs_index);
291
292            // Read number of RHS terms (4 bytes)
293            if offset + 4 > data.len() {
294                return Err(InvalidInstance::new("Invalid label: truncated RHS count").into());
295            }
296            let num_rhs_terms = u32::from_le_bytes([
297                data[offset],
298                data[offset + 1],
299                data[offset + 2],
300                data[offset + 3],
301            ]) as usize;
302            offset += 4;
303
304            // Read RHS terms
305            let mut rhs_terms = Vec::new();
306            for _ in 0..num_rhs_terms {
307                // Read scalar index (4 bytes)
308                if offset + 4 > data.len() {
309                    return Err(
310                        InvalidInstance::new("Invalid label: truncated scalar index").into(),
311                    );
312                }
313                let scalar_index = u32::from_le_bytes([
314                    data[offset],
315                    data[offset + 1],
316                    data[offset + 2],
317                    data[offset + 3],
318                ]);
319                offset += 4;
320                max_scalar_index = max_scalar_index.max(scalar_index);
321
322                // Read group index (4 bytes)
323                if offset + 4 > data.len() {
324                    return Err(InvalidInstance::new("Invalid label: truncated group index").into());
325                }
326                let group_index = u32::from_le_bytes([
327                    data[offset],
328                    data[offset + 1],
329                    data[offset + 2],
330                    data[offset + 3],
331                ]);
332                offset += 4;
333                max_group_index = max_group_index.max(group_index);
334
335                rhs_terms.push((scalar_index, group_index));
336            }
337
338            constraint_data.push((lhs_index, rhs_terms));
339        }
340
341        // Calculate expected number of group elements
342        let num_group_elements = (max_group_index + 1) as usize;
343        let group_element_size = G::Repr::default().as_ref().len();
344        let expected_remaining = num_group_elements * group_element_size;
345
346        if data.len() - offset != expected_remaining {
347            return Err(InvalidInstance::new(format!(
348                "Invalid label: expected {} bytes for {} group elements, got {}",
349                expected_remaining,
350                num_group_elements,
351                data.len() - offset
352            ))
353            .into());
354        }
355
356        // Parse group elements
357        let mut group_elements_ordered = Vec::new();
358        for i in 0..num_group_elements {
359            let start = offset + i * group_element_size;
360            let end = start + group_element_size;
361            let elem_bytes = &data[start..end];
362
363            let mut repr = G::Repr::default();
364            repr.as_mut().copy_from_slice(elem_bytes);
365
366            let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
367                Error::from(InvalidInstance::new(format!(
368                    "Invalid group element at index {i}"
369                )))
370            })?;
371
372            group_elements_ordered.push(elem);
373        }
374
375        // Build the canonical relation
376        let mut canonical = Self::new();
377        canonical.num_scalars = (max_scalar_index + 1) as usize;
378
379        // Add all group elements to the map
380        let mut group_var_map = Vec::new();
381        for elem in &group_elements_ordered {
382            let var = canonical.group_elements.push(*elem);
383            group_var_map.push(var);
384        }
385
386        // Build constraints
387        for (lhs_index, rhs_terms) in constraint_data {
388            // Add image element
389            canonical.image.push(group_var_map[lhs_index as usize]);
390
391            // Build linear combination
392            let mut linear_combination = Vec::new();
393            for (scalar_index, group_index) in rhs_terms {
394                let scalar_var = ScalarVar(scalar_index as usize, PhantomData);
395                let group_var = group_var_map[group_index as usize];
396                linear_combination.push((scalar_var, group_var));
397            }
398            canonical.linear_combinations.push(linear_combination);
399        }
400
401        Ok(canonical)
402    }
403
404    /// Access the group elements associated with the image (i.e. left-hand side), panicking if any
405    /// of the image variables are unassigned in the group mkap.
406    pub(crate) fn image_elements(&self) -> impl Iterator<Item = G> + use<'_, G> {
407        self.image.iter().map(|var| {
408            self.group_elements
409                .get(*var)
410                .expect("expected group variable to be assigned")
411        })
412    }
413}
414
415impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for CanonicalLinearRelation<G> {
416    type Error = InvalidInstance;
417
418    fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
419        Self::try_from(&value)
420    }
421}
422
423impl<G: PrimeGroup + MultiScalarMul> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
424    type Error = InvalidInstance;
425
426    fn try_from(relation: &LinearRelation<G>) -> Result<Self, Self::Error> {
427        if relation.image.len() != relation.linear_map.linear_combinations.len() {
428            return Err(InvalidInstance::new(
429                "Number of equations must be equal to number of image elements.",
430            ));
431        }
432
433        let mut canonical = CanonicalLinearRelation::new();
434        canonical.num_scalars = relation.linear_map.num_scalars;
435
436        // Cache for deduplicating weighted group elements.
437        let mut weighted_group_cache = Vec::new();
438
439        // Process each constraint using the modular helper method
440        for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) {
441            // If any group element in the image is not assigned, return `InvalidInstance`.
442            let lhs_value = relation.linear_map.group_elements.get(*lhs)?;
443
444            // Compute the constant terms on the right-hand side of the equation.
445            // If any group element in the linear constraints is not assigned, return `InvalidInstance`.
446            let rhs_constant_terms = rhs
447                .0
448                .iter()
449                .filter(|term| matches!(term.term.scalar, ScalarTerm::Unit))
450                .map(|term| {
451                    let elem = relation.linear_map.group_elements.get(term.term.elem)?;
452                    let scalar = term.weight;
453                    Ok((elem, scalar))
454                })
455                .collect::<Result<(Vec<G>, Vec<G::Scalar>), _>>()?;
456
457            let rhs_constant_term = G::msm(&rhs_constant_terms.1, &rhs_constant_terms.0);
458
459            // We say that an equation is trivial if it contains no scalar variables.
460            // To "contain no scalar variables" means that each term in the right-hand side is a unit or its weight is zero.
461            let is_trivial = rhs.0.iter().all(|term| {
462                matches!(term.term.scalar, ScalarTerm::Unit) || term.weight.is_zero_vartime()
463            });
464
465            // We say that an equation is homogenous if the constant term is zero.
466            let is_homogenous = rhs_constant_term == lhs_value;
467
468            // Skip processing trivial equations that are always true.
469            // There's nothing to prove here.
470            if is_trivial && is_homogenous {
471                continue;
472            }
473
474            // Disallow non-trivial equations with trivial solutions.
475            if !is_trivial && is_homogenous {
476                return Err(InvalidInstance::new("Trivial kernel in this relation"));
477            }
478
479            canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
480        }
481
482        Ok(canonical)
483    }
484}
485
486impl<G: PrimeGroup + ConstantTimeEq + MultiScalarMul> CanonicalLinearRelation<G> {
487    /// Tests is the witness is valid.
488    ///
489    /// Returns a [`Choice`] indicating if the witness is valid for the instance constructed.
490    ///
491    /// # Panic
492    ///
493    /// Panics if the number of scalars given is less than the number of scalar variables.
494    /// If the number of scalars is more than the number of scalar variables, the extra elements are ignored.
495    pub fn is_witness_valid(&self, witness: &[G::Scalar]) -> Choice {
496        let got = self.evaluate(witness);
497        self.image_elements()
498            .zip(got)
499            .fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs))
500    }
501}