windows_implement/
lib.rs

1//! Implement COM interfaces for Rust types.
2//!
3//! Take a look at [macro@implement] for an example.
4//!
5//! Learn more about Rust for Windows here: <https://github.com/microsoft/windows-rs>
6
7use quote::{quote, ToTokens};
8
9mod r#gen;
10use r#gen::gen_all;
11
12#[cfg(test)]
13mod tests;
14
15/// Implements one or more COM interfaces.
16///
17/// # Example
18/// ```rust,no_run
19/// use windows_core::*;
20///
21/// #[interface("094d70d6-5202-44b8-abb8-43860da5aca2")]
22/// unsafe trait IValue: IUnknown {
23///     fn GetValue(&self, value: *mut i32) -> HRESULT;
24/// }
25///
26/// #[implement(IValue)]
27/// struct Value(i32);
28///
29/// impl IValue_Impl for Value_Impl {
30///     unsafe fn GetValue(&self, value: *mut i32) -> HRESULT {
31///         *value = self.0;
32///         HRESULT(0)
33///     }
34/// }
35///
36/// let object: IValue = Value(123).into();
37/// // Call interface methods...
38/// ```
39#[proc_macro_attribute]
40pub fn implement(
41    attributes: proc_macro::TokenStream,
42    type_tokens: proc_macro::TokenStream,
43) -> proc_macro::TokenStream {
44    implement_core(attributes.into(), type_tokens.into()).into()
45}
46
47fn implement_core(
48    attributes: proc_macro2::TokenStream,
49    item_tokens: proc_macro2::TokenStream,
50) -> proc_macro2::TokenStream {
51    let attributes = syn::parse2::<ImplementAttributes>(attributes).unwrap();
52    let original_type = syn::parse2::<syn::ItemStruct>(item_tokens).unwrap();
53
54    // Do a little thinking and assemble ImplementInputs.  We pass ImplementInputs to
55    // all of our gen_* function.
56    let inputs = ImplementInputs {
57        original_ident: original_type.ident.clone(),
58        interface_chains: convert_implements_to_interface_chains(attributes.implement),
59        trust_level: attributes.trust_level,
60        impl_ident: quote::format_ident!("{}_Impl", &original_type.ident),
61        constraints: {
62            if let Some(where_clause) = &original_type.generics.where_clause {
63                where_clause.predicates.to_token_stream()
64            } else {
65                quote!()
66            }
67        },
68        generics: if !original_type.generics.params.is_empty() {
69            let mut params = quote! {};
70            original_type.generics.params.to_tokens(&mut params);
71            quote! { <#params> }
72        } else {
73            quote! { <> }
74        },
75        is_generic: !original_type.generics.params.is_empty(),
76        original_type,
77    };
78
79    let items = gen_all(&inputs);
80    let mut tokens = inputs.original_type.into_token_stream();
81    for item in items {
82        tokens.extend(item.into_token_stream());
83    }
84
85    tokens
86}
87
88/// This provides the inputs to the `gen_*` functions, which generate the proc macro output.
89struct ImplementInputs {
90    /// The user's type that was marked with `#[implement]`.
91    original_type: syn::ItemStruct,
92
93    /// The identifier for the user's original type definition.
94    original_ident: syn::Ident,
95
96    /// The list of interface chains that this type implements.
97    interface_chains: Vec<InterfaceChain>,
98
99    /// The "trust level", which is returned by `IInspectable::GetTrustLevel`.
100    trust_level: usize,
101
102    /// The identifier of the `Foo_Impl` type.
103    impl_ident: syn::Ident,
104
105    /// The list of constraints needed for this `Foo_Impl` type.
106    constraints: proc_macro2::TokenStream,
107
108    /// The list of generic parameters for this `Foo_Impl` type, including `<` and `>`.
109    /// If there are no generics, this contains `<>`.
110    generics: proc_macro2::TokenStream,
111
112    /// True if the user type has any generic parameters.
113    is_generic: bool,
114}
115
116/// Describes one COM interface chain.
117struct InterfaceChain {
118    /// The name of the field for the vtable chain, e.g. `interface4_ifoo`.
119    field_ident: syn::Ident,
120
121    /// The name of the associated constant item for the vtable chain's initializer,
122    /// e.g. `INTERFACE4_IFOO_VTABLE`.
123    vtable_const_ident: syn::Ident,
124
125    implement: ImplementType,
126}
127
128struct ImplementType {
129    type_name: String,
130    generics: Vec<ImplementType>,
131
132    /// The best span for diagnostics.
133    span: proc_macro2::Span,
134}
135
136impl ImplementType {
137    fn to_ident(&self) -> proc_macro2::TokenStream {
138        let type_name = syn::parse_str::<proc_macro2::TokenStream>(&self.type_name)
139            .expect("Invalid token stream");
140        let generics = self.generics.iter().map(|g| g.to_ident());
141        quote! { #type_name<#(#generics,)*> }
142    }
143    fn to_vtbl_ident(&self) -> proc_macro2::TokenStream {
144        let ident = self.to_ident();
145        quote! {
146            <#ident as ::windows_core::Interface>::Vtable
147        }
148    }
149}
150
151#[derive(Default)]
152struct ImplementAttributes {
153    pub implement: Vec<ImplementType>,
154    pub trust_level: usize,
155}
156
157impl syn::parse::Parse for ImplementAttributes {
158    fn parse(cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
159        let mut input = Self::default();
160
161        while !cursor.is_empty() {
162            input.parse_implement(cursor)?;
163        }
164
165        Ok(input)
166    }
167}
168
169impl ImplementAttributes {
170    fn parse_implement(&mut self, cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<()> {
171        let tree = cursor.parse::<UseTree2>()?;
172        self.walk_implement(&tree, &mut String::new())?;
173
174        if !cursor.is_empty() {
175            cursor.parse::<syn::Token![,]>()?;
176        }
177
178        Ok(())
179    }
180
181    fn walk_implement(
182        &mut self,
183        tree: &UseTree2,
184        namespace: &mut String,
185    ) -> syn::parse::Result<()> {
186        match tree {
187            UseTree2::Path(input) => {
188                if !namespace.is_empty() {
189                    namespace.push_str("::");
190                }
191
192                namespace.push_str(&input.ident.to_string());
193                self.walk_implement(&input.tree, namespace)?;
194            }
195            UseTree2::Name(_) => {
196                self.implement.push(tree.to_element_type(namespace)?);
197            }
198            UseTree2::Group(input) => {
199                for tree in &input.items {
200                    self.walk_implement(tree, namespace)?;
201                }
202            }
203            UseTree2::TrustLevel(input) => self.trust_level = *input,
204        }
205
206        Ok(())
207    }
208}
209
210enum UseTree2 {
211    Path(UsePath2),
212    Name(UseName2),
213    Group(UseGroup2),
214    TrustLevel(usize),
215}
216
217impl UseTree2 {
218    fn to_element_type(&self, namespace: &mut String) -> syn::parse::Result<ImplementType> {
219        match self {
220            UseTree2::Path(input) => {
221                if !namespace.is_empty() {
222                    namespace.push_str("::");
223                }
224
225                namespace.push_str(&input.ident.to_string());
226                input.tree.to_element_type(namespace)
227            }
228            UseTree2::Name(input) => {
229                let mut type_name = input.ident.to_string();
230                let span = input.ident.span();
231
232                if !namespace.is_empty() {
233                    type_name = format!("{namespace}::{type_name}");
234                }
235
236                let mut generics = vec![];
237
238                for g in &input.generics {
239                    generics.push(g.to_element_type(&mut String::new())?);
240                }
241
242                Ok(ImplementType {
243                    type_name,
244                    generics,
245                    span,
246                })
247            }
248            UseTree2::Group(input) => Err(syn::parse::Error::new(
249                input.brace_token.span.join(),
250                "Syntax not supported",
251            )),
252            _ => unimplemented!(),
253        }
254    }
255}
256
257struct UsePath2 {
258    pub ident: syn::Ident,
259    pub tree: Box<UseTree2>,
260}
261
262struct UseName2 {
263    pub ident: syn::Ident,
264    pub generics: Vec<UseTree2>,
265}
266
267struct UseGroup2 {
268    pub brace_token: syn::token::Brace,
269    pub items: syn::punctuated::Punctuated<UseTree2, syn::Token![,]>,
270}
271
272impl syn::parse::Parse for UseTree2 {
273    fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<UseTree2> {
274        let lookahead = input.lookahead1();
275        if lookahead.peek(syn::Ident) {
276            use syn::ext::IdentExt;
277            let ident = input.call(syn::Ident::parse_any)?;
278            if input.peek(syn::Token![::]) {
279                input.parse::<syn::Token![::]>()?;
280                Ok(UseTree2::Path(UsePath2 {
281                    ident,
282                    tree: Box::new(input.parse()?),
283                }))
284            } else if input.peek(syn::Token![=]) {
285                if ident != "TrustLevel" {
286                    return Err(syn::parse::Error::new(
287                        ident.span(),
288                        "Unrecognized key-value pair",
289                    ));
290                }
291                input.parse::<syn::Token![=]>()?;
292                let span = input.span();
293                let value = input.call(syn::Ident::parse_any)?;
294                match value.to_string().as_str() {
295                    "Partial" => Ok(UseTree2::TrustLevel(1)),
296                    "Full" => Ok(UseTree2::TrustLevel(2)),
297                    _ => Err(syn::parse::Error::new(
298                        span,
299                        "`TrustLevel` must be `Partial` or `Full`",
300                    )),
301                }
302            } else {
303                let generics = if input.peek(syn::Token![<]) {
304                    input.parse::<syn::Token![<]>()?;
305                    let mut generics = Vec::new();
306                    loop {
307                        generics.push(input.parse::<UseTree2>()?);
308
309                        if input.parse::<syn::Token![,]>().is_err() {
310                            break;
311                        }
312                    }
313                    input.parse::<syn::Token![>]>()?;
314                    generics
315                } else {
316                    Vec::new()
317                };
318
319                Ok(UseTree2::Name(UseName2 { ident, generics }))
320            }
321        } else if lookahead.peek(syn::token::Brace) {
322            let content;
323            let brace_token = syn::braced!(content in input);
324            let items = content.parse_terminated(UseTree2::parse, syn::Token![,])?;
325
326            Ok(UseTree2::Group(UseGroup2 { brace_token, items }))
327        } else {
328            Err(lookahead.error())
329        }
330    }
331}
332
333fn convert_implements_to_interface_chains(implements: Vec<ImplementType>) -> Vec<InterfaceChain> {
334    let mut chains = Vec::with_capacity(implements.len());
335
336    for (i, implement) in implements.into_iter().enumerate() {
337        // Create an identifier for this interface chain.
338        // We only use this for naming fields; it is never visible to the developer.
339        // This helps with debugging.
340        //
341        // We use i + 1 so that it matches the numbering of our interface offsets. Interface 0
342        // is the "identity" interface.
343
344        let mut ident_string = format!("interface{}", i + 1);
345
346        let suffix = get_interface_ident_suffix(&implement.type_name);
347        if !suffix.is_empty() {
348            ident_string.push('_');
349            ident_string.push_str(&suffix);
350        }
351        let field_ident = syn::Ident::new(&ident_string, implement.span);
352
353        let mut vtable_const_string = ident_string.clone();
354        vtable_const_string.make_ascii_uppercase();
355        vtable_const_string.insert_str(0, "VTABLE_");
356        let vtable_const_ident = syn::Ident::new(&vtable_const_string, implement.span);
357
358        chains.push(InterfaceChain {
359            implement,
360            field_ident,
361            vtable_const_ident,
362        });
363    }
364
365    chains
366}
367
368fn get_interface_ident_suffix(type_name: &str) -> String {
369    let mut suffix = String::new();
370    for c in type_name.chars() {
371        let c = c.to_ascii_lowercase();
372
373        if suffix.len() >= 20 {
374            break;
375        }
376
377        if c.is_ascii_alphanumeric() {
378            suffix.push(c);
379        }
380    }
381
382    suffix
383}