windows_implement/
lib.rs

1//! Implement COM interfaces for Rust types.
2//!
3//! Take a look at [macro@implement] for an example.
4//!
5//! Learn more about Rust for Windows here: <https://github.com/microsoft/windows-rs>
6
7use quote::{quote, ToTokens};
8
9mod r#gen;
10use r#gen::gen_all;
11
12#[cfg(test)]
13mod tests;
14
15/// Implements one or more COM interfaces.
16///
17/// # Example
18/// ```rust,no_run
19/// use windows_core::*;
20///
21/// #[interface("094d70d6-5202-44b8-abb8-43860da5aca2")]
22/// unsafe trait IValue: IUnknown {
23///     fn GetValue(&self, value: *mut i32) -> HRESULT;
24/// }
25///
26/// #[implement(IValue)]
27/// struct Value(i32);
28///
29/// impl IValue_Impl for Value_Impl {
30///     unsafe fn GetValue(&self, value: *mut i32) -> HRESULT {
31///         *value = self.0;
32///         HRESULT(0)
33///     }
34/// }
35///
36/// let object: IValue = Value(123).into();
37/// // Call interface methods...
38/// ```
39#[proc_macro_attribute]
40pub fn implement(
41    attributes: proc_macro::TokenStream,
42    type_tokens: proc_macro::TokenStream,
43) -> proc_macro::TokenStream {
44    implement_core(attributes.into(), type_tokens.into()).into()
45}
46
47fn implement_core(
48    attributes: proc_macro2::TokenStream,
49    item_tokens: proc_macro2::TokenStream,
50) -> proc_macro2::TokenStream {
51    let attributes = syn::parse2::<ImplementAttributes>(attributes).unwrap();
52    let original_type = syn::parse2::<syn::ItemStruct>(item_tokens).unwrap();
53
54    // Do a little thinking and assemble ImplementInputs.  We pass ImplementInputs to
55    // all of our gen_* function.
56    let inputs = ImplementInputs {
57        original_ident: original_type.ident.clone(),
58        interface_chains: convert_implements_to_interface_chains(attributes.implement),
59        trust_level: attributes.trust_level,
60        agile: attributes.agile,
61        impl_ident: quote::format_ident!("{}_Impl", &original_type.ident),
62        constraints: {
63            if let Some(where_clause) = &original_type.generics.where_clause {
64                where_clause.predicates.to_token_stream()
65            } else {
66                quote!()
67            }
68        },
69        generics: if !original_type.generics.params.is_empty() {
70            let mut params = quote! {};
71            original_type.generics.params.to_tokens(&mut params);
72            quote! { <#params> }
73        } else {
74            quote! { <> }
75        },
76        is_generic: !original_type.generics.params.is_empty(),
77        original_type,
78    };
79
80    let items = gen_all(&inputs);
81    let mut tokens = inputs.original_type.into_token_stream();
82    for item in items {
83        tokens.extend(item.into_token_stream());
84    }
85
86    tokens
87}
88
89/// This provides the inputs to the `gen_*` functions, which generate the proc macro output.
90struct ImplementInputs {
91    /// The user's type that was marked with `#[implement]`.
92    original_type: syn::ItemStruct,
93
94    /// The identifier for the user's original type definition.
95    original_ident: syn::Ident,
96
97    /// The list of interface chains that this type implements.
98    interface_chains: Vec<InterfaceChain>,
99
100    /// The "trust level", which is returned by `IInspectable::GetTrustLevel`.
101    trust_level: usize,
102
103    /// Determines whether `IAgileObject` and `IMarshal` are implemented automatically.
104    agile: bool,
105
106    /// The identifier of the `Foo_Impl` type.
107    impl_ident: syn::Ident,
108
109    /// The list of constraints needed for this `Foo_Impl` type.
110    constraints: proc_macro2::TokenStream,
111
112    /// The list of generic parameters for this `Foo_Impl` type, including `<` and `>`.
113    /// If there are no generics, this contains `<>`.
114    generics: proc_macro2::TokenStream,
115
116    /// True if the user type has any generic parameters.
117    is_generic: bool,
118}
119
120/// Describes one COM interface chain.
121struct InterfaceChain {
122    /// The name of the field for the vtable chain, e.g. `interface4_ifoo`.
123    field_ident: syn::Ident,
124
125    /// The name of the associated constant item for the vtable chain's initializer,
126    /// e.g. `INTERFACE4_IFOO_VTABLE`.
127    vtable_const_ident: syn::Ident,
128
129    implement: ImplementType,
130}
131
132struct ImplementType {
133    type_name: String,
134    generics: Vec<ImplementType>,
135
136    /// The best span for diagnostics.
137    span: proc_macro2::Span,
138}
139
140impl ImplementType {
141    fn to_ident(&self) -> proc_macro2::TokenStream {
142        let type_name = syn::parse_str::<proc_macro2::TokenStream>(&self.type_name)
143            .expect("Invalid token stream");
144        let generics = self.generics.iter().map(|g| g.to_ident());
145        quote! { #type_name<#(#generics,)*> }
146    }
147    fn to_vtbl_ident(&self) -> proc_macro2::TokenStream {
148        let ident = self.to_ident();
149        quote! {
150            <#ident as ::windows_core::Interface>::Vtable
151        }
152    }
153}
154
155#[derive(Default)]
156struct ImplementAttributes {
157    pub implement: Vec<ImplementType>,
158    pub trust_level: usize,
159    pub agile: bool,
160}
161
162impl syn::parse::Parse for ImplementAttributes {
163    fn parse(cursor: syn::parse::ParseStream) -> syn::parse::Result<Self> {
164        let mut input = Self {
165            agile: true,
166            ..Default::default()
167        };
168
169        while !cursor.is_empty() {
170            input.parse_implement(cursor)?;
171        }
172
173        Ok(input)
174    }
175}
176
177impl ImplementAttributes {
178    fn parse_implement(&mut self, cursor: syn::parse::ParseStream) -> syn::parse::Result<()> {
179        let tree = cursor.parse::<UseTree2>()?;
180        self.walk_implement(&tree, &mut String::new())?;
181
182        if !cursor.is_empty() {
183            cursor.parse::<syn::Token![,]>()?;
184        }
185
186        Ok(())
187    }
188
189    fn walk_implement(
190        &mut self,
191        tree: &UseTree2,
192        namespace: &mut String,
193    ) -> syn::parse::Result<()> {
194        match tree {
195            UseTree2::Path(input) => {
196                if !namespace.is_empty() {
197                    namespace.push_str("::");
198                }
199
200                namespace.push_str(&input.ident.to_string());
201                self.walk_implement(&input.tree, namespace)?;
202            }
203            UseTree2::Name(_) => {
204                self.implement.push(tree.to_element_type(namespace)?);
205            }
206            UseTree2::Group(input) => {
207                for tree in &input.items {
208                    self.walk_implement(tree, namespace)?;
209                }
210            }
211            UseTree2::TrustLevel(input) => self.trust_level = *input,
212            UseTree2::Agile(agile) => self.agile = *agile,
213        }
214
215        Ok(())
216    }
217}
218
219enum UseTree2 {
220    Path(UsePath2),
221    Name(UseName2),
222    Group(UseGroup2),
223    TrustLevel(usize),
224    Agile(bool),
225}
226
227impl UseTree2 {
228    fn to_element_type(&self, namespace: &mut String) -> syn::parse::Result<ImplementType> {
229        match self {
230            Self::Path(input) => {
231                if !namespace.is_empty() {
232                    namespace.push_str("::");
233                }
234
235                namespace.push_str(&input.ident.to_string());
236                input.tree.to_element_type(namespace)
237            }
238            Self::Name(input) => {
239                let mut type_name = input.ident.to_string();
240                let span = input.ident.span();
241
242                if !namespace.is_empty() {
243                    type_name = format!("{namespace}::{type_name}");
244                }
245
246                let mut generics = vec![];
247
248                for g in &input.generics {
249                    generics.push(g.to_element_type(&mut String::new())?);
250                }
251
252                Ok(ImplementType {
253                    type_name,
254                    generics,
255                    span,
256                })
257            }
258            Self::Group(input) => Err(syn::parse::Error::new(
259                input.brace_token.span.join(),
260                "Syntax not supported",
261            )),
262            _ => unimplemented!(),
263        }
264    }
265}
266
267struct UsePath2 {
268    pub ident: syn::Ident,
269    pub tree: Box<UseTree2>,
270}
271
272struct UseName2 {
273    pub ident: syn::Ident,
274    pub generics: Vec<UseTree2>,
275}
276
277struct UseGroup2 {
278    pub brace_token: syn::token::Brace,
279    pub items: syn::punctuated::Punctuated<UseTree2, syn::Token![,]>,
280}
281
282impl syn::parse::Parse for UseTree2 {
283    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
284        let lookahead = input.lookahead1();
285        if lookahead.peek(syn::Ident) {
286            use syn::ext::IdentExt;
287            let ident = input.call(syn::Ident::parse_any)?;
288            if input.peek(syn::Token![::]) {
289                input.parse::<syn::Token![::]>()?;
290                Ok(Self::Path(UsePath2 {
291                    ident,
292                    tree: Box::new(input.parse()?),
293                }))
294            } else if input.peek(syn::Token![=]) {
295                if ident == "TrustLevel" {
296                    input.parse::<syn::Token![=]>()?;
297                    let span = input.span();
298                    let value = input.call(syn::Ident::parse_any)?;
299                    match value.to_string().as_str() {
300                        "Partial" => Ok(Self::TrustLevel(1)),
301                        "Full" => Ok(Self::TrustLevel(2)),
302                        _ => Err(syn::parse::Error::new(
303                            span,
304                            "`TrustLevel` must be `Partial` or `Full`",
305                        )),
306                    }
307                } else if ident == "Agile" {
308                    input.parse::<syn::Token![=]>()?;
309                    let span = input.span();
310                    let value = input.call(syn::Ident::parse_any)?;
311                    match value.to_string().as_str() {
312                        "true" => Ok(Self::Agile(true)),
313                        "false" => Ok(Self::Agile(false)),
314                        _ => Err(syn::parse::Error::new(
315                            span,
316                            "`Agile` must be `true` or `false`",
317                        )),
318                    }
319                } else {
320                    Err(syn::parse::Error::new(
321                        ident.span(),
322                        "Unrecognized key-value pair",
323                    ))
324                }
325            } else {
326                let generics = if input.peek(syn::Token![<]) {
327                    input.parse::<syn::Token![<]>()?;
328                    let mut generics = Vec::new();
329                    loop {
330                        generics.push(input.parse::<Self>()?);
331
332                        if input.parse::<syn::Token![,]>().is_err() {
333                            break;
334                        }
335                    }
336                    input.parse::<syn::Token![>]>()?;
337                    generics
338                } else {
339                    Vec::new()
340                };
341
342                Ok(Self::Name(UseName2 { ident, generics }))
343            }
344        } else if lookahead.peek(syn::token::Brace) {
345            let content;
346            let brace_token = syn::braced!(content in input);
347            let items = content.parse_terminated(Self::parse, syn::Token![,])?;
348
349            Ok(Self::Group(UseGroup2 { brace_token, items }))
350        } else {
351            Err(lookahead.error())
352        }
353    }
354}
355
356fn convert_implements_to_interface_chains(implements: Vec<ImplementType>) -> Vec<InterfaceChain> {
357    let mut chains = Vec::with_capacity(implements.len());
358
359    for (i, implement) in implements.into_iter().enumerate() {
360        // Create an identifier for this interface chain.
361        // We only use this for naming fields; it is never visible to the developer.
362        // This helps with debugging.
363        //
364        // We use i + 1 so that it matches the numbering of our interface offsets. Interface 0
365        // is the "identity" interface.
366
367        let mut ident_string = format!("interface{}", i + 1);
368
369        let suffix = get_interface_ident_suffix(&implement.type_name);
370        if !suffix.is_empty() {
371            ident_string.push('_');
372            ident_string.push_str(&suffix);
373        }
374        let field_ident = syn::Ident::new(&ident_string, implement.span);
375
376        let mut vtable_const_string = ident_string.clone();
377        vtable_const_string.make_ascii_uppercase();
378        vtable_const_string.insert_str(0, "VTABLE_");
379        let vtable_const_ident = syn::Ident::new(&vtable_const_string, implement.span);
380
381        chains.push(InterfaceChain {
382            implement,
383            field_ident,
384            vtable_const_ident,
385        });
386    }
387
388    chains
389}
390
391fn get_interface_ident_suffix(type_name: &str) -> String {
392    let mut suffix = String::new();
393    for c in type_name.chars() {
394        let c = c.to_ascii_lowercase();
395
396        if suffix.len() >= 20 {
397            break;
398        }
399
400        if c.is_ascii_alphanumeric() {
401            suffix.push(c);
402        }
403    }
404
405    suffix
406}