Skip to main content

sigma_proofs/
composition.rs

1//! # Protocol Composition with AND/OR Logic
2//!
3//! This module defines the [`ComposedRelation`] enum, which generalizes the [`CanonicalLinearRelation`]
4//! by enabling compositional logic between multiple proof instances.
5//!
6//! Specifically, it supports:
7//! - Simple atomic proofs (e.g., discrete logarithm, Pedersen commitments)
8//! - Conjunctions (`And`) of multiple sub-protocols
9//! - Disjunctions (`Or`) of multiple sub-protocols
10//! - Thresholds (`Threshold`) over multiple sub-protocols
11//!
12//! ## Example Composition
13//!
14//! ```ignore
15//! And(
16//!    Or(dleq, pedersen_commitment),
17//!    Simple(discrete_logarithm),
18//!    And(pedersen_commitment_dleq, bbs_blind_commitment_computation)
19//! )
20//! ```
21
22use alloc::{vec, vec::Vec};
23use ff::{Field, PrimeField};
24use group::prime::PrimeGroup;
25use itertools::Itertools;
26use sha3::{Digest, Sha3_256};
27use spongefish::{
28    Decoding, Encoding, NargDeserialize, NargSerialize, VerificationError, VerificationResult,
29};
30use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
31
32use crate::errors::InvalidInstance;
33use crate::traits::ScalarRng;
34use crate::MultiScalarMul;
35use crate::{
36    errors::Error,
37    fiat_shamir::Nizk,
38    linear_relation::{CanonicalLinearRelation, LinearRelation},
39    traits::{SigmaProtocol, SigmaProtocolSimulator},
40};
41
42/// A protocol proving knowledge of a witness for a composition of linear relations.
43///
44/// This implementation generalizes [`CanonicalLinearRelation`] by using AND/OR links.
45///
46/// # Type Parameters
47/// - `G`: A cryptographic group implementing [`group::Group`] and [`group::GroupEncoding`].
48#[derive(Clone)]
49pub enum ComposedRelation<G: PrimeGroup> {
50    Simple(CanonicalLinearRelation<G>),
51    And(Vec<ComposedRelation<G>>),
52    Or(Vec<ComposedRelation<G>>),
53    Threshold(usize, Vec<ComposedRelation<G>>),
54}
55
56impl<G: PrimeGroup + ConstantTimeEq + ConditionallySelectable> ComposedRelation<G> {
57    /// Create a [ComposedRelation] for an AND relation from the given list of relations.
58    pub fn and<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
59        Self::And(witness.into_iter().map(|x| x.into()).collect())
60    }
61
62    /// Create a [ComposedRelation] for an OR relation from the given list of relations.
63    pub fn or<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
64        Self::Or(witness.into_iter().map(|x| x.into()).collect())
65    }
66
67    /// Create a [ComposedRelation] for a threshold relation from the given list of relations.
68    pub fn threshold<T: Into<ComposedRelation<G>>>(
69        threshold: usize,
70        witness: impl IntoIterator<Item = T>,
71    ) -> Self {
72        Self::Threshold(threshold, witness.into_iter().map(|x| x.into()).collect())
73    }
74}
75
76impl<G: PrimeGroup> From<CanonicalLinearRelation<G>> for ComposedRelation<G> {
77    fn from(value: CanonicalLinearRelation<G>) -> Self {
78        ComposedRelation::Simple(value)
79    }
80}
81
82impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for ComposedRelation<G> {
83    type Error = InvalidInstance;
84
85    fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
86        Ok(Self::Simple(CanonicalLinearRelation::try_from(value)?))
87    }
88}
89
90// Structure representing the Commitment type of Protocol as SigmaProtocol
91#[derive(Clone)]
92pub enum ComposedCommitment<G>
93where
94    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
95    G::Scalar:
96        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
97{
98    Simple(Vec<G>),
99    And(Vec<ComposedCommitment<G>>),
100    Or(Vec<ComposedCommitment<G>>),
101    Threshold(Vec<ComposedCommitment<G>>),
102}
103
104impl<G: PrimeGroup> ComposedCommitment<G>
105where
106    G: ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
107    G::Scalar:
108        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
109{
110    /// Conditionally select between two ComposedCommitment values.
111    /// This function performs constant-time selection of the commitment values.
112    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
113        match (a, b) {
114            (ComposedCommitment::Simple(a_elements), ComposedCommitment::Simple(b_elements)) => {
115                let selected: Vec<G> = a_elements
116                    .iter()
117                    .zip_eq(b_elements.iter())
118                    .map(|(a, b)| G::conditional_select(a, b, choice))
119                    .collect();
120                ComposedCommitment::Simple(selected)
121            }
122            (ComposedCommitment::And(a_commitments), ComposedCommitment::And(b_commitments)) => {
123                let selected: Vec<ComposedCommitment<G>> = a_commitments
124                    .iter()
125                    .zip_eq(b_commitments.iter())
126                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
127                    .collect();
128                ComposedCommitment::And(selected)
129            }
130            (ComposedCommitment::Or(a_commitments), ComposedCommitment::Or(b_commitments)) => {
131                let selected: Vec<ComposedCommitment<G>> = a_commitments
132                    .iter()
133                    .zip_eq(b_commitments.iter())
134                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
135                    .collect();
136                ComposedCommitment::Or(selected)
137            }
138            (
139                ComposedCommitment::Threshold(a_commitments),
140                ComposedCommitment::Threshold(b_commitments),
141            ) => {
142                let selected: Vec<ComposedCommitment<G>> = a_commitments
143                    .iter()
144                    .zip_eq(b_commitments.iter())
145                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
146                    .collect();
147                ComposedCommitment::Threshold(selected)
148            }
149            _ => {
150                unreachable!("Mismatched ComposedCommitment variants in conditional_select");
151            }
152        }
153    }
154}
155
156// Structure representing the ProverState type of Protocol as SigmaProtocol
157pub enum ComposedProverState<G>
158where
159    G: PrimeGroup
160        + ConstantTimeEq
161        + ConditionallySelectable
162        + Encoding<[u8]>
163        + NargSerialize
164        + NargDeserialize
165        + MultiScalarMul,
166    G::Scalar:
167        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
168{
169    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::ProverState),
170    And(Vec<ComposedProverState<G>>),
171    Or(ComposedOrProverState<G>),
172    Threshold(ComposedThresholdProverState<G>),
173}
174
175pub type ComposedOrProverState<G> = Vec<ComposedOrProverStateEntry<G>>;
176pub struct ComposedOrProverStateEntry<G>(
177    Choice,
178    ComposedProverState<G>,
179    ComposedChallenge<G>,
180    ComposedResponse<G>,
181)
182where
183    G: PrimeGroup
184        + ConstantTimeEq
185        + ConditionallySelectable
186        + Encoding<[u8]>
187        + NargSerialize
188        + NargDeserialize
189        + MultiScalarMul,
190    G::Scalar:
191        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable;
192
193pub type ComposedThresholdProverState<G> = Vec<ComposedThresholdProverStateEntry<G>>;
194pub struct ComposedThresholdProverStateEntry<G>
195where
196    G: PrimeGroup
197        + ConstantTimeEq
198        + ConditionallySelectable
199        + Encoding<[u8]>
200        + NargSerialize
201        + NargDeserialize
202        + MultiScalarMul,
203    G::Scalar:
204        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
205{
206    use_simulator: Choice,
207    prover_state: ComposedProverState<G>,
208    simulated_challenge: ComposedChallenge<G>,
209    simulated_response: ComposedResponse<G>,
210}
211
212// Structure representing the Response type of Protocol as SigmaProtocol
213#[derive(Clone)]
214pub enum ComposedResponse<G>
215where
216    G: PrimeGroup
217        + ConditionallySelectable
218        + Encoding<[u8]>
219        + NargSerialize
220        + NargDeserialize
221        + MultiScalarMul,
222    G::Scalar:
223        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
224{
225    Simple(Vec<<CanonicalLinearRelation<G> as SigmaProtocol>::Response>),
226    And(Vec<ComposedResponse<G>>),
227    Or(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
228    Threshold(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
229}
230
231const TAG_SIMPLE: u8 = 0;
232const TAG_AND: u8 = 1;
233const TAG_OR: u8 = 2;
234const TAG_THRESHOLD: u8 = 3;
235
236fn read_u32(buf: &mut &[u8]) -> VerificationResult<u32> {
237    if buf.len() < 4 {
238        return Err(VerificationError);
239    }
240    let (head, tail) = buf.split_at(4);
241    *buf = tail;
242    Ok(u32::from_le_bytes(head.try_into().unwrap()))
243}
244
245fn write_len(out: &mut Vec<u8>, len: usize) {
246    out.extend_from_slice(&(len as u32).to_le_bytes());
247}
248
249impl<G> Encoding<[u8]> for ComposedCommitment<G>
250where
251    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
252    G::Scalar:
253        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
254{
255    fn encode(&self) -> impl AsRef<[u8]> {
256        let mut out = Vec::new();
257        match self {
258            ComposedCommitment::Simple(elems) => {
259                out.push(TAG_SIMPLE);
260                write_len(&mut out, elems.len());
261                for elem in elems {
262                    elem.serialize_into_narg(&mut out);
263                }
264            }
265            ComposedCommitment::And(cs) => {
266                out.push(TAG_AND);
267                write_len(&mut out, cs.len());
268                for c in cs {
269                    c.serialize_into_narg(&mut out);
270                }
271            }
272            ComposedCommitment::Or(cs) => {
273                out.push(TAG_OR);
274                write_len(&mut out, cs.len());
275                for c in cs {
276                    c.serialize_into_narg(&mut out);
277                }
278            }
279            ComposedCommitment::Threshold(cs) => {
280                out.push(TAG_THRESHOLD);
281                write_len(&mut out, cs.len());
282                for c in cs {
283                    c.serialize_into_narg(&mut out);
284                }
285            }
286        }
287        out
288    }
289}
290
291impl<G> NargDeserialize for ComposedCommitment<G>
292where
293    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
294    G::Scalar:
295        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
296{
297    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
298        if buf.is_empty() {
299            return Err(VerificationError);
300        }
301        let (tag_bytes, rest) = buf.split_at(1);
302        *buf = rest;
303        match tag_bytes[0] {
304            TAG_SIMPLE => {
305                let len = read_u32(buf)? as usize;
306                let mut elems = Vec::new();
307                for _ in 0..len {
308                    elems.push(G::deserialize_from_narg(buf)?);
309                }
310                Ok(ComposedCommitment::Simple(elems))
311            }
312            TAG_AND => {
313                let len = read_u32(buf)? as usize;
314                let mut entries = Vec::new();
315                for _ in 0..len {
316                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
317                }
318                Ok(ComposedCommitment::And(entries))
319            }
320            TAG_OR => {
321                let len = read_u32(buf)? as usize;
322                let mut entries = Vec::new();
323                for _ in 0..len {
324                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
325                }
326                Ok(ComposedCommitment::Or(entries))
327            }
328            TAG_THRESHOLD => {
329                let len = read_u32(buf)? as usize;
330                let mut entries = Vec::new();
331                for _ in 0..len {
332                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
333                }
334                Ok(ComposedCommitment::Threshold(entries))
335            }
336            _ => Err(VerificationError),
337        }
338    }
339}
340
341impl<G> Encoding<[u8]> for ComposedResponse<G>
342where
343    G: PrimeGroup
344        + ConditionallySelectable
345        + Encoding<[u8]>
346        + NargSerialize
347        + NargDeserialize
348        + MultiScalarMul,
349    G::Scalar:
350        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
351{
352    fn encode(&self) -> impl AsRef<[u8]> {
353        let mut out = Vec::new();
354        match self {
355            ComposedResponse::Simple(responses) => {
356                out.push(TAG_SIMPLE);
357                write_len(&mut out, responses.len());
358                for r in responses {
359                    r.serialize_into_narg(&mut out);
360                }
361            }
362            ComposedResponse::And(entries) => {
363                out.push(TAG_AND);
364                write_len(&mut out, entries.len());
365                for r in entries {
366                    r.serialize_into_narg(&mut out);
367                }
368            }
369            ComposedResponse::Or(challenges, responses) => {
370                out.push(TAG_OR);
371                write_len(&mut out, challenges.len());
372                for c in challenges {
373                    c.serialize_into_narg(&mut out);
374                }
375                write_len(&mut out, responses.len());
376                for r in responses {
377                    r.serialize_into_narg(&mut out);
378                }
379            }
380            ComposedResponse::Threshold(challenges, responses) => {
381                out.push(TAG_THRESHOLD);
382                write_len(&mut out, challenges.len());
383                for c in challenges {
384                    c.serialize_into_narg(&mut out);
385                }
386                write_len(&mut out, responses.len());
387                for r in responses {
388                    r.serialize_into_narg(&mut out);
389                }
390            }
391        }
392        out
393    }
394}
395
396impl<G> NargDeserialize for ComposedResponse<G>
397where
398    G: PrimeGroup
399        + ConditionallySelectable
400        + Encoding<[u8]>
401        + NargSerialize
402        + NargDeserialize
403        + MultiScalarMul,
404    G::Scalar:
405        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
406{
407    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
408        if buf.is_empty() {
409            return Err(VerificationError);
410        }
411        let (tag_bytes, rest) = buf.split_at(1);
412        *buf = rest;
413        match tag_bytes[0] {
414            TAG_SIMPLE => {
415                let len = read_u32(buf)? as usize;
416                let mut elems = Vec::new();
417                for _ in 0..len {
418                    elems.push(G::Scalar::deserialize_from_narg(buf)?);
419                }
420                Ok(ComposedResponse::Simple(elems))
421            }
422            TAG_AND => {
423                let len = read_u32(buf)? as usize;
424                let mut entries = Vec::new();
425                for _ in 0..len {
426                    entries.push(ComposedResponse::deserialize_from_narg(buf)?);
427                }
428                Ok(ComposedResponse::And(entries))
429            }
430            TAG_OR => {
431                let ch_len = read_u32(buf)? as usize;
432                let mut challenges = Vec::new();
433                for _ in 0..ch_len {
434                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
435                }
436                let resp_len = read_u32(buf)? as usize;
437                let mut responses = Vec::new();
438                for _ in 0..resp_len {
439                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
440                }
441                Ok(ComposedResponse::Or(challenges, responses))
442            }
443            TAG_THRESHOLD => {
444                let ch_len = read_u32(buf)? as usize;
445                let mut challenges = Vec::new();
446                for _ in 0..ch_len {
447                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
448                }
449                let resp_len = read_u32(buf)? as usize;
450                let mut responses = Vec::new();
451                for _ in 0..resp_len {
452                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
453                }
454                Ok(ComposedResponse::Threshold(challenges, responses))
455            }
456            _ => Err(VerificationError),
457        }
458    }
459}
460
461impl<G> ComposedResponse<G>
462where
463    G: PrimeGroup
464        + ConditionallySelectable
465        + Encoding<[u8]>
466        + NargSerialize
467        + NargDeserialize
468        + MultiScalarMul,
469    G::Scalar:
470        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
471{
472    /// Conditionally select between two ComposedResponse values.
473    /// This function performs constant-time selection of the response values.
474    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
475        match (a, b) {
476            (ComposedResponse::Simple(a_scalars), ComposedResponse::Simple(b_scalars)) => {
477                let selected: Vec<G::Scalar> = a_scalars
478                    .iter()
479                    .zip_eq(b_scalars.iter())
480                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
481                    .collect();
482                ComposedResponse::Simple(selected)
483            }
484            (ComposedResponse::And(a_responses), ComposedResponse::And(b_responses)) => {
485                let selected: Vec<ComposedResponse<G>> = a_responses
486                    .iter()
487                    .zip_eq(b_responses.iter())
488                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
489                    .collect();
490                ComposedResponse::And(selected)
491            }
492            (
493                ComposedResponse::Or(a_challenges, a_responses),
494                ComposedResponse::Or(b_challenges, b_responses),
495            ) => {
496                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
497                    .iter()
498                    .zip_eq(b_challenges.iter())
499                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
500                    .collect();
501
502                let selected_responses: Vec<ComposedResponse<G>> = a_responses
503                    .iter()
504                    .zip_eq(b_responses.iter())
505                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
506                    .collect();
507
508                ComposedResponse::Or(selected_challenges, selected_responses)
509            }
510            (
511                ComposedResponse::Threshold(a_challenges, a_responses),
512                ComposedResponse::Threshold(b_challenges, b_responses),
513            ) => {
514                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
515                    .iter()
516                    .zip_eq(b_challenges.iter())
517                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
518                    .collect();
519
520                let selected_responses: Vec<ComposedResponse<G>> = a_responses
521                    .iter()
522                    .zip_eq(b_responses.iter())
523                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
524                    .collect();
525
526                ComposedResponse::Threshold(selected_challenges, selected_responses)
527            }
528            _ => {
529                unreachable!("Mismatched ComposedResponse variants in conditional_select");
530            }
531        }
532    }
533}
534
535// Structure representing the Witness type of Protocol as SigmaProtocol
536#[derive(Clone)]
537pub enum ComposedWitness<G>
538where
539    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
540    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
541{
542    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::Witness),
543    And(Vec<ComposedWitness<G>>),
544    Or(Vec<ComposedWitness<G>>),
545    Threshold(Vec<ComposedWitness<G>>),
546}
547
548impl<G> ComposedWitness<G>
549where
550    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
551    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
552{
553    /// Create a [ComposedWitness] for an AND relation from the given list of witnesses.
554    pub fn and<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
555        Self::And(witness.into_iter().map(|x| x.into()).collect())
556    }
557
558    /// Create a [ComposedWitness] for an OR relation from the given list of witnesses.
559    pub fn or<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
560        Self::Or(witness.into_iter().map(|x| x.into()).collect())
561    }
562
563    /// Create a [ComposedWitness] for a threshold relation from the given list of witnesses.
564    pub fn threshold<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
565        Self::Threshold(witness.into_iter().map(|x| x.into()).collect())
566    }
567}
568
569impl<G> From<<CanonicalLinearRelation<G> as SigmaProtocol>::Witness> for ComposedWitness<G>
570where
571    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
572    G::Scalar:
573        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
574{
575    fn from(value: <CanonicalLinearRelation<G> as SigmaProtocol>::Witness) -> Self {
576        Self::Simple(value)
577    }
578}
579
580type ComposedChallenge<G> = <CanonicalLinearRelation<G> as SigmaProtocol>::Challenge;
581fn threshold_x<F: PrimeField>(index: usize) -> F {
582    F::from((index + 1) as u64)
583}
584
585fn poly_mul_linear<F: Field>(coeffs: &[F], constant: F) -> Vec<F> {
586    let mut out = vec![F::ZERO; coeffs.len() + 1];
587    for (i, coeff) in coeffs.iter().enumerate() {
588        out[i] += *coeff * constant;
589        out[i + 1] += *coeff;
590    }
591    out
592}
593
594fn interpolate_polynomial<F: Field>(points: &[Evaluation<F>]) -> Result<Vec<F>, Error> {
595    if points.is_empty() {
596        return Err(Error::InvalidInstanceWitnessPair);
597    }
598
599    let mut coeffs = vec![F::ZERO; points.len()];
600
601    for (i, point_i) in points.iter().enumerate() {
602        let mut basis = vec![F::ONE];
603        let mut denom = F::ONE;
604
605        for (j, point_j) in points.iter().enumerate() {
606            if i == j {
607                continue;
608            }
609            denom *= point_i.x - point_j.x;
610            basis = poly_mul_linear::<F>(&basis, -point_j.x);
611        }
612
613        let denom_inv = denom.invert();
614        if denom_inv.is_none().into() {
615            return Err(Error::InvalidInstanceWitnessPair);
616        }
617        let scale = point_i.y * denom_inv.unwrap_or(F::ZERO);
618        for (coeff, basis_coeff) in coeffs.iter_mut().zip_eq(basis.iter()) {
619            *coeff += *basis_coeff * scale;
620        }
621    }
622
623    Ok(coeffs)
624}
625
626fn evaluate_polynomial<F: Field>(coeffs: &[F], x: F) -> F {
627    coeffs
628        .iter()
629        .rev()
630        .fold(F::ZERO, |acc, coeff| acc * x + coeff)
631}
632
633fn expand_threshold_challenges<F: PrimeField>(
634    threshold: usize,
635    total: usize,
636    challenge: F,
637    compressed_challenges: &[F],
638) -> Result<Vec<F>, Error> {
639    if threshold == 0 || threshold > total {
640        return Err(Error::InvalidInstanceWitnessPair);
641    }
642
643    let degree = total - threshold;
644    if compressed_challenges.len() != degree {
645        return Err(Error::InvalidInstanceWitnessPair);
646    }
647
648    let mut points = Vec::with_capacity(degree + 1);
649    points.push(Evaluation {
650        x: F::ZERO,
651        y: challenge,
652    });
653    for (index, share) in compressed_challenges.iter().enumerate() {
654        points.push(Evaluation {
655            x: threshold_x::<F>(index),
656            y: *share,
657        });
658    }
659
660    let coeffs = interpolate_polynomial::<F>(&points)?;
661    let mut challenges = Vec::with_capacity(total);
662    for index in 0..total {
663        challenges.push(evaluate_polynomial::<F>(&coeffs, threshold_x::<F>(index)));
664    }
665
666    Ok(challenges)
667}
668
669fn count_choices(choices: &[Choice]) -> usize {
670    let mut sum: u32 = 0;
671    for choice in choices {
672        let inc = sum.wrapping_add(1);
673        sum = u32::conditional_select(&sum, &inc, *choice);
674    }
675    sum as usize
676}
677
678#[derive(Clone, Copy)]
679struct Evaluation<T> {
680    x: T,
681    y: T,
682}
683
684impl<T: ConditionallySelectable> ConditionallySelectable for Evaluation<T> {
685    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
686        Evaluation {
687            x: T::conditional_select(&a.x, &b.x, choice),
688            y: T::conditional_select(&a.y, &b.y, choice),
689        }
690    }
691}
692
693impl<T> From<(T, T)> for Evaluation<T> {
694    fn from(value: (T, T)) -> Self {
695        Evaluation {
696            x: value.0,
697            y: value.1,
698        }
699    }
700}
701
702fn conditional_swap_point<T: ConditionallySelectable>(
703    points: &mut [T],
704    left: usize,
705    right: usize,
706    swap: Choice,
707) {
708    if left == right {
709        return;
710    }
711    if left < right {
712        let (head, tail) = points.split_at_mut(right);
713        T::conditional_swap(&mut head[left], &mut tail[0], swap);
714    } else {
715        let (head, tail) = points.split_at_mut(left);
716        T::conditional_swap(&mut tail[0], &mut head[right], swap);
717    }
718}
719
720fn oroffcompact_points<T: ConditionallySelectable>(
721    points: &mut [T],
722    marks: &[Choice],
723    offset: usize,
724) {
725    let n = points.len();
726    if n <= 1 {
727        return;
728    }
729    debug_assert_eq!(n, marks.len());
730    debug_assert!(n.is_power_of_two());
731
732    let half = n / 2;
733    let mut m = 0usize;
734    for mark in &marks[..half] {
735        m += mark.unwrap_u8() as usize;
736    }
737
738    if n == 2 {
739        let z = Choice::from((offset & 1) as u8);
740        let b = ((!marks[0]) & marks[1]) ^ z;
741        conditional_swap_point(points, 0, 1, b);
742        return;
743    }
744
745    let offset_mod = offset % half;
746    oroffcompact_points(&mut points[..half], &marks[..half], offset_mod);
747    let offset_plus_m_mod = (offset + m) % half;
748    oroffcompact_points(&mut points[half..], &marks[half..], offset_plus_m_mod);
749
750    let s = Choice::from(((offset_mod + m) >= half) as u8) ^ Choice::from((offset >= half) as u8);
751    for i in 0..half {
752        let b = s ^ Choice::from((i >= offset_plus_m_mod) as u8);
753        conditional_swap_point(points, i, i + half, b);
754    }
755}
756
757fn oblivious_compact_points<T: ConditionallySelectable>(points: &mut [T], marks: &[Choice]) {
758    let n = points.len();
759    if n == 0 {
760        return;
761    }
762    debug_assert_eq!(n, marks.len());
763
764    let n1 = 1usize << (usize::BITS as usize - 1 - n.leading_zeros() as usize);
765    let n2 = n - n1;
766    let mut m = 0usize;
767    for mark in &marks[..n2] {
768        m += mark.unwrap_u8() as usize;
769    }
770
771    if n2 > 0 {
772        oblivious_compact_points(&mut points[..n2], &marks[..n2]);
773    }
774    oroffcompact_points(&mut points[n2..], &marks[n2..], (n1 - n2 + m) % n1);
775
776    for i in 0..n2 {
777        let b = Choice::from((i >= m) as u8);
778        conditional_swap_point(points, i, i + n1, b);
779    }
780}
781
782impl<G> ComposedRelation<G>
783where
784    G: PrimeGroup
785        + ConstantTimeEq
786        + ConditionallySelectable
787        + Encoding<[u8]>
788        + NargSerialize
789        + NargDeserialize
790        + MultiScalarMul,
791    G::Scalar:
792        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
793{
794    fn is_witness_valid(&self, witness: &ComposedWitness<G>) -> Choice {
795        match (self, witness) {
796            (ComposedRelation::Simple(instance), ComposedWitness::Simple(witness)) => {
797                instance.is_witness_valid(witness)
798            }
799            (ComposedRelation::And(instances), ComposedWitness::And(witnesses)) => {
800                if instances.len() != witnesses.len() {
801                    return Choice::from(0);
802                }
803                instances
804                    .iter()
805                    .zip_eq(witnesses)
806                    .fold(Choice::from(1), |bit, (instance, witness)| {
807                        bit & instance.is_witness_valid(witness)
808                    })
809            }
810            (ComposedRelation::Or(instances), ComposedWitness::Or(witnesses)) => {
811                if instances.len() != witnesses.len() {
812                    return Choice::from(0);
813                }
814                instances
815                    .iter()
816                    .zip_eq(witnesses)
817                    .fold(Choice::from(0), |bit, (instance, witness)| {
818                        bit | instance.is_witness_valid(witness)
819                    })
820            }
821            (
822                ComposedRelation::Threshold(threshold, instances),
823                ComposedWitness::Threshold(witnesses),
824            ) => {
825                if *threshold == 0 || instances.len() != witnesses.len() {
826                    return Choice::from(0);
827                }
828                let mut count = 0usize;
829                for (instance, witness) in instances.iter().zip_eq(witnesses) {
830                    if instance.is_witness_valid(witness).unwrap_u8() == 1 {
831                        count += 1;
832                    }
833                }
834                Choice::from((count >= *threshold) as u8)
835            }
836            _ => Choice::from(0),
837        }
838    }
839
840    fn prover_commit_simple(
841        protocol: &CanonicalLinearRelation<G>,
842        witness: &<CanonicalLinearRelation<G> as SigmaProtocol>::Witness,
843        rng: &mut impl ScalarRng,
844    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
845        protocol.prover_commit(witness, rng).map(|(c, s)| {
846            (
847                ComposedCommitment::Simple(c),
848                ComposedProverState::Simple(s),
849            )
850        })
851    }
852
853    fn prover_response_simple(
854        instance: &CanonicalLinearRelation<G>,
855        state: <CanonicalLinearRelation<G> as SigmaProtocol>::ProverState,
856        challenge: &<CanonicalLinearRelation<G> as SigmaProtocol>::Challenge,
857    ) -> Result<ComposedResponse<G>, Error> {
858        instance
859            .prover_response(state, challenge)
860            .map(ComposedResponse::Simple)
861    }
862
863    fn prover_commit_and(
864        protocols: &[ComposedRelation<G>],
865        witnesses: &[ComposedWitness<G>],
866        rng: &mut impl ScalarRng,
867    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
868        if protocols.len() != witnesses.len() {
869            return Err(Error::InvalidInstanceWitnessPair);
870        }
871
872        let mut commitments = Vec::with_capacity(protocols.len());
873        let mut prover_states = Vec::with_capacity(protocols.len());
874
875        for (p, w) in protocols.iter().zip_eq(witnesses.iter()) {
876            let (mut c, s) = p.prover_commit(w, rng)?;
877            let commitment = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
878            if !c.is_empty() {
879                return Err(Error::InvalidInstanceWitnessPair);
880            }
881            commitments.push(commitment);
882            prover_states.push(s);
883        }
884
885        Ok((
886            ComposedCommitment::And(commitments),
887            ComposedProverState::And(prover_states),
888        ))
889    }
890
891    fn prover_response_and(
892        instances: &[ComposedRelation<G>],
893        prover_state: Vec<ComposedProverState<G>>,
894        challenge: &ComposedChallenge<G>,
895    ) -> Result<ComposedResponse<G>, Error> {
896        if instances.len() != prover_state.len() {
897            return Err(Error::InvalidInstanceWitnessPair);
898        }
899
900        let responses: Result<Vec<_>, _> = instances
901            .iter()
902            .zip_eq(prover_state)
903            .map(|(p, s)| {
904                let mut res = p.prover_response(s, challenge)?;
905                res.pop().ok_or(Error::InvalidInstanceWitnessPair)
906            })
907            .collect();
908
909        Ok(ComposedResponse::And(responses?))
910    }
911
912    fn prover_commit_or(
913        instances: &[ComposedRelation<G>],
914        witnesses: &[ComposedWitness<G>],
915        rng: &mut impl ScalarRng,
916    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
917    where
918        G: ConditionallySelectable,
919    {
920        if instances.len() != witnesses.len() {
921            return Err(Error::InvalidInstanceWitnessPair);
922        }
923
924        let mut commitments = Vec::new();
925        let mut prover_states = Vec::new();
926
927        // Selector value set when the first valid witness is found.
928        let mut valid_witness_found = Choice::from(0);
929        for (i, w) in witnesses.iter().enumerate() {
930            let (mut commitment_vec, prover_state) = instances[i].prover_commit(w, rng)?;
931            let commitment = commitment_vec
932                .pop()
933                .ok_or(Error::InvalidInstanceWitnessPair)?;
934            if !commitment_vec.is_empty() {
935                return Err(Error::InvalidInstanceWitnessPair);
936            }
937
938            let (mut simulated_commitment_vec, simulated_challenge, mut simulated_response_vec) =
939                instances[i].simulate_transcript(rng)?;
940            let simulated_commitment = simulated_commitment_vec
941                .pop()
942                .ok_or(Error::InvalidInstanceWitnessPair)?;
943            if !simulated_commitment_vec.is_empty() {
944                return Err(Error::InvalidInstanceWitnessPair);
945            }
946            let simulated_response = simulated_response_vec
947                .pop()
948                .ok_or(Error::InvalidInstanceWitnessPair)?;
949            if !simulated_response_vec.is_empty() {
950                return Err(Error::InvalidInstanceWitnessPair);
951            }
952
953            let valid_witness = instances[i].is_witness_valid(w) & !valid_witness_found;
954            let select_witness = valid_witness;
955
956            let commitment = ComposedCommitment::conditional_select(
957                &simulated_commitment,
958                &commitment,
959                select_witness,
960            );
961
962            commitments.push(commitment);
963            prover_states.push(ComposedOrProverStateEntry(
964                select_witness,
965                prover_state,
966                simulated_challenge,
967                simulated_response,
968            ));
969
970            valid_witness_found |= valid_witness;
971        }
972
973        if valid_witness_found.unwrap_u8() == 0 {
974            Err(Error::InvalidInstanceWitnessPair)
975        } else {
976            Ok((
977                ComposedCommitment::Or(commitments),
978                ComposedProverState::Or(prover_states),
979            ))
980        }
981    }
982
983    fn prover_response_or(
984        instances: &[ComposedRelation<G>],
985        prover_state: ComposedOrProverState<G>,
986        challenge: &ComposedChallenge<G>,
987    ) -> Result<ComposedResponse<G>, Error> {
988        let mut result_challenges = Vec::with_capacity(instances.len());
989        let mut result_responses = Vec::with_capacity(instances.len());
990
991        if instances.len() != prover_state.len() {
992            return Err(Error::InvalidInstanceWitnessPair);
993        }
994
995        let mut witness_challenge = *challenge;
996        for ComposedOrProverStateEntry(
997            valid_witness,
998            _prover_state,
999            simulated_challenge,
1000            _simulated_response,
1001        ) in &prover_state
1002        {
1003            let c = G::Scalar::conditional_select(
1004                simulated_challenge,
1005                &G::Scalar::ZERO,
1006                *valid_witness,
1007            );
1008            witness_challenge -= c;
1009        }
1010        for (
1011            instance,
1012            ComposedOrProverStateEntry(
1013                valid_witness,
1014                prover_state,
1015                simulated_challenge,
1016                simulated_response,
1017            ),
1018        ) in instances.iter().zip_eq(prover_state)
1019        {
1020            let challenge_i = G::Scalar::conditional_select(
1021                &simulated_challenge,
1022                &witness_challenge,
1023                valid_witness,
1024            );
1025
1026            let mut response_vec = instance.prover_response(prover_state, &challenge_i)?;
1027            let response = response_vec
1028                .pop()
1029                .ok_or(Error::InvalidInstanceWitnessPair)?;
1030            if !response_vec.is_empty() {
1031                return Err(Error::InvalidInstanceWitnessPair);
1032            }
1033            let response =
1034                ComposedResponse::conditional_select(&simulated_response, &response, valid_witness);
1035
1036            result_challenges.push(challenge_i);
1037            result_responses.push(response.clone());
1038        }
1039
1040        result_challenges.pop();
1041        Ok(ComposedResponse::Or(result_challenges, result_responses))
1042    }
1043
1044    fn prover_commit_threshold(
1045        threshold: usize,
1046        instances: &[ComposedRelation<G>],
1047        witnesses: &[ComposedWitness<G>],
1048        rng: &mut impl ScalarRng,
1049    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
1050    where
1051        G: ConditionallySelectable,
1052    {
1053        if instances.len() != witnesses.len() || threshold == 0 || threshold > instances.len() {
1054            return Err(Error::InvalidInstanceWitnessPair);
1055        }
1056        let degree = instances.len() - threshold;
1057
1058        let valid_witnesses = instances
1059            .iter()
1060            .zip_eq(witnesses.iter())
1061            .map(|(x, w)| x.is_witness_valid(w))
1062            .collect::<Vec<Choice>>();
1063
1064        // Degree-(t-1) interpolation can only satisfy t fixed points.
1065        let invalid_count = instances.len() - count_choices(&valid_witnesses);
1066        if invalid_count > degree {
1067            return Err(Error::InvalidInstanceWitnessPair);
1068        }
1069
1070        let mut remaining_seeds = (degree - invalid_count) as u32;
1071        let mut commitments = Vec::with_capacity(instances.len());
1072        let mut prover_states = Vec::with_capacity(instances.len());
1073        for (i, (instance, witness)) in instances.iter().zip_eq(witnesses.iter()).enumerate() {
1074            let (mut commitment_vec, prover_state) = instance.prover_commit(witness, rng)?;
1075            let commitment = commitment_vec
1076                .pop()
1077                .ok_or(Error::InvalidInstanceWitnessPair)?;
1078            if !commitment_vec.is_empty() {
1079                return Err(Error::InvalidInstanceWitnessPair);
1080            }
1081
1082            let (mut simulated_commitments, simulated_challenge, mut simulated_responses) =
1083                instance.simulate_transcript(rng)?;
1084            let simulated_commitment = simulated_commitments
1085                .pop()
1086                .ok_or(Error::InvalidInstanceWitnessPair)?;
1087            if !simulated_commitments.is_empty() {
1088                return Err(Error::InvalidInstanceWitnessPair);
1089            }
1090            let simulated_response = simulated_responses
1091                .pop()
1092                .ok_or(Error::InvalidInstanceWitnessPair)?;
1093            if !simulated_responses.is_empty() {
1094                return Err(Error::InvalidInstanceWitnessPair);
1095            }
1096
1097            let valid_witness = valid_witnesses[i];
1098            let should_seed = valid_witness & Choice::from((remaining_seeds != 0) as u8);
1099            remaining_seeds = remaining_seeds.wrapping_sub(should_seed.unwrap_u8() as u32);
1100            let use_simulator = (!valid_witness) | should_seed;
1101            let commitment = ComposedCommitment::conditional_select(
1102                &commitment,
1103                &simulated_commitment,
1104                use_simulator,
1105            );
1106            commitments.push(commitment);
1107            prover_states.push(ComposedThresholdProverStateEntry {
1108                use_simulator,
1109                prover_state,
1110                simulated_challenge,
1111                simulated_response,
1112            });
1113        }
1114
1115        Ok((
1116            ComposedCommitment::Threshold(commitments),
1117            ComposedProverState::Threshold(prover_states),
1118        ))
1119    }
1120
1121    fn prover_response_threshold(
1122        threshold: usize,
1123        instances: &[ComposedRelation<G>],
1124        prover_states: ComposedThresholdProverState<G>,
1125        challenge: &ComposedChallenge<G>,
1126    ) -> Result<ComposedResponse<G>, Error> {
1127        if threshold == 0 || threshold > instances.len() || instances.len() != prover_states.len() {
1128            return Err(Error::InvalidInstanceWitnessPair);
1129        }
1130        let degree = instances.len() - threshold;
1131
1132        let marks = prover_states
1133            .iter()
1134            .map(|entry| entry.use_simulator)
1135            .collect::<Vec<_>>();
1136        debug_assert_eq!(count_choices(&marks), degree);
1137
1138        let mut points = prover_states
1139            .iter()
1140            .enumerate()
1141            .map(|(i, entry)| Evaluation {
1142                x: threshold_x::<G::Scalar>(i),
1143                y: entry.simulated_challenge,
1144            })
1145            .collect::<Vec<Evaluation<G::Scalar>>>();
1146        oblivious_compact_points(&mut points, &marks);
1147        points.drain(degree..);
1148
1149        let mut full_points = Vec::with_capacity(degree + 1);
1150        full_points.push(Evaluation {
1151            x: G::Scalar::ZERO,
1152            y: *challenge,
1153        });
1154        full_points.extend_from_slice(&points);
1155
1156        let coeffs = interpolate_polynomial::<G::Scalar>(&full_points)?;
1157        let mut compressed_challenges = Vec::with_capacity(degree);
1158        for index in 0..degree {
1159            compressed_challenges.push(evaluate_polynomial::<G::Scalar>(
1160                &coeffs,
1161                threshold_x::<G::Scalar>(index),
1162            ));
1163        }
1164
1165        let expanded_challenges = expand_threshold_challenges::<G::Scalar>(
1166            threshold,
1167            instances.len(),
1168            *challenge,
1169            &compressed_challenges,
1170        )?;
1171
1172        let mut responses = Vec::with_capacity(instances.len());
1173
1174        for (i, (instance, prover_state)) in instances.iter().zip_eq(prover_states).enumerate() {
1175            let poly_challenge = expanded_challenges[i];
1176            let challenge = G::Scalar::conditional_select(
1177                &poly_challenge,
1178                &prover_state.simulated_challenge,
1179                prover_state.use_simulator,
1180            );
1181
1182            let mut response_vec =
1183                instance.prover_response(prover_state.prover_state, &challenge)?;
1184            let response = response_vec
1185                .pop()
1186                .ok_or(Error::InvalidInstanceWitnessPair)?;
1187            if !response_vec.is_empty() {
1188                return Err(Error::InvalidInstanceWitnessPair);
1189            }
1190            let response = ComposedResponse::conditional_select(
1191                &response,
1192                &prover_state.simulated_response,
1193                prover_state.use_simulator,
1194            );
1195
1196            responses.push(response);
1197        }
1198
1199        Ok(ComposedResponse::Threshold(
1200            compressed_challenges,
1201            responses,
1202        ))
1203    }
1204}
1205
1206impl<G> SigmaProtocol for ComposedRelation<G>
1207where
1208    G: PrimeGroup
1209        + ConstantTimeEq
1210        + ConditionallySelectable
1211        + Encoding<[u8]>
1212        + NargSerialize
1213        + NargDeserialize
1214        + MultiScalarMul,
1215    G::Scalar:
1216        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1217{
1218    type Commitment = ComposedCommitment<G>;
1219    type ProverState = ComposedProverState<G>;
1220    type Response = ComposedResponse<G>;
1221    type Witness = ComposedWitness<G>;
1222    type Challenge = ComposedChallenge<G>;
1223
1224    fn prover_commit(
1225        &self,
1226        witness: &Self::Witness,
1227        rng: &mut impl ScalarRng,
1228    ) -> Result<(Vec<Self::Commitment>, Self::ProverState), Error> {
1229        let (commitment, state) = match (self, witness) {
1230            (ComposedRelation::Simple(p), ComposedWitness::Simple(w)) => {
1231                Self::prover_commit_simple(p, w, rng)
1232            }
1233            (ComposedRelation::And(ps), ComposedWitness::And(ws)) => {
1234                Self::prover_commit_and(ps, ws, rng)
1235            }
1236            (ComposedRelation::Or(ps), ComposedWitness::Or(witnesses)) => {
1237                Self::prover_commit_or(ps, witnesses, rng)
1238            }
1239            (ComposedRelation::Threshold(threshold, ps), ComposedWitness::Threshold(witnesses)) => {
1240                Self::prover_commit_threshold(*threshold, ps, witnesses, rng)
1241            }
1242            _ => Err(Error::InvalidInstanceWitnessPair),
1243        }?;
1244        Ok((vec![commitment], state))
1245    }
1246
1247    fn prover_response(
1248        &self,
1249        state: Self::ProverState,
1250        challenge: &Self::Challenge,
1251    ) -> Result<Vec<Self::Response>, Error> {
1252        let response = match (self, state) {
1253            (ComposedRelation::Simple(instance), ComposedProverState::Simple(state)) => {
1254                Self::prover_response_simple(instance, state, challenge)
1255            }
1256            (ComposedRelation::And(instances), ComposedProverState::And(prover_state)) => {
1257                Self::prover_response_and(instances, prover_state, challenge)
1258            }
1259            (ComposedRelation::Or(instances), ComposedProverState::Or(prover_state)) => {
1260                Self::prover_response_or(instances, prover_state, challenge)
1261            }
1262            (
1263                ComposedRelation::Threshold(threshold, instances),
1264                ComposedProverState::Threshold(prover_state),
1265            ) => Self::prover_response_threshold(*threshold, instances, prover_state, challenge),
1266            _ => Err(Error::InvalidInstanceWitnessPair),
1267        }?;
1268        Ok(vec![response])
1269    }
1270
1271    fn verifier(
1272        &self,
1273        commitment: &[Self::Commitment],
1274        challenge: &Self::Challenge,
1275        response: &[Self::Response],
1276    ) -> Result<(), Error> {
1277        let (commitment, response) = match (commitment.first(), response.first()) {
1278            (Some(c), Some(r)) => (c, r),
1279            _ => return Err(Error::InvalidInstanceWitnessPair),
1280        };
1281
1282        match (self, commitment, response) {
1283            (
1284                ComposedRelation::Simple(p),
1285                ComposedCommitment::Simple(c),
1286                ComposedResponse::Simple(r),
1287            ) => p.verifier(c, challenge, r),
1288            (
1289                ComposedRelation::And(ps),
1290                ComposedCommitment::And(commitments),
1291                ComposedResponse::And(responses),
1292            ) => {
1293                if ps.len() != commitments.len() || commitments.len() != responses.len() {
1294                    return Err(Error::InvalidInstanceWitnessPair);
1295                }
1296                ps.iter()
1297                    .zip_eq(commitments)
1298                    .zip_eq(responses)
1299                    .try_for_each(|((p, c), r)| {
1300                        p.verifier(
1301                            core::slice::from_ref(c),
1302                            challenge,
1303                            core::slice::from_ref(r),
1304                        )
1305                    })
1306            }
1307            (
1308                ComposedRelation::Or(ps),
1309                ComposedCommitment::Or(commitments),
1310                ComposedResponse::Or(challenges, responses),
1311            ) => {
1312                if ps.len() != commitments.len()
1313                    || commitments.len() != responses.len()
1314                    || challenges.len() != ps.len() - 1
1315                {
1316                    return Err(Error::InvalidInstanceWitnessPair);
1317                }
1318                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1319                ps.iter()
1320                    .zip_eq(commitments)
1321                    .zip_eq(challenges.iter().chain(&Some(last_challenge)))
1322                    .zip_eq(responses)
1323                    .try_for_each(|(((p, commitment), challenge), response)| {
1324                        p.verifier(
1325                            core::slice::from_ref(commitment),
1326                            challenge,
1327                            core::slice::from_ref(response),
1328                        )
1329                    })
1330            }
1331            (
1332                ComposedRelation::Threshold(threshold, ps),
1333                ComposedCommitment::Threshold(commitments),
1334                ComposedResponse::Threshold(challenges, responses),
1335            ) => {
1336                if *threshold == 0
1337                    || *threshold > ps.len()
1338                    || commitments.len() != ps.len()
1339                    || challenges.len() != ps.len() - *threshold
1340                    || responses.len() != ps.len()
1341                {
1342                    return Err(Error::InvalidInstanceWitnessPair);
1343                }
1344
1345                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1346                    *threshold,
1347                    ps.len(),
1348                    *challenge,
1349                    challenges,
1350                )?;
1351
1352                ps.iter()
1353                    .zip_eq(commitments)
1354                    .zip_eq(full_challenges.iter())
1355                    .zip_eq(responses)
1356                    .try_for_each(|(((p, commitment), challenge), response)| {
1357                        p.verifier(
1358                            core::slice::from_ref(commitment),
1359                            challenge,
1360                            core::slice::from_ref(response),
1361                        )
1362                    })
1363            }
1364            _ => Err(Error::InvalidInstanceWitnessPair),
1365        }
1366    }
1367
1368    fn commitment_len(&self) -> usize {
1369        1
1370    }
1371
1372    fn response_len(&self) -> usize {
1373        1
1374    }
1375
1376    fn instance_label(&self) -> impl AsRef<[u8]> {
1377        match self {
1378            ComposedRelation::Simple(p) => {
1379                let label = p.instance_label();
1380                label.as_ref().to_vec()
1381            }
1382            ComposedRelation::And(ps) => {
1383                let mut bytes = Vec::new();
1384                for p in ps {
1385                    bytes.extend(p.instance_label().as_ref());
1386                }
1387                bytes
1388            }
1389            ComposedRelation::Or(ps) => {
1390                let mut bytes = Vec::new();
1391                for p in ps {
1392                    bytes.extend(p.instance_label().as_ref());
1393                }
1394                bytes
1395            }
1396            ComposedRelation::Threshold(threshold, ps) => {
1397                let mut bytes = Vec::new();
1398                bytes.extend_from_slice(&((*threshold as u64).to_le_bytes()));
1399                for p in ps {
1400                    bytes.extend(p.instance_label().as_ref());
1401                }
1402                bytes
1403            }
1404        }
1405    }
1406
1407    fn protocol_identifier(&self) -> [u8; 64] {
1408        let mut hasher = Sha3_256::new();
1409
1410        match self {
1411            ComposedRelation::Simple(p) => {
1412                // take the digest of the simple protocol id
1413                hasher.update([0u8; 32]);
1414                hasher.update(p.protocol_identifier());
1415            }
1416            ComposedRelation::And(protocols) => {
1417                hasher.update([1u8; 32]);
1418                for p in protocols {
1419                    hasher.update(p.protocol_identifier().as_ref());
1420                }
1421            }
1422            ComposedRelation::Or(protocols) => {
1423                hasher.update([2u8; 32]);
1424                for p in protocols {
1425                    hasher.update(p.protocol_identifier().as_ref());
1426                }
1427            }
1428            ComposedRelation::Threshold(threshold, protocols) => {
1429                hasher.update([3u8; 32]);
1430                hasher.update(((*threshold as u64).to_le_bytes()).as_ref());
1431                for p in protocols {
1432                    hasher.update(p.protocol_identifier().as_ref());
1433                }
1434            }
1435        }
1436
1437        let mut protocol_id = [0u8; 64];
1438        protocol_id[..32].clone_from_slice(&hasher.finalize());
1439        protocol_id
1440    }
1441}
1442
1443impl<G> SigmaProtocolSimulator for ComposedRelation<G>
1444where
1445    G: PrimeGroup
1446        + ConstantTimeEq
1447        + ConditionallySelectable
1448        + Encoding<[u8]>
1449        + NargSerialize
1450        + NargDeserialize
1451        + MultiScalarMul,
1452    G::Scalar:
1453        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1454{
1455    fn simulate_commitment(
1456        &self,
1457        challenge: &Self::Challenge,
1458        response: &[Self::Response],
1459    ) -> Result<Vec<Self::Commitment>, Error> {
1460        let response = response.first().ok_or(Error::InvalidInstanceWitnessPair)?;
1461        let commitment = match (self, response) {
1462            (ComposedRelation::Simple(p), ComposedResponse::Simple(r)) => {
1463                ComposedCommitment::Simple(p.simulate_commitment(challenge, r)?)
1464            }
1465            (ComposedRelation::And(ps), ComposedResponse::And(rs)) => {
1466                if ps.len() != rs.len() {
1467                    return Err(Error::InvalidInstanceWitnessPair);
1468                }
1469                let commitments = ps
1470                    .iter()
1471                    .zip_eq(rs)
1472                    .map(|(p, r)| {
1473                        p.simulate_commitment(challenge, core::slice::from_ref(r))
1474                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1475                    })
1476                    .collect::<Result<Vec<_>, _>>()?;
1477                ComposedCommitment::And(commitments)
1478            }
1479            (ComposedRelation::Or(ps), ComposedResponse::Or(challenges, rs)) => {
1480                if rs.len() != ps.len() || challenges.len() + 1 != ps.len() {
1481                    return Err(Error::InvalidInstanceWitnessPair);
1482                }
1483                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1484                let commitments = ps
1485                    .iter()
1486                    .zip_eq(challenges.iter().chain(&Some(last_challenge)))
1487                    .zip_eq(rs)
1488                    .map(|((p, ch), r)| {
1489                        p.simulate_commitment(ch, core::slice::from_ref(r))
1490                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1491                    })
1492                    .collect::<Result<Vec<_>, _>>()?;
1493                ComposedCommitment::Or(commitments)
1494            }
1495            (
1496                ComposedRelation::Threshold(threshold, ps),
1497                ComposedResponse::Threshold(challenges, rs),
1498            ) => {
1499                if rs.len() != ps.len()
1500                    || ps.len() < *threshold
1501                    || challenges.len() != ps.len() - threshold
1502                {
1503                    return Err(Error::InvalidInstanceWitnessPair);
1504                }
1505
1506                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1507                    *threshold,
1508                    ps.len(),
1509                    *challenge,
1510                    challenges,
1511                )?;
1512                let commitments = ps
1513                    .iter()
1514                    .zip_eq(full_challenges.iter())
1515                    .zip_eq(rs)
1516                    .map(|((p, ch), r)| {
1517                        p.simulate_commitment(ch, core::slice::from_ref(r))
1518                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1519                    })
1520                    .collect::<Result<Vec<_>, _>>()?;
1521                ComposedCommitment::Threshold(commitments)
1522            }
1523            _ => return Err(Error::InvalidInstanceWitnessPair),
1524        };
1525
1526        Ok(vec![commitment])
1527    }
1528
1529    fn simulate_response(&self, rng: &mut impl ScalarRng) -> Vec<Self::Response> {
1530        let response = match self {
1531            ComposedRelation::Simple(p) => ComposedResponse::Simple(p.simulate_response(rng)),
1532            ComposedRelation::And(ps) => {
1533                let responses = ps
1534                    .iter()
1535                    .map(|p| {
1536                        let mut r = p.simulate_response(rng);
1537                        r.pop().ok_or(Error::InvalidInstanceWitnessPair)
1538                    })
1539                    .collect::<Result<Vec<_>, _>>()
1540                    .expect("simulate_response invariant");
1541                ComposedResponse::And(responses)
1542            }
1543            ComposedRelation::Or(ps) => {
1544                let challenges = rng.random_scalars_vec::<G>(ps.len()).to_vec();
1545                let mut responses = Vec::with_capacity(ps.len());
1546                for p in ps.iter() {
1547                    let mut r = p.simulate_response(&mut *rng);
1548                    let resp = r
1549                        .pop()
1550                        .expect("simulate_response should return at least one element");
1551                    responses.push(resp);
1552                }
1553                ComposedResponse::Or(challenges, responses)
1554            }
1555            ComposedRelation::Threshold(threshold, ps) => {
1556                if *threshold == 0 || *threshold > ps.len() {
1557                    return vec![ComposedResponse::Threshold(Vec::new(), Vec::new())];
1558                }
1559
1560                let degree = ps.len() - *threshold;
1561                let compressed_challenges = rng.random_scalars_vec::<G>(degree).to_vec();
1562                let mut responses = Vec::with_capacity(ps.len());
1563                for p in ps.iter() {
1564                    let mut r = p.simulate_response(&mut *rng);
1565                    let response = r
1566                        .pop()
1567                        .expect("simulate_response should return at least one element");
1568                    responses.push(response);
1569                }
1570                ComposedResponse::Threshold(compressed_challenges, responses)
1571            }
1572        };
1573        vec![response]
1574    }
1575
1576    fn simulate_transcript(
1577        &self,
1578        rng: &mut impl ScalarRng,
1579    ) -> Result<(Vec<Self::Commitment>, Self::Challenge, Vec<Self::Response>), Error> {
1580        match self {
1581            ComposedRelation::Simple(p) => {
1582                let (c, ch, r) = p.simulate_transcript(rng)?;
1583                Ok((
1584                    vec![ComposedCommitment::Simple(c)],
1585                    ch,
1586                    vec![ComposedResponse::Simple(r)],
1587                ))
1588            }
1589            ComposedRelation::And(ps) => {
1590                let [challenge] = rng.random_scalars::<G, _>();
1591                let mut responses = Vec::with_capacity(ps.len());
1592                for p in ps.iter() {
1593                    let mut resp = p.simulate_response(&mut *rng);
1594                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1595                    if !resp.is_empty() {
1596                        return Err(Error::InvalidInstanceWitnessPair);
1597                    }
1598                    responses.push(response);
1599                }
1600                let commitments = ps
1601                    .iter()
1602                    .enumerate()
1603                    .map(|(i, p)| {
1604                        p.simulate_commitment(&challenge, &[responses[i].clone()])
1605                            .and_then(|mut c| {
1606                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1607                                if !c.is_empty() {
1608                                    return Err(Error::InvalidInstanceWitnessPair);
1609                                }
1610                                Ok(first)
1611                            })
1612                    })
1613                    .collect::<Result<Vec<_>, Error>>()?;
1614
1615                Ok((
1616                    vec![ComposedCommitment::And(commitments)],
1617                    challenge,
1618                    vec![ComposedResponse::And(responses)],
1619                ))
1620            }
1621            ComposedRelation::Or(ps) => {
1622                let challenges = rng.random_scalars_vec::<G>(ps.len() - 1);
1623                let mut responses = Vec::with_capacity(ps.len());
1624                for p in ps.iter() {
1625                    let mut resp = p.simulate_response(&mut *rng);
1626                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1627                    if !resp.is_empty() {
1628                        return Err(Error::InvalidInstanceWitnessPair);
1629                    }
1630                    responses.push(response);
1631                }
1632
1633                let mut commitments = Vec::with_capacity(ps.len());
1634                for i in 0..ps.len() {
1635                    let mut commitment = ps[i].simulate_commitment(
1636                        &if i == ps.len() - 1 {
1637                            challenges.iter().fold(G::Scalar::ZERO, |acc, x| acc - x)
1638                        } else {
1639                            challenges[i]
1640                        },
1641                        &[responses[i].clone()],
1642                    )?;
1643                    let commitment = commitment.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1644                    commitments.push(commitment);
1645                }
1646
1647                Ok((
1648                    vec![ComposedCommitment::Or(commitments)],
1649                    challenges.iter().sum::<G::Scalar>(),
1650                    vec![ComposedResponse::Or(challenges, responses)],
1651                ))
1652            }
1653            ComposedRelation::Threshold(threshold, ps) => {
1654                if *threshold == 0 || *threshold > ps.len() {
1655                    return Err(Error::InvalidInstanceWitnessPair);
1656                }
1657
1658                let degree = ps.len() - *threshold;
1659                let compressed_challenges = rng.random_scalars_vec::<G>(degree);
1660                let mut responses = Vec::with_capacity(ps.len());
1661                for p in ps.iter() {
1662                    let mut resp = p.simulate_response(&mut *rng);
1663                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1664                    if !resp.is_empty() {
1665                        return Err(Error::InvalidInstanceWitnessPair);
1666                    }
1667                    responses.push(response);
1668                }
1669
1670                let [challenge] = rng.random_scalars::<G, _>();
1671                let full_challenges = expand_threshold_challenges(
1672                    *threshold,
1673                    ps.len(),
1674                    challenge,
1675                    &compressed_challenges,
1676                )?;
1677                let commitments = ps
1678                    .iter()
1679                    .zip_eq(full_challenges.iter())
1680                    .zip_eq(responses.iter())
1681                    .map(|((p, ch), r)| {
1682                        p.simulate_commitment(ch, core::slice::from_ref(r))
1683                            .and_then(|mut c| {
1684                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1685                                if !c.is_empty() {
1686                                    return Err(Error::InvalidInstanceWitnessPair);
1687                                }
1688                                Ok(first)
1689                            })
1690                    })
1691                    .collect::<Result<Vec<_>, Error>>()?;
1692                Ok((
1693                    vec![ComposedCommitment::Threshold(commitments)],
1694                    challenge,
1695                    vec![ComposedResponse::Threshold(
1696                        compressed_challenges,
1697                        responses,
1698                    )],
1699                ))
1700            }
1701        }
1702    }
1703}
1704
1705impl<G> ComposedRelation<G>
1706where
1707    G: PrimeGroup
1708        + ConstantTimeEq
1709        + ConditionallySelectable
1710        + Encoding<[u8]>
1711        + NargSerialize
1712        + NargDeserialize
1713        + MultiScalarMul,
1714    G::Scalar:
1715        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1716{
1717    /// Convert this Protocol into a non-interactive zero-knowledge proof
1718    /// using the Shake128DuplexSponge codec and a specified session identifier.
1719    ///
1720    /// This method provides a convenient way to create a NIZK from a Protocol
1721    /// without exposing the specific codec type to the API caller.
1722    ///
1723    /// # Parameters
1724    /// - `session_identifier`: Domain separator bytes for the Fiat-Shamir transform
1725    ///
1726    /// # Returns
1727    /// A `Nizk` instance ready for proving and verification
1728    pub fn into_nizk(self, session_identifier: &[u8]) -> Nizk<ComposedRelation<G>> {
1729        Nizk::new(session_identifier, self)
1730    }
1731}