sra_wasm/
lib.rs

1use num_bigint::{BigInt, BigUint, ToBigInt};
2use num_integer::Integer;
3use num_prime::RandPrime;
4use num_traits::{One, Zero};
5use rand::Rng;
6use std::str::FromStr;
7use wasm_bindgen::prelude::*;
8
9#[wasm_bindgen]
10extern "C" {
11    #[wasm_bindgen(js_namespace = console)]
12    fn log(s: &str);
13}
14
15#[wasm_bindgen]
16pub fn js_generate_phi_n(bit_size: usize) -> JsValue {
17    let (phi, n) = generate_phi_n(bit_size);
18    let obj = js_sys::Object::new();
19    js_sys::Reflect::set(&obj, &"phi".into(), &JsValue::from_str(&phi.to_string())).unwrap();
20    js_sys::Reflect::set(&obj, &"n".into(), &JsValue::from_str(&n.to_string())).unwrap();
21    obj.into()
22}
23
24#[wasm_bindgen]
25pub fn js_generate_key_pair(js_phi: &str) -> JsValue {
26    let phi = BigInt::from_str(js_phi).unwrap();
27    let (e, d) = generate_key_pair(&phi);
28    let obj = js_sys::Object::new();
29    js_sys::Reflect::set(&obj, &"e".into(), &JsValue::from_str(&e.to_string())).unwrap();
30    js_sys::Reflect::set(&obj, &"d".into(), &JsValue::from_str(&d.to_string())).unwrap();
31    obj.into()
32}
33
34#[wasm_bindgen]
35pub fn js_encrypt(js_message: &str, js_e: &str, js_n: &str) -> JsValue {
36    // parse string into u32 instead of BigInt
37    let message: u32 = js_message.parse().unwrap();
38    let e = BigInt::from_str(js_e).unwrap();
39    let n = BigInt::from_str(js_n).unwrap();
40
41    let cipher = encrypt(&BigInt::from(message), &e, &n);
42    JsValue::from_str(&cipher.to_string())
43}
44
45#[wasm_bindgen]
46pub fn js_decrypt(js_cipher: &str, js_d: &str, js_n: &str) -> JsValue {
47    // parse string into u32 instead of BigInt
48    let cipher: BigInt = BigInt::from_str(js_cipher).unwrap();
49    let d = BigInt::from_str(js_d).unwrap();
50    let n = BigInt::from_str(js_n).unwrap();
51
52    let decrypted = decrypt(&cipher, &d, &n);
53    JsValue::from_str(&decrypted.to_string())
54}
55
56fn exp_by_squaring(base: &BigInt, exp: &BigInt, modulus: &BigInt) -> BigInt {
57    if *exp == Zero::zero() {
58        One::one()
59    } else if exp.is_even() {
60        let half = exp.clone() >> 1; // Divide by 2
61        let half_exp = exp_by_squaring(base, &half, modulus);
62        return (&half_exp * &half_exp) % modulus;
63    } else {
64        let half = (exp.clone() - BigInt::one()) >> 1; // (exp - 1) / 2
65        let half_exp = exp_by_squaring(base, &half, modulus);
66        return (base * &half_exp * &half_exp) % modulus;
67    }
68}
69
70fn encrypt(message: &BigInt, e: &BigInt, n: &BigInt) -> BigInt {
71    let cipher: BigInt = exp_by_squaring(message, e, n);
72    log(&format!(
73        "cipher: {} cipher bits: {}",
74        &cipher,
75        &cipher.bits()
76    ));
77    cipher
78}
79
80fn decrypt(cipher: &BigInt, d: &BigInt, n: &BigInt) -> BigInt {
81    let message: BigInt = exp_by_squaring(cipher, d, n);
82    log(&format!(
83        "message: {} message bits: {}",
84        &message,
85        &message.bits()
86    ));
87    message
88}
89
90fn get_fixed_sized_prime(bit_size: usize) -> BigInt {
91    let mut rng = rand::thread_rng();
92
93    let mut prime: BigUint;
94    loop {
95        prime = rng.gen_prime(bit_size, None);
96        if prime.bits() == bit_size as u64 {
97            break;
98        }
99    }
100    prime.to_bigint().unwrap()
101}
102
103// Generate a shared phi and N, while keeping p and q secret.
104fn generate_phi_n(bit_size: usize) -> (BigInt, BigInt) {
105    let p = get_fixed_sized_prime(bit_size / 2);
106    let q = get_fixed_sized_prime(bit_size / 2);
107    let phi = (p.clone() - BigInt::one()) * (q.clone() - BigInt::one());
108    let n = p.clone() * q.clone();
109    if n.bits() != bit_size as u64 {
110        return generate_phi_n(bit_size);
111    }
112    (phi, n)
113}
114
115fn extended_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
116    if *a == BigInt::zero() {
117        (b.clone(), BigInt::zero(), BigInt::one())
118    } else {
119        let (g, x, y) = extended_gcd(&(b % a), a);
120        (g, y.clone() - (b / a) * x.clone(), x)
121    }
122}
123
124fn mod_inverse(a: &BigInt, m: &BigInt) -> Option<BigInt> {
125    let (g, x, _) = extended_gcd(a, m);
126    if g != BigInt::one() {
127        None
128    } else {
129        Some((x % m + m) % m)
130    }
131}
132
133fn generate_key_pair(phi: &BigInt) -> (BigInt, BigInt) {
134    let mut rng = rand::thread_rng();
135
136    loop {
137        // Generate a random `e`
138        let e =
139            rng.gen_range(BigInt::one() << (phi.bits() / 2)..=BigInt::one() << (phi.bits() - 1));
140
141        if phi.gcd(&e) == BigInt::one() {
142            // Try to compute the modular inverse of `e` modulo `phi`
143            if let Some(d) = mod_inverse(&e, phi) {
144                return (e, d);
145            }
146            // If mod_inverse returns None, continue the loop to try again
147        }
148        // If the gcd is not 1, the loop will continue to try again
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use wasm_bindgen_test::wasm_bindgen_test;
156
157    #[wasm_bindgen_test]
158    fn test_get_fixed_prime() {
159        log(&format!("\n\n"));
160        let mut old_prime: BigInt = BigInt::zero();
161        for i in 0..52 {
162            let prime = get_fixed_sized_prime(32 / 2);
163            assert!(old_prime != prime);
164            assert_eq!(prime.bits(), 16);
165            if (i % 8) == 0 {
166                log(&format!("==================="));
167                log(&format!("prime = {}", prime));
168                log(&format!("bits = {}", prime.bits()));
169            }
170            old_prime = prime;
171        }
172        log(&format!("==================="));
173    }
174
175    #[wasm_bindgen_test]
176    fn test_generate_phi_n() {
177        log(&format!("\n\n"));
178        let mut old_phi_n: (BigInt, BigInt) = (BigInt::zero(), BigInt::zero());
179        for i in 0..64 {
180            let (phi, n) = generate_phi_n(32);
181            assert!(n.bits() == 32);
182            assert!(old_phi_n != (phi.clone(), n.clone()));
183            if (i % 8) == 0 {
184                log(&format!("==================="));
185                log(&format!("phi = {} n = {}", phi, n));
186                log(&format!("phi bits = {}", phi.bits()));
187                log(&format!("n bits = {}", n.bits()));
188            }
189            old_phi_n = (phi, n);
190        }
191        log(&format!("==================="));
192    }
193
194    #[wasm_bindgen_test]
195    fn test_generate_key_pair() {
196        log(&format!("\n\n"));
197        let (phi, n) = generate_phi_n(32);
198        let (e, d) = generate_key_pair(&phi);
199        let midpoint = 52u8 / 2;
200
201        // Initialize with default values
202        let mut deck_e: [[String; 26]; 2] = Default::default();
203        let mut deck_d: [[String; 26]; 2] = Default::default();
204
205        for i in 0u8..52u8 {
206            let value: JsValue = js_generate_key_pair(&phi.to_string());
207            let e1 = js_sys::Reflect::get(&value, &"e".into())
208                .unwrap()
209                .as_string()
210                .unwrap()
211                .parse::<u32>()
212                .unwrap();
213            let d1 = js_sys::Reflect::get(&value, &"d".into())
214                .unwrap()
215                .as_string()
216                .unwrap()
217                .parse::<u32>()
218                .unwrap();
219
220            if i < midpoint {
221                deck_e[0][i as usize] = format!("{}u32", e1);
222                deck_d[0][i as usize] = format!("{}u32", d1);
223            } else {
224                deck_e[1][i as usize - midpoint as usize] = format!("{}u32", e1);
225                deck_d[1][i as usize - midpoint as usize] = format!("{}u32", d1);
226            }
227        }
228
229        log(&format!("==================="));
230        log(&format!("deck_e {:?}", deck_e));
231        log(&format!("deck_d {:?}", deck_d));
232        log(&format!("==================="));
233        log(&format!("\n\n"));
234        // e and d are the encryption and decryption key pair.
235        // e is the public key, d is the private key.
236        log(&format!("==================="));
237        log(&format!("phi = {} n = {}", phi, n));
238        log(&format!("e = {} d = {}", e, d));
239        log(&format!("phi bits = {}", phi.bits()));
240        log(&format!("n bits = {}", n.bits()));
241        log(&format!("e bits = {}", e.bits()));
242        log(&format!("d bits = {}", d.bits()));
243        assert!(BigInt::one() < e);
244        assert!(e < phi);
245        assert!(BigInt::one() < d);
246        assert!(d < phi);
247        assert!(d != BigInt::zero() && e != BigInt::zero());
248        assert_eq!((e * d) % phi, BigInt::one());
249        log(&format!("==================="));
250    }
251
252    #[wasm_bindgen_test]
253    fn test_generate_large_key_pair() {
254        log(&format!("\n\n"));
255        let mut old_phi_n: (BigInt, BigInt) = (BigInt::zero(), BigInt::zero());
256        let bit_size_samples: &[usize] = &[256, 512, 1024, 2048];
257        for bit_size in bit_size_samples.iter() {
258            for _ in 0..2 {
259                let (phi, n) = generate_phi_n(*bit_size);
260                assert!(n.bits() == *bit_size as u64);
261                assert!(old_phi_n != (phi.clone(), n.clone()));
262                let (e, d) = generate_key_pair(&phi);
263                let e_log = e.clone();
264                let d_log = e.clone();
265                assert!(BigInt::one() < e);
266                assert!(e < phi);
267                assert!(BigInt::one() < d);
268                assert!(d < phi);
269                assert!(d != BigInt::zero() && e != BigInt::zero());
270                assert_eq!((e * d) % phi.clone(), BigInt::one());
271                log(&format!("==================="));
272                log(&format!("phi = {} n = {}", phi, n));
273                log(&format!("phi bits = {}", phi.bits()));
274                log(&format!("n bits = {}", n.bits()));
275                log(&format!("e = {} d = {}", e_log, d_log));
276                log(&format!("e bits = {}", e_log.bits()));
277                log(&format!("d bits = {}", d_log.bits()));
278                old_phi_n = (phi, n);
279            }
280        }
281    }
282
283    #[wasm_bindgen_test]
284    fn test_sra() {
285        log(&format!("\n\n"));
286        // Shared p, q, n
287        let (phi, n) = generate_phi_n(32);
288        // Alice key pair (e1, d1)
289        let (e1, d1) = generate_key_pair(&phi);
290        // Bob key pair (e2, d2)
291        let (e2, d2) = generate_key_pair(&phi);
292        assert!(e1 < n);
293        assert!(e2 < n);
294        assert!(e1 != e2);
295        // The card
296        let message = BigInt::from(63u8);
297        log(&format!("==================="));
298        log(&format!("phi = {} n = {}", phi, n));
299        log(&format!("==================="));
300        log(&format!("e1 = {} d1 = {}", e1, d1));
301        log(&format!("e2 = {} d2 = {}", e2, d2));
302        log(&format!("n = {}", n));
303        log(&format!("Message = {}", message));
304        log(&format!("==================="));
305        let alice_cipher = encrypt(&message, &e1, &n);
306        log(&format!(
307            "  Cipher result (after Alice encrypt): {}",
308            alice_cipher
309        ));
310
311        let bob_cipher = encrypt(&alice_cipher, &e2, &n);
312        log(&format!(
313            "  Cipher result (after Bob encrypt): {}",
314            bob_cipher
315        ));
316
317        let decipher_1 = decrypt(&bob_cipher, &d2, &n);
318        log(&format!(
319            "  Cipher result (After Bob decrypt): {}",
320            decipher_1
321        ));
322
323        let decipher_2 = decrypt(&decipher_1, &d1, &n);
324        log(&format!(
325            "  Cipher result (After Alice decrypt): {}",
326            decipher_2
327        ));
328
329        log(&format!("A -> B -> B -> A: {}", decipher_2));
330
331        let decipher_1 = decrypt(&bob_cipher, &d1, &n);
332        log(&format!(
333            "  Cipher result (After Alice decrypt): {}",
334            decipher_1
335        ));
336
337        let decipher_2 = decrypt(&decipher_1, &d2, &n);
338        log(&format!(
339            "  Cipher result (After Bob decrypt): {}",
340            decipher_2
341        ));
342
343        log(&format!("A -> B -> A -> B: {}", decipher_2));
344
345        assert_eq!(decipher_2, message);
346    }
347
348    #[wasm_bindgen_test]
349    fn test_sra_mock() {
350        log(&format!("\n\n"));
351        let e1: BigInt = BigInt::from(5u8);
352        let d1: BigInt = BigInt::from(29u8);
353        let e2: BigInt = BigInt::from(7u8);
354        let d2: BigInt = BigInt::from(31u8);
355        let n: BigInt = BigInt::from(91u8);
356        let message: BigInt = BigInt::from(5u8);
357
358        log(&format!("==================="));
359        log(&format!("d1 = {}, d2 = {}", d1, d2));
360        log(&format!("e1 = {}, e2 = {}", e1, e2));
361        log(&format!("n = {}", n));
362        log(&format!("Message = {}", message));
363        log(&format!("==================="));
364
365        let alice_cipher = encrypt(&message, &e1, &n);
366        log(&format!(
367            "Cipher result (after Alice encrypt): {}",
368            alice_cipher
369        ));
370        assert_eq!(BigInt::from(31u8), alice_cipher);
371
372        let bob_cipher = encrypt(&alice_cipher, &e2, &n);
373        log(&format!(
374            "Cipher result (after Bob encrypt): {}",
375            bob_cipher
376        ));
377        assert_eq!(BigInt::from(73u8), bob_cipher);
378
379        log(&format!("==================="));
380
381        let decipher_1 = decrypt(&bob_cipher, &d2, &n);
382        log(&format!(
383            "Cipher result (After Bob decrypt): {}",
384            decipher_1
385        ));
386        assert_eq!(BigInt::from(31u8), decipher_1);
387        let decipher_2 = decrypt(&decipher_1, &d1, &n);
388        log(&format!(
389            "Cipher result (After Alice decrypt): {}",
390            decipher_2
391        ));
392        log(&format!("A -> B -> B -> A: {}", decipher_2));
393        assert_eq!(BigInt::from(5u8), decipher_2);
394
395        let decipher_1 = decrypt(&bob_cipher, &d1, &n);
396        log(&format!(
397            "Cipher result (After Alice decrypt): {}",
398            decipher_1
399        ));
400        assert_eq!(BigInt::from(47u8), decipher_1);
401        let decipher_2 = decrypt(&decipher_1, &d2, &n);
402        log(&format!(
403            "Cipher result (After Bob decrypt): {}",
404            decipher_2
405        ));
406        log(&format!("A -> B -> A -> B: {}", decipher_2));
407        assert_eq!(BigInt::from(5u8), decipher_2);
408    }
409}