variant_ssl/
encrypt.rs

1//! Message encryption.
2//!
3//! The [`Encrypter`] allows for encryption of data given a public key. The [`Decrypter`] can be
4//! used with the corresponding private key to decrypt the data.
5//!
6//! # Examples
7//!
8//! Encrypt and decrypt data given an RSA keypair:
9//!
10//! ```rust
11//! use openssl::encrypt::{Encrypter, Decrypter};
12//! use openssl::rsa::{Rsa, Padding};
13//! use openssl::pkey::PKey;
14//!
15//! // Generate a keypair
16//! let keypair = Rsa::generate(2048).unwrap();
17//! let keypair = PKey::from_rsa(keypair).unwrap();
18//!
19//! let data = b"hello, world!";
20//!
21//! // Encrypt the data with RSA PKCS1
22//! let mut encrypter = Encrypter::new(&keypair).unwrap();
23//! encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
24//! // Create an output buffer
25//! let buffer_len = encrypter.encrypt_len(data).unwrap();
26//! let mut encrypted = vec![0; buffer_len];
27//! // Encrypt and truncate the buffer
28//! let encrypted_len = encrypter.encrypt(data, &mut encrypted).unwrap();
29//! encrypted.truncate(encrypted_len);
30//!
31//! // Decrypt the data
32//! let mut decrypter = Decrypter::new(&keypair).unwrap();
33//! decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
34//! // Create an output buffer
35//! let buffer_len = decrypter.decrypt_len(&encrypted).unwrap();
36//! let mut decrypted = vec![0; buffer_len];
37//! // Encrypt and truncate the buffer
38//! let decrypted_len = decrypter.decrypt(&encrypted, &mut decrypted).unwrap();
39//! decrypted.truncate(decrypted_len);
40//! assert_eq!(&*decrypted, data);
41//! ```
42#[cfg(any(ossl102, libressl))]
43use libc::c_int;
44use std::{marker::PhantomData, ptr};
45
46use crate::error::ErrorStack;
47use crate::hash::MessageDigest;
48use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
49use crate::rsa::Padding;
50use crate::{cvt, cvt_p};
51use foreign_types::ForeignTypeRef;
52use openssl_macros::corresponds;
53
54/// A type which encrypts data.
55pub struct Encrypter<'a> {
56    pctx: *mut ffi::EVP_PKEY_CTX,
57    _p: PhantomData<&'a ()>,
58}
59
60unsafe impl Sync for Encrypter<'_> {}
61unsafe impl Send for Encrypter<'_> {}
62
63impl Drop for Encrypter<'_> {
64    fn drop(&mut self) {
65        unsafe {
66            ffi::EVP_PKEY_CTX_free(self.pctx);
67        }
68    }
69}
70
71impl<'a> Encrypter<'a> {
72    /// Creates a new `Encrypter`.
73    #[corresponds(EVP_PKEY_encrypt_init)]
74    pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Encrypter<'a>, ErrorStack>
75    where
76        T: HasPublic,
77    {
78        unsafe {
79            ffi::init();
80
81            let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
82            let r = ffi::EVP_PKEY_encrypt_init(pctx);
83            if r != 1 {
84                ffi::EVP_PKEY_CTX_free(pctx);
85                return Err(ErrorStack::get());
86            }
87
88            Ok(Encrypter {
89                pctx,
90                _p: PhantomData,
91            })
92        }
93    }
94
95    /// Returns the RSA padding mode in use.
96    ///
97    /// This is only useful for RSA keys.
98    ///
99    /// This corresponds to `EVP_PKEY_CTX_get_rsa_padding`.
100    pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
101        unsafe {
102            let mut pad = 0;
103            cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
104                .map(|_| Padding::from_raw(pad))
105        }
106    }
107
108    /// Sets the RSA padding mode.
109    ///
110    /// This is only useful for RSA keys.
111    #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
112    pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
113        unsafe {
114            cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
115                self.pctx,
116                padding.as_raw(),
117            ))
118            .map(|_| ())
119        }
120    }
121
122    /// Sets the RSA MGF1 algorithm.
123    ///
124    /// This is only useful for RSA keys.
125    #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
126    pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
127        unsafe {
128            cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
129                self.pctx,
130                md.as_ptr() as *mut _,
131            ))
132            .map(|_| ())
133        }
134    }
135
136    /// Sets the RSA OAEP algorithm.
137    ///
138    /// This is only useful for RSA keys.
139    #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
140    pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
141        unsafe {
142            cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
143                self.pctx,
144                md.as_ptr() as *mut _,
145            ))
146            .map(|_| ())
147        }
148    }
149
150    /// Sets the RSA OAEP label.
151    ///
152    /// This is only useful for RSA keys.
153    #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
154    #[cfg(any(ossl102, libressl))]
155    pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
156        unsafe {
157            let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
158            ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
159
160            cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
161                self.pctx,
162                p,
163                label.len() as c_int,
164            ))
165            .map(|_| ())
166            .map_err(|e| {
167                ffi::OPENSSL_free(p);
168                e
169            })
170        }
171    }
172
173    /// Performs public key encryption.
174    ///
175    /// In order to know the size needed for the output buffer, use [`encrypt_len`](Encrypter::encrypt_len).
176    /// Note that the length of the output buffer can be greater of the length of the encoded data.
177    /// ```
178    /// # use openssl::{
179    /// #   encrypt::Encrypter,
180    /// #   pkey::PKey,
181    /// #   rsa::{Rsa, Padding},
182    /// # };
183    /// #
184    /// # let key = include_bytes!("../test/rsa.pem");
185    /// # let private_key = Rsa::private_key_from_pem(key).unwrap();
186    /// # let pkey = PKey::from_rsa(private_key).unwrap();
187    /// # let input = b"hello world".to_vec();
188    /// #
189    /// let mut encrypter = Encrypter::new(&pkey).unwrap();
190    /// encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
191    ///
192    /// // Get the length of the output buffer
193    /// let buffer_len = encrypter.encrypt_len(&input).unwrap();
194    /// let mut encoded = vec![0u8; buffer_len];
195    ///
196    /// // Encode the data and get its length
197    /// let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
198    ///
199    /// // Use only the part of the buffer with the encoded data
200    /// let encoded = &encoded[..encoded_len];
201    /// ```
202    ///
203    #[corresponds(EVP_PKEY_encrypt)]
204    pub fn encrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
205        let mut written = to.len();
206        unsafe {
207            cvt(ffi::EVP_PKEY_encrypt(
208                self.pctx,
209                to.as_mut_ptr(),
210                &mut written,
211                from.as_ptr(),
212                from.len(),
213            ))?;
214        }
215
216        Ok(written)
217    }
218
219    /// Gets the size of the buffer needed to encrypt the input data.
220    ///
221    /// This corresponds to `EVP_PKEY_encrypt` called with a null pointer as output argument.
222    #[corresponds(EVP_PKEY_encrypt)]
223    pub fn encrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
224        let mut written = 0;
225        unsafe {
226            cvt(ffi::EVP_PKEY_encrypt(
227                self.pctx,
228                ptr::null_mut(),
229                &mut written,
230                from.as_ptr(),
231                from.len(),
232            ))?;
233        }
234
235        Ok(written)
236    }
237}
238
239/// A type which decrypts data.
240pub struct Decrypter<'a> {
241    pctx: *mut ffi::EVP_PKEY_CTX,
242    _p: PhantomData<&'a ()>,
243}
244
245unsafe impl Sync for Decrypter<'_> {}
246unsafe impl Send for Decrypter<'_> {}
247
248impl Drop for Decrypter<'_> {
249    fn drop(&mut self) {
250        unsafe {
251            ffi::EVP_PKEY_CTX_free(self.pctx);
252        }
253    }
254}
255
256impl<'a> Decrypter<'a> {
257    /// Creates a new `Decrypter`.
258    #[corresponds(EVP_PKEY_decrypt_init)]
259    pub fn new<T>(pkey: &'a PKeyRef<T>) -> Result<Decrypter<'a>, ErrorStack>
260    where
261        T: HasPrivate,
262    {
263        unsafe {
264            ffi::init();
265
266            let pctx = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
267            let r = ffi::EVP_PKEY_decrypt_init(pctx);
268            if r != 1 {
269                ffi::EVP_PKEY_CTX_free(pctx);
270                return Err(ErrorStack::get());
271            }
272
273            Ok(Decrypter {
274                pctx,
275                _p: PhantomData,
276            })
277        }
278    }
279
280    /// Returns the RSA padding mode in use.
281    ///
282    /// This is only useful for RSA keys.
283    ///
284    /// This corresponds to `EVP_PKEY_CTX_get_rsa_padding`.
285    pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
286        unsafe {
287            let mut pad = 0;
288            cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.pctx, &mut pad))
289                .map(|_| Padding::from_raw(pad))
290        }
291    }
292
293    /// Sets the RSA padding mode.
294    ///
295    /// This is only useful for RSA keys.
296    #[corresponds(EVP_PKEY_CTX_set_rsa_padding)]
297    pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
298        unsafe {
299            cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
300                self.pctx,
301                padding.as_raw(),
302            ))
303            .map(|_| ())
304        }
305    }
306
307    /// Sets the RSA MGF1 algorithm.
308    ///
309    /// This is only useful for RSA keys.
310    #[corresponds(EVP_PKEY_CTX_set_rsa_mgf1_md)]
311    pub fn set_rsa_mgf1_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
312        unsafe {
313            cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
314                self.pctx,
315                md.as_ptr() as *mut _,
316            ))
317            .map(|_| ())
318        }
319    }
320
321    /// Sets the RSA OAEP algorithm.
322    ///
323    /// This is only useful for RSA keys.
324    #[corresponds(EVP_PKEY_CTX_set_rsa_oaep_md)]
325    pub fn set_rsa_oaep_md(&mut self, md: MessageDigest) -> Result<(), ErrorStack> {
326        unsafe {
327            cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
328                self.pctx,
329                md.as_ptr() as *mut _,
330            ))
331            .map(|_| ())
332        }
333    }
334
335    /// Sets the RSA OAEP label.
336    ///
337    /// This is only useful for RSA keys.
338    #[corresponds(EVP_PKEY_CTX_set0_rsa_oaep_label)]
339    #[cfg(any(ossl102, libressl))]
340    pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
341        unsafe {
342            let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
343            ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());
344
345            cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
346                self.pctx,
347                p,
348                label.len() as c_int,
349            ))
350            .map(|_| ())
351            .map_err(|e| {
352                ffi::OPENSSL_free(p);
353                e
354            })
355        }
356    }
357
358    /// Performs public key decryption.
359    ///
360    /// In order to know the size needed for the output buffer, use [`decrypt_len`](Decrypter::decrypt_len).
361    /// Note that the length of the output buffer can be greater of the length of the decoded data.
362    /// ```
363    /// # use openssl::{
364    /// #   encrypt::Decrypter,
365    /// #   pkey::PKey,
366    /// #   rsa::{Rsa, Padding},
367    /// # };
368    /// #
369    /// # const INPUT: &[u8] = b"\
370    /// #     \x26\xa1\xc1\x13\xc5\x7f\xb4\x9f\xa0\xb4\xde\x61\x5e\x2e\xc6\xfb\x76\x5c\xd1\x2b\x5f\
371    /// #     \x1d\x36\x60\xfa\xf8\xe8\xb3\x21\xf4\x9c\x70\xbc\x03\xea\xea\xac\xce\x4b\xb3\xf6\x45\
372    /// #     \xcc\xb3\x80\x9e\xa8\xf7\xc3\x5d\x06\x12\x7a\xa3\x0c\x30\x67\xf1\xe7\x94\x6c\xf6\x26\
373    /// #     \xac\x28\x17\x59\x69\xe1\xdc\xed\x7e\xc0\xe9\x62\x57\x49\xce\xdd\x13\x07\xde\x18\x03\
374    /// #     \x0f\x9d\x61\x65\xb9\x23\x8c\x78\x4b\xad\x23\x49\x75\x47\x64\xa0\xa0\xa2\x90\xc1\x49\
375    /// #     \x1b\x05\x24\xc2\xe9\x2c\x0d\x49\x78\x72\x61\x72\xed\x8b\x6f\x8a\xe8\xca\x05\x5c\x58\
376    /// #     \xd6\x95\xd6\x7b\xe3\x2d\x0d\xaa\x3e\x6d\x3c\x9a\x1c\x1d\xb4\x6c\x42\x9d\x9a\x82\x55\
377    /// #     \xd9\xde\xc8\x08\x7b\x17\xac\xd7\xaf\x86\x7b\x69\x9e\x3c\xf4\x5e\x1c\x39\x52\x6d\x62\
378    /// #     \x50\x51\xbd\xa6\xc8\x4e\xe9\x34\xf0\x37\x0d\xa9\xa9\x77\xe6\xf5\xc2\x47\x2d\xa8\xee\
379    /// #     \x3f\x69\x78\xff\xa9\xdc\x70\x22\x20\x9a\x5c\x9b\x70\x15\x90\xd3\xb4\x0e\x54\x9e\x48\
380    /// #     \xed\xb6\x2c\x88\xfc\xb4\xa9\x37\x10\xfa\x71\xb2\xec\x75\xe7\xe7\x0e\xf4\x60\x2c\x7b\
381    /// #     \x58\xaf\xa0\x53\xbd\x24\xf1\x12\xe3\x2e\x99\x25\x0a\x54\x54\x9d\xa1\xdb\xca\x41\x85\
382    /// #     \xf4\x62\x78\x64";
383    /// #
384    /// # let key = include_bytes!("../test/rsa.pem");
385    /// # let private_key = Rsa::private_key_from_pem(key).unwrap();
386    /// # let pkey = PKey::from_rsa(private_key).unwrap();
387    /// # let input = INPUT.to_vec();
388    /// #
389    /// let mut decrypter = Decrypter::new(&pkey).unwrap();
390    /// decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
391    ///
392    /// // Get the length of the output buffer
393    /// let buffer_len = decrypter.decrypt_len(&input).unwrap();
394    /// let mut decoded = vec![0u8; buffer_len];
395    ///
396    /// // Decrypt the data and get its length
397    /// let decoded_len = decrypter.decrypt(&input, &mut decoded).unwrap();
398    ///
399    /// // Use only the part of the buffer with the decrypted data
400    /// let decoded = &decoded[..decoded_len];
401    /// ```
402    ///
403    #[corresponds(EVP_PKEY_decrypt)]
404    pub fn decrypt(&self, from: &[u8], to: &mut [u8]) -> Result<usize, ErrorStack> {
405        let mut written = to.len();
406        unsafe {
407            cvt(ffi::EVP_PKEY_decrypt(
408                self.pctx,
409                to.as_mut_ptr(),
410                &mut written,
411                from.as_ptr(),
412                from.len(),
413            ))?;
414        }
415
416        Ok(written)
417    }
418
419    /// Gets the size of the buffer needed to decrypt the input data.
420    ///
421    /// This corresponds to `EVP_PKEY_decrypt` called with a null pointer as output argument.
422    #[corresponds(EVP_PKEY_decrypt)]
423    pub fn decrypt_len(&self, from: &[u8]) -> Result<usize, ErrorStack> {
424        let mut written = 0;
425        unsafe {
426            cvt(ffi::EVP_PKEY_decrypt(
427                self.pctx,
428                ptr::null_mut(),
429                &mut written,
430                from.as_ptr(),
431                from.len(),
432            ))?;
433        }
434
435        Ok(written)
436    }
437}
438
439#[cfg(test)]
440mod test {
441    use hex::FromHex;
442
443    use crate::encrypt::{Decrypter, Encrypter};
444    use crate::hash::MessageDigest;
445    use crate::pkey::PKey;
446    use crate::rsa::{Padding, Rsa};
447
448    const INPUT: &str =
449        "65794a68624763694f694a53557a49314e694a392e65794a7063334d694f694a71623255694c41304b49434a6c\
450         654841694f6a457a4d4441344d546b7a4f44417344516f67496d6830644841364c79396c654746746347786c4c\
451         6d4e76625339706331397962323930496a7030636e566c6651";
452
453    #[test]
454    fn rsa_encrypt_decrypt() {
455        let key = include_bytes!("../test/rsa.pem");
456        let private_key = Rsa::private_key_from_pem(key).unwrap();
457        let pkey = PKey::from_rsa(private_key).unwrap();
458
459        let mut encrypter = Encrypter::new(&pkey).unwrap();
460        encrypter.set_rsa_padding(Padding::PKCS1).unwrap();
461        let input = Vec::from_hex(INPUT).unwrap();
462        let buffer_len = encrypter.encrypt_len(&input).unwrap();
463        let mut encoded = vec![0u8; buffer_len];
464        let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
465        let encoded = &encoded[..encoded_len];
466
467        let mut decrypter = Decrypter::new(&pkey).unwrap();
468        decrypter.set_rsa_padding(Padding::PKCS1).unwrap();
469        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
470        let mut decoded = vec![0u8; buffer_len];
471        let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
472        let decoded = &decoded[..decoded_len];
473
474        assert_eq!(decoded, &*input);
475    }
476
477    #[test]
478    fn rsa_encrypt_decrypt_with_sha256() {
479        let key = include_bytes!("../test/rsa.pem");
480        let private_key = Rsa::private_key_from_pem(key).unwrap();
481        let pkey = PKey::from_rsa(private_key).unwrap();
482
483        let md = MessageDigest::sha256();
484
485        let mut encrypter = Encrypter::new(&pkey).unwrap();
486        encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
487        encrypter.set_rsa_oaep_md(md).unwrap();
488        encrypter.set_rsa_mgf1_md(md).unwrap();
489        let input = Vec::from_hex(INPUT).unwrap();
490        let buffer_len = encrypter.encrypt_len(&input).unwrap();
491        let mut encoded = vec![0u8; buffer_len];
492        let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
493        let encoded = &encoded[..encoded_len];
494
495        let mut decrypter = Decrypter::new(&pkey).unwrap();
496        decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
497        decrypter.set_rsa_oaep_md(md).unwrap();
498        decrypter.set_rsa_mgf1_md(md).unwrap();
499        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
500        let mut decoded = vec![0u8; buffer_len];
501        let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
502        let decoded = &decoded[..decoded_len];
503
504        assert_eq!(decoded, &*input);
505    }
506
507    #[test]
508    #[cfg(any(ossl102, libressl))]
509    fn rsa_encrypt_decrypt_oaep_label() {
510        let key = include_bytes!("../test/rsa.pem");
511        let private_key = Rsa::private_key_from_pem(key).unwrap();
512        let pkey = PKey::from_rsa(private_key).unwrap();
513
514        let mut encrypter = Encrypter::new(&pkey).unwrap();
515        encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
516        encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
517        let input = Vec::from_hex(INPUT).unwrap();
518        let buffer_len = encrypter.encrypt_len(&input).unwrap();
519        let mut encoded = vec![0u8; buffer_len];
520        let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
521        let encoded = &encoded[..encoded_len];
522
523        let mut decrypter = Decrypter::new(&pkey).unwrap();
524        decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
525        decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
526        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
527        let mut decoded = vec![0u8; buffer_len];
528        let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
529        let decoded = &decoded[..decoded_len];
530
531        assert_eq!(decoded, &*input);
532
533        decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap();
534        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
535        let mut decoded = vec![0u8; buffer_len];
536
537        assert!(decrypter.decrypt(encoded, &mut decoded).is_err());
538    }
539}