1use foreign_types::ForeignTypeRef;
53use std::marker::PhantomData;
54use std::ptr;
55
56use crate::error::ErrorStack;
57use crate::pkey::{HasPrivate, HasPublic, PKeyRef};
58use crate::{cvt, cvt_p};
59use openssl_macros::corresponds;
60
61pub struct Deriver<'a>(*mut ffi::EVP_PKEY_CTX, PhantomData<&'a ()>);
63
64unsafe impl Sync for Deriver<'_> {}
65unsafe impl Send for Deriver<'_> {}
66
67#[allow(clippy::len_without_is_empty)]
68impl<'a> Deriver<'a> {
69 #[corresponds(EVP_PKEY_derive_init)]
71 pub fn new<T>(key: &'a PKeyRef<T>) -> Result<Deriver<'a>, ErrorStack>
72 where
73 T: HasPrivate,
74 {
75 unsafe {
76 cvt_p(ffi::EVP_PKEY_CTX_new(key.as_ptr(), ptr::null_mut()))
77 .map(|p| Deriver(p, PhantomData))
78 .and_then(|ctx| cvt(ffi::EVP_PKEY_derive_init(ctx.0)).map(|_| ctx))
79 }
80 }
81
82 #[corresponds(EVP_PKEY_derive_set_peer)]
84 pub fn set_peer<T>(&mut self, key: &'a PKeyRef<T>) -> Result<(), ErrorStack>
85 where
86 T: HasPublic,
87 {
88 unsafe { cvt(ffi::EVP_PKEY_derive_set_peer(self.0, key.as_ptr())).map(|_| ()) }
89 }
90
91 #[corresponds(EVP_PKEY_derive_set_peer_ex)]
95 #[cfg(ossl300)]
96 pub fn set_peer_ex<T>(
97 &mut self,
98 key: &'a PKeyRef<T>,
99 validate_peer: bool,
100 ) -> Result<(), ErrorStack>
101 where
102 T: HasPublic,
103 {
104 unsafe {
105 cvt(ffi::EVP_PKEY_derive_set_peer_ex(
106 self.0,
107 key.as_ptr(),
108 validate_peer as i32,
109 ))
110 .map(|_| ())
111 }
112 }
113
114 #[corresponds(EVP_PKEY_derive)]
122 pub fn len(&mut self) -> Result<usize, ErrorStack> {
123 unsafe {
124 let mut len = 0;
125 cvt(ffi::EVP_PKEY_derive(self.0, ptr::null_mut(), &mut len)).map(|_| len)
126 }
127 }
128
129 #[corresponds(EVP_PKEY_derive)]
133 pub fn derive(&mut self, buf: &mut [u8]) -> Result<usize, ErrorStack> {
134 let mut len = buf.len();
135 unsafe {
136 cvt(ffi::EVP_PKEY_derive(
137 self.0,
138 buf.as_mut_ptr() as *mut _,
139 &mut len,
140 ))
141 .map(|_| len)
142 }
143 }
144
145 pub fn derive_to_vec(&mut self) -> Result<Vec<u8>, ErrorStack> {
152 let len = self.len()?;
153 let mut buf = vec![0; len];
154 let len = self.derive(&mut buf)?;
155 buf.truncate(len);
156 Ok(buf)
157 }
158}
159
160impl Drop for Deriver<'_> {
161 fn drop(&mut self) {
162 unsafe {
163 ffi::EVP_PKEY_CTX_free(self.0);
164 }
165 }
166}
167
168#[cfg(test)]
169mod test {
170 use super::*;
171
172 use crate::ec::{EcGroup, EcKey};
173 use crate::nid::Nid;
174 use crate::pkey::PKey;
175
176 #[test]
177 fn derive_without_peer() {
178 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
179 let ec_key = EcKey::generate(&group).unwrap();
180 let pkey = PKey::from_ec_key(ec_key).unwrap();
181 let mut deriver = Deriver::new(&pkey).unwrap();
182 deriver.derive_to_vec().unwrap_err();
183 }
184
185 #[test]
186 fn test_ec_key_derive() {
187 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
188 let ec_key = EcKey::generate(&group).unwrap();
189 let ec_key2 = EcKey::generate(&group).unwrap();
190 let pkey = PKey::from_ec_key(ec_key).unwrap();
191 let pkey2 = PKey::from_ec_key(ec_key2).unwrap();
192 let mut deriver = Deriver::new(&pkey).unwrap();
193 deriver.set_peer(&pkey2).unwrap();
194 let shared = deriver.derive_to_vec().unwrap();
195 assert!(!shared.is_empty());
196 }
197
198 #[test]
199 #[cfg(ossl300)]
200 fn test_ec_key_derive_ex() {
201 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
202 let ec_key = EcKey::generate(&group).unwrap();
203 let ec_key2 = EcKey::generate(&group).unwrap();
204 let pkey = PKey::from_ec_key(ec_key).unwrap();
205 let pkey2 = PKey::from_ec_key(ec_key2).unwrap();
206 let mut deriver = Deriver::new(&pkey).unwrap();
207 deriver.set_peer_ex(&pkey2, true).unwrap();
208 let shared = deriver.derive_to_vec().unwrap();
209 assert!(!shared.is_empty());
210 }
211}