umbral_pre/
curve.rs

1//! This module is an adapter to the ECC backend.
2//! `elliptic_curves` has a somewhat unstable API,
3//! and we isolate all the related logic here.
4
5use alloc::boxed::Box;
6use alloc::format;
7use alloc::string::String;
8use core::default::Default;
9use core::ops::{Add, Mul, Sub};
10
11use k256::{
12    elliptic_curve::{
13        bigint::U256, // Note that this type is different from typenum::U256
14        generic_array::GenericArray,
15        hash2curve::{ExpandMsgXmd, GroupDigest},
16        ops::Reduce,
17        sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
18        CurveArithmetic,
19        FieldBytesSize,
20        NonZeroScalar,
21        Scalar,
22    },
23    Secp256k1,
24};
25use rand_core::{CryptoRng, RngCore};
26use sha2::{digest::Digest, Sha256};
27use subtle::CtOption;
28use zeroize::{DefaultIsZeroes, Zeroize};
29
30#[cfg(any(feature = "serde", test))]
31use k256::elliptic_curve::group::ff::PrimeField;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Deserializer, Serialize, Serializer};
35
36#[cfg(feature = "serde")]
37use crate::serde_bytes::{
38    deserialize_with_encoding, serialize_with_encoding, Encoding, TryFromBytes,
39};
40
41pub(crate) type CurveType = Secp256k1;
42pub(crate) type CompressedPointSize =
43    <FieldBytesSize<CurveType> as ModulusSize>::CompressedPointSize;
44
45type BackendScalar = Scalar<CurveType>;
46pub(crate) type ScalarSize = FieldBytesSize<CurveType>;
47pub(crate) type BackendNonZeroScalar = NonZeroScalar<CurveType>;
48
49// We have to define newtypes for scalar and point here because the compiler
50// is not currently smart enough to resolve `BackendScalar` and `BackendPoint`
51// as specific types, so we cannot implement local traits for them.
52//
53// They also have to be public because Rust isn't smart enough to understand that
54//     type PointSize = <Point as RepresentableAsArray>::Size;
55// isn't leaking the `Point` (probably because type aliases are just inlined).
56
57#[derive(Clone, Copy, Debug, PartialEq, Default)]
58pub struct CurveScalar(BackendScalar);
59
60impl CurveScalar {
61    pub(crate) fn invert(&self) -> CtOption<Self> {
62        self.0.invert().map(Self)
63    }
64
65    pub(crate) fn one() -> Self {
66        Self(BackendScalar::ONE)
67    }
68
69    pub(crate) fn to_array(self) -> k256::FieldBytes {
70        self.0.to_bytes()
71    }
72
73    #[cfg(any(feature = "serde", test))]
74    pub(crate) fn try_from_bytes(bytes: &[u8]) -> Result<Self, String> {
75        let arr = GenericArray::<u8, ScalarSize>::from_exact_iter(bytes.iter().cloned())
76            .ok_or("Invalid length of a curve scalar")?;
77
78        // unwrap CtOption into Option
79        let maybe_scalar: Option<BackendScalar> = BackendScalar::from_repr(arr).into();
80        maybe_scalar
81            .map(Self)
82            .ok_or_else(|| "Invalid curve scalar representation".into())
83    }
84}
85
86#[cfg(feature = "serde")]
87impl Serialize for CurveScalar {
88    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89    where
90        S: Serializer,
91    {
92        serialize_with_encoding(&self.0.to_bytes(), serializer, Encoding::Hex)
93    }
94}
95
96#[cfg(feature = "serde")]
97impl<'de> Deserialize<'de> for CurveScalar {
98    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
99    where
100        D: Deserializer<'de>,
101    {
102        deserialize_with_encoding(deserializer, Encoding::Hex)
103    }
104}
105
106#[cfg(feature = "serde")]
107impl TryFromBytes for CurveScalar {
108    type Error = String;
109
110    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
111        Self::try_from_bytes(bytes)
112    }
113}
114
115impl DefaultIsZeroes for CurveScalar {}
116
117#[derive(Clone, Zeroize)]
118pub struct NonZeroCurveScalar(BackendNonZeroScalar);
119
120impl NonZeroCurveScalar {
121    /// Generates a random non-zero scalar (in nearly constant-time).
122    pub(crate) fn random(rng: &mut (impl CryptoRng + RngCore)) -> Self {
123        Self(BackendNonZeroScalar::random(rng))
124    }
125
126    pub(crate) fn from_backend_scalar(source: BackendNonZeroScalar) -> Self {
127        Self(source)
128    }
129
130    pub(crate) fn as_backend_scalar(&self) -> &BackendNonZeroScalar {
131        &self.0
132    }
133
134    pub(crate) fn invert(&self) -> Self {
135        // At the moment there is no infallible invert() for non-zero scalars
136        // (see https://github.com/RustCrypto/elliptic-curves/issues/499).
137        // But we know it will never fail.
138        let inv = self.0.invert().unwrap();
139        // We know that the inversion of a nonzero scalar is nonzero,
140        // so it is safe to unwrap again.
141        Self(BackendNonZeroScalar::new(inv).unwrap())
142    }
143
144    pub(crate) fn from_digest(d: impl Digest<OutputSize = ScalarSize>) -> Self {
145        // There's currently no way to make the required digest output size
146        // depend on the target scalar size, so we are hardcoding it to 256 bit
147        // (that is, equal to the scalar size).
148        Self(<BackendNonZeroScalar as Reduce<U256>>::reduce_bytes(
149            &d.finalize(),
150        ))
151    }
152}
153
154impl From<NonZeroCurveScalar> for CurveScalar {
155    fn from(source: NonZeroCurveScalar) -> Self {
156        CurveScalar(*source.0)
157    }
158}
159
160impl From<&NonZeroCurveScalar> for CurveScalar {
161    fn from(source: &NonZeroCurveScalar) -> Self {
162        CurveScalar(*source.0)
163    }
164}
165
166type BackendPoint = <CurveType as CurveArithmetic>::ProjectivePoint;
167
168/// A point on the elliptic curve.
169#[derive(Clone, Copy, Debug, PartialEq)]
170pub struct CurvePoint(BackendPoint);
171
172impl CurvePoint {
173    pub(crate) fn from_backend_point(point: &BackendPoint) -> Self {
174        Self(*point)
175    }
176
177    pub(crate) fn as_backend_point(&self) -> &BackendPoint {
178        &self.0
179    }
180
181    pub(crate) fn generator() -> Self {
182        Self(BackendPoint::GENERATOR)
183    }
184
185    pub(crate) fn identity() -> Self {
186        Self(BackendPoint::IDENTITY)
187    }
188
189    /// Returns `x` and `y` coordinates serialized as big-endian bytes,
190    /// or `None` if it is the infinity point.
191    pub fn coordinates(&self) -> Option<(k256::FieldBytes, k256::FieldBytes)> {
192        let point = self.0.to_encoded_point(false);
193        // x() may be None if it is the infinity point.
194        // If x() is not None, y() is not None either because we requested
195        // an uncompressed point in the line above; can safely unwrap.
196        point.x().map(|x| (*x, *point.y().unwrap()))
197    }
198
199    pub(crate) fn try_from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
200        let ep = EncodedPoint::<CurveType>::from_bytes(bytes).map_err(|err| format!("{err}"))?;
201
202        // Unwrap CtOption into Option
203        let cp_opt: Option<BackendPoint> = BackendPoint::from_encoded_point(&ep).into();
204        cp_opt
205            .map(Self)
206            .ok_or_else(|| "Invalid curve point representation".into())
207    }
208
209    pub(crate) fn to_compressed_array(self) -> GenericArray<u8, CompressedPointSize> {
210        *GenericArray::<u8, CompressedPointSize>::from_slice(
211            self.0.to_affine().to_encoded_point(true).as_bytes(),
212        )
213    }
214
215    pub(crate) fn to_uncompressed_bytes(self) -> Box<[u8]> {
216        self.0.to_affine().to_encoded_point(false).as_bytes().into()
217    }
218
219    /// Hashes arbitrary data with the given domain separation tag
220    /// into a valid EC point of the specified curve, using the algorithm described in the
221    /// [IETF hash-to-curve standard](https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/)
222    pub(crate) fn from_data(dst: &[u8], data: &[u8]) -> Option<Self> {
223        Some(Self(
224            CurveType::hash_from_bytes::<ExpandMsgXmd<Sha256>>(&[data], &[dst]).ok()?,
225        ))
226    }
227}
228
229impl Default for CurvePoint {
230    fn default() -> Self {
231        CurvePoint::identity()
232    }
233}
234
235#[cfg(feature = "serde")]
236impl Serialize for CurvePoint {
237    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
238    where
239        S: Serializer,
240    {
241        serialize_with_encoding(&self.to_compressed_array(), serializer, Encoding::Hex)
242    }
243}
244
245#[cfg(feature = "serde")]
246impl<'de> Deserialize<'de> for CurvePoint {
247    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
248    where
249        D: Deserializer<'de>,
250    {
251        deserialize_with_encoding(deserializer, Encoding::Hex)
252    }
253}
254
255#[cfg(feature = "serde")]
256impl TryFromBytes for CurvePoint {
257    type Error = String;
258
259    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
260        Self::try_from_compressed_bytes(bytes)
261    }
262}
263
264impl DefaultIsZeroes for CurvePoint {}
265
266impl Add<&CurveScalar> for &CurveScalar {
267    type Output = CurveScalar;
268
269    fn add(self, other: &CurveScalar) -> CurveScalar {
270        CurveScalar(self.0.add(&(other.0)))
271    }
272}
273
274impl Add<&NonZeroCurveScalar> for &CurveScalar {
275    type Output = CurveScalar;
276
277    fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
278        CurveScalar(self.0.add(&(*other.0)))
279    }
280}
281
282impl Add<&NonZeroCurveScalar> for &NonZeroCurveScalar {
283    type Output = CurveScalar;
284
285    fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
286        CurveScalar(self.0.add(&(*other.0)))
287    }
288}
289
290impl Add<&CurvePoint> for &CurvePoint {
291    type Output = CurvePoint;
292
293    fn add(self, other: &CurvePoint) -> CurvePoint {
294        CurvePoint(self.0.add(&(other.0)))
295    }
296}
297
298impl Sub<&CurveScalar> for &CurveScalar {
299    type Output = CurveScalar;
300
301    fn sub(self, other: &CurveScalar) -> CurveScalar {
302        CurveScalar(self.0.sub(&(other.0)))
303    }
304}
305
306impl Sub<&NonZeroCurveScalar> for &NonZeroCurveScalar {
307    type Output = CurveScalar;
308
309    fn sub(self, other: &NonZeroCurveScalar) -> CurveScalar {
310        CurveScalar(self.0.sub(&(*other.0)))
311    }
312}
313
314impl Mul<&CurveScalar> for &CurvePoint {
315    type Output = CurvePoint;
316
317    fn mul(self, other: &CurveScalar) -> CurvePoint {
318        CurvePoint(self.0.mul(&(other.0)))
319    }
320}
321
322impl Mul<&NonZeroCurveScalar> for &CurvePoint {
323    type Output = CurvePoint;
324
325    fn mul(self, other: &NonZeroCurveScalar) -> CurvePoint {
326        CurvePoint(self.0.mul(&(*other.0)))
327    }
328}
329
330impl Mul<&CurveScalar> for &CurveScalar {
331    type Output = CurveScalar;
332
333    fn mul(self, other: &CurveScalar) -> CurveScalar {
334        CurveScalar(self.0.mul(&(other.0)))
335    }
336}
337
338impl Mul<&NonZeroCurveScalar> for &CurveScalar {
339    type Output = CurveScalar;
340
341    fn mul(self, other: &NonZeroCurveScalar) -> CurveScalar {
342        CurveScalar(self.0.mul(&(*other.0)))
343    }
344}
345
346impl Mul<&NonZeroCurveScalar> for &NonZeroCurveScalar {
347    type Output = NonZeroCurveScalar;
348
349    fn mul(self, other: &NonZeroCurveScalar) -> NonZeroCurveScalar {
350        NonZeroCurveScalar(self.0.mul(other.0))
351    }
352}