skf_rs/engine/
crypto.rs

1use crate::engine::symbol::{crypto_fn, ModBlockCipher, SymbolBundle};
2use crate::error::{InvalidArgumentError, SkfErr};
3use crate::{BlockCipherParameter, Error, ManagedKey, Result, SkfBlockCipher};
4use skf_api::native::error::SAR_OK;
5use skf_api::native::types::{BlockCipherParam, BYTE, HANDLE, MAX_IV_LEN, ULONG};
6use std::fmt::{Debug, Formatter};
7use std::sync::Arc;
8use tracing::{instrument, trace};
9
10pub(crate) struct ManagedKeyImpl {
11    close_fn: crypto_fn::SKF_CloseHandle,
12    handle: HANDLE,
13}
14impl Debug for ManagedKeyImpl {
15    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
16        write!(f, "ManagedKeyImpl")
17    }
18}
19impl Drop for ManagedKeyImpl {
20    fn drop(&mut self) {
21        let _ = self.close();
22    }
23}
24
25impl AsRef<HANDLE> for ManagedKeyImpl {
26    fn as_ref(&self) -> &HANDLE {
27        &self.handle
28    }
29}
30
31impl ManagedKey for ManagedKeyImpl {}
32impl Debug for dyn ManagedKey {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        write!(f, "Handle:{:p}", self.as_ref())
35    }
36}
37impl ManagedKeyImpl {
38    pub(crate) fn try_new(handle: HANDLE, lib: &Arc<libloading::Library>) -> Result<Self> {
39        let close_fn = unsafe { SymbolBundle::new(lib, b"SKF_CloseHandle\0")? };
40        Ok(Self { close_fn, handle })
41    }
42
43    #[instrument]
44    pub(crate) fn close(&mut self) -> Result<()> {
45        let ret = unsafe { (self.close_fn)(self.handle) };
46        trace!("[SKF_CloseHandle]: ret = {}", ret);
47        if ret != SAR_OK {
48            return Err(Error::Skf(SkfErr::of_code(ret)));
49        }
50        self.handle = std::ptr::null();
51        Ok(())
52    }
53}
54
55pub(crate) struct SkfBlockCipherImpl {
56    symbols: ModBlockCipher,
57}
58impl Debug for SkfBlockCipherImpl {
59    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
60        write!(f, "SkfBlockCipherImpl")
61    }
62}
63impl SkfBlockCipherImpl {
64    pub fn new(lib: &Arc<libloading::Library>) -> Result<Self> {
65        let symbols = ModBlockCipher::load_symbols(lib)?;
66        Ok(Self { symbols })
67    }
68}
69
70impl SkfBlockCipher for SkfBlockCipherImpl {
71    #[instrument(skip(key))]
72    fn encrypt_init(&self, key: &dyn ManagedKey, param: &BlockCipherParameter) -> Result<()> {
73        let func = self.symbols.encrypt_init.as_ref().expect("Symbol not load");
74        let param = make_cipher_param(param)?;
75        let ret = unsafe { func(*key.as_ref(), param) };
76        trace!("[SKF_EncryptInit]: ret = {}", ret);
77        if ret != SAR_OK {
78            return Err(Error::Skf(SkfErr::of_code(ret)));
79        }
80        Ok(())
81    }
82
83    #[instrument(skip(key, data))]
84    fn encrypt(&self, key: &dyn ManagedKey, data: &[u8], buffer_size: usize) -> Result<Vec<u8>> {
85        let func = self.symbols.encrypt.as_ref().expect("Symbol not load");
86        let mut len = buffer_size as ULONG;
87        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
88        let ret = unsafe {
89            func(
90                *key.as_ref(),
91                data.as_ptr() as *const BYTE,
92                data.len() as ULONG,
93                buffer.as_mut_ptr() as *mut BYTE,
94                &mut len,
95            )
96        };
97        trace!("[SKF_Encrypt]: ret = {}", ret);
98        if ret != SAR_OK {
99            return Err(Error::Skf(SkfErr::of_code(ret)));
100        }
101        trace!("[SKF_Encrypt]: output len = {}", len);
102        unsafe { buffer.set_len(len as usize) };
103        Ok(buffer)
104    }
105
106    #[instrument(skip(key, data))]
107    fn encrypt_update(
108        &self,
109        key: &dyn ManagedKey,
110        data: &[u8],
111        buffer_size: usize,
112    ) -> Result<Vec<u8>> {
113        let func = self
114            .symbols
115            .encrypt_update
116            .as_ref()
117            .expect("Symbol not load");
118        let mut len = buffer_size as ULONG;
119        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
120        let ret = unsafe {
121            func(
122                *key.as_ref(),
123                data.as_ptr() as *const BYTE,
124                data.len() as ULONG,
125                buffer.as_mut_ptr() as *mut BYTE,
126                &mut len,
127            )
128        };
129        trace!("[SKF_EncryptUpdate]: ret = {}", ret);
130        if ret != SAR_OK {
131            return Err(Error::Skf(SkfErr::of_code(ret)));
132        }
133        trace!("[SKF_EncryptUpdate]: output len = {}", len);
134        unsafe { buffer.set_len(len as usize) };
135        Ok(buffer)
136    }
137
138    #[instrument(skip(key))]
139    fn encrypt_final(&self, key: &dyn ManagedKey, buffer_size: usize) -> Result<Vec<u8>> {
140        let func = self
141            .symbols
142            .encrypt_final
143            .as_ref()
144            .expect("Symbol not load");
145        let mut len = buffer_size as ULONG;
146        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
147        let ret = unsafe { func(*key.as_ref(), buffer.as_mut_ptr() as *mut BYTE, &mut len) };
148        trace!("[SKF_EncryptFinal]: ret = {}", ret);
149        if ret != SAR_OK {
150            return Err(Error::Skf(SkfErr::of_code(ret)));
151        }
152        trace!("[SKF_EncryptFinal]: output len = {}", len);
153        unsafe { buffer.set_len(len as usize) };
154        Ok(buffer)
155    }
156
157    #[instrument(skip(key))]
158    fn decrypt_init(&self, key: &dyn ManagedKey, param: &BlockCipherParameter) -> Result<()> {
159        let func = self.symbols.decrypt_init.as_ref().expect("Symbol not load");
160        let param = make_cipher_param(param)?;
161        let ret = unsafe { func(*key.as_ref(), param) };
162        trace!("[SKF_DecryptInit]: ret = {}", ret);
163        if ret != SAR_OK {
164            return Err(Error::Skf(SkfErr::of_code(ret)));
165        }
166        Ok(())
167    }
168
169    #[instrument(skip(key, data))]
170    fn decrypt(&self, key: &dyn ManagedKey, data: &[u8], buffer_size: usize) -> Result<Vec<u8>> {
171        let func = self.symbols.decrypt.as_ref().expect("Symbol not load");
172        let mut len = buffer_size as ULONG;
173        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
174        let ret = unsafe {
175            func(
176                *key.as_ref(),
177                data.as_ptr() as *const BYTE,
178                data.len() as ULONG,
179                buffer.as_mut_ptr() as *mut BYTE,
180                &mut len,
181            )
182        };
183        trace!("[SKF_Decrypt]: ret = {}", ret);
184        if ret != SAR_OK {
185            return Err(Error::Skf(SkfErr::of_code(ret)));
186        }
187        trace!("[SKF_Decrypt]: output len = {}", len);
188        unsafe { buffer.set_len(len as usize) };
189        Ok(buffer)
190    }
191
192    #[instrument(skip(key, data))]
193    fn decrypt_update(
194        &self,
195        key: &dyn ManagedKey,
196        data: &[u8],
197        buffer_size: usize,
198    ) -> Result<Vec<u8>> {
199        let func = self
200            .symbols
201            .decrypt_update
202            .as_ref()
203            .expect("Symbol not load");
204        let mut len = buffer_size as ULONG;
205        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
206        let ret = unsafe {
207            func(
208                *key.as_ref(),
209                data.as_ptr() as *const BYTE,
210                data.len() as ULONG,
211                buffer.as_mut_ptr() as *mut BYTE,
212                &mut len,
213            )
214        };
215        trace!("[SKF_DecryptUpdate]: ret = {}", ret);
216        if ret != SAR_OK {
217            return Err(Error::Skf(SkfErr::of_code(ret)));
218        }
219        trace!("[SKF_DecryptUpdate]: output len = {}", len);
220        unsafe { buffer.set_len(len as usize) };
221        Ok(buffer)
222    }
223
224    #[instrument(skip(key))]
225    fn decrypt_final(&self, key: &dyn ManagedKey, buffer_size: usize) -> Result<Vec<u8>> {
226        let func = self
227            .symbols
228            .decrypt_final
229            .as_ref()
230            .expect("Symbol not load");
231        let mut len = buffer_size as ULONG;
232        let mut buffer = Vec::<u8>::with_capacity(buffer_size);
233        let ret = unsafe { func(*key.as_ref(), buffer.as_mut_ptr() as *mut BYTE, &mut len) };
234        trace!("[SKF_DecryptFinal]: ret = {}", ret);
235        if ret != SAR_OK {
236            return Err(Error::Skf(SkfErr::of_code(ret)));
237        }
238        trace!("[SKF_DecryptFinal]: output len = {}", len);
239        unsafe { buffer.set_len(len as usize) };
240        Ok(buffer)
241    }
242}
243
244fn make_cipher_param(src: &BlockCipherParameter) -> Result<BlockCipherParam> {
245    if src.iv.len() > MAX_IV_LEN {
246        let err = InvalidArgumentError::new(format!("max iv length is {}", MAX_IV_LEN), None);
247        return Err(Error::InvalidArgument(err));
248    }
249    let mut iv = [0u8 as BYTE; MAX_IV_LEN];
250    unsafe { std::ptr::copy(src.iv.as_ptr(), iv.as_mut_ptr(), src.iv.len()) };
251    Ok(BlockCipherParam {
252        iv,
253        iv_len: src.iv.len() as ULONG,
254        padding_type: src.padding_type as ULONG,
255        feed_bit_len: src.feed_bit_len as ULONG,
256    })
257}
258
259#[cfg(test)]
260mod test {
261    use super::*;
262
263    #[test]
264    fn make_cipher_param_test() {
265        let src = BlockCipherParameter {
266            iv: vec![],
267            padding_type: 0,
268            feed_bit_len: 0,
269        };
270        assert!(make_cipher_param(&src).is_ok());
271
272        let src = BlockCipherParameter {
273            iv: [0u8; 1].to_vec(),
274            padding_type: 0,
275            feed_bit_len: 0,
276        };
277        assert!(make_cipher_param(&src).is_ok());
278
279        // fail case: iv length > 32
280        let src = BlockCipherParameter {
281            iv: [0u8; 33].to_vec(),
282            padding_type: 0,
283            feed_bit_len: 0,
284        };
285        assert!(matches!(
286            make_cipher_param(&src).unwrap_err(),
287            Error::InvalidArgument(_)
288        ));
289    }
290}