1use crate::types::ExportKeyFormat;
4use rsa::{
5 pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey},
6 Oaep, Pss, RsaPrivateKey, RsaPublicKey,
7};
8use sha2::{Digest, Sha256, Sha384, Sha512};
9
10const DEFAULT_RSA_KEY_SIZE: usize = 2048;
11const DEFAULT_SALT_LENGTH: usize = 32;
12
13#[derive(Clone)]
14pub struct RsaKeyPair {
15 pub private_key: RsaPrivateKey,
16 pub public_key: RsaPublicKey,
17 pub hash_algorithm: HashAlgorithm,
18}
19
20#[derive(Clone, Copy)]
21pub enum HashAlgorithm {
22 SHA256,
23 SHA384,
24 SHA512,
25}
26
27impl HashAlgorithm {
28 fn digest(&self, data: &[u8]) -> Vec<u8> {
29 match self {
30 HashAlgorithm::SHA256 => {
31 let mut hasher = Sha256::new();
32 hasher.update(data);
33 hasher.finalize().to_vec()
34 }
35 HashAlgorithm::SHA384 => {
36 let mut hasher = Sha384::new();
37 hasher.update(data);
38 hasher.finalize().to_vec()
39 }
40 HashAlgorithm::SHA512 => {
41 let mut hasher = Sha512::new();
42 hasher.update(data);
43 hasher.finalize().to_vec()
44 }
45 }
46 }
47
48 fn oaep_padding(&self) -> Oaep {
49 match self {
50 HashAlgorithm::SHA256 => Oaep::new::<Sha256>(),
51 HashAlgorithm::SHA384 => Oaep::new::<Sha384>(),
52 HashAlgorithm::SHA512 => Oaep::new::<Sha512>(),
53 }
54 }
55}
56
57pub struct RSA;
58
59impl RSA {
60 pub async fn generate_key_pair(
61 modulus_length: Option<usize>,
62 hash: Option<HashAlgorithm>,
63 ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
64 let mut rng = rand::thread_rng();
65 let bits = modulus_length.unwrap_or(DEFAULT_RSA_KEY_SIZE);
66 let private_key = RsaPrivateKey::new(&mut rng, bits)?;
67 let public_key = RsaPublicKey::from(&private_key);
68 let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
69 Ok(RsaKeyPair {
70 private_key,
71 public_key,
72 hash_algorithm,
73 })
74 }
75
76 pub async fn export_key_public(
77 key: &RsaKeyPair,
78 format: ExportKeyFormat,
79 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
80 match format {
81 ExportKeyFormat::SPKI => Ok(key.public_key.to_public_key_der()?.as_bytes().to_vec()),
82 ExportKeyFormat::JWK => {
83 Err("JWK format not yet implemented".into())
85 }
86 _ => Err("Unsupported export format for public key".into()),
87 }
88 }
89
90 pub async fn export_key_private(
91 key: &RsaKeyPair,
92 format: ExportKeyFormat,
93 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
94 match format {
95 ExportKeyFormat::PKCS8 => Ok(key.private_key.to_pkcs8_der()?.as_bytes().to_vec()),
96 ExportKeyFormat::JWK => {
97 Err("JWK format not yet implemented".into())
99 }
100 _ => Err("Unsupported export format for private key".into()),
101 }
102 }
103
104 pub async fn import_key(
105 key_data: &[u8],
106 format: ExportKeyFormat,
107 _for_encryption: bool,
108 hash: Option<HashAlgorithm>,
109 ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
110 let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
111 match format {
112 ExportKeyFormat::SPKI => {
113 let public_key = RsaPublicKey::from_public_key_der(key_data)?;
114 let private_key =
116 RsaPrivateKey::new(&mut rand::thread_rng(), DEFAULT_RSA_KEY_SIZE)?;
117 Ok(RsaKeyPair {
118 private_key,
119 public_key,
120 hash_algorithm,
121 })
122 }
123 ExportKeyFormat::JWK => {
124 Err("JWK format not yet implemented".into())
126 }
127 _ => Err("Unsupported import format".into()),
128 }
129 }
130
131 pub async fn encrypt(
132 key: &RsaKeyPair,
133 data: impl AsRef<[u8]>,
134 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135 let mut rng = rand::thread_rng();
136 let padding = key.hash_algorithm.oaep_padding();
137 Ok(key.public_key.encrypt(&mut rng, padding, data.as_ref())?)
138 }
139
140 pub async fn decrypt(
141 key: &RsaKeyPair,
142 encrypted_data: impl AsRef<[u8]>,
143 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
144 let padding = key.hash_algorithm.oaep_padding();
145 Ok(key.private_key.decrypt(padding, encrypted_data.as_ref())?)
146 }
147
148 pub async fn sign(
149 key: &RsaKeyPair,
150 data: impl AsRef<[u8]>,
151 salt_length: Option<usize>,
152 hash: Option<HashAlgorithm>,
153 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
154 let hash_alg = hash.unwrap_or(key.hash_algorithm);
155 let hashed = hash_alg.digest(data.as_ref());
156
157 let mut rng = rand::thread_rng();
158 let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
159 let padding = match hash_alg {
160 HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
161 HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
162 HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
163 };
164 Ok(key.private_key.sign_with_rng(&mut rng, padding, &hashed)?)
165 }
166
167 pub async fn verify(
168 key: &RsaKeyPair,
169 signature: impl AsRef<[u8]>,
170 data: impl AsRef<[u8]>,
171 salt_length: Option<usize>,
172 hash: Option<HashAlgorithm>,
173 ) -> Result<bool, Box<dyn std::error::Error>> {
174 let hash_alg = hash.unwrap_or(key.hash_algorithm);
175 let hashed = hash_alg.digest(data.as_ref());
176
177 let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
178 let padding = match hash_alg {
179 HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
180 HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
181 HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
182 };
183 Ok(key
184 .public_key
185 .verify(padding, &hashed, signature.as_ref())
186 .is_ok())
187 }
188}