1use md5::{Digest, Md5};
2use num_bigint::{BigInt, Sign};
3use num_integer::Integer;
4use num_traits::{FromBytes, One, Signed, Zero};
5use rand::RngCore;
6use std::fmt;
7use std::io::Write;
8use std::ops::Neg;
9
10pub trait KeyExchange {
11 type Error;
12
13 fn key_exchange<T: AsRef<[u8]>>(
14 &self,
15 public: T,
16 hash: bool,
17 ) -> std::result::Result<Vec<u8>, Self::Error>;
18}
19
20#[derive(Debug)]
21pub enum EcdhError {
22 InvalidPublicKey,
23 InvalidSecretKey,
24 PointNotOnCurve,
25 InverseDoesNotExist,
26 IOError(std::io::Error),
27}
28
29impl fmt::Display for EcdhError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 EcdhError::InvalidPublicKey => write!(f, "Invalid public key."),
33 EcdhError::InvalidSecretKey => write!(f, "Invalid secret key."),
34 EcdhError::PointNotOnCurve => write!(f, "Point is not on the curve."),
35 EcdhError::InverseDoesNotExist => write!(f, "Modular inverse does not exist."),
36 EcdhError::IOError(e) => write!(f, "IO error: {}", e),
37 }
38 }
39}
40
41impl std::error::Error for EcdhError {}
42
43impl From<std::io::Error> for EcdhError {
44 fn from(err: std::io::Error) -> Self {
45 EcdhError::IOError(err)
46 }
47}
48
49pub type Result<T> = std::result::Result<T, EcdhError>;
50
51#[derive(Debug, Clone)]
52pub struct Ecdh {
53 curve: EllipticCurve,
54 secret: BigInt,
55 public: EllipticPoint,
56}
57
58impl Ecdh {
59 pub fn new(curve: EllipticCurve) -> Result<Self> {
60 let mut ecdh = Ecdh {
61 curve,
62 secret: BigInt::default(),
63 public: EllipticPoint::default(),
64 };
65 ecdh.secret = ecdh.create_secret()?;
66 ecdh.public = ecdh.create_public()?;
67 Ok(ecdh)
68 }
69
70 pub fn new_with_secret<T: AsRef<[u8]>>(curve: EllipticCurve, secret: T) -> Result<Self> {
71 let mut ecdh = Ecdh {
72 curve,
73 secret: BigInt::default(),
74 public: EllipticPoint::default(),
75 };
76 ecdh.secret = Self::unpack_secret(secret.as_ref())?;
77 ecdh.public = ecdh.create_public()?;
78 Ok(ecdh)
79 }
80
81 pub fn pack_public(&self, compress: bool) -> Result<Vec<u8>> {
82 if compress {
83 let mut result = vec![0u8; self.curve.size + 1];
84 result[0] = if self.public.y.is_even() ^ self.public.y.is_negative() {
85 0x02
86 } else {
87 0x03
88 };
89 (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
90 Ok(result)
91 } else {
92 let mut result = vec![0u8; self.curve.size * 2 + 1];
93 result[0] = 0x04;
94 (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
95 (&mut result[self.curve.size + 1..])
96 .write_all(self.public.y.to_bytes_be().1.as_slice())?;
97 Ok(result)
98 }
99 }
100
101 pub fn pack_secret(&self) -> Result<Vec<u8>> {
102 let raw_length = self.secret.to_bytes_be().1.len();
103 let mut result = vec![0u8; raw_length + 4];
104 (&mut result[4..]).write_all(self.secret.to_bytes_be().1.as_slice())?;
105 result[3] = raw_length as u8;
106 Ok(result)
107 }
108
109 fn pack_shared(&self, ec_shared: &EllipticPoint, hash: bool) -> Vec<u8> {
110 let x = ec_shared.x.to_bytes_be().1;
111 if hash {
112 Md5::digest(x.as_slice()[..self.curve.pack_size].as_ref()).to_vec()
113 } else {
114 x
115 }
116 }
117
118 fn unpack_public<T: AsRef<[u8]>>(&self, public_key: T) -> Result<EllipticPoint> {
119 let length = public_key.as_ref().len();
120 if length != self.curve.size * 2 + 1 && length != self.curve.size + 1 {
121 return Err(EcdhError::InvalidPublicKey);
122 }
123 if public_key.as_ref()[0] == 0x04 {
124 Ok(EllipticPoint::new(
125 BigInt::from_bytes_be(
126 Sign::Plus,
127 public_key.as_ref()[1..self.curve.size + 1].as_ref(),
128 ),
129 BigInt::from_bytes_be(
130 Sign::Plus,
131 public_key.as_ref()[self.curve.size + 1..].as_ref(),
132 ),
133 ))
134 } else {
135 let px = BigInt::from_bytes_be(Sign::Plus, public_key.as_ref()[1..].as_ref());
137 let x3 = px.modpow(&BigInt::from(3), &self.curve.p);
138 let ax = (&px * &self.curve.a) % &self.curve.p;
139 let right = (&x3 + &ax + &self.curve.b) % &self.curve.p;
140
141 let tmp = (&self.curve.p + 1) >> 2;
142 let mut py = right.modpow(&tmp, &self.curve.p);
143 if !(py.is_even() && public_key.as_ref()[0] == 0x02
144 || !py.is_even() && public_key.as_ref()[0] == 0x03)
145 {
146 py = &self.curve.p - py;
147 }
148 Ok(EllipticPoint::new(px, py))
149 }
150 }
151
152 fn unpack_secret<T: AsRef<[u8]>>(ec_secret: T) -> Result<BigInt> {
153 let length = ec_secret.as_ref().len() - 4;
154 if length != ec_secret.as_ref()[3] as usize {
155 return Err(EcdhError::InvalidSecretKey);
156 }
157 Ok(BigInt::from_be_bytes(
158 ec_secret.as_ref()[4..length + 4].as_ref(),
159 ))
160 }
161
162 fn create_public(&self) -> Result<EllipticPoint> {
163 self.create_shared(&self.secret, &self.curve.g)
164 }
165
166 fn create_secret(&self) -> Result<BigInt> {
167 let mut rng = rand::thread_rng();
168 let mut arr = vec![0u8; self.curve.size];
169 loop {
170 rng.fill_bytes(&mut arr);
171 let result = BigInt::from_be_bytes(arr.as_slice());
172 if result >= BigInt::one() && result < self.curve.n {
173 return Ok(result);
174 }
175 }
176 }
177
178 fn create_shared(
179 &self,
180 ec_secret: &BigInt,
181 ec_public: &EllipticPoint,
182 ) -> Result<EllipticPoint> {
183 if ec_secret % &self.curve.n == BigInt::ZERO || ec_public.is_default() {
184 return Ok(EllipticPoint::default());
185 }
186 if ec_secret.is_negative() {
187 return self.create_shared(&-ec_secret, ec_public);
188 }
189 if !self.curve.check_on(ec_public) {
190 return Err(EcdhError::PointNotOnCurve);
191 }
192
193 let mut pr = EllipticPoint::default();
194 let mut pa = ec_public.clone();
195 let mut ps = ec_secret.clone();
196 while ps > BigInt::ZERO {
197 if (&ps & BigInt::one()) > BigInt::ZERO {
198 pr = self.point_add(&pr, &pa)?;
199 }
200 pa = self.point_add(&pa, &pa)?;
201 ps = ps >> 1;
202 }
203 if !self.curve.check_on(&pr) {
204 return Err(EcdhError::PointNotOnCurve);
205 }
206 Ok(pr)
207 }
208
209 fn point_add(&self, p1: &EllipticPoint, p2: &EllipticPoint) -> Result<EllipticPoint> {
210 if p1.is_default() {
211 return Ok(p2.clone());
212 };
213 if p2.is_default() {
214 return Ok(p1.clone());
215 };
216 if !self.curve.check_on(p1) || !self.curve.check_on(p2) {
217 return Err(EcdhError::PointNotOnCurve);
218 }
219
220 let (x1, x2, y1, y2) = (&p1.x, &p2.x, &p1.y, &p2.y);
221
222 let m = if x1 == x2 {
223 if y1 == y2 {
224 (3 * x1 * x1 + &self.curve.a) * mod_inverse(&(y1 << 1), &self.curve.p)?
225 } else {
226 return Ok(EllipticPoint::default());
227 }
228 } else {
229 (y1 - y2) * mod_inverse(&(x1 - x2), &self.curve.p)?
230 };
231 let xr = mod_positive(&(&m * &m - x1 - x2), &self.curve.p);
232 let yr = mod_positive(&(&m * (x1 - &xr) - y1), &self.curve.p);
233 let pr = EllipticPoint::new(xr, yr);
234 if !self.curve.check_on(&pr) {
235 return Err(EcdhError::PointNotOnCurve);
236 }
237 Ok(pr)
238 }
239}
240
241impl KeyExchange for Ecdh {
242 type Error = EcdhError;
243
244 fn key_exchange<T: AsRef<[u8]>>(&self, ec_pub: T, hash: bool) -> Result<Vec<u8>> {
245 let shared = self.create_shared(&self.secret, &self.unpack_public(ec_pub)?)?;
246 Ok(self.pack_shared(&shared, hash))
247 }
248}
249
250fn mod_inverse(a: &BigInt, b: &BigInt) -> Result<BigInt> {
251 if a.is_negative() {
252 return Ok(b - mod_inverse(&-a, b)?);
253 }
254 if a.gcd(b) != BigInt::one() {
255 return Err(EcdhError::InverseDoesNotExist);
256 }
257 Ok(a.modpow(&(b - 2), b))
258}
259
260fn mod_positive(a: &BigInt, b: &BigInt) -> BigInt {
261 let result = a % b;
262 if result.is_negative() {
263 result + b
264 } else {
265 result
266 }
267}
268
269#[derive(Debug, Clone)]
270pub struct EllipticCurve {
271 pub p: BigInt,
272 pub a: BigInt,
273 pub b: BigInt,
274 pub g: EllipticPoint,
275 pub n: BigInt,
276 pub size: usize,
278 pub pack_size: usize,
279}
280
281impl EllipticCurve {
282 pub fn check_on(&self, point: &EllipticPoint) -> bool {
283 let lhs = point.y.modpow(&BigInt::from(2), &self.p); let rhs = (point.x.modpow(&BigInt::from(3), &self.p) + (&self.a * &point.x) % &self.p
287 + &self.b)
288 % &self.p;
289 lhs == rhs
290 }
291}
292
293#[derive(Debug, Clone)]
294pub struct EllipticPoint {
295 x: BigInt,
296 y: BigInt,
297}
298
299impl EllipticPoint {
300 pub fn new(x: BigInt, y: BigInt) -> EllipticPoint {
301 Self { x, y }
302 }
303
304 pub fn is_default(&self) -> bool {
305 self.x.is_zero() && self.y.is_zero()
306 }
307}
308
309impl Default for EllipticPoint {
310 fn default() -> EllipticPoint {
311 EllipticPoint {
312 x: BigInt::ZERO,
313 y: BigInt::ZERO,
314 }
315 }
316}
317
318impl Neg for EllipticPoint {
319 type Output = EllipticPoint;
320 fn neg(self) -> EllipticPoint {
321 Self {
322 x: -self.x,
323 y: -self.y,
324 }
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use once_cell::sync::Lazy;
332
333 pub static PRIME256V1: Lazy<EllipticCurve> = Lazy::new(|| EllipticCurve {
334 p: BigInt::from_bytes_le(
335 Sign::Plus,
336 &[
337 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
338 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
339 0xFF, 0xFF, 0xFF, 0xFF,
340 ],
341 ),
342 a: BigInt::from_bytes_le(
343 Sign::Plus,
344 &[
345 0xFC, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
346 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
347 0xFF, 0xFF, 0xFF, 0xFF,
348 ],
349 ),
350 b: BigInt::from_bytes_le(
351 Sign::Plus,
352 &[
353 0x4B, 0x60, 0xD2, 0x27, 0x3E, 0x3C, 0xCE, 0x3B, 0xF6, 0xB0, 0x53, 0xCC, 0xB0, 0x06,
354 0x1D, 0x65, 0xBC, 0x86, 0x98, 0x76, 0x55, 0xBD, 0xEB, 0xB3, 0xE7, 0x93, 0x3A, 0xAA,
355 0xD8, 0x35, 0xC6, 0x5A,
356 ],
357 ),
358 g: EllipticPoint {
359 x: BigInt::from_bytes_le(
360 Sign::Plus,
361 &[
362 0x96, 0xC2, 0x98, 0xD8, 0x45, 0x39, 0xA1, 0xF4, 0xA0, 0x33, 0xEB, 0x2D, 0x81,
363 0x7D, 0x03, 0x77, 0xF2, 0x40, 0xA4, 0x63, 0xE5, 0xE6, 0xBC, 0xF8, 0x47, 0x42,
364 0x2C, 0xE1, 0xF2, 0xD1, 0x17, 0x6B,
365 ],
366 ),
367 y: BigInt::from_bytes_le(
368 Sign::Plus,
369 &[
370 0xF5, 0x51, 0xBF, 0x37, 0x68, 0x40, 0xB6, 0xCB, 0xCE, 0x5E, 0x31, 0x6B, 0x57,
371 0x33, 0xCE, 0x2B, 0x16, 0x9E, 0x0F, 0x7C, 0x4A, 0xEB, 0xE7, 0x8E, 0x9B, 0x7F,
372 0x1A, 0xFE, 0xE2, 0x42, 0xE3, 0x4F,
373 ],
374 ),
375 },
376 n: BigInt::from_bytes_le(
377 Sign::Plus,
378 &[
379 0x51, 0x25, 0x63, 0xFC, 0xC2, 0xCA, 0xB9, 0xF3, 0x84, 0x9E, 0x17, 0xA7, 0xAD, 0xFA,
380 0xE6, 0xBC, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,
381 0xFF, 0xFF, 0xFF, 0xFF,
382 ],
383 ),
384 size: 32,
385 pack_size: 16,
386 });
387
388 #[test]
389 fn test_create_shared() -> Result<()> {
390 let ec_pub = EllipticPoint {
391 x: BigInt::from_slice(
392 Sign::Plus,
393 &[
394 3633889942, 4104206661, 770388896, 1996717441, 1671708914, 4173129445,
395 3777774151, 1796723186,
396 ],
397 ),
398 y: BigInt::from_slice(
399 Sign::Plus,
400 &[
401 935285237, 3417718888, 1798397646, 734933847, 2081398294, 2397563722,
402 4263149467, 1340293858,
403 ],
404 ),
405 };
406 let ec_sec = BigInt::from_slice(
407 Sign::Plus,
408 &[
409 2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051, 1983615641,
410 1413039311,
411 ],
412 );
413 let ecdh = Ecdh {
414 curve: PRIME256V1.clone(),
415 public: EllipticPoint::default(),
416 secret: BigInt::from_slice(
417 Sign::Plus,
418 &[
419 2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051,
420 1983615641, 1413039311,
421 ],
422 ),
423 };
424 let shared = ecdh.create_shared(&ec_sec, &ec_pub)?;
425 assert_eq!(
426 shared.x.to_string(),
427 "108127657902349890608713381133328869798433126817450190281707933471728911716600"
428 );
429 assert_eq!(
430 shared.y.to_string(),
431 "90081983369721618215244899859661517112066743317494361926578071297084718897517"
432 );
433 Ok(())
434 }
435
436 #[test]
437 fn test_pack_public() -> Result<()> {
438 let ecdh = Ecdh {
439 curve: PRIME256V1.clone(),
440 secret: BigInt::from_slice(
441 Sign::Plus,
442 &[
443 3008988791, 2223512191, 4101290493, 1349918355, 981201553, 1906136310,
444 3328719820, 837222639,
445 ],
446 ),
447 public: EllipticPoint {
448 x: BigInt::from_slice(
449 Sign::Plus,
450 &[
451 192391996, 42411882, 3555170359, 1596249776, 1928736291, 3341608003,
452 1479611319, 155190516,
453 ],
454 ),
455 y: BigInt::from_slice(
456 Sign::Plus,
457 &[
458 3113797789, 2959440885, 2287647648, 3290191409, 693816809, 3512895204,
459 3156405519, 100630943,
460 ],
461 ),
462 },
463 };
464 let public = ecdh.pack_public(false)?;
465 assert_eq!(
466 public.as_slice(),
467 &[
468 4, 9, 64, 4, 244, 88, 49, 19, 183, 199, 44, 228, 67, 114, 246, 46, 35, 95, 36, 214,
469 176, 211, 231, 152, 55, 2, 135, 39, 106, 11, 119, 171, 60, 5, 255, 129, 159, 188,
470 34, 237, 15, 209, 98, 134, 228, 41, 90, 205, 233, 196, 28, 86, 49, 136, 90, 187,
471 160, 176, 101, 123, 245, 185, 152, 200, 157,
472 ]
473 );
474 Ok(())
475 }
476
477 #[test]
478 fn test_ecdh_key_exchange() -> Result<()> {
479 let curve_p256 = &PRIME256V1.clone();
480 let alice = Ecdh::new(curve_p256.clone())?;
481 let alice_public_key = alice.pack_public(false)?;
482 let bob = Ecdh::new(curve_p256.clone())?;
483 let bob_public_key = bob.pack_public(false)?;
484 let shared_by_alice = alice.key_exchange(bob_public_key, false)?;
485 let shared_by_bob = bob.key_exchange(alice_public_key, false)?;
486 assert_eq!(shared_by_alice, shared_by_bob);
487 Ok(())
488 }
489}