Skip to main content

sigma_proofs/
composition.rs

1//! # Protocol Composition with AND/OR Logic
2//!
3//! This module defines the [`ComposedRelation`] enum, which generalizes the [`CanonicalLinearRelation`]
4//! by enabling compositional logic between multiple proof instances.
5//!
6//! Specifically, it supports:
7//! - Simple atomic proofs (e.g., discrete logarithm, Pedersen commitments)
8//! - Conjunctions (`And`) of multiple sub-protocols
9//! - Disjunctions (`Or`) of multiple sub-protocols
10//! - Thresholds (`Threshold`) over multiple sub-protocols
11//!
12//! ## Example Composition
13//!
14//! ```ignore
15//! And(
16//!    Or(dleq, pedersen_commitment),
17//!    Simple(discrete_logarithm),
18//!    And(pedersen_commitment_dleq, bbs_blind_commitment_computation)
19//! )
20//! ```
21
22use alloc::{vec, vec::Vec};
23use ff::{Field, PrimeField};
24use group::prime::PrimeGroup;
25use sha3::{Digest, Sha3_256};
26use spongefish::{
27    Decoding, Encoding, NargDeserialize, NargSerialize, VerificationError, VerificationResult,
28};
29use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
30
31use crate::errors::InvalidInstance;
32use crate::traits::ScalarRng;
33use crate::MultiScalarMul;
34use crate::{
35    errors::Error,
36    fiat_shamir::Nizk,
37    linear_relation::{CanonicalLinearRelation, LinearRelation},
38    traits::{SigmaProtocol, SigmaProtocolSimulator},
39};
40
41/// A protocol proving knowledge of a witness for a composition of linear relations.
42///
43/// This implementation generalizes [`CanonicalLinearRelation`] by using AND/OR links.
44///
45/// # Type Parameters
46/// - `G`: A cryptographic group implementing [`group::Group`] and [`group::GroupEncoding`].
47#[derive(Clone)]
48pub enum ComposedRelation<G: PrimeGroup> {
49    Simple(CanonicalLinearRelation<G>),
50    And(Vec<ComposedRelation<G>>),
51    Or(Vec<ComposedRelation<G>>),
52    Threshold(usize, Vec<ComposedRelation<G>>),
53}
54
55impl<G: PrimeGroup + ConstantTimeEq + ConditionallySelectable> ComposedRelation<G> {
56    /// Create a [ComposedRelation] for an AND relation from the given list of relations.
57    pub fn and<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
58        Self::And(witness.into_iter().map(|x| x.into()).collect())
59    }
60
61    /// Create a [ComposedRelation] for an OR relation from the given list of relations.
62    pub fn or<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
63        Self::Or(witness.into_iter().map(|x| x.into()).collect())
64    }
65
66    /// Create a [ComposedRelation] for a threshold relation from the given list of relations.
67    pub fn threshold<T: Into<ComposedRelation<G>>>(
68        threshold: usize,
69        witness: impl IntoIterator<Item = T>,
70    ) -> Self {
71        Self::Threshold(threshold, witness.into_iter().map(|x| x.into()).collect())
72    }
73}
74
75impl<G: PrimeGroup> From<CanonicalLinearRelation<G>> for ComposedRelation<G> {
76    fn from(value: CanonicalLinearRelation<G>) -> Self {
77        ComposedRelation::Simple(value)
78    }
79}
80
81impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for ComposedRelation<G> {
82    type Error = InvalidInstance;
83
84    fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
85        Ok(Self::Simple(CanonicalLinearRelation::try_from(value)?))
86    }
87}
88
89// Structure representing the Commitment type of Protocol as SigmaProtocol
90#[derive(Clone)]
91pub enum ComposedCommitment<G>
92where
93    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
94    G::Scalar:
95        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
96{
97    Simple(Vec<G>),
98    And(Vec<ComposedCommitment<G>>),
99    Or(Vec<ComposedCommitment<G>>),
100    Threshold(Vec<ComposedCommitment<G>>),
101}
102
103impl<G: PrimeGroup> ComposedCommitment<G>
104where
105    G: ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
106    G::Scalar:
107        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
108{
109    /// Conditionally select between two ComposedCommitment values.
110    /// This function performs constant-time selection of the commitment values.
111    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
112        match (a, b) {
113            (ComposedCommitment::Simple(a_elements), ComposedCommitment::Simple(b_elements)) => {
114                // Both vectors must have the same length for this to work
115                debug_assert_eq!(a_elements.len(), b_elements.len());
116                let selected: Vec<G> = a_elements
117                    .iter()
118                    .zip(b_elements.iter())
119                    .map(|(a, b)| G::conditional_select(a, b, choice))
120                    .collect();
121                ComposedCommitment::Simple(selected)
122            }
123            (ComposedCommitment::And(a_commitments), ComposedCommitment::And(b_commitments)) => {
124                debug_assert_eq!(a_commitments.len(), b_commitments.len());
125                let selected: Vec<ComposedCommitment<G>> = a_commitments
126                    .iter()
127                    .zip(b_commitments.iter())
128                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
129                    .collect();
130                ComposedCommitment::And(selected)
131            }
132            (ComposedCommitment::Or(a_commitments), ComposedCommitment::Or(b_commitments)) => {
133                debug_assert_eq!(a_commitments.len(), b_commitments.len());
134                let selected: Vec<ComposedCommitment<G>> = a_commitments
135                    .iter()
136                    .zip(b_commitments.iter())
137                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
138                    .collect();
139                ComposedCommitment::Or(selected)
140            }
141            (
142                ComposedCommitment::Threshold(a_commitments),
143                ComposedCommitment::Threshold(b_commitments),
144            ) => {
145                debug_assert_eq!(a_commitments.len(), b_commitments.len());
146                let selected: Vec<ComposedCommitment<G>> = a_commitments
147                    .iter()
148                    .zip(b_commitments.iter())
149                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
150                    .collect();
151                ComposedCommitment::Threshold(selected)
152            }
153            _ => {
154                unreachable!("Mismatched ComposedCommitment variants in conditional_select");
155            }
156        }
157    }
158}
159
160// Structure representing the ProverState type of Protocol as SigmaProtocol
161pub enum ComposedProverState<G>
162where
163    G: PrimeGroup
164        + ConstantTimeEq
165        + ConditionallySelectable
166        + Encoding<[u8]>
167        + NargSerialize
168        + NargDeserialize
169        + MultiScalarMul,
170    G::Scalar:
171        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
172{
173    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::ProverState),
174    And(Vec<ComposedProverState<G>>),
175    Or(ComposedOrProverState<G>),
176    Threshold(ComposedThresholdProverState<G>),
177}
178
179pub type ComposedOrProverState<G> = Vec<ComposedOrProverStateEntry<G>>;
180pub struct ComposedOrProverStateEntry<G>(
181    Choice,
182    ComposedProverState<G>,
183    ComposedChallenge<G>,
184    ComposedResponse<G>,
185)
186where
187    G: PrimeGroup
188        + ConstantTimeEq
189        + ConditionallySelectable
190        + Encoding<[u8]>
191        + NargSerialize
192        + NargDeserialize
193        + MultiScalarMul,
194    G::Scalar:
195        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable;
196
197pub type ComposedThresholdProverState<G> = Vec<ComposedThresholdProverStateEntry<G>>;
198pub struct ComposedThresholdProverStateEntry<G>
199where
200    G: PrimeGroup
201        + ConstantTimeEq
202        + ConditionallySelectable
203        + Encoding<[u8]>
204        + NargSerialize
205        + NargDeserialize
206        + MultiScalarMul,
207    G::Scalar:
208        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
209{
210    use_simulator: Choice,
211    prover_state: ComposedProverState<G>,
212    simulated_challenge: ComposedChallenge<G>,
213    simulated_response: ComposedResponse<G>,
214}
215
216// Structure representing the Response type of Protocol as SigmaProtocol
217#[derive(Clone)]
218pub enum ComposedResponse<G>
219where
220    G: PrimeGroup
221        + ConditionallySelectable
222        + Encoding<[u8]>
223        + NargSerialize
224        + NargDeserialize
225        + MultiScalarMul,
226    G::Scalar:
227        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
228{
229    Simple(Vec<<CanonicalLinearRelation<G> as SigmaProtocol>::Response>),
230    And(Vec<ComposedResponse<G>>),
231    Or(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
232    Threshold(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
233}
234
235const TAG_SIMPLE: u8 = 0;
236const TAG_AND: u8 = 1;
237const TAG_OR: u8 = 2;
238const TAG_THRESHOLD: u8 = 3;
239
240fn read_u32(buf: &mut &[u8]) -> VerificationResult<u32> {
241    if buf.len() < 4 {
242        return Err(VerificationError);
243    }
244    let (head, tail) = buf.split_at(4);
245    *buf = tail;
246    Ok(u32::from_le_bytes(head.try_into().unwrap()))
247}
248
249fn write_len(out: &mut Vec<u8>, len: usize) {
250    out.extend_from_slice(&(len as u32).to_le_bytes());
251}
252
253impl<G> Encoding<[u8]> for ComposedCommitment<G>
254where
255    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
256    G::Scalar:
257        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
258{
259    fn encode(&self) -> impl AsRef<[u8]> {
260        let mut out = Vec::new();
261        match self {
262            ComposedCommitment::Simple(elems) => {
263                out.push(TAG_SIMPLE);
264                write_len(&mut out, elems.len());
265                for elem in elems {
266                    elem.serialize_into_narg(&mut out);
267                }
268            }
269            ComposedCommitment::And(cs) => {
270                out.push(TAG_AND);
271                write_len(&mut out, cs.len());
272                for c in cs {
273                    c.serialize_into_narg(&mut out);
274                }
275            }
276            ComposedCommitment::Or(cs) => {
277                out.push(TAG_OR);
278                write_len(&mut out, cs.len());
279                for c in cs {
280                    c.serialize_into_narg(&mut out);
281                }
282            }
283            ComposedCommitment::Threshold(cs) => {
284                out.push(TAG_THRESHOLD);
285                write_len(&mut out, cs.len());
286                for c in cs {
287                    c.serialize_into_narg(&mut out);
288                }
289            }
290        }
291        out
292    }
293}
294
295impl<G> NargDeserialize for ComposedCommitment<G>
296where
297    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
298    G::Scalar:
299        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
300{
301    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
302        if buf.is_empty() {
303            return Err(VerificationError);
304        }
305        let (tag_bytes, rest) = buf.split_at(1);
306        *buf = rest;
307        match tag_bytes[0] {
308            TAG_SIMPLE => {
309                let len = read_u32(buf)? as usize;
310                let mut elems = Vec::with_capacity(len);
311                for _ in 0..len {
312                    elems.push(G::deserialize_from_narg(buf)?);
313                }
314                Ok(ComposedCommitment::Simple(elems))
315            }
316            TAG_AND => {
317                let len = read_u32(buf)? as usize;
318                let mut entries = Vec::with_capacity(len);
319                for _ in 0..len {
320                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
321                }
322                Ok(ComposedCommitment::And(entries))
323            }
324            TAG_OR => {
325                let len = read_u32(buf)? as usize;
326                let mut entries = Vec::with_capacity(len);
327                for _ in 0..len {
328                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
329                }
330                Ok(ComposedCommitment::Or(entries))
331            }
332            TAG_THRESHOLD => {
333                let len = read_u32(buf)? as usize;
334                let mut entries = Vec::with_capacity(len);
335                for _ in 0..len {
336                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
337                }
338                Ok(ComposedCommitment::Threshold(entries))
339            }
340            _ => Err(VerificationError),
341        }
342    }
343}
344
345impl<G> Encoding<[u8]> for ComposedResponse<G>
346where
347    G: PrimeGroup
348        + ConditionallySelectable
349        + Encoding<[u8]>
350        + NargSerialize
351        + NargDeserialize
352        + MultiScalarMul,
353    G::Scalar:
354        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
355{
356    fn encode(&self) -> impl AsRef<[u8]> {
357        let mut out = Vec::new();
358        match self {
359            ComposedResponse::Simple(responses) => {
360                out.push(TAG_SIMPLE);
361                write_len(&mut out, responses.len());
362                for r in responses {
363                    r.serialize_into_narg(&mut out);
364                }
365            }
366            ComposedResponse::And(entries) => {
367                out.push(TAG_AND);
368                write_len(&mut out, entries.len());
369                for r in entries {
370                    r.serialize_into_narg(&mut out);
371                }
372            }
373            ComposedResponse::Or(challenges, responses) => {
374                out.push(TAG_OR);
375                write_len(&mut out, challenges.len());
376                for c in challenges {
377                    c.serialize_into_narg(&mut out);
378                }
379                write_len(&mut out, responses.len());
380                for r in responses {
381                    r.serialize_into_narg(&mut out);
382                }
383            }
384            ComposedResponse::Threshold(challenges, responses) => {
385                out.push(TAG_THRESHOLD);
386                write_len(&mut out, challenges.len());
387                for c in challenges {
388                    c.serialize_into_narg(&mut out);
389                }
390                write_len(&mut out, responses.len());
391                for r in responses {
392                    r.serialize_into_narg(&mut out);
393                }
394            }
395        }
396        out
397    }
398}
399
400impl<G> NargDeserialize for ComposedResponse<G>
401where
402    G: PrimeGroup
403        + ConditionallySelectable
404        + Encoding<[u8]>
405        + NargSerialize
406        + NargDeserialize
407        + MultiScalarMul,
408    G::Scalar:
409        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
410{
411    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
412        if buf.is_empty() {
413            return Err(VerificationError);
414        }
415        let (tag_bytes, rest) = buf.split_at(1);
416        *buf = rest;
417        match tag_bytes[0] {
418            TAG_SIMPLE => {
419                let len = read_u32(buf)? as usize;
420                let mut elems = Vec::with_capacity(len);
421                for _ in 0..len {
422                    elems.push(G::Scalar::deserialize_from_narg(buf)?);
423                }
424                Ok(ComposedResponse::Simple(elems))
425            }
426            TAG_AND => {
427                let len = read_u32(buf)? as usize;
428                let mut entries = Vec::with_capacity(len);
429                for _ in 0..len {
430                    entries.push(ComposedResponse::deserialize_from_narg(buf)?);
431                }
432                Ok(ComposedResponse::And(entries))
433            }
434            TAG_OR => {
435                let ch_len = read_u32(buf)? as usize;
436                let mut challenges = Vec::with_capacity(ch_len);
437                for _ in 0..ch_len {
438                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
439                }
440                let resp_len = read_u32(buf)? as usize;
441                let mut responses = Vec::with_capacity(resp_len);
442                for _ in 0..resp_len {
443                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
444                }
445                Ok(ComposedResponse::Or(challenges, responses))
446            }
447            TAG_THRESHOLD => {
448                let ch_len = read_u32(buf)? as usize;
449                let mut challenges = Vec::with_capacity(ch_len);
450                for _ in 0..ch_len {
451                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
452                }
453                let resp_len = read_u32(buf)? as usize;
454                let mut responses = Vec::with_capacity(resp_len);
455                for _ in 0..resp_len {
456                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
457                }
458                Ok(ComposedResponse::Threshold(challenges, responses))
459            }
460            _ => Err(VerificationError),
461        }
462    }
463}
464
465impl<G> ComposedResponse<G>
466where
467    G: PrimeGroup
468        + ConditionallySelectable
469        + Encoding<[u8]>
470        + NargSerialize
471        + NargDeserialize
472        + MultiScalarMul,
473    G::Scalar:
474        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
475{
476    /// Conditionally select between two ComposedResponse values.
477    /// This function performs constant-time selection of the response values.
478    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
479        match (a, b) {
480            (ComposedResponse::Simple(a_scalars), ComposedResponse::Simple(b_scalars)) => {
481                // Both vectors must have the same length for this to work
482                debug_assert_eq!(a_scalars.len(), b_scalars.len());
483                let selected: Vec<G::Scalar> = a_scalars
484                    .iter()
485                    .zip(b_scalars.iter())
486                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
487                    .collect();
488                ComposedResponse::Simple(selected)
489            }
490            (ComposedResponse::And(a_responses), ComposedResponse::And(b_responses)) => {
491                debug_assert_eq!(a_responses.len(), b_responses.len());
492                let selected: Vec<ComposedResponse<G>> = a_responses
493                    .iter()
494                    .zip(b_responses.iter())
495                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
496                    .collect();
497                ComposedResponse::And(selected)
498            }
499            (
500                ComposedResponse::Or(a_challenges, a_responses),
501                ComposedResponse::Or(b_challenges, b_responses),
502            ) => {
503                debug_assert_eq!(a_challenges.len(), b_challenges.len());
504                debug_assert_eq!(a_responses.len(), b_responses.len());
505
506                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
507                    .iter()
508                    .zip(b_challenges.iter())
509                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
510                    .collect();
511
512                let selected_responses: Vec<ComposedResponse<G>> = a_responses
513                    .iter()
514                    .zip(b_responses.iter())
515                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
516                    .collect();
517
518                ComposedResponse::Or(selected_challenges, selected_responses)
519            }
520            (
521                ComposedResponse::Threshold(a_challenges, a_responses),
522                ComposedResponse::Threshold(b_challenges, b_responses),
523            ) => {
524                debug_assert_eq!(a_challenges.len(), b_challenges.len());
525                debug_assert_eq!(a_responses.len(), b_responses.len());
526
527                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
528                    .iter()
529                    .zip(b_challenges.iter())
530                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
531                    .collect();
532
533                let selected_responses: Vec<ComposedResponse<G>> = a_responses
534                    .iter()
535                    .zip(b_responses.iter())
536                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
537                    .collect();
538
539                ComposedResponse::Threshold(selected_challenges, selected_responses)
540            }
541            _ => {
542                unreachable!("Mismatched ComposedResponse variants in conditional_select");
543            }
544        }
545    }
546}
547
548// Structure representing the Witness type of Protocol as SigmaProtocol
549#[derive(Clone)]
550pub enum ComposedWitness<G>
551where
552    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
553    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
554{
555    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::Witness),
556    And(Vec<ComposedWitness<G>>),
557    Or(Vec<ComposedWitness<G>>),
558    Threshold(Vec<ComposedWitness<G>>),
559}
560
561impl<G> ComposedWitness<G>
562where
563    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
564    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
565{
566    /// Create a [ComposedWitness] for an AND relation from the given list of witnesses.
567    pub fn and<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
568        Self::And(witness.into_iter().map(|x| x.into()).collect())
569    }
570
571    /// Create a [ComposedWitness] for an OR relation from the given list of witnesses.
572    pub fn or<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
573        Self::Or(witness.into_iter().map(|x| x.into()).collect())
574    }
575
576    /// Create a [ComposedWitness] for a threshold relation from the given list of witnesses.
577    pub fn threshold<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
578        Self::Threshold(witness.into_iter().map(|x| x.into()).collect())
579    }
580}
581
582impl<G> From<<CanonicalLinearRelation<G> as SigmaProtocol>::Witness> for ComposedWitness<G>
583where
584    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
585    G::Scalar:
586        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
587{
588    fn from(value: <CanonicalLinearRelation<G> as SigmaProtocol>::Witness) -> Self {
589        Self::Simple(value)
590    }
591}
592
593type ComposedChallenge<G> = <CanonicalLinearRelation<G> as SigmaProtocol>::Challenge;
594fn threshold_x<F: PrimeField>(index: usize) -> F {
595    F::from((index + 1) as u64)
596}
597
598fn poly_mul_linear<F: Field>(coeffs: &[F], constant: F) -> Vec<F> {
599    let mut out = vec![F::ZERO; coeffs.len() + 1];
600    for (i, coeff) in coeffs.iter().enumerate() {
601        out[i] += *coeff * constant;
602        out[i + 1] += *coeff;
603    }
604    out
605}
606
607fn interpolate_polynomial<F: Field>(points: &[Evaluation<F>]) -> Result<Vec<F>, Error> {
608    if points.is_empty() {
609        return Err(Error::InvalidInstanceWitnessPair);
610    }
611
612    let mut coeffs = vec![F::ZERO; points.len()];
613
614    for (i, point_i) in points.iter().enumerate() {
615        let mut basis = vec![F::ONE];
616        let mut denom = F::ONE;
617
618        for (j, point_j) in points.iter().enumerate() {
619            if i == j {
620                continue;
621            }
622            denom *= point_i.x - point_j.x;
623            basis = poly_mul_linear::<F>(&basis, -point_j.x);
624        }
625
626        let denom_inv = denom.invert();
627        if denom_inv.is_none().into() {
628            return Err(Error::InvalidInstanceWitnessPair);
629        }
630        let scale = point_i.y * denom_inv.unwrap_or(F::ZERO);
631        for (coeff, basis_coeff) in coeffs.iter_mut().zip(basis.iter()) {
632            *coeff += *basis_coeff * scale;
633        }
634    }
635
636    Ok(coeffs)
637}
638
639fn evaluate_polynomial<F: Field>(coeffs: &[F], x: F) -> F {
640    coeffs
641        .iter()
642        .rev()
643        .fold(F::ZERO, |acc, coeff| acc * x + coeff)
644}
645
646fn expand_threshold_challenges<F: PrimeField>(
647    threshold: usize,
648    total: usize,
649    challenge: F,
650    compressed_challenges: &[F],
651) -> Result<Vec<F>, Error> {
652    if threshold == 0 || threshold > total {
653        return Err(Error::InvalidInstanceWitnessPair);
654    }
655
656    let degree = total - threshold;
657    if compressed_challenges.len() != degree {
658        return Err(Error::InvalidInstanceWitnessPair);
659    }
660
661    let mut points = Vec::with_capacity(degree + 1);
662    points.push(Evaluation {
663        x: F::ZERO,
664        y: challenge,
665    });
666    for (index, share) in compressed_challenges.iter().enumerate() {
667        points.push(Evaluation {
668            x: threshold_x::<F>(index),
669            y: *share,
670        });
671    }
672
673    let coeffs = interpolate_polynomial::<F>(&points)?;
674    let mut challenges = Vec::with_capacity(total);
675    for index in 0..total {
676        challenges.push(evaluate_polynomial::<F>(&coeffs, threshold_x::<F>(index)));
677    }
678
679    Ok(challenges)
680}
681
682fn count_choices(choices: &[Choice]) -> usize {
683    let mut sum: u32 = 0;
684    for choice in choices {
685        let inc = sum.wrapping_add(1);
686        sum = u32::conditional_select(&sum, &inc, *choice);
687    }
688    sum as usize
689}
690
691#[derive(Clone, Copy)]
692struct Evaluation<T> {
693    x: T,
694    y: T,
695}
696
697impl<T: ConditionallySelectable> ConditionallySelectable for Evaluation<T> {
698    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
699        Evaluation {
700            x: T::conditional_select(&a.x, &b.x, choice),
701            y: T::conditional_select(&a.y, &b.y, choice),
702        }
703    }
704}
705
706impl<T> From<(T, T)> for Evaluation<T> {
707    fn from(value: (T, T)) -> Self {
708        Evaluation {
709            x: value.0,
710            y: value.1,
711        }
712    }
713}
714
715fn conditional_swap_point<T: ConditionallySelectable>(
716    points: &mut [T],
717    left: usize,
718    right: usize,
719    swap: Choice,
720) {
721    if left == right {
722        return;
723    }
724    if left < right {
725        let (head, tail) = points.split_at_mut(right);
726        T::conditional_swap(&mut head[left], &mut tail[0], swap);
727    } else {
728        let (head, tail) = points.split_at_mut(left);
729        T::conditional_swap(&mut tail[0], &mut head[right], swap);
730    }
731}
732
733fn oroffcompact_points<T: ConditionallySelectable>(
734    points: &mut [T],
735    marks: &[Choice],
736    offset: usize,
737) {
738    let n = points.len();
739    if n <= 1 {
740        return;
741    }
742    debug_assert_eq!(n, marks.len());
743    debug_assert!(n.is_power_of_two());
744
745    let half = n / 2;
746    let mut m = 0usize;
747    for mark in &marks[..half] {
748        m += mark.unwrap_u8() as usize;
749    }
750
751    if n == 2 {
752        let z = Choice::from((offset & 1) as u8);
753        let b = ((!marks[0]) & marks[1]) ^ z;
754        conditional_swap_point(points, 0, 1, b);
755        return;
756    }
757
758    let offset_mod = offset % half;
759    oroffcompact_points(&mut points[..half], &marks[..half], offset_mod);
760    let offset_plus_m_mod = (offset + m) % half;
761    oroffcompact_points(&mut points[half..], &marks[half..], offset_plus_m_mod);
762
763    let s = Choice::from(((offset_mod + m) >= half) as u8) ^ Choice::from((offset >= half) as u8);
764    for i in 0..half {
765        let b = s ^ Choice::from((i >= offset_plus_m_mod) as u8);
766        conditional_swap_point(points, i, i + half, b);
767    }
768}
769
770fn oblivious_compact_points<T: ConditionallySelectable>(points: &mut [T], marks: &[Choice]) {
771    let n = points.len();
772    if n == 0 {
773        return;
774    }
775    debug_assert_eq!(n, marks.len());
776
777    let n1 = 1usize << (usize::BITS as usize - 1 - n.leading_zeros() as usize);
778    let n2 = n - n1;
779    let mut m = 0usize;
780    for mark in &marks[..n2] {
781        m += mark.unwrap_u8() as usize;
782    }
783
784    if n2 > 0 {
785        oblivious_compact_points(&mut points[..n2], &marks[..n2]);
786    }
787    oroffcompact_points(&mut points[n2..], &marks[n2..], (n1 - n2 + m) % n1);
788
789    for i in 0..n2 {
790        let b = Choice::from((i >= m) as u8);
791        conditional_swap_point(points, i, i + n1, b);
792    }
793}
794
795impl<G> ComposedRelation<G>
796where
797    G: PrimeGroup
798        + ConstantTimeEq
799        + ConditionallySelectable
800        + Encoding<[u8]>
801        + NargSerialize
802        + NargDeserialize
803        + MultiScalarMul,
804    G::Scalar:
805        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
806{
807    fn is_witness_valid(&self, witness: &ComposedWitness<G>) -> Choice {
808        match (self, witness) {
809            (ComposedRelation::Simple(instance), ComposedWitness::Simple(witness)) => {
810                instance.is_witness_valid(witness)
811            }
812            (ComposedRelation::And(instances), ComposedWitness::And(witnesses)) => instances
813                .iter()
814                .zip(witnesses)
815                .fold(Choice::from(1), |bit, (instance, witness)| {
816                    bit & instance.is_witness_valid(witness)
817                }),
818            (ComposedRelation::Or(instances), ComposedWitness::Or(witnesses)) => instances
819                .iter()
820                .zip(witnesses)
821                .fold(Choice::from(0), |bit, (instance, witness)| {
822                    bit | instance.is_witness_valid(witness)
823                }),
824            (
825                ComposedRelation::Threshold(threshold, instances),
826                ComposedWitness::Threshold(witnesses),
827            ) => {
828                if *threshold == 0 || instances.len() != witnesses.len() {
829                    return Choice::from(0);
830                }
831                let mut count = 0usize;
832                for (instance, witness) in instances.iter().zip(witnesses) {
833                    if instance.is_witness_valid(witness).unwrap_u8() == 1 {
834                        count += 1;
835                    }
836                }
837                Choice::from((count >= *threshold) as u8)
838            }
839            _ => Choice::from(0),
840        }
841    }
842
843    fn prover_commit_simple(
844        protocol: &CanonicalLinearRelation<G>,
845        witness: &<CanonicalLinearRelation<G> as SigmaProtocol>::Witness,
846        rng: &mut impl ScalarRng,
847    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
848        protocol.prover_commit(witness, rng).map(|(c, s)| {
849            (
850                ComposedCommitment::Simple(c),
851                ComposedProverState::Simple(s),
852            )
853        })
854    }
855
856    fn prover_response_simple(
857        instance: &CanonicalLinearRelation<G>,
858        state: <CanonicalLinearRelation<G> as SigmaProtocol>::ProverState,
859        challenge: &<CanonicalLinearRelation<G> as SigmaProtocol>::Challenge,
860    ) -> Result<ComposedResponse<G>, Error> {
861        instance
862            .prover_response(state, challenge)
863            .map(ComposedResponse::Simple)
864    }
865
866    fn prover_commit_and(
867        protocols: &[ComposedRelation<G>],
868        witnesses: &[ComposedWitness<G>],
869        rng: &mut impl ScalarRng,
870    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
871        if protocols.len() != witnesses.len() {
872            return Err(Error::InvalidInstanceWitnessPair);
873        }
874
875        let mut commitments = Vec::with_capacity(protocols.len());
876        let mut prover_states = Vec::with_capacity(protocols.len());
877
878        for (p, w) in protocols.iter().zip(witnesses.iter()) {
879            let (mut c, s) = p.prover_commit(w, rng)?;
880            let commitment = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
881            if !c.is_empty() {
882                return Err(Error::InvalidInstanceWitnessPair);
883            }
884            commitments.push(commitment);
885            prover_states.push(s);
886        }
887
888        Ok((
889            ComposedCommitment::And(commitments),
890            ComposedProverState::And(prover_states),
891        ))
892    }
893
894    fn prover_response_and(
895        instances: &[ComposedRelation<G>],
896        prover_state: Vec<ComposedProverState<G>>,
897        challenge: &ComposedChallenge<G>,
898    ) -> Result<ComposedResponse<G>, Error> {
899        if instances.len() != prover_state.len() {
900            return Err(Error::InvalidInstanceWitnessPair);
901        }
902
903        let responses: Result<Vec<_>, _> = instances
904            .iter()
905            .zip(prover_state)
906            .map(|(p, s)| {
907                let mut res = p.prover_response(s, challenge)?;
908                res.pop().ok_or(Error::InvalidInstanceWitnessPair)
909            })
910            .collect();
911
912        Ok(ComposedResponse::And(responses?))
913    }
914
915    fn prover_commit_or(
916        instances: &[ComposedRelation<G>],
917        witnesses: &[ComposedWitness<G>],
918        rng: &mut impl ScalarRng,
919    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
920    where
921        G: ConditionallySelectable,
922    {
923        if instances.len() != witnesses.len() {
924            return Err(Error::InvalidInstanceWitnessPair);
925        }
926
927        let mut commitments = Vec::new();
928        let mut prover_states = Vec::new();
929
930        // Selector value set when the first valid witness is found.
931        let mut valid_witness_found = Choice::from(0);
932        for (i, w) in witnesses.iter().enumerate() {
933            let (mut commitment_vec, prover_state) = instances[i].prover_commit(w, rng)?;
934            let commitment = commitment_vec
935                .pop()
936                .ok_or(Error::InvalidInstanceWitnessPair)?;
937            if !commitment_vec.is_empty() {
938                return Err(Error::InvalidInstanceWitnessPair);
939            }
940
941            let (mut simulated_commitment_vec, simulated_challenge, mut simulated_response_vec) =
942                instances[i].simulate_transcript(rng)?;
943            let simulated_commitment = simulated_commitment_vec
944                .pop()
945                .ok_or(Error::InvalidInstanceWitnessPair)?;
946            if !simulated_commitment_vec.is_empty() {
947                return Err(Error::InvalidInstanceWitnessPair);
948            }
949            let simulated_response = simulated_response_vec
950                .pop()
951                .ok_or(Error::InvalidInstanceWitnessPair)?;
952            if !simulated_response_vec.is_empty() {
953                return Err(Error::InvalidInstanceWitnessPair);
954            }
955
956            let valid_witness = instances[i].is_witness_valid(w) & !valid_witness_found;
957            let select_witness = valid_witness;
958
959            let commitment = ComposedCommitment::conditional_select(
960                &simulated_commitment,
961                &commitment,
962                select_witness,
963            );
964
965            commitments.push(commitment);
966            prover_states.push(ComposedOrProverStateEntry(
967                select_witness,
968                prover_state,
969                simulated_challenge,
970                simulated_response,
971            ));
972
973            valid_witness_found |= valid_witness;
974        }
975
976        if valid_witness_found.unwrap_u8() == 0 {
977            Err(Error::InvalidInstanceWitnessPair)
978        } else {
979            Ok((
980                ComposedCommitment::Or(commitments),
981                ComposedProverState::Or(prover_states),
982            ))
983        }
984    }
985
986    fn prover_response_or(
987        instances: &[ComposedRelation<G>],
988        prover_state: ComposedOrProverState<G>,
989        challenge: &ComposedChallenge<G>,
990    ) -> Result<ComposedResponse<G>, Error> {
991        let mut result_challenges = Vec::with_capacity(instances.len());
992        let mut result_responses = Vec::with_capacity(instances.len());
993
994        let mut witness_challenge = *challenge;
995        for ComposedOrProverStateEntry(
996            valid_witness,
997            _prover_state,
998            simulated_challenge,
999            _simulated_response,
1000        ) in &prover_state
1001        {
1002            let c = G::Scalar::conditional_select(
1003                simulated_challenge,
1004                &G::Scalar::ZERO,
1005                *valid_witness,
1006            );
1007            witness_challenge -= c;
1008        }
1009        for (
1010            instance,
1011            ComposedOrProverStateEntry(
1012                valid_witness,
1013                prover_state,
1014                simulated_challenge,
1015                simulated_response,
1016            ),
1017        ) in instances.iter().zip(prover_state)
1018        {
1019            let challenge_i = G::Scalar::conditional_select(
1020                &simulated_challenge,
1021                &witness_challenge,
1022                valid_witness,
1023            );
1024
1025            let mut response_vec = instance.prover_response(prover_state, &challenge_i)?;
1026            let response = response_vec
1027                .pop()
1028                .ok_or(Error::InvalidInstanceWitnessPair)?;
1029            if !response_vec.is_empty() {
1030                return Err(Error::InvalidInstanceWitnessPair);
1031            }
1032            let response =
1033                ComposedResponse::conditional_select(&simulated_response, &response, valid_witness);
1034
1035            result_challenges.push(challenge_i);
1036            result_responses.push(response.clone());
1037        }
1038
1039        result_challenges.pop();
1040        Ok(ComposedResponse::Or(result_challenges, result_responses))
1041    }
1042
1043    fn prover_commit_threshold(
1044        threshold: usize,
1045        instances: &[ComposedRelation<G>],
1046        witnesses: &[ComposedWitness<G>],
1047        rng: &mut impl ScalarRng,
1048    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
1049    where
1050        G: ConditionallySelectable,
1051    {
1052        if instances.len() != witnesses.len() || threshold == 0 || threshold > instances.len() {
1053            return Err(Error::InvalidInstanceWitnessPair);
1054        }
1055        let degree = instances.len() - threshold;
1056
1057        let valid_witnesses = instances
1058            .iter()
1059            .zip(witnesses.iter())
1060            .map(|(x, w)| x.is_witness_valid(w))
1061            .collect::<Vec<Choice>>();
1062
1063        // Degree-(t-1) interpolation can only satisfy t fixed points.
1064        let invalid_count = instances.len() - count_choices(&valid_witnesses);
1065        if invalid_count > degree {
1066            return Err(Error::InvalidInstanceWitnessPair);
1067        }
1068
1069        let mut remaining_seeds = (degree - invalid_count) as u32;
1070        let mut commitments = Vec::with_capacity(instances.len());
1071        let mut prover_states = Vec::with_capacity(instances.len());
1072        for (i, (instance, witness)) in instances.iter().zip(witnesses.iter()).enumerate() {
1073            let (mut commitment_vec, prover_state) = instance.prover_commit(witness, rng)?;
1074            let commitment = commitment_vec
1075                .pop()
1076                .ok_or(Error::InvalidInstanceWitnessPair)?;
1077            if !commitment_vec.is_empty() {
1078                return Err(Error::InvalidInstanceWitnessPair);
1079            }
1080
1081            let (mut simulated_commitments, simulated_challenge, mut simulated_responses) =
1082                instance.simulate_transcript(rng)?;
1083            let simulated_commitment = simulated_commitments
1084                .pop()
1085                .ok_or(Error::InvalidInstanceWitnessPair)?;
1086            if !simulated_commitments.is_empty() {
1087                return Err(Error::InvalidInstanceWitnessPair);
1088            }
1089            let simulated_response = simulated_responses
1090                .pop()
1091                .ok_or(Error::InvalidInstanceWitnessPair)?;
1092            if !simulated_responses.is_empty() {
1093                return Err(Error::InvalidInstanceWitnessPair);
1094            }
1095
1096            let valid_witness = valid_witnesses[i];
1097            let should_seed = valid_witness & Choice::from((remaining_seeds != 0) as u8);
1098            remaining_seeds = remaining_seeds.wrapping_sub(should_seed.unwrap_u8() as u32);
1099            let use_simulator = (!valid_witness) | should_seed;
1100            let commitment = ComposedCommitment::conditional_select(
1101                &commitment,
1102                &simulated_commitment,
1103                use_simulator,
1104            );
1105            commitments.push(commitment);
1106            prover_states.push(ComposedThresholdProverStateEntry {
1107                use_simulator,
1108                prover_state,
1109                simulated_challenge,
1110                simulated_response,
1111            });
1112        }
1113
1114        Ok((
1115            ComposedCommitment::Threshold(commitments),
1116            ComposedProverState::Threshold(prover_states),
1117        ))
1118    }
1119
1120    fn prover_response_threshold(
1121        threshold: usize,
1122        instances: &[ComposedRelation<G>],
1123        prover_states: ComposedThresholdProverState<G>,
1124        challenge: &ComposedChallenge<G>,
1125    ) -> Result<ComposedResponse<G>, Error> {
1126        if threshold == 0 || threshold > instances.len() || instances.len() != prover_states.len() {
1127            return Err(Error::InvalidInstanceWitnessPair);
1128        }
1129        let degree = instances.len() - threshold;
1130
1131        let marks = prover_states
1132            .iter()
1133            .map(|entry| entry.use_simulator)
1134            .collect::<Vec<_>>();
1135        debug_assert_eq!(count_choices(&marks), degree);
1136
1137        let mut points = prover_states
1138            .iter()
1139            .enumerate()
1140            .map(|(i, entry)| Evaluation {
1141                x: threshold_x::<G::Scalar>(i),
1142                y: entry.simulated_challenge,
1143            })
1144            .collect::<Vec<Evaluation<G::Scalar>>>();
1145        oblivious_compact_points(&mut points, &marks);
1146        points.drain(degree..);
1147
1148        let mut full_points = Vec::with_capacity(degree + 1);
1149        full_points.push(Evaluation {
1150            x: G::Scalar::ZERO,
1151            y: *challenge,
1152        });
1153        full_points.extend_from_slice(&points);
1154
1155        let coeffs = interpolate_polynomial::<G::Scalar>(&full_points)?;
1156        let mut compressed_challenges = Vec::with_capacity(degree);
1157        for index in 0..degree {
1158            compressed_challenges.push(evaluate_polynomial::<G::Scalar>(
1159                &coeffs,
1160                threshold_x::<G::Scalar>(index),
1161            ));
1162        }
1163
1164        let expanded_challenges = expand_threshold_challenges::<G::Scalar>(
1165            threshold,
1166            instances.len(),
1167            *challenge,
1168            &compressed_challenges,
1169        )?;
1170
1171        let mut responses = Vec::with_capacity(instances.len());
1172
1173        for (i, (instance, prover_state)) in instances.iter().zip(prover_states).enumerate() {
1174            let poly_challenge = expanded_challenges[i];
1175            let challenge = G::Scalar::conditional_select(
1176                &poly_challenge,
1177                &prover_state.simulated_challenge,
1178                prover_state.use_simulator,
1179            );
1180
1181            let mut response_vec =
1182                instance.prover_response(prover_state.prover_state, &challenge)?;
1183            let response = response_vec
1184                .pop()
1185                .ok_or(Error::InvalidInstanceWitnessPair)?;
1186            if !response_vec.is_empty() {
1187                return Err(Error::InvalidInstanceWitnessPair);
1188            }
1189            let response = ComposedResponse::conditional_select(
1190                &response,
1191                &prover_state.simulated_response,
1192                prover_state.use_simulator,
1193            );
1194
1195            responses.push(response);
1196        }
1197
1198        Ok(ComposedResponse::Threshold(
1199            compressed_challenges,
1200            responses,
1201        ))
1202    }
1203}
1204
1205impl<G> SigmaProtocol for ComposedRelation<G>
1206where
1207    G: PrimeGroup
1208        + ConstantTimeEq
1209        + ConditionallySelectable
1210        + Encoding<[u8]>
1211        + NargSerialize
1212        + NargDeserialize
1213        + MultiScalarMul,
1214    G::Scalar:
1215        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1216{
1217    type Commitment = ComposedCommitment<G>;
1218    type ProverState = ComposedProverState<G>;
1219    type Response = ComposedResponse<G>;
1220    type Witness = ComposedWitness<G>;
1221    type Challenge = ComposedChallenge<G>;
1222
1223    fn prover_commit(
1224        &self,
1225        witness: &Self::Witness,
1226        rng: &mut impl ScalarRng,
1227    ) -> Result<(Vec<Self::Commitment>, Self::ProverState), Error> {
1228        let (commitment, state) = match (self, witness) {
1229            (ComposedRelation::Simple(p), ComposedWitness::Simple(w)) => {
1230                Self::prover_commit_simple(p, w, rng)
1231            }
1232            (ComposedRelation::And(ps), ComposedWitness::And(ws)) => {
1233                Self::prover_commit_and(ps, ws, rng)
1234            }
1235            (ComposedRelation::Or(ps), ComposedWitness::Or(witnesses)) => {
1236                Self::prover_commit_or(ps, witnesses, rng)
1237            }
1238            (ComposedRelation::Threshold(threshold, ps), ComposedWitness::Threshold(witnesses)) => {
1239                Self::prover_commit_threshold(*threshold, ps, witnesses, rng)
1240            }
1241            _ => Err(Error::InvalidInstanceWitnessPair),
1242        }?;
1243        Ok((vec![commitment], state))
1244    }
1245
1246    fn prover_response(
1247        &self,
1248        state: Self::ProverState,
1249        challenge: &Self::Challenge,
1250    ) -> Result<Vec<Self::Response>, Error> {
1251        let response = match (self, state) {
1252            (ComposedRelation::Simple(instance), ComposedProverState::Simple(state)) => {
1253                Self::prover_response_simple(instance, state, challenge)
1254            }
1255            (ComposedRelation::And(instances), ComposedProverState::And(prover_state)) => {
1256                Self::prover_response_and(instances, prover_state, challenge)
1257            }
1258            (ComposedRelation::Or(instances), ComposedProverState::Or(prover_state)) => {
1259                Self::prover_response_or(instances, prover_state, challenge)
1260            }
1261            (
1262                ComposedRelation::Threshold(threshold, instances),
1263                ComposedProverState::Threshold(prover_state),
1264            ) => Self::prover_response_threshold(*threshold, instances, prover_state, challenge),
1265            _ => Err(Error::InvalidInstanceWitnessPair),
1266        }?;
1267        Ok(vec![response])
1268    }
1269
1270    fn verifier(
1271        &self,
1272        commitment: &[Self::Commitment],
1273        challenge: &Self::Challenge,
1274        response: &[Self::Response],
1275    ) -> Result<(), Error> {
1276        let (commitment, response) = match (commitment.first(), response.first()) {
1277            (Some(c), Some(r)) => (c, r),
1278            _ => return Err(Error::InvalidInstanceWitnessPair),
1279        };
1280
1281        match (self, commitment, response) {
1282            (
1283                ComposedRelation::Simple(p),
1284                ComposedCommitment::Simple(c),
1285                ComposedResponse::Simple(r),
1286            ) => p.verifier(c, challenge, r),
1287            (
1288                ComposedRelation::And(ps),
1289                ComposedCommitment::And(commitments),
1290                ComposedResponse::And(responses),
1291            ) => {
1292                if ps.len() != commitments.len() || commitments.len() != responses.len() {
1293                    return Err(Error::InvalidInstanceWitnessPair);
1294                }
1295                ps.iter()
1296                    .zip(commitments)
1297                    .zip(responses)
1298                    .try_for_each(|((p, c), r)| {
1299                        p.verifier(
1300                            core::slice::from_ref(c),
1301                            challenge,
1302                            core::slice::from_ref(r),
1303                        )
1304                    })
1305            }
1306            (
1307                ComposedRelation::Or(ps),
1308                ComposedCommitment::Or(commitments),
1309                ComposedResponse::Or(challenges, responses),
1310            ) => {
1311                if ps.len() != commitments.len() || commitments.len() != responses.len() {
1312                    return Err(Error::InvalidInstanceWitnessPair);
1313                }
1314                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1315                ps.iter()
1316                    .zip(commitments)
1317                    .zip(challenges.iter().chain(&Some(last_challenge)))
1318                    .zip(responses)
1319                    .try_for_each(|(((p, commitment), challenge), response)| {
1320                        p.verifier(
1321                            core::slice::from_ref(commitment),
1322                            challenge,
1323                            core::slice::from_ref(response),
1324                        )
1325                    })
1326            }
1327            (
1328                ComposedRelation::Threshold(threshold, ps),
1329                ComposedCommitment::Threshold(commitments),
1330                ComposedResponse::Threshold(challenges, responses),
1331            ) => {
1332                if *threshold == 0
1333                    || *threshold > ps.len()
1334                    || commitments.len() != ps.len()
1335                    || challenges.len() != ps.len() - *threshold
1336                    || responses.len() != ps.len()
1337                {
1338                    return Err(Error::InvalidInstanceWitnessPair);
1339                }
1340
1341                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1342                    *threshold,
1343                    ps.len(),
1344                    *challenge,
1345                    challenges,
1346                )?;
1347
1348                ps.iter()
1349                    .zip(commitments)
1350                    .zip(full_challenges.iter())
1351                    .zip(responses)
1352                    .try_for_each(|(((p, commitment), challenge), response)| {
1353                        p.verifier(
1354                            core::slice::from_ref(commitment),
1355                            challenge,
1356                            core::slice::from_ref(response),
1357                        )
1358                    })
1359            }
1360            _ => Err(Error::InvalidInstanceWitnessPair),
1361        }
1362    }
1363
1364    fn commitment_len(&self) -> usize {
1365        1
1366    }
1367
1368    fn response_len(&self) -> usize {
1369        1
1370    }
1371
1372    fn instance_label(&self) -> impl AsRef<[u8]> {
1373        match self {
1374            ComposedRelation::Simple(p) => {
1375                let label = p.instance_label();
1376                label.as_ref().to_vec()
1377            }
1378            ComposedRelation::And(ps) => {
1379                let mut bytes = Vec::new();
1380                for p in ps {
1381                    bytes.extend(p.instance_label().as_ref());
1382                }
1383                bytes
1384            }
1385            ComposedRelation::Or(ps) => {
1386                let mut bytes = Vec::new();
1387                for p in ps {
1388                    bytes.extend(p.instance_label().as_ref());
1389                }
1390                bytes
1391            }
1392            ComposedRelation::Threshold(threshold, ps) => {
1393                let mut bytes = Vec::new();
1394                bytes.extend_from_slice(&((*threshold as u64).to_le_bytes()));
1395                for p in ps {
1396                    bytes.extend(p.instance_label().as_ref());
1397                }
1398                bytes
1399            }
1400        }
1401    }
1402
1403    fn protocol_identifier(&self) -> [u8; 64] {
1404        let mut hasher = Sha3_256::new();
1405
1406        match self {
1407            ComposedRelation::Simple(p) => {
1408                // take the digest of the simple protocol id
1409                hasher.update([0u8; 32]);
1410                hasher.update(p.protocol_identifier());
1411            }
1412            ComposedRelation::And(protocols) => {
1413                hasher.update([1u8; 32]);
1414                for p in protocols {
1415                    hasher.update(p.protocol_identifier().as_ref());
1416                }
1417            }
1418            ComposedRelation::Or(protocols) => {
1419                hasher.update([2u8; 32]);
1420                for p in protocols {
1421                    hasher.update(p.protocol_identifier().as_ref());
1422                }
1423            }
1424            ComposedRelation::Threshold(threshold, protocols) => {
1425                hasher.update([3u8; 32]);
1426                hasher.update(((*threshold as u64).to_le_bytes()).as_ref());
1427                for p in protocols {
1428                    hasher.update(p.protocol_identifier().as_ref());
1429                }
1430            }
1431        }
1432
1433        let mut protocol_id = [0u8; 64];
1434        protocol_id[..32].clone_from_slice(&hasher.finalize());
1435        protocol_id
1436    }
1437}
1438
1439impl<G> SigmaProtocolSimulator for ComposedRelation<G>
1440where
1441    G: PrimeGroup
1442        + ConstantTimeEq
1443        + ConditionallySelectable
1444        + Encoding<[u8]>
1445        + NargSerialize
1446        + NargDeserialize
1447        + MultiScalarMul,
1448    G::Scalar:
1449        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1450{
1451    fn simulate_commitment(
1452        &self,
1453        challenge: &Self::Challenge,
1454        response: &[Self::Response],
1455    ) -> Result<Vec<Self::Commitment>, Error> {
1456        let response = response.first().ok_or(Error::InvalidInstanceWitnessPair)?;
1457        let commitment = match (self, response) {
1458            (ComposedRelation::Simple(p), ComposedResponse::Simple(r)) => {
1459                ComposedCommitment::Simple(p.simulate_commitment(challenge, r)?)
1460            }
1461            (ComposedRelation::And(ps), ComposedResponse::And(rs)) => {
1462                let commitments = ps
1463                    .iter()
1464                    .zip(rs)
1465                    .map(|(p, r)| {
1466                        p.simulate_commitment(challenge, core::slice::from_ref(r))
1467                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1468                    })
1469                    .collect::<Result<Vec<_>, _>>()?;
1470                ComposedCommitment::And(commitments)
1471            }
1472            (ComposedRelation::Or(ps), ComposedResponse::Or(challenges, rs)) => {
1473                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1474                let commitments = ps
1475                    .iter()
1476                    .zip(challenges.iter().chain(&Some(last_challenge)))
1477                    .zip(rs)
1478                    .map(|((p, ch), r)| {
1479                        p.simulate_commitment(ch, core::slice::from_ref(r))
1480                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1481                    })
1482                    .collect::<Result<Vec<_>, _>>()?;
1483                ComposedCommitment::Or(commitments)
1484            }
1485            (
1486                ComposedRelation::Threshold(threshold, ps),
1487                ComposedResponse::Threshold(challenges, rs),
1488            ) => {
1489                if rs.len() != ps.len() || challenges.len() != ps.len() - threshold {
1490                    return Err(Error::InvalidInstanceWitnessPair);
1491                }
1492
1493                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1494                    *threshold,
1495                    ps.len(),
1496                    *challenge,
1497                    challenges,
1498                )?;
1499                let commitments = ps
1500                    .iter()
1501                    .zip(full_challenges.iter())
1502                    .zip(rs)
1503                    .map(|((p, ch), r)| {
1504                        p.simulate_commitment(ch, core::slice::from_ref(r))
1505                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1506                    })
1507                    .collect::<Result<Vec<_>, _>>()?;
1508                ComposedCommitment::Threshold(commitments)
1509            }
1510            _ => unreachable!(),
1511        };
1512
1513        Ok(vec![commitment])
1514    }
1515
1516    fn simulate_response(&self, rng: &mut impl ScalarRng) -> Vec<Self::Response> {
1517        let response = match self {
1518            ComposedRelation::Simple(p) => ComposedResponse::Simple(p.simulate_response(rng)),
1519            ComposedRelation::And(ps) => {
1520                let responses = ps
1521                    .iter()
1522                    .map(|p| {
1523                        let mut r = p.simulate_response(rng);
1524                        r.pop().ok_or(Error::InvalidInstanceWitnessPair)
1525                    })
1526                    .collect::<Result<Vec<_>, _>>()
1527                    .expect("simulate_response invariant");
1528                ComposedResponse::And(responses)
1529            }
1530            ComposedRelation::Or(ps) => {
1531                let challenges = rng.random_scalars_vec::<G>(ps.len()).to_vec();
1532                let mut responses = Vec::with_capacity(ps.len());
1533                for p in ps.iter() {
1534                    let mut r = p.simulate_response(&mut *rng);
1535                    let resp = r
1536                        .pop()
1537                        .expect("simulate_response should return at least one element");
1538                    responses.push(resp);
1539                }
1540                ComposedResponse::Or(challenges, responses)
1541            }
1542            ComposedRelation::Threshold(threshold, ps) => {
1543                if *threshold == 0 || *threshold > ps.len() {
1544                    return vec![ComposedResponse::Threshold(Vec::new(), Vec::new())];
1545                }
1546
1547                let degree = ps.len() - *threshold;
1548                let compressed_challenges = rng.random_scalars_vec::<G>(degree).to_vec();
1549                let mut responses = Vec::with_capacity(ps.len());
1550                for p in ps.iter() {
1551                    let mut r = p.simulate_response(&mut *rng);
1552                    let response = r
1553                        .pop()
1554                        .expect("simulate_response should return at least one element");
1555                    responses.push(response);
1556                }
1557                ComposedResponse::Threshold(compressed_challenges, responses)
1558            }
1559        };
1560        vec![response]
1561    }
1562
1563    fn simulate_transcript(
1564        &self,
1565        rng: &mut impl ScalarRng,
1566    ) -> Result<(Vec<Self::Commitment>, Self::Challenge, Vec<Self::Response>), Error> {
1567        match self {
1568            ComposedRelation::Simple(p) => {
1569                let (c, ch, r) = p.simulate_transcript(rng)?;
1570                Ok((
1571                    vec![ComposedCommitment::Simple(c)],
1572                    ch,
1573                    vec![ComposedResponse::Simple(r)],
1574                ))
1575            }
1576            ComposedRelation::And(ps) => {
1577                let [challenge] = rng.random_scalars::<G, _>();
1578                let mut responses = Vec::with_capacity(ps.len());
1579                for p in ps.iter() {
1580                    let mut resp = p.simulate_response(&mut *rng);
1581                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1582                    if !resp.is_empty() {
1583                        return Err(Error::InvalidInstanceWitnessPair);
1584                    }
1585                    responses.push(response);
1586                }
1587                let commitments = ps
1588                    .iter()
1589                    .enumerate()
1590                    .map(|(i, p)| {
1591                        p.simulate_commitment(&challenge, &[responses[i].clone()])
1592                            .and_then(|mut c| {
1593                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1594                                if !c.is_empty() {
1595                                    return Err(Error::InvalidInstanceWitnessPair);
1596                                }
1597                                Ok(first)
1598                            })
1599                    })
1600                    .collect::<Result<Vec<_>, Error>>()?;
1601
1602                Ok((
1603                    vec![ComposedCommitment::And(commitments)],
1604                    challenge,
1605                    vec![ComposedResponse::And(responses)],
1606                ))
1607            }
1608            ComposedRelation::Or(ps) => {
1609                let challenges = rng.random_scalars_vec::<G>(ps.len() - 1);
1610                let mut responses = Vec::with_capacity(ps.len());
1611                for p in ps.iter() {
1612                    let mut resp = p.simulate_response(&mut *rng);
1613                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1614                    if !resp.is_empty() {
1615                        return Err(Error::InvalidInstanceWitnessPair);
1616                    }
1617                    responses.push(response);
1618                }
1619
1620                let mut commitments = Vec::with_capacity(ps.len());
1621                for i in 0..ps.len() {
1622                    let mut commitment = ps[i].simulate_commitment(
1623                        &if i == ps.len() - 1 {
1624                            challenges.iter().fold(G::Scalar::ZERO, |acc, x| acc - x)
1625                        } else {
1626                            challenges[i]
1627                        },
1628                        &[responses[i].clone()],
1629                    )?;
1630                    let commitment = commitment.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1631                    commitments.push(commitment);
1632                }
1633
1634                Ok((
1635                    vec![ComposedCommitment::Or(commitments)],
1636                    challenges.iter().sum::<G::Scalar>(),
1637                    vec![ComposedResponse::Or(challenges, responses)],
1638                ))
1639            }
1640            ComposedRelation::Threshold(threshold, ps) => {
1641                if *threshold == 0 || *threshold > ps.len() {
1642                    return Err(Error::InvalidInstanceWitnessPair);
1643                }
1644
1645                let degree = ps.len() - *threshold;
1646                let compressed_challenges = rng.random_scalars_vec::<G>(degree);
1647                let mut responses = Vec::with_capacity(ps.len());
1648                for p in ps.iter() {
1649                    let mut resp = p.simulate_response(&mut *rng);
1650                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1651                    if !resp.is_empty() {
1652                        return Err(Error::InvalidInstanceWitnessPair);
1653                    }
1654                    responses.push(response);
1655                }
1656
1657                let [challenge] = rng.random_scalars::<G, _>();
1658                let full_challenges = expand_threshold_challenges(
1659                    *threshold,
1660                    ps.len(),
1661                    challenge,
1662                    &compressed_challenges,
1663                )?;
1664                let commitments = ps
1665                    .iter()
1666                    .zip(full_challenges.iter())
1667                    .zip(responses.iter())
1668                    .map(|((p, ch), r)| {
1669                        p.simulate_commitment(ch, core::slice::from_ref(r))
1670                            .and_then(|mut c| {
1671                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1672                                if !c.is_empty() {
1673                                    return Err(Error::InvalidInstanceWitnessPair);
1674                                }
1675                                Ok(first)
1676                            })
1677                    })
1678                    .collect::<Result<Vec<_>, Error>>()?;
1679                Ok((
1680                    vec![ComposedCommitment::Threshold(commitments)],
1681                    challenge,
1682                    vec![ComposedResponse::Threshold(
1683                        compressed_challenges,
1684                        responses,
1685                    )],
1686                ))
1687            }
1688        }
1689    }
1690}
1691
1692impl<G> ComposedRelation<G>
1693where
1694    G: PrimeGroup
1695        + ConstantTimeEq
1696        + ConditionallySelectable
1697        + Encoding<[u8]>
1698        + NargSerialize
1699        + NargDeserialize
1700        + MultiScalarMul,
1701    G::Scalar:
1702        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1703{
1704    /// Convert this Protocol into a non-interactive zero-knowledge proof
1705    /// using the Shake128DuplexSponge codec and a specified session identifier.
1706    ///
1707    /// This method provides a convenient way to create a NIZK from a Protocol
1708    /// without exposing the specific codec type to the API caller.
1709    ///
1710    /// # Parameters
1711    /// - `session_identifier`: Domain separator bytes for the Fiat-Shamir transform
1712    ///
1713    /// # Returns
1714    /// A `Nizk` instance ready for proving and verification
1715    pub fn into_nizk(self, session_identifier: &[u8]) -> Nizk<ComposedRelation<G>> {
1716        Nizk::new(session_identifier, self)
1717    }
1718}