1use std::{collections::BTreeMap, path::PathBuf, str::FromStr};
8
9use fs4::fs_std::FileExt;
10use itertools::Itertools;
11use proc_macro::TokenStream;
12use proc_macro2::{Span, TokenStream as TokenStream2};
13use quote::{format_ident, quote, ToTokens};
14use serde::{Deserialize, Serialize};
15use syn::{
16 parse::{Parse, ParseStream, Parser},
17 parse2,
18 parse_file,
19 parse_quote,
20 punctuated::Punctuated,
21 AngleBracketedGenericArguments,
22 Attribute,
23 BareFnArg,
24 Error,
25 Expr,
26 ExprCall,
27 File,
28 GenericArgument,
29 Ident,
30 Item,
31 ItemType,
32 LitStr,
33 Path,
34 PathArguments,
35 PathSegment,
36 Result,
37 ReturnType,
38 Signature,
39 Stmt,
40 Token,
41 Type,
42 TypeBareFn,
43 TypePath,
44};
45
46const WDF_FUNC_ENUM_MOD_NAME: &str = "_WDFFUNCENUM";
49
50#[proc_macro]
57pub fn call_unsafe_wdf_function_binding(input_tokens: TokenStream) -> TokenStream {
58 call_unsafe_wdf_function_binding_impl(TokenStream2::from(input_tokens)).into()
59}
60
61trait StringExt {
63 fn to_snake_case(&self) -> String;
65}
66
67trait ResultExt<T, E> {
69 fn to_syn_result(self, span: Span, error: &str) -> syn::Result<T>;
70}
71
72#[derive(Debug, Deserialize, PartialEq, Serialize)]
75struct CachedFunctionInfo {
76 parameters: String,
77 return_type: String,
78}
79
80#[derive(Debug, PartialEq)]
83struct Inputs {
84 types_path: LitStr,
86 wdf_function_identifier: Ident,
89 wdf_function_arguments: Punctuated<Expr, Token![,]>,
92}
93
94#[derive(Debug, PartialEq)]
98struct DerivedASTFragments {
99 function_pointer_type: Ident,
100 function_table_index: Ident,
101 parameters: Punctuated<BareFnArg, Token![,]>,
102 parameter_identifiers: Punctuated<Ident, Token![,]>,
103 return_type: ReturnType,
104 arguments: Punctuated<Expr, Token![,]>,
105 inline_wdf_fn_name: Ident,
106}
107
108struct IntermediateOutputASTFragments {
111 must_use_attribute: Option<Attribute>,
112 inline_wdf_fn_signature: Signature,
113 inline_wdf_fn_body_statments: Vec<Stmt>,
114 inline_wdf_fn_invocation: ExprCall,
115}
116
117struct FileLockGuard {
120 file: std::fs::File,
121}
122
123impl FileLockGuard {
124 fn new(file: std::fs::File, span: Span) -> Result<Self> {
125 FileExt::lock_exclusive(&file).to_syn_result(span, "unable to obtain file lock")?;
126 Ok(Self { file })
127 }
128}
129
130impl Drop for FileLockGuard {
131 fn drop(&mut self) {
132 let _ = FileExt::unlock(&self.file);
133 }
134}
135
136impl StringExt for String {
137 fn to_snake_case(&self) -> String {
138 const MAX_PADDING_NEEDED: usize = 2;
141
142 let mut snake_case_string = Self::with_capacity(self.len());
143
144 for (current_char, next_char, next_next_char) in self
145 .chars()
146 .map(Some)
147 .chain([None; MAX_PADDING_NEEDED])
148 .tuple_windows()
149 .filter_map(|(c1, c2, c3)| Some((c1?, c2, c3)))
150 {
151 if current_char.is_lowercase() && next_char.is_some_and(|c| c.is_ascii_uppercase()) {
153 snake_case_string.push(current_char);
154 snake_case_string.push('_');
155 }
156 else if current_char.is_uppercase()
158 && next_char.is_some_and(|c| c.is_ascii_uppercase())
159 && next_next_char.is_some_and(|c| c.is_ascii_lowercase())
160 {
161 snake_case_string.push(current_char.to_ascii_lowercase());
162 snake_case_string.push('_');
163 } else {
164 snake_case_string.push(current_char.to_ascii_lowercase());
165 }
166 }
167
168 snake_case_string
169 }
170}
171
172impl<T, E: std::error::Error> ResultExt<T, E> for std::result::Result<T, E> {
173 fn to_syn_result(self, span: Span, error_description: &str) -> syn::Result<T> {
174 self.map_err(|err| Error::new(span, format!("{error_description}, {err}")))
175 }
176}
177
178impl From<(Punctuated<BareFnArg, Token![,]>, ReturnType)> for CachedFunctionInfo {
179 fn from((parameters, return_type): (Punctuated<BareFnArg, Token![,]>, ReturnType)) -> Self {
180 Self {
181 parameters: parameters.to_token_stream().to_string(),
182 return_type: return_type.to_token_stream().to_string(),
183 }
184 }
185}
186
187impl Parse for Inputs {
188 fn parse(input: ParseStream) -> Result<Self> {
189 let types_path = input.parse::<LitStr>()?;
190
191 input.parse::<Token![,]>()?;
192 let c_wdf_function_identifier = input.parse::<Ident>()?;
193
194 if input.is_empty() {
196 return Ok(Self {
197 types_path,
198 wdf_function_identifier: c_wdf_function_identifier,
199 wdf_function_arguments: Punctuated::new(),
200 });
201 }
202
203 input.parse::<Token![,]>()?;
204 let wdf_function_arguments = input.parse_terminated(Expr::parse, Token![,])?;
205
206 Ok(Self {
207 types_path,
208 wdf_function_identifier: c_wdf_function_identifier,
209 wdf_function_arguments,
210 })
211 }
212}
213
214impl Inputs {
215 fn generate_derived_ast_fragments(self) -> Result<DerivedASTFragments> {
216 let function_pointer_type = format_ident!(
217 "PFN_{uppercase_c_function_name}",
218 uppercase_c_function_name = self.wdf_function_identifier.to_string().to_uppercase(),
219 span = self.wdf_function_identifier.span()
220 );
221 let function_table_index = format_ident!(
222 "{wdf_function_identifier}TableIndex",
223 wdf_function_identifier = self.wdf_function_identifier,
224 span = self.wdf_function_identifier.span()
225 );
226
227 let function_name_to_info_map: BTreeMap<String, CachedFunctionInfo> =
228 get_wdf_function_info_map(&self.types_path, self.wdf_function_identifier.span())?;
229 let function_info = function_name_to_info_map
230 .get(&self.wdf_function_identifier.to_string())
231 .ok_or_else(|| {
232 Error::new(
233 self.wdf_function_identifier.span(),
234 format!(
235 "Failed to find function info for {}",
236 self.wdf_function_identifier
237 ),
238 )
239 })?;
240 let parameters_tokens = TokenStream2::from_str(&function_info.parameters).to_syn_result(
241 self.wdf_function_identifier.span(),
242 "unable to parse parameter tokens",
243 )?;
244 let return_type_tokens = TokenStream2::from_str(&function_info.return_type).to_syn_result(
245 self.wdf_function_identifier.span(),
246 "unable to parse return type tokens",
247 )?;
248 let parameters =
249 Punctuated::<BareFnArg, Token![,]>::parse_terminated.parse2(parameters_tokens)?;
250 let return_type = ReturnType::parse.parse2(return_type_tokens)?;
251
252 let parameter_identifiers = parameters
253 .iter()
254 .cloned()
255 .map(|bare_fn_arg| {
256 if let Some((identifier, _)) = bare_fn_arg.name {
257 return Ok(identifier);
258 }
259 Err(Error::new(
260 function_pointer_type.span(),
261 format!("Expected fn parameter to have a name: {bare_fn_arg:#?}"),
262 ))
263 })
264 .collect::<Result<_>>()?;
265 let inline_wdf_fn_name = format_ident!(
266 "{c_function_name_snake_case}_impl",
267 c_function_name_snake_case = self.wdf_function_identifier.to_string().to_snake_case()
268 );
269
270 Ok(DerivedASTFragments {
271 function_pointer_type,
272 function_table_index,
273 parameters,
274 parameter_identifiers,
275 return_type,
276 arguments: self.wdf_function_arguments,
277 inline_wdf_fn_name,
278 })
279 }
280}
281
282impl DerivedASTFragments {
283 fn generate_intermediate_output_ast_fragments(self) -> IntermediateOutputASTFragments {
284 let Self {
285 function_pointer_type,
286 function_table_index,
287 parameters,
288 parameter_identifiers,
289 return_type,
290 arguments,
291 inline_wdf_fn_name,
292 } = self;
293
294 let must_use_attribute = generate_must_use_attribute(&return_type);
295
296 let inline_wdf_fn_signature = parse_quote! {
297 unsafe fn #inline_wdf_fn_name(#parameters) #return_type
298 };
299
300 let inline_wdf_fn_body_statments = parse_quote! {
301 let wdf_function: wdk_sys::#function_pointer_type = Some(
303 unsafe {
307 let wdf_function_table = wdk_sys::WdfFunctions;
308 let wdf_function_count = wdk_sys::wdf::__private::get_wdf_function_count();
309
310 debug_assert!(isize::try_from(wdf_function_count * core::mem::size_of::<wdk_sys::WDFFUNC>()).is_ok());
320 let wdf_function_table = core::slice::from_raw_parts(wdf_function_table, wdf_function_count);
321
322 core::mem::transmute(
323 wdf_function_table[wdk_sys::_WDFFUNCENUM::#function_table_index as usize],
325 )
326 }
327 );
328
329 if let Some(wdf_function) = wdf_function {
332 unsafe {
337 (wdf_function)(
338 wdk_sys::WdfDriverGlobals,
339 #parameter_identifiers
340 )
341 }
342 } else {
343 unreachable!("Option should never be None");
344 }
345 };
346
347 let inline_wdf_fn_invocation = parse_quote! {
348 #inline_wdf_fn_name(#arguments)
349 };
350
351 IntermediateOutputASTFragments {
352 must_use_attribute,
353 inline_wdf_fn_signature,
354 inline_wdf_fn_body_statments,
355 inline_wdf_fn_invocation,
356 }
357 }
358}
359
360impl IntermediateOutputASTFragments {
361 fn assemble_final_output(self) -> TokenStream2 {
362 let Self {
363 must_use_attribute,
364 inline_wdf_fn_signature,
365 inline_wdf_fn_body_statments,
366 inline_wdf_fn_invocation,
367 } = self;
368
369 let conditional_must_use_attribute =
370 must_use_attribute.map_or_else(TokenStream2::new, quote::ToTokens::into_token_stream);
371
372 quote! {
373 {
374 mod private__ {
376 use wdk_sys::*;
379
380 #conditional_must_use_attribute
382 #[inline(always)]
385 pub #inline_wdf_fn_signature {
386 #(#inline_wdf_fn_body_statments)*
387 }
388 }
389
390 private__::#inline_wdf_fn_invocation
391 }
392 }
393 }
394}
395
396fn call_unsafe_wdf_function_binding_impl(input_tokens: TokenStream2) -> TokenStream2 {
397 let inputs = match parse2::<Inputs>(input_tokens) {
398 Ok(syntax_tree) => syntax_tree,
399 Err(err) => return err.to_compile_error(),
400 };
401
402 let derived_ast_fragments = match inputs.generate_derived_ast_fragments() {
403 Ok(derived_ast_fragments) => derived_ast_fragments,
404 Err(err) => return err.to_compile_error(),
405 };
406
407 derived_ast_fragments
408 .generate_intermediate_output_ast_fragments()
409 .assemble_final_output()
410}
411
412fn get_wdf_function_info_map(
425 types_path: &LitStr,
426 span: Span,
427) -> Result<BTreeMap<String, CachedFunctionInfo>> {
428 cfg_if::cfg_if! {
429 if #[cfg(test)] {
430 let scratch_dir = scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments_test"));
431 } else {
432 let scratch_dir = scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments"));
433 }
434 }
435
436 let cached_function_info_map_path = scratch_dir.join("cached_function_info_map.json");
437
438 if !cached_function_info_map_path.exists() {
439 let flock = std::fs::File::create(scratch_dir.join(".lock"))
440 .to_syn_result(span, "unable to create file lock")?;
441
442 let _flock_guard = FileLockGuard::new(flock, span)
444 .to_syn_result(span, "unable to create file lock guard")?;
445
446 if !cached_function_info_map_path.exists() {
449 let function_info_map = create_wdf_function_info_file_cache(
450 types_path,
451 cached_function_info_map_path.as_path(),
452 span,
453 )?;
454 return Ok(function_info_map);
455 }
456 }
457 let function_info_map =
458 read_wdf_function_info_file_cache(cached_function_info_map_path.as_path(), span)?;
459 Ok(function_info_map)
460}
461
462fn read_wdf_function_info_file_cache(
465 cached_function_info_map_path: &std::path::Path,
466 span: Span,
467) -> Result<BTreeMap<String, CachedFunctionInfo>> {
468 let generated_map_string = std::fs::read_to_string(cached_function_info_map_path)
469 .to_syn_result(span, "unable to read cache to string")?;
470 let map: BTreeMap<String, CachedFunctionInfo> = serde_json::from_str(&generated_map_string)
471 .to_syn_result(span, "unable to parse cache to BTreeMap")?;
472 Ok(map)
473}
474
475fn create_wdf_function_info_file_cache(
480 types_path: &LitStr,
481 cached_function_info_map_path: &std::path::Path,
482 span: Span,
483) -> Result<BTreeMap<String, CachedFunctionInfo>> {
484 let generated_map = generate_wdf_function_info_file_cache(types_path, span)?;
485 let generated_map_string = serde_json::to_string(&generated_map)
486 .to_syn_result(span, "unable to parse cache to JSON string")?;
487 std::fs::write(cached_function_info_map_path, generated_map_string)
488 .to_syn_result(span, "unable to write cache to file")?;
489 Ok(generated_map)
490}
491
492fn generate_wdf_function_info_file_cache(
496 types_path: &LitStr,
497 span: Span,
498) -> Result<BTreeMap<String, CachedFunctionInfo>> {
499 let types_ast = parse_types_ast(types_path)?;
500 let func_enum_mod = types_ast
501 .items
502 .iter()
503 .find_map(|item| {
504 if let Item::Mod(mod_alias) = item {
505 if mod_alias.ident == WDF_FUNC_ENUM_MOD_NAME {
506 return Some(mod_alias);
507 }
508 }
509 None
510 })
511 .ok_or_else(|| {
512 Error::new(
513 span,
514 format!("Failed to find {WDF_FUNC_ENUM_MOD_NAME} module in types.rs file",),
515 )
516 })?;
517
518 let (_brace, func_enum_mod_contents) = &func_enum_mod.content.as_ref().ok_or_else(|| {
519 Error::new(
520 span,
521 format!("Failed to find {WDF_FUNC_ENUM_MOD_NAME} module contents in types.rs file",),
522 )
523 })?;
524
525 func_enum_mod_contents
526 .iter()
527 .filter_map(|item| {
528 if let Item::Const(const_alias) = item {
529 return const_alias
530 .ident
531 .to_string()
532 .strip_suffix("TableIndex")
533 .and_then(|function_name| {
534 let function_pointer_type = format_ident!(
535 "PFN_{uppercase_c_function_name}",
536 uppercase_c_function_name = function_name.to_uppercase(),
537 span = span
538 );
539 generate_cached_function_info(&types_ast, &function_pointer_type)
540 .transpose()
541 .map(|generate_cached_function_info_result| {
542 generate_cached_function_info_result.map(|cached_function_info| {
543 (function_name.to_string(), cached_function_info)
544 })
545 })
546 });
547 }
548 None
549 })
550 .collect()
551}
552
553fn parse_types_ast(path: &LitStr) -> Result<File> {
554 let types_path = PathBuf::from(path.value());
555 let types_path = match types_path.canonicalize() {
556 Ok(types_path) => types_path,
557 Err(err) => {
558 return Err(Error::new(
559 path.span(),
560 format!(
561 "Failed to canonicalize types_path ({}): {err}",
562 types_path.display()
563 ),
564 ));
565 }
566 };
567
568 let types_file_contents = match std::fs::read_to_string(&types_path) {
569 Ok(contents) => contents,
570 Err(err) => {
571 return Err(Error::new(
572 path.span(),
573 format!(
574 "Failed to read wdk-sys types information from {}: {err}",
575 types_path.display(),
576 ),
577 ));
578 }
579 };
580
581 match parse_file(&types_file_contents) {
582 Ok(wdk_sys_types_rs_abstract_syntax_tree) => Ok(wdk_sys_types_rs_abstract_syntax_tree),
583 Err(err) => Err(Error::new(
584 path.span(),
585 format!(
586 "Failed to parse wdk-sys types information from {} into AST: {err}",
587 types_path.display(),
588 ),
589 )),
590 }
591}
592
593fn generate_cached_function_info(
612 types_ast: &File,
613 function_pointer_type: &Ident,
614) -> Result<Option<CachedFunctionInfo>> {
615 match find_type_alias_definition(types_ast, function_pointer_type) {
616 Ok(type_alias_definition) => {
617 let fn_pointer_definition =
618 extract_fn_pointer_definition(type_alias_definition, function_pointer_type.span())?;
619 Ok(Some(
620 parse_fn_pointer_definition(fn_pointer_definition, function_pointer_type.span())?
621 .into(),
622 ))
623 }
624 Err(_err) => Ok(None),
627 }
628}
629
630fn find_type_alias_definition<'a>(
651 types_ast: &'a File,
652 function_pointer_type: &Ident,
653) -> Result<&'a ItemType> {
654 types_ast
655 .items
656 .iter()
657 .find_map(|item| {
658 if let Item::Type(type_alias) = item {
659 if type_alias.ident == *function_pointer_type {
660 return Some(type_alias);
661 }
662 }
663 None
664 })
665 .ok_or_else(|| {
666 Error::new(
667 function_pointer_type.span(),
668 format!("Failed to find type alias definition for {function_pointer_type}"),
669 )
670 })
671}
672
673fn extract_fn_pointer_definition(type_alias: &ItemType, error_span: Span) -> Result<&TypePath> {
708 if let Type::Path(fn_pointer) = type_alias.ty.as_ref() {
709 Ok(fn_pointer)
710 } else {
711 Err(Error::new(
712 error_span,
713 format!("Expected Type::Path when parsing ItemType.ty:\n{type_alias:#?}"),
714 ))
715 }
716}
717
718fn parse_fn_pointer_definition(
751 fn_pointer_typepath: &TypePath,
752 error_span: Span,
753) -> Result<(Punctuated<BareFnArg, Token![,]>, ReturnType)> {
754 let bare_fn_type = extract_bare_fn_type(fn_pointer_typepath, error_span)?;
755 let fn_parameters = compute_fn_parameters(bare_fn_type, error_span)?;
756 let return_type = compute_return_type(bare_fn_type);
757
758 Ok((fn_parameters, return_type))
759}
760
761fn extract_bare_fn_type(fn_pointer_typepath: &TypePath, error_span: Span) -> Result<&TypeBareFn> {
794 let option_path_segment: &PathSegment =
795 fn_pointer_typepath.path.segments.last().ok_or_else(|| {
796 Error::new(
797 error_span,
798 format!("Expected at least one PathSegment in TypePath:\n{fn_pointer_typepath:#?}"),
799 )
800 })?;
801 if option_path_segment.ident != "Option" {
802 return Err(Error::new(
803 error_span,
804 format!("Expected Option as last PathSegment in TypePath:\n{fn_pointer_typepath:#?}"),
805 ));
806 }
807 let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
808 args: ref option_angle_bracketed_args,
809 ..
810 }) = option_path_segment.arguments
811 else {
812 return Err(Error::new(
813 error_span,
814 format!(
815 "Expected AngleBracketed PathArguments in Option \
816 PathSegment:\n{option_path_segment:#?}"
817 ),
818 ));
819 };
820 let bracketed_argument = option_angle_bracketed_args.first().ok_or_else(|| {
821 Error::new(
822 error_span,
823 format!(
824 "Expected exactly one GenericArgument in AngleBracketedGenericArguments:\n{:#?}",
825 option_path_segment.arguments
826 ),
827 )
828 })?;
829 let GenericArgument::Type(Type::BareFn(bare_fn_type)) = bracketed_argument else {
830 return Err(Error::new(
831 error_span,
832 format!("Expected TypeBareFn in GenericArgument:\n{bracketed_argument:#?}"),
833 ));
834 };
835 Ok(bare_fn_type)
836}
837
838fn compute_fn_parameters(
864 bare_fn_type: &syn::TypeBareFn,
865 error_span: Span,
866) -> Result<Punctuated<BareFnArg, Token![,]>> {
867 let Some(BareFnArg {
869 ty:
870 Type::Path(TypePath {
871 path:
872 Path {
873 segments: first_parameter_type_path,
874 ..
875 },
876 ..
877 }),
878 ..
879 }) = bare_fn_type.inputs.first()
880 else {
881 return Err(Error::new(
882 error_span,
883 format!(
884 "Expected at least one input parameter of type Path in \
885 BareFnType:\n{bare_fn_type:#?}"
886 ),
887 ));
888 };
889 let Some(last_path_segment) = first_parameter_type_path.last() else {
890 return Err(Error::new(
891 error_span,
892 format!("Expected at least one PathSegment in TypePath:\n{bare_fn_type:#?}"),
893 ));
894 };
895 if last_path_segment.ident != "PWDF_DRIVER_GLOBALS" {
896 return Err(Error::new(
897 error_span,
898 format!(
899 "Expected PWDF_DRIVER_GLOBALS as last PathSegment in TypePath of first BareFnArg \
900 input:\n{bare_fn_type:#?}"
901 ),
902 ));
903 }
904
905 Ok(bare_fn_type
906 .inputs
907 .iter()
908 .skip(1)
909 .map(|fn_arg| {
912 let arg_name = fn_arg.name.as_ref().map(|(ident, colon_token)| {
913 let modified_name = {
914 let mut name = ident.to_string().to_snake_case();
915 name.push_str("__");
916 name
917 };
918 (Ident::new(&modified_name, ident.span()), *colon_token)
919 });
920
921 BareFnArg {
922 name: arg_name,
923 ..fn_arg.clone()
924 }
925 })
926 .collect())
927}
928
929fn compute_return_type(bare_fn_type: &syn::TypeBareFn) -> ReturnType {
948 bare_fn_type.output.clone()
949}
950
951fn generate_must_use_attribute(return_type: &ReturnType) -> Option<Attribute> {
953 if matches!(return_type, ReturnType::Type(..)) {
954 Some(parse_quote! { #[must_use] })
955 } else {
956 None
957 }
958}
959
960#[cfg(test)]
961mod tests {
962 use std::sync::LazyLock;
963
964 use pretty_assertions::assert_eq as pretty_assert_eq;
965 use quote::ToTokens;
966
967 use super::*;
968
969 static SCRATCH_DIR: LazyLock<PathBuf> =
970 LazyLock::new(|| scratch::path(concat!(env!("CARGO_CRATE_NAME"), "_ast_fragments_test")));
971 const CACHE_FILE_NAME: &str = "cached_function_info_map.json";
972
973 fn with_file_lock_clean_env<F>(f: F)
974 where
975 F: FnOnce(),
976 {
977 let test_flock: std::fs::File =
978 std::fs::File::create(SCRATCH_DIR.join("test.lock")).unwrap();
979 FileExt::lock_exclusive(&test_flock).unwrap();
980
981 let cached_function_info_map_path = SCRATCH_DIR.join(CACHE_FILE_NAME);
982
983 pretty_assert_eq!(
984 cached_function_info_map_path.exists(),
985 false,
986 "could not remove file {}",
987 cached_function_info_map_path.display()
988 );
989
990 f();
991
992 if cached_function_info_map_path.exists() {
993 std::fs::remove_file(cached_function_info_map_path).unwrap();
994 }
995
996 FileExt::unlock(&test_flock).unwrap();
997 }
998
999 mod to_snake_case {
1000 use super::*;
1001
1002 #[test]
1003 fn camel_case() {
1004 let input = "camelCaseString".to_string();
1005 let expected = "camel_case_string";
1006
1007 pretty_assert_eq!(input.to_snake_case(), expected);
1008 }
1009
1010 #[test]
1011 fn short_camel_case() {
1012 let input = "aB".to_string();
1013 let expected = "a_b";
1014
1015 pretty_assert_eq!(input.to_snake_case(), expected);
1016 }
1017
1018 #[test]
1019 fn pascal_case() {
1020 let input = "PascalCaseString".to_string();
1021 let expected = "pascal_case_string";
1022
1023 pretty_assert_eq!(input.to_snake_case(), expected);
1024 }
1025
1026 #[test]
1027 fn pascal_case_with_leading_acronym() {
1028 let input = "ASCIIEncodedString".to_string();
1029 let expected = "ascii_encoded_string";
1030
1031 pretty_assert_eq!(input.to_snake_case(), expected);
1032 }
1033
1034 #[test]
1035 fn pascal_case_with_trailing_acronym() {
1036 let input = "IsASCII".to_string();
1037 let expected = "is_ascii";
1038
1039 pretty_assert_eq!(input.to_snake_case(), expected);
1040 }
1041
1042 #[test]
1043 fn screaming_snake_case() {
1044 let input = "PFN_WDF_DRIVER_DEVICE_ADD".to_string();
1045 let expected = "pfn_wdf_driver_device_add";
1046
1047 pretty_assert_eq!(input.to_snake_case(), expected);
1048 }
1049
1050 #[test]
1051 fn screaming_snake_case_with_leading_acronym() {
1052 let input = "ASCII_STRING".to_string();
1053 let expected = "ascii_string";
1054
1055 pretty_assert_eq!(input.to_snake_case(), expected);
1056 }
1057
1058 #[test]
1059 fn screaming_snake_case_with_leading_underscore() {
1060 let input = "_WDF_DRIVER_INIT_FLAGS".to_string();
1061 let expected = "_wdf_driver_init_flags";
1062
1063 pretty_assert_eq!(input.to_snake_case(), expected);
1064 }
1065
1066 #[test]
1067 fn snake_case() {
1068 let input = "snake_case_string".to_string();
1069 let expected = "snake_case_string";
1070
1071 pretty_assert_eq!(input.to_snake_case(), expected);
1072 }
1073
1074 #[test]
1075 fn snake_case_with_leading_underscore() {
1076 let input = "_snake_case_with_leading_underscore".to_string();
1077 let expected = "_snake_case_with_leading_underscore";
1078
1079 pretty_assert_eq!(input.to_snake_case(), expected);
1080 }
1081 }
1082
1083 mod inputs {
1084 use super::*;
1085
1086 mod parse {
1087 use super::*;
1088
1089 #[test]
1090 fn valid_input() {
1091 let input_tokens = quote! { "/path/to/generated/types/file.rs", WdfDriverCreate, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output };
1092 let expected = Inputs {
1093 types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1094 wdf_function_identifier: format_ident!("WdfDriverCreate"),
1095 wdf_function_arguments: parse_quote! {
1096 driver,
1097 registry_path,
1098 WDF_NO_OBJECT_ATTRIBUTES,
1099 &mut driver_config,
1100 driver_handle_output
1101 },
1102 };
1103
1104 pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1105 }
1106
1107 #[test]
1108 fn valid_input_with_trailing_comma() {
1109 let input_tokens = quote! { "/path/to/generated/types/file.rs" , WdfDriverCreate, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output, };
1110 let expected = Inputs {
1111 types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1112 wdf_function_identifier: format_ident!("WdfDriverCreate"),
1113 wdf_function_arguments: parse_quote! {
1114 driver,
1115 registry_path,
1116 WDF_NO_OBJECT_ATTRIBUTES,
1117 &mut driver_config,
1118 driver_handle_output,
1119 },
1120 };
1121
1122 pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1123 }
1124
1125 #[test]
1126 fn wdf_function_with_no_arguments() {
1127 let input_tokens =
1128 quote! { "/path/to/generated/types/file.rs", WdfVerifierDbgBreakPoint };
1129 let expected = Inputs {
1130 types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1131 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1132 wdf_function_arguments: Punctuated::new(),
1133 };
1134
1135 pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1136 }
1137
1138 #[test]
1139 fn wdf_function_with_no_arguments_and_trailing_comma() {
1140 let input_tokens =
1141 quote! { "/path/to/generated/types/file.rs", WdfVerifierDbgBreakPoint, };
1142 let expected = Inputs {
1143 types_path: parse_quote! { "/path/to/generated/types/file.rs" },
1144 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1145 wdf_function_arguments: Punctuated::new(),
1146 };
1147
1148 pretty_assert_eq!(parse2::<Inputs>(input_tokens).unwrap(), expected);
1149 }
1150
1151 #[test]
1152 fn invalid_ident() {
1153 let input_tokens = quote! { "/path/to/generated/types/file.rs", 23InvalidIdent, driver, registry_path, WDF_NO_OBJECT_ATTRIBUTES, &mut driver_config, driver_handle_output, };
1154 let expected = Error::new(Span::call_site(), "expected identifier");
1155
1156 pretty_assert_eq!(
1157 parse2::<Inputs>(input_tokens).unwrap_err().to_string(),
1158 expected.to_string()
1159 );
1160 }
1161 }
1162
1163 mod generate_derived_ast_fragments {
1164 use super::*;
1165
1166 #[test]
1167 fn valid_input() {
1168 with_file_lock_clean_env(|| {
1169 let inputs = Inputs {
1170 types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1171 wdf_function_identifier: format_ident!("WdfDriverCreate"),
1172 wdf_function_arguments: parse_quote! {
1173 driver,
1174 registry_path,
1175 WDF_NO_OBJECT_ATTRIBUTES,
1176 &mut driver_config,
1177 driver_handle_output,
1178 },
1179 };
1180 let expected = DerivedASTFragments {
1181 function_pointer_type: format_ident!("PFN_WDFDRIVERCREATE"),
1182 function_table_index: format_ident!("WdfDriverCreateTableIndex"),
1183 parameters: parse_quote! {
1184 driver_object__: PDRIVER_OBJECT,
1185 registry_path__: PCUNICODE_STRING,
1186 driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1187 driver_config__: PWDF_DRIVER_CONFIG,
1188 driver__: *mut WDFDRIVER
1189 },
1190 parameter_identifiers: parse_quote! {
1191 driver_object__,
1192 registry_path__,
1193 driver_attributes__,
1194 driver_config__,
1195 driver__
1196 },
1197 return_type: parse_quote! { -> NTSTATUS },
1198 arguments: parse_quote! {
1199 driver,
1200 registry_path,
1201 WDF_NO_OBJECT_ATTRIBUTES,
1202 &mut driver_config,
1203 driver_handle_output,
1204 },
1205 inline_wdf_fn_name: format_ident!("wdf_driver_create_impl"),
1206 };
1207
1208 pretty_assert_eq!(inputs.generate_derived_ast_fragments().unwrap(), expected);
1209 });
1210 }
1211
1212 #[test]
1213 fn valid_input_with_no_arguments() {
1214 with_file_lock_clean_env(|| {
1215 let inputs = Inputs {
1216 types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1217 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1218 wdf_function_arguments: Punctuated::new(),
1219 };
1220 let expected = DerivedASTFragments {
1221 function_pointer_type: format_ident!("PFN_WDFVERIFIERDBGBREAKPOINT"),
1222 function_table_index: format_ident!("WdfVerifierDbgBreakPointTableIndex"),
1223 parameters: Punctuated::new(),
1224 parameter_identifiers: Punctuated::new(),
1225 return_type: ReturnType::Default,
1226 arguments: Punctuated::new(),
1227 inline_wdf_fn_name: format_ident!("wdf_verifier_dbg_break_point_impl"),
1228 };
1229
1230 pretty_assert_eq!(inputs.generate_derived_ast_fragments().unwrap(), expected);
1231 });
1232 }
1233 }
1234 }
1235
1236 mod get_wdf_function_info_map {
1237 use super::*;
1238
1239 #[test]
1240 fn valid_input_no_cache() {
1241 with_file_lock_clean_env(|| {
1242 let inputs = Inputs {
1243 types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1244 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1245 wdf_function_arguments: Punctuated::new(),
1246 };
1247
1248 let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1249 expected.insert(
1250 "WdfDriverCreate".into(),
1251 CachedFunctionInfo {
1252 parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1253 PCUNICODE_STRING , driver_attributes__ : \
1254 PWDF_OBJECT_ATTRIBUTES , driver_config__ : \
1255 PWDF_DRIVER_CONFIG , driver__ : * mut WDFDRIVER"
1256 .into(),
1257 return_type: "-> NTSTATUS".into(),
1258 },
1259 );
1260
1261 expected.insert(
1262 "WdfVerifierDbgBreakPoint".into(),
1263 CachedFunctionInfo {
1264 parameters: String::new(),
1265 return_type: String::new(),
1266 },
1267 );
1268 pretty_assert_eq!(
1269 get_wdf_function_info_map(
1270 &inputs.types_path,
1271 inputs.wdf_function_identifier.span()
1272 )
1273 .unwrap(),
1274 expected
1275 );
1276
1277 pretty_assert_eq!(SCRATCH_DIR.join(CACHE_FILE_NAME).exists(), true);
1278 });
1279 }
1280
1281 #[test]
1282 fn valid_input_cache_exists() {
1283 with_file_lock_clean_env(|| {
1284 let inputs = Inputs {
1285 types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1286 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1287 wdf_function_arguments: Punctuated::new(),
1288 };
1289 get_wdf_function_info_map(
1292 &inputs.types_path,
1293 inputs.wdf_function_identifier.span(),
1294 )
1295 .unwrap();
1296
1297 pretty_assert_eq!(SCRATCH_DIR.join(CACHE_FILE_NAME).exists(), true);
1299
1300 let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1301 expected.insert(
1302 "WdfDriverCreate".into(),
1303 CachedFunctionInfo {
1304 parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1305 PCUNICODE_STRING , driver_attributes__ : \
1306 PWDF_OBJECT_ATTRIBUTES , driver_config__ : \
1307 PWDF_DRIVER_CONFIG , driver__ : * mut WDFDRIVER"
1308 .into(),
1309 return_type: "-> NTSTATUS".into(),
1310 },
1311 );
1312
1313 expected.insert(
1314 "WdfVerifierDbgBreakPoint".into(),
1315 CachedFunctionInfo {
1316 parameters: String::new(),
1317 return_type: String::new(),
1318 },
1319 );
1320 pretty_assert_eq!(
1321 get_wdf_function_info_map(
1322 &inputs.types_path,
1323 inputs.wdf_function_identifier.span()
1324 )
1325 .unwrap(),
1326 expected
1327 );
1328 });
1329 }
1330 }
1331
1332 mod generate_wdf_function_info_file_cache {
1333 use super::*;
1334
1335 #[test]
1336 fn valid_input() {
1337 let inputs = Inputs {
1338 types_path: parse_quote! { "tests/unit-tests-input/generated-types.rs" },
1339 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1340 wdf_function_arguments: Punctuated::new(),
1341 };
1342
1343 let mut expected: BTreeMap<String, CachedFunctionInfo> = BTreeMap::new();
1344 expected.insert(
1345 "WdfDriverCreate".into(),
1346 CachedFunctionInfo {
1347 parameters: "driver_object__ : PDRIVER_OBJECT , registry_path__ : \
1348 PCUNICODE_STRING , driver_attributes__ : PWDF_OBJECT_ATTRIBUTES \
1349 , driver_config__ : PWDF_DRIVER_CONFIG , driver__ : * mut \
1350 WDFDRIVER"
1351 .into(),
1352 return_type: "-> NTSTATUS".into(),
1353 },
1354 );
1355
1356 expected.insert(
1357 "WdfVerifierDbgBreakPoint".into(),
1358 CachedFunctionInfo {
1359 parameters: String::new(),
1360 return_type: String::new(),
1361 },
1362 );
1363
1364 pretty_assert_eq!(
1365 generate_wdf_function_info_file_cache(
1366 &inputs.types_path,
1367 inputs.wdf_function_identifier.span()
1368 )
1369 .unwrap(),
1370 expected
1371 );
1372 }
1373
1374 #[test]
1375 fn invalid_input_missing_wdf_func_enum() {
1376 let inputs = Inputs {
1377 types_path: parse_quote! { "tests/unit-tests-input/missing-wdf-func-enum.rs" },
1378 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1379 wdf_function_arguments: Punctuated::new(),
1380 };
1381
1382 let expected = Error::new(
1383 Span::call_site(),
1384 "Failed to find _WDFFUNCENUM module in types.rs file",
1385 );
1386
1387 pretty_assert_eq!(
1388 generate_wdf_function_info_file_cache(
1389 &inputs.types_path,
1390 inputs.wdf_function_identifier.span()
1391 )
1392 .unwrap_err()
1393 .to_string(),
1394 expected.to_string()
1395 );
1396 }
1397
1398 #[test]
1399 fn invalid_input_missing_wdf_func_enum_contents() {
1400 let inputs = Inputs {
1401 types_path: parse_quote! { "tests/unit-tests-input/missing-wdf-func-enum-contents.rs" },
1402 wdf_function_identifier: format_ident!("WdfVerifierDbgBreakPoint"),
1403 wdf_function_arguments: Punctuated::new(),
1404 };
1405
1406 let expected = Error::new(
1407 Span::call_site(),
1408 "Failed to find _WDFFUNCENUM module contents in types.rs file",
1409 );
1410
1411 pretty_assert_eq!(
1412 generate_wdf_function_info_file_cache(
1413 &inputs.types_path,
1414 inputs.wdf_function_identifier.span()
1415 )
1416 .unwrap_err()
1417 .to_string(),
1418 expected.to_string()
1419 );
1420 }
1421 }
1422
1423 mod generate_cached_function_info {
1424 use super::*;
1425
1426 #[test]
1427 fn valid_input() {
1428 let types_ast = parse_quote! {
1431 pub type PFN_WDFIOQUEUEPURGESYNCHRONOUSLY = ::core::option::Option<
1432 unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS, Queue: WDFQUEUE),
1433 >;
1434 };
1435 let function_pointer_type = format_ident!("PFN_WDFIOQUEUEPURGESYNCHRONOUSLY");
1436 let expected: Option<CachedFunctionInfo> = Some(
1437 (
1438 parse_quote! {
1439 queue__: WDFQUEUE
1440 },
1441 ReturnType::Default,
1442 )
1443 .into(),
1444 );
1445
1446 pretty_assert_eq!(
1447 generate_cached_function_info(&types_ast, &function_pointer_type).unwrap(),
1448 expected
1449 );
1450 }
1451 }
1452
1453 mod find_type_alias_definition {
1454 use super::*;
1455
1456 #[test]
1457 fn valid_input() {
1458 let types_ast = parse_quote! {
1461 pub type WDF_DRIVER_GLOBALS = _WDF_DRIVER_GLOBALS;
1462 pub type PWDF_DRIVER_GLOBALS = *mut _WDF_DRIVER_GLOBALS;
1463 pub mod _WDFFUNCENUM {
1464 pub type Type = ::core::ffi::c_int;
1465 pub const WdfChildListCreateTableIndex: Type = 0;
1466 pub const WdfChildListGetDeviceTableIndex: Type = 1;
1467 pub const WdfChildListRetrievePdoTableIndex: Type = 2;
1468 pub const WdfChildListRetrieveAddressDescriptionTableIndex: Type = 3;
1469 pub const WdfChildListBeginScanTableIndex: Type = 4;
1470 pub const WdfChildListEndScanTableIndex: Type = 5;
1471 pub const WdfChildListBeginIterationTableIndex: Type = 6;
1472 pub const WdfChildListRetrieveNextDeviceTableIndex: Type = 7;
1473 pub const WdfChildListEndIterationTableIndex: Type = 8;
1474 pub const WdfChildListAddOrUpdateChildDescriptionAsPresentTableIndex: Type = 9;
1475 pub const WdfChildListUpdateChildDescriptionAsMissingTableIndex: Type = 10;
1476 pub const WdfChildListUpdateAllChildDescriptionsAsPresentTableIndex: Type = 11;
1477 pub const WdfChildListRequestChildEjectTableIndex: Type = 12;
1478 }
1479 pub type PFN_WDFGETTRIAGEINFO = ::core::option::Option<
1480 unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS) -> PVOID,
1481 >;
1482 };
1483 let function_pointer_type = format_ident!("PFN_WDFGETTRIAGEINFO");
1484 let expected = parse_quote! {
1485 pub type PFN_WDFGETTRIAGEINFO = ::core::option::Option<
1486 unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS) -> PVOID,
1487 >;
1488 };
1489
1490 pretty_assert_eq!(
1491 find_type_alias_definition(&types_ast, &function_pointer_type).unwrap(),
1492 &expected
1493 );
1494 }
1495 }
1496
1497 mod extract_fn_pointer_definition {
1498 use super::*;
1499
1500 #[test]
1501 fn valid_input() {
1502 let fn_type_alias = parse_quote! {
1503 pub type PFN_WDFDRIVERCREATE = ::core::option::Option<
1504 unsafe extern "C" fn(
1505 DriverGlobals: PWDF_DRIVER_GLOBALS,
1506 DriverObject: PDRIVER_OBJECT,
1507 RegistryPath: PCUNICODE_STRING,
1508 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1509 DriverConfig: PWDF_DRIVER_CONFIG,
1510 Driver: *mut WDFDRIVER,
1511 ) -> NTSTATUS
1512 >;
1513 };
1514 let expected = parse_quote! {
1515 ::core::option::Option<
1516 unsafe extern "C" fn(
1517 DriverGlobals: PWDF_DRIVER_GLOBALS,
1518 DriverObject: PDRIVER_OBJECT,
1519 RegistryPath: PCUNICODE_STRING,
1520 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1521 DriverConfig: PWDF_DRIVER_CONFIG,
1522 Driver: *mut WDFDRIVER,
1523 ) -> NTSTATUS
1524 >
1525 };
1526
1527 pretty_assert_eq!(
1528 extract_fn_pointer_definition(&fn_type_alias, Span::call_site()).unwrap(),
1529 &expected
1530 );
1531 }
1532
1533 #[test]
1534 fn valid_input_with_no_arguments() {
1535 let fn_type_alias = parse_quote! {
1536 pub type PFN_WDFVERIFIERDBGBREAKPOINT = ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>;
1537 };
1538 let expected = parse_quote! {
1539 ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>
1540 };
1541
1542 pretty_assert_eq!(
1543 extract_fn_pointer_definition(&fn_type_alias, Span::call_site()).unwrap(),
1544 &expected
1545 );
1546 }
1547 }
1548
1549 mod parse_fn_pointer_definition {
1550 use super::*;
1551
1552 #[test]
1553 fn valid_input() {
1554 let fn_pointer_typepath = parse_quote! {
1556 ::core::option::Option<unsafe extern "C" fn(
1557 DriverGlobals: PWDF_DRIVER_GLOBALS,
1558 DriverObject: PDRIVER_OBJECT,
1559 RegistryPath: PCUNICODE_STRING,
1560 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1561 DriverConfig: PWDF_DRIVER_CONFIG,
1562 Driver: *mut WDFDRIVER,
1563 ) -> NTSTATUS>
1564 };
1565 let expected = (
1566 parse_quote! {
1567 driver_object__: PDRIVER_OBJECT,
1568 registry_path__: PCUNICODE_STRING,
1569 driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1570 driver_config__: PWDF_DRIVER_CONFIG,
1571 driver__: *mut WDFDRIVER
1572 },
1573 ReturnType::Type(
1574 Token),
1575 Box::new(Type::Path(parse_quote! { NTSTATUS })),
1576 ),
1577 );
1578
1579 pretty_assert_eq!(
1580 parse_fn_pointer_definition(&fn_pointer_typepath, Span::call_site()).unwrap(),
1581 expected
1582 );
1583 }
1584
1585 #[test]
1586 fn valid_input_with_no_arguments() {
1587 let fn_pointer_typepath = parse_quote! {
1589 ::core::option::Option<unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)>
1590 };
1591 let expected = (Punctuated::new(), ReturnType::Default);
1592
1593 pretty_assert_eq!(
1594 parse_fn_pointer_definition(&fn_pointer_typepath, Span::call_site()).unwrap(),
1595 expected
1596 );
1597 }
1598 }
1599
1600 mod extract_bare_fn_type {
1601 use super::*;
1602
1603 #[test]
1604 fn valid_input() {
1605 let fn_pointer_typepath = parse_quote! {
1607 ::core::option::Option<
1608 unsafe extern "C" fn(
1609 DriverGlobals: PWDF_DRIVER_GLOBALS,
1610 DriverObject: PDRIVER_OBJECT,
1611 RegistryPath: PCUNICODE_STRING,
1612 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1613 DriverConfig: PWDF_DRIVER_CONFIG,
1614 Driver: *mut WDFDRIVER,
1615 ) -> NTSTATUS,
1616 >
1617 };
1618 let expected: TypeBareFn = parse_quote! {
1619 unsafe extern "C" fn(
1620 DriverGlobals: PWDF_DRIVER_GLOBALS,
1621 DriverObject: PDRIVER_OBJECT,
1622 RegistryPath: PCUNICODE_STRING,
1623 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1624 DriverConfig: PWDF_DRIVER_CONFIG,
1625 Driver: *mut WDFDRIVER,
1626 ) -> NTSTATUS
1627 };
1628
1629 pretty_assert_eq!(
1630 extract_bare_fn_type(&fn_pointer_typepath, Span::call_site()).unwrap(),
1631 &expected
1632 );
1633 }
1634 }
1635
1636 mod compute_fn_parameters {
1637 use super::*;
1638
1639 #[test]
1640 fn valid_input() {
1641 let bare_fn_type = parse_quote! {
1643 unsafe extern "C" fn(
1644 DriverGlobals: PWDF_DRIVER_GLOBALS,
1645 DriverObject: PDRIVER_OBJECT,
1646 RegistryPath: PCUNICODE_STRING,
1647 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1648 DriverConfig: PWDF_DRIVER_CONFIG,
1649 Driver: *mut WDFDRIVER,
1650 ) -> NTSTATUS
1651 };
1652 let expected = parse_quote! {
1653 driver_object__: PDRIVER_OBJECT,
1654 registry_path__: PCUNICODE_STRING,
1655 driver_attributes__: PWDF_OBJECT_ATTRIBUTES,
1656 driver_config__: PWDF_DRIVER_CONFIG,
1657 driver__: *mut WDFDRIVER
1658 };
1659
1660 pretty_assert_eq!(
1661 compute_fn_parameters(&bare_fn_type, Span::call_site()).unwrap(),
1662 expected
1663 );
1664 }
1665
1666 #[test]
1667 fn valid_input_with_no_arguments() {
1668 let bare_fn_type = parse_quote! {
1670 unsafe extern "C" fn(DriverGlobals: PWDF_DRIVER_GLOBALS)
1671 };
1672 let expected = Punctuated::new();
1673
1674 pretty_assert_eq!(
1675 compute_fn_parameters(&bare_fn_type, Span::call_site()).unwrap(),
1676 expected
1677 );
1678 }
1679 }
1680
1681 mod compute_return_type {
1682 use super::*;
1683
1684 #[test]
1685 fn ntstatus() {
1686 let bare_fn_type = parse_quote! {
1688 unsafe extern "C" fn(
1689 DriverGlobals: PWDF_DRIVER_GLOBALS,
1690 DriverObject: PDRIVER_OBJECT,
1691 RegistryPath: PCUNICODE_STRING,
1692 DriverAttributes: PWDF_OBJECT_ATTRIBUTES,
1693 DriverConfig: PWDF_DRIVER_CONFIG,
1694 Driver: *mut WDFDRIVER,
1695 ) -> NTSTATUS
1696 };
1697 let expected = ReturnType::Type(
1698 Token),
1699 Box::new(Type::Path(parse_quote! { NTSTATUS })),
1700 );
1701
1702 pretty_assert_eq!(compute_return_type(&bare_fn_type), expected);
1703 }
1704
1705 #[test]
1706 fn unit() {
1707 let bare_fn_type = parse_quote! {
1709 unsafe extern "C" fn(
1710 DriverGlobals: PWDF_DRIVER_GLOBALS,
1711 SpinLock: WDFSPINLOCK
1712 )
1713 };
1714 let expected = ReturnType::Default;
1715
1716 pretty_assert_eq!(compute_return_type(&bare_fn_type), expected);
1717 }
1718 }
1719
1720 mod generate_must_use_attribute {
1721 use super::*;
1722
1723 #[test]
1724 fn unit_return_type() {
1725 let return_type = ReturnType::Default;
1726 let generated_must_use_attribute_tokens = generate_must_use_attribute(&return_type);
1727
1728 pretty_assert_eq!(generated_must_use_attribute_tokens, None);
1729 }
1730
1731 #[test]
1732 fn ntstatus_return_type() {
1733 let return_type: ReturnType = parse_quote! { -> NTSTATUS };
1734 let expected_tokens = quote! { #[must_use] };
1735 let generated_must_use_attribute_tokens = generate_must_use_attribute(&return_type);
1736
1737 pretty_assert_eq!(
1738 generated_must_use_attribute_tokens
1739 .unwrap()
1740 .into_token_stream()
1741 .to_string(),
1742 expected_tokens.to_string(),
1743 );
1744 }
1745 }
1746}