1pub mod curve;
2
3use md5::{Digest, Md5};
4use num_bigint::{BigInt, Sign};
5use num_integer::Integer;
6use num_traits::{FromBytes, One, Signed, Zero};
7use rand::RngCore;
8use std::fmt;
9use std::io::Write;
10use std::ops::Neg;
11
12pub trait KeyExchange {
13 type Error;
14
15 fn key_exchange<T: AsRef<[u8]>>(
16 &self,
17 public: T,
18 hash: bool,
19 ) -> std::result::Result<Vec<u8>, Self::Error>;
20}
21
22#[derive(Debug)]
23pub enum EcdhError {
24 InvalidPublicKey,
25 InvalidSecretKey,
26 PointNotOnCurve,
27 InverseDoesNotExist,
28 IOError(std::io::Error),
29}
30
31impl fmt::Display for EcdhError {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 match self {
34 EcdhError::InvalidPublicKey => write!(f, "Invalid public key."),
35 EcdhError::InvalidSecretKey => write!(f, "Invalid secret key."),
36 EcdhError::PointNotOnCurve => write!(f, "Point is not on the curve."),
37 EcdhError::InverseDoesNotExist => write!(f, "Modular inverse does not exist."),
38 EcdhError::IOError(e) => write!(f, "IO error: {}", e),
39 }
40 }
41}
42
43impl std::error::Error for EcdhError {}
44
45impl From<std::io::Error> for EcdhError {
46 fn from(err: std::io::Error) -> Self {
47 EcdhError::IOError(err)
48 }
49}
50
51pub type Result<T> = std::result::Result<T, EcdhError>;
52
53#[derive(Debug, Clone)]
54pub struct Ecdh {
55 curve: EllipticCurve,
56 secret: BigInt,
57 public: EllipticPoint,
58}
59
60impl Ecdh {
61 pub fn new(curve: EllipticCurve) -> Result<Self> {
62 let mut ecdh = Ecdh {
63 curve,
64 secret: BigInt::default(),
65 public: EllipticPoint::default(),
66 };
67 ecdh.secret = ecdh.create_secret()?;
68 ecdh.public = ecdh.create_public()?;
69 Ok(ecdh)
70 }
71
72 pub fn new_with_secret<T: AsRef<[u8]>>(curve: EllipticCurve, secret: T) -> Result<Self> {
73 let mut ecdh = Ecdh {
74 curve,
75 secret: BigInt::default(),
76 public: EllipticPoint::default(),
77 };
78 ecdh.secret = Self::unpack_secret(secret.as_ref())?;
79 ecdh.public = ecdh.create_public()?;
80 Ok(ecdh)
81 }
82
83 pub fn pack_public(&self, compress: bool) -> Result<Vec<u8>> {
84 if compress {
85 let mut result = vec![0u8; self.curve.size + 1];
86 result[0] = if self.public.y.is_even() ^ self.public.y.is_negative() {
87 0x02
88 } else {
89 0x03
90 };
91 (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
92 Ok(result)
93 } else {
94 let mut result = vec![0u8; self.curve.size * 2 + 1];
95 result[0] = 0x04;
96 (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
97 (&mut result[self.curve.size + 1..])
98 .write_all(self.public.y.to_bytes_be().1.as_slice())?;
99 Ok(result)
100 }
101 }
102
103 pub fn pack_secret(&self) -> Result<Vec<u8>> {
104 let raw_length = self.secret.to_bytes_be().1.len();
105 let mut result = vec![0u8; raw_length + 4];
106 (&mut result[4..]).write_all(self.secret.to_bytes_be().1.as_slice())?;
107 result[3] = raw_length as u8;
108 Ok(result)
109 }
110
111 fn pack_shared(&self, ec_shared: &EllipticPoint, hash: bool) -> Vec<u8> {
112 let x = ec_shared.x.to_bytes_be().1;
113 if hash {
114 Md5::digest(x.as_slice()[..self.curve.pack_size].as_ref()).to_vec()
115 } else {
116 x
117 }
118 }
119
120 fn unpack_public<T: AsRef<[u8]>>(&self, public_key: T) -> Result<EllipticPoint> {
121 let length = public_key.as_ref().len();
122 if length != self.curve.size * 2 + 1 && length != self.curve.size + 1 {
123 return Err(EcdhError::InvalidPublicKey);
124 }
125 if public_key.as_ref()[0] == 0x04 {
126 Ok(EllipticPoint::new(
127 BigInt::from_bytes_be(
128 Sign::Plus,
129 public_key.as_ref()[1..self.curve.size + 1].as_ref(),
130 ),
131 BigInt::from_bytes_be(
132 Sign::Plus,
133 public_key.as_ref()[self.curve.size + 1..].as_ref(),
134 ),
135 ))
136 } else {
137 let px = BigInt::from_bytes_be(Sign::Plus, public_key.as_ref()[1..].as_ref());
139 let x3 = px.modpow(&BigInt::from(3), &self.curve.p);
140 let ax = (&px * &self.curve.a) % &self.curve.p;
141 let right = (&x3 + &ax + &self.curve.b) % &self.curve.p;
142
143 let tmp = (&self.curve.p + 1) >> 2;
144 let mut py = right.modpow(&tmp, &self.curve.p);
145 if !(py.is_even() && public_key.as_ref()[0] == 0x02
146 || !py.is_even() && public_key.as_ref()[0] == 0x03)
147 {
148 py = &self.curve.p - py;
149 }
150 Ok(EllipticPoint::new(px, py))
151 }
152 }
153
154 fn unpack_secret<T: AsRef<[u8]>>(ec_secret: T) -> Result<BigInt> {
155 let length = ec_secret.as_ref().len() - 4;
156 if length != ec_secret.as_ref()[3] as usize {
157 return Err(EcdhError::InvalidSecretKey);
158 }
159 Ok(BigInt::from_be_bytes(
160 ec_secret.as_ref()[4..length + 4].as_ref(),
161 ))
162 }
163
164 fn create_public(&self) -> Result<EllipticPoint> {
165 self.create_shared(&self.secret, &self.curve.g)
166 }
167
168 fn create_secret(&self) -> Result<BigInt> {
169 let mut rng = rand::thread_rng();
170 let mut arr = vec![0u8; self.curve.size];
171 loop {
172 rng.fill_bytes(&mut arr);
173 let result = BigInt::from_be_bytes(arr.as_slice());
174 if result >= BigInt::one() && result < self.curve.n {
175 return Ok(result);
176 }
177 }
178 }
179
180 fn create_shared(
181 &self,
182 ec_secret: &BigInt,
183 ec_public: &EllipticPoint,
184 ) -> Result<EllipticPoint> {
185 if ec_secret % &self.curve.n == BigInt::ZERO || ec_public.is_default() {
186 return Ok(EllipticPoint::default());
187 }
188 if ec_secret.is_negative() {
189 return self.create_shared(&-ec_secret, ec_public);
190 }
191 if !self.curve.check_on(ec_public) {
192 return Err(EcdhError::PointNotOnCurve);
193 }
194
195 let mut pr = EllipticPoint::default();
196 let mut pa = ec_public.clone();
197 let mut ps = ec_secret.clone();
198 while ps > BigInt::ZERO {
199 if (&ps & BigInt::one()) > BigInt::ZERO {
200 pr = self.point_add(&pr, &pa)?;
201 }
202 pa = self.point_add(&pa, &pa)?;
203 ps >>= 1;
204 }
205 if !self.curve.check_on(&pr) {
206 return Err(EcdhError::PointNotOnCurve);
207 }
208 Ok(pr)
209 }
210
211 fn point_add(&self, p1: &EllipticPoint, p2: &EllipticPoint) -> Result<EllipticPoint> {
212 if p1.is_default() {
213 return Ok(p2.clone());
214 };
215 if p2.is_default() {
216 return Ok(p1.clone());
217 };
218 if !self.curve.check_on(p1) || !self.curve.check_on(p2) {
219 return Err(EcdhError::PointNotOnCurve);
220 }
221
222 let (x1, x2, y1, y2) = (&p1.x, &p2.x, &p1.y, &p2.y);
223
224 let m = if x1 == x2 {
225 if y1 == y2 {
226 (3 * x1 * x1 + &self.curve.a) * mod_inverse(&(y1 << 1), &self.curve.p)?
227 } else {
228 return Ok(EllipticPoint::default());
229 }
230 } else {
231 (y1 - y2) * mod_inverse(&(x1 - x2), &self.curve.p)?
232 };
233 let xr = mod_positive(&(&m * &m - x1 - x2), &self.curve.p);
234 let yr = mod_positive(&(&m * (x1 - &xr) - y1), &self.curve.p);
235 let pr = EllipticPoint::new(xr, yr);
236 if !self.curve.check_on(&pr) {
237 return Err(EcdhError::PointNotOnCurve);
238 }
239 Ok(pr)
240 }
241}
242
243impl KeyExchange for Ecdh {
244 type Error = EcdhError;
245
246 fn key_exchange<T: AsRef<[u8]>>(&self, ec_pub: T, hash: bool) -> Result<Vec<u8>> {
247 let shared = self.create_shared(&self.secret, &self.unpack_public(ec_pub)?)?;
248 Ok(self.pack_shared(&shared, hash))
249 }
250}
251
252fn mod_inverse(a: &BigInt, b: &BigInt) -> Result<BigInt> {
253 if a.is_negative() {
254 return Ok(b - mod_inverse(&-a, b)?);
255 }
256 if a.gcd(b) != BigInt::one() {
257 return Err(EcdhError::InverseDoesNotExist);
258 }
259 Ok(a.modpow(&(b - 2), b))
260}
261
262fn mod_positive(a: &BigInt, b: &BigInt) -> BigInt {
263 let result = a % b;
264 if result.is_negative() {
265 result + b
266 } else {
267 result
268 }
269}
270
271#[derive(Debug, Clone)]
272pub struct EllipticCurve {
273 pub p: BigInt,
274 pub a: BigInt,
275 pub b: BigInt,
276 pub g: EllipticPoint,
277 pub n: BigInt,
278 pub size: usize,
280 pub pack_size: usize,
281}
282
283impl EllipticCurve {
284 pub fn check_on(&self, point: &EllipticPoint) -> bool {
285 let lhs = point.y.modpow(&BigInt::from(2), &self.p); let rhs = (point.x.modpow(&BigInt::from(3), &self.p) + (&self.a * &point.x) % &self.p
289 + &self.b)
290 % &self.p;
291 lhs == rhs
292 }
293}
294
295#[derive(Debug, Clone)]
296pub struct EllipticPoint {
297 x: BigInt,
298 y: BigInt,
299}
300
301impl EllipticPoint {
302 pub fn new(x: BigInt, y: BigInt) -> EllipticPoint {
303 Self { x, y }
304 }
305
306 pub fn is_default(&self) -> bool {
307 self.x.is_zero() && self.y.is_zero()
308 }
309}
310
311impl Default for EllipticPoint {
312 fn default() -> EllipticPoint {
313 EllipticPoint {
314 x: BigInt::ZERO,
315 y: BigInt::ZERO,
316 }
317 }
318}
319
320impl Neg for EllipticPoint {
321 type Output = EllipticPoint;
322 fn neg(self) -> EllipticPoint {
323 Self {
324 x: -self.x,
325 y: -self.y,
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_create_shared() -> Result<()> {
336 let ec_pub = EllipticPoint {
337 x: BigInt::from_slice(
338 Sign::Plus,
339 &[
340 3633889942, 4104206661, 770388896, 1996717441, 1671708914, 4173129445,
341 3777774151, 1796723186,
342 ],
343 ),
344 y: BigInt::from_slice(
345 Sign::Plus,
346 &[
347 935285237, 3417718888, 1798397646, 734933847, 2081398294, 2397563722,
348 4263149467, 1340293858,
349 ],
350 ),
351 };
352 let ec_sec = BigInt::from_slice(
353 Sign::Plus,
354 &[
355 2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051, 1983615641,
356 1413039311,
357 ],
358 );
359 let ecdh = Ecdh {
360 curve: curve::PRIME256V1.clone(),
361 public: EllipticPoint::default(),
362 secret: BigInt::from_slice(
363 Sign::Plus,
364 &[
365 2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051,
366 1983615641, 1413039311,
367 ],
368 ),
369 };
370 let shared = ecdh.create_shared(&ec_sec, &ec_pub)?;
371 assert_eq!(
372 shared.x.to_string(),
373 "108127657902349890608713381133328869798433126817450190281707933471728911716600"
374 );
375 assert_eq!(
376 shared.y.to_string(),
377 "90081983369721618215244899859661517112066743317494361926578071297084718897517"
378 );
379 Ok(())
380 }
381
382 #[test]
383 fn test_pack_public() -> Result<()> {
384 let ecdh = Ecdh {
385 curve: curve::PRIME256V1.clone(),
386 secret: BigInt::from_slice(
387 Sign::Plus,
388 &[
389 3008988791, 2223512191, 4101290493, 1349918355, 981201553, 1906136310,
390 3328719820, 837222639,
391 ],
392 ),
393 public: EllipticPoint {
394 x: BigInt::from_slice(
395 Sign::Plus,
396 &[
397 192391996, 42411882, 3555170359, 1596249776, 1928736291, 3341608003,
398 1479611319, 155190516,
399 ],
400 ),
401 y: BigInt::from_slice(
402 Sign::Plus,
403 &[
404 3113797789, 2959440885, 2287647648, 3290191409, 693816809, 3512895204,
405 3156405519, 100630943,
406 ],
407 ),
408 },
409 };
410 let public = ecdh.pack_public(false)?;
411 assert_eq!(
412 public.as_slice(),
413 &[
414 4, 9, 64, 4, 244, 88, 49, 19, 183, 199, 44, 228, 67, 114, 246, 46, 35, 95, 36, 214,
415 176, 211, 231, 152, 55, 2, 135, 39, 106, 11, 119, 171, 60, 5, 255, 129, 159, 188,
416 34, 237, 15, 209, 98, 134, 228, 41, 90, 205, 233, 196, 28, 86, 49, 136, 90, 187,
417 160, 176, 101, 123, 245, 185, 152, 200, 157,
418 ]
419 );
420 Ok(())
421 }
422
423 #[test]
424 fn test_ecdh_key_exchange() -> Result<()> {
425 let curve_p256 = &curve::PRIME256V1.clone();
426 let alice = Ecdh::new(curve_p256.clone())?;
427 let alice_public_key = alice.pack_public(false)?;
428 let bob = Ecdh::new(curve_p256.clone())?;
429 let bob_public_key = bob.pack_public(false)?;
430 let shared_by_alice = alice.key_exchange(bob_public_key, false)?;
431 let shared_by_bob = bob.key_exchange(alice_public_key, false)?;
432 assert_eq!(shared_by_alice, shared_by_bob);
433 Ok(())
434 }
435}