1use curve25519_dalek::{RistrettoPoint, Scalar};
5use elliptic_curve::{group::GroupEncoding, subtle::ConstantTimeEq, Group};
6use rand::{CryptoRng, Rng, RngCore, SeedableRng};
7use rand_chacha::ChaCha20Rng;
8use sl_mpc_derive::{
9 impls::ristretto::VrfGroup,
10 math::{get_lagrange_coeff, participant_public_share},
11 ro_hash_string, ED25519_VRF_OUTPUT_BITS,
12};
13
14use crate::{
15 crypto::{calculate_final_session_id_pairs, hash_consistency, validate_input_messages},
16 dh_tuple::{dh_tuple_transcript, DhTuplePoints, DhTupleProof},
17 messages::{VrfMsg0, VrfMsg1},
18 types::{SessionId, VrfError},
19};
20
21#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct VrfOutput {
24 pub output: Vec<u8>,
25 pub session_id: SessionId,
26 pub pid_list: Vec<u8>,
27}
28
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31#[derive(Clone)]
32pub struct Context {
33 party_id: u8,
34 threshold: u8,
35 total_parties: u8,
36 message: Vec<u8>,
37 output_bits: usize,
38 shamir_share: Scalar,
39 public_key: RistrettoPoint,
40 party_public_shares: Vec<RistrettoPoint>,
41
42 session_id: SessionId,
43 seed: [u8; 32],
44 keygen_session_id: Option<SessionId>,
46
47 h_con: [u8; 32],
48 final_session_id: SessionId,
49 pid_list: Vec<u8>,
50 m_point: RistrettoPoint,
51
52 msg0_generated: bool,
53 round0_complete: bool,
54}
55
56impl Context {
57 #[allow(clippy::too_many_arguments)]
58 pub fn new<R: RngCore + CryptoRng>(
59 party_id: u8,
60 threshold: u8,
61 total_parties: u8,
62 message: Vec<u8>,
63 shamir_share: Scalar,
64 public_key: RistrettoPoint,
65 party_public_shares: Vec<RistrettoPoint>,
66 rng: &mut R,
67 ) -> Result<Self, VrfError> {
68 Self::new_with_output_bits(
69 party_id,
70 threshold,
71 total_parties,
72 message,
73 ED25519_VRF_OUTPUT_BITS,
74 shamir_share,
75 public_key,
76 party_public_shares,
77 None,
78 rng,
79 )
80 }
81
82 #[allow(clippy::too_many_arguments)]
83 pub fn new_with_output_bits<R: RngCore + CryptoRng>(
84 party_id: u8,
85 threshold: u8,
86 total_parties: u8,
87 message: Vec<u8>,
88 output_bits: usize,
89 shamir_share: Scalar,
90 public_key: RistrettoPoint,
91 party_public_shares: Vec<RistrettoPoint>,
92 keygen_session_id: Option<SessionId>,
93 rng: &mut R,
94 ) -> Result<Self, VrfError> {
95 if party_id >= total_parties {
96 return Err(VrfError::InvalidKeyshare);
97 }
98 if party_public_shares.len() != total_parties as usize {
99 return Err(VrfError::InvalidKeyshare);
100 }
101
102 Ok(Self {
103 party_id,
104 threshold,
105 total_parties,
106 message,
107 output_bits,
108 shamir_share,
109 public_key,
110 party_public_shares,
111 session_id: rng.gen(),
112 seed: rng.gen(),
113 keygen_session_id,
114 h_con: [0u8; 32],
115 final_session_id: [0u8; 32],
116 pid_list: Vec::new(),
117 m_point: RistrettoPoint::identity(),
118 msg0_generated: false,
119 round0_complete: false,
120 })
121 }
122
123 pub fn party_id(&self) -> u8 {
124 self.party_id
125 }
126
127 pub fn h_con(&self) -> &[u8; 32] {
128 &self.h_con
129 }
130
131 pub fn final_session_id(&self) -> &SessionId {
132 &self.final_session_id
133 }
134
135 pub fn pid_list(&self) -> &[u8] {
136 &self.pid_list
137 }
138
139 pub fn round0_out(&mut self) -> Result<VrfMsg0, VrfError> {
141 let vrf_ki = get_lagrange_coeff::<RistrettoPoint>(&self.party_id, 0..self.total_parties)
142 * self.shamir_share;
143 let ki = self.party_public_shares[self.party_id as usize];
144 let expected_ki = RistrettoPoint::generator() * vrf_ki;
145 if !bool::from(expected_ki.ct_eq(&ki)) {
146 return Err(VrfError::InvalidLocalKey);
147 }
148
149 let mut sum_k = RistrettoPoint::identity();
150 for share in &self.party_public_shares {
151 sum_k += *share;
152 }
153 if !bool::from(sum_k.ct_eq(&self.public_key)) {
154 return Err(VrfError::InvalidPublicShares);
155 }
156
157 self.h_con = hash_consistency(
158 &self.public_key,
159 &self.party_public_shares,
160 &self.message,
161 self.output_bits,
162 );
163 self.msg0_generated = true;
164
165 Ok(VrfMsg0 {
166 from_party: self.party_id,
167 session_id: self.session_id,
168 h_con: self.h_con,
169 })
170 }
171
172 pub fn round0_in(
177 &mut self,
178 messages: Vec<VrfMsg0>,
179 quorum: Option<&[u8]>,
180 ) -> Result<VrfMsg1, VrfError> {
181 if !self.msg0_generated {
182 return Err(VrfError::InvalidState);
183 }
184
185 let mut messages = messages;
186 messages.sort_by_key(|m| m.from_party);
187
188 if let Some(quorum) = quorum {
189 if messages.len() != quorum.len() {
190 return Err(VrfError::InvalidMsgCount);
191 }
192 let mut expected = quorum.to_vec();
193 expected.sort_unstable();
194 let mut actual: Vec<u8> = messages.iter().map(|m| m.from_party).collect();
195 actual.sort_unstable();
196 if actual != expected {
197 return Err(VrfError::InvalidParticipantSet);
198 }
199 } else if messages.len() != self.threshold as usize {
200 return Err(VrfError::InvalidMsgCount);
201 }
202
203 let local_msg = messages
204 .iter()
205 .find(|msg| msg.from_party == self.party_id)
206 .ok_or(VrfError::InvalidParticipantSet)?;
207
208 if local_msg.session_id != self.session_id {
209 return Err(VrfError::InvalidParticipantSet);
210 }
211
212 let mut party_ids: Vec<u8> = messages.iter().map(|m| m.from_party).collect();
213 party_ids.sort_unstable();
214 party_ids.dedup();
215 if party_ids.len() != messages.len() {
216 return Err(VrfError::InvalidParticipantSet);
217 }
218
219 for msg in &messages {
220 if msg.from_party >= self.total_parties {
221 return Err(VrfError::InvalidMsgPartyId);
222 }
223 if msg.h_con != self.h_con {
224 return Err(VrfError::ConsistencyHashMismatch(msg.from_party));
225 }
226 }
227
228 let pairs: Vec<_> = messages
229 .iter()
230 .map(|m| (m.from_party, m.session_id))
231 .collect();
232 let extra = self.session_id_extra();
233 self.final_session_id = calculate_final_session_id_pairs(pairs, &extra);
234 self.pid_list = party_ids;
235
236 self.m_point = RistrettoPoint::hash_vrf_message(&[self.message.as_slice()])
237 .map_err(|_| VrfError::HashToCurve)?;
238
239 let coeff =
240 get_lagrange_coeff::<RistrettoPoint>(&self.party_id, self.pid_list.iter().copied());
241 let vrf_ki = coeff * self.shamir_share;
242 let party_ki = RistrettoPoint::generator() * vrf_ki;
243
244 let aux = (self.party_id as u32).to_be_bytes();
245 let mut proof_rng = ChaCha20Rng::from_seed(self.seed);
246 let z_i = self.m_point * vrf_ki;
247 let mut transcript = dh_tuple_transcript(&self.final_session_id, &aux);
248 let pi = DhTupleProof::prove(
249 DhTuplePoints {
250 g: &RistrettoPoint::generator(),
251 q: &self.m_point,
252 a: &party_ki,
253 b: &z_i,
254 },
255 &vrf_ki,
256 &mut transcript,
257 &mut proof_rng,
258 );
259
260 self.round0_complete = true;
261
262 Ok(VrfMsg1 {
263 from_party: self.party_id,
264 session_id: self.final_session_id,
265 z_i: z_i.compress().to_bytes().to_vec(),
266 pi,
267 })
268 }
269
270 pub fn round1_in(&self, messages: Vec<VrfMsg1>) -> Result<VrfOutput, VrfError> {
272 if !self.round0_complete {
273 return Err(VrfError::InvalidState);
274 }
275 let messages = validate_input_messages(messages, &self.pid_list)?;
276
277 let g = RistrettoPoint::generator();
278 let mut z = RistrettoPoint::identity();
279
280 for msg in &messages {
281 if msg.from_party >= self.total_parties {
282 return Err(VrfError::InvalidMsgPartyId);
283 }
284 if msg.session_id != self.final_session_id {
285 return Err(VrfError::InvalidDhProof(msg.from_party));
286 }
287 let z_j = decode_point(&msg.z_i).ok_or(VrfError::InvalidZ(msg.from_party))?;
288 if !RistrettoPoint::is_valid_partial_vrf_point(&z_j) {
289 return Err(VrfError::InvalidZ(msg.from_party));
290 }
291
292 let aux = (msg.from_party as u32).to_be_bytes();
293 let k_j = participant_public_share(
294 &self.party_public_shares[msg.from_party as usize],
295 msg.from_party,
296 self.total_parties,
297 self.pid_list.iter().copied(),
298 );
299 let mut transcript = dh_tuple_transcript(&self.final_session_id, &aux);
300 if !msg.pi.verify(
301 DhTuplePoints {
302 g: &g,
303 q: &self.m_point,
304 a: &k_j,
305 b: &z_j,
306 },
307 &mut transcript,
308 ) {
309 return Err(VrfError::InvalidDhProof(msg.from_party));
310 }
311
312 z += z_j;
313 }
314
315 let output = ro_hash_string(&[z.compress().as_bytes()], self.output_bits);
316
317 Ok(VrfOutput {
318 output,
319 session_id: self.final_session_id,
320 pid_list: self.pid_list.clone(),
321 })
322 }
323
324 fn session_id_extra(&self) -> Vec<&[u8]> {
325 match self.keygen_session_id {
326 Some(ref sid) => vec![self.message.as_slice(), sid.as_slice()],
327 None => vec![self.message.as_slice()],
328 }
329 }
330}
331
332fn decode_point(bytes: &[u8]) -> Option<RistrettoPoint> {
333 let mut encoding = <RistrettoPoint as GroupEncoding>::Repr::default();
334 if encoding.as_ref().len() != bytes.len() {
335 return None;
336 }
337 encoding.as_mut().copy_from_slice(bytes);
338 Option::from(RistrettoPoint::from_bytes(&encoding))
339}
340
341#[cfg(test)]
342pub(crate) mod test_support {
343 use super::*;
344 use crate::dkg::VrfKeyshare;
345
346 pub fn init_eval_states(shares: &[VrfKeyshare], message: &[u8]) -> Vec<Context> {
347 let mut rng = rand::thread_rng();
348 shares
349 .iter()
350 .map(|ks| {
351 Context::new(
352 ks.party_id,
353 ks.threshold,
354 ks.total_parties,
355 message.to_vec(),
356 *ks.shamir_share(),
357 *ks.public_key(),
358 ks.party_public_shares().to_vec(),
359 &mut rng,
360 )
361 .unwrap()
362 })
363 .collect()
364 }
365
366 pub fn vrf_eval_inner(mut parties: Vec<Context>, threshold: usize) -> Vec<VrfOutput> {
367 let msg0: Vec<VrfMsg0> = parties
368 .iter_mut()
369 .map(|party| party.round0_out().unwrap())
370 .collect();
371 let msg0_subset: Vec<_> = msg0.into_iter().take(threshold).collect();
372
373 let mut active: Vec<Context> = parties.into_iter().take(threshold).collect();
374 let msg1: Vec<VrfMsg1> = active
375 .iter_mut()
376 .map(|party| party.round0_in(msg0_subset.clone(), None).unwrap())
377 .collect();
378
379 active
380 .into_iter()
381 .map(|party| party.round1_in(msg1.clone()).unwrap())
382 .collect()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::test_support::{init_eval_states, vrf_eval_inner};
389 use super::{Context, VrfError};
390 use crate::dkg::test_support::{init_states, vrf_dkg_inner};
391
392 fn assert_matching_outputs(outputs: &[super::VrfOutput]) {
393 let reference = &outputs[0].output;
394 for output in &outputs[1..] {
395 assert_eq!(&output.output, reference);
396 assert_eq!(output.session_id, outputs[0].session_id);
397 assert_eq!(output.pid_list, outputs[0].pid_list);
398 }
399 }
400
401 #[test]
402 fn vrf_eval_3_out_of_3() {
403 let shares = vrf_dkg_inner(init_states(3, 3));
404 let parties = init_eval_states(&shares, b"vrf-eval-3x3");
405 let outputs = vrf_eval_inner(parties, 3);
406 assert_eq!(outputs.len(), 3);
407 assert_matching_outputs(&outputs);
408 }
409
410 #[test]
411 fn vrf_eval_2_out_of_3() {
412 let shares = vrf_dkg_inner(init_states(3, 2));
413 let parties = init_eval_states(&shares, b"vrf-eval-2x3");
414 let outputs = vrf_eval_inner(parties, 2);
415 assert_eq!(outputs.len(), 2);
416 assert_matching_outputs(&outputs);
417 }
418
419 #[test]
420 fn vrf_eval_same_message_same_output() {
421 let shares = vrf_dkg_inner(init_states(3, 2));
422 let message = b"deterministic-input";
423
424 let outputs_a = vrf_eval_inner(init_eval_states(&shares, message), 2);
425 let outputs_b = vrf_eval_inner(init_eval_states(&shares, message), 2);
426
427 assert_eq!(outputs_a[0].output, outputs_b[0].output);
428 }
429
430 #[test]
431 fn new_rejects_malformed_keyshare_public_shares_len() {
432 let shares = vrf_dkg_inner(init_states(3, 2));
433 let ks = &shares[0];
434 let mut party_public_shares = ks.party_public_shares().to_vec();
435 party_public_shares.pop();
436
437 assert!(matches!(
438 Context::new(
439 ks.party_id,
440 ks.threshold,
441 ks.total_parties,
442 b"msg".to_vec(),
443 *ks.shamir_share(),
444 *ks.public_key(),
445 party_public_shares,
446 &mut rand::thread_rng(),
447 ),
448 Err(VrfError::InvalidKeyshare)
449 ));
450 }
451
452 #[test]
453 fn new_rejects_malformed_keyshare_party_id() {
454 let shares = vrf_dkg_inner(init_states(3, 2));
455 let ks = &shares[0];
456
457 assert!(matches!(
458 Context::new(
459 ks.total_parties,
460 ks.threshold,
461 ks.total_parties,
462 b"msg".to_vec(),
463 *ks.shamir_share(),
464 *ks.public_key(),
465 ks.party_public_shares().to_vec(),
466 &mut rand::thread_rng(),
467 ),
468 Err(VrfError::InvalidKeyshare)
469 ));
470 }
471}