1use crate::windows;
2use windows::{
3 core::HRESULT,
4 Win32::Foundation::{E_UNEXPECTED, WIN32_ERROR},
5};
6
7pub fn dual_call<F, T>(
8 first_call_expectation: FirstCallExpectation<T>,
9 mut call: F,
10) -> windows::core::Result<T>
11where
12 F: FnMut(bool) -> windows::core::Result<T>,
13 T: PartialEq,
14{
15 match first_call_expectation {
24 FirstCallExpectation::Ok => {
25 call(true)?;
26 call(false)
27 }
28 FirstCallExpectation::OkValue(expected_value) => {
29 if call(true)? == expected_value {
30 call(false)
31 } else {
32 Err(E_UNEXPECTED.into())
33 }
34 }
35 other_expectation => {
36 let expected_h_result = match other_expectation {
37 FirstCallExpectation::Win32Error(win_32_error) => win_32_error.to_hresult(),
38 FirstCallExpectation::HResultError(h_result) => h_result,
39 _ => unreachable!(),
40 };
41
42 match call(true) {
43 Err(error) => {
44 if error.code() == expected_h_result {
45 call(false)
46 } else {
47 Err(error)
48 }
49 }
50 Ok(_) => Err(E_UNEXPECTED.into()),
51 }
52 }
53 }
54}
55
56#[non_exhaustive]
58pub enum FirstCallExpectation<T> {
59 Ok,
63
64 OkValue(T),
68
69 Win32Error(WIN32_ERROR),
71
72 HResultError(HRESULT),
76}
77
78#[cfg(all(test, feature = "windows_latest_compatible_all"))]
79mod tests {
80 use super::{dual_call, FirstCallExpectation};
81 use crate::{
82 core::{CheckNumberError, HResultExt},
83 windows, Null, ResGuard,
84 };
85 use regex::Regex;
86 use windows::{
87 core::{w, PCWSTR, PWSTR},
88 Win32::{
89 Foundation::{
90 ERROR_BUFFER_OVERFLOW, ERROR_INSUFFICIENT_BUFFER, ERROR_MORE_DATA, E_FAIL,
91 E_POINTER, S_FALSE, S_OK, WIN32_ERROR,
92 },
93 NetworkManagement::IpHelper::{
94 GetAdaptersAddresses, GET_ADAPTERS_ADDRESSES_FLAGS, IP_ADAPTER_ADDRESSES_LH,
95 },
96 Networking::WinSock::AF_UNSPEC,
97 Security::{
98 Authorization::ConvertSidToStringSidW, GetTokenInformation, TokenUser,
99 SID_AND_ATTRIBUTES, TOKEN_QUERY,
100 },
101 System::{
102 SystemInformation::{ComputerNameNetBIOS, GetComputerNameExW},
103 Threading::{GetCurrentProcess, OpenProcessToken},
104 },
105 UI::{
106 Input::KeyboardAndMouse::{GetKeyboardLayoutList, HKL},
107 Shell::{AssocQueryStringW, ASSOCF_NONE, ASSOCF_NOTRUNCATE, ASSOCSTR_EXECUTABLE},
108 },
109 },
110 };
111
112 #[test]
113 fn expect_ok() -> windows::core::Result<()> {
114 let mut ids = Vec::<HKL>::new();
115 let mut num_ids = 0; dual_call(FirstCallExpectation::Ok, |getting_buffer_size| {
118 num_ids = unsafe {
119 GetKeyboardLayoutList((!getting_buffer_size).then(|| {
120 ids.resize(num_ids as _, HKL::default());
121 ids.as_mut_slice()
122 }))
123 };
124
125 num_ids.nonzero_or_win32_err()
126 })?;
127
128 assert!(num_ids >= 1 && num_ids <= 20 && ids.iter().all(|hkl| !hkl.is_invalid()));
129
130 Ok(())
131 }
132
133 #[test]
134 fn expect_win32_error_more_data() -> windows::core::Result<()> {
135 let mut buffer = Vec::new();
136 let mut len = 0;
137
138 dual_call(
139 FirstCallExpectation::Win32Error(ERROR_MORE_DATA),
140 |getting_buffer_size| unsafe {
141 GetComputerNameExW(
142 ComputerNameNetBIOS,
143 if getting_buffer_size {
144 PWSTR::NULL
145 } else {
146 buffer.resize(len as _, 0);
147 PWSTR(buffer.as_mut_ptr())
148 },
149 &mut len,
150 )
151 },
152 )?;
153
154 let computer_name = String::from_utf16(&buffer[..len as _])?;
155 assert!(
156 Regex::new(r"^[\w!@#$%^()\-'{}\.~]{1,15}$") .unwrap()
158 .is_match(&computer_name)
159 );
160
161 Ok(())
162 }
163
164 #[test]
165 fn expect_win32_error_insufficient_buffer() -> windows::core::Result<()> {
166 let process_token_handle = ResGuard::with_mut_acq_and_close_handle(|handle| unsafe {
167 OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, handle)
168 })?;
169
170 let mut sid_and_attrs_buffer = Vec::<u8>::new();
171 let mut sid_and_attrs_buffer_size = 0;
172
173 dual_call(
174 FirstCallExpectation::Win32Error(ERROR_INSUFFICIENT_BUFFER),
175 |getting_buffer_size| unsafe {
176 GetTokenInformation(
177 *process_token_handle,
178 TokenUser,
179 (!getting_buffer_size).then(|| {
180 sid_and_attrs_buffer.resize(sid_and_attrs_buffer_size as _, 0);
181 sid_and_attrs_buffer.as_mut_ptr().cast()
182 }),
183 sid_and_attrs_buffer_size,
184 &mut sid_and_attrs_buffer_size,
185 )
186 },
187 )?;
188
189 let string_sid = unsafe {
190 ResGuard::<PWSTR>::with_mut_acq_and_local_free(|pwstr| {
191 ConvertSidToStringSidW(
192 (&*sid_and_attrs_buffer.as_ptr().cast::<SID_AND_ATTRIBUTES>()).Sid,
193 pwstr,
194 )
195 })?
196 .to_string()?
197 };
198
199 assert!(Regex::new(r"^S-1-5(?:-\d+)+$")
200 .unwrap()
201 .is_match(&string_sid));
202
203 Ok(())
204 }
205
206 #[test]
207 fn expect_win32_error_buffer_overflow_from_return() -> windows::core::Result<()> {
208 let mut byte_buffer = Vec::<u8>::new();
209 let mut buffer_size = 0;
210
211 dual_call(
212 FirstCallExpectation::Win32Error(ERROR_BUFFER_OVERFLOW),
213 |getting_buffer_size| {
214 WIN32_ERROR(unsafe {
215 GetAdaptersAddresses(
216 AF_UNSPEC.0 as _,
217 GET_ADAPTERS_ADDRESSES_FLAGS(0),
218 None,
219 (!getting_buffer_size).then(|| {
220 byte_buffer.resize(buffer_size as _, 0);
221 byte_buffer.as_mut_ptr().cast()
222 }),
223 &mut buffer_size,
224 )
225 })
226 .to_hresult()
227 .ok()
228 },
229 )?;
230
231 let mut adapter_names = Vec::new();
232 let mut ip_adapter_addresses =
233 unsafe { &*byte_buffer.as_ptr().cast::<IP_ADAPTER_ADDRESSES_LH>() };
234
235 loop {
236 let adapter_name = unsafe { ip_adapter_addresses.FriendlyName.to_string()? };
237 if !adapter_names.contains(&adapter_name) {
238 adapter_names.push(adapter_name);
239 }
240
241 if ip_adapter_addresses.Next.is_null() {
242 break;
243 }
244 ip_adapter_addresses = unsafe { &*ip_adapter_addresses.Next };
245 }
246
247 let validate_regex = Regex::new(r"^[\x20-\x7f\p{Letter}]+$").unwrap();
248 assert!(adapter_names
249 .iter()
250 .all(|name| validate_regex.is_match(&name)));
251
252 Ok(())
253 }
254
255 #[test]
256 fn expect_ok_value() -> windows::core::Result<()> {
257 let mut buffer = Vec::new();
258 let mut buffer_size = 0;
259
260 let success_hresult = dual_call(
261 FirstCallExpectation::OkValue(S_FALSE),
262 |getting_buffer_size| {
263 unsafe {
264 AssocQueryStringW(
265 ASSOCF_NONE,
266 ASSOCSTR_EXECUTABLE,
267 w!(".msi"),
268 PCWSTR::NULL,
269 if getting_buffer_size {
270 PWSTR::NULL
271 } else {
272 buffer.resize(buffer_size as _, 0);
273 PWSTR(buffer.as_mut_ptr())
274 },
275 &mut buffer_size,
276 )
277 }
278 .ok_with_hresult()
279 },
280 )?;
281
282 if success_hresult == S_OK && buffer_size > 0 {
283 let string = String::from_utf16(&buffer[..(buffer_size - 1) as _])?;
284 assert!(Regex::new(r"(?i)\\System32\\msiexec.exe$")
285 .unwrap()
286 .is_match(&string));
287 } else {
288 return Err(E_FAIL.into());
289 }
290
291 Ok(())
292 }
293
294 #[test]
295 fn expect_hresult_error() -> windows::core::Result<()> {
296 let mut buffer = Vec::new();
297 let mut buffer_size = 0;
298
299 let success_hresult = dual_call(
300 FirstCallExpectation::HResultError(E_POINTER),
301 |getting_buffer_size| {
302 if !getting_buffer_size {
303 buffer.resize(buffer_size as _, 0);
304 }
305
306 unsafe {
307 AssocQueryStringW(
308 ASSOCF_NOTRUNCATE,
309 ASSOCSTR_EXECUTABLE,
310 w!(".msi"),
311 PCWSTR::NULL,
312 PWSTR(buffer.as_mut_ptr()),
313 &mut buffer_size,
314 )
315 }
316 .ok_with_hresult()
317 },
318 )?;
319
320 if success_hresult == S_OK && buffer_size > 0 {
321 let string = String::from_utf16(&buffer[..(buffer_size - 1) as _])?;
322 assert!(Regex::new(r"(?i)\\System32\\msiexec.exe$")
323 .unwrap()
324 .is_match(&string));
325 } else {
326 return Err(E_FAIL.into());
327 }
328
329 Ok(())
330 }
331}