simple_ecdh/
lib.rs

1use md5::{Digest, Md5};
2use num_bigint::{BigInt, Sign};
3use num_integer::Integer;
4use num_traits::{FromBytes, One, Signed, Zero};
5use rand::RngCore;
6use std::fmt;
7use std::io::Write;
8use std::ops::Neg;
9
10pub trait KeyExchange {
11    type Error;
12
13    fn key_exchange<T: AsRef<[u8]>>(
14        &self,
15        public: T,
16        hash: bool,
17    ) -> std::result::Result<Vec<u8>, Self::Error>;
18}
19
20#[derive(Debug)]
21pub enum EcdhError {
22    InvalidPublicKey,
23    InvalidSecretKey,
24    PointNotOnCurve,
25    InverseDoesNotExist,
26    IOError(std::io::Error),
27}
28
29impl fmt::Display for EcdhError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            EcdhError::InvalidPublicKey => write!(f, "Invalid public key."),
33            EcdhError::InvalidSecretKey => write!(f, "Invalid secret key."),
34            EcdhError::PointNotOnCurve => write!(f, "Point is not on the curve."),
35            EcdhError::InverseDoesNotExist => write!(f, "Modular inverse does not exist."),
36            EcdhError::IOError(e) => write!(f, "IO error: {}", e),
37        }
38    }
39}
40
41impl std::error::Error for EcdhError {}
42
43impl From<std::io::Error> for EcdhError {
44    fn from(err: std::io::Error) -> Self {
45        EcdhError::IOError(err)
46    }
47}
48
49pub type Result<T> = std::result::Result<T, EcdhError>;
50
51#[derive(Debug, Clone)]
52pub struct Ecdh {
53    curve: EllipticCurve,
54    secret: BigInt,
55    public: EllipticPoint,
56}
57
58impl Ecdh {
59    pub fn new(curve: EllipticCurve) -> Result<Self> {
60        let mut ecdh = Ecdh {
61            curve,
62            secret: BigInt::default(),
63            public: EllipticPoint::default(),
64        };
65        ecdh.secret = ecdh.create_secret()?;
66        ecdh.public = ecdh.create_public()?;
67        Ok(ecdh)
68    }
69
70    pub fn new_with_secret<T: AsRef<[u8]>>(curve: EllipticCurve, secret: T) -> Result<Self> {
71        let mut ecdh = Ecdh {
72            curve,
73            secret: BigInt::default(),
74            public: EllipticPoint::default(),
75        };
76        ecdh.secret = Self::unpack_secret(secret.as_ref())?;
77        ecdh.public = ecdh.create_public()?;
78        Ok(ecdh)
79    }
80
81    pub fn pack_public(&self, compress: bool) -> Result<Vec<u8>> {
82        if compress {
83            let mut result = vec![0u8; self.curve.size + 1];
84            result[0] = if self.public.y.is_even() ^ self.public.y.is_negative() {
85                0x02
86            } else {
87                0x03
88            };
89            (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
90            Ok(result)
91        } else {
92            let mut result = vec![0u8; self.curve.size * 2 + 1];
93            result[0] = 0x04;
94            (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
95            (&mut result[self.curve.size + 1..])
96                .write_all(self.public.y.to_bytes_be().1.as_slice())?;
97            Ok(result)
98        }
99    }
100
101    pub fn pack_secret(&self) -> Result<Vec<u8>> {
102        let raw_length = self.secret.to_bytes_be().1.len();
103        let mut result = vec![0u8; raw_length + 4];
104        (&mut result[4..]).write_all(self.secret.to_bytes_be().1.as_slice())?;
105        result[3] = raw_length as u8;
106        Ok(result)
107    }
108
109    fn pack_shared(&self, ec_shared: &EllipticPoint, hash: bool) -> Vec<u8> {
110        let x = ec_shared.x.to_bytes_be().1;
111        if hash {
112            Md5::digest(x.as_slice()[..self.curve.pack_size].as_ref()).to_vec()
113        } else {
114            x
115        }
116    }
117
118    fn unpack_public<T: AsRef<[u8]>>(&self, public_key: T) -> Result<EllipticPoint> {
119        let length = public_key.as_ref().len();
120        if length != self.curve.size * 2 + 1 && length != self.curve.size + 1 {
121            return Err(EcdhError::InvalidPublicKey);
122        }
123        if public_key.as_ref()[0] == 0x04 {
124            Ok(EllipticPoint::new(
125                BigInt::from_bytes_be(
126                    Sign::Plus,
127                    public_key.as_ref()[1..self.curve.size + 1].as_ref(),
128                ),
129                BigInt::from_bytes_be(
130                    Sign::Plus,
131                    public_key.as_ref()[self.curve.size + 1..].as_ref(),
132                ),
133            ))
134        } else {
135            // find the y-coordinate from x-coordinate by y^2 = x^3 + ax + b
136            let px = BigInt::from_bytes_be(Sign::Plus, public_key.as_ref()[1..].as_ref());
137            let x3 = px.modpow(&BigInt::from(3), &self.curve.p);
138            let ax = (&px * &self.curve.a) % &self.curve.p;
139            let right = (&x3 + &ax + &self.curve.b) % &self.curve.p;
140
141            let tmp = (&self.curve.p + 1) >> 2;
142            let mut py = right.modpow(&tmp, &self.curve.p);
143            if !(py.is_even() && public_key.as_ref()[0] == 0x02
144                || !py.is_even() && public_key.as_ref()[0] == 0x03)
145            {
146                py = &self.curve.p - py;
147            }
148            Ok(EllipticPoint::new(px, py))
149        }
150    }
151
152    fn unpack_secret<T: AsRef<[u8]>>(ec_secret: T) -> Result<BigInt> {
153        let length = ec_secret.as_ref().len() - 4;
154        if length != ec_secret.as_ref()[3] as usize {
155            return Err(EcdhError::InvalidSecretKey);
156        }
157        Ok(BigInt::from_be_bytes(
158            ec_secret.as_ref()[4..length + 4].as_ref(),
159        ))
160    }
161
162    fn create_public(&self) -> Result<EllipticPoint> {
163        self.create_shared(&self.secret, &self.curve.g)
164    }
165
166    fn create_secret(&self) -> Result<BigInt> {
167        let mut rng = rand::thread_rng();
168        let mut arr = vec![0u8; self.curve.size];
169        loop {
170            rng.fill_bytes(&mut arr);
171            let result = BigInt::from_be_bytes(arr.as_slice());
172            if result >= BigInt::one() && result < self.curve.n {
173                return Ok(result);
174            }
175        }
176    }
177
178    fn create_shared(
179        &self,
180        ec_secret: &BigInt,
181        ec_public: &EllipticPoint,
182    ) -> Result<EllipticPoint> {
183        if ec_secret % &self.curve.n == BigInt::ZERO || ec_public.is_default() {
184            return Ok(EllipticPoint::default());
185        }
186        if ec_secret.is_negative() {
187            return self.create_shared(&-ec_secret, ec_public);
188        }
189        if !self.curve.check_on(ec_public) {
190            return Err(EcdhError::PointNotOnCurve);
191        }
192
193        let mut pr = EllipticPoint::default();
194        let mut pa = ec_public.clone();
195        let mut ps = ec_secret.clone();
196        while ps > BigInt::ZERO {
197            if (&ps & BigInt::one()) > BigInt::ZERO {
198                pr = self.point_add(&pr, &pa)?;
199            }
200            pa = self.point_add(&pa, &pa)?;
201            ps = ps >> 1;
202        }
203        if !self.curve.check_on(&pr) {
204            return Err(EcdhError::PointNotOnCurve);
205        }
206        Ok(pr)
207    }
208
209    fn point_add(&self, p1: &EllipticPoint, p2: &EllipticPoint) -> Result<EllipticPoint> {
210        if p1.is_default() {
211            return Ok(p2.clone());
212        };
213        if p2.is_default() {
214            return Ok(p1.clone());
215        };
216        if !self.curve.check_on(p1) || !self.curve.check_on(p2) {
217            return Err(EcdhError::PointNotOnCurve);
218        }
219
220        let (x1, x2, y1, y2) = (&p1.x, &p2.x, &p1.y, &p2.y);
221
222        let m = if x1 == x2 {
223            if y1 == y2 {
224                (3 * x1 * x1 + &self.curve.a) * mod_inverse(&(y1 << 1), &self.curve.p)?
225            } else {
226                return Ok(EllipticPoint::default());
227            }
228        } else {
229            (y1 - y2) * mod_inverse(&(x1 - x2), &self.curve.p)?
230        };
231        let xr = mod_positive(&(&m * &m - x1 - x2), &self.curve.p);
232        let yr = mod_positive(&(&m * (x1 - &xr) - y1), &self.curve.p);
233        let pr = EllipticPoint::new(xr, yr);
234        if !self.curve.check_on(&pr) {
235            return Err(EcdhError::PointNotOnCurve);
236        }
237        Ok(pr)
238    }
239}
240
241impl KeyExchange for Ecdh {
242    type Error = EcdhError;
243
244    fn key_exchange<T: AsRef<[u8]>>(&self, ec_pub: T, hash: bool) -> Result<Vec<u8>> {
245        let shared = self.create_shared(&self.secret, &self.unpack_public(ec_pub)?)?;
246        Ok(self.pack_shared(&shared, hash))
247    }
248}
249
250fn mod_inverse(a: &BigInt, b: &BigInt) -> Result<BigInt> {
251    if a.is_negative() {
252        return Ok(b - mod_inverse(&-a, b)?);
253    }
254    if a.gcd(b) != BigInt::one() {
255        return Err(EcdhError::InverseDoesNotExist);
256    }
257    Ok(a.modpow(&(b - 2), b))
258}
259
260fn mod_positive(a: &BigInt, b: &BigInt) -> BigInt {
261    let result = a % b;
262    if result.is_negative() {
263        result + b
264    } else {
265        result
266    }
267}
268
269#[derive(Debug, Clone)]
270pub struct EllipticCurve {
271    pub p: BigInt,
272    pub a: BigInt,
273    pub b: BigInt,
274    pub g: EllipticPoint,
275    pub n: BigInt,
276    // h: BigInt,
277    pub size: usize,
278    pub pack_size: usize,
279}
280
281impl EllipticCurve {
282    pub fn check_on(&self, point: &EllipticPoint) -> bool {
283        // ((&point.y.pow(2) - &point.x.pow(3) - &self.a * &point.x - &self.b) % &self.p) == BigInt::zero()
284        let lhs = point.y.modpow(&BigInt::from(2), &self.p); // y² mod p
285        let rhs = (point.x.modpow(&BigInt::from(3), &self.p)  // x³ mod p
286            + (&self.a * &point.x) % &self.p
287            + &self.b)
288            % &self.p;
289        lhs == rhs
290    }
291}
292
293#[derive(Debug, Clone)]
294pub struct EllipticPoint {
295    x: BigInt,
296    y: BigInt,
297}
298
299impl EllipticPoint {
300    pub fn new(x: BigInt, y: BigInt) -> EllipticPoint {
301        Self { x, y }
302    }
303
304    pub fn is_default(&self) -> bool {
305        self.x.is_zero() && self.y.is_zero()
306    }
307}
308
309impl Default for EllipticPoint {
310    fn default() -> EllipticPoint {
311        EllipticPoint {
312            x: BigInt::ZERO,
313            y: BigInt::ZERO,
314        }
315    }
316}
317
318impl Neg for EllipticPoint {
319    type Output = EllipticPoint;
320    fn neg(self) -> EllipticPoint {
321        Self {
322            x: -self.x,
323            y: -self.y,
324        }
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use once_cell::sync::Lazy;
332
333    pub static PRIME256V1: Lazy<EllipticCurve> = Lazy::new(|| EllipticCurve {
334        p: BigInt::from_bytes_le(
335            Sign::Plus,
336            &[
337                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
338                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
339                0xFF, 0xFF, 0xFF, 0xFF,
340            ],
341        ),
342        a: BigInt::from_bytes_le(
343            Sign::Plus,
344            &[
345                0xFC, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
346                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
347                0xFF, 0xFF, 0xFF, 0xFF,
348            ],
349        ),
350        b: BigInt::from_bytes_le(
351            Sign::Plus,
352            &[
353                0x4B, 0x60, 0xD2, 0x27, 0x3E, 0x3C, 0xCE, 0x3B, 0xF6, 0xB0, 0x53, 0xCC, 0xB0, 0x06,
354                0x1D, 0x65, 0xBC, 0x86, 0x98, 0x76, 0x55, 0xBD, 0xEB, 0xB3, 0xE7, 0x93, 0x3A, 0xAA,
355                0xD8, 0x35, 0xC6, 0x5A,
356            ],
357        ),
358        g: EllipticPoint {
359            x: BigInt::from_bytes_le(
360                Sign::Plus,
361                &[
362                    0x96, 0xC2, 0x98, 0xD8, 0x45, 0x39, 0xA1, 0xF4, 0xA0, 0x33, 0xEB, 0x2D, 0x81,
363                    0x7D, 0x03, 0x77, 0xF2, 0x40, 0xA4, 0x63, 0xE5, 0xE6, 0xBC, 0xF8, 0x47, 0x42,
364                    0x2C, 0xE1, 0xF2, 0xD1, 0x17, 0x6B,
365                ],
366            ),
367            y: BigInt::from_bytes_le(
368                Sign::Plus,
369                &[
370                    0xF5, 0x51, 0xBF, 0x37, 0x68, 0x40, 0xB6, 0xCB, 0xCE, 0x5E, 0x31, 0x6B, 0x57,
371                    0x33, 0xCE, 0x2B, 0x16, 0x9E, 0x0F, 0x7C, 0x4A, 0xEB, 0xE7, 0x8E, 0x9B, 0x7F,
372                    0x1A, 0xFE, 0xE2, 0x42, 0xE3, 0x4F,
373                ],
374            ),
375        },
376        n: BigInt::from_bytes_le(
377            Sign::Plus,
378            &[
379                0x51, 0x25, 0x63, 0xFC, 0xC2, 0xCA, 0xB9, 0xF3, 0x84, 0x9E, 0x17, 0xA7, 0xAD, 0xFA,
380                0xE6, 0xBC, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,
381                0xFF, 0xFF, 0xFF, 0xFF,
382            ],
383        ),
384        size: 32,
385        pack_size: 16,
386    });
387
388    #[test]
389    fn test_create_shared() -> Result<()> {
390        let ec_pub = EllipticPoint {
391            x: BigInt::from_slice(
392                Sign::Plus,
393                &[
394                    3633889942, 4104206661, 770388896, 1996717441, 1671708914, 4173129445,
395                    3777774151, 1796723186,
396                ],
397            ),
398            y: BigInt::from_slice(
399                Sign::Plus,
400                &[
401                    935285237, 3417718888, 1798397646, 734933847, 2081398294, 2397563722,
402                    4263149467, 1340293858,
403                ],
404            ),
405        };
406        let ec_sec = BigInt::from_slice(
407            Sign::Plus,
408            &[
409                2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051, 1983615641,
410                1413039311,
411            ],
412        );
413        let ecdh = Ecdh {
414            curve: PRIME256V1.clone(),
415            public: EllipticPoint::default(),
416            secret: BigInt::from_slice(
417                Sign::Plus,
418                &[
419                    2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051,
420                    1983615641, 1413039311,
421                ],
422            ),
423        };
424        let shared = ecdh.create_shared(&ec_sec, &ec_pub)?;
425        assert_eq!(
426            shared.x.to_string(),
427            "108127657902349890608713381133328869798433126817450190281707933471728911716600"
428        );
429        assert_eq!(
430            shared.y.to_string(),
431            "90081983369721618215244899859661517112066743317494361926578071297084718897517"
432        );
433        Ok(())
434    }
435
436    #[test]
437    fn test_pack_public() -> Result<()> {
438        let ecdh = Ecdh {
439            curve: PRIME256V1.clone(),
440            secret: BigInt::from_slice(
441                Sign::Plus,
442                &[
443                    3008988791, 2223512191, 4101290493, 1349918355, 981201553, 1906136310,
444                    3328719820, 837222639,
445                ],
446            ),
447            public: EllipticPoint {
448                x: BigInt::from_slice(
449                    Sign::Plus,
450                    &[
451                        192391996, 42411882, 3555170359, 1596249776, 1928736291, 3341608003,
452                        1479611319, 155190516,
453                    ],
454                ),
455                y: BigInt::from_slice(
456                    Sign::Plus,
457                    &[
458                        3113797789, 2959440885, 2287647648, 3290191409, 693816809, 3512895204,
459                        3156405519, 100630943,
460                    ],
461                ),
462            },
463        };
464        let public = ecdh.pack_public(false)?;
465        assert_eq!(
466            public.as_slice(),
467            &[
468                4, 9, 64, 4, 244, 88, 49, 19, 183, 199, 44, 228, 67, 114, 246, 46, 35, 95, 36, 214,
469                176, 211, 231, 152, 55, 2, 135, 39, 106, 11, 119, 171, 60, 5, 255, 129, 159, 188,
470                34, 237, 15, 209, 98, 134, 228, 41, 90, 205, 233, 196, 28, 86, 49, 136, 90, 187,
471                160, 176, 101, 123, 245, 185, 152, 200, 157,
472            ]
473        );
474        Ok(())
475    }
476
477    #[test]
478    fn test_ecdh_key_exchange() -> Result<()> {
479        let curve_p256 = &PRIME256V1.clone();
480        let alice = Ecdh::new(curve_p256.clone())?;
481        let alice_public_key = alice.pack_public(false)?;
482        let bob = Ecdh::new(curve_p256.clone())?;
483        let bob_public_key = bob.pack_public(false)?;
484        let shared_by_alice = alice.key_exchange(bob_public_key, false)?;
485        let shared_by_bob = bob.key_exchange(alice_public_key, false)?;
486        assert_eq!(shared_by_alice, shared_by_bob);
487        Ok(())
488    }
489}