sl_verifiable_enc/
lib.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3//
4
5//! # Verifiable RSA Encryption
6//! This crate provides a simple implementation of verifiable RSA encryption. The implementation is based on the paper [Verifiable RSA Encryption](https://eprint.iacr.org/1999/008)
7
8use core::mem::size_of;
9#[doc = include_str!("../README.md")]
10use ff::{Field, PrimeField};
11use group::{Group, GroupEncoding};
12use num_bigint_dig::ModInverse;
13use rand::{Rng, SeedableRng};
14use rand_chacha::{rand_core::CryptoRngCore, ChaCha20Rng};
15use rsa::{
16    traits::PublicKeyParts, BigUint, Pkcs1v15Encrypt, RsaPrivateKey,
17    RsaPublicKey,
18};
19use sha2::{Digest, Sha256};
20use std::ops::Index;
21use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
22use thiserror::Error;
23
24pub const SECURITY_PARAM: usize = 128;
25
26//Re-Exports
27pub use rsa;
28
29// Intentionally not giving too much information about the error
30#[derive(Debug, Error)]
31pub enum RsaError {
32    #[error("Error during encryption")]
33    EncError,
34    #[error("Error during decryption")]
35    DecError,
36    #[error("Invalid label inverse")]
37    InvalidLabel,
38    #[error("Verification failed")]
39    VerificationFailed,
40    #[error("Invalid SIZE parameter, must be equal to the size of the scalar in bytes")]
41    InvalidSizeParam,
42    #[error("(de)Serialization error")]
43    SerdeError(String),
44    #[error("Invalid Security Parameter, cannot be more than 256")]
45    InvalidSecurityParam,
46}
47
48pub struct ProofData<G: Group + GroupEncoding> {
49    g_r: G::Repr,
50    enc_x_r: Vec<u8>,
51    enc_r: Vec<u8>,
52}
53
54pub struct VerifiableRsaEncryption<G>
55where
56    G: Group + GroupEncoding + ConstantTimeEq,
57    G::Scalar: ConditionallySelectable,
58{
59    pub seed: [u8; 32],
60    pub proofs: Vec<ProofData<G>>,
61    pub open_scalars: Vec<G::Scalar>,
62    security_param: usize,
63}
64
65impl<G> VerifiableRsaEncryption<G>
66where
67    G: Group + GroupEncoding + ConstantTimeEq,
68{
69    pub fn encrypt_with_proof<R: CryptoRngCore>(
70        x: &G::Scalar,
71        rsa_pubkey: &RsaPublicKey,
72        label: &[u8],
73        security_param: Option<usize>,
74        rng: &mut R,
75    ) -> Result<Self, RsaError> {
76        let seed = rng.gen::<[u8; 32]>();
77        let security_param = security_param.unwrap_or(SECURITY_PARAM);
78        // Security parameter must be at least 128 and at most 256
79        if !(SECURITY_PARAM..=256).contains(&security_param) {
80            return Err(RsaError::InvalidSizeParam);
81        }
82        let mut proofs = Vec::with_capacity(security_param);
83        let q_point = G::generator() * x;
84        let mut r_list = Vec::with_capacity(security_param);
85        let mut x_plus_r_list = Vec::with_capacity(security_param);
86
87        for _ in 0..security_param {
88            let r = G::Scalar::random(&mut *rng);
89            let g_r = G::generator() * r;
90            let x_plus_r = *x + r;
91            let enc_r =
92                rsa_encrypt_with_label(r.to_repr(), label, rsa_pubkey, seed)?;
93            let enc_x_plus_r = rsa_encrypt_with_label(
94                x_plus_r.to_repr(),
95                label,
96                rsa_pubkey,
97                seed,
98            )?;
99
100            r_list.push(r);
101            x_plus_r_list.push(x_plus_r);
102
103            proofs.push(ProofData {
104                g_r: g_r.to_bytes(),
105                enc_x_r: enc_x_plus_r,
106                enc_r,
107            });
108        }
109        let challenge = Self::challenge(&q_point, label, &proofs);
110        let mut open_scalars = Vec::with_capacity(security_param);
111        for i in 0..security_param {
112            let choice_bit = challenge.extract_bit(i);
113            let selected = G::Scalar::conditional_select(
114                &r_list[i],
115                &x_plus_r_list[i],
116                choice_bit,
117            );
118            open_scalars.push(selected);
119        }
120
121        Ok(Self {
122            open_scalars,
123            proofs,
124            seed,
125            security_param,
126        })
127    }
128
129    pub fn verify(
130        &self,
131        q_point: &G,
132        rsa_pubkey: &RsaPublicKey,
133        label: &[u8],
134    ) -> Result<(), RsaError> {
135        let challenge = Self::challenge(q_point, label, &self.proofs);
136        for i in 0..self.security_param {
137            let proof = &self.proofs[i];
138            let open_scalar = &self.open_scalars[i];
139            let scalar_expo = G::generator() * open_scalar;
140            let choice_bit = challenge.extract_bit(i);
141            let enc_open_scalar = rsa_encrypt_with_label(
142                open_scalar.to_repr(),
143                label,
144                rsa_pubkey,
145                self.seed,
146            )?;
147
148            let g_r_option = G::from_bytes(&proof.g_r);
149            let g_r = if g_r_option.is_some().unwrap_u8() == 1 {
150                g_r_option.unwrap()
151            } else {
152                return Err(RsaError::VerificationFailed);
153            };
154
155            // If choice bit is 0
156            let cond_a = {
157                let cond1 = g_r.ct_eq(&scalar_expo);
158                let cond2 = proof.enc_r.ct_eq(&enc_open_scalar);
159                cond1 & cond2
160            };
161            // If choice bit is 1
162            let cond_b = {
163                let calc_scalar_expo = *q_point + g_r;
164                let cond1 = calc_scalar_expo.ct_eq(&scalar_expo);
165                let cond2 = proof.enc_x_r.ct_eq(&enc_open_scalar);
166                cond1 & cond2
167            };
168
169            let verified =
170                Choice::conditional_select(&cond_a, &cond_b, choice_bit)
171                    .unwrap_u8();
172            if verified != 1 {
173                return Err(RsaError::VerificationFailed);
174            }
175        }
176        Ok(())
177    }
178
179    pub fn decrypt(
180        &self,
181        q_point: &G,
182        rsa_privkey: &RsaPrivateKey,
183        label: &[u8],
184    ) -> Result<G::Scalar, RsaError> {
185        if self.proofs.len() != self.security_param {
186            return Err(RsaError::VerificationFailed);
187        }
188
189        for proof in &self.proofs {
190            let enc_r = &proof.enc_r;
191            let enc_x_r = &proof.enc_x_r;
192
193            let r = rsa_decrypt_with_label(enc_r, label, rsa_privkey)?;
194
195            // If r is not a valid scalar, continue. We expect at least one of the proofs to be valid, assuming the proofs are verified.
196            let r = if let Some(r) = decode_scalar::<G::Scalar>(&r) {
197                r
198            } else {
199                continue;
200            };
201
202            let x_plus_r =
203                rsa_decrypt_with_label(enc_x_r, label, rsa_privkey)?;
204
205            let x_plus_r = if let Some(x_plus_r) =
206                decode_scalar::<G::Scalar>(&x_plus_r)
207            {
208                x_plus_r
209            } else {
210                continue;
211            };
212
213            let x = x_plus_r - r;
214            let calc_public_point = G::generator() * x;
215            if calc_public_point == *q_point {
216                return Ok(x);
217            }
218        }
219
220        Err(RsaError::DecError)
221    }
222
223    pub fn to_bytes(&self) -> Vec<u8> {
224        let mut bytes = Vec::new();
225        bytes.extend_from_slice(&self.seed);
226        // Adding the sizes
227        // Security parameter, g_r size, enc_x_r size (will be same as enc_r size) and scalar size
228        bytes.extend_from_slice(&(self.security_param as u16).to_be_bytes());
229        bytes.extend_from_slice(
230            &(self.proofs[0].g_r.as_ref().len() as u16).to_be_bytes(),
231        );
232        bytes.extend_from_slice(
233            &(self.proofs[0].enc_x_r.len() as u16).to_be_bytes(),
234        );
235        bytes.extend_from_slice(
236            &(size_of::<<G::Scalar as PrimeField>::Repr>() as u16)
237                .to_be_bytes(),
238        );
239        for proof in &self.proofs {
240            bytes.extend_from_slice(proof.g_r.as_ref());
241            bytes.extend_from_slice(proof.enc_x_r.as_ref());
242            bytes.extend_from_slice(proof.enc_r.as_ref());
243        }
244        for scalar in &self.open_scalars {
245            bytes.extend_from_slice(scalar.to_repr().as_ref());
246        }
247
248        bytes
249    }
250
251    pub fn from_bytes(data: &[u8]) -> Result<Self, RsaError> {
252        let res = || {
253            if data.len() < 32 + 8 {
254                // 32 (seed) + 8 (4 * u16 sizes)
255                return Err("Input data too short");
256            }
257
258            let mut offset = 0;
259
260            // Read seed
261            let mut seed = [0u8; 32];
262            seed.copy_from_slice(&data[offset..offset + 32]);
263            offset += 32;
264
265            // Read sizes
266            let security_param =
267                u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
268            offset += 2;
269            let g_r_size =
270                u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
271            offset += 2;
272            let enc_size =
273                u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
274            offset += 2;
275            let scalar_size =
276                u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
277            offset += 2;
278
279            if scalar_size
280                != core::mem::size_of::<<G::Scalar as PrimeField>::Repr>()
281            {
282                return Err("Inconsistent scalar size");
283            }
284
285            if g_r_size != core::mem::size_of::<G::Repr>() {
286                return Err("Inconsistent g_r size");
287            }
288
289            // Calculate number of proofs and open scalars
290            let proof_size = g_r_size + 2 * enc_size;
291            let remaining_data = data.len() - offset;
292            let num_proofs = remaining_data / (proof_size + scalar_size);
293
294            if security_param < SECURITY_PARAM {
295                return Err("Security param must at least be 128");
296            }
297
298            if num_proofs != security_param {
299                return Err("Inconsistent number of proofs, must be equal to the security parameter");
300            }
301
302            if remaining_data % (proof_size + scalar_size) != 0 {
303                return Err("Inconsistent data length");
304            }
305
306            // Read proofs
307            let mut proofs = Vec::with_capacity(num_proofs);
308            for _ in 0..num_proofs {
309                if offset + proof_size > data.len() {
310                    return Err(
311                        "Unexpected end of data while reading proofs",
312                    );
313                }
314
315                let mut g_r = G::Repr::default();
316                g_r.as_mut()
317                    .copy_from_slice(&data[offset..offset + g_r_size]);
318
319                offset += g_r_size;
320                let enc_x_r = data[offset..offset + enc_size].to_vec();
321                offset += enc_size;
322
323                let enc_r = data[offset..offset + enc_size].to_vec();
324                offset += enc_size;
325
326                proofs.push(ProofData {
327                    g_r,
328                    enc_x_r,
329                    enc_r,
330                });
331            }
332
333            // Read open scalars
334            let mut open_scalars = Vec::with_capacity(num_proofs);
335            let scalar_size = size_of::<<G::Scalar as PrimeField>::Repr>();
336            for _ in 0..num_proofs {
337                if offset + scalar_size > data.len() {
338                    return Err(
339                        "Unexpected end of data while reading scalars",
340                    );
341                }
342                let scalar =
343                    decode_scalar(&data[offset..offset + scalar_size])
344                        .ok_or("Invalid scalar")?;
345                offset += scalar_size;
346                open_scalars.push(scalar);
347            }
348
349            Ok(Self {
350                seed,
351                proofs,
352                open_scalars,
353                security_param,
354            })
355        };
356        res().map_err(|e| RsaError::SerdeError(e.to_string()))
357    }
358
359    fn challenge(
360        q_point: &G,
361        label: &[u8],
362        proofs: &[ProofData<G>],
363    ) -> [u8; 32] {
364        let mut hasher = Sha256::new();
365        hasher.update(b"Verified-RSA-encryption");
366        hasher.update(q_point.to_bytes());
367        for proof in proofs {
368            hasher.update(proof.g_r);
369            hasher.update(&proof.enc_x_r);
370            hasher.update(&proof.enc_r);
371        }
372        hasher.update(label);
373        hasher.finalize().into()
374    }
375}
376
377fn rsa_encrypt_with_label(
378    m: impl AsRef<[u8]>,
379    label: &[u8],
380    rsa_pubkey: &RsaPublicKey,
381    seed: [u8; 32],
382) -> Result<Vec<u8>, RsaError> {
383    let mut rng = ChaCha20Rng::from_seed(seed);
384    let m_int = BigUint::from_bytes_be(m.as_ref());
385    let label_int = label_int_from_bytes(label);
386    let plaintext = (m_int * label_int) % rsa_pubkey.n();
387    rsa_pubkey
388        .encrypt(&mut rng, Pkcs1v15Encrypt, &plaintext.to_bytes_be())
389        .map_err(|_| RsaError::EncError)
390}
391
392fn rsa_decrypt_with_label(
393    ciphertext: &[u8],
394    label: &[u8],
395    rsa_privkey: &RsaPrivateKey,
396) -> Result<Vec<u8>, RsaError> {
397    let plaintext = rsa_privkey
398        .decrypt(Pkcs1v15Encrypt, ciphertext)
399        .map_err(|_| RsaError::DecError)?;
400
401    let n = rsa_privkey.n();
402    let label_inv = label_int_from_bytes(label)
403        .mod_inverse(n)
404        .and_then(|num| num.to_biguint())
405        .ok_or(RsaError::InvalidLabel)?;
406
407    let plaintext_int = BigUint::from_bytes_be(&plaintext);
408    let message = (plaintext_int * label_inv) % n;
409    Ok(message.to_bytes_be())
410}
411
412fn label_int_from_bytes(label: &[u8]) -> BigUint {
413    let mut hasher = Sha256::new();
414    hasher.update(b"SL-label-for-RSA");
415    hasher.update(label);
416    let digest = hasher.finalize();
417    BigUint::from_bytes_be(&digest[..])
418}
419
420/// Simple trait to extract a bit from a byte array.
421pub trait ExtractBit: Index<usize, Output = u8> {
422    /// Extract a bit at given index (in little endian order) from a byte array.
423    fn extract_bit(&self, idx: usize) -> Choice {
424        let byte_idx = idx >> 3;
425        let bit_idx = idx & 0x7;
426        let byte = self[byte_idx];
427        let mask = 1 << bit_idx;
428        Choice::from(((byte & mask) != 0) as u8)
429    }
430}
431impl<const N: usize> ExtractBit for [u8; N] {}
432
433fn decode_scalar<S: PrimeField>(bytes: &[u8]) -> Option<S> {
434    if bytes.len() != size_of::<S::Repr>() {
435        return None;
436    }
437    let mut encoding = <S as PrimeField>::Repr::default();
438    encoding.as_mut().copy_from_slice(bytes);
439    S::from_repr(encoding).into()
440}
441
442#[cfg(test)]
443mod tests {
444    use curve25519_dalek::EdwardsPoint;
445    use group::Group;
446    use k256::{ProjectivePoint, Scalar};
447    use rand::SeedableRng;
448    use rand_chacha::ChaCha20Rng;
449    use rsa::RsaPrivateKey;
450    use subtle::Choice;
451
452    use crate::*;
453
454    #[test]
455    fn test_verifiable_rsa_ecdsa() -> Result<(), RsaError> {
456        let mut rng = ChaCha20Rng::from_entropy();
457        let private_key = Scalar::generate_vartime(&mut rng);
458
459        let public_key = ProjectivePoint::GENERATOR * private_key;
460        let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
461            .expect("Failed to generate RSA private key");
462        let rsa_public_key = rsa_private_key.to_public_key();
463        let label = b"test-label";
464        let verifiable_rsa = VerifiableRsaEncryption::encrypt_with_proof(
465            &private_key,
466            &rsa_public_key,
467            label,
468            None,
469            &mut rng,
470        )?;
471
472        verifiable_rsa.verify(&public_key, &rsa_public_key, label)?;
473
474        let decrypted_x =
475            verifiable_rsa.decrypt(&public_key, &rsa_private_key, label)?;
476
477        assert_eq!(private_key, decrypted_x);
478
479        Ok(())
480    }
481
482    #[test]
483    fn test_verifiable_rsa_25519() -> Result<(), RsaError> {
484        use curve25519_dalek::Scalar;
485        let mut rng = ChaCha20Rng::from_entropy();
486        let private_key = Scalar::random(&mut rng);
487        let public_key = EdwardsPoint::generator() * private_key;
488        let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
489            .expect("Failed to generate RSA private key");
490        let rsa_public_key = rsa_private_key.to_public_key();
491        let label = b"test-label";
492        let verifiable_rsa = VerifiableRsaEncryption::encrypt_with_proof(
493            &private_key,
494            &rsa_public_key,
495            label,
496            None,
497            &mut rng,
498        )?;
499        let bytes = verifiable_rsa.to_bytes();
500
501        let deserialized: VerifiableRsaEncryption<EdwardsPoint> =
502            VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
503
504        deserialized.verify(&public_key, &rsa_public_key, label)?;
505
506        verifiable_rsa.verify(&public_key, &rsa_public_key, label)?;
507        let decrypted_x =
508            verifiable_rsa.decrypt(&public_key, &rsa_private_key, label)?;
509        assert_eq!(private_key, decrypted_x);
510
511        Ok(())
512    }
513
514    #[test]
515    fn test_serde_k256() -> Result<(), RsaError> {
516        let mut rng = ChaCha20Rng::from_entropy();
517        let private_key = Scalar::generate_vartime(&mut rng);
518
519        let public_key = ProjectivePoint::GENERATOR * private_key;
520        let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
521            .expect("Failed to generate RSA private key");
522        let rsa_public_key = rsa_private_key.to_public_key();
523        let label = b"test-label";
524        let verifiable_rsa: VerifiableRsaEncryption<ProjectivePoint> =
525            VerifiableRsaEncryption::encrypt_with_proof(
526                &private_key,
527                &rsa_public_key,
528                label,
529                None,
530                &mut rng,
531            )?;
532
533        let bytes = verifiable_rsa.to_bytes();
534        let deserialized =
535            VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
536        deserialized.verify(&public_key, &rsa_public_key, label)?;
537
538        let decrypted_x =
539            deserialized.decrypt(&public_key, &rsa_private_key, label)?;
540        assert_eq!(private_key, decrypted_x);
541
542        Ok(())
543    }
544
545    #[test]
546    fn test_serde_25519() -> Result<(), RsaError> {
547        use curve25519_dalek::Scalar;
548        let mut rng = ChaCha20Rng::from_entropy();
549        let private_key = Scalar::random(&mut rng);
550        let public_key = EdwardsPoint::generator() * private_key;
551        let rsa_private_key = RsaPrivateKey::new(&mut rng, 2048)
552            .expect("Failed to generate RSA private key");
553        let rsa_public_key = rsa_private_key.to_public_key();
554        let label = b"test-label";
555        let verifiable_rsa: VerifiableRsaEncryption<EdwardsPoint> =
556            VerifiableRsaEncryption::encrypt_with_proof(
557                &private_key,
558                &rsa_public_key,
559                label,
560                None,
561                &mut rng,
562            )?;
563
564        let bytes = verifiable_rsa.to_bytes();
565        let deserialized =
566            VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
567        deserialized.verify(&public_key, &rsa_public_key, label)?;
568
569        let decrypted_x =
570            deserialized.decrypt(&public_key, &rsa_private_key, label)?;
571        assert_eq!(private_key, decrypted_x);
572
573        Ok(())
574    }
575
576    #[test]
577    fn test_serde_rsa_4096() -> Result<(), RsaError> {
578        // Using key-size of 4096 bits to test if the de/serialization works for larger keys
579        use curve25519_dalek::Scalar;
580        let mut rng = ChaCha20Rng::from_entropy();
581        let private_key = Scalar::random(&mut rng);
582        let public_key = EdwardsPoint::generator() * private_key;
583        let rsa_private_key = RsaPrivateKey::new(&mut rng, 4096)
584            .expect("Failed to generate RSA private key");
585        let rsa_public_key = rsa_private_key.to_public_key();
586        let label = b"test-label";
587        let verifiable_rsa: VerifiableRsaEncryption<EdwardsPoint> =
588            VerifiableRsaEncryption::encrypt_with_proof(
589                &private_key,
590                &rsa_public_key,
591                label,
592                None,
593                &mut rng,
594            )?;
595
596        let bytes = verifiable_rsa.to_bytes();
597        let deserialized =
598            VerifiableRsaEncryption::from_bytes(&bytes).unwrap();
599        deserialized.verify(&public_key, &rsa_public_key, label)?;
600
601        let decrypted_x =
602            deserialized.decrypt(&public_key, &rsa_private_key, label)?;
603        assert_eq!(private_key, decrypted_x);
604
605        Ok(())
606    }
607
608    #[test]
609    fn test_extract_bit() {
610        let array: [u8; 1] = [0b0100_1110];
611
612        // Check each bit
613        assert!(
614            array.extract_bit(0).ct_eq(&Choice::from(0)).unwrap_u8() == 1
615        );
616        assert!(
617            array.extract_bit(1).ct_eq(&Choice::from(1)).unwrap_u8() == 1
618        );
619        assert!(
620            array.extract_bit(2).ct_eq(&Choice::from(1)).unwrap_u8() == 1
621        );
622        assert!(
623            array.extract_bit(3).ct_eq(&Choice::from(1)).unwrap_u8() == 1
624        );
625        assert!(
626            array.extract_bit(4).ct_eq(&Choice::from(0)).unwrap_u8() == 1
627        );
628        assert!(
629            array.extract_bit(5).ct_eq(&Choice::from(0)).unwrap_u8() == 1
630        );
631        assert!(
632            array.extract_bit(6).ct_eq(&Choice::from(1)).unwrap_u8() == 1
633        );
634        assert!(
635            array.extract_bit(7).ct_eq(&Choice::from(0)).unwrap_u8() == 1
636        );
637    }
638}