1use quote::{quote, ToTokens};
2use proc_macro::TokenStream;
3use syn::parse::{Parse, ParseStream};
4use syn::spanned::Spanned;
5
6#[proc_macro_attribute]
109pub fn create_type_info_crate(attributes: TokenStream, input: TokenStream) -> TokenStream {
110 struct CrateName(Option<syn::Ident>);
111 impl Parse for CrateName {
112 fn parse(cursor: ParseStream) -> syn::Result<Self> {
113 let ident: Option<syn::Ident> = cursor.parse().ok();
114
115 Ok(Self(ident))
116 }
117 }
118
119 struct InterfaceMethodArg {
120 pub ty: Box<syn::Type>,
122 pub pat: Box<syn::Pat>,
124 }
125
126 struct InterfaceMethod {
127 pub name: syn::Ident,
128 pub args: Vec<InterfaceMethodArg>,
130 pub ret: syn::ReturnType,
131 }
133 macro_rules! bail {
134 ($item:expr, $($msg:tt),*) => {
135 return Err(syn::Error::new($item.span(), std::fmt::format(format_args!($($msg),*))));
136 };
137
138 }
139 macro_rules! unexpected_token {
140 ($item:expr, $msg:expr) => {
141 if let Some(i) = $item {
142 bail!(i, "unexpected {}", $msg);
143 }
144 };
145 }
146 macro_rules! expected_token {
147 ($sig:tt.$item:tt(), $msg:expr) => {
148 if let None = $sig.$item() {
149 bail!($sig, "expected {}", $msg);
150 }
151 };
152 }
153 impl syn::parse::Parse for InterfaceMethod {
154 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
155 let docs = input.call(syn::Attribute::parse_outer)?;
156 let _visibility = input.parse::<syn::Visibility>()?;
157 let method = input.parse::<syn::TraitItemFn>()?;
158 unexpected_token!(docs.iter().find(|a| !a.path().is_ident("doc")), "attribute");
159 unexpected_token!(method.default, "default method implementation");
160 let sig = method.sig;
161 unexpected_token!(sig.abi, "abi declaration");
162 unexpected_token!(sig.asyncness, "async declaration");
163 unexpected_token!(sig.generics.params.iter().next(), "generics declaration");
164 unexpected_token!(sig.constness, "const declaration");
165 expected_token!(sig.receiver(), "the method to have &self as its first argument");
166 unexpected_token!(sig.variadic, "variadic args");
167 let args = sig
168 .inputs
169 .into_iter()
170 .filter_map(|a| match a {
171 syn::FnArg::Receiver(_) => None,
172 syn::FnArg::Typed(p) => Some(p),
173 })
174 .map(|p| Ok(InterfaceMethodArg { ty: p.ty, pat: p.pat }))
175 .collect::<Result<Vec<InterfaceMethodArg>, syn::Error>>()?;
176
177 let ret = sig.output;
178 Ok(InterfaceMethod { name: sig.ident, args, ret })
179 }
180 }
181
182 struct Interface {
183 visibility: syn::Visibility,
184 name: syn::Ident,
185 methods: Vec<InterfaceMethod>,
187 }
188 impl Parse for Interface {
189 fn parse(input: ParseStream) -> syn::Result<Self> {
190 let _attributes = input.call(syn::Attribute::parse_outer)?;
191 let visibility = input.parse::<syn::Visibility>()?;
202 _ = input.parse::<syn::Token![unsafe]>()?;
203 _ = input.parse::<syn::Token![trait]>()?;
204 let name = input.parse::<syn::Ident>()?;
205 _ = input.parse::<syn::Token![:]>();
206 let _parent = input.parse::<syn::Path>().ok();
207 let content;
208 syn::braced!(content in input);
209 let mut methods = Vec::new();
210 while !content.is_empty() {
211 methods.push(content.parse::<InterfaceMethod>()?);
212 }
213 Ok(Self { visibility, methods, name })
214 }
215 }
216
217 let crate_name = syn::parse_macro_input!(attributes as CrateName);
218 let input_clone = input.clone();
219 let interface = syn::parse_macro_input!(input_clone as Interface);
220 let vis = interface.visibility;
221 let name = if let Some(crate_name) = crate_name.0 {
222 crate_name
223 } else {
224 quote::format_ident!("{}_TypeInfo", interface.name)
225 };
226 let methods = interface
227 .methods
228 .iter()
229 .enumerate()
230 .map(|(i, m)| {
231 let name = &m.name;
232 let i = proc_macro2::Literal::usize_unsuffixed(i);
233 let name = proc_macro2::Literal::string(&name.to_string());
234 let args_len = proc_macro2::Literal::usize_unsuffixed(m.args.len());
235 let args = m.args
236 .iter()
237 .map(|a| {
238 let pat = proc_macro2::Literal::string(&a.pat.to_token_stream().to_string());
239 let ty = &a.ty;
240
241 quote! {{
242 ::windows::Win32::System::Ole::PARAMDATA {
243 szName: pwstr!(#pat),
244 vt: #ty::COM_VARIANT, }
246 }}
247 })
248 .collect::<Vec<_>>();
249 let ret_type = if let syn::ReturnType::Type(_, ret_type) = &m.ret {
250 if let syn::Type::Tuple(type_tuple) = &**ret_type {
251 if type_tuple.elems.len() == 0 { quote!{ ::windows::Win32::System::Variant::VT_VOID }
253 } else {
254 quote!{ #ret_type::COM_VARIANT }
255 }
256 } else {
257 quote!{ #ret_type::COM_VARIANT }
258 }
259 } else {
260 m.ret.to_token_stream()
261 };
262
263 quote! {{
264 static mut ARGS: SyncStatic<[::windows::Win32::System::Ole::PARAMDATA; #args_len]> = SyncStatic([
265 #(#args),*
266 ]);
267 use std::mem::size_of;
268 windows::Win32::System::Ole::METHODDATA {
269 szName: pwstr!(#name),
270 ppdata: unsafe { &mut ARGS.0 as *mut _ },
271 dispid: #i, iMeth: (size_of::<::windows::core::IUnknown_Vtbl>() / size_of::<fn()>() + #i as usize) as u32, cc: ::windows::Win32::System::Com::CC_STDCALL,
275 cArgs: unsafe { ARGS.0.len() as u32 },
276 wFlags: 1u32 as u16,
277 vtReturn: #ret_type, }
279 },}
280 })
281 .collect::<Vec<_>>();
282 let methods_len = proc_macro2::Literal::usize_unsuffixed(methods.len());
283
284 let extend = quote! {
285 #[allow(non_camel_case_types)]
286 #vis trait #name {
287 fn create_disp_type_info() -> ::windows::Win32::System::Com::ITypeInfo {
288 macro_rules! pwstr {
289 ($string:literal) => {{
290 const UTF16: &[u16] = ::const_utf16::encode_null_terminated!($string);
291 static mut MUTABLE_UTF16: [u16; UTF16.len()] = {
292 let mut out = [0; UTF16.len()];
293 unsafe {
294 ::std::ptr::copy_nonoverlapping(UTF16.as_ptr(), out.as_mut_ptr(), UTF16.len());
295 }
296 out
297 };
298 unsafe { ::windows::core::PWSTR(&mut MUTABLE_UTF16 as *mut _) }
299 }};
300 }
301 #[repr(transparent)]
302 struct SyncStatic<T>(T);
303
304 unsafe impl<T> Sync for SyncStatic<T> {}
305
306 use ::windows::Win32::System::Variant;
308 trait bstr_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_BSTR; }
309 impl bstr_com_variant for ::windows::core::BSTR {}
310
311 trait u16_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI2; }
312 impl u16_com_variant for u16 {}
313
314 trait u32_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI4; }
315 impl u32_com_variant for u32 {}
316
317 trait u64_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_UI8; }
318 impl u64_com_variant for u64 {}
319
320 trait i16_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I2; }
321 impl i16_com_variant for i16 {}
322
323 trait i32_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I4; }
324 impl i32_com_variant for i32 {}
325
326 trait i64_com_variant { const COM_VARIANT: Variant::VARENUM = Variant::VT_I8; }
327 impl i64_com_variant for i64 {}
328
329 static mut METHODS: SyncStatic<[::windows::Win32::System::Ole::METHODDATA; #methods_len]> = SyncStatic([
330 #(#methods)*
331 ]);
332
333 static mut INTERFACE: SyncStatic<::windows::Win32::System::Ole::INTERFACEDATA> = SyncStatic(::windows::Win32::System::Ole::INTERFACEDATA {
334 pmethdata: unsafe { &mut METHODS.0 as *mut _ },
335 cMembers: unsafe { METHODS.0.len() as u32 },
336 });
337 let invariant_locale = unsafe {
338 ::windows::Win32::Globalization::LocaleNameToLCID(::windows::Win32::Globalization::LOCALE_NAME_INVARIANT, 0)
339 };
340 let mut type_info = None;
341 unsafe {
342 ::windows::Win32::System::Ole::CreateDispTypeInfo(
343 &mut INTERFACE.0 as *mut _,
344 invariant_locale,
345 &mut type_info as *mut _,
346 ).unwrap();
347 }
348 type_info.unwrap()
349 }
350 }
351 };
352
353 let mut input_clone: proc_macro::TokenStream = input.clone().into();
354 let extend: proc_macro::TokenStream = extend.into();
355 input_clone.extend(core::iter::once(extend));
356 input_clone
357}