1use std::{os::raw::c_void, ptr};
4
5use windows_sys::Win32::Security::Cryptography::*;
6
7use crate::{cert::CertContext, error::CngError, Result};
8
9const MY_ENCODING_TYPE: CERT_QUERY_ENCODING_TYPE = PKCS_7_ASN_ENCODING | X509_ASN_ENCODING;
10
11macro_rules! utf16z {
12 ($str: expr) => {
13 $str.encode_utf16().chain([0]).collect::<Vec<_>>()
14 };
15}
16
17#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd)]
19pub enum CertStoreType {
20 LocalMachine,
21 CurrentUser,
22 CurrentService,
23}
24
25impl CertStoreType {
26 fn as_flags(&self) -> u32 {
27 match self {
28 CertStoreType::LocalMachine => {
29 CERT_SYSTEM_STORE_LOCAL_MACHINE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
30 }
31 CertStoreType::CurrentUser => {
32 CERT_SYSTEM_STORE_CURRENT_USER_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
33 }
34 CertStoreType::CurrentService => {
35 CERT_SYSTEM_STORE_CURRENT_SERVICE_ID << CERT_SYSTEM_STORE_LOCATION_SHIFT
36 }
37 }
38 }
39}
40
41#[derive(Debug)]
43pub struct CertStore(HCERTSTORE);
44
45unsafe impl Send for CertStore {}
46unsafe impl Sync for CertStore {}
47
48impl CertStore {
49 pub fn inner(&self) -> HCERTSTORE {
51 self.0
52 }
53
54 pub fn open(store_type: CertStoreType, store_name: &str) -> Result<CertStore> {
56 unsafe {
57 let store_name = utf16z!(store_name);
58 let handle = CertOpenStore(
59 CERT_STORE_PROV_SYSTEM_W,
60 CERT_QUERY_ENCODING_TYPE::default(),
61 HCRYPTPROV_LEGACY::default(),
62 store_type.as_flags() | CERT_STORE_OPEN_EXISTING_FLAG,
63 store_name.as_ptr() as _,
64 );
65 if handle.is_null() {
66 Err(CngError::from_win32_error())
67 } else {
68 Ok(CertStore(handle))
69 }
70 }
71 }
72
73 pub fn from_pkcs12(data: &[u8], password: &str) -> Result<CertStore> {
75 unsafe {
76 let blob = CRYPT_INTEGER_BLOB {
77 cbData: data.len() as u32,
78 pbData: data.as_ptr() as _,
79 };
80
81 let password = utf16z!(password);
82 let store = PFXImportCertStore(
83 &blob,
84 password.as_ptr(),
85 CRYPT_EXPORTABLE | PKCS12_INCLUDE_EXTENDED_PROPERTIES | PKCS12_PREFER_CNG_KSP,
86 );
87 if store.is_null() {
88 Err(CngError::from_win32_error())
89 } else {
90 Ok(CertStore(store))
91 }
92 }
93 }
94
95 pub fn find_by_subject_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
97 where
98 S: AsRef<str>,
99 {
100 self.find_by_str(subject.as_ref(), CERT_FIND_SUBJECT_STR)
101 }
102
103 pub fn find_by_subject_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
105 where
106 S: AsRef<str>,
107 {
108 self.find_by_name(subject.as_ref(), CERT_FIND_SUBJECT_NAME)
109 }
110
111 pub fn find_by_issuer_str<S>(&self, subject: S) -> Result<Vec<CertContext>>
113 where
114 S: AsRef<str>,
115 {
116 self.find_by_str(subject.as_ref(), CERT_FIND_ISSUER_STR)
117 }
118
119 pub fn find_by_issuer_name<S>(&self, subject: S) -> Result<Vec<CertContext>>
121 where
122 S: AsRef<str>,
123 {
124 self.find_by_name(subject.as_ref(), CERT_FIND_ISSUER_NAME)
125 }
126
127 pub fn find_by_sha1<D>(&self, hash: D) -> Result<Vec<CertContext>>
129 where
130 D: AsRef<[u8]>,
131 {
132 let hash_blob = CRYPT_INTEGER_BLOB {
133 cbData: hash.as_ref().len() as u32,
134 pbData: hash.as_ref().as_ptr() as _,
135 };
136 unsafe { self.do_find(CERT_FIND_HASH, &hash_blob as *const _ as _) }
137 }
138
139 pub fn find_by_sha256<D>(&self, hash: D) -> Result<Vec<CertContext>>
148 where
149 D: AsRef<[u8]>,
150 {
151 let hash_blob = CRYPT_INTEGER_BLOB {
152 cbData: hash.as_ref().len() as u32,
153 pbData: hash.as_ref().as_ptr() as _,
154 };
155 unsafe { self.do_find_by_sha256_property(&hash_blob as *const _ as _) }
156 }
157
158 pub fn find_by_key_id<D>(&self, key_id: D) -> Result<Vec<CertContext>>
160 where
161 D: AsRef<[u8]>,
162 {
163 let cert_id = CERT_ID {
164 dwIdChoice: CERT_ID_KEY_IDENTIFIER,
165 Anonymous: CERT_ID_0 {
166 KeyId: CRYPT_INTEGER_BLOB {
167 cbData: key_id.as_ref().len() as u32,
168 pbData: key_id.as_ref().as_ptr() as _,
169 },
170 },
171 };
172 unsafe { self.do_find(CERT_FIND_CERT_ID, &cert_id as *const _ as _) }
173 }
174
175 pub fn find_all(&self) -> Result<Vec<CertContext>> {
177 unsafe { self.do_find(CERT_FIND_ANY, ptr::null()) }
178 }
179
180 unsafe fn do_find(
181 &self,
182 flags: CERT_FIND_FLAGS,
183 find_param: *const c_void,
184 ) -> Result<Vec<CertContext>> {
185 let mut certs = Vec::new();
186
187 let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
188
189 loop {
190 cert = CertFindCertificateInStore(self.0, MY_ENCODING_TYPE, 0, flags, find_param, cert);
191 if cert.is_null() {
192 break;
193 } else {
194 let cert = CertDuplicateCertificateContext(cert);
196 certs.push(CertContext::new_owned(cert))
197 }
198 }
199 Ok(certs)
200 }
201
202 unsafe fn do_find_by_sha256_property(
203 &self,
204 find_param: *const c_void,
205 ) -> Result<Vec<CertContext>> {
206 let mut certs = Vec::new();
207 let mut cert: *mut CERT_CONTEXT = ptr::null_mut();
208 let hash_blob = &*(find_param as *const CRYPT_INTEGER_BLOB);
209 let sha256_hash = std::slice::from_raw_parts(hash_blob.pbData, hash_blob.cbData as usize);
210 loop {
211 cert = CertFindCertificateInStore(
212 self.0,
213 MY_ENCODING_TYPE,
214 0,
215 CERT_FIND_ANY,
216 find_param,
217 cert,
218 );
219 if cert.is_null() {
220 break;
221 } else {
222 let mut prop_data = [0u8; 32];
223 let mut prop_data_len = prop_data.len() as u32;
224
225 if CertGetCertificateContextProperty(
226 cert,
227 CERT_SHA256_HASH_PROP_ID,
228 prop_data.as_mut_ptr() as *mut c_void,
229 &mut prop_data_len,
230 ) != 0
231 {
232 if prop_data[..prop_data_len as usize] == sha256_hash[..] {
233 let cert = CertDuplicateCertificateContext(cert);
234 certs.push(CertContext::new_owned(cert))
235 }
236 }
237 }
238 }
239 Ok(certs)
240 }
241
242 fn find_by_str(&self, pattern: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
243 let u16pattern = utf16z!(pattern);
244 unsafe { self.do_find(flags, u16pattern.as_ptr() as _) }
245 }
246
247 fn find_by_name(&self, field: &str, flags: CERT_FIND_FLAGS) -> Result<Vec<CertContext>> {
248 let mut name_size = 0;
249
250 unsafe {
251 let field_name = utf16z!(field);
252 if CertStrToNameW(
253 MY_ENCODING_TYPE,
254 field_name.as_ptr(),
255 CERT_X500_NAME_STR,
256 ptr::null(),
257 ptr::null_mut(),
258 &mut name_size,
259 ptr::null_mut(),
260 ) == 0
261 {
262 return Err(CngError::from_win32_error());
263 }
264
265 let mut x509name = vec![0u8; name_size as usize];
266 if CertStrToNameW(
267 MY_ENCODING_TYPE,
268 field_name.as_ptr(),
269 CERT_X500_NAME_STR,
270 ptr::null(),
271 x509name.as_mut_ptr(),
272 &mut name_size,
273 ptr::null_mut(),
274 ) == 0
275 {
276 return Err(CngError::from_win32_error());
277 }
278
279 let name_blob = CRYPT_INTEGER_BLOB {
280 cbData: x509name.len() as _,
281 pbData: x509name.as_mut_ptr(),
282 };
283
284 self.do_find(flags, &name_blob as *const _ as _)
285 }
286 }
287}
288
289impl Drop for CertStore {
290 fn drop(&mut self) {
291 unsafe { CertCloseStore(self.0, 0) };
292 }
293}