vrf_wasm/groups/
ristretto255.rs

1// Copyright (c) 2022, Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Implementations of the [ristretto255 group](https://www.ietf.org/archive/id/draft-irtf-cfrg-ristretto255-decaf448-03.html) which is a group of
5//! prime order 2^{252} + 27742317777372353535851937790883648493 built over Curve25519.
6
7use crate::error::FastCryptoResult;
8use crate::groups::{
9    Doubling, FiatShamirChallenge, GroupElement, HashToGroupElement, MultiScalarMul, Scalar,
10};
11use crate::hash::Sha512;
12use crate::serde_helpers::ToFromByteArray;
13use crate::traits::AllowedRng;
14use crate::{
15    error::FastCryptoError, hash::HashFunction, serialize_deserialize_with_to_from_byte_array,
16};
17use curve25519_dalek_ng;
18use curve25519_dalek_ng::constants::{BASEPOINT_ORDER, RISTRETTO_BASEPOINT_POINT};
19use curve25519_dalek_ng::ristretto::CompressedRistretto as ExternalCompressedRistrettoPoint;
20use curve25519_dalek_ng::ristretto::RistrettoPoint as ExternalRistrettoPoint;
21use curve25519_dalek_ng::scalar::Scalar as ExternalRistrettoScalar;
22use curve25519_dalek_ng::traits::Identity;
23use derive_more::{Add, Div, From, Neg, Sub};
24use fastcrypto_derive::GroupOpsExtend;
25use serde::{de, Deserialize};
26use std::ops::{Add, Div, Mul};
27use zeroize::Zeroize;
28
29const RISTRETTO_POINT_BYTE_LENGTH: usize = 32;
30const RISTRETTO_SCALAR_BYTE_LENGTH: usize = 32;
31
32/// Represents a point in the Ristretto group for Curve25519.
33#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, From, Add, Sub, Neg, GroupOpsExtend)]
34pub struct RistrettoPoint(ExternalRistrettoPoint);
35
36impl RistrettoPoint {
37    /// Construct a RistrettoPoint from the given data using an Ristretto-flavoured Elligator 2 map.
38    /// If the input bytes are uniformly distributed, the resulting point will be uniformly
39    /// distributed over the Ristretto group.
40    pub fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
41        RistrettoPoint::from(ExternalRistrettoPoint::from_uniform_bytes(bytes))
42    }
43
44    /// Construct a RistrettoPoint from the given data using a given hash function.
45    pub fn map_to_point<H: HashFunction<64>>(bytes: &[u8]) -> Self {
46        Self::from_uniform_bytes(&H::digest(bytes).digest)
47    }
48
49    /// Return this point in compressed form.
50    pub fn compress(&self) -> [u8; 32] {
51        self.0.compress().0
52    }
53
54    /// Return this point in compressed form.
55    pub fn decompress(bytes: &[u8; 32]) -> Result<Self, FastCryptoError> {
56        RistrettoPoint::try_from(bytes.as_slice())
57    }
58}
59
60impl Doubling for RistrettoPoint {
61    fn double(self) -> Self {
62        Self(self.0.add(self.0))
63    }
64}
65
66impl MultiScalarMul for RistrettoPoint {
67    fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult<Self> {
68        if scalars.len() != points.len() {
69            return Err(FastCryptoError::InvalidInput);
70        }
71
72        Ok(RistrettoPoint(
73            scalars.iter().zip(points.iter())
74                .map(|(s, p)| p.0 * s.0)
75                .fold(ExternalRistrettoPoint::identity(), |acc, point| acc + point),
76            // ExternalRistrettoPoint::vartime_multiscalar_mul(
77            //     scalars.iter().map(|s| s.0),
78            //     points.iter().map(|g| g.0),
79            // ),
80        ))
81    }
82}
83
84#[allow(clippy::suspicious_arithmetic_impl)]
85impl Div<RistrettoScalar> for RistrettoPoint {
86    type Output = Result<Self, FastCryptoError>;
87
88    fn div(self, rhs: RistrettoScalar) -> Self::Output {
89        let inv = rhs.inverse()?;
90        Ok(self * inv)
91    }
92}
93
94impl Mul<RistrettoScalar> for RistrettoPoint {
95    type Output = RistrettoPoint;
96
97    fn mul(self, rhs: RistrettoScalar) -> RistrettoPoint {
98        RistrettoPoint::from(self.0 * rhs.0)
99    }
100}
101
102impl GroupElement for RistrettoPoint {
103    type ScalarType = RistrettoScalar;
104
105    fn zero() -> RistrettoPoint {
106        RistrettoPoint::from(ExternalRistrettoPoint::identity())
107    }
108
109    fn generator() -> Self {
110        RistrettoPoint::from(RISTRETTO_BASEPOINT_POINT)
111    }
112}
113
114impl TryFrom<&[u8]> for RistrettoPoint {
115    type Error = FastCryptoError;
116
117    /// Decode a ristretto point in compressed binary form.
118    fn try_from(bytes: &[u8]) -> Result<Self, FastCryptoError> {
119        let point = ExternalCompressedRistrettoPoint::from_slice(bytes);
120        let decompressed_point = point.decompress().ok_or(FastCryptoError::InvalidInput)?;
121        Ok(RistrettoPoint::from(decompressed_point))
122    }
123}
124
125impl HashToGroupElement for RistrettoPoint {
126    fn hash_to_group_element(msg: &[u8]) -> Self {
127        RistrettoPoint::map_to_point::<Sha512>(msg)
128    }
129}
130
131impl ToFromByteArray<RISTRETTO_POINT_BYTE_LENGTH> for RistrettoPoint {
132    fn from_byte_array(bytes: &[u8; RISTRETTO_POINT_BYTE_LENGTH]) -> Result<Self, FastCryptoError> {
133        Self::try_from(bytes.as_slice())
134    }
135
136    fn to_byte_array(&self) -> [u8; RISTRETTO_POINT_BYTE_LENGTH] {
137        self.compress()
138    }
139}
140
141serialize_deserialize_with_to_from_byte_array!(RistrettoPoint);
142
143/// Represents a scalar.
144#[derive(Clone, Copy, Debug, PartialEq, Eq, From, Add, Sub, Neg, Div, GroupOpsExtend, Zeroize)]
145pub struct RistrettoScalar(ExternalRistrettoScalar);
146
147impl RistrettoScalar {
148    /// The order of the base point.
149    pub fn group_order() -> RistrettoScalar {
150        RistrettoScalar(BASEPOINT_ORDER)
151    }
152
153    /// Construct a [RistrettoScalar] by reducing a 64-byte little-endian integer modulo the group order.
154    pub fn from_bytes_mod_order_wide(bytes: &[u8; 64]) -> Self {
155        RistrettoScalar(ExternalRistrettoScalar::from_bytes_mod_order_wide(bytes))
156    }
157
158    /// Construct a [RistrettoScalar] by reducing a 32-byte little-endian integer modulo the group order.
159    pub fn from_bytes_mod_order(bytes: &[u8; 32]) -> Self {
160        RistrettoScalar(ExternalRistrettoScalar::from_bytes_mod_order(*bytes))
161    }
162}
163
164impl From<u128> for RistrettoScalar {
165    fn from(value: u128) -> RistrettoScalar {
166        RistrettoScalar(ExternalRistrettoScalar::from(value))
167    }
168}
169
170impl Mul<RistrettoScalar> for RistrettoScalar {
171    type Output = RistrettoScalar;
172
173    fn mul(self, rhs: RistrettoScalar) -> RistrettoScalar {
174        RistrettoScalar::from(self.0 * rhs.0)
175    }
176}
177
178#[allow(clippy::suspicious_arithmetic_impl)]
179impl Div<RistrettoScalar> for RistrettoScalar {
180    type Output = Result<RistrettoScalar, FastCryptoError>;
181
182    fn div(self, rhs: RistrettoScalar) -> Result<RistrettoScalar, FastCryptoError> {
183        let inv = rhs.inverse()?;
184        Ok(self * inv)
185    }
186}
187
188impl GroupElement for RistrettoScalar {
189    type ScalarType = Self;
190
191    fn zero() -> Self {
192        RistrettoScalar::from(ExternalRistrettoScalar::zero())
193    }
194    fn generator() -> Self {
195        RistrettoScalar::from(ExternalRistrettoScalar::one())
196    }
197}
198
199impl Scalar for RistrettoScalar {
200    fn rand<R: AllowedRng>(rng: &mut R) -> Self {
201        Self(ExternalRistrettoScalar::random(rng))
202    }
203
204    fn inverse(&self) -> FastCryptoResult<Self> {
205        if self.0 == ExternalRistrettoScalar::zero() {
206            return Err(FastCryptoError::InvalidInput);
207        }
208        Ok(RistrettoScalar::from(self.0.invert()))
209    }
210}
211
212impl HashToGroupElement for RistrettoScalar {
213    fn hash_to_group_element(bytes: &[u8]) -> Self {
214        Self::from_bytes_mod_order_wide(&Sha512::digest(bytes).digest)
215    }
216}
217
218impl FiatShamirChallenge for RistrettoScalar {
219    fn fiat_shamir_reduction_to_group_element(msg: &[u8]) -> Self {
220        Self::hash_to_group_element(msg)
221    }
222}
223
224impl ToFromByteArray<RISTRETTO_SCALAR_BYTE_LENGTH> for RistrettoScalar {
225    fn from_byte_array(
226        bytes: &[u8; RISTRETTO_SCALAR_BYTE_LENGTH],
227    ) -> Result<Self, FastCryptoError> {
228        Ok(RistrettoScalar(
229            ExternalRistrettoScalar::from_canonical_bytes(*bytes)
230                .ok_or(FastCryptoError::InvalidInput)?,
231        ))
232    }
233
234    fn to_byte_array(&self) -> [u8; RISTRETTO_SCALAR_BYTE_LENGTH] {
235        self.0.to_bytes()
236    }
237}
238
239serialize_deserialize_with_to_from_byte_array!(RistrettoScalar);