1use num_bigint::{BigInt, BigUint, ToBigInt};
2use num_integer::Integer;
3use num_prime::RandPrime;
4use num_traits::{One, Zero};
5use rand::Rng;
6use std::str::FromStr;
7use wasm_bindgen::prelude::*;
8
9#[wasm_bindgen]
10extern "C" {
11 #[wasm_bindgen(js_namespace = console)]
12 fn log(s: &str);
13}
14
15#[wasm_bindgen]
16pub fn js_generate_phi_n(bit_size: usize) -> JsValue {
17 let (phi, n) = generate_phi_n(bit_size);
18 let obj = js_sys::Object::new();
19 js_sys::Reflect::set(&obj, &"phi".into(), &JsValue::from_str(&phi.to_string())).unwrap();
20 js_sys::Reflect::set(&obj, &"n".into(), &JsValue::from_str(&n.to_string())).unwrap();
21 obj.into()
22}
23
24#[wasm_bindgen]
25pub fn js_generate_key_pair(js_phi: &str) -> JsValue {
26 let phi = BigInt::from_str(js_phi).unwrap();
27 let (e, d) = generate_key_pair(&phi);
28 let obj = js_sys::Object::new();
29 js_sys::Reflect::set(&obj, &"e".into(), &JsValue::from_str(&e.to_string())).unwrap();
30 js_sys::Reflect::set(&obj, &"d".into(), &JsValue::from_str(&d.to_string())).unwrap();
31 obj.into()
32}
33
34#[wasm_bindgen]
35pub fn js_encrypt(js_message: &str, js_e: &str, js_n: &str) -> JsValue {
36 let message: u32 = js_message.parse().unwrap();
38 let e = BigInt::from_str(js_e).unwrap();
39 let n = BigInt::from_str(js_n).unwrap();
40
41 let cipher = encrypt(&BigInt::from(message), &e, &n);
42 JsValue::from_str(&cipher.to_string())
43}
44
45#[wasm_bindgen]
46pub fn js_decrypt(js_cipher: &str, js_d: &str, js_n: &str) -> JsValue {
47 let cipher: BigInt = BigInt::from_str(js_cipher).unwrap();
49 let d = BigInt::from_str(js_d).unwrap();
50 let n = BigInt::from_str(js_n).unwrap();
51
52 let decrypted = decrypt(&cipher, &d, &n);
53 JsValue::from_str(&decrypted.to_string())
54}
55
56fn exp_by_squaring(base: &BigInt, exp: &BigInt, modulus: &BigInt) -> BigInt {
57 if *exp == Zero::zero() {
58 One::one()
59 } else if exp.is_even() {
60 let half = exp.clone() >> 1; let half_exp = exp_by_squaring(base, &half, modulus);
62 return (&half_exp * &half_exp) % modulus;
63 } else {
64 let half = (exp.clone() - BigInt::one()) >> 1; let half_exp = exp_by_squaring(base, &half, modulus);
66 return (base * &half_exp * &half_exp) % modulus;
67 }
68}
69
70fn encrypt(message: &BigInt, e: &BigInt, n: &BigInt) -> BigInt {
71 let cipher: BigInt = exp_by_squaring(message, e, n);
72 log(&format!(
73 "cipher: {} cipher bits: {}",
74 &cipher,
75 &cipher.bits()
76 ));
77 cipher
78}
79
80fn decrypt(cipher: &BigInt, d: &BigInt, n: &BigInt) -> BigInt {
81 let message: BigInt = exp_by_squaring(cipher, d, n);
82 log(&format!(
83 "message: {} message bits: {}",
84 &message,
85 &message.bits()
86 ));
87 message
88}
89
90fn get_fixed_sized_prime(bit_size: usize) -> BigInt {
91 let mut rng = rand::thread_rng();
92
93 let mut prime: BigUint;
94 loop {
95 prime = rng.gen_prime(bit_size, None);
96 if prime.bits() == bit_size as u64 {
97 break;
98 }
99 }
100 prime.to_bigint().unwrap()
101}
102
103fn generate_phi_n(bit_size: usize) -> (BigInt, BigInt) {
105 let p = get_fixed_sized_prime(bit_size / 2);
106 let q = get_fixed_sized_prime(bit_size / 2);
107 let phi = (p.clone() - BigInt::one()) * (q.clone() - BigInt::one());
108 let n = p.clone() * q.clone();
109 if n.bits() != bit_size as u64 {
110 return generate_phi_n(bit_size);
111 }
112 (phi, n)
113}
114
115fn extended_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
116 if *a == BigInt::zero() {
117 (b.clone(), BigInt::zero(), BigInt::one())
118 } else {
119 let (g, x, y) = extended_gcd(&(b % a), a);
120 (g, y.clone() - (b / a) * x.clone(), x)
121 }
122}
123
124fn mod_inverse(a: &BigInt, m: &BigInt) -> Option<BigInt> {
125 let (g, x, _) = extended_gcd(a, m);
126 if g != BigInt::one() {
127 None
128 } else {
129 Some((x % m + m) % m)
130 }
131}
132
133fn generate_key_pair(phi: &BigInt) -> (BigInt, BigInt) {
134 let mut rng = rand::thread_rng();
135
136 loop {
137 let e =
139 rng.gen_range(BigInt::one() << (phi.bits() / 2)..=BigInt::one() << (phi.bits() - 1));
140
141 if phi.gcd(&e) == BigInt::one() {
142 if let Some(d) = mod_inverse(&e, phi) {
144 return (e, d);
145 }
146 }
148 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use wasm_bindgen_test::wasm_bindgen_test;
156
157 #[wasm_bindgen_test]
158 fn test_get_fixed_prime() {
159 log(&format!("\n\n"));
160 let mut old_prime: BigInt = BigInt::zero();
161 for i in 0..52 {
162 let prime = get_fixed_sized_prime(32 / 2);
163 assert!(old_prime != prime);
164 assert_eq!(prime.bits(), 16);
165 if (i % 8) == 0 {
166 log(&format!("==================="));
167 log(&format!("prime = {}", prime));
168 log(&format!("bits = {}", prime.bits()));
169 }
170 old_prime = prime;
171 }
172 log(&format!("==================="));
173 }
174
175 #[wasm_bindgen_test]
176 fn test_generate_phi_n() {
177 log(&format!("\n\n"));
178 let mut old_phi_n: (BigInt, BigInt) = (BigInt::zero(), BigInt::zero());
179 for i in 0..64 {
180 let (phi, n) = generate_phi_n(32);
181 assert!(n.bits() == 32);
182 assert!(old_phi_n != (phi.clone(), n.clone()));
183 if (i % 8) == 0 {
184 log(&format!("==================="));
185 log(&format!("phi = {} n = {}", phi, n));
186 log(&format!("phi bits = {}", phi.bits()));
187 log(&format!("n bits = {}", n.bits()));
188 }
189 old_phi_n = (phi, n);
190 }
191 log(&format!("==================="));
192 }
193
194 #[wasm_bindgen_test]
195 fn test_generate_key_pair() {
196 log(&format!("\n\n"));
197 let (phi, n) = generate_phi_n(32);
198 let (e, d) = generate_key_pair(&phi);
199 let midpoint = 52u8 / 2;
200
201 let mut deck_e: [[String; 26]; 2] = Default::default();
203 let mut deck_d: [[String; 26]; 2] = Default::default();
204
205 for i in 0u8..52u8 {
206 let value: JsValue = js_generate_key_pair(&phi.to_string());
207 let e1 = js_sys::Reflect::get(&value, &"e".into())
208 .unwrap()
209 .as_string()
210 .unwrap()
211 .parse::<u32>()
212 .unwrap();
213 let d1 = js_sys::Reflect::get(&value, &"d".into())
214 .unwrap()
215 .as_string()
216 .unwrap()
217 .parse::<u32>()
218 .unwrap();
219
220 if i < midpoint {
221 deck_e[0][i as usize] = format!("{}u32", e1);
222 deck_d[0][i as usize] = format!("{}u32", d1);
223 } else {
224 deck_e[1][i as usize - midpoint as usize] = format!("{}u32", e1);
225 deck_d[1][i as usize - midpoint as usize] = format!("{}u32", d1);
226 }
227 }
228
229 log(&format!("==================="));
230 log(&format!("deck_e {:?}", deck_e));
231 log(&format!("deck_d {:?}", deck_d));
232 log(&format!("==================="));
233 log(&format!("\n\n"));
234 log(&format!("==================="));
237 log(&format!("phi = {} n = {}", phi, n));
238 log(&format!("e = {} d = {}", e, d));
239 log(&format!("phi bits = {}", phi.bits()));
240 log(&format!("n bits = {}", n.bits()));
241 log(&format!("e bits = {}", e.bits()));
242 log(&format!("d bits = {}", d.bits()));
243 assert!(BigInt::one() < e);
244 assert!(e < phi);
245 assert!(BigInt::one() < d);
246 assert!(d < phi);
247 assert!(d != BigInt::zero() && e != BigInt::zero());
248 assert_eq!((e * d) % phi, BigInt::one());
249 log(&format!("==================="));
250 }
251
252 #[wasm_bindgen_test]
253 fn test_generate_large_key_pair() {
254 log(&format!("\n\n"));
255 let mut old_phi_n: (BigInt, BigInt) = (BigInt::zero(), BigInt::zero());
256 let bit_size_samples: &[usize] = &[256, 512, 1024, 2048];
257 for bit_size in bit_size_samples.iter() {
258 for _ in 0..2 {
259 let (phi, n) = generate_phi_n(*bit_size);
260 assert!(n.bits() == *bit_size as u64);
261 assert!(old_phi_n != (phi.clone(), n.clone()));
262 let (e, d) = generate_key_pair(&phi);
263 let e_log = e.clone();
264 let d_log = e.clone();
265 assert!(BigInt::one() < e);
266 assert!(e < phi);
267 assert!(BigInt::one() < d);
268 assert!(d < phi);
269 assert!(d != BigInt::zero() && e != BigInt::zero());
270 assert_eq!((e * d) % phi.clone(), BigInt::one());
271 log(&format!("==================="));
272 log(&format!("phi = {} n = {}", phi, n));
273 log(&format!("phi bits = {}", phi.bits()));
274 log(&format!("n bits = {}", n.bits()));
275 log(&format!("e = {} d = {}", e_log, d_log));
276 log(&format!("e bits = {}", e_log.bits()));
277 log(&format!("d bits = {}", d_log.bits()));
278 old_phi_n = (phi, n);
279 }
280 }
281 }
282
283 #[wasm_bindgen_test]
284 fn test_sra() {
285 log(&format!("\n\n"));
286 let (phi, n) = generate_phi_n(32);
288 let (e1, d1) = generate_key_pair(&phi);
290 let (e2, d2) = generate_key_pair(&phi);
292 assert!(e1 < n);
293 assert!(e2 < n);
294 assert!(e1 != e2);
295 let message = BigInt::from(63u8);
297 log(&format!("==================="));
298 log(&format!("phi = {} n = {}", phi, n));
299 log(&format!("==================="));
300 log(&format!("e1 = {} d1 = {}", e1, d1));
301 log(&format!("e2 = {} d2 = {}", e2, d2));
302 log(&format!("n = {}", n));
303 log(&format!("Message = {}", message));
304 log(&format!("==================="));
305 let alice_cipher = encrypt(&message, &e1, &n);
306 log(&format!(
307 " Cipher result (after Alice encrypt): {}",
308 alice_cipher
309 ));
310
311 let bob_cipher = encrypt(&alice_cipher, &e2, &n);
312 log(&format!(
313 " Cipher result (after Bob encrypt): {}",
314 bob_cipher
315 ));
316
317 let decipher_1 = decrypt(&bob_cipher, &d2, &n);
318 log(&format!(
319 " Cipher result (After Bob decrypt): {}",
320 decipher_1
321 ));
322
323 let decipher_2 = decrypt(&decipher_1, &d1, &n);
324 log(&format!(
325 " Cipher result (After Alice decrypt): {}",
326 decipher_2
327 ));
328
329 log(&format!("A -> B -> B -> A: {}", decipher_2));
330
331 let decipher_1 = decrypt(&bob_cipher, &d1, &n);
332 log(&format!(
333 " Cipher result (After Alice decrypt): {}",
334 decipher_1
335 ));
336
337 let decipher_2 = decrypt(&decipher_1, &d2, &n);
338 log(&format!(
339 " Cipher result (After Bob decrypt): {}",
340 decipher_2
341 ));
342
343 log(&format!("A -> B -> A -> B: {}", decipher_2));
344
345 assert_eq!(decipher_2, message);
346 }
347
348 #[wasm_bindgen_test]
349 fn test_sra_mock() {
350 log(&format!("\n\n"));
351 let e1: BigInt = BigInt::from(5u8);
352 let d1: BigInt = BigInt::from(29u8);
353 let e2: BigInt = BigInt::from(7u8);
354 let d2: BigInt = BigInt::from(31u8);
355 let n: BigInt = BigInt::from(91u8);
356 let message: BigInt = BigInt::from(5u8);
357
358 log(&format!("==================="));
359 log(&format!("d1 = {}, d2 = {}", d1, d2));
360 log(&format!("e1 = {}, e2 = {}", e1, e2));
361 log(&format!("n = {}", n));
362 log(&format!("Message = {}", message));
363 log(&format!("==================="));
364
365 let alice_cipher = encrypt(&message, &e1, &n);
366 log(&format!(
367 "Cipher result (after Alice encrypt): {}",
368 alice_cipher
369 ));
370 assert_eq!(BigInt::from(31u8), alice_cipher);
371
372 let bob_cipher = encrypt(&alice_cipher, &e2, &n);
373 log(&format!(
374 "Cipher result (after Bob encrypt): {}",
375 bob_cipher
376 ));
377 assert_eq!(BigInt::from(73u8), bob_cipher);
378
379 log(&format!("==================="));
380
381 let decipher_1 = decrypt(&bob_cipher, &d2, &n);
382 log(&format!(
383 "Cipher result (After Bob decrypt): {}",
384 decipher_1
385 ));
386 assert_eq!(BigInt::from(31u8), decipher_1);
387 let decipher_2 = decrypt(&decipher_1, &d1, &n);
388 log(&format!(
389 "Cipher result (After Alice decrypt): {}",
390 decipher_2
391 ));
392 log(&format!("A -> B -> B -> A: {}", decipher_2));
393 assert_eq!(BigInt::from(5u8), decipher_2);
394
395 let decipher_1 = decrypt(&bob_cipher, &d1, &n);
396 log(&format!(
397 "Cipher result (After Alice decrypt): {}",
398 decipher_1
399 ));
400 assert_eq!(BigInt::from(47u8), decipher_1);
401 let decipher_2 = decrypt(&decipher_1, &d2, &n);
402 log(&format!(
403 "Cipher result (After Bob decrypt): {}",
404 decipher_2
405 ));
406 log(&format!("A -> B -> A -> B: {}", decipher_2));
407 assert_eq!(BigInt::from(5u8), decipher_2);
408 }
409}