1use core::{
4    borrow::{Borrow, BorrowMut},
5    fmt,
6    ops::Shl,
7};
8
9use generic_array::{ArrayLength, GenericArray, IntoArrayLength};
10use subtle::{Choice, ConstantTimeEq};
11use typenum::{Const, Double, Unsigned, B1, U133, U32, U33, U48, U49, U65, U66, U67, U97};
12
13use crate::{
14    hex::ToHex,
15    import::{Import, ImportError, InvalidSizeError},
16    zeroize::{zeroize_flat_type, Zeroize, ZeroizeOnDrop},
17};
18
19pub trait Curve: Copy + Clone + Eq + PartialEq {
24    type ScalarSize: ArrayLength + Unsigned + Copy + Clone + Eq + PartialEq;
26
27    type CompressedSize: ArrayLength + Unsigned + Copy + Clone + Eq + PartialEq;
29
30    type UncompressedSize: ArrayLength + Unsigned + Copy + Clone + Eq + PartialEq;
32}
33
34macro_rules! curve_impl {
35    (
36        $name:ident,
37        $doc:expr,
38        $bytes:ty,
39        $comp_len:ty,
40        $uncomp_len:ty $(,)?
41    ) => {
42        #[doc = concat!($doc, ".")]
43        #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
44        pub struct $name;
45
46        impl Curve for $name {
47            type ScalarSize = $bytes;
48            type CompressedSize = $comp_len;
49            type UncompressedSize = $uncomp_len;
50        }
51    };
52}
53curve_impl!(Secp256r1, "NIST Curve P-256", U32, U33, U65);
54curve_impl!(Secp384r1, "NIST Curve P-384", U48, U49, U97);
55curve_impl!(Secp521r1, "NIST Curve P-521", U66, U67, U133);
56curve_impl!(Curve25519, "Curve25519", U32, U32, U32);
57
58macro_rules! pk_impl {
59    ($name:ident, $size:ident) => {
60        #[doc = concat!(stringify!($name), " elliptic curve point per [SEC] section 2.3.3.\n\n")]
61        #[doc = "This is equivalent to X9.62 encoding.\n\n"]
62        #[doc = "[SEC]: https://www.secg.org/sec1-v2.pdf"]
63        #[derive(Clone, Default, Eq, PartialEq)]
64        pub struct $name<C: Curve>(pub GenericArray<u8, C::$size>);
65
66        impl<C: Curve> $name<C> {
67            pub fn as_ptr(&self) -> *const u8 {
69                self.0.as_ptr()
70            }
71
72            pub fn as_mut_ptr(&mut self) -> *mut u8 {
74                self.0.as_mut_ptr()
75            }
76
77            #[allow(clippy::len_without_is_empty)]
79            pub const fn len(&self) -> usize {
80                C::$size::USIZE
81            }
82        }
83
84        impl<C: Curve> Copy for $name<C> where <C::$size as ArrayLength>::ArrayType<u8>: Copy {}
85
86        impl<C: Curve> AsRef<[u8]> for $name<C> {
87            #[inline]
88            fn as_ref(&self) -> &[u8] {
89                self.0.as_ref()
90            }
91        }
92
93        impl<C: Curve> AsMut<[u8]> for $name<C> {
94            #[inline]
95            fn as_mut(&mut self) -> &mut [u8] {
96                self.0.as_mut()
97            }
98        }
99
100        impl<C: Curve> Borrow<[u8]> for $name<C> {
101            #[inline]
102            fn borrow(&self) -> &[u8] {
103                self.0.as_ref()
104            }
105        }
106
107        impl<C: Curve> BorrowMut<[u8]> for $name<C> {
108            #[inline]
109            fn borrow_mut(&mut self) -> &mut [u8] {
110                self.0.as_mut()
111            }
112        }
113
114        impl<C: Curve, const N: usize> From<$name<C>> for [u8; N]
115        where
116            [u8; N]: From<GenericArray<u8, C::$size>>,
117        {
118            fn from(v: $name<C>) -> Self {
119                v.0.into()
120            }
121        }
122
123        impl<C: Curve, const N: usize> From<[u8; N]> for $name<C>
124        where
125            GenericArray<u8, C::$size>: From<[u8; N]>,
126        {
127            fn from(data: [u8; N]) -> Self {
128                Self(data.into())
129            }
130        }
131
132        impl<C: Curve> TryFrom<&[u8]> for $name<C> {
133            type Error = InvalidSizeError;
134
135            fn try_from(data: &[u8]) -> Result<Self, Self::Error> {
136                let v: &GenericArray<u8, _> = data.try_into().map_err(|_| InvalidSizeError {
137                    got: data.len(),
138                    want: C::$size::USIZE..C::$size::USIZE,
139                })?;
140                Ok(Self(v.clone()))
141            }
142        }
143
144        impl<C: Curve, const N: usize> Import<[u8; N]> for $name<C>
145        where
146            GenericArray<u8, C::$size>: From<[u8; N]>,
147        {
148            fn import(data: [u8; N]) -> Result<Self, ImportError> {
149                Ok(Self::from(data))
150            }
151        }
152
153        impl<C: Curve> Import<&[u8]> for $name<C> {
154            fn import(data: &[u8]) -> Result<Self, ImportError> {
155                Ok(Self(Import::<_>::import(data)?))
156            }
157        }
158
159        impl<C: Curve> Zeroize for $name<C> {
160            fn zeroize(&mut self) {
161                unsafe {
168                    zeroize_flat_type(&mut self.0);
169                }
170            }
171        }
172
173        impl<C: Curve> fmt::Debug for $name<C>
174        where
175            <C as Curve>::$size: ArrayLength + Shl<B1>,
176            Double<C::$size>: ArrayLength,
177        {
178            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179                f.debug_tuple(stringify!($name))
180                    .field(&self.to_hex())
181                    .finish()
182            }
183        }
184    };
185}
186pk_impl!(Compressed, CompressedSize);
187pk_impl!(Uncompressed, UncompressedSize);
188
189#[derive(Default)]
191pub struct Scalar<C: Curve>(pub GenericArray<u8, C::ScalarSize>);
192
193impl<C: Curve> Scalar<C> {
194    pub fn as_ptr(&self) -> *const u8 {
196        self.0.as_ptr()
197    }
198
199    pub fn as_mut_ptr(&mut self) -> *mut u8 {
201        self.0.as_mut_ptr()
202    }
203
204    #[allow(clippy::len_without_is_empty)]
206    pub const fn len(&self) -> usize {
207        C::ScalarSize::USIZE
208    }
209}
210
211impl<C: Curve> Clone for Scalar<C>
212where
213    <C::ScalarSize as ArrayLength>::ArrayType<u8>: Clone,
214{
215    fn clone(&self) -> Self {
216        Self(self.0.clone())
217    }
218}
219
220impl<C: Curve> ConstantTimeEq for Scalar<C> {
221    #[inline]
222    fn ct_eq(&self, other: &Self) -> Choice {
223        self.as_ref().ct_eq(other.as_ref())
224    }
225}
226
227impl<C: Curve> AsRef<[u8]> for Scalar<C> {
228    #[inline]
229    fn as_ref(&self) -> &[u8] {
230        self.0.as_ref()
231    }
232}
233
234impl<C: Curve> AsMut<[u8]> for Scalar<C> {
235    #[inline]
236    fn as_mut(&mut self) -> &mut [u8] {
237        self.0.as_mut()
238    }
239}
240
241impl<C: Curve> Borrow<[u8]> for Scalar<C> {
242    #[inline]
243    fn borrow(&self) -> &[u8] {
244        self.0.as_ref()
245    }
246}
247
248impl<C: Curve> BorrowMut<[u8]> for Scalar<C> {
249    #[inline]
250    fn borrow_mut(&mut self) -> &mut [u8] {
251        self.0.as_mut()
252    }
253}
254
255impl<C: Curve, const N: usize> From<Scalar<C>> for [u8; N]
256where
257    [u8; N]: From<GenericArray<u8, C::ScalarSize>>,
258{
259    fn from(v: Scalar<C>) -> Self {
260        v.0.clone().into()
261    }
262}
263
264impl<C: Curve, const N: usize> From<[u8; N]> for Scalar<C>
265where
266    Const<N>: IntoArrayLength,
267    GenericArray<u8, C::ScalarSize>: From<[u8; N]>,
268{
269    fn from(v: [u8; N]) -> Self {
270        Self(v.into())
271    }
272}
273
274impl<C: Curve> TryFrom<&[u8]> for Scalar<C> {
275    type Error = InvalidSizeError;
276
277    fn try_from(data: &[u8]) -> Result<Self, Self::Error> {
278        data.try_into()
279    }
280}
281
282impl<C: Curve, const N: usize> Import<[u8; N]> for Scalar<C>
283where
284    C::ScalarSize: ArrayLength,
285    Const<N>: IntoArrayLength,
286    GenericArray<u8, C::ScalarSize>: From<[u8; N]>,
287{
288    fn import(data: [u8; N]) -> Result<Self, ImportError> {
289        Ok(Self::from(data))
290    }
291}
292
293impl<C: Curve> Import<&[u8]> for Scalar<C> {
294    fn import(data: &[u8]) -> Result<Self, ImportError> {
295        let v: &GenericArray<u8, _> = data.try_into().map_err(|_| InvalidSizeError {
296            got: data.len(),
297            want: C::ScalarSize::USIZE..C::ScalarSize::USIZE,
298        })?;
299        Ok(Self(v.clone()))
300    }
301}
302
303impl<C: Curve> fmt::Debug for Scalar<C> {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        f.debug_tuple("Scalar").finish_non_exhaustive()
306    }
307}
308
309impl<C: Curve> ZeroizeOnDrop for Scalar<C> {}
310impl<C: Curve> Drop for Scalar<C> {
311    fn drop(&mut self) {
312        unsafe {
319            zeroize_flat_type(&mut self.0);
320        }
321    }
322}