Skip to main content

surrealism_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::punctuated::Punctuated;
4use syn::token::Comma;
5use syn::{
6	Expr, ExprLit, FnArg, GenericArgument, ItemFn, Lit, Meta, MetaNameValue, PatType,
7	PathArguments, ReturnType, Type, TypePath, parse_macro_input,
8};
9
10#[proc_macro_attribute]
11pub fn surrealism(attr: TokenStream, item: TokenStream) -> TokenStream {
12	let args = parse_macro_input!(attr with Punctuated::<Meta, Comma>::parse_terminated);
13	let input_fn = parse_macro_input!(item as ItemFn);
14
15	let mut is_default = false;
16	let mut export_name_override: Option<String> = None;
17	let mut is_init = false;
18
19	for meta in args.iter() {
20		match meta {
21			Meta::NameValue(MetaNameValue {
22				path,
23				value,
24				..
25			}) if path.is_ident("name") => {
26				if let Expr::Lit(ExprLit {
27					lit: Lit::Str(s),
28					..
29				}) = value
30				{
31					let val = s.value();
32					if !val.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
33						panic!(
34							"#[surrealism(name = \"...\")] must use only ASCII letters, digits, and underscores"
35						);
36					}
37					export_name_override = Some(val);
38				}
39			}
40			Meta::Path(path) if path.is_ident("default") => {
41				is_default = true;
42			}
43			Meta::Path(path) if path.is_ident("init") => {
44				is_init = true;
45			}
46			_ => panic!(
47				"Unsupported attribute: expected #[surrealism], #[surrealism(default)], #[surrealism(init)], or #[surrealism(name = \"...\")]"
48			),
49		}
50	}
51
52	let fn_name = &input_fn.sig.ident;
53	let fn_vis = &input_fn.vis;
54	let fn_sig = &input_fn.sig;
55	let fn_block = &input_fn.block;
56
57	// Collect argument patterns and types
58	let mut arg_patterns = Vec::new();
59	let mut arg_types = Vec::new();
60
61	for arg in &fn_sig.inputs {
62		match arg {
63			FnArg::Typed(PatType {
64				pat,
65				ty,
66				..
67			}) => {
68				arg_patterns.push(pat.clone());
69				arg_types.push(ty);
70			}
71			FnArg::Receiver(_) => panic!("`self` is not supported in #[surrealism] functions"),
72		}
73	}
74
75	// Compose tuple type and pattern (single args are passed directly)
76	let (tuple_type, tuple_pattern) = if arg_types.is_empty() {
77		(quote! { () }, quote! { () })
78	} else {
79		(quote! { ( #(#arg_types),*, ) }, quote! { ( #(#arg_patterns),*, ) })
80	};
81
82	// Return type analysis
83	let (result_type, is_result) = match &fn_sig.output {
84		ReturnType::Default => (quote! { () }, false),
85		ReturnType::Type(_, ty) => {
86			// Check if the return type is Result<T, E>
87			if let Type::Path(TypePath {
88				path,
89				..
90			}) = &**ty
91			{
92				if let Some(last_segment) = path.segments.last() {
93					if last_segment.ident == "Result" {
94						if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
95							if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
96								(quote! { #inner_type }, true)
97							} else {
98								(quote! { #ty }, false)
99							}
100						} else {
101							(quote! { #ty }, false)
102						}
103					} else {
104						(quote! { #ty }, false)
105					}
106				} else {
107					(quote! { #ty }, false)
108				}
109			} else {
110				(quote! { #ty }, false)
111			}
112		}
113	};
114
115	// Export function names
116	let export_suffix = if is_default {
117		String::new()
118	} else {
119		export_name_override.unwrap_or_else(|| fn_name.to_string())
120	};
121
122	let export_ident = format_ident!("__sr_fnc__{}", export_suffix);
123	let args_ident = format_ident!("__sr_args__{}", export_suffix);
124	let returns_ident = format_ident!("__sr_returns__{}", export_suffix);
125
126	// DRY error handling pattern
127	let try_or_fail = |expr: proc_macro2::TokenStream, context: &str| {
128		let context = syn::LitStr::new(context, proc_macro2::Span::call_site());
129		quote! {
130			match #expr {
131				Ok(val) => val,
132				Err(e) => {
133					eprintln!(concat!(#context, " error: {}"), e);
134					return -1;
135				}
136			}
137		}
138	};
139
140	let expanded = if is_init {
141		let init_call = if is_result {
142			let expr = quote! { #fn_name() };
143			quote! {
144				match #expr {
145					Ok(()) => 0,
146					Err(e) => {
147						eprintln!("Init error: {}", e);
148						-1
149					}
150				}
151			}
152		} else {
153			quote! {
154				#fn_name();
155				0
156			}
157		};
158
159		quote! {
160			#fn_vis #fn_sig #fn_block
161
162			#[unsafe(no_mangle)]
163			pub extern "C" fn __sr_init() -> i32 {
164				#init_call
165			}
166		}
167	} else {
168		let function_call = if is_result {
169			quote! {
170				#fn_name(#(#arg_patterns),*).map_err(|e| e.to_string())
171			}
172		} else {
173			quote! {
174				Ok(#fn_name(#(#arg_patterns),*))
175			}
176		};
177
178		let transfer_call = if is_result {
179			let expr = quote! { f.invoke_raw(&mut controller, ptr.into()) };
180			let try_or_fail_result = try_or_fail(expr, "Function invocation");
181			quote! {
182				(*#try_or_fail_result)
183				.try_into()
184				.unwrap_or_else(|_| {
185					eprintln!("Transfer error: pointer overflow");
186					-1
187				})
188			}
189		} else {
190			quote! {
191				match f.invoke_raw(&mut controller, ptr.into()) {
192					Ok(result) => match (*result).try_into() {
193						Ok(ptr) => ptr,
194						Err(_) => {
195							eprintln!("Transfer error: pointer overflow");
196							-1
197						}
198					},
199					Err(e) => {
200						eprintln!("Function invocation error: {}", e);
201						-1
202					}
203				}
204			}
205		};
206
207		let args_call = if is_result {
208			let expr = quote! { f.args_raw(&mut controller) };
209			let try_or_fail_result = try_or_fail(expr, "Args");
210			quote! {
211				(*#try_or_fail_result)
212				.try_into()
213				.unwrap_or_else(|_| {
214					eprintln!("Transfer error: pointer overflow");
215					-1
216				})
217			}
218		} else {
219			quote! {
220				match f.args_raw(&mut controller) {
221					Ok(result) => match (*result).try_into() {
222						Ok(ptr) => ptr,
223						Err(_) => {
224							eprintln!("Transfer error: pointer overflow");
225							-1
226						}
227					},
228					Err(e) => {
229						eprintln!("Args error: {}", e);
230						-1
231					}
232				}
233			}
234		};
235
236		let returns_call = if is_result {
237			let expr = quote! { f.returns_raw(&mut controller) };
238			let try_or_fail_result = try_or_fail(expr, "Returns");
239			quote! {
240				(*#try_or_fail_result)
241				.try_into()
242				.unwrap_or_else(|_| {
243					eprintln!("Transfer error: pointer overflow");
244					-1
245				})
246			}
247		} else {
248			quote! {
249				match f.returns_raw(&mut controller) {
250					Ok(result) => match (*result).try_into() {
251						Ok(ptr) => ptr,
252						Err(_) => {
253							eprintln!("Transfer error: pointer overflow");
254							-1
255						}
256					},
257					Err(e) => {
258						eprintln!("Returns error: {}", e);
259						-1
260					}
261				}
262			}
263		};
264
265		quote! {
266			#fn_vis #fn_sig #fn_block
267
268			#[unsafe(no_mangle)]
269			pub extern "C" fn #export_ident(ptr: u32) -> i32 {
270				use surrealism::types::transfer::Transfer;
271				let mut controller = surrealism::Controller {};
272				let f = surrealism::SurrealismFunction::<#tuple_type, #result_type, _>::from(
273					|#tuple_pattern: #tuple_type| #function_call
274				);
275				#transfer_call
276			}
277
278			#[unsafe(no_mangle)]
279			pub extern "C" fn #args_ident() -> i32 {
280				use surrealism::types::transfer::Transfer;
281				let mut controller = surrealism::Controller {};
282				let f = surrealism::SurrealismFunction::<#tuple_type, #result_type, _>::from(
283					|#tuple_pattern: #tuple_type| #function_call
284				);
285				#args_call
286			}
287
288			#[unsafe(no_mangle)]
289			pub extern "C" fn #returns_ident() -> i32 {
290				use surrealism::types::transfer::Transfer;
291				let mut controller = surrealism::Controller {};
292				let f = surrealism::SurrealismFunction::<#tuple_type, #result_type, _>::from(
293					|#tuple_pattern: #tuple_type| #function_call
294				);
295				#returns_call
296			}
297		}
298	};
299
300	TokenStream::from(expanded)
301}