1#[cfg(any(ossl102, libressl))]
43use libc::c_int;
44use std::{marker::PhantomData, ptr};
45
46use crate::error::ErrorStack;
47use crate::hash::MessageDigest;
48use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
49use crate::rsa::Padding;
50use crate::{cvt, cvt_p};
51use foreign_types::ForeignTypeRef;
52use openssl_macros::corresponds;
53
54pub struct Encrypter<'a> {
56 pctx: *mut ffi::EVP_PKEY_CTX,
57 _p: PhantomData<&'a ()>,
58}
59
60unsafe impl Sync for Encrypter<'_> {}
61unsafe impl Send for Encrypter<'_> {}
62
63impl Drop for Encrypter<'_> {
64 fn drop(&mut self) {
65 unsafe {
66 ffi::EVP_PKEY_CTX_free(self.pctx);
67 }
68 }
69}
70
71impl<'a> Encrypter<'a> {
72 #[corresponds(EVP_PKEY_encrypt_init)]
74 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Encrypter<'a>, ErrorStack>
75 where
76 T: HasPublic,
77 {
78 unsafe {
79 ffi::init();
80
81 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
82 let r = ffi::EVP_PKEY_encrypt_init(pctx);
83 if r != 1 {
84 ffi::EVP_PKEY_CTX_free(pctx);
85 return Err(ErrorStack::get());
86 }
87
88 Ok(Encrypter {
89 pctx,
90 _p: PhantomData,
91 })
92 }
93 }
94
95 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
101 unsafe {
102 let mut pad = 0;
103 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
104 .map(|_| Padding::from_raw(pad))
105 }
106 }
107
108 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
112 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
113 unsafe {
114 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
115 self.pctx,
116 padding.as_raw(),
117 ))
118 .map(|_| ())
119 }
120 }
121
122 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
126 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
127 unsafe {
128 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
129 self.pctx,
130 md.as_ptr() as *mut _,
131 ))
132 .map(|_| ())
133 }
134 }
135
136 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
140 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
141 unsafe {
142 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
143 self.pctx,
144 md.as_ptr() as *mut _,
145 ))
146 .map(|_| ())
147 }
148 }
149
150 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
154 #[cfg(any(ossl102, libressl))]
155 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
156 unsafe {
157 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
158 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
159
160 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
161 self.pctx,
162 p,
163 label.len() as c_int,
164 ))
165 .map(|_| ())
166 .map_err(|e| {
167 ffi::OPENSSL_free(p);
168 e
169 })
170 }
171 }
172
173 #[corresponds(EVP_PKEY_encrypt)]
204 pub fn encrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
205 let mut written = to.len();
206 unsafe {
207 cvt(ffi::EVP_PKEY_encrypt(
208 self.pctx,
209 to.as_mut_ptr(),
210 &mut written,
211 from.as_ptr(),
212 from.len(),
213 ))?;
214 }
215
216 Ok(written)
217 }
218
219 #[corresponds(EVP_PKEY_encrypt)]
223 pub fn encrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
224 let mut written = 0;
225 unsafe {
226 cvt(ffi::EVP_PKEY_encrypt(
227 self.pctx,
228 ptr::null_mut(),
229 &mut written,
230 from.as_ptr(),
231 from.len(),
232 ))?;
233 }
234
235 Ok(written)
236 }
237}
238
239pub struct Decrypter<'a> {
241 pctx: *mut ffi::EVP_PKEY_CTX,
242 _p: PhantomData<&'a ()>,
243}
244
245unsafe impl Sync for Decrypter<'_> {}
246unsafe impl Send for Decrypter<'_> {}
247
248impl Drop for Decrypter<'_> {
249 fn drop(&mut self) {
250 unsafe {
251 ffi::EVP_PKEY_CTX_free(self.pctx);
252 }
253 }
254}
255
256impl<'a> Decrypter<'a> {
257 #[corresponds(EVP_PKEY_decrypt_init)]
259 pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Decrypter<'a>, ErrorStack>
260 where
261 T: HasPrivate,
262 {
263 unsafe {
264 ffi::init();
265
266 let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
267 let r = ffi::EVP_PKEY_decrypt_init(pctx);
268 if r != 1 {
269 ffi::EVP_PKEY_CTX_free(pctx);
270 return Err(ErrorStack::get());
271 }
272
273 Ok(Decrypter {
274 pctx,
275 _p: PhantomData,
276 })
277 }
278 }
279
280 pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
286 unsafe {
287 let mut pad = 0;
288 cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
289 .map(|_| Padding::from_raw(pad))
290 }
291 }
292
293 #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
297 pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
298 unsafe {
299 cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
300 self.pctx,
301 padding.as_raw(),
302 ))
303 .map(|_| ())
304 }
305 }
306
307 #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
311 pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
312 unsafe {
313 cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
314 self.pctx,
315 md.as_ptr() as *mut _,
316 ))
317 .map(|_| ())
318 }
319 }
320
321 #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
325 pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
326 unsafe {
327 cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
328 self.pctx,
329 md.as_ptr() as *mut _,
330 ))
331 .map(|_| ())
332 }
333 }
334
335 #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
339 #[cfg(any(ossl102, libressl))]
340 pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
341 unsafe {
342 let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
343 ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
344
345 cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
346 self.pctx,
347 p,
348 label.len() as c_int,
349 ))
350 .map(|_| ())
351 .map_err(|e| {
352 ffi::OPENSSL_free(p);
353 e
354 })
355 }
356 }
357
358 #[corresponds(EVP_PKEY_decrypt)]
404 pub fn decrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
405 let mut written = to.len();
406 unsafe {
407 cvt(ffi::EVP_PKEY_decrypt(
408 self.pctx,
409 to.as_mut_ptr(),
410 &mut written,
411 from.as_ptr(),
412 from.len(),
413 ))?;
414 }
415
416 Ok(written)
417 }
418
419 #[corresponds(EVP_PKEY_decrypt)]
423 pub fn decrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
424 let mut written = 0;
425 unsafe {
426 cvt(ffi::EVP_PKEY_decrypt(
427 self.pctx,
428 ptr::null_mut(),
429 &mut written,
430 from.as_ptr(),
431 from.len(),
432 ))?;
433 }
434
435 Ok(written)
436 }
437}
438
439#[cfg(test)]
440mod test {
441 use hex::FromHex;
442
443 use crate::encrypt::{Decrypter, Encrypter};
444 use crate::hash::MessageDigest;
445 use crate::pkey::PKey;
446 use crate::rsa::{Padding, Rsa};
447
448 const INPUT: &str =
449 "65794a68624763694f694a53557a49314e694a392e65794a7063334d694f694a71623255694c41304b49434a6c\
450 654841694f6a457a4d4441344d546b7a4f44417344516f67496d6830644841364c79396c654746746347786c4c\
451 6d4e76625339706331397962323930496a7030636e566c6651";
452
453 #[test]
454 fn rsa_encrypt_decrypt() {
455 let key = include_bytes!("../test/rsa.pem");
456 let private_key = Rsa::private_key_from_pem(key).unwrap();
457 let pkey = PKey::from_rsa(private_key).unwrap();
458
459 let mut encrypter = Encrypter::new(&pkey).unwrap();
460 encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
461 let input = Vec::from_hex(INPUT).unwrap();
462 let buffer_len = encrypter.encrypt_len(&input).unwrap();
463 let mut encoded = vec![0u8; buffer_len];
464 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
465 let encoded = &encoded[..encoded_len];
466
467 let mut decrypter = Decrypter::new(&pkey).unwrap();
468 decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
469 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
470 let mut decoded = vec![0u8; buffer_len];
471 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
472 let decoded = &decoded[..decoded_len];
473
474 assert_eq!(decoded, &*input);
475 }
476
477 #[test]
478 fn rsa_encrypt_decrypt_with_sha256() {
479 let key = include_bytes!("../test/rsa.pem");
480 let private_key = Rsa::private_key_from_pem(key).unwrap();
481 let pkey = PKey::from_rsa(private_key).unwrap();
482
483 let md = MessageDigest::sha256();
484
485 let mut encrypter = Encrypter::new(&pkey).unwrap();
486 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
487 encrypter.set_rsa_oaep_md(md).unwrap();
488 encrypter.set_rsa_mgf1_md(md).unwrap();
489 let input = Vec::from_hex(INPUT).unwrap();
490 let buffer_len = encrypter.encrypt_len(&input).unwrap();
491 let mut encoded = vec![0u8; buffer_len];
492 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
493 let encoded = &encoded[..encoded_len];
494
495 let mut decrypter = Decrypter::new(&pkey).unwrap();
496 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
497 decrypter.set_rsa_oaep_md(md).unwrap();
498 decrypter.set_rsa_mgf1_md(md).unwrap();
499 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
500 let mut decoded = vec![0u8; buffer_len];
501 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
502 let decoded = &decoded[..decoded_len];
503
504 assert_eq!(decoded, &*input);
505 }
506
507 #[test]
508 #[cfg(any(ossl102, libressl))]
509 fn rsa_encrypt_decrypt_oaep_label() {
510 let key = include_bytes!("../test/rsa.pem");
511 let private_key = Rsa::private_key_from_pem(key).unwrap();
512 let pkey = PKey::from_rsa(private_key).unwrap();
513
514 let mut encrypter = Encrypter::new(&pkey).unwrap();
515 encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
516 encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
517 let input = Vec::from_hex(INPUT).unwrap();
518 let buffer_len = encrypter.encrypt_len(&input).unwrap();
519 let mut encoded = vec![0u8; buffer_len];
520 let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
521 let encoded = &encoded[..encoded_len];
522
523 let mut decrypter = Decrypter::new(&pkey).unwrap();
524 decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
525 decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
526 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
527 let mut decoded = vec![0u8; buffer_len];
528 let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
529 let decoded = &decoded[..decoded_len];
530
531 assert_eq!(decoded, &*input);
532
533 decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap();
534 let buffer_len = decrypter.decrypt_len(encoded).unwrap();
535 let mut decoded = vec![0u8; buffer_len];
536
537 assert!(decrypter.decrypt(encoded, &mut decoded).is_err());
538 }
539}