1use crate::{public::RsaPublicKey, Error, Mpint, Result};
4use core::fmt;
5use encoding::{CheckedSum, Decode, Encode, Reader, Writer};
6use subtle::{Choice, ConstantTimeEq};
7use zeroize::Zeroize;
8
9#[cfg(feature = "rsa")]
10use {
11 rand_core::CryptoRngCore,
12 rsa::{
13 pkcs1v15,
14 traits::{PrivateKeyParts, PublicKeyParts},
15 },
16 sha2::{digest::const_oid::AssociatedOid, Digest},
17};
18
19#[derive(Clone)]
21pub struct RsaPrivateKey {
22 pub d: Mpint,
24
25 pub iqmp: Mpint,
27
28 pub p: Mpint,
30
31 pub q: Mpint,
33}
34
35impl ConstantTimeEq for RsaPrivateKey {
36 fn ct_eq(&self, other: &Self) -> Choice {
37 self.d.ct_eq(&other.d)
38 & self.iqmp.ct_eq(&self.iqmp)
39 & self.p.ct_eq(&other.p)
40 & self.q.ct_eq(&other.q)
41 }
42}
43
44impl Eq for RsaPrivateKey {}
45
46impl PartialEq for RsaPrivateKey {
47 fn eq(&self, other: &Self) -> bool {
48 self.ct_eq(other).into()
49 }
50}
51
52impl Decode for RsaPrivateKey {
53 type Error = Error;
54
55 fn decode(reader: &mut impl Reader) -> Result<Self> {
56 let d = Mpint::decode(reader)?;
57 let iqmp = Mpint::decode(reader)?;
58 let p = Mpint::decode(reader)?;
59 let q = Mpint::decode(reader)?;
60 Ok(Self { d, iqmp, p, q })
61 }
62}
63
64impl Encode for RsaPrivateKey {
65 fn encoded_len(&self) -> encoding::Result<usize> {
66 [
67 self.d.encoded_len()?,
68 self.iqmp.encoded_len()?,
69 self.p.encoded_len()?,
70 self.q.encoded_len()?,
71 ]
72 .checked_sum()
73 }
74
75 fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> {
76 self.d.encode(writer)?;
77 self.iqmp.encode(writer)?;
78 self.p.encode(writer)?;
79 self.q.encode(writer)?;
80 Ok(())
81 }
82}
83
84impl Drop for RsaPrivateKey {
85 fn drop(&mut self) {
86 self.d.zeroize();
87 self.iqmp.zeroize();
88 self.p.zeroize();
89 self.q.zeroize();
90 }
91}
92
93#[derive(Clone)]
95pub struct RsaKeypair {
96 pub public: RsaPublicKey,
98
99 pub private: RsaPrivateKey,
101}
102
103impl RsaKeypair {
104 #[cfg(feature = "rsa")]
106 pub(crate) const MIN_KEY_SIZE: usize = 2048;
107
108 #[cfg(feature = "rsa")]
110 pub fn random(rng: &mut impl CryptoRngCore, bit_size: usize) -> Result<Self> {
111 if bit_size >= Self::MIN_KEY_SIZE {
112 rsa::RsaPrivateKey::new(rng, bit_size)?.try_into()
113 } else {
114 Err(Error::Crypto)
115 }
116 }
117}
118
119impl ConstantTimeEq for RsaKeypair {
120 fn ct_eq(&self, other: &Self) -> Choice {
121 Choice::from((self.public == other.public) as u8) & self.private.ct_eq(&other.private)
122 }
123}
124
125impl Eq for RsaKeypair {}
126
127impl PartialEq for RsaKeypair {
128 fn eq(&self, other: &Self) -> bool {
129 self.ct_eq(other).into()
130 }
131}
132
133impl Decode for RsaKeypair {
134 type Error = Error;
135
136 fn decode(reader: &mut impl Reader) -> Result<Self> {
137 let n = Mpint::decode(reader)?;
138 let e = Mpint::decode(reader)?;
139 let public = RsaPublicKey { n, e };
140 let private = RsaPrivateKey::decode(reader)?;
141 Ok(RsaKeypair { public, private })
142 }
143}
144
145impl Encode for RsaKeypair {
146 fn encoded_len(&self) -> encoding::Result<usize> {
147 [
148 self.public.n.encoded_len()?,
149 self.public.e.encoded_len()?,
150 self.private.encoded_len()?,
151 ]
152 .checked_sum()
153 }
154
155 fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> {
156 self.public.n.encode(writer)?;
157 self.public.e.encode(writer)?;
158 self.private.encode(writer)
159 }
160}
161
162impl From<RsaKeypair> for RsaPublicKey {
163 fn from(keypair: RsaKeypair) -> RsaPublicKey {
164 keypair.public
165 }
166}
167
168impl From<&RsaKeypair> for RsaPublicKey {
169 fn from(keypair: &RsaKeypair) -> RsaPublicKey {
170 keypair.public.clone()
171 }
172}
173
174impl fmt::Debug for RsaKeypair {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_struct("RsaKeypair")
177 .field("public", &self.public)
178 .finish_non_exhaustive()
179 }
180}
181
182#[cfg(feature = "rsa")]
183impl TryFrom<RsaKeypair> for rsa::RsaPrivateKey {
184 type Error = Error;
185
186 fn try_from(key: RsaKeypair) -> Result<rsa::RsaPrivateKey> {
187 rsa::RsaPrivateKey::try_from(&key)
188 }
189}
190
191#[cfg(feature = "rsa")]
192impl TryFrom<&RsaKeypair> for rsa::RsaPrivateKey {
193 type Error = Error;
194
195 fn try_from(key: &RsaKeypair) -> Result<rsa::RsaPrivateKey> {
196 let ret = rsa::RsaPrivateKey::from_components(
197 rsa::BigUint::try_from(&key.public.n)?,
198 rsa::BigUint::try_from(&key.public.e)?,
199 rsa::BigUint::try_from(&key.private.d)?,
200 vec![
201 rsa::BigUint::try_from(&key.private.p)?,
202 rsa::BigUint::try_from(&key.private.p)?,
203 ],
204 )?;
205
206 if ret.size().saturating_mul(8) >= RsaKeypair::MIN_KEY_SIZE {
207 Ok(ret)
208 } else {
209 Err(Error::Crypto)
210 }
211 }
212}
213
214#[cfg(feature = "rsa")]
215impl TryFrom<rsa::RsaPrivateKey> for RsaKeypair {
216 type Error = Error;
217
218 fn try_from(key: rsa::RsaPrivateKey) -> Result<RsaKeypair> {
219 RsaKeypair::try_from(&key)
220 }
221}
222
223#[cfg(feature = "rsa")]
224impl TryFrom<&rsa::RsaPrivateKey> for RsaKeypair {
225 type Error = Error;
226
227 fn try_from(key: &rsa::RsaPrivateKey) -> Result<RsaKeypair> {
228 if key.primes().len() > 2 {
230 return Err(Error::Crypto);
231 }
232
233 let public = RsaPublicKey::try_from(key.to_public_key())?;
234
235 let p = &key.primes()[0];
236 let q = &key.primes()[1];
237 let iqmp = key.crt_coefficient().ok_or(Error::Crypto)?;
238
239 let private = RsaPrivateKey {
240 d: key.d().try_into()?,
241 iqmp: iqmp.try_into()?,
242 p: p.try_into()?,
243 q: q.try_into()?,
244 };
245
246 Ok(RsaKeypair { public, private })
247 }
248}
249
250#[cfg(feature = "rsa")]
251impl<D> TryFrom<&RsaKeypair> for pkcs1v15::SigningKey<D>
252where
253 D: Digest + AssociatedOid,
254{
255 type Error = Error;
256
257 fn try_from(keypair: &RsaKeypair) -> Result<pkcs1v15::SigningKey<D>> {
258 Ok(pkcs1v15::SigningKey::new(keypair.try_into()?))
259 }
260}