1use alloc::boxed::Box;
6use alloc::format;
7use alloc::string::String;
8use core::default::Default;
9use core::ops::{Add, Mul, Sub};
10
11use k256::{
12 elliptic_curve::{
13 bigint::U256, generic_array::GenericArray,
15 hash2curve::{ExpandMsgXmd, GroupDigest},
16 ops::Reduce,
17 sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
18 CurveArithmetic,
19 FieldBytesSize,
20 NonZeroScalar,
21 Scalar,
22 },
23 Secp256k1,
24};
25use rand_core::{CryptoRng, RngCore};
26use sha2::{digest::Digest, Sha256};
27use subtle::CtOption;
28use zeroize::{DefaultIsZeroes, Zeroize};
29
30#[cfg(any(feature = "serde", test))]
31use k256::elliptic_curve::group::ff::PrimeField;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Deserializer, Serialize, Serializer};
35
36#[cfg(feature = "serde")]
37use crate::serde_bytes::{
38 deserialize_with_encoding, serialize_with_encoding, Encoding, TryFromBytes,
39};
40
41pub(crate) type CurveType = Secp256k1;
42pub(crate) type CompressedPointSize =
43 <FieldBytesSize<CurveType> as ModulusSize>::CompressedPointSize;
44
45type BackendScalar = Scalar<CurveType>;
46pub(crate) type ScalarSize = FieldBytesSize<CurveType>;
47pub(crate) type BackendNonZeroScalar = NonZeroScalar<CurveType>;
48
49#[derive(Clone, Copy, Debug, PartialEq, Default)]
58pub struct CurveScalar(BackendScalar);
59
60impl CurveScalar {
61 pub(crate) fn invert(&self) -> CtOption<Self> {
62 self.0.invert().map(Self)
63 }
64
65 pub(crate) fn one() -> Self {
66 Self(BackendScalar::ONE)
67 }
68
69 pub(crate) fn to_array(self) -> k256::FieldBytes {
70 self.0.to_bytes()
71 }
72
73 #[cfg(any(feature = "serde", test))]
74 pub(crate) fn try_from_bytes(bytes: &[u8]) -> Result<Self, String> {
75 let arr = GenericArray::<u8, ScalarSize>::from_exact_iter(bytes.iter().cloned())
76 .ok_or("Invalid length of a curve scalar")?;
77
78 let maybe_scalar: Option<BackendScalar> = BackendScalar::from_repr(arr).into();
80 maybe_scalar
81 .map(Self)
82 .ok_or_else(|| "Invalid curve scalar representation".into())
83 }
84}
85
86#[cfg(feature = "serde")]
87impl Serialize for CurveScalar {
88 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89 where
90 S: Serializer,
91 {
92 serialize_with_encoding(&self.0.to_bytes(), serializer, Encoding::Hex)
93 }
94}
95
96#[cfg(feature = "serde")]
97impl<'de> Deserialize<'de> for CurveScalar {
98 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
99 where
100 D: Deserializer<'de>,
101 {
102 deserialize_with_encoding(deserializer, Encoding::Hex)
103 }
104}
105
106#[cfg(feature = "serde")]
107impl TryFromBytes for CurveScalar {
108 type Error = String;
109
110 fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
111 Self::try_from_bytes(bytes)
112 }
113}
114
115impl DefaultIsZeroes for CurveScalar {}
116
117#[derive(Clone, Zeroize)]
118pub struct NonZeroCurveScalar(BackendNonZeroScalar);
119
120impl NonZeroCurveScalar {
121 pub(crate) fn random(rng: &mut (impl CryptoRng + RngCore)) -> Self {
123 Self(BackendNonZeroScalar::random(rng))
124 }
125
126 pub(crate) fn from_backend_scalar(source: BackendNonZeroScalar) -> Self {
127 Self(source)
128 }
129
130 pub(crate) fn as_backend_scalar(&self) -> &BackendNonZeroScalar {
131 &self.0
132 }
133
134 pub(crate) fn invert(&self) -> Self {
135 let inv = self.0.invert().unwrap();
139 Self(BackendNonZeroScalar::new(inv).unwrap())
142 }
143
144 pub(crate) fn from_digest(d: impl Digest<OutputSize = ScalarSize>) -> Self {
145 Self(<BackendNonZeroScalar as Reduce<U256>>::reduce_bytes(
149 &d.finalize(),
150 ))
151 }
152}
153
154impl From<NonZeroCurveScalar> for CurveScalar {
155 fn from(source: NonZeroCurveScalar) -> Self {
156 CurveScalar(*source.0)
157 }
158}
159
160impl From<&NonZeroCurveScalar> for CurveScalar {
161 fn from(source: &NonZeroCurveScalar) -> Self {
162 CurveScalar(*source.0)
163 }
164}
165
166type BackendPoint = <CurveType as CurveArithmetic>::ProjectivePoint;
167
168#[derive(Clone, Copy, Debug, PartialEq)]
170pub struct CurvePoint(BackendPoint);
171
172impl CurvePoint {
173 pub(crate) fn from_backend_point(point: &BackendPoint) -> Self {
174 Self(*point)
175 }
176
177 pub(crate) fn as_backend_point(&self) -> &BackendPoint {
178 &self.0
179 }
180
181 pub(crate) fn generator() -> Self {
182 Self(BackendPoint::GENERATOR)
183 }
184
185 pub(crate) fn identity() -> Self {
186 Self(BackendPoint::IDENTITY)
187 }
188
189 pub fn coordinates(&self) -> Option<(k256::FieldBytes, k256::FieldBytes)> {
192 let point = self.0.to_encoded_point(false);
193 point.x().map(|x| (*x, *point.y().unwrap()))
197 }
198
199 pub(crate) fn try_from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
200 let ep = EncodedPoint::<CurveType>::from_bytes(bytes).map_err(|err| format!("{err}"))?;
201
202 let cp_opt: Option<BackendPoint> = BackendPoint::from_encoded_point(&ep).into();
204 cp_opt
205 .map(Self)
206 .ok_or_else(|| "Invalid curve point representation".into())
207 }
208
209 pub(crate) fn to_compressed_array(self) -> GenericArray<u8, CompressedPointSize> {
210 *GenericArray::<u8, CompressedPointSize>::from_slice(
211 self.0.to_affine().to_encoded_point(true).as_bytes(),
212 )
213 }
214
215 pub(crate) fn to_uncompressed_bytes(self) -> Box<[u8]> {
216 self.0.to_affine().to_encoded_point(false).as_bytes().into()
217 }
218
219 pub(crate) fn from_data(dst: &[u8], data: &[u8]) -> Option<Self> {
223 Some(Self(
224 CurveType::hash_from_bytes::<ExpandMsgXmd<Sha256>>(&[data], &[dst]).ok()?,
225 ))
226 }
227}
228
229impl Default for CurvePoint {
230 fn default() -> Self {
231 CurvePoint::identity()
232 }
233}
234
235#[cfg(feature = "serde")]
236impl Serialize for CurvePoint {
237 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
238 where
239 S: Serializer,
240 {
241 serialize_with_encoding(&self.to_compressed_array(), serializer, Encoding::Hex)
242 }
243}
244
245#[cfg(feature = "serde")]
246impl<'de> Deserialize<'de> for CurvePoint {
247 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
248 where
249 D: Deserializer<'de>,
250 {
251 deserialize_with_encoding(deserializer, Encoding::Hex)
252 }
253}
254
255#[cfg(feature = "serde")]
256impl TryFromBytes for CurvePoint {
257 type Error = String;
258
259 fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
260 Self::try_from_compressed_bytes(bytes)
261 }
262}
263
264impl DefaultIsZeroes for CurvePoint {}
265
266impl Add<&CurveScalar> for &CurveScalar {
267 type Output = CurveScalar;
268
269 fn add(self, other: &CurveScalar) -> CurveScalar {
270 CurveScalar(self.0.add(&(other.0)))
271 }
272}
273
274impl Add<&NonZeroCurveScalar> for &CurveScalar {
275 type Output = CurveScalar;
276
277 fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
278 CurveScalar(self.0.add(&(*other.0)))
279 }
280}
281
282impl Add<&NonZeroCurveScalar> for &NonZeroCurveScalar {
283 type Output = CurveScalar;
284
285 fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
286 CurveScalar(self.0.add(&(*other.0)))
287 }
288}
289
290impl Add<&CurvePoint> for &CurvePoint {
291 type Output = CurvePoint;
292
293 fn add(self, other: &CurvePoint) -> CurvePoint {
294 CurvePoint(self.0.add(&(other.0)))
295 }
296}
297
298impl Sub<&CurveScalar> for &CurveScalar {
299 type Output = CurveScalar;
300
301 fn sub(self, other: &CurveScalar) -> CurveScalar {
302 CurveScalar(self.0.sub(&(other.0)))
303 }
304}
305
306impl Sub<&NonZeroCurveScalar> for &NonZeroCurveScalar {
307 type Output = CurveScalar;
308
309 fn sub(self, other: &NonZeroCurveScalar) -> CurveScalar {
310 CurveScalar(self.0.sub(&(*other.0)))
311 }
312}
313
314impl Mul<&CurveScalar> for &CurvePoint {
315 type Output = CurvePoint;
316
317 fn mul(self, other: &CurveScalar) -> CurvePoint {
318 CurvePoint(self.0.mul(&(other.0)))
319 }
320}
321
322impl Mul<&NonZeroCurveScalar> for &CurvePoint {
323 type Output = CurvePoint;
324
325 fn mul(self, other: &NonZeroCurveScalar) -> CurvePoint {
326 CurvePoint(self.0.mul(&(*other.0)))
327 }
328}
329
330impl Mul<&CurveScalar> for &CurveScalar {
331 type Output = CurveScalar;
332
333 fn mul(self, other: &CurveScalar) -> CurveScalar {
334 CurveScalar(self.0.mul(&(other.0)))
335 }
336}
337
338impl Mul<&NonZeroCurveScalar> for &CurveScalar {
339 type Output = CurveScalar;
340
341 fn mul(self, other: &NonZeroCurveScalar) -> CurveScalar {
342 CurveScalar(self.0.mul(&(*other.0)))
343 }
344}
345
346impl Mul<&NonZeroCurveScalar> for &NonZeroCurveScalar {
347 type Output = NonZeroCurveScalar;
348
349 fn mul(self, other: &NonZeroCurveScalar) -> NonZeroCurveScalar {
350 NonZeroCurveScalar(self.0.mul(other.0))
351 }
352}