Skip to main content

sm_ext_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
4use quote::{format_ident, quote, quote_spanned, ToTokens, TokenStreamExt};
5use syn;
6use syn::spanned::Spanned;
7
8/// Creates the entry point for SourceMod to recognise this library as an extension and set the required metadata.
9///
10/// The `#[extension]` attribute recognises the following optional keys using the *MetaListNameValueStr* syntax:
11///   * `name`
12///   * `description`
13///   * `url`
14///   * `author`
15///   * `version`
16///   * `tag`
17///   * `date`
18///
19/// If not overridden, all extension metadata will be set to suitable values using the Cargo package metadata.
20///
21/// An instance of the struct this is applied to will be created with [`Default::default()`] to serve
22/// as the global singleton instance, and you can implement the [`IExtensionInterface`] trait on the
23/// type to handle SourceMod lifecycle callbacks.
24#[proc_macro_derive(SMExtension, attributes(extension))]
25pub fn derive_extension_metadata(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
26    let ast: syn::DeriveInput = syn::parse(input).unwrap();
27
28    let name = &ast.ident;
29    let input = MetadataInput::new(&ast);
30
31    let extension_name = CStringToken(input.name);
32    let extension_description = CStringToken(input.description);
33    let extension_url = CStringToken(input.url);
34    let extension_author = CStringToken(input.author);
35    let extension_version = CStringToken(input.version);
36    let extension_tag = CStringToken(input.tag);
37    let extension_date = CStringToken(input.date);
38
39    let expanded = quote! {
40        // TODO: Checking for a test build here doesn't work when a dependent crate is being tested.
41        #[cfg(all(windows, not(target_feature = "crt-static"), not(test)))]
42        compile_error!("SourceMod requires the Windows CRT to be statically linked (pass `-C target-feature=+crt-static` to rustc)");
43
44        thread_local! {
45            // TODO: This should probably be on the chopping block, it is fairly gross and not just because
46            // it is storing a raw pointer, but I can't currently think of a better way for consumers to
47            // be able to share the SM interfaces with natives.
48            //
49            // One more long-term option might be to handle this internally to the library and pass the
50            // singleton into native callbacks as a param - if we go that route I think this and the
51            // IPluginContext arguments need to be opt-in, possibly using attributes, so that all native
52            // callbacks don't end up with 2 params that 90% of them don't use.
53            //
54            // Even then this isn't a great solution for the interfaces, maybe we should store those in
55            // thread_local variables directly as part of the wrapping API (similar to how SM stores
56            // the requested interfaces in globals) and offer static methods to fetch them automatically.
57            // To do that we would only need to store IShareSys and IExtension globally, but should
58            // probably cache all requested interfaces individually (and ideally force checking them
59            // on load, but that is likely unrealistic.)
60            static EXTENSION_GLOBAL: std::cell::RefCell<Option<*mut sm_ext::IExtensionInterfaceAdapter<#name>>> = std::cell::RefCell::new(None);
61        }
62
63        #[no_mangle]
64        pub extern "C" fn GetSMExtAPI() -> *mut sm_ext::IExtensionInterfaceAdapter<#name> {
65            let delegate: #name = Default::default();
66            let extension = sm_ext::IExtensionInterfaceAdapter::new(delegate);
67            let ptr = Box::into_raw(Box::new(extension));
68            EXTENSION_GLOBAL.with(|ext| {
69                *ext.borrow_mut() = Some(ptr);
70                ptr
71            })
72        }
73
74        // impl #name {
75        //     fn get() -> &'static Self {
76        //         EXTENSION_GLOBAL.with(|ext| {
77        //             unsafe { &(*ext.borrow().unwrap()).delegate }
78        //         })
79        //     }
80        // }
81
82        impl sm_ext::ExtensionMetadata for #name {
83            fn get_extension_name(&self) -> &'static ::std::ffi::CStr {
84                #extension_name
85            }
86            fn get_extension_url(&self) -> &'static ::std::ffi::CStr {
87                #extension_url
88            }
89            fn get_extension_tag(&self) -> &'static ::std::ffi::CStr {
90                #extension_tag
91            }
92            fn get_extension_author(&self) -> &'static ::std::ffi::CStr {
93                #extension_author
94            }
95            fn get_extension_ver_string(&self) -> &'static ::std::ffi::CStr {
96                #extension_version
97            }
98            fn get_extension_description(&self) -> &'static ::std::ffi::CStr {
99                #extension_description
100            }
101            fn get_extension_date_string(&self) -> &'static ::std::ffi::CStr {
102                #extension_date
103            }
104        }
105    };
106
107    expanded.into()
108}
109
110struct CStringToken(MetadataString);
111
112impl ToTokens for CStringToken {
113    fn to_tokens(&self, tokens: &mut TokenStream) {
114        let value = match &self.0 {
115            MetadataString::String(str) => str.to_token_stream(),
116            MetadataString::EnvVar(var) => quote! {
117                env!(#var)
118            },
119        };
120
121        // Inspired by https://crates.io/crates/c_str_macro
122        tokens.append_all(quote! {
123            unsafe {
124                ::std::ffi::CStr::from_ptr(concat!(#value, "\0").as_ptr() as *const ::std::os::raw::c_char)
125            }
126        });
127    }
128}
129
130enum MetadataString {
131    String(String),
132    EnvVar(String),
133}
134
135struct MetadataInput {
136    pub name: MetadataString,
137    pub description: MetadataString,
138    pub url: MetadataString,
139    pub author: MetadataString,
140    pub version: MetadataString,
141    pub tag: MetadataString,
142    pub date: MetadataString,
143}
144
145impl MetadataInput {
146    #[allow(clippy::cognitive_complexity)]
147    pub fn new(ast: &syn::DeriveInput) -> MetadataInput {
148        let mut name = None;
149        let mut description = None;
150        let mut url = None;
151        let mut author = None;
152        let mut version = None;
153        let mut tag = None;
154        let mut date = None;
155
156        let meta = ast.attrs.iter().find_map(|attr| match attr.parse_meta() {
157            Ok(m) => {
158                if m.path().is_ident("extension") {
159                    Some(m)
160                } else {
161                    None
162                }
163            }
164            Err(e) => panic!("unable to parse attribute: {}", e),
165        });
166
167        if let Some(meta) = meta {
168            let meta_list = match meta {
169                syn::Meta::List(inner) => inner,
170                _ => panic!("attribute 'extension' has incorrect type"),
171            };
172
173            for item in meta_list.nested {
174                let pair = match item {
175                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref pair)) => pair,
176                    _ => panic!("unsupported attribute argument {:?}", item.to_token_stream()),
177                };
178
179                if pair.path.is_ident("name") {
180                    if let syn::Lit::Str(ref s) = pair.lit {
181                        name = Some(s.value());
182                    } else {
183                        panic!("name value must be string literal");
184                    }
185                } else if pair.path.is_ident("description") {
186                    if let syn::Lit::Str(ref s) = pair.lit {
187                        description = Some(s.value())
188                    } else {
189                        panic!("description value must be string literal");
190                    }
191                } else if pair.path.is_ident("url") {
192                    if let syn::Lit::Str(ref s) = pair.lit {
193                        url = Some(s.value())
194                    } else {
195                        panic!("url value must be string literal");
196                    }
197                } else if pair.path.is_ident("author") {
198                    if let syn::Lit::Str(ref s) = pair.lit {
199                        author = Some(s.value())
200                    } else {
201                        panic!("author value must be string literal");
202                    }
203                } else if pair.path.is_ident("version") {
204                    if let syn::Lit::Str(ref s) = pair.lit {
205                        version = Some(s.value())
206                    } else {
207                        panic!("version value must be string literal");
208                    }
209                } else if pair.path.is_ident("tag") {
210                    if let syn::Lit::Str(ref s) = pair.lit {
211                        tag = Some(s.value())
212                    } else {
213                        panic!("tag value must be string literal");
214                    }
215                } else if pair.path.is_ident("date") {
216                    if let syn::Lit::Str(ref s) = pair.lit {
217                        date = Some(s.value())
218                    } else {
219                        panic!("date value must be string literal");
220                    }
221                } else {
222                    panic!("unsupported attribute key '{}' found", pair.path.to_token_stream())
223                }
224            }
225        }
226
227        let name = match name {
228            Some(name) => MetadataString::String(name),
229            None => MetadataString::EnvVar("CARGO_PKG_NAME".into()),
230        };
231
232        let description = match description {
233            Some(description) => MetadataString::String(description),
234            None => MetadataString::EnvVar("CARGO_PKG_DESCRIPTION".into()),
235        };
236
237        let url = match url {
238            Some(url) => MetadataString::String(url),
239            None => MetadataString::EnvVar("CARGO_PKG_HOMEPAGE".into()),
240        };
241
242        // TODO: This probably needs a special type to post-process the author list later.
243        let author = match author {
244            Some(author) => MetadataString::String(author),
245            None => MetadataString::EnvVar("CARGO_PKG_AUTHORS".into()),
246        };
247
248        let version = match version {
249            Some(version) => MetadataString::String(version),
250            None => MetadataString::EnvVar("CARGO_PKG_VERSION".into()),
251        };
252
253        // TODO: This probably should have a special type to slugify/uppercase the package name later.
254        let tag = match tag {
255            Some(tag) => MetadataString::String(tag),
256            None => MetadataString::EnvVar("CARGO_PKG_NAME".into()),
257        };
258
259        let date = match date {
260            Some(date) => MetadataString::String(date),
261            None => MetadataString::String("with Rust".into()),
262        };
263
264        MetadataInput { name, description, url, author, version, tag, date }
265    }
266}
267
268#[proc_macro_derive(SMInterfaceApi, attributes(interface))]
269pub fn derive_interface_api(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
270    let input = syn::parse_macro_input!(input as syn::DeriveInput);
271
272    let ident = input.ident;
273
274    let attribute = input.attrs.iter().find_map(|attr| match attr.parse_meta() {
275        Ok(m) => {
276            if m.path().is_ident("interface") {
277                Some(m)
278            } else {
279                None
280            }
281        }
282        Err(e) => panic!("unable to parse attribute: {}", e),
283    });
284
285    let mut output = TokenStream::new();
286
287    if let Some(attribute) = attribute {
288        let nested = match attribute {
289            syn::Meta::List(inner) => inner.nested,
290            _ => panic!("attribute 'interface' has incorrect type"),
291        };
292
293        if nested.len() != 2 {
294            panic!("attribute 'interface' expected 2 params: name, version")
295        }
296
297        let interface_name = match &nested[0] {
298            syn::NestedMeta::Lit(lit) => match lit {
299                syn::Lit::Str(str) => str,
300                _ => panic!("attribute 'interface' param 1 should be a string"),
301            },
302            _ => panic!("attribute 'interface' param 1 should be a literal string"),
303        };
304
305        let interface_version = match &nested[1] {
306            syn::NestedMeta::Lit(lit) => match lit {
307                syn::Lit::Int(int) => int,
308                _ => panic!("attribute 'interface' param 2 should be an integer"),
309            },
310            _ => panic!("attribute 'interface' param 2 should be a literal integer"),
311        };
312
313        output.extend(quote! {
314            impl RequestableInterface for #ident {
315                fn get_interface_name() -> &'static str {
316                    #interface_name
317                }
318
319                fn get_interface_version() -> u32 {
320                    #interface_version
321                }
322
323                #[allow(clippy::transmute_ptr_to_ptr)]
324                unsafe fn from_raw_interface(iface: SMInterface) -> #ident {
325                    #ident(std::mem::transmute(iface.0))
326                }
327            }
328        });
329    }
330
331    output.extend(quote! {
332        impl SMInterfaceApi for #ident {
333            fn get_interface_version(&self) -> u32 {
334                unsafe { virtual_call!(GetInterfaceVersion, self.0) }
335            }
336
337            fn get_interface_name(&self) -> &str {
338                unsafe {
339                    let c_name = virtual_call!(GetInterfaceName, self.0);
340
341                    std::ffi::CStr::from_ptr(c_name).to_str().unwrap()
342                }
343            }
344
345            fn is_version_compatible(&self, version: u32) -> bool {
346                unsafe { virtual_call!(IsVersionCompatible, self.0, version) }
347            }
348        }
349    });
350
351    output.into()
352}
353
354#[proc_macro_derive(ICallableApi)]
355pub fn derive_callable_api(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
356    let input = syn::parse_macro_input!(input as syn::DeriveInput);
357
358    let ident = &input.ident;
359    let generics = &input.generics;
360    let output = quote! {
361        impl #generics ICallableApi for #ident #generics {
362            fn push_int(&mut self, cell: i32) -> Result<(), SPError> {
363                unsafe {
364                    let res = virtual_call!(PushCell, self.0, cell.into());
365                    match res {
366                        SPError::None => Ok(()),
367                        _ => Err(res),
368                    }
369                }
370            }
371
372            fn push_float(&mut self, number: f32) -> Result<(), SPError> {
373                unsafe {
374                    let res = virtual_call!(PushFloat, self.0, number);
375                    match res {
376                        SPError::None => Ok(()),
377                        _ => Err(res),
378                    }
379                }
380            }
381
382            fn push_string(&mut self, string: &CStr) -> Result<(), SPError> {
383                unsafe {
384                    let res = virtual_call!(PushString, self.0, string.as_ptr());
385                    match res {
386                        SPError::None => Ok(()),
387                        _ => Err(res),
388                    }
389                }
390            }
391        }
392    };
393
394    output.into()
395}
396
397/// Declares a function as a native callback and generates internal support code.
398///
399/// A valid native callback must be a free function that is not async, not unsafe, not extern, has
400/// no generic parameters, the first argument takes a [`&IPluginContext`](IPluginContext), any
401/// remaining arguments are convertible to [`cell_t`] using [`TryIntoPlugin`] (possibly wrapped in
402/// an [`Option`]), and returns a type that satisfies the [`NativeResult`] trait.
403///
404/// When the native is invoked by SourceMod the input arguments will be checked to ensure all required
405/// arguments have been passed and are of the correct type, and panics or error results will automatically
406/// be converted into a SourceMod native error using [`safe_native_invoke`].
407///
408/// # Example
409///
410/// ```ignore
411/// use sm_ext::{native, IPluginContext};
412///
413/// #[native]
414/// fn simple_add_native(_ctx: &IPluginContext, a: i32, b: i32) -> i32 {
415///     a + b
416/// }
417/// ```
418#[proc_macro_attribute]
419pub fn native(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
420    let mut input = syn::parse_macro_input!(item as syn::ItemFn);
421    // println!("{:#?}", input);
422
423    let mut output = TokenStream::new();
424
425    if let Some(asyncness) = &input.sig.asyncness {
426        output.extend(error("Native callback must not be async", asyncness.span()));
427    }
428
429    if let Some(unsafety) = &input.sig.unsafety {
430        output.extend(error("Native callback must not be unsafe", unsafety.span()));
431    }
432
433    if let Some(abi) = &input.sig.abi {
434        output.extend(error("Native callback must use the default Rust ABI", abi.span()));
435    }
436
437    if !input.sig.generics.params.is_empty() {
438        output.extend(error("Native callback must not have any generic parameters", input.sig.generics.span()));
439    }
440
441    let mut param_count: i32 = 0;
442    let mut trailing_optional_count = 0;
443    let mut param_output = TokenStream::new();
444    for param in &input.sig.inputs {
445        match param {
446            syn::FnArg::Receiver(param) => {
447                output.extend(error("Native callback must not be a method", param.span()));
448            }
449            syn::FnArg::Typed(param) => {
450                param_count += 1;
451                if param_count == 1 {
452                    param_output.extend(quote_spanned!(param.span() => &ctx,));
453                    continue;
454                }
455
456                let mut is_optional = false;
457                if let syn::Type::Path(path) = &*param.ty {
458                    if path.path.segments.last().unwrap().ident == "Option" {
459                        is_optional = true;
460                        trailing_optional_count += 1;
461                    } else {
462                        trailing_optional_count = 0;
463                    }
464                } else {
465                    trailing_optional_count = 0;
466                }
467
468                let param_idx = param_count - 1;
469                let convert_param = quote_spanned!(param.span() =>
470                    (*(args.offset(#param_idx as isize)))
471                        .try_into_plugin(&ctx)
472                        .map_err(|err| format!("Error processing argument {}: {}", #param_idx, err))?
473                );
474
475                if is_optional {
476                    param_output.extend(quote! {
477                        if #param_idx <= count {
478                            Some(#convert_param)
479                        } else {
480                            None
481                        },
482                    });
483                } else {
484                    param_output.extend(quote! {
485                        #convert_param,
486                    });
487                }
488            }
489        };
490    }
491
492    let args_minimum = (param_count - 1) - trailing_optional_count;
493    let wrapper_ident = &input.sig.ident;
494    let callback_ident = format_ident!("__{}_impl", wrapper_ident);
495    output.extend(quote! {
496        unsafe extern "C" fn #wrapper_ident(ctx: sm_ext::IPluginContextPtr, args: *const sm_ext::cell_t) -> sm_ext::cell_t {
497            sm_ext::safe_native_invoke(ctx, |ctx| -> Result<sm_ext::cell_t, Box<dyn std::error::Error>> {
498                use sm_ext::NativeResult;
499                use sm_ext::TryIntoPlugin;
500
501                let count: i32 = (*args).into();
502                if count < #args_minimum {
503                    return Err(format!("not enough arguments, got {}, expected at least {}", count, #args_minimum).into());
504                }
505
506                let result = #callback_ident(
507                    #param_output
508                ).into_result()?;
509
510                Ok(result.try_into_plugin(&ctx)
511                    .map_err(|err| format!("error processing return value: {}", err))?)
512            })
513        }
514    });
515
516    input.sig.ident = callback_ident;
517    output.extend(input.to_token_stream());
518
519    // println!("{}", output.to_string());
520    output.into()
521}
522
523struct ForwardInfo {
524    ident: syn::Ident,
525    name: Option<syn::LitStr>,
526    exec_type: syn::Path,
527    params: Vec<syn::BareFnArg>,
528    ret: syn::Type,
529}
530
531fn parse_forward_from_field(field: &syn::Field, output: &mut TokenStream) -> Option<ForwardInfo> {
532    // TODO: It would improve diagnostics to remove the attribute if it is found.
533    let attribute = field.attrs.iter().find_map(|attr| match attr.parse_meta() {
534        Ok(m) => {
535            if m.path().is_ident("global_forward") || m.path().is_ident("private_forward") {
536                Some(m)
537            } else {
538                None
539            }
540        }
541        Err(e) => {
542            output.extend(e.to_compile_error());
543            None
544        }
545    })?;
546
547    let (params, ret): (Vec<syn::BareFnArg>, _) = match &field.ty {
548        syn::Type::BareFn(ty) => (
549            ty.inputs.iter().cloned().collect(),
550            match &ty.output {
551                syn::ReturnType::Default => syn::parse_quote!(()),
552                syn::ReturnType::Type(_, ty) => (*ty.as_ref()).clone(),
553            },
554        ),
555        _ => {
556            output.extend(error("expected bare function", field.ty.span()));
557            return None;
558        }
559    };
560
561    let nested = match &attribute {
562        syn::Meta::List(inner) => &inner.nested,
563        _ => {
564            output.extend(error(&format!("attribute '{}' has incorrect type", attribute.path().get_ident().unwrap()), attribute.span()));
565            return None;
566        }
567    };
568
569    if attribute.path().is_ident("global_forward") {
570        if nested.len() != 2 {
571            output.extend(error("Usage: #[global_forward(Forward_Name, ExecType::)]", attribute.span()));
572            return None;
573        }
574
575        let forward_name = match &nested[0] {
576            syn::NestedMeta::Lit(lit) => match lit {
577                syn::Lit::Str(str) => str,
578                _ => {
579                    output.extend(error("expected string literal", nested[0].span()));
580                    return None;
581                }
582            },
583            _ => {
584                output.extend(error("expected string literal", nested[0].span()));
585                return None;
586            }
587        };
588
589        let forward_exec_type = match &nested[1] {
590            syn::NestedMeta::Meta(meta) => match meta {
591                syn::Meta::Path(path) => path,
592                _ => {
593                    output.extend(error("expected type path", nested[1].span()));
594                    return None;
595                }
596            },
597            _ => {
598                output.extend(error("expected type path", nested[1].span()));
599                return None;
600            }
601        };
602
603        Some(ForwardInfo { ident: field.ident.as_ref().unwrap().clone(), name: Some((*forward_name).clone()), exec_type: (*forward_exec_type).clone(), params, ret })
604    } else if attribute.path().is_ident("private_forward") {
605        output.extend(error("#[private_forward] not implemented", attribute.span()));
606
607        if nested.len() != 1 {
608            output.extend(error("Usage: #[private_forward(ExecType::)]", attribute.span()));
609            return None;
610        }
611
612        let forward_exec_type = match &nested[0] {
613            syn::NestedMeta::Meta(meta) => match meta {
614                syn::Meta::Path(path) => path,
615                _ => {
616                    output.extend(error("expected type path", nested[0].span()));
617                    return None;
618                }
619            },
620            _ => {
621                output.extend(error("expected type path", nested[0].span()));
622                return None;
623            }
624        };
625
626        Some(ForwardInfo { ident: field.ident.as_ref().unwrap().clone(), name: None, exec_type: (*forward_exec_type).clone(), params, ret })
627    } else {
628        None
629    }
630}
631
632#[proc_macro_attribute]
633pub fn forwards(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
634    let mut input = syn::parse_macro_input!(item as syn::ItemStruct);
635    // println!("{:#?}", input);
636
637    let mut fields = match &mut input.fields {
638        syn::Fields::Named(fields) => fields,
639        _ => panic!("Expected a struct with named fields"),
640    };
641
642    let mut output = TokenStream::new();
643
644    let mut forwards = Vec::new();
645    let mut filtered_fields: syn::punctuated::Punctuated<syn::Field, syn::Token![,]> = syn::punctuated::Punctuated::new();
646
647    for field in &fields.named {
648        if let Some(forward) = parse_forward_from_field(field, &mut output) {
649            forwards.push(forward);
650        } else {
651            filtered_fields.push((*field).clone());
652        }
653    }
654
655    fields.named = filtered_fields;
656
657    output.extend(input.to_token_stream());
658
659    if forwards.is_empty() {
660        output.extend(error("#[forwards] attribute used on struct with no forward members", input.ident.span()));
661        return output.into();
662    }
663
664    let mut output_thread_locals = TokenStream::new();
665    let mut output_trait = TokenStream::new();
666    let mut output_trait_impl = TokenStream::new();
667    let mut output_trait_impl_register = TokenStream::new();
668    let mut output_trait_impl_unregister = TokenStream::new();
669
670    for forward in forwards {
671        let forward_ident = &forward.ident;
672        let type_ident = format_ident!("__{}_forward", forward.ident);
673        let global_ident = format_ident!("__g_{}_forward", forward.ident);
674
675        let forward_name = forward.name.unwrap(); // TODO: Handle private forwards.
676        let forward_exec_type = forward.exec_type;
677
678        let mut forward_param_types = TokenStream::new();
679
680        let forward_call_return = forward.ret;
681        let mut forward_call_args = TokenStream::new();
682        let mut forward_call_pushes = TokenStream::new();
683
684        for param in forward.params {
685            let param_type = &param.ty;
686            let param_name = &param.name.as_ref().unwrap().0;
687            forward_param_types.extend(quote_spanned!(param_type.span() =>
688                <#param_type>::param_type(),
689            ));
690            forward_call_args.extend(quote_spanned!(param.span() =>
691                #param,
692            ));
693            forward_call_pushes.extend(quote_spanned!(param_name.span() =>
694                self.0.push(#param_name)?;
695            ));
696        }
697
698        output.extend(quote_spanned!(forward.ident.span() =>
699            #[allow(non_camel_case_types)]
700            struct #type_ident<'a>(&'a mut sm_ext::Forward);
701        ));
702
703        let execute_return = match &forward_call_return {
704            syn::Type::Tuple(tuple) if tuple.elems.is_empty() => quote!(self.0.execute()?; Ok(())),
705            _ => quote!(Ok(self.0.execute()?.into())),
706        };
707
708        output.extend(quote_spanned!(forward.ident.span() =>
709            impl #type_ident<'_> {
710                fn execute(&mut self, #forward_call_args) -> Result<#forward_call_return, sm_ext::SPError> {
711                    use sm_ext::Executable;
712                    #forward_call_pushes
713                    #execute_return
714                }
715            }
716        ));
717
718        output_thread_locals.extend(quote_spanned!(forward.ident.span() =>
719            #[allow(non_upper_case_globals)]
720            static #global_ident: std::cell::RefCell<Option<sm_ext::Forward>> = std::cell::RefCell::new(None);
721        ));
722
723        output_trait.extend(quote_spanned!(forward.ident.span() =>
724            fn #forward_ident<F, R>(f: F) -> R where F: FnOnce(&mut #type_ident) -> R;
725        ));
726
727        output_trait_impl_register.extend(quote_spanned!(forward.ident.span() =>
728            let #forward_ident = forward_manager.create_global_forward(#forward_name, #forward_exec_type, &[#forward_param_types])?;
729            #global_ident.with(|fwd| {
730                *fwd.borrow_mut() = Some(#forward_ident);
731            });
732        ));
733
734        output_trait_impl_unregister.extend(quote_spanned!(forward.ident.span() =>
735            #global_ident.with(|fwd| {
736                *fwd.borrow_mut() = None;
737            });
738        ));
739
740        output_trait_impl.extend(quote_spanned!(forward.ident.span() =>
741            fn #forward_ident<F, R>(f: F) -> R where F: FnOnce(&mut #type_ident) -> R {
742                #global_ident.with(|fwd| {
743                    let mut fwd = fwd.borrow_mut();
744                    let fwd = fwd.as_mut().unwrap();
745                    let mut fwd = #type_ident(fwd);
746                    f(&mut fwd)
747                })
748            }
749        ));
750    }
751
752    output.extend(quote! {
753        thread_local! {
754            #output_thread_locals
755        }
756    });
757
758    let struct_ident = &input.ident;
759    let trait_ident = format_ident!("__{}_forwards", input.ident);
760
761    output.extend(quote! {
762        #[allow(non_camel_case_types)]
763        trait #trait_ident {
764            fn register(forward_manager: &sm_ext::IForwardManager) -> Result<(), sm_ext::CreateForwardError>;
765            fn unregister();
766            #output_trait
767        }
768    });
769
770    output.extend(quote! {
771        impl #trait_ident for #struct_ident {
772            fn register(forward_manager: &sm_ext::IForwardManager) -> Result<(), sm_ext::CreateForwardError> {
773                use sm_ext::CallableParam;
774                #output_trait_impl_register
775                Ok(())
776            }
777
778            fn unregister() {
779                #output_trait_impl_unregister
780            }
781
782            #output_trait_impl
783        }
784    });
785
786    // println!("{}", output.to_string());
787    output.into()
788}
789
790#[proc_macro_attribute]
791pub fn vtable(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
792    let this_ptr_type = syn::parse_macro_input!(attr as syn::Path);
793    let mut input = syn::parse_macro_input!(item as syn::ItemStruct);
794    let mut output = TokenStream::new();
795
796    // println!("{}", input.to_token_stream().to_string());
797
798    input.attrs.push(syn::parse_quote!(#[doc(hidden)]));
799    input.attrs.push(syn::parse_quote!(#[repr(C)]));
800
801    let mut did_error = false;
802    for field in &mut input.fields {
803        if let syn::Type::BareFn(ty) = &mut field.ty {
804            ty.unsafety = syn::parse_quote!(unsafe);
805            ty.abi = syn::parse_quote!(extern "C");
806
807            // Prepend the thisptr argument
808            ty.inputs.insert(0, syn::parse_quote!(this: #this_ptr_type));
809        } else {
810            output.extend(error("All vtable struct fields must be bare functions", field.span()));
811            did_error = true;
812        }
813    }
814
815    if !did_error {
816        input.attrs.push(syn::parse_quote!(#[cfg(not(all(windows, target_arch = "x86")))]));
817    }
818
819    output.extend(input.to_token_stream());
820
821    if did_error {
822        return output.into();
823    }
824
825    input.attrs.pop();
826    input.attrs.push(syn::parse_quote!(#[cfg(all(windows, target_arch = "x86", feature = "abi_thiscall"))]));
827
828    for field in &mut input.fields {
829        if let syn::Type::BareFn(ty) = &mut field.ty {
830            if ty.variadic.is_none() {
831                ty.abi = syn::parse_quote!(extern "thiscall");
832            }
833        }
834    }
835
836    output.extend(input.to_token_stream());
837
838    input.attrs.pop();
839    input.attrs.push(syn::parse_quote!(#[cfg(all(windows, target_arch = "x86", not(feature = "abi_thiscall")))]));
840
841    for field in &mut input.fields {
842        if let syn::Type::BareFn(ty) = &mut field.ty {
843            if ty.variadic.is_none() {
844                ty.abi = syn::parse_quote!(extern "fastcall");
845
846                // Add a dummy argument to be passed in edx
847                ty.inputs.insert(1, syn::parse_quote!(_dummy: *const usize));
848            }
849        }
850    }
851
852    output.extend(input.to_token_stream());
853
854    // println!("{}", output.to_string());
855
856    output.into()
857}
858
859// TODO: This needs a lot of input checking and error reporting work
860#[proc_macro_attribute]
861pub fn vtable_override(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
862    let mut input = syn::parse_macro_input!(item as syn::ItemFn);
863    let mut output = TokenStream::new();
864
865    // println!("{}", input.to_token_stream().to_string());
866
867    input.attrs.push(syn::parse_quote!(#[cfg(not(all(windows, target_arch = "x86")))]));
868
869    input.sig.abi = syn::parse_quote!(extern "C");
870
871    output.extend(input.to_token_stream());
872
873    input.attrs.pop();
874    input.attrs.push(syn::parse_quote!(#[cfg(all(windows, target_arch = "x86", feature = "abi_thiscall"))]));
875
876    input.sig.abi = syn::parse_quote!(extern "thiscall");
877
878    output.extend(input.to_token_stream());
879
880    input.attrs.pop();
881    input.attrs.push(syn::parse_quote!(#[cfg(all(windows, target_arch = "x86", not(feature = "abi_thiscall")))]));
882
883    // Add a dummy argument to be passed in edx
884    input.sig.inputs.insert(1, syn::parse_quote!(_dummy: *const usize));
885
886    input.sig.abi = syn::parse_quote!(extern "fastcall");
887
888    output.extend(input.to_token_stream());
889
890    // println!("{}", output.to_string());
891
892    output.into()
893}
894
895fn error(s: &str, span: Span) -> TokenStream {
896    let mut v = Vec::new();
897    v.push(respan(Literal::string(&s), Span::call_site()));
898    let group = v.into_iter().collect();
899
900    let mut r = Vec::<TokenTree>::new();
901    r.push(respan(Ident::new("compile_error", span), span));
902    r.push(respan(Punct::new('!', Spacing::Alone), span));
903    r.push(respan(Group::new(Delimiter::Brace, group), span));
904
905    r.into_iter().collect()
906}
907
908fn respan<T: Into<TokenTree>>(t: T, span: Span) -> TokenTree {
909    let mut t = t.into();
910    t.set_span(span);
911    t
912}