windows_helpers/
dual_call.rs

1use crate::windows;
2use windows::{
3    core::HRESULT,
4    Win32::Foundation::{E_UNEXPECTED, WIN32_ERROR},
5};
6
7pub fn dual_call<F, T>(
8    first_call_expectation: FirstCallExpectation<T>,
9    mut call: F,
10) -> windows::core::Result<T>
11where
12    F: FnMut(bool) -> windows::core::Result<T>,
13    T: PartialEq,
14{
15    //! For functions that are to be called with a preparation step, normally to determine the required buffer size.
16    //!
17    //! You may find that this is easier to verify for correctness than a `loop` approach - so, less straining on the mind and less time-consuming. It's also more versatile than a `for` approach.
18    //!
19    //! The closure parameter will be `true` for the first call. It can be called something like `getting_buffer_size`.
20    //!
21    //! If the expectation after the first call isn't met and it returned an `Err`, the function ends with that `Err`. If the first call returned `Ok`, however, and this didn't harmonize with the expectation, `Err` including `HRESULT` `E_UNEXPECTED` is returned.
22
23    match first_call_expectation {
24        FirstCallExpectation::Ok => {
25            call(true)?;
26            call(false)
27        }
28        FirstCallExpectation::OkValue(expected_value) => {
29            if call(true)? == expected_value {
30                call(false)
31            } else {
32                Err(E_UNEXPECTED.into())
33            }
34        }
35        other_expectation => {
36            let expected_h_result = match other_expectation {
37                FirstCallExpectation::Win32Error(win_32_error) => win_32_error.to_hresult(),
38                FirstCallExpectation::HResultError(h_result) => h_result,
39                _ => unreachable!(),
40            };
41
42            match call(true) {
43                Err(error) => {
44                    if error.code() == expected_h_result {
45                        call(false)
46                    } else {
47                        Err(error)
48                    }
49                }
50                Ok(_) => Err(E_UNEXPECTED.into()),
51            }
52        }
53    }
54}
55
56/// Defining the return value of the first call of [`dual_call()`] that is the precondition to continue with the second call.
57#[non_exhaustive]
58pub enum FirstCallExpectation<T> {
59    /// Useful with a function like [`GetKeyboardLayoutList()`][1].
60    ///
61    /// [1]: https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-getkeyboardlayoutlist
62    Ok,
63
64    /// Useful with a function like [`AssocQueryStringW()`][1].
65    ///
66    /// [1]: https://learn.microsoft.com/en-us/windows/win32/api/shlwapi/nf-shlwapi-assocquerystringw
67    OkValue(T),
68
69    /// The most useful. Requires `ERROR_INSUFFICIENT_BUFFER` most often, if not documented.
70    Win32Error(WIN32_ERROR),
71
72    /// Useful with a function like [`AssocQueryStringW()`][1] (in `ASSOCF_NOTRUNCATE` mode).
73    ///
74    /// [1]: https://learn.microsoft.com/en-us/windows/win32/api/shlwapi/nf-shlwapi-assocquerystringw
75    HResultError(HRESULT),
76}
77
78#[cfg(all(test, feature = "windows_latest_compatible_all"))]
79mod tests {
80    use super::{dual_call, FirstCallExpectation};
81    use crate::{
82        core::{CheckNumberError, HResultExt},
83        windows, Null, ResGuard,
84    };
85    use regex::Regex;
86    use windows::{
87        core::{w, PCWSTR, PWSTR},
88        Win32::{
89            Foundation::{
90                ERROR_BUFFER_OVERFLOW, ERROR_INSUFFICIENT_BUFFER, ERROR_MORE_DATA, E_FAIL,
91                E_POINTER, S_FALSE, S_OK, WIN32_ERROR,
92            },
93            NetworkManagement::IpHelper::{
94                GetAdaptersAddresses, GET_ADAPTERS_ADDRESSES_FLAGS, IP_ADAPTER_ADDRESSES_LH,
95            },
96            Networking::WinSock::AF_UNSPEC,
97            Security::{
98                Authorization::ConvertSidToStringSidW, GetTokenInformation, TokenUser,
99                SID_AND_ATTRIBUTES, TOKEN_QUERY,
100            },
101            System::{
102                SystemInformation::{ComputerNameNetBIOS, GetComputerNameExW},
103                Threading::{GetCurrentProcess, OpenProcessToken},
104            },
105            UI::{
106                Input::KeyboardAndMouse::{GetKeyboardLayoutList, HKL},
107                Shell::{AssocQueryStringW, ASSOCF_NONE, ASSOCF_NOTRUNCATE, ASSOCSTR_EXECUTABLE},
108            },
109        },
110    };
111
112    #[test]
113    fn expect_ok() -> windows::core::Result<()> {
114        let mut ids = Vec::<HKL>::new();
115        let mut num_ids = 0; // Will equal number of input locales in Windows UI.
116
117        dual_call(FirstCallExpectation::Ok, |getting_buffer_size| {
118            num_ids = unsafe {
119                GetKeyboardLayoutList((!getting_buffer_size).then(|| {
120                    ids.resize(num_ids as _, HKL::default());
121                    ids.as_mut_slice()
122                }))
123            };
124
125            num_ids.nonzero_or_win32_err()
126        })?;
127
128        assert!(num_ids >= 1 && num_ids <= 20 && ids.iter().all(|hkl| !hkl.is_invalid()));
129
130        Ok(())
131    }
132
133    #[test]
134    fn expect_win32_error_more_data() -> windows::core::Result<()> {
135        let mut buffer = Vec::new();
136        let mut len = 0;
137
138        dual_call(
139            FirstCallExpectation::Win32Error(ERROR_MORE_DATA),
140            |getting_buffer_size| unsafe {
141                GetComputerNameExW(
142                    ComputerNameNetBIOS,
143                    if getting_buffer_size {
144                        PWSTR::NULL
145                    } else {
146                        buffer.resize(len as _, 0);
147                        PWSTR(buffer.as_mut_ptr())
148                    },
149                    &mut len,
150                )
151            },
152        )?;
153
154        let computer_name = String::from_utf16(&buffer[..len as _])?;
155        assert!(
156            Regex::new(r"^[\w!@#$%^()\-'{}\.~]{1,15}$") // https://stackoverflow.com/a/24095455
157                .unwrap()
158                .is_match(&computer_name)
159        );
160
161        Ok(())
162    }
163
164    #[test]
165    fn expect_win32_error_insufficient_buffer() -> windows::core::Result<()> {
166        let process_token_handle = ResGuard::with_mut_acq_and_close_handle(|handle| unsafe {
167            OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, handle)
168        })?;
169
170        let mut sid_and_attrs_buffer = Vec::<u8>::new();
171        let mut sid_and_attrs_buffer_size = 0;
172
173        dual_call(
174            FirstCallExpectation::Win32Error(ERROR_INSUFFICIENT_BUFFER),
175            |getting_buffer_size| unsafe {
176                GetTokenInformation(
177                    *process_token_handle,
178                    TokenUser,
179                    (!getting_buffer_size).then(|| {
180                        sid_and_attrs_buffer.resize(sid_and_attrs_buffer_size as _, 0);
181                        sid_and_attrs_buffer.as_mut_ptr().cast()
182                    }),
183                    sid_and_attrs_buffer_size,
184                    &mut sid_and_attrs_buffer_size,
185                )
186            },
187        )?;
188
189        let string_sid = unsafe {
190            ResGuard::<PWSTR>::with_mut_acq_and_local_free(|pwstr| {
191                ConvertSidToStringSidW(
192                    (&*sid_and_attrs_buffer.as_ptr().cast::<SID_AND_ATTRIBUTES>()).Sid,
193                    pwstr,
194                )
195            })?
196            .to_string()?
197        };
198
199        assert!(Regex::new(r"^S-1-5(?:-\d+)+$")
200            .unwrap()
201            .is_match(&string_sid));
202
203        Ok(())
204    }
205
206    #[test]
207    fn expect_win32_error_buffer_overflow_from_return() -> windows::core::Result<()> {
208        let mut byte_buffer = Vec::<u8>::new();
209        let mut buffer_size = 0;
210
211        dual_call(
212            FirstCallExpectation::Win32Error(ERROR_BUFFER_OVERFLOW),
213            |getting_buffer_size| {
214                WIN32_ERROR(unsafe {
215                    GetAdaptersAddresses(
216                        AF_UNSPEC.0 as _,
217                        GET_ADAPTERS_ADDRESSES_FLAGS(0),
218                        None,
219                        (!getting_buffer_size).then(|| {
220                            byte_buffer.resize(buffer_size as _, 0);
221                            byte_buffer.as_mut_ptr().cast()
222                        }),
223                        &mut buffer_size,
224                    )
225                })
226                .to_hresult()
227                .ok()
228            },
229        )?;
230
231        let mut adapter_names = Vec::new();
232        let mut ip_adapter_addresses =
233            unsafe { &*byte_buffer.as_ptr().cast::<IP_ADAPTER_ADDRESSES_LH>() };
234
235        loop {
236            let adapter_name = unsafe { ip_adapter_addresses.FriendlyName.to_string()? };
237            if !adapter_names.contains(&adapter_name) {
238                adapter_names.push(adapter_name);
239            }
240
241            if ip_adapter_addresses.Next.is_null() {
242                break;
243            }
244            ip_adapter_addresses = unsafe { &*ip_adapter_addresses.Next };
245        }
246
247        let validate_regex = Regex::new(r"^[\x20-\x7f\p{Letter}]+$").unwrap();
248        assert!(adapter_names
249            .iter()
250            .all(|name| validate_regex.is_match(&name)));
251
252        Ok(())
253    }
254
255    #[test]
256    fn expect_ok_value() -> windows::core::Result<()> {
257        let mut buffer = Vec::new();
258        let mut buffer_size = 0;
259
260        let success_hresult = dual_call(
261            FirstCallExpectation::OkValue(S_FALSE),
262            |getting_buffer_size| {
263                unsafe {
264                    AssocQueryStringW(
265                        ASSOCF_NONE,
266                        ASSOCSTR_EXECUTABLE,
267                        w!(".msi"),
268                        PCWSTR::NULL,
269                        if getting_buffer_size {
270                            PWSTR::NULL
271                        } else {
272                            buffer.resize(buffer_size as _, 0);
273                            PWSTR(buffer.as_mut_ptr())
274                        },
275                        &mut buffer_size,
276                    )
277                }
278                .ok_with_hresult()
279            },
280        )?;
281
282        if success_hresult == S_OK && buffer_size > 0 {
283            let string = String::from_utf16(&buffer[..(buffer_size - 1) as _])?;
284            assert!(Regex::new(r"(?i)\\System32\\msiexec.exe$")
285                .unwrap()
286                .is_match(&string));
287        } else {
288            return Err(E_FAIL.into());
289        }
290
291        Ok(())
292    }
293
294    #[test]
295    fn expect_hresult_error() -> windows::core::Result<()> {
296        let mut buffer = Vec::new();
297        let mut buffer_size = 0;
298
299        let success_hresult = dual_call(
300            FirstCallExpectation::HResultError(E_POINTER),
301            |getting_buffer_size| {
302                if !getting_buffer_size {
303                    buffer.resize(buffer_size as _, 0);
304                }
305
306                unsafe {
307                    AssocQueryStringW(
308                        ASSOCF_NOTRUNCATE,
309                        ASSOCSTR_EXECUTABLE,
310                        w!(".msi"),
311                        PCWSTR::NULL,
312                        PWSTR(buffer.as_mut_ptr()),
313                        &mut buffer_size,
314                    )
315                }
316                .ok_with_hresult()
317            },
318        )?;
319
320        if success_hresult == S_OK && buffer_size > 0 {
321            let string = String::from_utf16(&buffer[..(buffer_size - 1) as _])?;
322            assert!(Regex::new(r"(?i)\\System32\\msiexec.exe$")
323                .unwrap()
324                .is_match(&string));
325        } else {
326            return Err(E_FAIL.into());
327        }
328
329        Ok(())
330    }
331}