1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{ItemEnum, ItemStruct, ItemType};

fn get_name(item: &proc_macro2::TokenStream) -> Option<Ident> {
    let ident = if let Ok(ItemStruct { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else if let Ok(ItemEnum { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else if let Ok(ItemType { ident, .. }) = syn::parse2(item.clone()) {
        ident
    } else {
        return None;
    };

    Some(ident)
}

#[allow(clippy::too_many_lines)]
fn stateroom_wasm_impl(item: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let name =
        get_name(item).expect("Can only use #[stateroom_wasm] on a struct, enum, or type alias.");

    quote! {
        #item

        mod _stateroom_wasm_macro_autogenerated {
            extern crate alloc;

            use super::#name;

            // Functions implemented by the host.
            mod ffi {
                extern "C" {
                    pub fn stateroom_send(message_ptr: *const u8, message_len: u32);
                }
            }

            // Instance-global stateroom service.
            static mut SERVER_STATE: Option<stateroom_wasm::WrappedStateroomService<#name>> = None;

            #[no_mangle]
            pub static STATEROOM_API_VERSION: i32 = 1;

            #[no_mangle]
            pub static STATEROOM_API_PROTOCOL: i32 = 0;

            #[no_mangle]
            extern "C" fn stateroom_recv(message_ptr: *const u8, message_len: u32) {
                let state = unsafe {
                    match SERVER_STATE.as_mut() {
                        Some(s) => s,
                        None => {
                            let s = stateroom_wasm::WrappedStateroomService::new(#name::default(), ffi::stateroom_send);
                            SERVER_STATE.replace(s);
                            SERVER_STATE.as_mut().unwrap()
                        }
                    }
                };
                state.recv(message_ptr, message_len);
            }

            #[no_mangle]
            pub unsafe extern "C" fn stateroom_malloc(size: u32) -> *mut u8 {
                if size == 0 {
                    return core::ptr::null_mut();
                }
                let layout = core::alloc::Layout::from_size_align_unchecked(size as usize, 0);
                alloc::alloc::alloc(layout)
            }

            #[no_mangle]
            pub unsafe extern "C" fn stateroom_free(ptr: *mut u8, size: u32) {
                if size == 0 {
                    return;
                }
                let layout = core::alloc::Layout::from_size_align_unchecked(size as usize, 0);
                alloc::alloc::dealloc(ptr, layout);
            }
        }
    }
}

/// Exposes a `stateroom_wasm::StateroomService`-implementing trait as a WebAssembly module.
#[proc_macro_attribute]
pub fn stateroom_wasm(_attr: TokenStream, item: TokenStream) -> TokenStream {
    #[allow(clippy::needless_borrow)]
    stateroom_wasm_impl(&item.into()).into()
}

#[cfg(test)]
mod test {
    use super::get_name;
    use quote::quote;

    #[test]
    fn test_parse_name() {
        assert_eq!(
            "MyStruct",
            get_name(&quote! {
                struct MyStruct {}
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "AnotherStruct",
            get_name(&quote! {
                struct AnotherStruct;
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "ATupleStruct",
            get_name(&quote! {
                struct ATupleStruct(u32, u32, u32);
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "AnEnum",
            get_name(&quote! {
                enum AnEnum {
                    Option1,
                    Option2(u32),
                }
            })
            .unwrap()
            .to_string()
        );

        assert_eq!(
            "ATypeDecl",
            get_name(&quote! {
                type ATypeDecl = u32;
            })
            .unwrap()
            .to_string()
        );

        assert!(get_name(&quote! {
            impl Foo {}
        })
        .is_none());
    }
}