1use crate::duplex_sponge::DuplexSpongeInterface;
4use crate::duplex_sponge::{keccak::KeccakDuplexSponge, shake::ShakeDuplexSponge};
5use alloc::vec;
6use ff::PrimeField;
7use group::prime::PrimeGroup;
8use num_bigint::BigUint;
9use num_traits::identities::One;
10
11pub trait Codec {
23 type Challenge;
24
25 fn new(
27 protocol_identifier: &[u8; 64],
28 session_identifier: &[u8],
29 instance_label: &[u8],
30 ) -> Self;
31
32 fn prover_message(&mut self, data: &[u8]);
34
35 fn verifier_challenge(&mut self) -> Self::Challenge;
37}
38
39fn cardinal<F: PrimeField>() -> BigUint {
40 let bytes = (F::ZERO - F::ONE).to_repr();
41 BigUint::from_bytes_le(bytes.as_ref()) + BigUint::one()
42}
43
44#[derive(Clone)]
49pub struct ByteSchnorrCodec<G, H>
50where
51 G: PrimeGroup,
52 H: DuplexSpongeInterface,
53{
54 hasher: H,
55 _marker: core::marker::PhantomData<G>,
56}
57
58const WORD_SIZE: usize = 4;
59
60fn length_to_bytes(x: usize) -> [u8; WORD_SIZE] {
61 (x as u32).to_be_bytes()
62}
63
64pub fn compute_iv<H: DuplexSpongeInterface>(
69 protocol_id: &[u8; 64],
70 session_id: &[u8],
71 instance_label: &[u8],
72) -> [u8; 64] {
73 let mut tmp = H::new([0u8; 64]);
74 tmp.absorb(protocol_id);
75 tmp.absorb(&length_to_bytes(session_id.len()));
76 tmp.absorb(session_id);
77 tmp.absorb(&length_to_bytes(instance_label.len()));
78 tmp.absorb(instance_label);
79 tmp.squeeze(64).try_into().unwrap()
80}
81
82impl<G, H> Codec for ByteSchnorrCodec<G, H>
83where
84 G: PrimeGroup,
85 H: DuplexSpongeInterface,
86{
87 type Challenge = G::Scalar;
88
89 fn new(protocol_id: &[u8; 64], session_id: &[u8], instance_label: &[u8]) -> Self {
90 let mut hasher = H::new(*protocol_id);
91 hasher.absorb(&length_to_bytes(session_id.len()));
92 hasher.absorb(session_id);
93 hasher.absorb(&length_to_bytes(instance_label.len()));
94 hasher.absorb(instance_label);
95 Self {
96 hasher,
97 _marker: core::marker::PhantomData,
98 }
99 }
100
101 fn prover_message(&mut self, data: &[u8]) {
102 self.hasher.absorb(data);
103 }
104
105 fn verifier_challenge(&mut self) -> Self::Challenge {
106 #[allow(clippy::manual_div_ceil)]
107 let scalar_byte_length = (G::Scalar::NUM_BITS as usize + 7) / 8;
108
109 let uniform_bytes = self.hasher.squeeze(scalar_byte_length + 16);
110 let scalar = BigUint::from_bytes_be(&uniform_bytes);
111 let reduced = scalar % cardinal::<G::Scalar>();
112
113 let mut bytes = vec![0u8; scalar_byte_length];
114 let reduced_bytes = reduced.to_bytes_be();
115 let start = bytes.len() - reduced_bytes.len();
116 bytes[start..].copy_from_slice(&reduced_bytes);
117 bytes.reverse();
118
119 let mut repr = <G::Scalar as PrimeField>::Repr::default();
120 repr.as_mut().copy_from_slice(&bytes);
121
122 <G::Scalar as PrimeField>::from_repr(repr).expect("Error")
123 }
124}
125
126pub type KeccakByteSchnorrCodec<G> = ByteSchnorrCodec<G, KeccakDuplexSponge>;
129
130pub type Shake128DuplexSponge<G> = ByteSchnorrCodec<G, ShakeDuplexSponge>;