1use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use syn::{DeriveInput, parse_macro_input};
8
9#[proc_macro_derive(SwampExport, attributes(swamp))]
10pub fn derive_swamp_export(input: TokenStream) -> TokenStream {
11 let input = parse_macro_input!(input as DeriveInput);
12 let name = &input.ident;
13
14 let fields = match input.data {
16 syn::Data::Struct(ref data) => &data.fields,
17 _ => panic!("SwampExport can only be derived for structs"),
18 };
19
20 let from_field_extractions = fields.iter().enumerate().map(|(index, f)| {
22 let field_name = &f.ident.as_ref().unwrap();
23 let field_type = &f.ty;
24 quote! {
25 let #field_name = <#field_type>::from_swamp_value(&values[#index])?;
26 }
27 });
28
29 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
31 let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
32
33 let expanded = quote! {
34 impl SwampExport for #name {
35
36 fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
37 let fields = vec![
38 #((stringify!(#field_names), <#field_types>::get_resolved_type(registry))),*
39 ];
40 registry.register_derived_struct(stringify!(#name), fields)
41 }
42
43 fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
44 let mut values = Vec::new();
45 #(values.push(self.#field_names.to_swamp_value(registry));)*
46
47 let resolved_type = Self::get_resolved_type(registry);
48 match &resolved_type {
49 ResolvedType::Struct(struct_type) => {
50 Value::Struct(struct_type.clone(), values, resolved_type)
51 },
52 _ => unreachable!("get_resolved_type returned non-struct type")
53 }
54 }
55
56 fn from_swamp_value(value: &Value) -> Result<Self, String> {
57 match value {
58 Value::Struct(struct_type_ref, values, _) => {
59 if struct_type_ref.borrow().name.text != stringify!(#name) {
60 return Err(format!(
61 "Expected {} struct, got {}",
62 stringify!(#name),
63 struct_type_ref.borrow().name.text
64 ));
65 }
66 #(#from_field_extractions)*
67 Ok(Self {
68 #(#field_names),*
69 })
70 }
71 _ => Err(format!("Expected {} struct", stringify!(#name)))
72 }
73 }
74 }
75 };
76
77 TokenStream::from(expanded)
78}
79
80#[proc_macro_attribute]
81pub fn swamp_fn(_attr: TokenStream, item: TokenStream) -> TokenStream {
82 let input_fn = parse_macro_input!(item as syn::ItemFn);
83 let fn_name = &input_fn.sig.ident;
84 let module_name = format_ident!("swamp_{}", fn_name.to_string().to_lowercase());
85
86 let context_type = match &input_fn.sig.inputs[0] {
88 syn::FnArg::Typed(pat_type) => &*pat_type.ty,
89 _ => panic!("First parameter must be the context type"),
90 };
91
92 let context_inner_type = match context_type {
94 syn::Type::Reference(type_ref) => &*type_ref.elem,
95 _ => panic!("Context parameter must be a mutable reference"),
96 };
97
98 let return_type = match &input_fn.sig.output {
100 syn::ReturnType::Default => quote!(<()>::get_resolved_type(registry)),
101 syn::ReturnType::Type(_, ty) => quote!(<#ty>::get_resolved_type(registry)),
102 };
103
104 let args = input_fn
106 .sig
107 .inputs
108 .iter()
109 .skip(1)
110 .map(|arg| {
111 if let syn::FnArg::Typed(pat_type) = arg {
112 let pat = &pat_type.pat;
113 let ty = &pat_type.ty;
114 (pat, ty)
115 } else {
116 panic!("self parameters not supported yet")
117 }
118 })
119 .collect::<Vec<_>>();
120
121 let arg_count = args.len();
122 let arg_indices = 0..arg_count;
123 let (patterns, types): (Vec<_>, Vec<_>) = args.iter().copied().unzip();
124
125 let expanded = quote! {
126 #input_fn mod #module_name {
129 use super::*;
130 use swamp_script_core_extra::prelude::*;
131
132 pub struct Function {
133 pub name: &'static str,
134 pub function_id: ExternalFunctionId,
135 }
136
137 impl Function {
138 pub fn new(function_id: ExternalFunctionId) -> Self {
139 Self {
140 name: stringify!(#fn_name),
141 function_id,
142 }
143 }
144
145 pub fn handler<'a>(
146 &'a self,
147 registry: &'a TypeRegistry,
148 ) -> Box<dyn FnMut(&[Value], &mut #context_inner_type) -> Result<Value, ValueError> + 'a> {
149 Box::new(move |args: &[Value], ctx: &mut #context_inner_type| {
150 if args.len() != #arg_count {
151 return Err(ValueError::WrongNumberOfArguments {
152 expected: #arg_count,
153 got: args.len(),
154 });
155 }
156
157 #(
159 let #patterns = <#types>::from_swamp_value(&args[#arg_indices])
160 .map_err(|e| ValueError::TypeError(e))?;
161 )*
162
163 let result = super::#fn_name(ctx, #(#patterns),*);
165
166 Ok(result.to_swamp_value(registry))
168 })
169 }
170
171 pub fn get_definition(&self, registry: &TypeRegistry) -> ResolvedExternalFunctionDefinition {
172 ResolvedExternalFunctionDefinition {
173 name: LocalIdentifier::from_str(self.name),
174 signature: ResolvedFunctionSignature {
175 parameters: vec![
176 #(ResolvedParameter {
177 name: stringify!(#patterns).to_string(),
178 resolved_type: <#types>::get_resolved_type(registry),
179 ast_parameter: Parameter::default(),
180 is_mutable: false,
181 },)*
182 ],
183 return_type: #return_type,
184 },
185 id: self.function_id,
186 }
187 }
188 }
189 }
190 };
191
192 TokenStream::from(expanded)
193}
194
195#[proc_macro_derive(SwampExportEnum, attributes(swamp))]
196pub fn derive_swamp_export_enum(input: TokenStream) -> TokenStream {
197 let input = parse_macro_input!(input as DeriveInput);
198 let name = &input.ident;
199
200 let expanded = match input.data {
201 syn::Data::Enum(ref data) => {
202 let variant_matches = data.variants.iter().enumerate().map(|(variant_index, variant)| {
203 let variant_name = &variant.ident;
204
205 match &variant.fields {
206 syn::Fields::Unit => {
207 quote! {
208 #name::#variant_name => {
209 let variant_type = ResolvedEnumVariantType {
210 owner: enum_type.clone(),
211 data: ResolvedEnumVariantContainerType::Nothing,
212 name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
213 number: #variant_index as TypeNumber,
214 };
215 Value::EnumVariantSimple(Rc::new(variant_type))
216 }
217 }
218 }
219 syn::Fields::Named(fields) => {
220 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
221 let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
222
223 let field_type_conversions = field_types.iter().map(|ty| {
224 match quote!(#ty).to_string().as_str() {
225 "f32" => quote! { registry.get_float_type() },
226 "i32" => quote! { registry.get_int_type() },
227 "bool" => quote! { registry.get_bool_type() },
228 "String" => quote! { registry.get_string_type() },
229 ty => quote! { panic!("Unsupported type: {}", #ty) },
230 }
231 });
232
233 let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
234 match quote!(#ty).to_string().as_str() {
235 "f32" => quote! { Value::Float(Fp::from(*#name)) },
236 "i32" => quote! { Value::Int(*#name) },
237 "bool" => quote! { Value::Bool(*#name) },
238 "String" => quote! { Value::String(#name.clone()) },
239 ty => quote! { panic!("Unsupported type: {}", #ty) },
240 }
241 });
242
243 quote! {
244 #name::#variant_name { #(ref #field_names),* } => {
245 let mut fields = SeqMap::new();
246 #(
247 fields.insert(
248 IdentifierName(stringify!(#field_names).to_string()),
249 #field_type_conversions
250 );
251 )*
252
253 let common = CommonEnumVariantType {
254 number: #variant_index as TypeNumber,
255 module_path: ModulePath::new(),
256 variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
257 enum_ref: enum_type.clone(),
258 };
259
260 let variant_struct = Rc::new(ResolvedEnumVariantStructType {
261 common,
262 fields,
263 ast_struct: AnonymousStruct::default(),
264 });
265
266 let values = vec![
267 #(#field_value_conversions),*
268 ];
269
270 Value::EnumVariantStruct(variant_struct, values)
271 }
272 }
273 }
274
275
276 syn::Fields::Unnamed(fields) => {
277 let field_types: Vec<_> = fields.unnamed.iter().map(|f| &f.ty).collect();
278 let field_names: Vec<_> = (0..field_types.len())
279 .map(|i| format_ident!("field_{}", i))
280 .collect::<Vec<_>>();
281
282 let field_type_conversions = field_types.iter().map(|ty| {
283 match quote!(#ty).to_string().as_str() {
284 "f32" => quote! { registry.get_float_type() },
285 "i32" => quote! { registry.get_int_type() },
286 "bool" => quote! { registry.get_bool_type() },
287 "String" => quote! { registry.get_string_type() },
288 ty => quote! { panic!("Unsupported type: {}", #ty) },
289 }
290 });
291
292 let field_value_conversions = field_names.iter().zip(field_types.iter()).map(|(name, ty)| {
293 match quote!(#ty).to_string().as_str() {
294 "f32" => quote! { Value::Float(Fp::from(*#name)) },
295 "i32" => quote! { Value::Int(*#name) },
296 "bool" => quote! { Value::Bool(*#name) },
297 "String" => quote! { Value::String(#name.clone()) },
298 ty => quote! { panic!("Unsupported type: {}", #ty) },
299 }
300 });
301
302 quote! {
303 #name::#variant_name(#(ref #field_names),*) => {
304 let fields_in_order = vec![
305 #(#field_type_conversions),*
306 ];
307
308 let common = CommonEnumVariantType {
309 number: #variant_index as TypeNumber,
310 module_path: ModulePath::new(),
311 variant_name: LocalTypeIdentifier::from_str(stringify!(#variant_name)),
312 enum_ref: enum_type.clone(),
313 };
314
315 let variant_tuple = Rc::new(ResolvedEnumVariantTupleType {
316 common,
317 fields_in_order,
318 });
319
320 let values = vec![
321 #(#field_value_conversions),*
322 ];
323
324 Value::EnumVariantTuple(variant_tuple, values)
325 }
326 }
327 }
328 }
329 });
330
331 quote! {
332 impl SwampExport for #name {
333 fn get_resolved_type(registry: &TypeRegistry) -> ResolvedType {
334 let enum_type = Rc::new(ResolvedEnumType {
335 name: LocalTypeIdentifier::from_str(stringify!(#name)),
336 number: registry.allocate_type_number(),
337 module_path: ModulePath(vec![]),
338 });
339 ResolvedType::Enum(enum_type)
340 }
341
342 fn to_swamp_value(&self, registry: &TypeRegistry) -> Value {
343 let enum_type = match Self::get_resolved_type(registry) {
344 ResolvedType::Enum(t) => t,
345 _ => unreachable!(),
346 };
347
348 match self {
349 #(#variant_matches),*
350 }
351 }
352
353 fn from_swamp_value(value: &Value) -> Result<Self, String> {
354 match value {
355 Value::EnumVariantSimple(_) |
356 Value::EnumVariantTuple(_, _) |
357 Value::EnumVariantStruct(_, _) => {
358 todo!("Implement from_swamp_value for enums") }
360 _ => Err(format!("Expected enum variant, got {:?}", value))
361 }
362 }
363 }
364 }
365 }
366 _ => panic!("SwampExportEnum can only be derived for enums"),
367 };
368
369 TokenStream::from(expanded)
370}