simple_ecdh/
lib.rs

1pub mod curve;
2
3use md5::{Digest, Md5};
4use num_bigint::{BigInt, Sign};
5use num_integer::Integer;
6use num_traits::{FromBytes, One, Signed, Zero};
7use rand::RngCore;
8use std::fmt;
9use std::io::Write;
10use std::ops::Neg;
11
12pub trait KeyExchange {
13    type Error;
14
15    fn key_exchange<T: AsRef<[u8]>>(
16        &self,
17        public: T,
18        hash: bool,
19    ) -> std::result::Result<Vec<u8>, Self::Error>;
20}
21
22#[derive(Debug)]
23pub enum EcdhError {
24    InvalidPublicKey,
25    InvalidSecretKey,
26    PointNotOnCurve,
27    InverseDoesNotExist,
28    IOError(std::io::Error),
29}
30
31impl fmt::Display for EcdhError {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            EcdhError::InvalidPublicKey => write!(f, "Invalid public key."),
35            EcdhError::InvalidSecretKey => write!(f, "Invalid secret key."),
36            EcdhError::PointNotOnCurve => write!(f, "Point is not on the curve."),
37            EcdhError::InverseDoesNotExist => write!(f, "Modular inverse does not exist."),
38            EcdhError::IOError(e) => write!(f, "IO error: {}", e),
39        }
40    }
41}
42
43impl std::error::Error for EcdhError {}
44
45impl From<std::io::Error> for EcdhError {
46    fn from(err: std::io::Error) -> Self {
47        EcdhError::IOError(err)
48    }
49}
50
51pub type Result<T> = std::result::Result<T, EcdhError>;
52
53#[derive(Debug, Clone)]
54pub struct Ecdh {
55    curve: EllipticCurve,
56    secret: BigInt,
57    public: EllipticPoint,
58}
59
60impl Ecdh {
61    pub fn new(curve: EllipticCurve) -> Result<Self> {
62        let mut ecdh = Ecdh {
63            curve,
64            secret: BigInt::default(),
65            public: EllipticPoint::default(),
66        };
67        ecdh.secret = ecdh.create_secret()?;
68        ecdh.public = ecdh.create_public()?;
69        Ok(ecdh)
70    }
71
72    pub fn new_with_secret<T: AsRef<[u8]>>(curve: EllipticCurve, secret: T) -> Result<Self> {
73        let mut ecdh = Ecdh {
74            curve,
75            secret: BigInt::default(),
76            public: EllipticPoint::default(),
77        };
78        ecdh.secret = Self::unpack_secret(secret.as_ref())?;
79        ecdh.public = ecdh.create_public()?;
80        Ok(ecdh)
81    }
82
83    pub fn pack_public(&self, compress: bool) -> Result<Vec<u8>> {
84        if compress {
85            let mut result = vec![0u8; self.curve.size + 1];
86            result[0] = if self.public.y.is_even() ^ self.public.y.is_negative() {
87                0x02
88            } else {
89                0x03
90            };
91            (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
92            Ok(result)
93        } else {
94            let mut result = vec![0u8; self.curve.size * 2 + 1];
95            result[0] = 0x04;
96            (&mut result[1..]).write_all(self.public.x.to_bytes_be().1.as_slice())?;
97            (&mut result[self.curve.size + 1..])
98                .write_all(self.public.y.to_bytes_be().1.as_slice())?;
99            Ok(result)
100        }
101    }
102
103    pub fn pack_secret(&self) -> Result<Vec<u8>> {
104        let raw_length = self.secret.to_bytes_be().1.len();
105        let mut result = vec![0u8; raw_length + 4];
106        (&mut result[4..]).write_all(self.secret.to_bytes_be().1.as_slice())?;
107        result[3] = raw_length as u8;
108        Ok(result)
109    }
110
111    fn pack_shared(&self, ec_shared: &EllipticPoint, hash: bool) -> Vec<u8> {
112        let x = ec_shared.x.to_bytes_be().1;
113        if hash {
114            Md5::digest(x.as_slice()[..self.curve.pack_size].as_ref()).to_vec()
115        } else {
116            x
117        }
118    }
119
120    fn unpack_public<T: AsRef<[u8]>>(&self, public_key: T) -> Result<EllipticPoint> {
121        let length = public_key.as_ref().len();
122        if length != self.curve.size * 2 + 1 && length != self.curve.size + 1 {
123            return Err(EcdhError::InvalidPublicKey);
124        }
125        if public_key.as_ref()[0] == 0x04 {
126            Ok(EllipticPoint::new(
127                BigInt::from_bytes_be(
128                    Sign::Plus,
129                    public_key.as_ref()[1..self.curve.size + 1].as_ref(),
130                ),
131                BigInt::from_bytes_be(
132                    Sign::Plus,
133                    public_key.as_ref()[self.curve.size + 1..].as_ref(),
134                ),
135            ))
136        } else {
137            // find the y-coordinate from x-coordinate by y^2 = x^3 + ax + b
138            let px = BigInt::from_bytes_be(Sign::Plus, public_key.as_ref()[1..].as_ref());
139            let x3 = px.modpow(&BigInt::from(3), &self.curve.p);
140            let ax = (&px * &self.curve.a) % &self.curve.p;
141            let right = (&x3 + &ax + &self.curve.b) % &self.curve.p;
142
143            let tmp = (&self.curve.p + 1) >> 2;
144            let mut py = right.modpow(&tmp, &self.curve.p);
145            if !(py.is_even() && public_key.as_ref()[0] == 0x02
146                || !py.is_even() && public_key.as_ref()[0] == 0x03)
147            {
148                py = &self.curve.p - py;
149            }
150            Ok(EllipticPoint::new(px, py))
151        }
152    }
153
154    fn unpack_secret<T: AsRef<[u8]>>(ec_secret: T) -> Result<BigInt> {
155        let length = ec_secret.as_ref().len() - 4;
156        if length != ec_secret.as_ref()[3] as usize {
157            return Err(EcdhError::InvalidSecretKey);
158        }
159        Ok(BigInt::from_be_bytes(
160            ec_secret.as_ref()[4..length + 4].as_ref(),
161        ))
162    }
163
164    fn create_public(&self) -> Result<EllipticPoint> {
165        self.create_shared(&self.secret, &self.curve.g)
166    }
167
168    fn create_secret(&self) -> Result<BigInt> {
169        let mut rng = rand::thread_rng();
170        let mut arr = vec![0u8; self.curve.size];
171        loop {
172            rng.fill_bytes(&mut arr);
173            let result = BigInt::from_be_bytes(arr.as_slice());
174            if result >= BigInt::one() && result < self.curve.n {
175                return Ok(result);
176            }
177        }
178    }
179
180    fn create_shared(
181        &self,
182        ec_secret: &BigInt,
183        ec_public: &EllipticPoint,
184    ) -> Result<EllipticPoint> {
185        if ec_secret % &self.curve.n == BigInt::ZERO || ec_public.is_default() {
186            return Ok(EllipticPoint::default());
187        }
188        if ec_secret.is_negative() {
189            return self.create_shared(&-ec_secret, ec_public);
190        }
191        if !self.curve.check_on(ec_public) {
192            return Err(EcdhError::PointNotOnCurve);
193        }
194
195        let mut pr = EllipticPoint::default();
196        let mut pa = ec_public.clone();
197        let mut ps = ec_secret.clone();
198        while ps > BigInt::ZERO {
199            if (&ps & BigInt::one()) > BigInt::ZERO {
200                pr = self.point_add(&pr, &pa)?;
201            }
202            pa = self.point_add(&pa, &pa)?;
203            ps >>= 1;
204        }
205        if !self.curve.check_on(&pr) {
206            return Err(EcdhError::PointNotOnCurve);
207        }
208        Ok(pr)
209    }
210
211    fn point_add(&self, p1: &EllipticPoint, p2: &EllipticPoint) -> Result<EllipticPoint> {
212        if p1.is_default() {
213            return Ok(p2.clone());
214        };
215        if p2.is_default() {
216            return Ok(p1.clone());
217        };
218        if !self.curve.check_on(p1) || !self.curve.check_on(p2) {
219            return Err(EcdhError::PointNotOnCurve);
220        }
221
222        let (x1, x2, y1, y2) = (&p1.x, &p2.x, &p1.y, &p2.y);
223
224        let m = if x1 == x2 {
225            if y1 == y2 {
226                (3 * x1 * x1 + &self.curve.a) * mod_inverse(&(y1 << 1), &self.curve.p)?
227            } else {
228                return Ok(EllipticPoint::default());
229            }
230        } else {
231            (y1 - y2) * mod_inverse(&(x1 - x2), &self.curve.p)?
232        };
233        let xr = mod_positive(&(&m * &m - x1 - x2), &self.curve.p);
234        let yr = mod_positive(&(&m * (x1 - &xr) - y1), &self.curve.p);
235        let pr = EllipticPoint::new(xr, yr);
236        if !self.curve.check_on(&pr) {
237            return Err(EcdhError::PointNotOnCurve);
238        }
239        Ok(pr)
240    }
241}
242
243impl KeyExchange for Ecdh {
244    type Error = EcdhError;
245
246    fn key_exchange<T: AsRef<[u8]>>(&self, ec_pub: T, hash: bool) -> Result<Vec<u8>> {
247        let shared = self.create_shared(&self.secret, &self.unpack_public(ec_pub)?)?;
248        Ok(self.pack_shared(&shared, hash))
249    }
250}
251
252fn mod_inverse(a: &BigInt, b: &BigInt) -> Result<BigInt> {
253    if a.is_negative() {
254        return Ok(b - mod_inverse(&-a, b)?);
255    }
256    if a.gcd(b) != BigInt::one() {
257        return Err(EcdhError::InverseDoesNotExist);
258    }
259    Ok(a.modpow(&(b - 2), b))
260}
261
262fn mod_positive(a: &BigInt, b: &BigInt) -> BigInt {
263    let result = a % b;
264    if result.is_negative() {
265        result + b
266    } else {
267        result
268    }
269}
270
271#[derive(Debug, Clone)]
272pub struct EllipticCurve {
273    pub p: BigInt,
274    pub a: BigInt,
275    pub b: BigInt,
276    pub g: EllipticPoint,
277    pub n: BigInt,
278    // h: BigInt,
279    pub size: usize,
280    pub pack_size: usize,
281}
282
283impl EllipticCurve {
284    pub fn check_on(&self, point: &EllipticPoint) -> bool {
285        // ((&point.y.pow(2) - &point.x.pow(3) - &self.a * &point.x - &self.b) % &self.p) == BigInt::zero()
286        let lhs = point.y.modpow(&BigInt::from(2), &self.p); // y² mod p
287        let rhs = (point.x.modpow(&BigInt::from(3), &self.p)  // x³ mod p
288            + (&self.a * &point.x) % &self.p
289            + &self.b)
290            % &self.p;
291        lhs == rhs
292    }
293}
294
295#[derive(Debug, Clone)]
296pub struct EllipticPoint {
297    x: BigInt,
298    y: BigInt,
299}
300
301impl EllipticPoint {
302    pub fn new(x: BigInt, y: BigInt) -> EllipticPoint {
303        Self { x, y }
304    }
305
306    pub fn is_default(&self) -> bool {
307        self.x.is_zero() && self.y.is_zero()
308    }
309}
310
311impl Default for EllipticPoint {
312    fn default() -> EllipticPoint {
313        EllipticPoint {
314            x: BigInt::ZERO,
315            y: BigInt::ZERO,
316        }
317    }
318}
319
320impl Neg for EllipticPoint {
321    type Output = EllipticPoint;
322    fn neg(self) -> EllipticPoint {
323        Self {
324            x: -self.x,
325            y: -self.y,
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_create_shared() -> Result<()> {
336        let ec_pub = EllipticPoint {
337            x: BigInt::from_slice(
338                Sign::Plus,
339                &[
340                    3633889942, 4104206661, 770388896, 1996717441, 1671708914, 4173129445,
341                    3777774151, 1796723186,
342                ],
343            ),
344            y: BigInt::from_slice(
345                Sign::Plus,
346                &[
347                    935285237, 3417718888, 1798397646, 734933847, 2081398294, 2397563722,
348                    4263149467, 1340293858,
349                ],
350            ),
351        };
352        let ec_sec = BigInt::from_slice(
353            Sign::Plus,
354            &[
355                2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051, 1983615641,
356                1413039311,
357            ],
358        );
359        let ecdh = Ecdh {
360            curve: curve::PRIME256V1.clone(),
361            public: EllipticPoint::default(),
362            secret: BigInt::from_slice(
363                Sign::Plus,
364                &[
365                    2792767394, 3497172710, 1652332542, 1637215680, 2069543466, 3042786051,
366                    1983615641, 1413039311,
367                ],
368            ),
369        };
370        let shared = ecdh.create_shared(&ec_sec, &ec_pub)?;
371        assert_eq!(
372            shared.x.to_string(),
373            "108127657902349890608713381133328869798433126817450190281707933471728911716600"
374        );
375        assert_eq!(
376            shared.y.to_string(),
377            "90081983369721618215244899859661517112066743317494361926578071297084718897517"
378        );
379        Ok(())
380    }
381
382    #[test]
383    fn test_pack_public() -> Result<()> {
384        let ecdh = Ecdh {
385            curve: curve::PRIME256V1.clone(),
386            secret: BigInt::from_slice(
387                Sign::Plus,
388                &[
389                    3008988791, 2223512191, 4101290493, 1349918355, 981201553, 1906136310,
390                    3328719820, 837222639,
391                ],
392            ),
393            public: EllipticPoint {
394                x: BigInt::from_slice(
395                    Sign::Plus,
396                    &[
397                        192391996, 42411882, 3555170359, 1596249776, 1928736291, 3341608003,
398                        1479611319, 155190516,
399                    ],
400                ),
401                y: BigInt::from_slice(
402                    Sign::Plus,
403                    &[
404                        3113797789, 2959440885, 2287647648, 3290191409, 693816809, 3512895204,
405                        3156405519, 100630943,
406                    ],
407                ),
408            },
409        };
410        let public = ecdh.pack_public(false)?;
411        assert_eq!(
412            public.as_slice(),
413            &[
414                4, 9, 64, 4, 244, 88, 49, 19, 183, 199, 44, 228, 67, 114, 246, 46, 35, 95, 36, 214,
415                176, 211, 231, 152, 55, 2, 135, 39, 106, 11, 119, 171, 60, 5, 255, 129, 159, 188,
416                34, 237, 15, 209, 98, 134, 228, 41, 90, 205, 233, 196, 28, 86, 49, 136, 90, 187,
417                160, 176, 101, 123, 245, 185, 152, 200, 157,
418            ]
419        );
420        Ok(())
421    }
422
423    #[test]
424    fn test_ecdh_key_exchange() -> Result<()> {
425        let curve_p256 = &curve::PRIME256V1.clone();
426        let alice = Ecdh::new(curve_p256.clone())?;
427        let alice_public_key = alice.pack_public(false)?;
428        let bob = Ecdh::new(curve_p256.clone())?;
429        let bob_public_key = bob.pack_public(false)?;
430        let shared_by_alice = alice.key_exchange(bob_public_key, false)?;
431        let shared_by_bob = bob.key_exchange(alice_public_key, false)?;
432        assert_eq!(shared_by_alice, shared_by_bob);
433        Ok(())
434    }
435}