wdk_macros/
lib.rs

1// Copyright (c) Microsoft Corporation
2// License: MIT OR Apache-2.0
3
4//! A collection of macros that help make it easier to interact with
5//! [`wdk-sys`]'s direct bindings to the Windows Driver Kit (WDK).
6
7use std::{collections::BTreeMap, path::PathBuf, str::FromStr};
8
9use fs4::fs_std::FileExt;
10use itertools::Itertools;
11use proc_macro::TokenStream;
12use proc_macro2::{Span, TokenStream as TokenStream2};
13use quote::{format_ident, quote, ToTokens};
14use serde::{Deserialize, Serialize};
15use syn::{
16    parse::{Parse, ParseStream, Parser},
17    parse2,
18    parse_file,
19    parse_quote,
20    punctuated::Punctuated,
21    AngleBracketedGenericArguments,
22    Attribute,
23    BareFnArg,
24    Error,
25    Expr,
26    ExprCall,
27    File,
28    GenericArgument,
29    Ident,
30    Item,
31    ItemType,
32    LitStr,
33    Path,
34    PathArguments,
35    PathSegment,
36    Result,
37    ReturnType,
38    Signature,
39    Stmt,
40    Token,
41    Type,
42    TypeBareFn,
43    TypePath,
44};
45
46/// Name of the `bindgen`-generated Rust module that contains the `TableIndex`
47/// constants for the `WDF`'s function table
48const WDF_FUNC_ENUM_MOD_NAME: &str = "_WDFFUNCENUM";
49
50/// A procedural macro that allows WDF functions to be called by name.
51///
52/// This macro is only intended to be used in the `wdk-sys` crate. Users wanting
53/// to call WDF functions should use the macro in `wdk-sys`. This macro differs
54/// from the one in [`wdk-sys`] in that it must pass in the generated types from
55/// `wdk-sys` as an argument to the macro.
56#[proc_macro]
57pub fn call_unsafe_wdf_function_binding(input_tokens: TokenStream) -> TokenStream {
58    call_unsafe_wdf_function_binding_impl(TokenStream2::from(input_tokens)).into()
59}
60
61/// A trait to provide additional functionality to the [`String`] type
62trait StringExt {
63    /// Convert a string to `snake_case`
64    fn to_snake_case(&self) -> String;
65}
66
67/// A trait to provide additional functionality to `std::result::Result`
68trait ResultExt<T, E> {
69    fn to_syn_result(self, span: Span, error: &str) -> syn::Result<T>;
70}
71
72/// Struct storing string representations of the information we want to cache
73/// from `types.rs`.
74#[derive(Debug, Deserialize, PartialEq, Serialize)]
75struct CachedFunctionInfo {
76    parameters: String,
77    return_type: String,
78}
79
80/// Struct storing the input tokens directly parsed from calls to
81/// `call_unsafe_wdf_function_binding` macro
82#[derive(Debug, PartialEq)]
83struct Inputs {
84    /// Path to file where generated type information resides.
85    types_path: LitStr,
86    /// The name of the WDF function to call. This matches the name of the
87    /// function in C/C++.
88    wdf_function_identifier: Ident,
89    /// The arguments to pass to the WDF function. These should match the
90    /// function signature of the WDF function.
91    wdf_function_arguments: Punctuated<Expr, Token![,]>,
92}
93
94/// Struct storing all the AST fragments derived from [`Inputs`]. This
95/// represents all the ASTs derived from [`Inputs`]. These ultimately get used
96/// in the final generated code.
97#[derive(Debug, PartialEq)]
98struct DerivedASTFragments {
99    function_pointer_type: Ident,
100    function_table_index: Ident,
101    parameters: Punctuated<BareFnArg, Token![,]>,
102    parameter_identifiers: Punctuated<Ident, Token![,]>,
103    return_type: ReturnType,
104    arguments: Punctuated<Expr, Token![,]>,
105    inline_wdf_fn_name: Ident,
106}
107
108/// Struct storing the AST fragments that form distinct sections of the final
109/// generated code. Each field is derived from [`DerivedASTFragments`].
110struct IntermediateOutputASTFragments {
111    must_use_attribute: Option<Attribute>,
112    inline_wdf_fn_signature: Signature,
113    inline_wdf_fn_body_statments: Vec<Stmt>,
114    inline_wdf_fn_invocation: ExprCall,
115}
116
117/// Struct to represent a file lock guard. This struct enforces RAII, ensuring
118/// that the file lock is released when the guard goes out of scope.
119struct FileLockGuard {
120    file: std::fs::File,
121}
122
123impl FileLockGuard {
124    fn new(file: std::fs::File, span: Span) -> Result<Self> {
125        FileExt::lock_exclusive(&file).to_syn_result(span, "unable to obtain file lock")?;
126        Ok(Self { file })
127    }
128}
129
130impl Drop for FileLockGuard {
131    fn drop(&mut self) {
132        let _ = FileExt::unlock(&self.file);
133    }
134}
135
136impl StringExt for String {
137    fn to_snake_case(&self) -> String {
138        // There will be, at max, 2 characters unhandled by the 3-char windows. It is
139        // only less than 2 when the string has length less than 2
140        const MAX_PADDING_NEEDED: usize = 2;
141
142        let mut snake_case_string = Self::with_capacity(self.len());
143
144        for (current_char, next_char, next_next_char) in self
145            .chars()
146            .map(Some)
147            .chain([None; MAX_PADDING_NEEDED])
148            .tuple_windows()
149            .filter_map(|(c1, c2, c3)| Some((c1?, c2, c3)))
150        {
151            // Handle camelCase or PascalCase word boundary (e.g. lC in camelCase)
152            if current_char.is_lowercase() && next_char.is_some_and(|c| c.is_ascii_uppercase()) {
153                snake_case_string.push(current_char);
154                snake_case_string.push('_');
155            }
156            // Handle UPPERCASE acronym word boundary (e.g. ISt in ASCIIString)
157            else if current_char.is_uppercase()
158                && next_char.is_some_and(|c| c.is_ascii_uppercase())
159                && next_next_char.is_some_and(|c| c.is_ascii_lowercase())
160            {
161                snake_case_string.push(current_char.to_ascii_lowercase());
162                snake_case_string.push('_');
163            } else {
164                snake_case_string.push(current_char.to_ascii_lowercase());
165            }
166        }
167
168        snake_case_string
169    }
170}
171
172impl<T, E: std::error::Error> ResultExt<T, E> for std::result::Result<T, E> {
173    fn to_syn_result(self, span: Span, error_description: &str) -> syn::Result<T> {
174        self.map_err(|err| Error::new(span, format!("{error_description}, {err}")))
175    }
176}
177
178impl From<(Punctuated<BareFnArg, Token![,]>, ReturnType)> for CachedFunctionInfo {
179    fn from((parameters, return_type): (Punctuated<BareFnArg, Token![,]>, ReturnType)) -> Self {
180        Self {
181            parameters: parameters.to_token_stream().to_string(),
182            return_type: return_type.to_token_stream().to_string(),
183        }
184    }
185}
186
187impl Parse for Inputs {
188    fn parse(input: ParseStream) -> Result<Self> {
189        let types_path = input.parse::<LitStr>()?;
190
191        input.parse::<Token![,]>()?;
192        let c_wdf_function_identifier = input.parse::<Ident>()?;
193
194        // Support WDF apis with no arguments
195        if input.is_empty() {
196            return Ok(Self {
197                types_path,
198                wdf_function_identifier: c_wdf_function_identifier,
199                wdf_function_arguments: Punctuated::new(),
200            });
201        }
202
203        input.parse::<Token![,]>()?;
204        let wdf_function_arguments = input.parse_terminated(Expr::parse, Token![,])?;
205
206        Ok(Self {
207            types_path,
208            wdf_function_identifier: c_wdf_function_identifier,
209            wdf_function_arguments,
210        })
211    }
212}
213
214impl Inputs {
215    fn generate_derived_ast_fragments(self) -> Result<DerivedASTFragments> {
216        let function_pointer_type = format_ident!(
217            "PFN_{uppercase_c_function_name}",
218            uppercase_c_function_name = self.wdf_function_identifier.to_string().to_uppercase(),
219            span = self.wdf_function_identifier.span()
220        );
221        let function_table_index = format_ident!(
222            "{wdf_function_identifier}TableIndex",
223            wdf_function_identifier = self.wdf_function_identifier,
224            span = self.wdf_function_identifier.span()
225        );
226
227        let function_name_to_info_map: BTreeMap<String, CachedFunctionInfo> =
228            get_wdf_function_info_map(&self.types_path, self.wdf_function_identifier.span())?;
229        let function_info = function_name_to_info_map
230            .get(&self.wdf_function_identifier.to_string())
231            .ok_or_else(|| {
232                Error::new(
233                    self.wdf_function_identifier.span(),
234                    format!(
235                        "Failed to find function info for {}",
236                        self.wdf_function_identifier
237                    ),
238                )
239            })?;
240        let parameters_tokens = TokenStream2::from_str(&function_info.parameters).to_syn_result(
241            self.wdf_function_identifier.span(),
242            "unable to parse parameter tokens",
243        )?;
244        let return_type_tokens = TokenStream2::from_str(&function_info.return_type).to_syn_result(
245            self.wdf_function_identifier.span(),
246            "unable to parse return type tokens",
247        )?;
248        let parameters =
249            Punctuated::<BareFnArg, Token![,]>::parse_terminated.parse2(parameters_tokens)?;
250        let return_type = ReturnType::parse.parse2(return_type_tokens)?;
251
252        let parameter_identifiers = parameters
253            .iter()
254            .cloned()
255            .map(|bare_fn_arg| {
256                if let Some((identifier, _)) = bare_fn_arg.name {
257                    return Ok(identifier);
258                }
259                Err(Error::new(
260                    function_pointer_type.span(),
261                    format!("Expected fn parameter to have a name: {bare_fn_arg:#?}"),
262                ))
263            })
264            .collect::<Result<_>>()?;
265        let inline_wdf_fn_name = format_ident!(
266            "{c_function_name_snake_case}_impl",
267            c_function_name_snake_case = self.wdf_function_identifier.to_string().to_snake_case()
268        );
269
270        Ok(DerivedASTFragments {
271            function_pointer_type,
272            function_table_index,
273            parameters,
274            parameter_identifiers,
275            return_type,
276            arguments: self.wdf_function_arguments,
277            inline_wdf_fn_name,
278        })
279    }
280}
281
282impl DerivedASTFragments {
283    fn generate_intermediate_output_ast_fragments(self) -> IntermediateOutputASTFragments {
284        let Self {
285            function_pointer_type,
286            function_table_index,
287            parameters,
288            parameter_identifiers,
289            return_type,
290            arguments,
291            inline_wdf_fn_name,
292        } = self;
293
294        let must_use_attribute = generate_must_use_attribute(&return_type);
295
296        let inline_wdf_fn_signature = parse_quote! {
297            unsafe fn #inline_wdf_fn_name(#parameters) #return_type
298        };
299
300        let inline_wdf_fn_body_statments = parse_quote! {
301            // Get handle to WDF function from the function table
302            let wdf_function: wdk_sys::#function_pointer_type = Some(
303                // SAFETY: This `transmute` from a no-argument function pointer to a function pointer with the correct
304                //         arguments for the WDF function is safe befause WDF maintains the strict mapping between the
305                //         function table index and the correct function pointer type.
306                unsafe {
307                    let wdf_function_table = wdk_sys::WdfFunctions;
308                    let wdf_function_count = wdk_sys::wdf::__private::get_wdf_function_count();
309
310                    // SAFETY: This is safe because:
311                    //         1. `WdfFunctions` is valid for reads for `{NUM_WDF_FUNCTIONS_PLACEHOLDER}` * `core::mem::size_of::<WDFFUNC>()`
312                    //            bytes, and is guaranteed to be aligned and it must be properly aligned.
313                    //         2. `WdfFunctions` points to `{NUM_WDF_FUNCTIONS_PLACEHOLDER}` consecutive properly initialized values of
314                    //            type `WDFFUNC`.
315                    //         3. WDF does not mutate the memory referenced by the returned slice for for its entire `'static' lifetime.
316                    //         4. The total size, `{NUM_WDF_FUNCTIONS_PLACEHOLDER}` * `core::mem::size_of::<WDFFUNC>()`, of the slice must be no
317                    //            larger than `isize::MAX`. This is proven by the below `const_assert!`.
318
319                    debug_assert!(isize::try_from(wdf_function_count * core::mem::size_of::<wdk_sys::WDFFUNC>()).is_ok());
320                    let wdf_function_table = core::slice::from_raw_parts(wdf_function_table, wdf_function_count);
321
322                    core::mem::transmute(
323                        // FIXME: investigate why _WDFFUNCENUM does not have a generated type alias without the underscore prefix
324                        wdf_function_table[wdk_sys::_WDFFUNCENUM::#function_table_index as usize],
325                    )
326                }
327            );
328
329            // Call the WDF function with the supplied args. This mirrors what happens in the inlined WDF function in
330            // the various wdf headers(ex. wdfdriver.h)
331            if let Some(wdf_function) = wdf_function {
332                // SAFETY: The WDF function pointer is always valid because its an entry in
333                // `wdk_sys::WDF_FUNCTION_TABLE` indexed by `table_index` and guarded by the type-safety of
334                // `pointer_type`. The passed arguments are also guaranteed to be of a compatible type due to
335                // `pointer_type`.
336                unsafe {
337                    (wdf_function)(
338                        wdk_sys::WdfDriverGlobals,
339                        #parameter_identifiers
340                    )
341                }
342            } else {
343                unreachable!("Option should never be None");
344            }
345        };
346
347        let inline_wdf_fn_invocation = parse_quote! {
348            #inline_wdf_fn_name(#arguments)
349        };
350
351        IntermediateOutputASTFragments {
352            must_use_attribute,
353            inline_wdf_fn_signature,
354            inline_wdf_fn_body_statments,
355            inline_wdf_fn_invocation,
356        }
357    }
358}
359
360impl IntermediateOutputASTFragments {
361    fn assemble_final_output(self) -> TokenStream2 {
362        let Self {
363            must_use_attribute,
364            inline_wdf_fn_signature,
365            inline_wdf_fn_body_statments,
366            inline_wdf_fn_invocation,
367        } = self;
368
369        let conditional_must_use_attribute =
370            must_use_attribute.map_or_else(TokenStream2::new, quote::ToTokens::into_token_stream);
371
372        quote! {
373            {
374                // Use a private module to prevent leaking of glob import into inline_wdf_fn_invocation's parameters
375                mod private__ {
376                    // Glob import types from wdk_sys. glob importing is done instead of blindly prepending the
377                    // paramters types with wdk_sys:: because bindgen generates some paramters as native rust types
378                    use wdk_sys::*;
379
380                    // If the function returns a value, add a `#[must_use]` attribute to the function
381                    #conditional_must_use_attribute
382                    // Encapsulate the code in an inline functions to allow for condition must_use attribute.
383                    //  core::hint::must_use is not stable yet: https://github.com/rust-lang/rust/issues/94745
384                    #[inline(always)]
385                    pub #inline_wdf_fn_signature {
386                        #(#inline_wdf_fn_body_statments)*
387                    }
388                }
389
390                private__::#inline_wdf_fn_invocation
391            }
392        }
393    }
394}
395
396fn call_unsafe_wdf_function_binding_impl(input_tokens: TokenStream2) -> TokenStream2 {
397    let inputs = match parse2::<Inputs>(input_tokens) {
398        Ok(syntax_tree) => syntax_tree,
399        Err(err) => return err.to_compile_error(),
400    };
401
402    let derived_ast_fragments = match inputs.generate_derived_ast_fragments() {
403        Ok(derived_ast_fragments) => derived_ast_fragments,
404        Err(err) => return err.to_compile_error(),
405    };
406
407    derived_ast_fragments
408        .generate_intermediate_output_ast_fragments()
409        .assemble_final_output()
410}
411
412/// Fetch the function table information from the cache, if
413/// it exists. If not, create the cache by reading the
414/// `types.rs` file. Returns a `BTreeMap`, where
415/// `key` is the function name and `value` is the cached function table
416/// information.
417///
418/// Instead of parsing `types.rs` for relevant data on
419/// every macro invocation, all relevant function
420/// table information is extracted during the first `proc-macro` invocation and
421/// serialized to a location accessible by all proc-macro invocations.
422/// Subsequent invocations fetching from the cache significantly reduces
423/// compilation time.
424fn get_wdf_function_info_map(
425    types_path: &LitStr,
426    span: Span,
427) -> Result<BTreeMap<String, CachedFunctionInfo>> {
428    cfg_if::cfg_if! {
429        if #[cfg(test)] {
430            let scratch_dir = scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments_test"));
431        } else {
432            let scratch_dir = scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments"));
433        }
434    }
435
436    let cached_function_info_map_path = scratch_dir.join("cached_function_info_map.json");
437
438    if !cached_function_info_map_path.exists() {
439        let flock = std::fs::File::create(scratch_dir.join(".lock"))
440            .to_syn_result(span, "unable to create file lock")?;
441
442        // When _flock_guard goes out of scope, the file lock is released
443        let _flock_guard = FileLockGuard::new(flock, span)
444            .to_syn_result(span, "unable to create file lock guard")?;
445
446        // Before this thread acquires the lock, it's possible that a concurrent thread
447        // already created the cache. If so, this thread skips cache generation.
448        if !cached_function_info_map_path.exists() {
449            let function_info_map = create_wdf_function_info_file_cache(
450                types_path,
451                cached_function_info_map_path.as_path(),
452                span,
453            )?;
454            return Ok(function_info_map);
455        }
456    }
457    let function_info_map =
458        read_wdf_function_info_file_cache(cached_function_info_map_path.as_path(), span)?;
459    Ok(function_info_map)
460}
461
462/// Reads the cache of function information, then deserializes it into a
463/// `BTreeMap`.
464fn read_wdf_function_info_file_cache(
465    cached_function_info_map_path: &std::path::Path,
466    span: Span,
467) -> Result<BTreeMap<String, CachedFunctionInfo>> {
468    let generated_map_string = std::fs::read_to_string(cached_function_info_map_path)
469        .to_syn_result(span, "unable to read cache to string")?;
470    let map: BTreeMap<String, CachedFunctionInfo> = serde_json::from_str(&generated_map_string)
471        .to_syn_result(span, "unable to parse cache to BTreeMap")?;
472    Ok(map)
473}
474
475/// Generates the cache of function information, then
476/// serializes it into a JSON string and writes it to a designated location.
477/// Must obtain an exclusive file lock prior to calling this function to prevent
478/// concurrent threads from reading and writing to the same file.
479fn create_wdf_function_info_file_cache(
480    types_path: &LitStr,
481    cached_function_info_map_path: &std::path::Path,
482    span: Span,
483) -> Result<BTreeMap<String, CachedFunctionInfo>> {
484    let generated_map = generate_wdf_function_info_file_cache(types_path, span)?;
485    let generated_map_string = serde_json::to_string(&generated_map)
486        .to_syn_result(span, "unable to parse cache to JSON string")?;
487    std::fs::write(cached_function_info_map_path, generated_map_string)
488        .to_syn_result(span, "unable to write cache to file")?;
489    Ok(generated_map)
490}
491
492/// Parses file from `types_path` to generate a `BTreeMap` of
493/// function information, where `key` is the function name and `value` is
494/// the cached function table information.
495fn generate_wdf_function_info_file_cache(
496    types_path: &LitStr,
497    span: Span,
498) -> Result<BTreeMap<String, CachedFunctionInfo>> {
499    let types_ast = parse_types_ast(types_path)?;
500    let func_enum_mod = types_ast
501        .items
502        .iter()
503        .find_map(|item| {
504            if let Item::Mod(mod_alias) = item {
505                if mod_alias.ident == WDF_FUNC_ENUM_MOD_NAME {
506                    return Some(mod_alias);
507                }
508            }
509            None
510        })
511        .ok_or_else(|| {
512            Error::new(
513                span,
514                format!("Failed to find {WDF_FUNC_ENUM_MOD_NAME} module in types.rs file",),
515            )
516        })?;
517
518    let (_brace, func_enum_mod_contents) = &func_enum_mod.content.as_ref().ok_or_else(|| {
519        Error::new(
520            span,
521            format!("Failed to find {WDF_FUNC_ENUM_MOD_NAME} module contents in types.rs file",),
522        )
523    })?;
524
525    func_enum_mod_contents
526        .iter()
527        .filter_map(|item| {
528            if let Item::Const(const_alias) = item {
529                return const_alias
530                    .ident
531                    .to_string()
532                    .strip_suffix("TableIndex")
533                    .and_then(|function_name| {
534                        let function_pointer_type = format_ident!(
535                            "PFN_{uppercase_c_function_name}",
536                            uppercase_c_function_name = function_name.to_uppercase(),
537                            span = span
538                        );
539                        generate_cached_function_info(&types_ast, &function_pointer_type)
540                            .transpose()
541                            .map(|generate_cached_function_info_result| {
542                                generate_cached_function_info_result.map(|cached_function_info| {
543                                    (function_name.to_string(), cached_function_info)
544                                })
545                            })
546                    });
547            }
548            None
549        })
550        .collect()
551}
552
553fn parse_types_ast(path: &LitStr) -> Result<File> {
554    let types_path = PathBuf::from(path.value());
555    let types_path = match types_path.canonicalize() {
556        Ok(types_path) => types_path,
557        Err(err) => {
558            return Err(Error::new(
559                path.span(),
560                format!(
561                    "Failed to canonicalize types_path ({}): {err}",
562                    types_path.display()
563                ),
564            ));
565        }
566    };
567
568    let types_file_contents = match std::fs::read_to_string(&types_path) {
569        Ok(contents) => contents,
570        Err(err) => {
571            return Err(Error::new(
572                path.span(),
573                format!(
574                    "Failed to read wdk-sys types information from {}: {err}",
575                    types_path.display(),
576                ),
577            ));
578        }
579    };
580
581    match parse_file(&types_file_contents) {
582        Ok(wdk_sys_types_rs_abstract_syntax_tree) => Ok(wdk_sys_types_rs_abstract_syntax_tree),
583        Err(err) => Err(Error::new(
584            path.span(),
585            format!(
586                "Failed to parse wdk-sys types information from {} into AST: {err}",
587                types_path.display(),
588            ),
589        )),
590    }
591}
592
593/// Generate the function parameters and return type corresponding to the
594/// function signature of the `function_pointer_type` type alias found in
595/// bindgen-generated types information
596///
597/// # Examples
598///
599/// Passing the `PFN_WDFDRIVERCREATE` [`Ident`] as `function_pointer_type` would
600/// return a [`Punctuated`] representation of
601///
602/// ```rust, compile_fail
603/// DriverObject: PDRIVER_OBJECT,
604/// RegistryPath: PCUNICODE_STRING,
605/// DriverAttributes: WDF_OBJECT_ATTRIBUTES,
606/// DriverConfig: PWDF_DRIVER_CONFIG,
607/// Driver: *mut WDFDRIVER
608/// ```
609///
610/// and return type as the [`ReturnType`] representation of `wdk_sys::NTSTATUS`
611fn generate_cached_function_info(
612    types_ast: &File,
613    function_pointer_type: &Ident,
614) -> Result<Option<CachedFunctionInfo>> {
615    match find_type_alias_definition(types_ast, function_pointer_type) {
616        Ok(type_alias_definition) => {
617            let fn_pointer_definition =
618                extract_fn_pointer_definition(type_alias_definition, function_pointer_type.span())?;
619            Ok(Some(
620                parse_fn_pointer_definition(fn_pointer_definition, function_pointer_type.span())?
621                    .into(),
622            ))
623        }
624        // `types.rs` includes only a subset of types listed in _WDFFUNCENUM. Therefore, not finding
625        // a type alias definition is expected behavior.
626        Err(_err) => Ok(None),
627    }
628}
629
630/// Find type alias declaration and definition that matches the Ident of
631/// `function_pointer_type` in `syn::File` AST
632///
633/// # Examples
634///
635/// Passing the `PFN_WDFDRIVERCREATE` [`Ident`] as `function_pointer_type` would
636/// return a [`ItemType`] representation of:
637///
638/// ```rust, compile_fail
639/// pub type PFN_WDFDRIVERCREATE = ::core::option::Option<
640///     unsafe extern "C" fn(
641///         DriverGlobals: PWDF_DRIVER_GLOBALS,
642///         DriverObject: PDRIVER_OBJECT,
643///         RegistryPath: PCUNICODE_STRING,
644///         DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
645///         DriverConfig: PWDF_DRIVER_CONFIG,
646///         Driver: *mut WDFDRIVER,
647///     ) -> NTSTATUS,
648/// >;
649/// ```
650fn find_type_alias_definition<'a>(
651    types_ast: &'a File,
652    function_pointer_type: &Ident,
653) -> Result<&'a ItemType> {
654    types_ast
655        .items
656        .iter()
657        .find_map(|item| {
658            if let Item::Type(type_alias) = item {
659                if type_alias.ident == *function_pointer_type {
660                    return Some(type_alias);
661                }
662            }
663            None
664        })
665        .ok_or_else(|| {
666            Error::new(
667                function_pointer_type.span(),
668                format!("Failed to find type alias definition for {function_pointer_type}"),
669            )
670        })
671}
672
673/// Extract the [`TypePath`] representing the function pointer definition from
674/// the [`ItemType`]
675///
676/// # Examples
677///
678/// The [`ItemType`] representation of
679///
680/// ```rust, compile_fail
681/// pub type PFN_WDFDRIVERCREATE = ::core::option::Option<
682///     unsafe extern "C" fn(
683///         DriverGlobals: PWDF_DRIVER_GLOBALS,
684///         DriverObject: PDRIVER_OBJECT,
685///         RegistryPath: PCUNICODE_STRING,
686///         DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
687///         DriverConfig: PWDF_DRIVER_CONFIG,
688///         Driver: *mut WDFDRIVER,
689///     ) -> NTSTATUS,
690/// >;
691/// ```
692///
693/// would return the [`TypePath`] representation of
694///
695/// ```rust, compile_fail
696/// ::core::option::Option<
697///     unsafe extern "C" fn(
698///         DriverGlobals: PWDF_DRIVER_GLOBALS,
699///         DriverObject: PDRIVER_OBJECT,
700///         RegistryPath: PCUNICODE_STRING,
701///         DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
702///         DriverConfig: PWDF_DRIVER_CONFIG,
703///         Driver: *mut WDFDRIVER,
704///     ) -> NTSTATUS,
705/// >
706/// ```
707fn extract_fn_pointer_definition(type_alias: &ItemType, error_span: Span) -> Result<&TypePath> {
708    if let Type::Path(fn_pointer) = type_alias.ty.as_ref() {
709        Ok(fn_pointer)
710    } else {
711        Err(Error::new(
712            error_span,
713            format!("Expected Type::Path when parsing  ItemType.ty:\n{type_alias:#?}"),
714        ))
715    }
716}
717
718/// Parse the parameter list (both names and types) and the return type from the
719/// [`TypePath`] representing the function pointer definition
720///
721/// # Examples
722///
723/// The [`TypePath`] representation of
724///
725/// ```rust, compile_fail
726/// ::core::option::Option<
727///     unsafe extern "C" fn(
728///         DriverGlobals: PWDF_DRIVER_GLOBALS,
729///         DriverObject: PDRIVER_OBJECT,
730///         RegistryPath: PCUNICODE_STRING,
731///         DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
732///         DriverConfig: PWDF_DRIVER_CONFIG,
733///         Driver: *mut WDFDRIVER,
734///     ) -> NTSTATUS,
735/// >
736/// ```
737///
738/// would return the parsed parameter list as the [`Punctuated`] representation
739/// of
740///
741/// ```rust, compile_fail
742/// DriverObject: PDRIVER_OBJECT,
743/// RegistryPath: PCUNICODE_STRING,
744/// DriverAttributes: WDF_OBJECT_ATTRIBUTES,
745/// DriverConfig: PWDF_DRIVER_CONFIG,
746/// Driver: *mut WDFDRIVER
747/// ```
748///
749/// and return type as the [`ReturnType`] representation of `wdk_sys::NTSTATUS`
750fn parse_fn_pointer_definition(
751    fn_pointer_typepath: &TypePath,
752    error_span: Span,
753) -> Result<(Punctuated<BareFnArg, Token![,]>, ReturnType)> {
754    let bare_fn_type = extract_bare_fn_type(fn_pointer_typepath, error_span)?;
755    let fn_parameters = compute_fn_parameters(bare_fn_type, error_span)?;
756    let return_type = compute_return_type(bare_fn_type);
757
758    Ok((fn_parameters, return_type))
759}
760
761/// Extract the [`TypeBareFn`] (i.e. function definition) from the [`TypePath`]
762/// (i.e. the function pointer option) representing the function
763///
764/// # Examples
765///
766/// The [`TypePath`] representation of
767///
768/// ```rust, compile_fail
769/// ::core::option::Option<
770///     unsafe extern "C" fn(
771///         DriverGlobals: PWDF_DRIVER_GLOBALS,
772///         DriverObject: PDRIVER_OBJECT,
773///         RegistryPath: PCUNICODE_STRING,
774///         DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
775///         DriverConfig: PWDF_DRIVER_CONFIG,
776///         Driver: *mut WDFDRIVER,
777///     ) -> NTSTATUS,
778/// >
779/// ```
780///
781/// would return the [`TypeBareFn`] representation of
782///
783/// ```rust, compile_fail
784/// unsafe extern "C" fn(
785///     DriverGlobals: PWDF_DRIVER_GLOBALS,
786///     DriverObject: PDRIVER_OBJECT,
787///     RegistryPath: PCUNICODE_STRING,
788///     DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
789///     DriverConfig: PWDF_DRIVER_CONFIG,
790///     Driver: *mut WDFDRIVER,
791/// ) -> NTSTATUS,
792/// ```
793fn extract_bare_fn_type(fn_pointer_typepath: &TypePath, error_span: Span) -> Result<&TypeBareFn> {
794    let option_path_segment: &PathSegment =
795        fn_pointer_typepath.path.segments.last().ok_or_else(|| {
796            Error::new(
797                error_span,
798                format!("Expected at least one PathSegment in TypePath:\n{fn_pointer_typepath:#?}"),
799            )
800        })?;
801    if option_path_segment.ident != "Option" {
802        return Err(Error::new(
803            error_span,
804            format!("Expected Option as last PathSegment in TypePath:\n{fn_pointer_typepath:#?}"),
805        ));
806    }
807    let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
808        args: ref option_angle_bracketed_args,
809        ..
810    }) = option_path_segment.arguments
811    else {
812        return Err(Error::new(
813            error_span,
814            format!(
815                "Expected AngleBracketed PathArguments in Option \
816                 PathSegment:\n{option_path_segment:#?}"
817            ),
818        ));
819    };
820    let bracketed_argument = option_angle_bracketed_args.first().ok_or_else(|| {
821        Error::new(
822            error_span,
823            format!(
824                "Expected exactly one GenericArgument in AngleBracketedGenericArguments:\n{:#?}",
825                option_path_segment.arguments
826            ),
827        )
828    })?;
829    let GenericArgument::Type(Type::BareFn(bare_fn_type)) = bracketed_argument else {
830        return Err(Error::new(
831            error_span,
832            format!("Expected TypeBareFn in GenericArgument:\n{bracketed_argument:#?}"),
833        ));
834    };
835    Ok(bare_fn_type)
836}
837
838/// Compute the function parameters based on the function definition
839///
840/// # Examples
841///
842/// The [`TypeBareFn`] representation of
843///
844/// ```rust, compile_fail
845/// unsafe extern "C" fn(
846///     DriverGlobals: PWDF_DRIVER_GLOBALS,
847///     DriverObject: PDRIVER_OBJECT,
848///     RegistryPath: PCUNICODE_STRING,
849///     DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
850///     DriverConfig: PWDF_DRIVER_CONFIG,
851///     Driver: *mut WDFDRIVER,
852/// ) -> NTSTATUS,
853/// ```
854///
855/// would return the [`Punctuated`] representation of
856/// ```rust, compile_fail
857/// DriverObject: PDRIVER_OBJECT,
858/// RegistryPath: PCUNICODE_STRING,
859/// DriverAttributes: WDF_OBJECT_ATTRIBUTES,
860/// DriverConfig: PWDF_DRIVER_CONFIG,
861/// Driver: *mut WDFDRIVER
862/// ```
863fn compute_fn_parameters(
864    bare_fn_type: &syn::TypeBareFn,
865    error_span: Span,
866) -> Result<Punctuated<BareFnArg, Token![,]>> {
867    // Validate that the first parameter is PWDF_DRIVER_GLOBALS
868    let Some(BareFnArg {
869        ty:
870            Type::Path(TypePath {
871                path:
872                    Path {
873                        segments: first_parameter_type_path,
874                        ..
875                    },
876                ..
877            }),
878        ..
879    }) = bare_fn_type.inputs.first()
880    else {
881        return Err(Error::new(
882            error_span,
883            format!(
884                "Expected at least one input parameter of type Path in \
885                 BareFnType:\n{bare_fn_type:#?}"
886            ),
887        ));
888    };
889    let Some(last_path_segment) = first_parameter_type_path.last() else {
890        return Err(Error::new(
891            error_span,
892            format!("Expected at least one PathSegment in TypePath:\n{bare_fn_type:#?}"),
893        ));
894    };
895    if last_path_segment.ident != "PWDF_DRIVER_GLOBALS" {
896        return Err(Error::new(
897            error_span,
898            format!(
899                "Expected PWDF_DRIVER_GLOBALS as last PathSegment in TypePath of first BareFnArg \
900                 input:\n{bare_fn_type:#?}"
901            ),
902        ));
903    }
904
905    Ok(bare_fn_type
906        .inputs
907        .iter()
908        .skip(1)
909        // transform argument names to snake_case with trailing underscores to lessen likelihood
910        // of shadowing issues
911        .map(|fn_arg| {
912            let arg_name = fn_arg.name.as_ref().map(|(ident, colon_token)| {
913                let modified_name = {
914                    let mut name = ident.to_string().to_snake_case();
915                    name.push_str("__");
916                    name
917                };
918                (Ident::new(&modified_name, ident.span()), *colon_token)
919            });
920
921            BareFnArg {
922                name: arg_name,
923                ..fn_arg.clone()
924            }
925        })
926        .collect())
927}
928
929/// Compute the return type based on the function defintion
930///
931/// # Examples
932///
933/// The [`TypeBareFn`] representation of
934///
935/// ```rust, compile_fail
936/// unsafe extern "C" fn(
937///     DriverGlobals: PWDF_DRIVER_GLOBALS,
938///     DriverObject: PDRIVER_OBJECT,
939///     RegistryPath: PCUNICODE_STRING,
940///     DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
941///     DriverConfig: PWDF_DRIVER_CONFIG,
942///     Driver: *mut WDFDRIVER,
943/// ) -> NTSTATUS,
944/// ```
945///
946/// would return the [`ReturnType`] representation of `wdk_sys::NTSTATUS`
947fn compute_return_type(bare_fn_type: &syn::TypeBareFn) -> ReturnType {
948    bare_fn_type.output.clone()
949}
950
951/// Generate the `#[must_use]` attribute if the return type is not `()`
952fn generate_must_use_attribute(return_type: &ReturnType) -> Option<Attribute> {
953    if matches!(return_type, ReturnType::Type(..)) {
954        Some(parse_quote! { #[must_use] })
955    } else {
956        None
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use std::sync::LazyLock;
963
964    use pretty_assertions::assert_eq as pretty_assert_eq;
965    use quote::ToTokens;
966
967    use super::*;
968
969    static SCRATCH_DIR: LazyLock<PathBuf> =
970        LazyLock::new(|| scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments_test")));
971    const CACHE_FILE_NAME: &str = "cached_function_info_map.json";
972
973    fn with_file_lock_clean_env<F>(f: F)
974    where
975        F: FnOnce(),
976    {
977        let test_flock: std::fs::File =
978            std::fs::File::create(SCRATCH_DIR.join("test.lock")).unwrap();
979        FileExt::lock_exclusive(&test_flock).unwrap();
980
981        let cached_function_info_map_path = SCRATCH_DIR.join(CACHE_FILE_NAME);
982
983        pretty_assert_eq!(
984            cached_function_info_map_path.exists(),
985            false,
986            "could not remove file {}",
987            cached_function_info_map_path.display()
988        );
989
990        f();
991
992        if cached_function_info_map_path.exists() {
993            std::fs::remove_file(cached_function_info_map_path).unwrap();
994        }
995
996        FileExt::unlock(&test_flock).unwrap();
997    }
998
999    mod to_snake_case {
1000        use super::*;
1001
1002        #[test]
1003        fn camel_case() {
1004            let input = "camelCaseString".to_string();
1005            let expected = "camel_case_string";
1006
1007            pretty_assert_eq!(input.to_snake_case(), expected);
1008        }
1009
1010        #[test]
1011        fn short_camel_case() {
1012            let input = "aB".to_string();
1013            let expected = "a_b";
1014
1015            pretty_assert_eq!(input.to_snake_case(), expected);
1016        }
1017
1018        #[test]
1019        fn pascal_case() {
1020            let input = "PascalCaseString".to_string();
1021            let expected = "pascal_case_string";
1022
1023            pretty_assert_eq!(input.to_snake_case(), expected);
1024        }
1025
1026        #[test]
1027        fn pascal_case_with_leading_acronym() {
1028            let input = "ASCIIEncodedString".to_string();
1029            let expected = "ascii_encoded_string";
1030
1031            pretty_assert_eq!(input.to_snake_case(), expected);
1032        }
1033
1034        #[test]
1035        fn pascal_case_with_trailing_acronym() {
1036            let input = "IsASCII".to_string();
1037            let expected = "is_ascii";
1038
1039            pretty_assert_eq!(input.to_snake_case(), expected);
1040        }
1041
1042        #[test]
1043        fn screaming_snake_case() {
1044            let input = "PFN_WDF_DRIVER_DEVICE_ADD".to_string();
1045            let expected = "pfn_wdf_driver_device_add";
1046
1047            pretty_assert_eq!(input.to_snake_case(), expected);
1048        }
1049
1050        #[test]
1051        fn screaming_snake_case_with_leading_acronym() {
1052            let input = "ASCII_STRING".to_string();
1053            let expected = "ascii_string";
1054
1055            pretty_assert_eq!(input.to_snake_case(), expected);
1056        }
1057
1058        #[test]
1059        fn screaming_snake_case_with_leading_underscore() {
1060            let input = "_WDF_DRIVER_INIT_FLAGS".to_string();
1061            let expected = "_wdf_driver_init_flags";
1062
1063            pretty_assert_eq!(input.to_snake_case(), expected);
1064        }
1065
1066        #[test]
1067        fn snake_case() {
1068            let input = "snake_case_string".to_string();
1069            let expected = "snake_case_string";
1070
1071            pretty_assert_eq!(input.to_snake_case(), expected);
1072        }
1073
1074        #[test]
1075        fn snake_case_with_leading_underscore() {
1076            let input = "_snake_case_with_leading_underscore".to_string();
1077            let expected = "_snake_case_with_leading_underscore";
1078
1079            pretty_assert_eq!(input.to_snake_case(), expected);
1080        }
1081    }
1082
1083    mod inputs {
1084        use super::*;
1085
1086        mod parse {
1087            use super::*;
1088
1089            #[test]
1090            fn valid_input() {
1091                let input_tokens = quote! { "/path/to/generated/types/file.rs", WdfDriverCreate, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output };
1092                let expected = Inputs {
1093                    types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1094                    wdf_function_identifier: format_ident!("WdfDriverCreate"),
1095                    wdf_function_arguments: parse_quote! {
1096                        driver,
1097                        registry_path,
1098                        WDF_NO_OBJECT_ATTRIBUTES,
1099                        &mut driver_config,
1100                        driver_handle_output
1101                    },
1102                };
1103
1104                pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1105            }
1106
1107            #[test]
1108            fn valid_input_with_trailing_comma() {
1109                let input_tokens = quote! { "/path/to/generated/types/file.rs" , WdfDriverCreate, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output, };
1110                let expected = Inputs {
1111                    types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1112                    wdf_function_identifier: format_ident!("WdfDriverCreate"),
1113                    wdf_function_arguments: parse_quote! {
1114                        driver,
1115                        registry_path,
1116                        WDF_NO_OBJECT_ATTRIBUTES,
1117                        &mut driver_config,
1118                        driver_handle_output,
1119                    },
1120                };
1121
1122                pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1123            }
1124
1125            #[test]
1126            fn wdf_function_with_no_arguments() {
1127                let input_tokens =
1128                    quote! { "/path/to/generated/types/file.rs", WdfVerifierDbgBreakPoint };
1129                let expected = Inputs {
1130                    types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1131                    wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1132                    wdf_function_arguments: Punctuated::new(),
1133                };
1134
1135                pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1136            }
1137
1138            #[test]
1139            fn wdf_function_with_no_arguments_and_trailing_comma() {
1140                let input_tokens =
1141                    quote! { "/path/to/generated/types/file.rs", WdfVerifierDbgBreakPoint, };
1142                let expected = Inputs {
1143                    types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1144                    wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1145                    wdf_function_arguments: Punctuated::new(),
1146                };
1147
1148                pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1149            }
1150
1151            #[test]
1152            fn invalid_ident() {
1153                let input_tokens = quote! { "/path/to/generated/types/file.rs", 23InvalidIdent, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output, };
1154                let expected = Error::new(Span::call_site(), "expected identifier");
1155
1156                pretty_assert_eq!(
1157                    parse2::<Inputs>(input_tokens).unwrap_err().to_string(),
1158                    expected.to_string()
1159                );
1160            }
1161        }
1162
1163        mod generate_derived_ast_fragments {
1164            use super::*;
1165
1166            #[test]
1167            fn valid_input() {
1168                with_file_lock_clean_env(|| {
1169                    let inputs = Inputs {
1170                        types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1171                        wdf_function_identifier: format_ident!("WdfDriverCreate"),
1172                        wdf_function_arguments: parse_quote! {
1173                            driver,
1174                            registry_path,
1175                            WDF_NO_OBJECT_ATTRIBUTES,
1176                            &mut driver_config,
1177                            driver_handle_output,
1178                        },
1179                    };
1180                    let expected = DerivedASTFragments {
1181                        function_pointer_type: format_ident!("PFN_WDFDRIVERCREATE"),
1182                        function_table_index: format_ident!("WdfDriverCreateTableIndex"),
1183                        parameters: parse_quote! {
1184                            driver_object__: PDRIVER_OBJECT,
1185                            registry_path__: PCUNICODE_STRING,
1186                            driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1187                            driver_config__: PWDF_DRIVER_CONFIG,
1188                            driver__: *mut WDFDRIVER
1189                        },
1190                        parameter_identifiers: parse_quote! {
1191                            driver_object__,
1192                            registry_path__,
1193                            driver_attributes__,
1194                            driver_config__,
1195                            driver__
1196                        },
1197                        return_type: parse_quote! { -> NTSTATUS },
1198                        arguments: parse_quote! {
1199                            driver,
1200                            registry_path,
1201                            WDF_NO_OBJECT_ATTRIBUTES,
1202                            &mut driver_config,
1203                            driver_handle_output,
1204                        },
1205                        inline_wdf_fn_name: format_ident!("wdf_driver_create_impl"),
1206                    };
1207
1208                    pretty_assert_eq!(inputs.generate_derived_ast_fragments().unwrap(), expected);
1209                });
1210            }
1211
1212            #[test]
1213            fn valid_input_with_no_arguments() {
1214                with_file_lock_clean_env(|| {
1215                    let inputs = Inputs {
1216                        types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1217                        wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1218                        wdf_function_arguments: Punctuated::new(),
1219                    };
1220                    let expected = DerivedASTFragments {
1221                        function_pointer_type: format_ident!("PFN_WDFVERIFIERDBGBREAKPOINT"),
1222                        function_table_index: format_ident!("WdfVerifierDbgBreakPointTableIndex"),
1223                        parameters: Punctuated::new(),
1224                        parameter_identifiers: Punctuated::new(),
1225                        return_type: ReturnType::Default,
1226                        arguments: Punctuated::new(),
1227                        inline_wdf_fn_name: format_ident!("wdf_verifier_dbg_break_point_impl"),
1228                    };
1229
1230                    pretty_assert_eq!(inputs.generate_derived_ast_fragments().unwrap(), expected);
1231                });
1232            }
1233        }
1234    }
1235
1236    mod get_wdf_function_info_map {
1237        use super::*;
1238
1239        #[test]
1240        fn valid_input_no_cache() {
1241            with_file_lock_clean_env(|| {
1242                let inputs = Inputs {
1243                    types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1244                    wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1245                    wdf_function_arguments: Punctuated::new(),
1246                };
1247
1248                let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1249                expected.insert(
1250                    "WdfDriverCreate".into(),
1251                    CachedFunctionInfo {
1252                        parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1253                                     PCUNICODE_STRING , driver_attributes__ : \
1254                                     PWDF_OBJECT_ATTRIBUTES , driver_config__ : \
1255                                     PWDF_DRIVER_CONFIG , driver__ : * mut WDFDRIVER"
1256                            .into(),
1257                        return_type: "-> NTSTATUS".into(),
1258                    },
1259                );
1260
1261                expected.insert(
1262                    "WdfVerifierDbgBreakPoint".into(),
1263                    CachedFunctionInfo {
1264                        parameters: String::new(),
1265                        return_type: String::new(),
1266                    },
1267                );
1268                pretty_assert_eq!(
1269                    get_wdf_function_info_map(
1270                        &inputs.types_path,
1271                        inputs.wdf_function_identifier.span()
1272                    )
1273                    .unwrap(),
1274                    expected
1275                );
1276
1277                pretty_assert_eq!(SCRATCH_DIR.join(CACHE_FILE_NAME).exists(), true);
1278            });
1279        }
1280
1281        #[test]
1282        fn valid_input_cache_exists() {
1283            with_file_lock_clean_env(|| {
1284                let inputs = Inputs {
1285                    types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1286                    wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1287                    wdf_function_arguments: Punctuated::new(),
1288                };
1289                // create cache with first call to get_wdf_function_info_map
1290
1291                get_wdf_function_info_map(
1292                    &inputs.types_path,
1293                    inputs.wdf_function_identifier.span(),
1294                )
1295                .unwrap();
1296
1297                // make sure cache exists
1298                pretty_assert_eq!(SCRATCH_DIR.join(CACHE_FILE_NAME).exists(), true);
1299
1300                let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1301                expected.insert(
1302                    "WdfDriverCreate".into(),
1303                    CachedFunctionInfo {
1304                        parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1305                                     PCUNICODE_STRING , driver_attributes__ : \
1306                                     PWDF_OBJECT_ATTRIBUTES , driver_config__ : \
1307                                     PWDF_DRIVER_CONFIG , driver__ : * mut WDFDRIVER"
1308                            .into(),
1309                        return_type: "-> NTSTATUS".into(),
1310                    },
1311                );
1312
1313                expected.insert(
1314                    "WdfVerifierDbgBreakPoint".into(),
1315                    CachedFunctionInfo {
1316                        parameters: String::new(),
1317                        return_type: String::new(),
1318                    },
1319                );
1320                pretty_assert_eq!(
1321                    get_wdf_function_info_map(
1322                        &inputs.types_path,
1323                        inputs.wdf_function_identifier.span()
1324                    )
1325                    .unwrap(),
1326                    expected
1327                );
1328            });
1329        }
1330    }
1331
1332    mod generate_wdf_function_info_file_cache {
1333        use super::*;
1334
1335        #[test]
1336        fn valid_input() {
1337            let inputs = Inputs {
1338                types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1339                wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1340                wdf_function_arguments: Punctuated::new(),
1341            };
1342
1343            let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1344            expected.insert(
1345                "WdfDriverCreate".into(),
1346                CachedFunctionInfo {
1347                    parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1348                                 PCUNICODE_STRING , driver_attributes__ : PWDF_OBJECT_ATTRIBUTES \
1349                                 , driver_config__ : PWDF_DRIVER_CONFIG , driver__ : * mut \
1350                                 WDFDRIVER"
1351                        .into(),
1352                    return_type: "-> NTSTATUS".into(),
1353                },
1354            );
1355
1356            expected.insert(
1357                "WdfVerifierDbgBreakPoint".into(),
1358                CachedFunctionInfo {
1359                    parameters: String::new(),
1360                    return_type: String::new(),
1361                },
1362            );
1363
1364            pretty_assert_eq!(
1365                generate_wdf_function_info_file_cache(
1366                    &inputs.types_path,
1367                    inputs.wdf_function_identifier.span()
1368                )
1369                .unwrap(),
1370                expected
1371            );
1372        }
1373
1374        #[test]
1375        fn invalid_input_missing_wdf_func_enum() {
1376            let inputs = Inputs {
1377                types_path: parse_quote! { "tests/unit-tests-input/missing-wdf-func-enum.rs" },
1378                wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1379                wdf_function_arguments: Punctuated::new(),
1380            };
1381
1382            let expected = Error::new(
1383                Span::call_site(),
1384                "Failed to find _WDFFUNCENUM module in types.rs file",
1385            );
1386
1387            pretty_assert_eq!(
1388                generate_wdf_function_info_file_cache(
1389                    &inputs.types_path,
1390                    inputs.wdf_function_identifier.span()
1391                )
1392                .unwrap_err()
1393                .to_string(),
1394                expected.to_string()
1395            );
1396        }
1397
1398        #[test]
1399        fn invalid_input_missing_wdf_func_enum_contents() {
1400            let inputs = Inputs {
1401                types_path: parse_quote! { "tests/unit-tests-input/missing-wdf-func-enum-contents.rs" },
1402                wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1403                wdf_function_arguments: Punctuated::new(),
1404            };
1405
1406            let expected = Error::new(
1407                Span::call_site(),
1408                "Failed to find _WDFFUNCENUM module contents in types.rs file",
1409            );
1410
1411            pretty_assert_eq!(
1412                generate_wdf_function_info_file_cache(
1413                    &inputs.types_path,
1414                    inputs.wdf_function_identifier.span()
1415                )
1416                .unwrap_err()
1417                .to_string(),
1418                expected.to_string()
1419            );
1420        }
1421    }
1422
1423    mod generate_cached_function_info {
1424        use super::*;
1425
1426        #[test]
1427        fn valid_input() {
1428            // This is a snippet of a bindgen-generated file containing types information
1429            // used by tests for [`wdk_macros::call_unsafe_wdf_function_binding!`]
1430            let types_ast = parse_quote! {
1431                pub type PFN_WDFIOQUEUEPURGESYNCHRONOUSLY = ::core::option::Option<
1432                    unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS, Queue: WDFQUEUE),
1433                >;
1434            };
1435            let function_pointer_type = format_ident!("PFN_WDFIOQUEUEPURGESYNCHRONOUSLY");
1436            let expected: Option<CachedFunctionInfo> = Some(
1437                (
1438                    parse_quote! {
1439                        queue__: WDFQUEUE
1440                    },
1441                    ReturnType::Default,
1442                )
1443                    .into(),
1444            );
1445
1446            pretty_assert_eq!(
1447                generate_cached_function_info(&types_ast, &function_pointer_type).unwrap(),
1448                expected
1449            );
1450        }
1451    }
1452
1453    mod find_type_alias_definition {
1454        use super::*;
1455
1456        #[test]
1457        fn valid_input() {
1458            // This is a snippet of a bindgen-generated file containing types information
1459            // used by tests for [`wdk_macros::call_unsafe_wdf_function_binding!`]
1460            let types_ast = parse_quote! {
1461                pub type WDF_DRIVER_GLOBALS = _WDF_DRIVER_GLOBALS;
1462                pub type PWDF_DRIVER_GLOBALS = *mut _WDF_DRIVER_GLOBALS;
1463                pub mod _WDFFUNCENUM {
1464                    pub type Type = ::core::ffi::c_int;
1465                    pub const WdfChildListCreateTableIndex: Type = 0;
1466                    pub const WdfChildListGetDeviceTableIndex: Type = 1;
1467                    pub const WdfChildListRetrievePdoTableIndex: Type = 2;
1468                    pub const WdfChildListRetrieveAddressDescriptionTableIndex: Type = 3;
1469                    pub const WdfChildListBeginScanTableIndex: Type = 4;
1470                    pub const WdfChildListEndScanTableIndex: Type = 5;
1471                    pub const WdfChildListBeginIterationTableIndex: Type = 6;
1472                    pub const WdfChildListRetrieveNextDeviceTableIndex: Type = 7;
1473                    pub const WdfChildListEndIterationTableIndex: Type = 8;
1474                    pub const WdfChildListAddOrUpdateChildDescriptionAsPresentTableIndex: Type = 9;
1475                    pub const WdfChildListUpdateChildDescriptionAsMissingTableIndex: Type = 10;
1476                    pub const WdfChildListUpdateAllChildDescriptionsAsPresentTableIndex: Type = 11;
1477                    pub const WdfChildListRequestChildEjectTableIndex: Type = 12;
1478                }
1479                pub type PFN_WDFGETTRIAGEINFO = ::core::option::Option<
1480                    unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS) -> PVOID,
1481                >;
1482            };
1483            let function_pointer_type = format_ident!("PFN_WDFGETTRIAGEINFO");
1484            let expected = parse_quote! {
1485                pub type PFN_WDFGETTRIAGEINFO = ::core::option::Option<
1486                    unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS) -> PVOID,
1487                >;
1488            };
1489
1490            pretty_assert_eq!(
1491                find_type_alias_definition(&types_ast, &function_pointer_type).unwrap(),
1492                &expected
1493            );
1494        }
1495    }
1496
1497    mod extract_fn_pointer_definition {
1498        use super::*;
1499
1500        #[test]
1501        fn valid_input() {
1502            let fn_type_alias = parse_quote! {
1503                pub type PFN_WDFDRIVERCREATE = ::core::option::Option<
1504                    unsafe extern "C" fn(
1505                        DriverGlobals: PWDF_DRIVER_GLOBALS,
1506                        DriverObject: PDRIVER_OBJECT,
1507                        RegistryPath: PCUNICODE_STRING,
1508                        DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1509                        DriverConfig: PWDF_DRIVER_CONFIG,
1510                        Driver: *mut WDFDRIVER,
1511                    ) -> NTSTATUS
1512                >;
1513            };
1514            let expected = parse_quote! {
1515                ::core::option::Option<
1516                    unsafe extern "C" fn(
1517                        DriverGlobals: PWDF_DRIVER_GLOBALS,
1518                        DriverObject: PDRIVER_OBJECT,
1519                        RegistryPath: PCUNICODE_STRING,
1520                        DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1521                        DriverConfig: PWDF_DRIVER_CONFIG,
1522                        Driver: *mut WDFDRIVER,
1523                    ) -> NTSTATUS
1524                >
1525            };
1526
1527            pretty_assert_eq!(
1528                extract_fn_pointer_definition(&fn_type_alias, Span::call_site()).unwrap(),
1529                &expected
1530            );
1531        }
1532
1533        #[test]
1534        fn valid_input_with_no_arguments() {
1535            let fn_type_alias = parse_quote! {
1536                pub type PFN_WDFVERIFIERDBGBREAKPOINT = ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>;
1537            };
1538            let expected = parse_quote! {
1539                ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>
1540            };
1541
1542            pretty_assert_eq!(
1543                extract_fn_pointer_definition(&fn_type_alias, Span::call_site()).unwrap(),
1544                &expected
1545            );
1546        }
1547    }
1548
1549    mod parse_fn_pointer_definition {
1550        use super::*;
1551
1552        #[test]
1553        fn valid_input() {
1554            // WdfDriverCreate has the following generated signature:
1555            let fn_pointer_typepath = parse_quote! {
1556                ::core::option::Option<unsafe extern "C" fn(
1557                    DriverGlobals: PWDF_DRIVER_GLOBALS,
1558                    DriverObject: PDRIVER_OBJECT,
1559                    RegistryPath: PCUNICODE_STRING,
1560                    DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1561                    DriverConfig: PWDF_DRIVER_CONFIG,
1562                    Driver: *mut WDFDRIVER,
1563                ) -> NTSTATUS>
1564            };
1565            let expected = (
1566                parse_quote! {
1567                    driver_object__: PDRIVER_OBJECT,
1568                    registry_path__: PCUNICODE_STRING,
1569                    driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1570                    driver_config__: PWDF_DRIVER_CONFIG,
1571                    driver__: *mut WDFDRIVER
1572                },
1573                ReturnType::Type(
1574                    Token![->](Span::call_site()),
1575                    Box::new(Type::Path(parse_quote! { NTSTATUS })),
1576                ),
1577            );
1578
1579            pretty_assert_eq!(
1580                parse_fn_pointer_definition(&fn_pointer_typepath, Span::call_site()).unwrap(),
1581                expected
1582            );
1583        }
1584
1585        #[test]
1586        fn valid_input_with_no_arguments() {
1587            // WdfVerifierDbgBreakPoint has the following generated signature:
1588            let fn_pointer_typepath = parse_quote! {
1589                ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>
1590            };
1591            let expected = (Punctuated::new(), ReturnType::Default);
1592
1593            pretty_assert_eq!(
1594                parse_fn_pointer_definition(&fn_pointer_typepath, Span::call_site()).unwrap(),
1595                expected
1596            );
1597        }
1598    }
1599
1600    mod extract_bare_fn_type {
1601        use super::*;
1602
1603        #[test]
1604        fn valid_input() {
1605            // WdfDriverCreate has the following generated signature:
1606            let fn_pointer_typepath = parse_quote! {
1607                ::core::option::Option<
1608                    unsafe extern "C" fn(
1609                        DriverGlobals: PWDF_DRIVER_GLOBALS,
1610                        DriverObject: PDRIVER_OBJECT,
1611                        RegistryPath: PCUNICODE_STRING,
1612                        DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1613                        DriverConfig: PWDF_DRIVER_CONFIG,
1614                        Driver: *mut WDFDRIVER,
1615                    ) -> NTSTATUS,
1616                >
1617            };
1618            let expected: TypeBareFn = parse_quote! {
1619                unsafe extern "C" fn(
1620                    DriverGlobals: PWDF_DRIVER_GLOBALS,
1621                    DriverObject: PDRIVER_OBJECT,
1622                    RegistryPath: PCUNICODE_STRING,
1623                    DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1624                    DriverConfig: PWDF_DRIVER_CONFIG,
1625                    Driver: *mut WDFDRIVER,
1626                ) -> NTSTATUS
1627            };
1628
1629            pretty_assert_eq!(
1630                extract_bare_fn_type(&fn_pointer_typepath, Span::call_site()).unwrap(),
1631                &expected
1632            );
1633        }
1634    }
1635
1636    mod compute_fn_parameters {
1637        use super::*;
1638
1639        #[test]
1640        fn valid_input() {
1641            // WdfDriverCreate has the following generated signature:
1642            let bare_fn_type = parse_quote! {
1643                unsafe extern "C" fn(
1644                    DriverGlobals: PWDF_DRIVER_GLOBALS,
1645                    DriverObject: PDRIVER_OBJECT,
1646                    RegistryPath: PCUNICODE_STRING,
1647                    DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1648                    DriverConfig: PWDF_DRIVER_CONFIG,
1649                    Driver: *mut WDFDRIVER,
1650                ) -> NTSTATUS
1651            };
1652            let expected = parse_quote! {
1653                driver_object__: PDRIVER_OBJECT,
1654                registry_path__: PCUNICODE_STRING,
1655                driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1656                driver_config__: PWDF_DRIVER_CONFIG,
1657                driver__: *mut WDFDRIVER
1658            };
1659
1660            pretty_assert_eq!(
1661                compute_fn_parameters(&bare_fn_type, Span::call_site()).unwrap(),
1662                expected
1663            );
1664        }
1665
1666        #[test]
1667        fn valid_input_with_no_arguments() {
1668            // WdfVerifierDbgBreakPoint has the following generated signature:
1669            let bare_fn_type = parse_quote! {
1670                unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)
1671            };
1672            let expected = Punctuated::new();
1673
1674            pretty_assert_eq!(
1675                compute_fn_parameters(&bare_fn_type, Span::call_site()).unwrap(),
1676                expected
1677            );
1678        }
1679    }
1680
1681    mod compute_return_type {
1682        use super::*;
1683
1684        #[test]
1685        fn ntstatus() {
1686            // WdfDriverCreate has the following generated signature:
1687            let bare_fn_type = parse_quote! {
1688                unsafe extern "C" fn(
1689                    DriverGlobals: PWDF_DRIVER_GLOBALS,
1690                    DriverObject: PDRIVER_OBJECT,
1691                    RegistryPath: PCUNICODE_STRING,
1692                    DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1693                    DriverConfig: PWDF_DRIVER_CONFIG,
1694                    Driver: *mut WDFDRIVER,
1695                ) -> NTSTATUS
1696            };
1697            let expected = ReturnType::Type(
1698                Token![->](Span::call_site()),
1699                Box::new(Type::Path(parse_quote! { NTSTATUS })),
1700            );
1701
1702            pretty_assert_eq!(compute_return_type(&bare_fn_type), expected);
1703        }
1704
1705        #[test]
1706        fn unit() {
1707            // WdfSpinLockAcquire has the following generated signature:
1708            let bare_fn_type = parse_quote! {
1709                unsafe extern "C" fn(
1710                    DriverGlobals: PWDF_DRIVER_GLOBALS,
1711                    SpinLock: WDFSPINLOCK
1712                )
1713            };
1714            let expected = ReturnType::Default;
1715
1716            pretty_assert_eq!(compute_return_type(&bare_fn_type), expected);
1717        }
1718    }
1719
1720    mod generate_must_use_attribute {
1721        use super::*;
1722
1723        #[test]
1724        fn unit_return_type() {
1725            let return_type = ReturnType::Default;
1726            let generated_must_use_attribute_tokens = generate_must_use_attribute(&return_type);
1727
1728            pretty_assert_eq!(generated_must_use_attribute_tokens, None);
1729        }
1730
1731        #[test]
1732        fn ntstatus_return_type() {
1733            let return_type: ReturnType = parse_quote! { -> NTSTATUS };
1734            let expected_tokens = quote! { #[must_use] };
1735            let generated_must_use_attribute_tokens = generate_must_use_attribute(&return_type);
1736
1737            pretty_assert_eq!(
1738                generated_must_use_attribute_tokens
1739                    .unwrap()
1740                    .into_token_stream()
1741                    .to_string(),
1742                expected_tokens.to_string(),
1743            );
1744        }
1745    }
1746}