1use crate::error::{Error, Result};
7use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
8use curve25519_dalek::scalar::Scalar;
9
10pub trait Commitment: Clone + Send + Sync {
12 fn to_bytes(&self) -> Vec<u8>;
13 fn from_bytes(bytes: &[u8]) -> Result<Self>
14 where
15 Self: Sized;
16}
17
18pub trait Challenge: Clone + Send + Sync {
20 fn to_bytes(&self) -> Vec<u8>;
21 fn from_bytes(bytes: &[u8]) -> Result<Self>
22 where
23 Self: Sized;
24}
25
26pub trait Response: Clone + Send + Sync {
28 fn to_bytes(&self) -> Vec<u8>;
29 fn from_bytes(bytes: &[u8]) -> Result<Self>
30 where
31 Self: Sized;
32}
33
34pub trait SigmaProtocol {
41 type Statement: Clone + Send + Sync;
43 type Witness: Clone + Send + Sync;
45 type Commitment: Commitment;
47 type Challenge: Challenge;
49 type Response: Response;
51
52 fn prover_commit(
56 statement: &Self::Statement,
57 witness: &Self::Witness,
58 ) -> (Self::Commitment, Vec<u8>);
59
60 fn prover_response(
64 statement: &Self::Statement,
65 witness: &Self::Witness,
66 state: &[u8],
67 challenge: &Self::Challenge,
68 ) -> Result<Self::Response>;
69
70 fn verifier(
74 statement: &Self::Statement,
75 commitment: &Self::Commitment,
76 challenge: &Self::Challenge,
77 response: &Self::Response,
78 ) -> Result<()>;
79}
80
81fn scalar_from_bytes(bytes: &[u8]) -> Result<Scalar> {
83 if bytes.len() != 32 {
84 return Err(Error::InvalidScalar);
85 }
86 let mut array = [0u8; 32];
87 array.copy_from_slice(bytes);
88
89 Scalar::from_canonical_bytes(array)
90 .into_option()
91 .ok_or(Error::InvalidScalar)
92}
93
94fn point_from_bytes(bytes: &[u8]) -> Result<RistrettoPoint> {
96 if bytes.len() != 32 {
97 return Err(Error::InvalidPoint);
98 }
99 let mut array = [0u8; 32];
100 array.copy_from_slice(bytes);
101
102 CompressedRistretto::from_slice(&array)
103 .map_err(|_| Error::InvalidPoint)?
104 .decompress()
105 .ok_or(Error::InvalidPoint)
106}
107
108#[derive(Clone, Debug)]
110pub struct ScalarChallenge(pub Scalar);
111
112impl Challenge for ScalarChallenge {
113 fn to_bytes(&self) -> Vec<u8> {
114 self.0.to_bytes().to_vec()
115 }
116
117 fn from_bytes(bytes: &[u8]) -> Result<Self> {
118 if bytes.len() != 32 {
119 return Err(Error::InvalidChallenge);
120 }
121 let mut array = [0u8; 32];
122 array.copy_from_slice(bytes);
123 Ok(ScalarChallenge(Scalar::from_bytes_mod_order(array)))
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct PointCommitment(pub RistrettoPoint);
131
132impl Commitment for PointCommitment {
133 fn to_bytes(&self) -> Vec<u8> {
134 self.0.compress().to_bytes().to_vec()
135 }
136
137 fn from_bytes(bytes: &[u8]) -> Result<Self> {
138 point_from_bytes(bytes)
139 .map(PointCommitment)
140 .map_err(|_| Error::InvalidCommitment)
141 }
142}
143
144#[derive(Clone, Debug)]
146pub struct ScalarResponse(pub Scalar);
147
148impl Response for ScalarResponse {
149 fn to_bytes(&self) -> Vec<u8> {
150 self.0.to_bytes().to_vec()
151 }
152
153 fn from_bytes(bytes: &[u8]) -> Result<Self> {
154 scalar_from_bytes(bytes)
155 .map(ScalarResponse)
156 .map_err(|_| Error::InvalidResponse)
157 }
158}
159
160#[derive(Clone, Debug)]
162pub struct MultiPointCommitment(pub Vec<RistrettoPoint>);
163
164impl Commitment for MultiPointCommitment {
165 fn to_bytes(&self) -> Vec<u8> {
166 let mut bytes = Vec::with_capacity(self.0.len() * 32 + 4);
167 bytes.extend_from_slice(&(self.0.len() as u32).to_le_bytes());
168
169 for point in &self.0 {
170 bytes.extend_from_slice(&point.compress().to_bytes());
171 }
172
173 bytes
174 }
175
176 fn from_bytes(bytes: &[u8]) -> Result<Self> {
177 if bytes.len() < 4 {
178 return Err(Error::InvalidCommitment);
179 }
180
181 let len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
182
183 if bytes.len() != 4 + len * 32 {
184 return Err(Error::InvalidCommitment);
185 }
186
187 let mut points = Vec::with_capacity(len);
188 for i in 0..len {
189 let start = 4 + i * 32;
190 let end = start + 32;
191 points.push(point_from_bytes(&bytes[start..end])?)
192 }
193
194 Ok(MultiPointCommitment(points))
195 }
196}
197
198#[derive(Clone, Debug)]
200pub struct MultiScalarResponse(pub Vec<Scalar>);
201
202impl Response for MultiScalarResponse {
203 fn to_bytes(&self) -> Vec<u8> {
204 let mut bytes = Vec::with_capacity(self.0.len() * 32 + 4);
205 bytes.extend_from_slice(&(self.0.len() as u32).to_le_bytes());
206
207 for scalar in &self.0 {
208 bytes.extend_from_slice(&scalar.to_bytes());
209 }
210
211 bytes
212 }
213
214 fn from_bytes(bytes: &[u8]) -> Result<Self> {
215 if bytes.len() < 4 {
216 return Err(Error::InvalidResponse);
217 }
218
219 let len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
220
221 if bytes.len() != 4 + len * 32 {
222 return Err(Error::InvalidResponse);
223 }
224
225 let mut scalars = Vec::with_capacity(len);
226 for i in 0..len {
227 let start = 4 + i * 32;
228 let end = start + 32;
229 scalars.push(scalar_from_bytes(&bytes[start..end])?)
230 }
231
232 Ok(MultiScalarResponse(scalars))
233 }
234}