skf_rs/helper/
mod.rs

1pub mod auth;
2pub mod easy;
3
4pub mod mem {
5    use crate::{ECCEncryptedData, EnvelopedKeyData};
6
7    use skf_api::native::types::ECCPublicKeyBlob;
8    use std::cmp::min;
9    use std::ffi::CStr;
10    use std::slice;
11
12    /// Returns the position of the first null byte
13    ///
14    /// [ptr] - The pointer to the buffer
15    ///
16    /// [len] - The length of the buffer
17    ///
18    /// # Examples
19    /// ```
20    /// use skf_rs::helper::mem::first_null_byte;
21    /// let ptr = b"Hello\0World\0".as_ptr();
22    /// unsafe {
23    ///     assert_eq!(Some(5), first_null_byte(ptr, 12));
24    ///     assert_eq!(None, first_null_byte(ptr, 5));
25    ///     assert_eq!(Some(4), first_null_byte(ptr.add(1), 11));
26    ///     assert_eq!(Some(0), first_null_byte(ptr.add(5), 7));
27    /// }
28    /// ```
29    #[inline]
30    #[must_use]
31    pub unsafe fn first_null_byte(ptr: *const u8, len: usize) -> Option<usize> {
32        let slice = unsafe { slice::from_raw_parts(ptr, len) };
33        slice.iter().position(|&x| x == 0)
34    }
35
36    /// Returns the position of the first two null byte
37    ///
38    /// [ptr] - The pointer to the buffer
39    ///
40    /// [len] - The length of the buffer
41    ///
42    /// # Examples
43    /// ```
44    /// use skf_rs::helper::mem::first_two_null_byte;
45    /// let ptr = b"Hello\0World\0\0".as_ptr();
46    /// unsafe {
47    ///     assert_eq!(Some(12), first_two_null_byte(ptr, 13));
48    ///     assert_eq!(Some(11), first_two_null_byte(ptr.add(1), 12));
49    ///     assert_eq!(None, first_two_null_byte(ptr, 12));
50    /// }
51    /// ```
52    #[inline]
53    #[must_use]
54    pub const unsafe fn first_two_null_byte(ptr: *const u8, len: usize) -> Option<usize> {
55        let mut pos = 0;
56        while pos < len {
57            if *ptr.add(pos) == 0 && pos + 1 < len && *ptr.add(pos + 1) == 0 {
58                return Some(pos + 1);
59            }
60            pos += 1;
61        }
62        None
63    }
64
65    /// Parse a C string from buffer
66    ///
67    /// [ptr] - The pointer to the buffer
68    ///
69    /// [len] - The length of the buffer
70    /// # Examples
71    /// ```
72    /// use std::ffi::CStr;
73    /// use skf_rs::helper::mem::parse_cstr;
74    /// let ptr = b"Hello\0World\0".as_ptr();
75    /// unsafe {
76    ///     assert_eq!(Some(CStr::from_bytes_with_nul(b"Hello\0").unwrap()), parse_cstr(ptr, 12));
77    ///     assert_eq!(Some(CStr::from_bytes_with_nul(b"lo\0").unwrap()), parse_cstr(ptr.add(3), 12));
78    ///     assert_eq!(Some(CStr::from_bytes_with_nul(b"World\0").unwrap()), parse_cstr(ptr.add(6), 12));
79    ///     assert_eq!(Some(CStr::from_bytes_with_nul(b"\0").unwrap()), parse_cstr(ptr.add(5), 1));
80    ///     assert_eq!(None, parse_cstr(ptr, 1));
81    /// }
82    /// ```
83    #[inline]
84    #[must_use]
85    pub unsafe fn parse_cstr<'a>(ptr: *const u8, len: usize) -> Option<&'a CStr> {
86        let slice = unsafe { slice::from_raw_parts(ptr, len) };
87        CStr::from_bytes_until_nul(slice).ok()
88    }
89
90    /// Parse a C string from buffer, use `CStr::to_string_lossy` to convert data
91    ///
92    /// [ptr] - The pointer to the buffer
93    ///
94    /// [len] - The length of the buffer
95    #[must_use]
96    pub unsafe fn parse_cstr_lossy(ptr: *const u8, len: usize) -> Option<String> {
97        let val = unsafe { parse_cstr(ptr, len) };
98        val.map(|s| s.to_string_lossy().to_string())
99    }
100
101    /// Parse C string list from buffer, the list may end with two null byte
102    ///
103    /// [ptr] - The pointer to the buffer
104    ///
105    /// [len] - The length of the buffer
106    /// # Examples
107    /// ```
108    /// use std::ffi::CStr;
109    /// use skf_rs::helper::mem::parse_cstr_list;
110    /// unsafe {
111    ///     let list = parse_cstr_list(b"Hello\0World\0\0".as_ptr(), 13);
112    ///     assert_eq!(CStr::from_bytes_with_nul(b"Hello\0").unwrap(), *list.get(0).unwrap());
113    ///     assert_eq!(CStr::from_bytes_with_nul(b"World\0").unwrap(), *list.get(1).unwrap());
114    ///
115    ///     let list = parse_cstr_list(b"Hello\0World\0".as_ptr(), 12);
116    ///     assert_eq!(CStr::from_bytes_with_nul(b"Hello\0").unwrap(), *list.get(0).unwrap());
117    ///     assert_eq!(CStr::from_bytes_with_nul(b"World\0").unwrap(), *list.get(1).unwrap());
118    ///
119    ///     let list = parse_cstr_list(b"Hello\0World".as_ptr(), 11);
120    ///     assert_eq!(CStr::from_bytes_with_nul(b"Hello\0").unwrap(), *list.get(0).unwrap());
121    ///
122    ///     let list = parse_cstr_list(b"Hello".as_ptr(), 5);
123    ///     assert!(list.is_empty());
124    /// }
125    /// ```
126    #[inline]
127    #[must_use]
128    pub unsafe fn parse_cstr_list<'a>(ptr: *const u8, len: usize) -> Vec<&'a CStr> {
129        let mut list: Vec<&CStr> = Vec::new();
130        let mut next_str = 0;
131        let mut pos = 0;
132        while pos < len {
133            if *ptr.add(pos) == 0 {
134                let bytes = slice::from_raw_parts(ptr.add(next_str), pos - next_str + 1);
135                list.push(CStr::from_bytes_with_nul_unchecked(bytes));
136                next_str = pos + 1;
137                if next_str < len && *ptr.add(next_str) == 0 {
138                    break;
139                }
140            }
141            pos += 1;
142        }
143        list
144    }
145
146    /// Parse C string list from buffer, the list may end with two null byte
147    ///
148    /// [ptr] - The pointer to the buffer
149    ///
150    /// [len] - The length of the buffer
151    #[must_use]
152    pub unsafe fn parse_cstr_list_lossy(ptr: *const u8, len: usize) -> Vec<String> {
153        let list = unsafe { parse_cstr_list(ptr, len) };
154        list.iter()
155            .map(|s| s.to_string_lossy().to_string())
156            .collect()
157    }
158
159    /// Write string to buffer
160    ///
161    /// [src] - The string to write,if too long, it will be truncated
162    ///
163    /// [buffer] - The buffer to write to,at least one byte to fill with null byte
164    ///
165    /// ## Memory copy
166    ///
167    /// - if the string is too long,it will be truncated,and the last byte will be set to null byte
168    /// - if the string is smaller than the buffer size,it will be filled with null byte
169    ///
170    /// ## example
171    /// ```
172    /// use skf_rs::helper::mem::write_cstr;
173    ///
174    /// let mut buffer = [0u8; 11];
175    /// unsafe {
176    ///     write_cstr("Hello World", &mut buffer);
177    ///}
178    ///assert_eq!(b"Hello Worl\0", &buffer);
179    ///```
180    pub unsafe fn write_cstr(src: impl AsRef<str>, buffer: &mut [u8]) {
181        let src = src.as_ref().as_bytes();
182        let len = min(src.len(), buffer.len());
183        debug_assert!(len > 0);
184        unsafe {
185            std::ptr::copy(src.as_ptr(), buffer.as_mut_ptr(), len);
186        }
187        if len < buffer.len() {
188            buffer[len] = 0;
189        } else {
190            buffer[len - 1] = 0;
191        }
192    }
193
194    /// Write string to buffer
195    ///
196    /// [src] - The string to write
197    ///
198    /// [buffer_ptr] - The buffer to write to
199    ///
200    /// [buffer_len] - The length of the buffer
201    pub unsafe fn write_cstr_ptr(src: impl AsRef<str>, buffer_ptr: *mut u8, buffer_len: usize) {
202        let bytes = slice::from_raw_parts_mut(buffer_ptr, buffer_len);
203        write_cstr(src, bytes);
204    }
205
206    impl ECCEncryptedData {
207        /// Convert to bytes of `ECCCipherBlob`
208        pub fn blob_bytes(&self) -> Vec<u8> {
209            use skf_api::native::types::ULONG;
210
211            let len = 64 + 64 + 32 + 4 + self.cipher.len();
212            let mut vec: Vec<u8> = Vec::with_capacity(len);
213            let cipher_len: [u8; 4] = (self.cipher.len() as ULONG).to_ne_bytes();
214            vec.extend_from_slice(&self.ec_x);
215            vec.extend_from_slice(&self.ec_y);
216            vec.extend_from_slice(&self.hash);
217            vec.extend_from_slice(&cipher_len);
218            vec.extend_from_slice(&self.cipher);
219            vec
220        }
221    }
222
223    impl EnvelopedKeyData {
224        /// Convert to bytes of `EnvelopedKeyBlob`
225        pub fn blob_bytes(&self) -> Vec<u8> {
226            use skf_api::native::types::ULONG;
227
228            let cipher_blob = self.ecc_cipher.blob_bytes();
229            let len = 4 + 4 + 4 + 64 + std::mem::size_of::<ECCPublicKeyBlob>() + cipher_blob.len();
230            let mut vec: Vec<u8> = Vec::with_capacity(len);
231
232            // version
233            let bytes: [u8; 4] = (self.version as ULONG).to_ne_bytes();
234            vec.extend_from_slice(&bytes);
235            // sym_alg_id
236            let bytes: [u8; 4] = (self.sym_alg_id as ULONG).to_ne_bytes();
237            vec.extend_from_slice(&bytes);
238            // bits
239            let bytes: [u8; 4] = (self.bits as ULONG).to_ne_bytes();
240            vec.extend_from_slice(&bytes);
241            // encrypted_pri_key
242            vec.extend_from_slice(&self.encrypted_pri_key);
243
244            // pub_key.bit_len
245            let bytes: [u8; 4] = (self.pub_key.bit_len as ULONG).to_ne_bytes();
246            vec.extend_from_slice(&bytes);
247
248            // pub_key.x_coordinate
249            vec.extend_from_slice(&self.pub_key.x_coordinate);
250
251            // pub_key.y_coordinate
252            vec.extend_from_slice(&self.pub_key.y_coordinate);
253
254            // cipher
255            vec.extend_from_slice(&cipher_blob);
256            vec
257        }
258    }
259    #[cfg(test)]
260    mod tests {
261        use super::*;
262        use crate::ECCEncryptedData;
263
264        #[test]
265        fn parse_terminated_cstr_list_test() {
266            unsafe {
267                let list = parse_cstr_list(b"Hello\0\0".as_ptr(), 7);
268                assert_eq!(1, list.len());
269
270                let list = parse_cstr_list(b"Hello\0World\0\0".as_ptr(), 13);
271                assert_eq!(
272                    CStr::from_bytes_with_nul(b"Hello\0").unwrap(),
273                    *list.first().unwrap()
274                );
275                assert_eq!(
276                    CStr::from_bytes_with_nul(b"World\0").unwrap(),
277                    *list.get(1).unwrap()
278                );
279            }
280        }
281        #[test]
282        fn write_cstr_test() {
283            let input = "Hello World";
284            let mut buffer = [0u8; 12];
285            unsafe {
286                write_cstr(input, &mut buffer);
287            }
288            assert_eq!(b"Hello World\0", &buffer);
289
290            let mut buffer = [0u8; 11];
291            unsafe {
292                write_cstr(input, &mut buffer);
293            }
294            assert_eq!(b"Hello Worl\0", &buffer);
295
296            let mut buffer = [0u8; 1];
297            unsafe {
298                write_cstr(input, &mut buffer);
299            }
300            assert_eq!(b"\0", &buffer);
301        }
302
303        #[test]
304        fn cipher_blob_data_test() {
305            use skf_api::native::types::ECCCipherBlob;
306            let data = ECCEncryptedData {
307                ec_x: [1u8; 64],
308                ec_y: [2u8; 64],
309                hash: [3u8; 32],
310                cipher: vec![1u8, 2u8, 3u8, 4u8, 5u8],
311            };
312            let mem = data.blob_bytes();
313            assert_eq!(mem.len(), 64 + 64 + 32 + 4 + 5);
314            unsafe {
315                let blob_ptr = mem.as_ptr() as *const ECCCipherBlob;
316                let blob = &*blob_ptr;
317
318                assert_eq!(blob.x_coordinate, [1u8; 64]);
319                assert_eq!(blob.y_coordinate, [2u8; 64]);
320                assert_eq!(blob.hash, [3u8; 32]);
321                assert_eq!(std::ptr::addr_of!(blob.cipher_len).read_unaligned(), 5);
322                assert_eq!(blob.cipher, [1u8]);
323            }
324        }
325    }
326}
327
328pub mod param {
329    use crate::error::InvalidArgumentError;
330    use crate::Result;
331    use std::ffi::CString;
332
333    /// Convert `&str` to `CString`
334    ///
335    /// ## Errors
336    /// This function will return an error if conversion from `&str` to `CString` fails,The error message use `param_name` to describe the parameter.
337    pub fn as_cstring(
338        param_name: impl AsRef<str>,
339        param_value: impl AsRef<str>,
340    ) -> Result<CString> {
341        let value = CString::new(param_value.as_ref()).map_err(|e| {
342            InvalidArgumentError::new(
343                format!("parameter '{}' is invalid", param_name.as_ref()),
344                Some(anyhow::Error::new(e)),
345            )
346        })?;
347        Ok(value)
348    }
349}