1use core::mem::size_of;
9#[doc = include_str!("../README.md")]
10use ff::{Field, PrimeField};
11use group::{Group, GroupEncoding};
12use num_bigint_dig::ModInverse;
13use rand::{Rng, SeedableRng};
14use rand_chacha::{rand_core::CryptoRngCore, ChaCha20Rng};
15use rsa::{
16 traits::PublicKeyParts, BigUint, Pkcs1v15Encrypt, RsaPrivateKey,
17 RsaPublicKey,
18};
19use sha2::{Digest, Sha256};
20use std::ops::Index;
21use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
22use thiserror::Error;
23
24pub const SECURITY_PARAM: usize = 128;
25
26pub use rsa;
28
29#[derive(Debug, Error)]
31pub enum RsaError {
32 #[error("Error during encryption")]
33 EncError,
34 #[error("Error during decryption")]
35 DecError,
36 #[error("Invalid label inverse")]
37 InvalidLabel,
38 #[error("Verification failed")]
39 VerificationFailed,
40 #[error("Invalid SIZE parameter, must be equal to the size of the scalar in bytes")]
41 InvalidSizeParam,
42 #[error("(de)Serialization error")]
43 SerdeError(String),
44 #[error("Invalid Security Parameter, cannot be more than 256")]
45 InvalidSecurityParam,
46}
47
48pub struct ProofData<G: Group + GroupEncoding> {
49 g_r: G::Repr,
50 enc_x_r: Vec<u8>,
51 enc_r: Vec<u8>,
52}
53
54pub struct VerifiableRsaEncryption<G>
55where
56 G: Group + GroupEncoding + ConstantTimeEq,
57 G::Scalar: ConditionallySelectable,
58{
59 pub seed: [u8; 32],
60 pub proofs: Vec<ProofData<G>>,
61 pub open_scalars: Vec<G::Scalar>,
62 security_param: usize,
63}
64
65impl<G> VerifiableRsaEncryption<G>
66where
67 G: Group + GroupEncoding + ConstantTimeEq,
68{
69 pub fn encrypt_with_proof<R: CryptoRngCore>(
70 x: &G::Scalar,
71 rsa_pubkey: &RsaPublicKey,
72 label: &[u8],
73 security_param: Option<usize>,
74 rng: &mut R,
75 ) -> Result<Self, RsaError> {
76 let seed = rng.gen::<[u8; 32]>();
77 let security_param = security_param.unwrap_or(SECURITY_PARAM);
78 if !(SECURITY_PARAM..=256).contains(&security_param) {
80 return Err(RsaError::InvalidSizeParam);
81 }
82 let mut proofs = Vec::with_capacity(security_param);
83 let q_point = G::generator() * x;
84 let mut r_list = Vec::with_capacity(security_param);
85 let mut x_plus_r_list = Vec::with_capacity(security_param);
86
87 for _ in 0..security_param {
88 let r = G::Scalar::random(&mut *rng);
89 let g_r = G::generator() * r;
90 let x_plus_r = *x + r;
91 let enc_r =
92 rsa_encrypt_with_label(r.to_repr(), label, rsa_pubkey, seed)?;
93 let enc_x_plus_r = rsa_encrypt_with_label(
94 x_plus_r.to_repr(),
95 label,
96 rsa_pubkey,
97 seed,
98 )?;
99
100 r_list.push(r);
101 x_plus_r_list.push(x_plus_r);
102
103 proofs.push(ProofData {
104 g_r: g_r.to_bytes(),
105 enc_x_r: enc_x_plus_r,
106 enc_r,
107 });
108 }
109 let challenge = Self::challenge(&q_point, label, &proofs);
110 let mut open_scalars = Vec::with_capacity(security_param);
111 for i in 0..security_param {
112 let choice_bit = challenge.extract_bit(i);
113 let selected = G::Scalar::conditional_select(
114 &r_list[i],
115 &x_plus_r_list[i],
116 choice_bit,
117 );
118 open_scalars.push(selected);
119 }
120
121 Ok(Self {
122 open_scalars,
123 proofs,
124 seed,
125 security_param,
126 })
127 }
128
129 pub fn verify(
130 &self,
131 q_point: &G,
132 rsa_pubkey: &RsaPublicKey,
133 label: &[u8],
134 ) -> Result<(), RsaError> {
135 let challenge = Self::challenge(q_point, label, &self.proofs);
136 for i in 0..self.security_param {
137 let proof = &self.proofs[i];
138 let open_scalar = &self.open_scalars[i];
139 let scalar_expo = G::generator() * open_scalar;
140 let choice_bit = challenge.extract_bit(i);
141 let enc_open_scalar = rsa_encrypt_with_label(
142 open_scalar.to_repr(),
143 label,
144 rsa_pubkey,
145 self.seed,
146 )?;
147
148 let g_r_option = G::from_bytes(&proof.g_r);
149 let g_r = if g_r_option.is_some().unwrap_u8() == 1 {
150 g_r_option.unwrap()
151 } else {
152 return Err(RsaError::VerificationFailed);
153 };
154
155 let cond_a = {
157 let cond1 = g_r.ct_eq(&scalar_expo);
158 let cond2 = proof.enc_r.ct_eq(&enc_open_scalar);
159 cond1 & cond2
160 };
161 let cond_b = {
163 let calc_scalar_expo = *q_point + g_r;
164 let cond1 = calc_scalar_expo.ct_eq(&scalar_expo);
165 let cond2 = proof.enc_x_r.ct_eq(&enc_open_scalar);
166 cond1 & cond2
167 };
168
169 let verified =
170 Choice::conditional_select(&cond_a, &cond_b, choice_bit)
171 .unwrap_u8();
172 if verified != 1 {
173 return Err(RsaError::VerificationFailed);
174 }
175 }
176 Ok(())
177 }
178
179 pub fn decrypt(
180 &self,
181 q_point: &G,
182 rsa_privkey: &RsaPrivateKey,
183 label: &[u8],
184 ) -> Result<G::Scalar, RsaError> {
185 if self.proofs.len() != self.security_param {
186 return Err(RsaError::VerificationFailed);
187 }
188
189 for proof in &self.proofs {
190 let enc_r = &proof.enc_r;
191 let enc_x_r = &proof.enc_x_r;
192
193 let r = rsa_decrypt_with_label(enc_r, label, rsa_privkey)?;
194
195 let r = if let Some(r) = decode_scalar::<G::Scalar>(&r) {
197 r
198 } else {
199 continue;
200 };
201
202 let x_plus_r =
203 rsa_decrypt_with_label(enc_x_r, label, rsa_privkey)?;
204
205 let x_plus_r = if let Some(x_plus_r) =
206 decode_scalar::<G::Scalar>(&x_plus_r)
207 {
208 x_plus_r
209 } else {
210 continue;
211 };
212
213 let x = x_plus_r - r;
214 let calc_public_point = G::generator() * x;
215 if calc_public_point == *q_point {
216 return Ok(x);
217 }
218 }
219
220 Err(RsaError::DecError)
221 }
222
223 pub fn to_bytes(&self) -> Vec<u8> {
224 let mut bytes = Vec::new();
225 bytes.extend_from_slice(&self.seed);
226 bytes.extend_from_slice(&(self.security_param as u16).to_be_bytes());
229 bytes.extend_from_slice(
230 &(self.proofs[0].g_r.as_ref().len() as u16).to_be_bytes(),
231 );
232 bytes.extend_from_slice(
233 &(self.proofs[0].enc_x_r.len() as u16).to_be_bytes(),
234 );
235 bytes.extend_from_slice(
236 &(size_of::<<G::Scalar as PrimeField>::Repr>() as u16)
237 .to_be_bytes(),
238 );
239 for proof in &self.proofs {
240 bytes.extend_from_slice(proof.g_r.as_ref());
241 bytes.extend_from_slice(proof.enc_x_r.as_ref());
242 bytes.extend_from_slice(proof.enc_r.as_ref());
243 }
244 for scalar in &self.open_scalars {
245 bytes.extend_from_slice(scalar.to_repr().as_ref());
246 }
247
248 bytes
249 }
250
251 pub fn from_bytes(data: &[u8]) -> Result<Self, RsaError> {
252 let res = || {
253 if data.len() < 32 + 8 {
254 return Err("Input data too short");
256 }
257
258 let mut offset = 0;
259
260 let mut seed = [0u8; 32];
262 seed.copy_from_slice(&data[offset..offset + 32]);
263 offset += 32;
264
265 let security_param =
267 u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
268 offset += 2;
269 let g_r_size =
270 u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
271 offset += 2;
272 let enc_size =
273 u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
274 offset += 2;
275 let scalar_size =
276 u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
277 offset += 2;
278
279 if scalar_size
280 != core::mem::size_of::<<G::Scalar as PrimeField>::Repr>()
281 {
282 return Err("Inconsistent scalar size");
283 }
284
285 if g_r_size != core::mem::size_of::<G::Repr>() {
286 return Err("Inconsistent g_r size");
287 }
288
289 let proof_size = g_r_size + 2 * enc_size;
291 let remaining_data = data.len() - offset;
292 let num_proofs = remaining_data / (proof_size + scalar_size);
293
294 if security_param < SECURITY_PARAM {
295 return Err("Security param must at least be 128");
296 }
297
298 if num_proofs != security_param {
299 return Err("Inconsistent number of proofs, must be equal to the security parameter");
300 }
301
302 if remaining_data % (proof_size + scalar_size) != 0 {
303 return Err("Inconsistent data length");
304 }
305
306 let mut proofs = Vec::with_capacity(num_proofs);
308 for _ in 0..num_proofs {
309 if offset + proof_size > data.len() {
310 return Err(
311 "Unexpected end of data while reading proofs",
312 );
313 }
314
315 let mut g_r = G::Repr::default();
316 g_r.as_mut()
317 .copy_from_slice(&data[offset..offset + g_r_size]);
318
319 offset += g_r_size;
320 let enc_x_r = data[offset..offset + enc_size].to_vec();
321 offset += enc_size;
322
323 let enc_r = data[offset..offset + enc_size].to_vec();
324 offset += enc_size;
325
326 proofs.push(ProofData {
327 g_r,
328 enc_x_r,
329 enc_r,
330 });
331 }
332
333 let mut open_scalars = Vec::with_capacity(num_proofs);
335 let scalar_size = size_of::<<G::Scalar as PrimeField>::Repr>();
336 for _ in 0..num_proofs {
337 if offset + scalar_size > data.len() {
338 return Err(
339 "Unexpected end of data while reading scalars",
340 );
341 }
342 let scalar =
343 decode_scalar(&data[offset..offset + scalar_size])
344 .ok_or("Invalid scalar")?;
345 offset += scalar_size;
346 open_scalars.push(scalar);
347 }
348
349 Ok(Self {
350 seed,
351 proofs,
352 open_scalars,
353 security_param,
354 })
355 };
356 res().map_err(|e| RsaError::SerdeError(e.to_string()))
357 }
358
359 fn challenge(
360 q_point: &G,
361 label: &[u8],
362 proofs: &[ProofData<G>],
363 ) -> [u8; 32] {
364 let mut hasher = Sha256::new();
365 hasher.update(b"Verified-RSA-encryption");
366 hasher.update(q_point.to_bytes());
367 for proof in proofs {
368 hasher.update(proof.g_r);
369 hasher.update(&proof.enc_x_r);
370 hasher.update(&proof.enc_r);
371 }
372 hasher.update(label);
373 hasher.finalize().into()
374 }
375}
376
377fn rsa_encrypt_with_label(
378 m: impl AsRef<[u8]>,
379 label: &[u8],
380 rsa_pubkey: &RsaPublicKey,
381 seed: [u8; 32],
382) -> Result<Vec<u8>, RsaError> {
383 let mut rng = ChaCha20Rng::from_seed(seed);
384 let m_int = BigUint::from_bytes_be(m.as_ref());
385 let label_int = label_int_from_bytes(label);
386 let plaintext = (m_int * label_int) % rsa_pubkey.n();
387 rsa_pubkey
388 .encrypt(&mut rng, Pkcs1v15Encrypt, &plaintext.to_bytes_be())
389 .map_err(|_| RsaError::EncError)
390}
391
392fn rsa_decrypt_with_label(
393 ciphertext: &[u8],
394 label: &[u8],
395 rsa_privkey: &RsaPrivateKey,
396) -> Result<Vec<u8>, RsaError> {
397 let plaintext = rsa_privkey
398 .decrypt(Pkcs1v15Encrypt, ciphertext)
399 .map_err(|_| RsaError::DecError)?;
400
401 let n = rsa_privkey.n();
402 let label_inv = label_int_from_bytes(label)
403 .mod_inverse(n)
404 .and_then(|num| num.to_biguint())
405 .ok_or(RsaError::InvalidLabel)?;
406
407 let plaintext_int = BigUint::from_bytes_be(&plaintext);
408 let message = (plaintext_int * label_inv) % n;
409 Ok(message.to_bytes_be())
410}
411
412fn label_int_from_bytes(label: &[u8]) -> BigUint {
413 let mut hasher = Sha256::new();
414 hasher.update(b"SL-label-for-RSA");
415 hasher.update(label);
416 let digest = hasher.finalize();
417 BigUint::from_bytes_be(&digest[..])
418}
419
420pub trait ExtractBit: Index<usize, Output = u8> {
422 fn extract_bit(&self, idx: usize) -> Choice {
424 let byte_idx = idx >> 3;
425 let bit_idx = idx & 0x7;
426 let byte = self[byte_idx];
427 let mask = 1 << bit_idx;
428 Choice::from(((byte & mask) != 0) as u8)
429 }
430}
431impl<const N: usize> ExtractBit for [u8; N] {}
432
433fn decode_scalar<S: PrimeField>(bytes: &[u8]) -> Option<S> {
434 if bytes.len() != size_of::<S::Repr>() {
435 return None;
436 }
437 let mut encoding = <S as PrimeField>::Repr::default();
438 encoding.as_mut().copy_from_slice(bytes);
439 S::from_repr(encoding).into()
440}
441
442#[cfg(test)]
443mod tests {
444 use curve25519_dalek::EdwardsPoint;
445 use group::Group;
446 use k256::{ProjectivePoint, Scalar};
447 use rand::SeedableRng;
448 use rand_chacha::ChaCha20Rng;
449 use rsa::RsaPrivateKey;
450 use subtle::Choice;
451
452 use crate::*;
453
454 #[test]
455 fn test_verifiable_rsa_ecdsa() -> Result<(), RsaError> {
456 let mut rng = ChaCha20Rng::from_entropy();
457 let private_key = Scalar::generate_vartime(&mut rng);
458
459 let public_key = ProjectivePoint::GENERATOR * private_key;
460 let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
461 .expect("Failed to generate RSA private key");
462 let rsa_public_key = rsa_private_key.to_public_key();
463 let label = b"test-label";
464 let verifiable_rsa = VerifiableRsaEncryption::encrypt_with_proof(
465 &private_key,
466 &rsa_public_key,
467 label,
468 None,
469 &mut rng,
470 )?;
471
472 verifiable_rsa.verify(&public_key, &rsa_public_key, label)?;
473
474 let decrypted_x =
475 verifiable_rsa.decrypt(&public_key, &rsa_private_key, label)?;
476
477 assert_eq!(private_key, decrypted_x);
478
479 Ok(())
480 }
481
482 #[test]
483 fn test_verifiable_rsa_25519() -> Result<(), RsaError> {
484 use curve25519_dalek::Scalar;
485 let mut rng = ChaCha20Rng::from_entropy();
486 let private_key = Scalar::random(&mut rng);
487 let public_key = EdwardsPoint::generator() * private_key;
488 let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
489 .expect("Failed to generate RSA private key");
490 let rsa_public_key = rsa_private_key.to_public_key();
491 let label = b"test-label";
492 let verifiable_rsa = VerifiableRsaEncryption::encrypt_with_proof(
493 &private_key,
494 &rsa_public_key,
495 label,
496 None,
497 &mut rng,
498 )?;
499 let bytes = verifiable_rsa.to_bytes();
500
501 let deserialized: VerifiableRsaEncryption<EdwardsPoint> =
502 VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
503
504 deserialized.verify(&public_key, &rsa_public_key, label)?;
505
506 verifiable_rsa.verify(&public_key, &rsa_public_key, label)?;
507 let decrypted_x =
508 verifiable_rsa.decrypt(&public_key, &rsa_private_key, label)?;
509 assert_eq!(private_key, decrypted_x);
510
511 Ok(())
512 }
513
514 #[test]
515 fn test_serde_k256() -> Result<(), RsaError> {
516 let mut rng = ChaCha20Rng::from_entropy();
517 let private_key = Scalar::generate_vartime(&mut rng);
518
519 let public_key = ProjectivePoint::GENERATOR * private_key;
520 let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
521 .expect("Failed to generate RSA private key");
522 let rsa_public_key = rsa_private_key.to_public_key();
523 let label = b"test-label";
524 let verifiable_rsa: VerifiableRsaEncryption<ProjectivePoint> =
525 VerifiableRsaEncryption::encrypt_with_proof(
526 &private_key,
527 &rsa_public_key,
528 label,
529 None,
530 &mut rng,
531 )?;
532
533 let bytes = verifiable_rsa.to_bytes();
534 let deserialized =
535 VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
536 deserialized.verify(&public_key, &rsa_public_key, label)?;
537
538 let decrypted_x =
539 deserialized.decrypt(&public_key, &rsa_private_key, label)?;
540 assert_eq!(private_key, decrypted_x);
541
542 Ok(())
543 }
544
545 #[test]
546 fn test_serde_25519() -> Result<(), RsaError> {
547 use curve25519_dalek::Scalar;
548 let mut rng = ChaCha20Rng::from_entropy();
549 let private_key = Scalar::random(&mut rng);
550 let public_key = EdwardsPoint::generator() * private_key;
551 let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
552 .expect("Failed to generate RSA private key");
553 let rsa_public_key = rsa_private_key.to_public_key();
554 let label = b"test-label";
555 let verifiable_rsa: VerifiableRsaEncryption<EdwardsPoint> =
556 VerifiableRsaEncryption::encrypt_with_proof(
557 &private_key,
558 &rsa_public_key,
559 label,
560 None,
561 &mut rng,
562 )?;
563
564 let bytes = verifiable_rsa.to_bytes();
565 let deserialized =
566 VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
567 deserialized.verify(&public_key, &rsa_public_key, label)?;
568
569 let decrypted_x =
570 deserialized.decrypt(&public_key, &rsa_private_key, label)?;
571 assert_eq!(private_key, decrypted_x);
572
573 Ok(())
574 }
575
576 #[test]
577 fn test_serde_rsa_4096() -> Result<(), RsaError> {
578 use curve25519_dalek::Scalar;
580 let mut rng = ChaCha20Rng::from_entropy();
581 let private_key = Scalar::random(&mut rng);
582 let public_key = EdwardsPoint::generator() * private_key;
583 let rsa_private_key = RsaPrivateKey::new(&mut rng, 4096)
584 .expect("Failed to generate RSA private key");
585 let rsa_public_key = rsa_private_key.to_public_key();
586 let label = b"test-label";
587 let verifiable_rsa: VerifiableRsaEncryption<EdwardsPoint> =
588 VerifiableRsaEncryption::encrypt_with_proof(
589 &private_key,
590 &rsa_public_key,
591 label,
592 None,
593 &mut rng,
594 )?;
595
596 let bytes = verifiable_rsa.to_bytes();
597 let deserialized =
598 VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
599 deserialized.verify(&public_key, &rsa_public_key, label)?;
600
601 let decrypted_x =
602 deserialized.decrypt(&public_key, &rsa_private_key, label)?;
603 assert_eq!(private_key, decrypted_x);
604
605 Ok(())
606 }
607
608 #[test]
609 fn test_extract_bit() {
610 let array: [u8; 1] = [0b0100_1110];
611
612 assert!(
614 array.extract_bit(0).ct_eq(&Choice::from(0)).unwrap_u8() == 1
615 );
616 assert!(
617 array.extract_bit(1).ct_eq(&Choice::from(1)).unwrap_u8() == 1
618 );
619 assert!(
620 array.extract_bit(2).ct_eq(&Choice::from(1)).unwrap_u8() == 1
621 );
622 assert!(
623 array.extract_bit(3).ct_eq(&Choice::from(1)).unwrap_u8() == 1
624 );
625 assert!(
626 array.extract_bit(4).ct_eq(&Choice::from(0)).unwrap_u8() == 1
627 );
628 assert!(
629 array.extract_bit(5).ct_eq(&Choice::from(0)).unwrap_u8() == 1
630 );
631 assert!(
632 array.extract_bit(6).ct_eq(&Choice::from(1)).unwrap_u8() == 1
633 );
634 assert!(
635 array.extract_bit(7).ct_eq(&Choice::from(0)).unwrap_u8() == 1
636 );
637 }
638}