1use rand::Rng;
2use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
3use serde::Serialize;
4use vrf::openssl::{CipherSuite, ECVRF};
5use vrf::VRF;
6
7use crate::error::{Error, Result};
8use crate::types::Role;
9
10#[derive(Debug, Serialize)]
11pub struct VrfParams {
12 pub weight: u64,
13 pub round: u64,
14 pub seed: u64,
15 pub role: Role,
16}
17
18pub struct VrfProof {
19 pub proof: Vec<u8>,
20 pub hash: Vec<u8>,
21}
22
23pub struct VrfClient {
24 inner: ECVRF,
25}
26
27impl VrfClient {
28 pub fn new() -> Result<Self> {
29 let inner = ECVRF::from_suite(CipherSuite::SECP256K1_SHA256_TAI)
30 .map_err(|e| Error::Vrf(format!("{:?}", e)))?;
31 Ok(Self { inner })
32 }
33
34 pub fn generate_keys(&mut self, seed: u64) -> Result<(Vec<u8>, Vec<u8>)> {
36 let mut rng = ChaCha8Rng::seed_from_u64(seed);
37 let secret_key: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
38 let public_key = self
39 .inner
40 .derive_public_key(&secret_key)
41 .map_err(|e| Error::Vrf(format!("{:?}", e)))?;
42 Ok((secret_key, public_key))
43 }
44
45 pub fn prove(&mut self, secret_key: &[u8], params: &VrfParams) -> Result<VrfProof> {
47 let data = bincode::serialize(params).map_err(|e| Error::Vrf(e.to_string()))?;
48 let proof = self
49 .inner
50 .prove(secret_key, &data)
51 .map_err(|e| Error::Vrf(format!("{:?}", e)))?;
52 let hash = self
53 .inner
54 .proof_to_hash(&proof)
55 .map_err(|e| Error::Vrf(format!("{:?}", e)))?;
56 Ok(VrfProof { proof, hash })
57 }
58
59 pub fn verify(
61 &mut self,
62 public_key: &[u8],
63 vrf_proof: &VrfProof,
64 params: &VrfParams,
65 ) -> Result<bool> {
66 let data = bincode::serialize(params).map_err(|e| Error::Vrf(e.to_string()))?;
67 let beta = self
68 .inner
69 .verify(public_key, &vrf_proof.proof, &data)
70 .map_err(|e| Error::Vrf(format!("{:?}", e)))?;
71 Ok(beta == vrf_proof.hash)
72 }
73}
74
75fn log_binom_pmf(k: u64, n: u64, p: f64) -> f64 {
78 if p <= 0.0 {
79 return if k == 0 { 0.0 } else { f64::NEG_INFINITY };
80 }
81 if p >= 1.0 {
82 return if k == n { 0.0 } else { f64::NEG_INFINITY };
83 }
84 if k == 0 {
85 return (n as f64) * (1.0 - p).ln();
86 }
87 if k == n {
88 return (n as f64) * p.ln();
89 }
90 let (lgn1, _) = libm::lgamma_r((n + 1) as f64);
91 let (lgk1, _) = libm::lgamma_r((k + 1) as f64);
92 let (lgnk1, _) = libm::lgamma_r((n - k + 1) as f64);
93 lgn1 - lgk1 - lgnk1 + (k as f64) * p.ln() + ((n - k) as f64) * (1.0 - p).ln()
94}
95
96pub fn sortition(vrf_hash: &[u8], weight: u64, expected: u64, total: u64) -> u64 {
99 if weight == 0 || total == 0 || expected == 0 {
100 return 0;
101 }
102
103 let p = expected as f64 / total as f64;
104
105 let mut bytes = [0u8; 8];
107 let len = vrf_hash.len().min(8);
108 bytes[..len].copy_from_slice(&vrf_hash[..len]);
109 let val = u64::from_le_bytes(bytes) as f64 / u64::MAX as f64;
110
111 let mut cumulative = 0.0;
113 for k in 0..=weight {
114 cumulative += log_binom_pmf(k, weight, p).exp();
115 if val < cumulative {
116 return k;
117 }
118 }
119 weight
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn sortition_zero_weight_returns_zero() {
128 assert_eq!(sortition(&[0; 32], 0, 10, 100), 0);
129 }
130
131 #[test]
132 fn sortition_zero_expected_returns_zero() {
133 assert_eq!(sortition(&[0; 32], 100, 0, 1000), 0);
134 }
135}