1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11use syn::{parse_macro_input, FnArg, ItemFn, LitInt, Pat, ReturnType, Token, Type, TypeReference};
12use telepath_wire::cmd_id::derive_cmd_id as compute_cmd_id;
13
14struct CommandArgs {
18 explicit_cmd_id: Option<u16>,
21}
22
23impl syn::parse::Parse for CommandArgs {
24 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
25 if input.is_empty() {
26 return Ok(CommandArgs {
27 explicit_cmd_id: None,
28 });
29 }
30 let key: syn::Ident = input.parse()?;
31 if key != "cmd_id" {
32 return Err(syn::Error::new_spanned(
33 key,
34 "#[command]: unknown attribute key (expected `cmd_id`)",
35 ));
36 }
37 let _eq: Token![=] = input.parse()?;
38 let lit: LitInt = input.parse()?;
39 let value: u16 = lit.base10_parse().map_err(|_| {
40 syn::Error::new_spanned(&lit, "#[command(cmd_id = ...)]: value must fit in u16")
41 })?;
42 Ok(CommandArgs {
43 explicit_cmd_id: Some(value),
44 })
45 }
46}
47
48fn seen_cmd_ids() -> &'static Mutex<HashMap<u16, String>> {
49 static SEEN: OnceLock<Mutex<HashMap<u16, String>>> = OnceLock::new();
50 SEEN.get_or_init(|| Mutex::new(HashMap::new()))
51}
52
53#[proc_macro_attribute]
110pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
111 let args = match syn::parse2::<CommandArgs>(TokenStream2::from(attr)) {
112 Ok(a) => a,
113 Err(e) => return e.to_compile_error().into(),
114 };
115 let input = parse_macro_input!(item as ItemFn);
116 match expand_command(input, args.explicit_cmd_id) {
117 Ok(ts) => ts.into(),
118 Err(e) => e.to_compile_error().into(),
119 }
120}
121
122fn expand_command(
123 func: ItemFn,
124 explicit_cmd_id: Option<u16>,
125) -> syn::Result<proc_macro2::TokenStream> {
126 let fn_ident = &func.sig.ident;
127 let fn_name_str = fn_ident.to_string();
128
129 if let Some(tok) = &func.sig.asyncness {
132 return Err(syn::Error::new_spanned(
133 tok,
134 "#[command] does not support async fn",
135 ));
136 }
137 if let Some(tok) = &func.sig.unsafety {
138 return Err(syn::Error::new_spanned(
139 tok,
140 "#[command] does not support unsafe fn",
141 ));
142 }
143 if !func.sig.generics.params.is_empty() {
144 return Err(syn::Error::new_spanned(
145 &func.sig.generics,
146 "#[command] does not support generic functions",
147 ));
148 }
149 if let Some(wc) = &func.sig.generics.where_clause {
150 return Err(syn::Error::new_spanned(
151 wc,
152 "#[command] does not support where clauses",
153 ));
154 }
155
156 let mut wire_idents = Vec::new();
160 let mut wire_types: Vec<Box<Type>> = Vec::new();
161 let mut wire_type_strs = Vec::new();
162
163 struct ResourceArg {
165 ident: syn::Ident,
166 inner_ty: Box<Type>,
167 is_mut: bool,
168 }
169 let mut resource_args: Vec<ResourceArg> = Vec::new();
170
171 let mut all_arg_idents: Vec<syn::Ident> = Vec::new();
173
174 for fn_arg in &func.sig.inputs {
175 match fn_arg {
176 FnArg::Receiver(recv) => {
177 return Err(syn::Error::new_spanned(
178 recv,
179 "#[command] cannot be applied to methods",
180 ));
181 }
182 FnArg::Typed(pat_type) => {
183 let ident = match pat_type.pat.as_ref() {
184 Pat::Ident(pi) => pi.ident.clone(),
185 other => {
186 return Err(syn::Error::new_spanned(
187 other,
188 "#[command] requires simple named arguments (patterns not supported)",
189 ));
190 }
191 };
192
193 let is_resource = pat_type.attrs.iter().any(|a| a.path().is_ident("resource"));
194
195 if is_resource {
196 let Type::Reference(TypeReference {
197 elem, mutability, ..
198 }) = pat_type.ty.as_ref()
199 else {
200 return Err(syn::Error::new_spanned(
201 &pat_type.ty,
202 "#[resource] arguments must be &T or &mut T",
203 ));
204 };
205
206 let inner_str = quote! { #elem }.to_string();
211 for existing in &resource_args {
212 let existing_ty = &existing.inner_ty;
213 let existing_str = quote! { #existing_ty }.to_string();
214 if existing_str == inner_str {
215 return Err(syn::Error::new_spanned(
216 &pat_type.ty,
217 "duplicate #[resource] type; each resource type may appear at most once",
218 ));
219 }
220 }
221
222 resource_args.push(ResourceArg {
223 ident: ident.clone(),
224 inner_ty: elem.clone(),
225 is_mut: mutability.is_some(),
226 });
227 all_arg_idents.push(ident);
228 } else {
229 if let Type::Reference(r) = pat_type.ty.as_ref() {
230 return Err(syn::Error::new_spanned(
231 r,
232 "#[command] does not support reference arguments \
233 (use #[resource] for injected references)",
234 ));
235 }
236 let ty = &*pat_type.ty;
237 wire_type_strs.push(quote! { #ty }.to_string());
238 wire_idents.push(ident.clone());
239 wire_types.push(pat_type.ty.clone());
240 all_arg_idents.push(ident);
241 }
242 }
243 }
244 }
245
246 let ret_type_str = match &func.sig.output {
249 ReturnType::Default => "()".to_string(),
250 ReturnType::Type(_, ty) => {
251 if let Type::Reference(r) = ty.as_ref() {
252 return Err(syn::Error::new_spanned(
253 r,
254 "#[command] does not support reference return types",
255 ));
256 }
257 quote! { #ty }.to_string()
258 }
259 };
260
261 let arg_names_str: String = wire_idents
265 .iter()
266 .map(|id| id.to_string())
267 .collect::<Vec<_>>()
268 .join(",");
269
270 let args_type_str = if wire_type_strs.is_empty() {
275 "()".to_string()
276 } else if wire_type_strs.len() == 1 {
277 format!("({},)", wire_type_strs[0])
278 } else {
279 format!("({})", wire_type_strs.join(", "))
280 };
281
282 let cmd_id_value = explicit_cmd_id
291 .unwrap_or_else(|| compute_cmd_id(&fn_name_str, &args_type_str, &ret_type_str));
292
293 {
294 let mut seen = seen_cmd_ids().lock().unwrap();
295 if let Some(existing) = seen.get(&cmd_id_value) {
296 return Err(syn::Error::new_spanned(
297 fn_ident,
298 format!(
299 "#[command] cmd_id collision: `{}` and `{}` both map to 0x{:04X}. \
300 Rename one of the commands to avoid the collision.",
301 fn_name_str, existing, cmd_id_value
302 ),
303 ));
304 }
305 seen.insert(cmd_id_value, fn_name_str.clone());
306 }
307
308 let cmd_id_expr: proc_macro2::TokenStream = if explicit_cmd_id.is_some() {
312 let v = cmd_id_value;
313 quote! { #v }
314 } else {
315 quote! {
316 ::telepath_server::__derive_cmd_id(
317 #fn_name_str,
318 #args_type_str,
319 #ret_type_str,
320 )
321 }
322 };
323
324 let collision_export = format!("__telepath_cmd_id_{:04X}", cmd_id_value);
325 let guard_ident = format_ident!("__TELEPATH_CMDID_GUARD_{}", fn_name_str.to_uppercase());
326
327 let shim_ident = format_ident!("__telepath_shim_{}", fn_name_str);
330 let args_schema_ident = format_ident!("__telepath_args_schema_{}", fn_name_str);
331 let ret_schema_ident = format_ident!("__telepath_ret_schema_{}", fn_name_str);
332 let static_ident = format_ident!("__TELEPATH_CMD_{}", fn_name_str.to_uppercase());
333 let reg_ident = format_ident!("__TELEPATH_REG_{}", fn_name_str.to_uppercase());
334
335 let args_schema_type = if wire_types.is_empty() {
339 quote! { () }
340 } else if wire_types.len() == 1 {
341 let t = &*wire_types[0];
342 quote! { (#t,) }
343 } else {
344 quote! { (#(#wire_types),*) }
345 };
346
347 let ret_schema_type = match &func.sig.output {
348 ReturnType::Default => quote! { () },
349 ReturnType::Type(_, ty) => quote! { #ty },
350 };
351
352 let wire_deser = if wire_idents.is_empty() {
356 quote! {
357 if !input.is_empty() {
358 return ::core::result::Result::Err(
359 ::telepath_server::DispatchError::DeserializeError
360 );
361 }
362 }
363 } else {
364 let wire_tuple_type = if wire_types.len() == 1 {
365 let t = &*wire_types[0];
366 quote! { (#t,) }
367 } else {
368 quote! { (#(#wire_types),*) }
369 };
370 let wire_pat = if wire_idents.len() == 1 {
371 let id = &wire_idents[0];
372 quote! { (#id,) }
373 } else {
374 quote! { (#(#wire_idents),*) }
375 };
376 quote! {
377 let #wire_pat: #wire_tuple_type = match ::postcard::from_bytes(input) {
378 Ok(v) => v,
379 Err(_) => return ::core::result::Result::Err(
380 ::telepath_server::DispatchError::DeserializeError
381 ),
382 };
383 }
384 };
385
386 let resource_lookups: Vec<_> = resource_args
388 .iter()
389 .map(|ra| {
390 let ident = &ra.ident;
391 let inner_ty = &ra.inner_ty;
392 if ra.is_mut {
393 quote! {
394 let #ident: &mut #inner_ty = unsafe {
395 &mut *__resources.get_ptr::<#inner_ty>()
396 .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
397 };
398 }
399 } else {
400 quote! {
401 let #ident: &#inner_ty = unsafe {
402 &*__resources.get_ptr::<#inner_ty>()
403 .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
404 };
405 }
406 }
407 })
408 .collect();
409
410 let call_args: Vec<_> = all_arg_idents
412 .iter()
413 .map(|ident| quote! { #ident })
414 .collect();
415
416 let shim_body = quote! {
417 #wire_deser
418 #(#resource_lookups)*
419 let __ret = #fn_ident(#(#call_args),*);
420 match ::postcard::to_slice(&__ret, output) {
421 Ok(s) => ::core::result::Result::Ok(s.len()),
422 Err(_) => ::core::result::Result::Err(
423 ::telepath_server::DispatchError::SerializeError
424 ),
425 }
426 };
427
428 let mut clean_func = func.clone();
431 for fn_arg in &mut clean_func.sig.inputs {
432 if let FnArg::Typed(pat_type) = fn_arg {
433 pat_type.attrs.retain(|a| !a.path().is_ident("resource"));
434 }
435 }
436
437 Ok(quote! {
438 #clean_func
439
440 #[allow(non_snake_case)]
441 fn #shim_ident(
442 input: &[u8],
443 output: &mut [u8],
444 __resources: &::telepath_server::ResourceRegistry,
445 ) -> ::core::result::Result<usize, ::telepath_server::DispatchError> {
446 #shim_body
447 }
448
449 #[allow(non_snake_case)]
450 fn #args_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
451 ::postcard::to_slice(
452 <#args_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
453 out,
454 )
455 .map(|s| s.len())
456 .map_err(|_| ())
457 }
458
459 #[allow(non_snake_case)]
460 fn #ret_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
461 ::postcard::to_slice(
462 <#ret_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
463 out,
464 )
465 .map(|s| s.len())
466 .map_err(|_| ())
467 }
468
469 pub const #static_ident: ::telepath_server::CommandMetadata =
470 ::telepath_server::CommandMetadata {
471 name: #fn_name_str,
472 id: #cmd_id_expr,
473 invoke: #shim_ident,
474 args_schema: #args_schema_ident,
475 ret_schema: #ret_schema_ident,
476 arg_names: #arg_names_str,
477 };
478
479 #[allow(non_upper_case_globals, non_snake_case)]
480 #[::telepath_server::__linkme::distributed_slice(::telepath_server::TELEPATH_COMMANDS)]
481 #[linkme(crate = ::telepath_server::__linkme)]
482 static #reg_ident: ::telepath_server::CommandMetadata = #static_ident;
483
484 #[doc(hidden)]
495 #[allow(non_upper_case_globals, dead_code)]
496 #[used]
497 #[export_name = #collision_export]
498 pub static #guard_ident: u8 = 0;
499
500 })
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use std::sync::Mutex;
507
508 static TEST_GUARD: Mutex<()> = Mutex::new(());
510
511 fn parse_fn(src: &str) -> ItemFn {
512 syn::parse_str(src).unwrap()
513 }
514
515 #[test]
516 fn same_crate_collision_is_rejected() {
517 let _g = TEST_GUARD.lock().unwrap();
518 seen_cmd_ids().lock().unwrap().clear();
519 assert!(expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None).is_ok());
521 let err = expand_command(parse_fn("fn cmd_470() -> u32 { 0 }"), None)
522 .unwrap_err()
523 .to_string();
524 assert!(
525 err.contains("cmd_id collision"),
526 "expected collision error, got: {err}"
527 );
528 assert!(
529 err.contains("0x43AE"),
530 "expected hex id 0x43AE in error, got: {err}"
531 );
532 assert!(
533 err.contains("cmd_446") && err.contains("cmd_470"),
534 "expected both command names in error, got: {err}"
535 );
536 seen_cmd_ids().lock().unwrap().clear();
537 }
538
539 #[test]
540 fn guard_symbol_has_correct_export_name() {
541 let _g = TEST_GUARD.lock().unwrap();
542 seen_cmd_ids().lock().unwrap().clear();
543 let ts = expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None)
544 .unwrap()
545 .to_string();
546 assert!(
548 ts.contains("__telepath_cmd_id_43AE"),
549 "guard symbol export_name not found in generated code: {ts}"
550 );
551 seen_cmd_ids().lock().unwrap().clear();
552 }
553
554 #[test]
555 fn distinct_commands_do_not_collide() {
556 let _g = TEST_GUARD.lock().unwrap();
557 seen_cmd_ids().lock().unwrap().clear();
558 assert!(expand_command(parse_fn("fn ping() -> u32 { 0 }"), None).is_ok());
559 assert!(expand_command(parse_fn("fn echo(x: u32) -> u32 { x }"), None).is_ok());
560 seen_cmd_ids().lock().unwrap().clear();
561 }
562
563 #[test]
564 fn explicit_cmd_id_overrides_derive() {
565 let _g = TEST_GUARD.lock().unwrap();
566 seen_cmd_ids().lock().unwrap().clear();
567 let ts = expand_command(parse_fn("fn get_metrics() -> u32 { 0 }"), Some(0xFFFE))
568 .unwrap()
569 .to_string();
570 assert!(
572 ts.contains("65534"), "explicit cmd_id 0xFFFE not found as literal in generated code: {ts}"
574 );
575 assert!(
577 ts.contains("__telepath_cmd_id_FFFE"),
578 "guard symbol for explicit cmd_id not found in generated code: {ts}"
579 );
580 seen_cmd_ids().lock().unwrap().clear();
581 }
582
583 #[test]
584 fn explicit_cmd_id_collision_rejected() {
585 let _g = TEST_GUARD.lock().unwrap();
586 seen_cmd_ids().lock().unwrap().clear();
587 assert!(expand_command(parse_fn("fn foo() -> u32 { 0 }"), Some(0xFFFE)).is_ok());
588 let err = expand_command(parse_fn("fn bar() -> u32 { 0 }"), Some(0xFFFE))
589 .unwrap_err()
590 .to_string();
591 assert!(
592 err.contains("cmd_id collision"),
593 "expected collision error for duplicate explicit cmd_id, got: {err}"
594 );
595 seen_cmd_ids().lock().unwrap().clear();
596 }
597}