rustls_cng/
store.rs

1//! Windows certificate store wrapper
2
3use std::{os::raw::c_void, ptr};
4
5use windows_sys::Win32::Security::Cryptography::*;
6
7use crate::{cert::CertContext, error::CngError, Result};
8
9const MY_ENCODING_TYPE: CERT_QUERY_ENCODING_TYPE = PKCS_7_ASN_ENCODING | X509_ASN_ENCODING;
10
11macro_rules! utf16z {
12    ($str: expr) => {
13        $str.encode_utf16().chain([0]).collect::<Vec<_>>()
14    };
15}
16
17/// Certificate store type
18#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd)]
19pub enum CertStoreType {
20    LocalMachine,
21    CurrentUser,
22    CurrentService,
23}
24
25impl CertStoreType {
26    fn as_flags(&self) -> u32 {
27        match self {
28            CertStoreType::LocalMachine => {
29                CERT_SYSTEM_STORE_LOCAL_MACHINE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
30            }
31            CertStoreType::CurrentUser => {
32                CERT_SYSTEM_STORE_CURRENT_USER_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
33            }
34            CertStoreType::CurrentService => {
35                CERT_SYSTEM_STORE_CURRENT_SERVICE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
36            }
37        }
38    }
39}
40
41/// Windows certificate store wrapper
42#[derive(Debug)]
43pub struct CertStore(HCERTSTORE);
44
45unsafe impl Send for CertStore {}
46unsafe impl Sync for CertStore {}
47
48impl CertStore {
49    /// Return an inner handle to the store
50    pub fn inner(&self) -> HCERTSTORE {
51        self.0
52    }
53
54    /// Open certificate store of the given type and name
55    pub fn open(store_type: CertStoreType, store_name: &str) -> Result<CertStore> {
56        unsafe {
57            let store_name = utf16z!(store_name);
58            let handle = CertOpenStore(
59                CERT_STORE_PROV_SYSTEM_W,
60                CERT_QUERY_ENCODING_TYPE::default(),
61                HCRYPTPROV_LEGACY::default(),
62                store_type.as_flags() | CERT_STORE_OPEN_EXISTING_FLAG,
63                store_name.as_ptr() as _,
64            );
65            if handle.is_null() {
66                Err(CngError::from_win32_error())
67            } else {
68                Ok(CertStore(handle))
69            }
70        }
71    }
72
73    /// Import certificate store from PKCS12 file
74    pub fn from_pkcs12(data: &[u8], password: &str) -> Result<CertStore> {
75        unsafe {
76            let blob = CRYPT_INTEGER_BLOB {
77                cbData: data.len() as u32,
78                pbData: data.as_ptr() as _,
79            };
80
81            let password = utf16z!(password);
82            let store = PFXImportCertStore(
83                &blob,
84                password.as_ptr(),
85                CRYPT_EXPORTABLE | PKCS12_INCLUDE_EXTENDED_PROPERTIES | PKCS12_PREFER_CNG_KSP,
86            );
87            if store.is_null() {
88                Err(CngError::from_win32_error())
89            } else {
90                Ok(CertStore(store))
91            }
92        }
93    }
94
95    /// Find list of certificates matching the subject substring
96    pub fn find_by_subject_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
97    where
98        S: AsRef<str>,
99    {
100        self.find_by_str(subject.as_ref(), CERT_FIND_SUBJECT_STR)
101    }
102
103    /// Find list of certificates matching the exact subject name
104    pub fn find_by_subject_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
105    where
106        S: AsRef<str>,
107    {
108        self.find_by_name(subject.as_ref(), CERT_FIND_SUBJECT_NAME)
109    }
110
111    /// Find list of certificates matching the issuer substring
112    pub fn find_by_issuer_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
113    where
114        S: AsRef<str>,
115    {
116        self.find_by_str(subject.as_ref(), CERT_FIND_ISSUER_STR)
117    }
118
119    /// Find list of certificates matching the exact issuer name
120    pub fn find_by_issuer_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
121    where
122        S: AsRef<str>,
123    {
124        self.find_by_name(subject.as_ref(), CERT_FIND_ISSUER_NAME)
125    }
126
127    /// Find list of certificates matching the SHA1 hash
128    pub fn find_by_sha1<D>(&self, hash: D) -> Result<Vec<CertContext>>
129    where
130        D: AsRef<[u8]>,
131    {
132        let hash_blob = CRYPT_INTEGER_BLOB {
133            cbData: hash.as_ref().len() as u32,
134            pbData: hash.as_ref().as_ptr() as _,
135        };
136        unsafe { self.do_find(CERT_FIND_HASH, &hash_blob as *const _ as _) }
137    }
138
139    // On later OS releases, we added CERT_FIND_SHA256_HASH.
140    // However, rustls-cng could be installed on earlier OS release where this FIND_SHA256 isn't present.
141    // But the CERT_SHA256_HASH_PROP_ID is present.
142    // So will need to add a new internal find function that gets and compares the SHA256 property.
143    // Also, since SHA1 is being deprecated, Windows components should not use.
144    // Therefore, the need to find via SHA256 instead of SHA1.
145
146    /// Find list of certificates matching the SHA256 hash
147    pub fn find_by_sha256<D>(&self, hash: D) -> Result<Vec<CertContext>>
148    where
149        D: AsRef<[u8]>,
150    {
151        let hash_blob = CRYPT_INTEGER_BLOB {
152            cbData: hash.as_ref().len() as u32,
153            pbData: hash.as_ref().as_ptr() as _,
154        };
155        unsafe { self.do_find_by_sha256_property(&hash_blob as *const _ as _) }
156    }
157
158    /// Find list of certificates matching the key identifier
159    pub fn find_by_key_id<D>(&self, key_id: D) -> Result<Vec<CertContext>>
160    where
161        D: AsRef<[u8]>,
162    {
163        let cert_id = CERT_ID {
164            dwIdChoice: CERT_ID_KEY_IDENTIFIER,
165            Anonymous: CERT_ID_0 {
166                KeyId: CRYPT_INTEGER_BLOB {
167                    cbData: key_id.as_ref().len() as u32,
168                    pbData: key_id.as_ref().as_ptr() as _,
169                },
170            },
171        };
172        unsafe { self.do_find(CERT_FIND_CERT_ID, &cert_id as *const _ as _) }
173    }
174
175    /// Get all certificates
176    pub fn find_all(&self) -> Result<Vec<CertContext>> {
177        unsafe { self.do_find(CERT_FIND_ANY, ptr::null()) }
178    }
179
180    unsafe fn do_find(
181        &self,
182        flags: CERT_FIND_FLAGS,
183        find_param: *const c_void,
184    ) -> Result<Vec<CertContext>> {
185        let mut certs = Vec::new();
186
187        let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
188
189        loop {
190            cert = CertFindCertificateInStore(self.0, MY_ENCODING_TYPE, 0, flags, find_param, cert);
191            if cert.is_null() {
192                break;
193            } else {
194                // increase refcount because it will be released by next call to CertFindCertificateInStore
195                let cert = CertDuplicateCertificateContext(cert);
196                certs.push(CertContext::new_owned(cert))
197            }
198        }
199        Ok(certs)
200    }
201
202    unsafe fn do_find_by_sha256_property(
203        &self,
204        find_param: *const c_void,
205    ) -> Result<Vec<CertContext>> {
206        let mut certs = Vec::new();
207        let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
208        let hash_blob = &*(find_param as *const CRYPT_INTEGER_BLOB);
209        let sha256_hash = std::slice::from_raw_parts(hash_blob.pbData, hash_blob.cbData as usize);
210        loop {
211            cert = CertFindCertificateInStore(
212                self.0,
213                MY_ENCODING_TYPE,
214                0,
215                CERT_FIND_ANY,
216                find_param,
217                cert,
218            );
219            if cert.is_null() {
220                break;
221            } else {
222                let mut prop_data = [0u8; 32];
223                let mut prop_data_len = prop_data.len() as u32;
224
225                if CertGetCertificateContextProperty(
226                    cert,
227                    CERT_SHA256_HASH_PROP_ID,
228                    prop_data.as_mut_ptr() as *mut c_void,
229                    &mut prop_data_len,
230                ) != 0
231                {
232                    if prop_data[..prop_data_len as usize] == sha256_hash[..] {
233                        let cert = CertDuplicateCertificateContext(cert);
234                        certs.push(CertContext::new_owned(cert))
235                    }
236                }
237            }
238        }
239        Ok(certs)
240    }
241
242    fn find_by_str(&self, pattern: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
243        let u16pattern = utf16z!(pattern);
244        unsafe { self.do_find(flags, u16pattern.as_ptr() as _) }
245    }
246
247    fn find_by_name(&self, field: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
248        let mut name_size = 0;
249
250        unsafe {
251            let field_name = utf16z!(field);
252            if CertStrToNameW(
253                MY_ENCODING_TYPE,
254                field_name.as_ptr(),
255                CERT_X500_NAME_STR,
256                ptr::null(),
257                ptr::null_mut(),
258                &mut name_size,
259                ptr::null_mut(),
260            ) == 0
261            {
262                return Err(CngError::from_win32_error());
263            }
264
265            let mut x509name = vec![0u8; name_size as usize];
266            if CertStrToNameW(
267                MY_ENCODING_TYPE,
268                field_name.as_ptr(),
269                CERT_X500_NAME_STR,
270                ptr::null(),
271                x509name.as_mut_ptr(),
272                &mut name_size,
273                ptr::null_mut(),
274            ) == 0
275            {
276                return Err(CngError::from_win32_error());
277            }
278
279            let name_blob = CRYPT_INTEGER_BLOB {
280                cbData: x509name.len() as _,
281                pbData: x509name.as_mut_ptr(),
282            };
283
284            self.do_find(flags, &name_blob as *const _ as _)
285        }
286    }
287}
288
289impl Drop for CertStore {
290    fn drop(&mut self) {
291        unsafe { CertCloseStore(self.0, 0) };
292    }
293}