wvwasi_macro/
lib.rs

1use quote::{quote, ToTokens};
2use proc_macro::TokenStream;
3use syn::parse::{Parse, ParseStream};
4use syn::spanned::Spanned;
5
6// debug command: cargo expand > test.rs
7
8
9/// Create a trait with a method called create_disp_type_info. The create_disp_type_info method returns an ITypeInfo for the following trait.
10///
11/// # Example
12/// ```rust,ignore
13/// #[wvwasi_macro::create_type_info_crate(ISyncIPCHandler_TypeInfo)]
14/// #[windows::core::interface("094d70d6-5202-44b8-abb8-43860da5aca2")]
15/// unsafe trait ISyncIPCHandler: windows::core::IUnknown {
16///   unsafe fn test(&self, message: windows::core::BSTR) -> windows::core::BSTR;
17///   unsafe fn test2(&self) -> u16;
18/// }
19///
20/// #[windows::core::implement(windows::Win32::System::Com::IDispatch, ISyncIPCHandler)]
21/// struct SyncIPCHandler {
22///   type_info: ITypeInfo,
23/// }
24///
25/// impl ISyncIPCHandler_TypeInfo for SyncIPCHandler {}
26/// impl ISyncIPCHandler_Impl for SyncIPCHandler {
27///   unsafe fn test(&self, message: windows::core::BSTR) {
28///     windows::core::BSTR::default()
29///   }
30///   unsafe fn test2(&self) -> u16 {
31///     123u16
32///   }
33/// }
34/// impl IDispatch_Impl for SyncIPCHandler {
35///   fn GetTypeInfoCount(&self) -> windows::core::Result<u32> {
36///     Ok(1)
37///   }
38///   fn GetTypeInfo(&self, itinfo: u32, _lcid: u32) -> windows::core::Result<ITypeInfo> {
39///     if itinfo != 0 {
40///       Err(windows::Win32::Foundation::DISP_E_BADINDEX.into())
41///     } else {
42///       Ok(self.type_info.clone())
43///     }
44///   }
45///   fn GetIDsOfNames(
46///     &self,
47///     riid: *const windows::core::GUID,
48///     rgsznames: *const windows::core::PCWSTR,
49///     cnames: u32,
50///     _lcid: u32,
51///     rgdispid: *mut i32,
52///   ) -> windows::core::Result<()> {
53///     unsafe {
54///       if riid.is_null() || *riid != windows::core::GUID::default() {
55///         Err(windows::Win32::Foundation::DISP_E_UNKNOWNINTERFACE.into())
56///       } else {
57///         windows::Win32::System::Ole::DispGetIDsOfNames(&self.type_info, rgsznames, cnames, rgdispid)
58///       }
59///     }
60///   }
61///   fn Invoke(
62///     &self,
63///     dispidmember: i32,
64///     riid: *const windows::core::GUID,
65///     _lcid: u32,
66///     wflags: DISPATCH_FLAGS,
67///     pdispparams: *const DISPPARAMS,
68///     pvarresult: *mut VARIANT,
69///     pexcepinfo: *mut EXCEPINFO,
70///     puargerr: *mut u32,
71///   ) -> windows::core::Result<()> {
72///     unsafe {
73///       if riid.is_null() || *riid != windows::core::GUID::default() {
74///         Err(windows::Win32::Foundation::DISP_E_UNKNOWNINTERFACE.into())
75///       } else {
76///         let this: ISyncIPCHandler = self.cast()?;
77///     
78///         let mut dispparams = if pdispparams.is_null() {
79///           None
80///         } else {
81///           Some(*pdispparams)
82///         };
83///         let pdispparams_mut = dispparams
84///           .as_mut()
85///           .map(|x| x as _)
86///           .unwrap_or(std::ptr::null_mut());
87///     
88///         windows::Win32::System::Ole::DispInvoke(
89///           windows::core::Interface::as_raw(&this),
90///           &self.type_info,
91///           dispidmember,
92///           wflags.0,
93///           pdispparams_mut,
94///           pvarresult,
95///           pexcepinfo,
96///           puargerr,
97///         )
98///       }
99///     }
100///   }
101/// }
102///
103/// fn main() {
104///   let type_info = SyncIPCHandler::create_disp_type_info();
105///   let host_object = SyncIPCHandler { type_info };
106/// }
107/// ```
108#[proc_macro_attribute]
109pub fn create_type_info_crate(attributes: TokenStream, input: TokenStream) -> TokenStream {
110  struct CrateName(Option<syn::Ident>);
111  impl Parse for CrateName {
112    fn parse(cursor: ParseStream) -> syn::Result<Self> {
113        let ident: Option<syn::Ident> = cursor.parse().ok();
114
115        Ok(Self(ident))
116    }
117  }
118
119  struct InterfaceMethodArg {
120    /// The type of the argument
121    pub ty: Box<syn::Type>,
122    /// The name of the argument
123    pub pat: Box<syn::Pat>,
124  }
125
126  struct InterfaceMethod {
127    pub name: syn::Ident,
128    // pub visibility: syn::Visibility,
129    pub args: Vec<InterfaceMethodArg>,
130    pub ret: syn::ReturnType,
131    // pub docs: Vec<syn::Attribute>,
132  }
133  macro_rules! bail {
134    ($item:expr, $($msg:tt),*) => {
135        return Err(syn::Error::new($item.span(), std::fmt::format(format_args!($($msg),*))));
136    };
137
138  }
139  macro_rules! unexpected_token {
140    ($item:expr, $msg:expr) => {
141        if let Some(i) = $item {
142            bail!(i, "unexpected {}", $msg);
143        }
144    };
145  }
146  macro_rules! expected_token {
147    ($sig:tt.$item:tt(), $msg:expr) => {
148        if let None = $sig.$item() {
149            bail!($sig, "expected {}", $msg);
150        }
151    };
152  }
153  impl syn::parse::Parse for InterfaceMethod {
154    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
155      let docs = input.call(syn::Attribute::parse_outer)?;
156      let _visibility = input.parse::<syn::Visibility>()?;
157      let method = input.parse::<syn::TraitItemFn>()?;
158      unexpected_token!(docs.iter().find(|a| !a.path().is_ident("doc")), "attribute");
159      unexpected_token!(method.default, "default method implementation");
160      let sig = method.sig;
161      unexpected_token!(sig.abi, "abi declaration");
162      unexpected_token!(sig.asyncness, "async declaration");
163      unexpected_token!(sig.generics.params.iter().next(), "generics declaration");
164      unexpected_token!(sig.constness, "const declaration");
165      expected_token!(sig.receiver(), "the method to have &self as its first argument");
166      unexpected_token!(sig.variadic, "variadic args");
167      let args = sig
168        .inputs
169        .into_iter()
170        .filter_map(|a| match a {
171          syn::FnArg::Receiver(_) => None,
172          syn::FnArg::Typed(p) => Some(p),
173        })
174        .map(|p| Ok(InterfaceMethodArg { ty: p.ty, pat: p.pat }))
175        .collect::<Result<Vec<InterfaceMethodArg>, syn::Error>>()?;
176
177      let ret = sig.output;
178      Ok(InterfaceMethod { name: sig.ident, args, ret })
179    }
180  }
181
182  struct Interface {
183    visibility: syn::Visibility,
184    name: syn::Ident,
185    // parent: Option<syn::Path>,
186    methods: Vec<InterfaceMethod>,
187  }
188  impl Parse for Interface {
189    fn parse(input: ParseStream) -> syn::Result<Self> {
190      let _attributes = input.call(syn::Attribute::parse_outer)?;
191      // let mut attributes_vec = Vec::new();
192      // for attr in attributes.into_iter() {
193      //     let path = attr.path();
194      //     if path.is_ident("interface") {
195      //       attributes_vec.push(attr);
196      //     } else {
197      //       return Err(syn::Error::new(path.span(), "Unrecognized attribute "));
198      //     }
199      // }
200
201      let visibility = input.parse::<syn::Visibility>()?;
202      _ = input.parse::<syn::Token![unsafe]>()?;
203      _ = input.parse::<syn::Token![trait]>()?;
204      let name = input.parse::<syn::Ident>()?;
205      _ = input.parse::<syn::Token![:]>();
206      let _parent = input.parse::<syn::Path>().ok();
207      let content;
208      syn::braced!(content in input);
209      let mut methods = Vec::new();
210      while !content.is_empty() {
211        methods.push(content.parse::<InterfaceMethod>()?);
212      }
213      Ok(Self { visibility, methods, name })
214    }
215  }
216
217  let crate_name = syn::parse_macro_input!(attributes as CrateName);
218  let input_clone = input.clone();
219  let interface = syn::parse_macro_input!(input_clone as Interface);
220  let vis = interface.visibility;
221  let name = if let Some(crate_name) = crate_name.0 {
222    crate_name
223  } else {
224    quote::format_ident!("{}_TypeInfo", interface.name)
225  };
226  let methods = interface
227    .methods
228    .iter()
229    .enumerate()
230    .map(|(i, m)| {
231        let name = &m.name;
232        let i = proc_macro2::Literal::usize_unsuffixed(i);
233        let name = proc_macro2::Literal::string(&name.to_string());
234        let args_len = proc_macro2::Literal::usize_unsuffixed(m.args.len());
235        let args = m.args
236          .iter()
237          .map(|a| {
238              let pat = proc_macro2::Literal::string(&a.pat.to_token_stream().to_string());
239              let ty = &a.ty;
240              
241              quote! {{
242                ::windows::Win32::System::Ole::PARAMDATA {
243                  szName: pwstr!(#pat),
244                  vt: #ty::COM_VARIANT, // e.g. VT_BSTR
245                }
246              }}
247          })
248          .collect::<Vec<_>>();
249        let ret_type = if let syn::ReturnType::Type(_, ret_type) = &m.ret {
250          if let syn::Type::Tuple(type_tuple) = &**ret_type {
251            if type_tuple.elems.len() == 0 { // -> ()
252              quote!{ ::windows::Win32::System::Variant::VT_VOID }
253            } else {
254              quote!{ #ret_type::COM_VARIANT }
255            }
256          } else {
257            quote!{ #ret_type::COM_VARIANT }
258          }
259        } else {
260          m.ret.to_token_stream()
261        };
262
263        quote! {{
264          static mut ARGS: SyncStatic<[::windows::Win32::System::Ole::PARAMDATA; #args_len]> = SyncStatic([
265            #(#args),*
266          ]);
267          use std::mem::size_of;
268          windows::Win32::System::Ole::METHODDATA {
269            szName: pwstr!(#name),
270            ppdata: unsafe { &mut ARGS.0 as *mut _ },
271            dispid: #i, // method id
272            // #[allow(clippy::identity_op)]
273            iMeth: (size_of::<::windows::core::IUnknown_Vtbl>() / size_of::<fn()>() + #i as usize) as u32, // method index
274            cc: ::windows::Win32::System::Com::CC_STDCALL,
275            cArgs: unsafe { ARGS.0.len() as u32 },
276            wFlags: 1u32 as u16,
277            vtReturn: #ret_type, // e.g. VT_BSTR
278          }
279        },}
280    })
281    .collect::<Vec<_>>();
282  let methods_len = proc_macro2::Literal::usize_unsuffixed(methods.len());
283
284  let extend = quote! {
285    #[allow(non_camel_case_types)]
286    #vis trait #name {
287      fn create_disp_type_info() -> ::windows::Win32::System::Com::ITypeInfo {
288        macro_rules! pwstr {
289          ($string:literal) => {{
290            const UTF16: &[u16] = ::const_utf16::encode_null_terminated!($string);
291            static mut MUTABLE_UTF16: [u16; UTF16.len()] = {
292              let mut out = [0; UTF16.len()];
293              unsafe {
294                ::std::ptr::copy_nonoverlapping(UTF16.as_ptr(), out.as_mut_ptr(), UTF16.len());
295              }
296              out
297            };
298            unsafe { ::windows::core::PWSTR(&mut MUTABLE_UTF16 as *mut _) }
299          }};
300        }
301        #[repr(transparent)]
302        struct SyncStatic<T>(T);
303
304        unsafe impl<T> Sync for SyncStatic<T> {}
305
306        // map rust type to com variant, usage: `use windows::core::BSTR; let a:BSTR::COM_VARIANT;` or `let a:windows::core::BSTR::COM_VARIANT;`. The reason for using traits is that type paths can only be obtained at runtime, and cannot be obtained directly at compile time. Traits have limited scope and do not pollute the original type.
307        use ::windows::Win32::System::Variant;
308        trait bstr_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_BSTR; }
309        impl bstr_com_variant for ::windows::core::BSTR {}
310
311        trait u16_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI2; }
312        impl u16_com_variant for u16 {}
313
314        trait u32_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI4; }
315        impl u32_com_variant for u32 {}
316
317        trait u64_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI8; }
318        impl u64_com_variant for u64 {}
319
320        trait i16_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I2; }
321        impl i16_com_variant for i16 {}
322
323        trait i32_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I4; }
324        impl i32_com_variant for i32 {}
325
326        trait i64_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I8; }
327        impl i64_com_variant for i64 {}
328
329        static mut METHODS: SyncStatic<[::windows::Win32::System::Ole::METHODDATA; #methods_len]> = SyncStatic([
330          #(#methods)*
331        ]);
332
333        static mut INTERFACE: SyncStatic<::windows::Win32::System::Ole::INTERFACEDATA> = SyncStatic(::windows::Win32::System::Ole::INTERFACEDATA {
334          pmethdata: unsafe { &mut METHODS.0 as *mut _ },
335          cMembers: unsafe { METHODS.0.len() as u32 },
336        });
337        let invariant_locale = unsafe {
338          ::windows::Win32::Globalization::LocaleNameToLCID(::windows::Win32::Globalization::LOCALE_NAME_INVARIANT, 0)
339        };
340        let mut type_info = None;
341        unsafe {
342          ::windows::Win32::System::Ole::CreateDispTypeInfo(
343            &mut INTERFACE.0 as *mut _,
344            invariant_locale,
345            &mut type_info as *mut _,
346          ).unwrap();
347        }
348        type_info.unwrap()
349      }
350    }
351  };
352
353  let mut input_clone: proc_macro::TokenStream = input.clone().into();
354  let extend: proc_macro::TokenStream = extend.into();
355  input_clone.extend(core::iter::once(extend));
356  input_clone
357}