Skip to main content

traceforge_macros/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro2::{Ident, TokenStream};
5use quote::quote;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::{Comma, PathSep};
10use syn::{DeriveInput, Generics, PathSegment, TypePath};
11
12struct MsgTypes {
13    types: Vec<MsgVariant>,
14}
15
16// name is the name of the struct type (defining the monitor) and mtype is the type of the messages
17// the monitor is observing
18struct MsgVariant {
19    name: Ident,     // FOO or M2
20    mtype: TypePath, // Foo or M2
21}
22
23// this is to define types of the form
24// pub enum MyMonitorMsg (the monitor name written with upper case letters) {
25//    Foo(Foo),
26//    M2(M2),
27//    Terminate
28//}
29// where MyMonitor is the name of the monitor, Foo and M2 are the types of the messages observed by the monitor
30impl MsgTypes {
31    fn enum_stream(&self, name: &Ident) -> TokenStream {
32        let vars = self.types.iter().map(|t| {
33            let MsgVariant { name, mtype } = t;
34            quote! {
35                #name(#mtype),
36            }
37        });
38
39        quote! {
40            #[derive(Clone, Debug, PartialEq)]
41            pub enum #name {
42                #(#vars)*
43                Terminate
44            }
45        }
46    }
47}
48
49// this seems to be about parsing something and transforming to a MsgTypes structure
50impl Parse for MsgTypes {
51    fn parse(input: ParseStream) -> Result<Self> {
52        let vars = Punctuated::<TypePath, Comma>::parse_terminated(input)?;
53
54        Ok(MsgTypes {
55            types: vars
56                .into_iter()
57                .map(|t| MsgVariant {
58                    name: get_name(&t.path.segments),
59                    mtype: t,
60                })
61                .collect::<Vec<_>>(),
62        })
63    }
64}
65
66// this is used only in the "parse" function above
67fn get_name(segments: &Punctuated<PathSegment, PathSep>) -> Ident {
68    let vname = segments
69        .iter()
70        .map(|seg| {
71            let ident = format!("{}", seg.ident);
72            ident
73                .split('_')
74                .map(|s| {
75                    let mut s = s.to_string();
76                    if let Some(c) = s.get_mut(0..1) {
77                        c.make_ascii_uppercase();
78                    }
79                    s
80                })
81                .collect::<String>()
82        })
83        .collect::<String>();
84    syn::Ident::new(&vname, segments.span())
85}
86
87// this seems to be related to the name of the macro "#[monitor(M1, M2, M3)]"
88#[proc_macro_attribute]
89pub fn monitor(
90    attr: proc_macro::TokenStream,
91    input: proc_macro::TokenStream,
92) -> proc_macro::TokenStream {
93    let i = input.clone();
94    let ast = syn::parse_macro_input!(i as DeriveInput);
95
96    let name = format!("{}Msg", ast.ident);
97    let name = syn::Ident::new(&name, ast.ident.span());
98
99    // this seems to be related to taking the inputs in the paranthesis of the macro definition and transforming them to a MsgTypes
100    let types = syn::parse_macro_input!(attr as MsgTypes);
101
102    let menum = types.enum_stream(&name);
103    let intos = intos(&name, &types);
104    let rec = monitor_code(&ast.ident, &ast.generics, &name, &types);
105
106    let input: TokenStream = input.into();
107    let gen = quote! {
108        #input
109
110        #menum
111        #intos
112
113        #rec
114    };
115
116    gen.into()
117}
118
119fn intos(name: &Ident, types: &MsgTypes) -> TokenStream {
120    let intos = types
121        .types
122        .iter()
123        .map(|t| impl_into(name, &t.name, &t.mtype));
124    quote! {
125        #(#intos)*
126    }
127}
128
129fn monitor_code(aname: &Ident, gen: &Generics, name: &Ident, types: &MsgTypes) -> TokenStream {
130    let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
131
132    let vars = types.types.iter().map(|t| {
133        let vname = &t.name;
134        let tname = &t.mtype;
135        quote! {
136            #name::#vname(msg) => (self as &mut dyn ::traceforge::monitor_types::Observer<#tname>).notify(who, whom, msg),
137        }
138    });
139
140    let start_monitor_fnname = format!("start_monitor_{}", aname).to_case(Case::Snake);
141    let start_monitor_fnname = syn::Ident::new(&start_monitor_fnname, aname.span());
142
143    let create_msg_for_monitor_fnname =
144        format!("create_msg_for_monitor_{}", aname).to_case(Case::Snake);
145    let create_msg_for_monitor_fnname =
146        syn::Ident::new(&create_msg_for_monitor_fnname, aname.span());
147
148    let accept_msg_for_monitor_fnname =
149        format!("accept_msg_for_monitor_{}", aname).to_case(Case::Snake);
150    let accept_msg_for_monitor_fnname =
151        syn::Ident::new(&accept_msg_for_monitor_fnname, aname.span());
152
153    let terminate_monitor_fnname = format!("terminate_monitor_{}", aname).to_case(Case::Snake);
154    let terminate_monitor_fnname = syn::Ident::new(&terminate_monitor_fnname, aname.span());
155
156    let wrappings = types.types.iter().map(|t| {
157        let tname = &t.mtype;
158        quote! {
159            let m = msg.clone();
160            if let Ok(msg) = m.as_any().downcast::<(#tname)>() {
161                let msg = *msg;
162                let m: #name = <#tname as Into<#name>>::into(msg);
163                return Some(traceforge::Val::new((Some(who),Some(whom),m)));
164            }
165        }
166    });
167
168    let acceptors = types.types.iter().map(|t| {
169        let tname = &t.mtype;
170        quote! {
171            let m = msg.clone();
172            if let Ok(msg) = m.as_any().downcast::<(#tname)>() {
173                let msg = *msg;
174                let mut mon = #aname::default();
175                return (&mut mon as &mut dyn Acceptor<#tname>).accept(who, whom, &msg);
176            }
177        }
178    });
179
180    quote! {
181        impl #impl_generics ::traceforge::monitor_types::Observer<#name> for #aname #ty_generics #where_clause {
182            fn notify(&mut self,
183                        who: ::traceforge::thread::ThreadId,
184                        whom: ::traceforge::thread::ThreadId,
185                        msg: &#name,
186                        ) -> ::traceforge::monitor_types::MonitorResult {
187                match msg {
188                    #(#vars)*
189                    #name::Terminate => Ok(()),
190                }
191            }
192        }
193
194        pub fn #start_monitor_fnname(m: #aname) -> ::traceforge::thread::JoinHandle<::traceforge::monitor_types::MonitorResult> {
195            let cloned = m.clone();
196            let mon1 = std::sync::Arc::new(std::sync::Mutex::new(cloned));
197            let mon2 = mon1.clone();
198            let jh = ::traceforge::spawn_monitor(move || {
199                loop {
200                    let (who, whom, msg): (Option<::traceforge::thread::ThreadId>, Option<::traceforge::thread::ThreadId>, #name) = traceforge::recv_msg_block();
201                    let unwrapped = &mut (*mon1.lock().expect("Failed to lock mon1"));
202                    if let #name::Terminate = msg {
203                        let res = traceforge::invoke_on_stop(unwrapped);
204                        return res;
205                    }
206                    // here I call the handler of that message
207                    let observer = unwrapped as &mut dyn ::traceforge::monitor_types::Observer<#name>;
208                    let res = observer.notify(who.unwrap(), whom.unwrap(), &msg);
209                    if let Err(e) = res {
210                        println!(" {e:?} is the error returned");
211                        traceforge::assert(false);
212                    }
213                }
214            },
215                #create_msg_for_monitor_fnname as fn(::traceforge::thread::ThreadId,::traceforge::thread::ThreadId,traceforge::Val) -> Option<traceforge::Val>,
216                #accept_msg_for_monitor_fnname as fn(::traceforge::thread::ThreadId,::traceforge::thread::ThreadId,traceforge::Val) -> bool,
217                mon2);
218            return jh;
219        }
220
221        pub fn #create_msg_for_monitor_fnname(who: ::traceforge::thread::ThreadId, whom: ::traceforge::thread::ThreadId, msg: traceforge::Val) -> Option<(traceforge::Val)> { //#name
222            #(#wrappings)*
223
224            return None;
225        }
226
227        pub fn #accept_msg_for_monitor_fnname(who: ::traceforge::thread::ThreadId, whom: ::traceforge::thread::ThreadId, msg: traceforge::Val) -> bool { //#name
228            #(#acceptors)*
229
230            return false;
231        }
232
233        pub fn #terminate_monitor_fnname(t: ::traceforge::thread::ThreadId) {
234            let who: Option<::traceforge::thread::ThreadId> = None;
235            let whom: Option<::traceforge::thread::ThreadId> = None;
236            traceforge::send_msg(t, (who, whom, #name::Terminate));
237        }
238    }
239}
240
241fn impl_into(name: &Ident, vname: &Ident, ty: &TypePath) -> TokenStream {
242    quote! {
243        impl Into<#name> for #ty {
244            fn into(self) -> #name {
245                #name::#vname(self)
246            }
247        }
248    }
249}