swamp_script_derive/
lib.rs

1/*
2 * Copyright (c) Peter Bjorklund. All rights reserved. https://github.com/swamp/script
3 * Licensed under the MIT License. See LICENSE in the project root for license information.
4 */
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use syn::{DeriveInput, parse_macro_input};
8
9#[proc_macro_derive(SwampExport, attributes(swamp))]
10pub fn derive_swamp_export(input: TokenStream) -> TokenStream {
11    let input = parse_macro_input!(input as DeriveInput);
12    let name = &input.ident;
13
14    // Extract fields from struct
15    let fields = match input.data {
16        syn::Data::Struct(ref data) => &data.fields,
17        _ => panic!("SwampExport can only be derived for structs"),
18    };
19
20    // Generate field extractions for from_swamp_value
21    let from_field_extractions = fields.iter().enumerate().map(|(index, f)| {
22        let field_name = &f.ident.as_ref().unwrap();
23        let field_type = &f.ty;
24        quote! {
25            let #field_name = <#field_type>::from_swamp_value(&values[#index])?;
26        }
27    });
28
29    // Collect field names and types for struct construction
30    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
31    let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
32
33    let expanded = quote! {
34        impl SwampExport for #name {
35
36            fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
37                let fields = vec![
38                    #((stringify!(#field_names), <#field_types>::get_resolved_type(registry))),*
39                ];
40                registry.register_derived_struct(stringify!(#name), fields)
41            }
42
43            fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
44                let mut values = Vec::new();
45                #(values.push(self.#field_names.to_swamp_value(registry));)*
46
47                let resolved_type = Self::get_resolved_type(registry);
48                match &resolved_type {
49                    ResolvedType::Struct(struct_type) => {
50                        Value::Struct(struct_type.clone(), values, resolved_type)
51                    },
52                    _ => unreachable!("get_resolved_type returned non-struct type")
53                }
54            }
55
56            fn from_swamp_value(value: &Value) -> Result<Self, String> {
57                match value {
58                    Value::Struct(struct_type_ref, values, _) => {
59                        if struct_type_ref.borrow().name.text != stringify!(#name) {
60                            return Err(format!(
61                                "Expected {} struct, got {}",
62                                stringify!(#name),
63                                struct_type_ref.borrow().name.text
64                            ));
65                        }
66                        #(#from_field_extractions)*
67                        Ok(Self {
68                            #(#field_names),*
69                        })
70                    }
71                    _ => Err(format!("Expected {} struct", stringify!(#name)))
72                }
73            }
74        }
75    };
76
77    TokenStream::from(expanded)
78}
79
80#[proc_macro_attribute]
81pub fn swamp_fn(_attr: TokenStream, item: TokenStream) -> TokenStream {
82    let input_fn = parse_macro_input!(item as syn::ItemFn);
83    let fn_name = &input_fn.sig.ident;
84    let module_name = format_ident!("swamp_{}", fn_name.to_string().to_lowercase());
85
86    // Get the context type from the first parameter
87    let context_type = match &input_fn.sig.inputs[0] {
88        syn::FnArg::Typed(pat_type) => &*pat_type.ty,
89        _ => panic!("First parameter must be the context type"),
90    };
91
92    // Extract the inner type from &mut MyContext
93    let context_inner_type = match context_type {
94        syn::Type::Reference(type_ref) => &*type_ref.elem,
95        _ => panic!("Context parameter must be a mutable reference"),
96    };
97
98    // Extract return type
99    let return_type = match &input_fn.sig.output {
100        syn::ReturnType::Default => quote!(<()>::get_resolved_type(registry)),
101        syn::ReturnType::Type(_, ty) => quote!(<#ty>::get_resolved_type(registry)),
102    };
103
104    // Skip the context parameter
105    let args = input_fn
106        .sig
107        .inputs
108        .iter()
109        .skip(1)
110        .map(|arg| {
111            if let syn::FnArg::Typed(pat_type) = arg {
112                let pat = &pat_type.pat;
113                let ty = &pat_type.ty;
114                (pat, ty)
115            } else {
116                panic!("self parameters not supported yet")
117            }
118        })
119        .collect::<Vec<_>>();
120
121    let arg_count = args.len();
122    let arg_indices = 0..arg_count;
123    let (patterns, types): (Vec<_>, Vec<_>) = args.iter().copied().unzip();
124
125    let expanded = quote! {
126        #input_fn  // Keep the original function
127
128        mod #module_name {
129            use super::*;
130            use swamp_script_core_extra::prelude::*;
131
132            pub struct Function {
133                pub name: &'static str,
134                pub function_id: ExternalFunctionId,
135            }
136
137            impl Function {
138                pub fn new(function_id: ExternalFunctionId) -> Self {
139                    Self {
140                        name: stringify!(#fn_name),
141                        function_id,
142                    }
143                }
144
145                pub fn handler<'a>(
146                    &'a self,
147                    registry: &'a TypeRegistry,
148                ) -> Box<dyn FnMut(&[Value], &mut #context_inner_type) -> Result<Value, ValueError> + 'a> {
149                    Box::new(move |args: &[Value], ctx: &mut #context_inner_type| {
150                        if args.len() != #arg_count {
151                            return Err(ValueError::WrongNumberOfArguments {
152                                expected: #arg_count,
153                                got: args.len(),
154                            });
155                        }
156
157                        // Convert arguments
158                        #(
159                            let #patterns = <#types>::from_swamp_value(&args[#arg_indices])
160                                .map_err(|e| ValueError::TypeError(e))?;
161                        )*
162
163                        // Call the function with context
164                        let result = super::#fn_name(ctx, #(#patterns),*);
165
166                        // Convert result back to Value
167                        Ok(result.to_swamp_value(registry))
168                    })
169                }
170
171                pub fn get_definition(&self, registry: &TypeRegistry) -> ResolvedExternalFunctionDefinition {
172                    ResolvedExternalFunctionDefinition {
173                        name: LocalIdentifier::from_str(self.name),
174                        signature: ResolvedFunctionSignature {
175                            parameters: vec![
176                                #(ResolvedParameter {
177                                    name: stringify!(#patterns).to_string(),
178                                    resolved_type: <#types>::get_resolved_type(registry),
179                                    ast_parameter: Parameter::default(),
180                                    is_mutable: false,
181                                },)*
182                            ],
183                            return_type: #return_type,
184                        },
185                        id: self.function_id,
186                    }
187                }
188            }
189        }
190    };
191
192    TokenStream::from(expanded)
193}
194
195#[proc_macro_derive(SwampExportEnum, attributes(swamp))]
196pub fn derive_swamp_export_enum(input: TokenStream) -> TokenStream {
197    let input = parse_macro_input!(input as DeriveInput);
198    let name = &input.ident;
199
200    let expanded = match input.data {
201        syn::Data::Enum(ref data) => {
202            let variant_matches = data.variants.iter().enumerate().map(|(variant_index, variant)| {
203                let variant_name = &variant.ident;
204
205                match &variant.fields {
206                    syn::Fields::Unit => {
207                        quote! {
208                            #name::#variant_name => {
209                                let variant_type = ResolvedEnumVariantType {
210                                    owner: enum_type.clone(),
211                                    data: ResolvedEnumVariantContainerType::Nothing,
212                                    name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
213                                    number: #variant_index as TypeNumber,
214                                };
215                                Value::EnumVariantSimple(Rc::new(variant_type))
216                            }
217                        }
218                    }
219                    syn::Fields::Named(fields) => {
220                        let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
221                        let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
222
223                        let field_type_conversions = field_types.iter().map(|ty| {
224                            match quote!(#ty).to_string().as_str() {
225                                "f32" => quote! { registry.get_float_type() },
226                                "i32" => quote! { registry.get_int_type() },
227                                "bool" => quote! { registry.get_bool_type() },
228                                "String" => quote! { registry.get_string_type() },
229                                ty => quote! { panic!("Unsupported type: {}", #ty) },
230                            }
231                        });
232
233                        let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
234                            match quote!(#ty).to_string().as_str() {
235                                "f32" => quote! { Value::Float(Fp::from(*#name)) },
236                                "i32" => quote! { Value::Int(*#name) },
237                                "bool" => quote! { Value::Bool(*#name) },
238                                "String" => quote! { Value::String(#name.clone()) },
239                                ty => quote! { panic!("Unsupported type: {}", #ty) },
240                            }
241                        });
242
243                        quote! {
244                            #name::#variant_name { #(ref #field_names),* } => {
245                                let mut fields = SeqMap::new();
246                                #(
247                                    fields.insert(
248                                        IdentifierName(stringify!(#field_names).to_string()),
249                                        #field_type_conversions
250                                    );
251                                )*
252
253                                let common = CommonEnumVariantType {
254                                    number: #variant_index as TypeNumber,
255                                    module_path: ModulePath::new(),
256                                    variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
257                                    enum_ref: enum_type.clone(),
258                                };
259
260                                let variant_struct = Rc::new(ResolvedEnumVariantStructType {
261                                    common,
262                                    fields,
263                                    ast_struct: AnonymousStruct::default(),
264                                });
265
266                                let values = vec![
267                                    #(#field_value_conversions),*
268                                ];
269
270                                Value::EnumVariantStruct(variant_struct, values)
271                            }
272                        }
273                    }
274
275
276                    syn::Fields::Unnamed(fields) => {
277                        let field_types: Vec<_> = fields.unnamed.iter().map(|f| &f.ty).collect();
278                        let field_names: Vec<_> = (0..field_types.len())
279                            .map(|i| format_ident!("field_{}", i))
280                            .collect::<Vec<_>>();
281
282                        let field_type_conversions = field_types.iter().map(|ty| {
283                            match quote!(#ty).to_string().as_str() {
284                                "f32" => quote! { registry.get_float_type() },
285                                "i32" => quote! { registry.get_int_type() },
286                                "bool" => quote! { registry.get_bool_type() },
287                                "String" => quote! { registry.get_string_type() },
288                                ty => quote! { panic!("Unsupported type: {}", #ty) },
289                            }
290                        });
291
292                        let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
293                            match quote!(#ty).to_string().as_str() {
294                                "f32" => quote! { Value::Float(Fp::from(*#name)) },
295                                "i32" => quote! { Value::Int(*#name) },
296                                "bool" => quote! { Value::Bool(*#name) },
297                                "String" => quote! { Value::String(#name.clone()) },
298                                ty => quote! { panic!("Unsupported type: {}", #ty) },
299                            }
300                        });
301
302                        quote! {
303                            #name::#variant_name(#(ref #field_names),*) => {
304                                let fields_in_order = vec![
305                                    #(#field_type_conversions),*
306                                ];
307
308                                let common = CommonEnumVariantType {
309                                    number: #variant_index as TypeNumber,
310                                    module_path: ModulePath::new(),
311                                    variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
312                                    enum_ref: enum_type.clone(),
313                                };
314
315                                let variant_tuple = Rc::new(ResolvedEnumVariantTupleType {
316                                    common,
317                                    fields_in_order,
318                                });
319
320                                let values = vec![
321                                    #(#field_value_conversions),*
322                                ];
323
324                                Value::EnumVariantTuple(variant_tuple, values)
325                            }
326                        }
327                    }
328                }
329            });
330
331            quote! {
332                impl SwampExport for #name {
333                    fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
334                        let enum_type = Rc::new(ResolvedEnumType {
335                            name: LocalTypeIdentifier::from_str(stringify!(#name)),
336                            number: registry.allocate_type_number(),
337                            module_path: ModulePath(vec![]),
338                        });
339                        ResolvedType::Enum(enum_type)
340                    }
341
342                    fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
343                        let enum_type = match Self::get_resolved_type(registry) {
344                            ResolvedType::Enum(t) => t,
345                            _ => unreachable!(),
346                        };
347
348                        match self {
349                            #(#variant_matches),*
350                        }
351                    }
352
353                    fn from_swamp_value(value: &Value) -> Result<Self, String> {
354                        match value {
355                            Value::EnumVariantSimple(_) |
356                            Value::EnumVariantTuple(_, _) |
357                            Value::EnumVariantStruct(_, _) => {
358                                todo!("Implement from_swamp_value for enums") // TODO: PBJ: Fix this when needed
359                            }
360                            _ => Err(format!("Expected enum variant, got {:?}", value))
361                        }
362                    }
363                }
364            }
365        }
366        _ => panic!("SwampExportEnum can only be derived for enums"),
367    };
368
369    TokenStream::from(expanded)
370}