1use std::{marker::PhantomData, ptr};
43
44use crate::error::ErrorStack;
45use crate::hash::MessageDigest;
46use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
47use crate::rsa::Padding;
48use crate::{cvt, cvt_p};
49use foreign_types::ForeignTypeRef;
50use openssl_macros::corresponds;
51
52pub struct Encrypter<'a> {
54 pctx: *mut ffi::EVP_PKEY_CTX,
55 _p: PhantomData<&'a ()>,
56}
57
58unsafe impl Sync for Encrypter<'_> {}
59unsafe impl Send for Encrypter<'_> {}
60
61impl Drop for Encrypter<'_> {
62 fn drop(&mut self) {
63 unsafe {
64 ffi::EVP_PKEY_CTX_free(self.pctx);
65 }
66 }
67}
68
69impl<'a> Encrypter<'a> {
70 #[corresponds(EVP_PKEY_encrypt_init)]
72 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Encrypter<'a>, ErrorStack>
73 where
74 T: HasPublic,
75 {
76 unsafe {
77 ffi::init();
78
79 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
80 let r = ffi::EVP_PKEY_encrypt_init(pctx);
81 if r != 1 {
82 ffi::EVP_PKEY_CTX_free(pctx);
83 return Err(ErrorStack::get());
84 }
85
86 Ok(Encrypter {
87 pctx,
88 _p: PhantomData,
89 })
90 }
91 }
92
93 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
99 unsafe {
100 let mut pad = 0;
101 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
102 .map(|_| Padding::from_raw(pad))
103 }
104 }
105
106 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
110 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
111 unsafe {
112 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
113 self.pctx,
114 padding.as_raw(),
115 ))
116 .map(|_| ())
117 }
118 }
119
120 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
124 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
125 unsafe {
126 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
127 self.pctx,
128 md.as_ptr() as *mut _,
129 ))
130 .map(|_| ())
131 }
132 }
133
134 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
138 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
139 unsafe {
140 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
141 self.pctx,
142 md.as_ptr() as *mut _,
143 ))
144 .map(|_| ())
145 }
146 }
147
148 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
152 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
153 unsafe {
154 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
155 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
156
157 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
158 self.pctx,
159 p.cast(),
160 label.len() as _,
161 ))
162 .map(|_| ())
163 .inspect_err(|_e| {
164 ffi::OPENSSL_free(p);
165 })
166 }
167 }
168
169 #[corresponds(EVP_PKEY_encrypt)]
200 pub fn encrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
201 let mut written = to.len();
202 unsafe {
203 cvt(ffi::EVP_PKEY_encrypt(
204 self.pctx,
205 to.as_mut_ptr(),
206 &mut written,
207 from.as_ptr(),
208 from.len(),
209 ))?;
210 }
211
212 Ok(written)
213 }
214
215 #[corresponds(EVP_PKEY_encrypt)]
219 pub fn encrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
220 let mut written = 0;
221 unsafe {
222 cvt(ffi::EVP_PKEY_encrypt(
223 self.pctx,
224 ptr::null_mut(),
225 &mut written,
226 from.as_ptr(),
227 from.len(),
228 ))?;
229 }
230
231 Ok(written)
232 }
233}
234
235pub struct Decrypter<'a> {
237 pctx: *mut ffi::EVP_PKEY_CTX,
238 _p: PhantomData<&'a ()>,
239}
240
241unsafe impl Sync for Decrypter<'_> {}
242unsafe impl Send for Decrypter<'_> {}
243
244impl Drop for Decrypter<'_> {
245 fn drop(&mut self) {
246 unsafe {
247 ffi::EVP_PKEY_CTX_free(self.pctx);
248 }
249 }
250}
251
252impl<'a> Decrypter<'a> {
253 #[corresponds(EVP_PKEY_decrypt_init)]
255 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Decrypter<'a>, ErrorStack>
256 where
257 T: HasPrivate,
258 {
259 unsafe {
260 ffi::init();
261
262 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
263 let r = ffi::EVP_PKEY_decrypt_init(pctx);
264 if r != 1 {
265 ffi::EVP_PKEY_CTX_free(pctx);
266 return Err(ErrorStack::get());
267 }
268
269 Ok(Decrypter {
270 pctx,
271 _p: PhantomData,
272 })
273 }
274 }
275
276 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
282 unsafe {
283 let mut pad = 0;
284 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
285 .map(|_| Padding::from_raw(pad))
286 }
287 }
288
289 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
293 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
294 unsafe {
295 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
296 self.pctx,
297 padding.as_raw(),
298 ))
299 .map(|_| ())
300 }
301 }
302
303 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
307 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
308 unsafe {
309 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
310 self.pctx,
311 md.as_ptr() as *mut _,
312 ))
313 .map(|_| ())
314 }
315 }
316
317 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
321 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
322 unsafe {
323 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
324 self.pctx,
325 md.as_ptr() as *mut _,
326 ))
327 .map(|_| ())
328 }
329 }
330
331 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
335 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
336 unsafe {
337 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
338 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
339
340 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
341 self.pctx,
342 p.cast(),
343 label.len() as _,
344 ))
345 .map(|_| ())
346 .inspect_err(|_e| {
347 ffi::OPENSSL_free(p);
348 })
349 }
350 }
351
352 #[corresponds(EVP_PKEY_decrypt)]
398 pub fn decrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
399 let mut written = to.len();
400 unsafe {
401 cvt(ffi::EVP_PKEY_decrypt(
402 self.pctx,
403 to.as_mut_ptr(),
404 &mut written,
405 from.as_ptr(),
406 from.len(),
407 ))?;
408 }
409
410 Ok(written)
411 }
412
413 #[corresponds(EVP_PKEY_decrypt)]
417 pub fn decrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
418 let mut written = 0;
419 unsafe {
420 cvt(ffi::EVP_PKEY_decrypt(
421 self.pctx,
422 ptr::null_mut(),
423 &mut written,
424 from.as_ptr(),
425 from.len(),
426 ))?;
427 }
428
429 Ok(written)
430 }
431}
432
433#[cfg(test)]
434mod test {
435 use hex::FromHex;
436
437 use crate::encrypt::{Decrypter, Encrypter};
438 use crate::hash::MessageDigest;
439 use crate::pkey::PKey;
440 use crate::rsa::{Padding, Rsa};
441
442 const INPUT: &str =
443 "65794a68624763694f694a53557a49314e694a392e65794a7063334d694f694a71623255694c41304b49434a6c\
444 654841694f6a457a4d4441344d546b7a4f44417344516f67496d6830644841364c79396c654746746347786c4c\
445 6d4e76625339706331397962323930496a7030636e566c6651";
446
447 #[test]
448 fn rsa_encrypt_decrypt() {
449 let key = include_bytes!("../test/rsa.pem");
450 let private_key = Rsa::private_key_from_pem(key).unwrap();
451 let pkey = PKey::from_rsa(private_key).unwrap();
452
453 let mut encrypter = Encrypter::new(&pkey).unwrap();
454 encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
455 let input = Vec::from_hex(INPUT).unwrap();
456 let buffer_len = encrypter.encrypt_len(&input).unwrap();
457 let mut encoded = vec![0u8; buffer_len];
458 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
459 let encoded = &encoded[..encoded_len];
460
461 let mut decrypter = Decrypter::new(&pkey).unwrap();
462 decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
463 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
464 let mut decoded = vec![0u8; buffer_len];
465 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
466 let decoded = &decoded[..decoded_len];
467
468 assert_eq!(decoded, &*input);
469 }
470
471 #[test]
472 fn rsa_encrypt_decrypt_with_sha256() {
473 let key = include_bytes!("../test/rsa.pem");
474 let private_key = Rsa::private_key_from_pem(key).unwrap();
475 let pkey = PKey::from_rsa(private_key).unwrap();
476
477 let md = MessageDigest::sha256();
478
479 let mut encrypter = Encrypter::new(&pkey).unwrap();
480 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
481 encrypter.set_rsa_oaep_md(md).unwrap();
482 encrypter.set_rsa_mgf1_md(md).unwrap();
483 let input = Vec::from_hex(INPUT).unwrap();
484 let buffer_len = encrypter.encrypt_len(&input).unwrap();
485 let mut encoded = vec![0u8; buffer_len];
486 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
487 let encoded = &encoded[..encoded_len];
488
489 let mut decrypter = Decrypter::new(&pkey).unwrap();
490 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
491 decrypter.set_rsa_oaep_md(md).unwrap();
492 decrypter.set_rsa_mgf1_md(md).unwrap();
493 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
494 let mut decoded = vec![0u8; buffer_len];
495 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
496 let decoded = &decoded[..decoded_len];
497
498 assert_eq!(decoded, &*input);
499 }
500
501 #[test]
502 fn rsa_encrypt_decrypt_oaep_label() {
503 let key = include_bytes!("../test/rsa.pem");
504 let private_key = Rsa::private_key_from_pem(key).unwrap();
505 let pkey = PKey::from_rsa(private_key).unwrap();
506
507 let mut encrypter = Encrypter::new(&pkey).unwrap();
508 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
509 encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
510 let input = Vec::from_hex(INPUT).unwrap();
511 let buffer_len = encrypter.encrypt_len(&input).unwrap();
512 let mut encoded = vec![0u8; buffer_len];
513 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
514 let encoded = &encoded[..encoded_len];
515
516 let mut decrypter = Decrypter::new(&pkey).unwrap();
517 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
518 decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
519 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
520 let mut decoded = vec![0u8; buffer_len];
521 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
522 let decoded = &decoded[..decoded_len];
523
524 assert_eq!(decoded, &*input);
525
526 decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap();
527 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
528 let mut decoded = vec![0u8; buffer_len];
529
530 assert!(decrypter.decrypt(encoded, &mut decoded).is_err());
531 }
532}