Skip to main content

sigma_proofs/linear_relation/
canonical.rs

1use alloc::format;
2use alloc::vec::Vec;
3use core::iter;
4use core::marker::PhantomData;
5
6use ff::Field;
7use group::prime::PrimeGroup;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar};
11use crate::errors::{Error, InvalidInstance};
12use crate::group::msm::MultiScalarMul;
13use crate::serialization::serialize_elements;
14
15/// A [`LinearRelation`] in canonical form, compatible with the IETF spec.
16///
17/// This relation is type-safe:
18/// it can be instantiated only if all group vars are assigned,
19/// size match, and the relation is not trivially false.
20///
21/// This struct represents a normalized form of a linear relation where each
22/// constraint is of the form: image_i = Σ (scalar_j * group_element_k)
23/// without weights or extra scalars.
24#[derive(Clone, Debug, Default)]
25pub struct CanonicalLinearRelation<G: PrimeGroup> {
26    /// The image group elements (left-hand side of equations)
27    pub image: Vec<GroupVar<G>>,
28    /// The constraints, where each constraint is a vector of (scalar_var, group_var) pairs
29    /// representing the right-hand side of the equation
30    pub linear_combinations: Vec<Vec<(ScalarVar<G>, GroupVar<G>)>>,
31    /// The group elements map
32    pub group_elements: GroupMap<G>,
33    /// Number of scalar variables
34    pub num_scalars: usize,
35}
36
37/// Private type alias used to simplify function signatures below.
38///
39/// The cache maps each `GroupVar` index to a list of `(weight, canonical_group_var)` pairs.
40type WeightedGroupCache<G> = Vec<Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
41
42impl<G: PrimeGroup> CanonicalLinearRelation<G> {
43    /// Create a new empty canonical linear relation.
44    ///
45    /// This function is not meant to be publicly exposed. It is internally used to build a type-safe linear relation,
46    /// so that all instances guaranteed to be "good" relations over which the prover will want to make a proof.
47    fn new() -> Self {
48        Self {
49            image: Vec::new(),
50            linear_combinations: Vec::new(),
51            group_elements: GroupMap::default(),
52            num_scalars: 0,
53        }
54    }
55
56    /// Evaluate the canonical linear relation with the provided scalars
57    ///
58    /// This returns a list of image points produced by evaluating each linear combination in the
59    /// relation. The order of the returned list matches the order of [`Self::linear_combinations`].
60    ///
61    /// # Panic
62    ///
63    /// Panics if the number of scalars given is less than the number of scalar variables in this
64    /// linear relation.
65    /// If the vector of scalars if longer than the number of terms in each linear combinations, the extra terms are ignored.
66    pub fn evaluate(&self, scalars: &[G::Scalar]) -> Vec<G>
67    where
68        G: MultiScalarMul,
69    {
70        self.linear_combinations
71            .iter()
72            .map(|lc| {
73                let scalars = lc
74                    .iter()
75                    .map(|(scalar_var, _)| scalars[scalar_var.index()])
76                    .collect::<Vec<_>>();
77                let bases = lc
78                    .iter()
79                    .map(|(_, group_var)| self.group_elements.get(*group_var).unwrap())
80                    .collect::<Vec<_>>();
81                G::msm(&scalars, &bases)
82            })
83            .collect()
84    }
85
86    /// Get or create a GroupVar for a weighted group element, with deduplication
87    fn get_or_create_weighted_group_var(
88        &mut self,
89        group_var: GroupVar<G>,
90        weight: &G::Scalar,
91        original_group_elements: &GroupMap<G>,
92        weighted_group_cache: &mut WeightedGroupCache<G>,
93    ) -> Result<GroupVar<G>, InvalidInstance> {
94        // Check if we already have this (weight, group_var) combination.
95        let index = group_var.index();
96        if weighted_group_cache.len() <= index {
97            weighted_group_cache.resize_with(index + 1, Vec::new);
98        }
99        let entry = &mut weighted_group_cache[index];
100
101        // Find if we already have this weight for this group_var
102        if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
103            return Ok(*existing_var);
104        }
105
106        // Create new weighted group element
107        // Use a special case for one, as this is the most common weight.
108        let original_group_val = original_group_elements.get(group_var)?;
109        let weighted_group = match *weight == G::Scalar::ONE {
110            true => original_group_val,
111            false => original_group_val * weight,
112        };
113
114        // Add to our group elements with new index (length)
115        let new_var = self.group_elements.push(weighted_group);
116
117        // Cache the mapping for this group_var and weight
118        entry.push((*weight, new_var));
119
120        Ok(new_var)
121    }
122
123    /// Process a single constraint equation and add it to the canonical relation.
124    fn process_constraint(
125        &mut self,
126        &image_var: &GroupVar<G>,
127        equation: &LinearCombination<G>,
128        original_relation: &LinearRelation<G>,
129        weighted_group_cache: &mut WeightedGroupCache<G>,
130    ) -> Result<(), InvalidInstance> {
131        let mut rhs_terms = Vec::new();
132
133        // Collect RHS terms that have scalar variables and apply weights
134        for weighted_term in equation.terms() {
135            if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
136                let group_var = weighted_term.term.elem;
137                let weight = &weighted_term.weight;
138
139                if weight.is_zero_vartime() {
140                    continue; // Skip zero weights
141                }
142
143                let canonical_group_var = self.get_or_create_weighted_group_var(
144                    group_var,
145                    weight,
146                    &original_relation.linear_map.group_elements,
147                    weighted_group_cache,
148                )?;
149
150                rhs_terms.push((scalar_var, canonical_group_var));
151            }
152        }
153
154        // Compute the canonical image by subtracting constant terms from the original image
155        let mut canonical_image = original_relation.linear_map.group_elements.get(image_var)?;
156        for weighted_term in equation.terms() {
157            if let ScalarTerm::Unit = weighted_term.term.scalar {
158                let group_val = original_relation
159                    .linear_map
160                    .group_elements
161                    .get(weighted_term.term.elem)?;
162                canonical_image -= group_val * weighted_term.weight;
163            }
164        }
165
166        // Only include constraints that are non-trivial (not zero constraints).
167        #[expect(clippy::collapsible_if)]
168        if rhs_terms.is_empty() {
169            if canonical_image.is_identity().into() {
170                return Ok(());
171            }
172            // In this location, we have determined that the constraint is trivially false.
173            // If the constraint is added to the relation, proving will always fail for this
174            // constraint. A composed relation containing a trivially false constraint in an OR
175            // branch may still be provable.
176            //
177            // TODO: In this case, we can optimize and improve error reporting by having this
178            // library special-case trvially false statements.
179            // One approach would be to return an error here and handle it in the OR composition.
180        }
181
182        let canonical_image_group_var = self.group_elements.push(canonical_image);
183        self.image.push(canonical_image_group_var);
184        self.linear_combinations.push(rhs_terms);
185
186        Ok(())
187    }
188
189    /// Serialize the linear relation to bytes.
190    ///
191    /// The output format is:
192    ///
193    /// - `[Ne: u32]` number of equations
194    /// - `Ne × equations`:
195    ///   - `[lhs_index: u32]` output group element index
196    ///   - `[Nt: u32]` number of terms
197    ///   - `Nt × [scalar_index: u32, group_index: u32]` term entries
198    /// - All group elements in serialized form.
199    pub fn label(&self) -> Vec<u8> {
200        let mut out = Vec::new();
201
202        // Build constraint data in the same order as original, as a nested list of group and
203        // scalar indices. Note that the group indices are into group_elements_ordered.
204        let mut constraint_data = Vec::<(u32, Vec<(u32, u32)>)>::new();
205
206        for (image_var, constraint_terms) in iter::zip(&self.image, &self.linear_combinations) {
207            // Build the RHS terms
208            let mut rhs_terms = Vec::new();
209            for (scalar_var, group_var) in constraint_terms {
210                rhs_terms.push((scalar_var.0 as u32, group_var.0 as u32));
211            }
212
213            constraint_data.push((image_var.0 as u32, rhs_terms));
214        }
215
216        // 1. Number of equations
217        let ne = constraint_data.len();
218        out.extend_from_slice(&(ne as u32).to_le_bytes());
219
220        // 2. Encode each equation
221        for (lhs_index, rhs_terms) in constraint_data {
222            // a. Output point index (LHS)
223            out.extend_from_slice(&lhs_index.to_le_bytes());
224
225            // b. Number of terms in the RHS linear combination
226            out.extend_from_slice(&(rhs_terms.len() as u32).to_le_bytes());
227
228            // c. Each term: scalar index and point index
229            for (scalar_index, group_index) in rhs_terms {
230                out.extend_from_slice(&scalar_index.to_le_bytes());
231                out.extend_from_slice(&group_index.to_le_bytes());
232            }
233        }
234
235        // Dump the group elements.
236        let group_reprs = serialize_elements(
237            self.group_elements
238                .iter()
239                .map(|(_, elem)| elem.expect("expected group variable to be assigned")),
240        );
241        out.extend_from_slice(&group_reprs);
242
243        out
244    }
245
246    /// Parse a canonical linear relation from its label representation.
247    ///
248    /// Returns an [`InvalidInstance`] error if the label is malformed.
249    ///
250    /// # Examples
251    ///
252    /// ```
253    /// use hex_literal::hex;
254    /// use sigma_proofs::linear_relation::CanonicalLinearRelation;
255    /// type G = bls12_381::G1Projective;
256    ///
257    /// let dlog_instance_label = hex!("01000000000000000100000000000000010000009823a3def60a6e07fb25feb35f211ee2cbc9c130c1959514f5df6b5021a2b21a4c973630ec2090c733c1fe791834ce1197f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb");
258    /// let instance = CanonicalLinearRelation::<G>::from_label(&dlog_instance_label).unwrap();
259    /// assert_eq!(&dlog_instance_label[..], &instance.label()[..]);
260    /// ```
261    pub fn from_label(data: &[u8]) -> Result<Self, Error> {
262        use crate::errors::InvalidInstance;
263        use crate::group::serialization::group_elt_serialized_len;
264
265        let mut offset = 0;
266
267        // Read number of equations (4 bytes, little endian)
268        if data.len() < 4 {
269            return Err(InvalidInstance::new("Invalid label: too short for equation count").into());
270        }
271        let num_equations = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
272        offset += 4;
273
274        // Parse constraints and collect unique group element indices
275        let mut constraint_data = Vec::new();
276        let mut max_scalar_index = 0u32;
277        let mut max_group_index = 0u32;
278
279        for _ in 0..num_equations {
280            // Read LHS index (4 bytes)
281            if offset + 4 > data.len() {
282                return Err(InvalidInstance::new("Invalid label: truncated LHS index").into());
283            }
284            let lhs_index = u32::from_le_bytes([
285                data[offset],
286                data[offset + 1],
287                data[offset + 2],
288                data[offset + 3],
289            ]);
290            offset += 4;
291            max_group_index = max_group_index.max(lhs_index);
292
293            // Read number of RHS terms (4 bytes)
294            if offset + 4 > data.len() {
295                return Err(InvalidInstance::new("Invalid label: truncated RHS count").into());
296            }
297            let num_rhs_terms = u32::from_le_bytes([
298                data[offset],
299                data[offset + 1],
300                data[offset + 2],
301                data[offset + 3],
302            ]) as usize;
303            offset += 4;
304
305            // Read RHS terms
306            let mut rhs_terms = Vec::new();
307            for _ in 0..num_rhs_terms {
308                // Read scalar index (4 bytes)
309                if offset + 4 > data.len() {
310                    return Err(
311                        InvalidInstance::new("Invalid label: truncated scalar index").into(),
312                    );
313                }
314                let scalar_index = u32::from_le_bytes([
315                    data[offset],
316                    data[offset + 1],
317                    data[offset + 2],
318                    data[offset + 3],
319                ]);
320                offset += 4;
321                max_scalar_index = max_scalar_index.max(scalar_index);
322
323                // Read group index (4 bytes)
324                if offset + 4 > data.len() {
325                    return Err(InvalidInstance::new("Invalid label: truncated group index").into());
326                }
327                let group_index = u32::from_le_bytes([
328                    data[offset],
329                    data[offset + 1],
330                    data[offset + 2],
331                    data[offset + 3],
332                ]);
333                offset += 4;
334                max_group_index = max_group_index.max(group_index);
335
336                rhs_terms.push((scalar_index, group_index));
337            }
338
339            constraint_data.push((lhs_index, rhs_terms));
340        }
341
342        // Calculate expected number of group elements
343        let num_group_elements = (max_group_index + 1) as usize;
344        let group_element_size = group_elt_serialized_len::<G>();
345        let expected_remaining = num_group_elements * group_element_size;
346
347        if data.len() - offset != expected_remaining {
348            return Err(InvalidInstance::new(format!(
349                "Invalid label: expected {} bytes for {} group elements, got {}",
350                expected_remaining,
351                num_group_elements,
352                data.len() - offset
353            ))
354            .into());
355        }
356
357        // Parse group elements
358        let mut group_elements_ordered = Vec::new();
359        for i in 0..num_group_elements {
360            let start = offset + i * group_element_size;
361            let end = start + group_element_size;
362            let elem_bytes = &data[start..end];
363
364            let mut repr = G::Repr::default();
365            repr.as_mut().copy_from_slice(elem_bytes);
366
367            let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
368                Error::from(InvalidInstance::new(format!(
369                    "Invalid group element at index {i}"
370                )))
371            })?;
372
373            group_elements_ordered.push(elem);
374        }
375
376        // Build the canonical relation
377        let mut canonical = Self::new();
378        canonical.num_scalars = (max_scalar_index + 1) as usize;
379
380        // Add all group elements to the map
381        let mut group_var_map = Vec::new();
382        for elem in &group_elements_ordered {
383            let var = canonical.group_elements.push(*elem);
384            group_var_map.push(var);
385        }
386
387        // Build constraints
388        for (lhs_index, rhs_terms) in constraint_data {
389            // Add image element
390            canonical.image.push(group_var_map[lhs_index as usize]);
391
392            // Build linear combination
393            let mut linear_combination = Vec::new();
394            for (scalar_index, group_index) in rhs_terms {
395                let scalar_var = ScalarVar(scalar_index as usize, PhantomData);
396                let group_var = group_var_map[group_index as usize];
397                linear_combination.push((scalar_var, group_var));
398            }
399            canonical.linear_combinations.push(linear_combination);
400        }
401
402        Ok(canonical)
403    }
404
405    /// Access the group elements associated with the image (i.e. left-hand side), panicking if any
406    /// of the image variables are unassigned in the group mkap.
407    pub(crate) fn image_elements(&self) -> impl Iterator<Item = G> + use<'_, G> {
408        self.image.iter().map(|var| {
409            self.group_elements
410                .get(*var)
411                .expect("expected group variable to be assigned")
412        })
413    }
414}
415
416impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for CanonicalLinearRelation<G> {
417    type Error = InvalidInstance;
418
419    fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
420        Self::try_from(&value)
421    }
422}
423
424impl<G: PrimeGroup + MultiScalarMul> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
425    type Error = InvalidInstance;
426
427    fn try_from(relation: &LinearRelation<G>) -> Result<Self, Self::Error> {
428        if relation.image.len() != relation.linear_map.linear_combinations.len() {
429            return Err(InvalidInstance::new(
430                "Number of equations must be equal to number of image elements.",
431            ));
432        }
433
434        let mut canonical = CanonicalLinearRelation::new();
435        canonical.num_scalars = relation.linear_map.num_scalars;
436
437        // Cache for deduplicating weighted group elements.
438        let mut weighted_group_cache = Vec::new();
439
440        // Process each constraint using the modular helper method
441        for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) {
442            // If any group element in the image is not assigned, return `InvalidInstance`.
443            let lhs_value = relation.linear_map.group_elements.get(*lhs)?;
444
445            // Compute the constant terms on the right-hand side of the equation.
446            // If any group element in the linear constraints is not assigned, return `InvalidInstance`.
447            let rhs_constant_terms = rhs
448                .0
449                .iter()
450                .filter(|term| matches!(term.term.scalar, ScalarTerm::Unit))
451                .map(|term| {
452                    let elem = relation.linear_map.group_elements.get(term.term.elem)?;
453                    let scalar = term.weight;
454                    Ok((elem, scalar))
455                })
456                .collect::<Result<(Vec<G>, Vec<G::Scalar>), _>>()?;
457
458            let rhs_constant_term = G::msm(&rhs_constant_terms.1, &rhs_constant_terms.0);
459
460            // We say that an equation is trivial if it contains no scalar variables.
461            // To "contain no scalar variables" means that each term in the right-hand side is a unit or its weight is zero.
462            let is_trivial = rhs.0.iter().all(|term| {
463                matches!(term.term.scalar, ScalarTerm::Unit) || term.weight.is_zero_vartime()
464            });
465
466            // We say that an equation is homogenous if the constant term is zero.
467            let is_homogenous = rhs_constant_term == lhs_value;
468
469            // Skip processing trivial equations that are always true.
470            // There's nothing to prove here.
471            if is_trivial && is_homogenous {
472                continue;
473            }
474
475            // Disallow non-trivial equations with trivial solutions.
476            if !is_trivial && is_homogenous {
477                return Err(InvalidInstance::new("Trivial kernel in this relation"));
478            }
479
480            canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
481        }
482
483        Ok(canonical)
484    }
485}
486
487impl<G: PrimeGroup + ConstantTimeEq + MultiScalarMul> CanonicalLinearRelation<G> {
488    /// Tests is the witness is valid.
489    ///
490    /// Returns a [`Choice`] indicating if the witness is valid for the instance constructed.
491    ///
492    /// # Panic
493    ///
494    /// Panics if the number of scalars given is less than the number of scalar variables.
495    /// If the number of scalars is more than the number of scalar variables, the extra elements are ignored.
496    pub fn is_witness_valid(&self, witness: &[G::Scalar]) -> Choice {
497        let got = self.evaluate(witness);
498        self.image_elements()
499            .zip(got)
500            .fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs))
501    }
502}