Skip to main content

sl_mpc_vrf/eval/
context.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4use curve25519_dalek::{RistrettoPoint, Scalar};
5use elliptic_curve::{group::GroupEncoding, subtle::ConstantTimeEq, Group};
6use rand::{CryptoRng, Rng, RngCore, SeedableRng};
7use rand_chacha::ChaCha20Rng;
8use sl_mpc_derive::{
9    impls::ristretto::VrfGroup,
10    math::{get_lagrange_coeff, participant_public_share},
11    ro_hash_string, ED25519_VRF_OUTPUT_BITS,
12};
13
14use crate::{
15    crypto::{calculate_final_session_id_pairs, hash_consistency, validate_input_messages},
16    dh_tuple::{dh_tuple_transcript, DhTuplePoints, DhTupleProof},
17    messages::{VrfMsg0, VrfMsg1},
18    types::{SessionId, VrfError},
19};
20
21/// Output of a successful MPC VRF evaluation.
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct VrfOutput {
24    pub output: Vec<u8>,
25    pub session_id: SessionId,
26    pub pid_list: Vec<u8>,
27}
28
29/// Key material and round state for MPC VRF evaluation.
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31#[derive(Clone)]
32pub struct Context {
33    party_id: u8,
34    threshold: u8,
35    total_parties: u8,
36    message: Vec<u8>,
37    output_bits: usize,
38    shamir_share: Scalar,
39    public_key: RistrettoPoint,
40    party_public_shares: Vec<RistrettoPoint>,
41
42    session_id: SessionId,
43    seed: [u8; 32],
44    /// When set (e.g. `keyshare-session-id` in multi-party-schnorr), mixed into the eval SID.
45    keygen_session_id: Option<SessionId>,
46
47    h_con: [u8; 32],
48    final_session_id: SessionId,
49    pid_list: Vec<u8>,
50    m_point: RistrettoPoint,
51
52    msg0_generated: bool,
53    round0_complete: bool,
54}
55
56impl Context {
57    #[allow(clippy::too_many_arguments)]
58    pub fn new<R: RngCore + CryptoRng>(
59        party_id: u8,
60        threshold: u8,
61        total_parties: u8,
62        message: Vec<u8>,
63        shamir_share: Scalar,
64        public_key: RistrettoPoint,
65        party_public_shares: Vec<RistrettoPoint>,
66        rng: &mut R,
67    ) -> Result<Self, VrfError> {
68        Self::new_with_output_bits(
69            party_id,
70            threshold,
71            total_parties,
72            message,
73            ED25519_VRF_OUTPUT_BITS,
74            shamir_share,
75            public_key,
76            party_public_shares,
77            None,
78            rng,
79        )
80    }
81
82    #[allow(clippy::too_many_arguments)]
83    pub fn new_with_output_bits<R: RngCore + CryptoRng>(
84        party_id: u8,
85        threshold: u8,
86        total_parties: u8,
87        message: Vec<u8>,
88        output_bits: usize,
89        shamir_share: Scalar,
90        public_key: RistrettoPoint,
91        party_public_shares: Vec<RistrettoPoint>,
92        keygen_session_id: Option<SessionId>,
93        rng: &mut R,
94    ) -> Result<Self, VrfError> {
95        if party_id >= total_parties {
96            return Err(VrfError::InvalidKeyshare);
97        }
98        if party_public_shares.len() != total_parties as usize {
99            return Err(VrfError::InvalidKeyshare);
100        }
101
102        Ok(Self {
103            party_id,
104            threshold,
105            total_parties,
106            message,
107            output_bits,
108            shamir_share,
109            public_key,
110            party_public_shares,
111            session_id: rng.gen(),
112            seed: rng.gen(),
113            keygen_session_id,
114            h_con: [0u8; 32],
115            final_session_id: [0u8; 32],
116            pid_list: Vec::new(),
117            m_point: RistrettoPoint::identity(),
118            msg0_generated: false,
119            round0_complete: false,
120        })
121    }
122
123    pub fn party_id(&self) -> u8 {
124        self.party_id
125    }
126
127    pub fn h_con(&self) -> &[u8; 32] {
128        &self.h_con
129    }
130
131    pub fn final_session_id(&self) -> &SessionId {
132        &self.final_session_id
133    }
134
135    pub fn pid_list(&self) -> &[u8] {
136        &self.pid_list
137    }
138
139    /// Round 0 outbound: consistency hash and session id contribution.
140    pub fn round0_out(&mut self) -> Result<VrfMsg0, VrfError> {
141        let vrf_ki = get_lagrange_coeff::<RistrettoPoint>(&self.party_id, 0..self.total_parties)
142            * self.shamir_share;
143        let ki = self.party_public_shares[self.party_id as usize];
144        let expected_ki = RistrettoPoint::generator() * vrf_ki;
145        if !bool::from(expected_ki.ct_eq(&ki)) {
146            return Err(VrfError::InvalidLocalKey);
147        }
148
149        let mut sum_k = RistrettoPoint::identity();
150        for share in &self.party_public_shares {
151            sum_k += *share;
152        }
153        if !bool::from(sum_k.ct_eq(&self.public_key)) {
154            return Err(VrfError::InvalidPublicShares);
155        }
156
157        self.h_con = hash_consistency(
158            &self.public_key,
159            &self.party_public_shares,
160            &self.message,
161            self.output_bits,
162        );
163        self.msg0_generated = true;
164
165        Ok(VrfMsg0 {
166            from_party: self.party_id,
167            session_id: self.session_id,
168            h_con: self.h_con,
169        })
170    }
171
172    /// Round 0 inbound: `messages` must include this party's round-0 message.
173    ///
174    /// When `quorum` is `None`, exactly `threshold` messages are required.
175    /// When `quorum` is `Some`, senders must match that party-id set exactly.
176    pub fn round0_in(
177        &mut self,
178        messages: Vec<VrfMsg0>,
179        quorum: Option<&[u8]>,
180    ) -> Result<VrfMsg1, VrfError> {
181        if !self.msg0_generated {
182            return Err(VrfError::InvalidState);
183        }
184
185        let mut messages = messages;
186        messages.sort_by_key(|m| m.from_party);
187
188        if let Some(quorum) = quorum {
189            if messages.len() != quorum.len() {
190                return Err(VrfError::InvalidMsgCount);
191            }
192            let mut expected = quorum.to_vec();
193            expected.sort_unstable();
194            let mut actual: Vec<u8> = messages.iter().map(|m| m.from_party).collect();
195            actual.sort_unstable();
196            if actual != expected {
197                return Err(VrfError::InvalidParticipantSet);
198            }
199        } else if messages.len() != self.threshold as usize {
200            return Err(VrfError::InvalidMsgCount);
201        }
202
203        let local_msg = messages
204            .iter()
205            .find(|msg| msg.from_party == self.party_id)
206            .ok_or(VrfError::InvalidParticipantSet)?;
207
208        if local_msg.session_id != self.session_id {
209            return Err(VrfError::InvalidParticipantSet);
210        }
211
212        let mut party_ids: Vec<u8> = messages.iter().map(|m| m.from_party).collect();
213        party_ids.sort_unstable();
214        party_ids.dedup();
215        if party_ids.len() != messages.len() {
216            return Err(VrfError::InvalidParticipantSet);
217        }
218
219        for msg in &messages {
220            if msg.from_party >= self.total_parties {
221                return Err(VrfError::InvalidMsgPartyId);
222            }
223            if msg.h_con != self.h_con {
224                return Err(VrfError::ConsistencyHashMismatch(msg.from_party));
225            }
226        }
227
228        let pairs: Vec<_> = messages
229            .iter()
230            .map(|m| (m.from_party, m.session_id))
231            .collect();
232        let extra = self.session_id_extra();
233        self.final_session_id = calculate_final_session_id_pairs(pairs, &extra);
234        self.pid_list = party_ids;
235
236        self.m_point = RistrettoPoint::hash_vrf_message(&[self.message.as_slice()])
237            .map_err(|_| VrfError::HashToCurve)?;
238
239        let coeff =
240            get_lagrange_coeff::<RistrettoPoint>(&self.party_id, self.pid_list.iter().copied());
241        let vrf_ki = coeff * self.shamir_share;
242        let party_ki = RistrettoPoint::generator() * vrf_ki;
243
244        let aux = (self.party_id as u32).to_be_bytes();
245        let mut proof_rng = ChaCha20Rng::from_seed(self.seed);
246        let z_i = self.m_point * vrf_ki;
247        let mut transcript = dh_tuple_transcript(&self.final_session_id, &aux);
248        let pi = DhTupleProof::prove(
249            DhTuplePoints {
250                g: &RistrettoPoint::generator(),
251                q: &self.m_point,
252                a: &party_ki,
253                b: &z_i,
254            },
255            &vrf_ki,
256            &mut transcript,
257            &mut proof_rng,
258        );
259
260        self.round0_complete = true;
261
262        Ok(VrfMsg1 {
263            from_party: self.party_id,
264            session_id: self.final_session_id,
265            z_i: z_i.compress().to_bytes().to_vec(),
266            pi,
267        })
268    }
269
270    /// Round 1 inbound: verify partial points and derive the VRF output.
271    pub fn round1_in(&self, messages: Vec<VrfMsg1>) -> Result<VrfOutput, VrfError> {
272        if !self.round0_complete {
273            return Err(VrfError::InvalidState);
274        }
275        let messages = validate_input_messages(messages, &self.pid_list)?;
276
277        let g = RistrettoPoint::generator();
278        let mut z = RistrettoPoint::identity();
279
280        for msg in &messages {
281            if msg.from_party >= self.total_parties {
282                return Err(VrfError::InvalidMsgPartyId);
283            }
284            if msg.session_id != self.final_session_id {
285                return Err(VrfError::InvalidDhProof(msg.from_party));
286            }
287            let z_j = decode_point(&msg.z_i).ok_or(VrfError::InvalidZ(msg.from_party))?;
288            if !RistrettoPoint::is_valid_partial_vrf_point(&z_j) {
289                return Err(VrfError::InvalidZ(msg.from_party));
290            }
291
292            let aux = (msg.from_party as u32).to_be_bytes();
293            let k_j = participant_public_share(
294                &self.party_public_shares[msg.from_party as usize],
295                msg.from_party,
296                self.total_parties,
297                self.pid_list.iter().copied(),
298            );
299            let mut transcript = dh_tuple_transcript(&self.final_session_id, &aux);
300            if !msg.pi.verify(
301                DhTuplePoints {
302                    g: &g,
303                    q: &self.m_point,
304                    a: &k_j,
305                    b: &z_j,
306                },
307                &mut transcript,
308            ) {
309                return Err(VrfError::InvalidDhProof(msg.from_party));
310            }
311
312            z += z_j;
313        }
314
315        let output = ro_hash_string(&[z.compress().as_bytes()], self.output_bits);
316
317        Ok(VrfOutput {
318            output,
319            session_id: self.final_session_id,
320            pid_list: self.pid_list.clone(),
321        })
322    }
323
324    fn session_id_extra(&self) -> Vec<&[u8]> {
325        match self.keygen_session_id {
326            Some(ref sid) => vec![self.message.as_slice(), sid.as_slice()],
327            None => vec![self.message.as_slice()],
328        }
329    }
330}
331
332fn decode_point(bytes: &[u8]) -> Option<RistrettoPoint> {
333    let mut encoding = <RistrettoPoint as GroupEncoding>::Repr::default();
334    if encoding.as_ref().len() != bytes.len() {
335        return None;
336    }
337    encoding.as_mut().copy_from_slice(bytes);
338    Option::from(RistrettoPoint::from_bytes(&encoding))
339}
340
341#[cfg(test)]
342pub(crate) mod test_support {
343    use super::*;
344    use crate::dkg::VrfKeyshare;
345
346    pub fn init_eval_states(shares: &[VrfKeyshare], message: &[u8]) -> Vec<Context> {
347        let mut rng = rand::thread_rng();
348        shares
349            .iter()
350            .map(|ks| {
351                Context::new(
352                    ks.party_id,
353                    ks.threshold,
354                    ks.total_parties,
355                    message.to_vec(),
356                    *ks.shamir_share(),
357                    *ks.public_key(),
358                    ks.party_public_shares().to_vec(),
359                    &mut rng,
360                )
361                .unwrap()
362            })
363            .collect()
364    }
365
366    pub fn vrf_eval_inner(mut parties: Vec<Context>, threshold: usize) -> Vec<VrfOutput> {
367        let msg0: Vec<VrfMsg0> = parties
368            .iter_mut()
369            .map(|party| party.round0_out().unwrap())
370            .collect();
371        let msg0_subset: Vec<_> = msg0.into_iter().take(threshold).collect();
372
373        let mut active: Vec<Context> = parties.into_iter().take(threshold).collect();
374        let msg1: Vec<VrfMsg1> = active
375            .iter_mut()
376            .map(|party| party.round0_in(msg0_subset.clone(), None).unwrap())
377            .collect();
378
379        active
380            .into_iter()
381            .map(|party| party.round1_in(msg1.clone()).unwrap())
382            .collect()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::test_support::{init_eval_states, vrf_eval_inner};
389    use super::{Context, VrfError};
390    use crate::dkg::test_support::{init_states, vrf_dkg_inner};
391
392    fn assert_matching_outputs(outputs: &[super::VrfOutput]) {
393        let reference = &outputs[0].output;
394        for output in &outputs[1..] {
395            assert_eq!(&output.output, reference);
396            assert_eq!(output.session_id, outputs[0].session_id);
397            assert_eq!(output.pid_list, outputs[0].pid_list);
398        }
399    }
400
401    #[test]
402    fn vrf_eval_3_out_of_3() {
403        let shares = vrf_dkg_inner(init_states(3, 3));
404        let parties = init_eval_states(&shares, b"vrf-eval-3x3");
405        let outputs = vrf_eval_inner(parties, 3);
406        assert_eq!(outputs.len(), 3);
407        assert_matching_outputs(&outputs);
408    }
409
410    #[test]
411    fn vrf_eval_2_out_of_3() {
412        let shares = vrf_dkg_inner(init_states(3, 2));
413        let parties = init_eval_states(&shares, b"vrf-eval-2x3");
414        let outputs = vrf_eval_inner(parties, 2);
415        assert_eq!(outputs.len(), 2);
416        assert_matching_outputs(&outputs);
417    }
418
419    #[test]
420    fn vrf_eval_same_message_same_output() {
421        let shares = vrf_dkg_inner(init_states(3, 2));
422        let message = b"deterministic-input";
423
424        let outputs_a = vrf_eval_inner(init_eval_states(&shares, message), 2);
425        let outputs_b = vrf_eval_inner(init_eval_states(&shares, message), 2);
426
427        assert_eq!(outputs_a[0].output, outputs_b[0].output);
428    }
429
430    #[test]
431    fn new_rejects_malformed_keyshare_public_shares_len() {
432        let shares = vrf_dkg_inner(init_states(3, 2));
433        let ks = &shares[0];
434        let mut party_public_shares = ks.party_public_shares().to_vec();
435        party_public_shares.pop();
436
437        assert!(matches!(
438            Context::new(
439                ks.party_id,
440                ks.threshold,
441                ks.total_parties,
442                b"msg".to_vec(),
443                *ks.shamir_share(),
444                *ks.public_key(),
445                party_public_shares,
446                &mut rand::thread_rng(),
447            ),
448            Err(VrfError::InvalidKeyshare)
449        ));
450    }
451
452    #[test]
453    fn new_rejects_malformed_keyshare_party_id() {
454        let shares = vrf_dkg_inner(init_states(3, 2));
455        let ks = &shares[0];
456
457        assert!(matches!(
458            Context::new(
459                ks.total_parties,
460                ks.threshold,
461                ks.total_parties,
462                b"msg".to_vec(),
463                *ks.shamir_share(),
464                *ks.public_key(),
465                ks.party_public_shares().to_vec(),
466                &mut rand::thread_rng(),
467            ),
468            Err(VrfError::InvalidKeyshare)
469        ));
470    }
471}