1use std::{mem, ptr, slice, sync::Arc};
4
5use windows_sys::Win32::Security::Cryptography::*;
6
7use crate::{error::CngError, key::NCryptKey, Result};
8
9const HCCE_LOCAL_MACHINE: HCERTCHAINENGINE = 0x1 as HCERTCHAINENGINE;
10
11#[derive(Debug)]
12enum InnerContext {
13 Owned(*const CERT_CONTEXT),
14 Borrowed(*const CERT_CONTEXT),
15}
16
17unsafe impl Send for InnerContext {}
18unsafe impl Sync for InnerContext {}
19
20impl InnerContext {
21 fn inner(&self) -> *const CERT_CONTEXT {
22 match self {
23 Self::Owned(handle) => *handle,
24 Self::Borrowed(handle) => *handle,
25 }
26 }
27}
28
29impl Drop for InnerContext {
30 fn drop(&mut self) {
31 match self {
32 Self::Owned(handle) => unsafe {
33 CertFreeCertificateContext(*handle);
34 },
35 Self::Borrowed(_) => {}
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct CertContext(Arc<InnerContext>);
43
44impl CertContext {
45 pub fn new_owned(context: *const CERT_CONTEXT) -> Self {
47 Self(Arc::new(InnerContext::Owned(context)))
48 }
49
50 pub fn new_borrowed(context: *const CERT_CONTEXT) -> Self {
52 Self(Arc::new(InnerContext::Borrowed(context)))
53 }
54
55 pub fn inner(&self) -> &CERT_CONTEXT {
57 unsafe { &*self.0.inner() }
58 }
59
60 pub fn acquire_key(&self, silent: bool) -> Result<NCryptKey> {
63 let mut handle = HCRYPTPROV_OR_NCRYPT_KEY_HANDLE::default();
64 let mut key_spec = CERT_KEY_SPEC::default();
65
66 let flags =
67 if silent { CRYPT_ACQUIRE_SILENT_FLAG } else { 0 } | CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG;
68
69 unsafe {
70 let result = CryptAcquireCertificatePrivateKey(
71 self.inner(),
72 flags,
73 ptr::null(),
74 &mut handle,
75 &mut key_spec,
76 ptr::null_mut(),
77 ) != 0;
78 if result {
79 let mut key = NCryptKey::new_owned(handle);
80 key.set_silent(silent);
81 Ok(key)
82 } else {
83 Err(CngError::from_win32_error())
84 }
85 }
86 }
87
88 pub fn as_der(&self) -> &[u8] {
90 unsafe {
91 slice::from_raw_parts(
92 self.inner().pbCertEncoded,
93 self.inner().cbCertEncoded as usize,
94 )
95 }
96 }
97
98 pub fn as_chain_der(&self) -> Result<Vec<Vec<u8>>> {
101 unsafe {
102 let param = CERT_CHAIN_PARA {
103 cbSize: mem::size_of::<CERT_CHAIN_PARA>() as u32,
104 RequestedUsage: std::mem::zeroed(),
105 };
106 let mut context: *mut CERT_CHAIN_CONTEXT = ptr::null_mut();
107 let mut dw_access_state_flags: u32 = 0;
108 let mut cb_data = mem::size_of::<u32>() as u32;
109
110 let chain_engine = if CertGetCertificateContextProperty(
111 self.inner(),
112 CERT_ACCESS_STATE_PROP_ID,
113 &mut dw_access_state_flags as *mut _ as *mut _,
114 &mut cb_data as *mut _,
115 ) != 0
116 && (dw_access_state_flags & CERT_ACCESS_STATE_LM_SYSTEM_STORE_FLAG) != 0
117 {
118 HCCE_LOCAL_MACHINE
119 } else {
120 HCERTCHAINENGINE::default()
121 };
122
123 let result = CertGetCertificateChain(
124 chain_engine,
125 self.inner(),
126 ptr::null(),
127 ptr::null_mut(),
128 ¶m,
129 0,
130 ptr::null(),
131 &mut context,
132 ) != 0;
133
134 if result {
135 let mut chain = vec![];
136
137 if (*context).cChain > 0 {
138 let chain_ptr = *(*context).rgpChain;
139 let elements = slice::from_raw_parts(
140 (*chain_ptr).rgpElement,
141 (*chain_ptr).cElement as usize,
142 );
143
144 for (index, element) in elements.iter().enumerate() {
145 if index != 0 {
146 if 0 != ((**element).TrustStatus.dwInfoStatus
147 & CERT_TRUST_IS_SELF_SIGNED)
148 {
149 break;
150 }
151 }
152
153 let context = (**element).pCertContext;
154 chain.push(Self::new_borrowed(context).as_der().to_vec());
155 }
156 }
157
158 CertFreeCertificateChain(&*context);
159 Ok(chain)
160 } else {
161 Err(CngError::from_win32_error())
162 }
163 }
164 }
165}