1#![doc(html_root_url = "https://docs.rs/sdforge-macros/0.1.0")]
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::quote;
11use syn::{parse_macro_input, FnArg, ItemFn, ItemMod, Pat};
12
13type ServiceApiArgs = Result<
15 (
16 String,
17 String,
18 Option<String>,
19 Option<String>,
20 Option<String>,
21 Option<String>,
22 Option<bool>,
23 Option<u64>,
24 Option<String>,
25 Option<String>,
26 ),
27 syn::Error,
28>;
29
30fn parse_kv_pairs(args: TokenStream2) -> Result<Vec<(String, String)>, syn::Error> {
33 let args_str = args.to_string();
34 let mut pairs = Vec::new();
35
36 let mut chars = args_str.chars().peekable();
37 while let Some(&c) = chars.peek() {
38 if c.is_whitespace() || c == ',' {
39 chars.next();
40 continue;
41 }
42
43 let mut key = String::new();
44 while let Some(&c) = chars.peek() {
45 if c == '=' || c.is_whitespace() {
46 break;
47 }
48 key.push(c);
49 chars.next();
50 }
51
52 while let Some(&c) = chars.peek() {
53 if c == '=' {
54 chars.next();
55 break;
56 }
57 chars.next();
58 }
59
60 while let Some(&c) = chars.peek() {
61 if c.is_whitespace() {
62 chars.next();
63 } else {
64 break;
65 }
66 }
67
68 let mut value = String::new();
69 if let Some(&'"') = chars.peek() {
70 chars.next();
71 for c in chars.by_ref() {
72 if c == '"' {
73 break;
74 }
75 value.push(c);
76 }
77 }
78
79 if !key.is_empty() && !value.is_empty() {
80 pairs.push((key, value));
81 }
82
83 if chars.peek().is_none() {
84 break;
85 }
86 }
87
88 Ok(pairs)
89}
90
91#[allow(dead_code)]
94#[inline]
95fn api_metadata_tokens(
96 name: TokenStream2,
97 version: TokenStream2,
98 description: TokenStream2,
99 cache_ttl: TokenStream2,
100 is_streaming: TokenStream2,
101) -> Result<TokenStream2, syn::Error> {
102 let validated_name = validate_api_name(&name.to_string())?;
105 let validated_version = validate_version(&version.to_string())?;
106
107 Ok(quote! {
108 sdforge::core::ApiMetadata::new(
109 #validated_name.to_string(),
110 #validated_version.to_string(),
111 #description.to_string(),
112 #cache_ttl,
113 #is_streaming,
114 )
115 })
116}
117
118fn validate_api_name(name: &str) -> Result<String, syn::Error> {
121 let name = name.trim_matches('"').trim();
122
123 if name.is_empty() {
125 return Err(syn::Error::new(
126 proc_macro2::Span::call_site(),
127 "API name cannot be empty",
128 ));
129 }
130
131 if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
133 return Err(syn::Error::new(
134 proc_macro2::Span::call_site(),
135 format!("API name contains invalid characters: {}", name),
136 ));
137 }
138
139 if name.starts_with(|c: char| !c.is_alphabetic() && c != '_') {
141 return Err(syn::Error::new(
142 proc_macro2::Span::call_site(),
143 format!("API name must start with a letter or underscore: {}", name),
144 ));
145 }
146
147 if RESERVED_KEYWORDS.contains(&name) {
149 return Err(syn::Error::new(
150 proc_macro2::Span::call_site(),
151 format!("API name cannot be a Rust keyword: {}", name),
152 ));
153 }
154
155 Ok(name.to_string())
156}
157
158fn validate_version(version: &str) -> Result<String, syn::Error> {
161 let version = version.trim_matches('"').trim();
162
163 if version.is_empty() {
165 return Err(syn::Error::new(
166 proc_macro2::Span::call_site(),
167 "API version cannot be empty",
168 ));
169 }
170
171 if !version
173 .chars()
174 .all(|c| c.is_alphanumeric() || c == '.' || c == '-')
175 {
176 let invalid_chars: Vec<char> = version
177 .chars()
178 .filter(|c| !c.is_alphanumeric() && *c != '.' && *c != '-')
179 .collect();
180 return Err(syn::Error::new(
181 proc_macro2::Span::call_site(),
182 format!(
183 "API version contains invalid characters: {}",
184 invalid_chars.iter().collect::<String>()
185 ),
186 ));
187 }
188
189 Ok(version.to_string())
190}
191
192const RESERVED_KEYWORDS: &[&str] = &[
194 "match", "if", "else", "loop", "while", "for", "break", "continue", "fn", "struct", "enum",
195 "impl", "trait", "pub", "mod", "use", "const", "static", "let", "mut", "ref", "self", "super",
196 "crate", "return", "true", "false", "async", "await", "dyn", "unsafe", "extern", "type",
197 "where", "move", "as", "in", "of", "is", "Some", "None", "Ok", "Err",
198];
199
200const DEFAULT_CACHE_TTL: u64 = 300;
202
203fn parse_service_api_args(args: TokenStream2) -> ServiceApiArgs {
205 let pairs = parse_kv_pairs(args)?;
206
207 let mut name = None;
208 let mut version = None;
209 let mut description = None;
210 let mut path = None;
211 let mut method = None;
212 let mut tool_name = None;
213 let mut stream = None;
214 let mut cache_ttl = None;
215 let mut ws_path = None;
216 let mut grpc_method = None;
217
218 for (key, value) in pairs {
219 match key.as_str() {
220 "name" => name = Some(value),
221 "version" => version = Some(value),
222 "description" => description = Some(value),
223 "path" => path = Some(value),
224 "method" => method = Some(value),
225 "tool_name" => tool_name = Some(value),
226 "stream" => {
227 stream = Some(value.parse::<bool>().map_err(|_| {
228 syn::Error::new(
229 proc_macro2::Span::call_site(),
230 format!("Invalid boolean value for 'stream': {}", value),
231 )
232 })?)
233 }
234 "cache_ttl" => {
235 cache_ttl = Some(
236 value
237 .parse::<u64>()
238 .map_err(|_| {
239 syn::Error::new(
240 proc_macro2::Span::call_site(),
241 format!(
242 "Invalid cache TTL value (must be a positive integer): {}",
243 value
244 ),
245 )
246 })?
247 .max(DEFAULT_CACHE_TTL),
248 )
249 }
250 "ws_path" => ws_path = Some(value),
251 "grpc_method" => grpc_method = Some(value),
252 _ => {
253 return Err(syn::Error::new(
254 proc_macro2::Span::call_site(),
255 format!("Unknown attribute: {}", key),
256 ))
257 }
258 }
259 }
260
261 let name = name.ok_or_else(|| {
262 syn::Error::new(
263 proc_macro2::Span::call_site(),
264 "Missing required attribute: name",
265 )
266 })?;
267 let version = version.ok_or_else(|| {
268 syn::Error::new(
269 proc_macro2::Span::call_site(),
270 "Missing required attribute: version",
271 )
272 })?;
273
274 Ok((
275 name,
276 version,
277 description,
278 path,
279 method,
280 tool_name,
281 stream,
282 cache_ttl,
283 ws_path,
284 grpc_method,
285 ))
286}
287
288fn parse_service_module_args(args: TokenStream2) -> Result<String, syn::Error> {
290 let pairs = parse_kv_pairs(args)?;
291
292 let mut prefix = None;
293
294 for (key, value) in pairs {
295 match key.as_str() {
296 "prefix" => prefix = Some(value),
297 _ => {
298 return Err(syn::Error::new(
299 proc_macro2::Span::call_site(),
300 format!("Unknown attribute: {}", key),
301 ))
302 }
303 }
304 }
305
306 prefix.ok_or_else(|| {
307 syn::Error::new(
308 proc_macro2::Span::call_site(),
309 "Missing required attribute: prefix",
310 )
311 })
312}
313
314#[derive(Debug, Clone)]
315enum ParamKind {
316 Path,
317 Query,
318 Header,
319 Cookie,
320 Form,
321 Body,
322}
323
324impl std::fmt::Display for ParamKind {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 match self {
327 ParamKind::Path => write!(f, "path"),
328 ParamKind::Query => write!(f, "query"),
329 ParamKind::Header => write!(f, "header"),
330 ParamKind::Cookie => write!(f, "cookie"),
331 ParamKind::Form => write!(f, "form"),
332 ParamKind::Body => write!(f, "body"),
333 }
334 }
335}
336
337#[derive(Debug, Clone)]
339#[allow(dead_code)]
340struct ParamInfo {
341 name: String,
343 ty: syn::Type,
345 param_kind: ParamKind,
347 is_option: bool,
349 is_vec: bool,
351 inner_type: String,
353 explicit_annotation: Option<ParamKind>,
355}
356
357impl ParamInfo {
358 fn from_arg(
359 arg: &FnArg,
360 path_params: &[String],
361 http_method: Option<&str>,
362 body_params: &[String],
363 ) -> Option<Self> {
364 let pat_type = match arg {
365 FnArg::Receiver(_) => return None,
366 FnArg::Typed(pat_type) => pat_type,
367 };
368
369 let pat = &*pat_type.pat;
370 if let Pat::Ident(pat_ident) = pat {
371 let name = pat_ident.ident.to_string();
372
373 let ty = (*pat_type.ty).clone();
375
376 let ty_str = quote! { #ty }.to_string();
377 let ty_str_trimmed = ty_str.trim().to_string();
378
379 let explicit_annotation = Self::extract_param_annotation(pat_type);
381
382 let param_kind = if let Some(ref kind) = explicit_annotation {
384 kind.clone()
385 } else if path_params.contains(&name) {
386 ParamKind::Path
387 } else if ty_str_trimmed.starts_with("Option<") {
388 let inner = &ty_str_trimmed[7..ty_str_trimmed.len() - 1];
390 if inner.starts_with("HeaderMap") || inner.starts_with("HeaderValue") {
391 ParamKind::Header
392 } else {
393 ParamKind::Query
394 }
395 } else if http_method.map(|m| m.to_uppercase()) == Some("GET".to_string()) {
396 ParamKind::Query
397 } else if body_params.contains(&name) {
398 ParamKind::Body
401 } else {
402 ParamKind::Body
403 };
404
405 let (is_option, is_vec, inner_type) = if ty_str_trimmed.starts_with("Option<") {
406 let inner = &ty_str_trimmed[7..ty_str_trimmed.len() - 1];
407 (true, false, inner.to_string())
408 } else if ty_str_trimmed.starts_with("Vec<") {
409 let inner = &ty_str_trimmed[4..ty_str_trimmed.len() - 1];
410 (false, true, inner.to_string())
411 } else {
412 (false, false, ty_str_trimmed.clone())
413 };
414
415 Some(Self {
416 name,
417 ty,
418 param_kind,
419 is_option,
420 is_vec,
421 inner_type,
422 explicit_annotation,
423 })
424 } else {
425 None
426 }
427 }
428
429 fn extract_param_annotation(pat_type: &syn::PatType) -> Option<ParamKind> {
431 for attr in &pat_type.attrs {
432 if attr.path().is_ident("param") {
433 if let Ok(meta) = attr.parse_args_with(
435 syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
436 ) {
437 for meta_item in meta {
438 if let syn::Meta::NameValue(name_value) = meta_item {
439 if name_value.path.is_ident("kind") {
440 if let syn::Expr::Lit(syn::ExprLit {
441 lit: syn::Lit::Str(lit_str),
442 ..
443 }) = &name_value.value
444 {
445 return match lit_str.value().as_str() {
446 "path" => Some(ParamKind::Path),
447 "query" => Some(ParamKind::Query),
448 "header" => Some(ParamKind::Header),
449 "cookie" => Some(ParamKind::Cookie),
450 "form" => Some(ParamKind::Form),
451 "body" => Some(ParamKind::Body),
452 _ => None,
453 };
454 }
455 }
456 }
457 }
458 }
459 }
460 }
461 None
462 }
463
464 fn to_json_schema(&self) -> String {
466 let param_type = if self.is_option {
467 format!(
468 "{{\"type\":[\"null\",{}]}}",
469 self.inner_type_to_json_schema()
470 )
471 } else if self.is_vec {
472 format!(
473 "{{\"type\":\"array\",\"items\":{}}}",
474 self.inner_type_to_json_schema()
475 )
476 } else {
477 format!("{{\"type\":{}}}", self.inner_type_to_json_schema())
478 };
479 format!("\"{}\":{}", self.name, param_type)
480 }
481
482 fn inner_type_to_json_schema(&self) -> String {
483 match self.inner_type.as_str() {
484 "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128"
485 | "f32" | "f64" => "\"number\"".to_string(),
486 "bool" => "\"boolean\"".to_string(),
487 "String" | "&str" => "\"string\"".to_string(),
488 _ => "\"object\"".to_string(),
489 }
490 }
491}
492
493fn extract_path_params(path: &str) -> Vec<String> {
495 path.split('/')
496 .filter(|segment| segment.starts_with(':') || segment.starts_with('{'))
497 .map(|segment| {
498 segment
500 .trim_start_matches(':')
501 .trim_start_matches('{')
502 .trim_end_matches('}')
503 .trim_end_matches('}')
504 .to_string()
505 })
506 .collect()
507}
508
509#[proc_macro_attribute]
510pub fn service_api(args: TokenStream, input: TokenStream) -> TokenStream {
511 let args = match parse_service_api_args(args.into()) {
512 Ok(args) => args,
513 Err(e) => return e.into_compile_error().into(),
514 };
515 let input = parse_macro_input!(input as ItemFn);
516
517 let (
518 name,
519 version,
520 description,
521 path,
522 method,
523 tool_name,
524 stream,
525 cache_ttl,
526 ws_path,
527 grpc_method,
528 ) = args;
529 let fn_name = &input.sig.ident;
530 let _fn_vis = &input.vis; let return_type = &input.sig.output;
532
533 let path_params = path
535 .as_ref()
536 .map(|p| extract_path_params(p))
537 .unwrap_or_default();
538
539 let all_param_names: Vec<String> = input
541 .sig
542 .inputs
543 .iter()
544 .filter_map(|arg| {
545 if let FnArg::Typed(pat_type) = arg {
546 if let Pat::Ident(pat_ident) = &*pat_type.pat {
547 return Some(pat_ident.ident.to_string());
548 }
549 }
550 None
551 })
552 .collect();
553
554 let body_param_names: Vec<String> = all_param_names
556 .iter()
557 .filter(|name| !path_params.contains(name))
558 .cloned()
559 .collect();
560
561 let params: Vec<ParamInfo> = input
563 .sig
564 .inputs
565 .iter()
566 .filter_map(|arg| {
567 ParamInfo::from_arg(arg, &path_params, method.as_deref(), &body_param_names)
568 })
569 .collect();
570
571 let has_params = !params.is_empty();
573
574 let param_patterns: Vec<_> = params
576 .iter()
577 .map(|p| {
578 let name_ident = syn::Ident::new(&p.name, proc_macro2::Span::call_site());
579 let ty = &p.ty;
580 match p.param_kind {
581 ParamKind::Path => quote! { #name_ident: sdforge::axum::extract::Path<#ty> },
582 ParamKind::Query => quote! { #name_ident: sdforge::axum::extract::Query<#ty> },
583 ParamKind::Header => {
584 quote! { #name_ident: sdforge::axum::extract::TypedHeader<#ty> }
585 }
586 ParamKind::Cookie => quote! { #name_ident: sdforge::axum::extract::Cookie },
587 ParamKind::Form => quote! { #name_ident: sdforge::axum::extract::Form<#ty> },
588 ParamKind::Body => quote! { #name_ident: sdforge::axum::extract::Json<#ty> },
589 }
590 })
591 .collect();
592
593 let _param_unwraps: Vec<_> = params .iter()
597 .map(|p| {
598 let name_ident = syn::Ident::new(&p.name, proc_macro2::Span::call_site());
599 quote! { let #name_ident = #name_ident.0; }
601 })
602 .collect();
603
604 let param_names: Vec<_> = params
605 .iter()
606 .map(|p| syn::Ident::new(&p.name, proc_macro2::Span::call_site()))
607 .collect();
608
609 let param_types: Vec<_> = params.iter().map(|p| &p.ty).collect();
611
612 let mcp_schema_props: Vec<String> = params.iter().map(|p| p.to_json_schema()).collect();
614 let mcp_schema_required: Vec<String> = params
615 .iter()
616 .filter(|p| !p.is_option)
617 .map(|p| format!("\"{}\"", p.name))
618 .collect();
619
620 let mcp_properties_json = if mcp_schema_props.is_empty() {
622 quote! { serde_json::json!({}) }
623 } else {
624 let props_vec: Vec<TokenStream2> = mcp_schema_props
625 .iter()
626 .map(|s| s.parse().expect("valid JSON property"))
627 .collect();
628 quote! { serde_json::json!({ #(#props_vec),* }) }
629 };
630
631 let mcp_required_json = if mcp_schema_required.is_empty() {
633 quote! { serde_json::json!([]) }
634 } else {
635 quote! { serde_json::json!([#(#mcp_schema_required),*]) }
636 };
637
638 let fn_name_str = fn_name.to_string();
640 let _handler_name = syn::Ident::new(
641 &format!("__axiom_http_handler_{}", fn_name_str),
643 proc_macro2::Span::call_site(),
644 );
645
646 let register_fn_name = syn::Ident::new(
648 &format!("__axiom_register_{}", fn_name_str),
649 proc_macro2::Span::call_site(),
650 );
651
652 let path_str = path.as_ref().cloned().unwrap_or_default();
654 let axum_path = path_str
657 .split('/')
658 .map(|segment| {
659 if let Some(stripped) = segment.strip_prefix(':') {
660 format!("{{{}}}", stripped)
661 } else {
662 segment.to_string()
663 }
664 })
665 .collect::<Vec<_>>()
666 .join("/");
667 let http_path = format!("/api/{}{}", version, axum_path);
668
669 let http_method_upper = method.as_ref().unwrap_or(&"GET".to_string()).to_uppercase();
671 let http_method_lower = http_method_upper.to_lowercase();
672
673 let cache_ttl_expr = match &cache_ttl {
675 Some(ttl) => quote! { Some(#ttl) },
676 None => quote! { None },
677 };
678
679 let description_literal = description.as_deref().unwrap_or(&name);
681
682 let is_streaming = stream.unwrap_or(false);
684
685 let http_code = if path.is_some() && method.is_some() {
686 let streaming_metadata = match api_metadata_tokens(
688 quote! { #name },
689 quote! { #version },
690 quote! { #description_literal },
691 quote! { None },
692 quote! { true },
693 ) {
694 Ok(tokens) => tokens,
695 Err(e) => return e.into_compile_error().into(),
696 };
697
698 let non_streaming_metadata = match api_metadata_tokens(
699 quote! { #name },
700 quote! { #version },
701 quote! { #description_literal },
702 quote! { None },
703 quote! { false },
704 ) {
705 Ok(tokens) => tokens,
706 Err(e) => return e.into_compile_error().into(),
707 };
708
709 let route_creation = if is_streaming {
711 quote! {
712 fn #register_fn_name() -> sdforge::http::HttpRoute {
713 sdforge::http::HttpRoute {
714 path: #http_path.to_string(),
715 handler: {
716 let mut router = sdforge::axum::routing::MethodRouter::new();
717 router = router.get(#(#param_patterns),* | #(#param_names.0),* | {
718 async move {
719 use sdforge::prelude::*;
720 match #fn_name(#(#param_names.0),*).await {
721 Ok(_stream) => {
722 let body = sdforge::axum::body::Body::from_streaming_bytes(
723 tokio_stream::iter(vec![])
724 );
725 let response: sdforge::axum::response::Response = (
726 [(sdforge::axum::http::header::CONTENT_TYPE, "text/event-stream")],
727 body
728 ).into_response();
729 response
730 }
731 Err(e) => e.into_response(),
732 }
733 }
734 });
735 router
736 },
737 metadata: #streaming_metadata,
738 module_prefix: None,
739 }
740 }
741 }
742 } else {
743 let is_result = match return_type {
744 syn::ReturnType::Type(_, ty) => {
745 matches!(ty.as_ref(), syn::Type::Path(syn::TypePath { qself: None, path: syn::Path { segments, .. } }) if segments.iter().any(|s| s.ident == "Result"))
746 }
747 syn::ReturnType::Default => false,
748 };
749
750 let handler_closure = if is_result {
751 quote! {
752 |#(#param_patterns),*| {
753 async move {
754 use sdforge::prelude::*;
755 match #fn_name(#(#param_names.0),*).await {
756 Ok(value) => sdforge::axum::extract::Json(value).into_response(),
757 Err(e) => e.into_response(),
758 }
759 }
760 }
761 }
762 } else {
763 quote! {
764 |#(#param_patterns),*| {
765 async move {
766 use sdforge::prelude::*;
767 let result = #fn_name(#(#param_names.0),*).await;
768 sdforge::axum::extract::Json(result).into_response()
769 }
770 }
771 }
772 };
773
774 quote! {
775 fn #register_fn_name() -> sdforge::http::HttpRoute {
776 sdforge::http::HttpRoute {
777 path: #http_path.to_string(),
778 handler: {
779 let mut router = sdforge::axum::routing::MethodRouter::new();
780 match #http_method_lower.as_ref() {
781 "get" => router = router.get(#handler_closure),
782 "post" => router = router.post(#handler_closure),
783 "put" => router = router.put(#handler_closure),
784 "delete" => router = router.delete(#handler_closure),
785 "patch" => router = router.patch(#handler_closure),
786 "head" => router = router.head(#handler_closure),
787 "options" => router = router.options(#handler_closure),
788 _ => router = router.get(#handler_closure),
789 }
790 router
791 },
792 metadata: #non_streaming_metadata,
793 module_prefix: None,
794 }
795 }
796 }
797 };
798
799 quote! {
801 #route_creation
802 sdforge::inventory::submit!(sdforge::http::RouteRegistration {
803 name: #name,
804 version: #version,
805 register_fn: #register_fn_name,
806 });
807 }
808 } else {
809 quote! {}
810 };
811
812 let grpc_metadata = match api_metadata_tokens(
814 quote! { #name },
815 quote! { #version },
816 quote! { #description_literal },
817 quote! { #cache_ttl_expr },
818 quote! { false },
819 ) {
820 Ok(tokens) => tokens,
821 Err(e) => return e.into_compile_error().into(),
822 };
823
824 let mcp_code = if let Some(ref tool_name) = tool_name {
825 let mcp_call_logic = if has_params {
826 quote! {
827 #[derive(serde::Deserialize)]
828 struct Params {
829 #(pub #param_names: #param_types),*
830 }
831
832 let params: Params = match input {
833 Some(v) => serde_json::from_value(v)
834 .map_err(|e| anyhow::anyhow!("Failed to parse input: {}", e))?,
835 None => {
836 return Err(anyhow::anyhow!("Missing input parameters"));
837 }
838 };
839
840 let result = #fn_name(#(params.#param_names),*).await;
841 Ok(result)
842 }
843 } else {
844 quote! {
845 let result = #fn_name().await;
846 Ok(result)
847 }
848 };
849
850 let mcp_tool_name = tool_name;
851 let mcp_tool_description = description.as_ref().unwrap_or(&name);
852 let mcp_struct_name = syn::Ident::new(
853 &format!("{}McpTool", fn_name),
854 proc_macro2::Span::call_site(),
855 );
856 let mcp_create_fn_name = syn::Ident::new(
857 &format!("__create_{}_mcp_tool", fn_name),
858 proc_macro2::Span::call_site(),
859 );
860
861 quote! {
862 #[cfg(feature = "mcp")]
863 #[derive(Debug)]
864 struct #mcp_struct_name;
865
866 #[cfg(feature = "mcp")]
867 impl #mcp_struct_name {
868 fn create() -> std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool> {
869 std::sync::Arc::new(Self) as std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool>
870 }
871 }
872
873 #[cfg(feature = "mcp")]
874 impl sdforge::mcp_sdk::tools::Tool for #mcp_struct_name {
875 fn name(&self) -> String {
876 #mcp_tool_name.to_string()
877 }
878
879 fn description(&self) -> String {
880 #mcp_tool_description.to_string()
881 }
882
883 fn input_schema(&self) -> serde_json::Value {
884 serde_json::json!({
885 "type": "object",
886 "properties": #mcp_properties_json,
887 "required": #mcp_required_json
888 })
889 }
890
891 fn call(&self, input: Option<serde_json::Value>) -> anyhow::Result<sdforge::mcp_sdk::types::CallToolResponse> {
892 use sdforge::prelude::*;
893 use tokio::runtime::Runtime;
894
895 let rt = Runtime::new().map_err(|e| anyhow::anyhow!("Failed to create runtime: {}", e))?;
896 let inner_result: Result<Result<_, ApiError>, anyhow::Error> = rt.block_on(async {
897 #mcp_call_logic
898 });
899 let result = inner_result?;
900
901 match result {
902 Ok(response) => {
903 let response_json = serde_json::to_value(response)
904 .map_err(|e| anyhow::anyhow!("Failed to serialize response: {}", e))?;
905 Ok(sdforge::mcp_sdk::types::CallToolResponse {
906 content: vec![sdforge::mcp_sdk::types::ToolResponseContent::Text {
907 text: serde_json::to_string(&response_json)
908 .map_err(|e| anyhow::anyhow!("Failed to stringify response: {}", e))?,
909 }],
910 is_error: Some(false),
911 meta: None,
912 })
913 }
914 Err(e) => {
915 let error_json = serde_json::to_value(e)
916 .map_err(|e| anyhow::anyhow!("Failed to serialize error: {}", e))
917 .unwrap_or_else(|_| {
918 serde_json::json!({
919 "success": false,
920 "error": {
921 "code": "UNKNOWN_ERROR",
922 "message": "An unknown error occurred"
923 }
924 })
925 });
926 Ok(sdforge::mcp_sdk::types::CallToolResponse {
927 content: vec![sdforge::mcp_sdk::types::ToolResponseContent::Text {
928 text: serde_json::to_string(&error_json)
929 .map_err(|e| anyhow::anyhow!("Failed to stringify error: {}", e))?,
930 }],
931 is_error: Some(true),
932 meta: None,
933 })
934 }
935 }
936 }
937 }
938
939 #[cfg(feature = "mcp")]
940 fn #mcp_create_fn_name() -> std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool> {
941 #mcp_struct_name::create()
942 }
943
944 #[cfg(feature = "mcp")]
945 sdforge::inventory::submit!(sdforge::mcp::McpToolRegistration {
946 name: #mcp_tool_name,
947 version: #version,
948 description: #mcp_tool_description,
949 create_fn: #mcp_create_fn_name,
950 });
951 }
952 } else {
953 quote! {}
954 };
955
956 let ws_code = if ws_path.is_some() {
957 quote! {
958 #[cfg(feature = "websocket")]
959 sdforge::inventory::submit!(sdforge::websocket::WebSocketRoute {
960 path: #ws_path.unwrap().to_string(),
961 handler: #fn_name,
962 });
963 }
964 } else {
965 quote! {}
966 };
967
968 let grpc_code = if grpc_method.is_some() {
969 quote! {
970 #[cfg(feature = "grpc")]
971 sdforge::inventory::submit!(sdforge::grpc::GrpcRoute {
972 service_name: #name.to_string(),
973 metadata: #grpc_metadata,
974 });
975 }
976 } else {
977 quote! {}
978 };
979
980 let generated = quote! {
981 #input
982 #http_code
983 #mcp_code
984 #ws_code
985 #grpc_code
986 };
987
988 generated.into()
989}
990
991#[proc_macro_attribute]
992pub fn service_module(args: TokenStream, input: TokenStream) -> TokenStream {
993 let prefix = match parse_service_module_args(args.into()) {
994 Ok(prefix) => prefix,
995 Err(e) => return e.into_compile_error().into(),
996 };
997 let input = parse_macro_input!(input as ItemMod);
998
999 let prefix_const = quote! {
1001 pub const MODULE_PREFIX: &str = #prefix;
1002 };
1003
1004 let prefix_helper = quote! {
1006 #[inline]
1007 pub fn apply_prefix(path: &str) -> String {
1008 if path.starts_with('/') {
1009 format!("{}{}", MODULE_PREFIX, path)
1010 } else {
1011 format!("{}{}", MODULE_PREFIX, path)
1012 }
1013 }
1014 };
1015
1016 let expanded = quote! {
1017 #input
1018
1019 #prefix_const
1020 #prefix_helper
1021 };
1022
1023 expanded.into()
1024}
1025
1026#[proc_macro]
1027pub fn test_macro(input: TokenStream) -> TokenStream {
1028 let input = parse_macro_input!(input as ItemFn);
1029
1030 let fn_name = &input.sig.ident;
1031
1032 let expanded = quote! {
1033 #input
1034
1035 #[cfg(test)]
1036 mod #fn_name {
1037 use super::*;
1038
1039 #[test]
1040 fn test_generated() {
1041 println!("Test macro generated for: {}", stringify!(#fn_name));
1042 }
1043 }
1044 };
1045
1046 expanded.into()
1047}
1048
1049#[cfg(test)]
1050mod macro_parsing_tests {
1051 use super::*;
1052
1053 #[test]
1054 fn test_parse_kv_pairs_simple() {
1055 let input: TokenStream2 = quote! { name = "test" };
1056 let result = parse_kv_pairs(input).unwrap();
1057 assert_eq!(result, vec![("name".to_string(), "test".to_string())]);
1058 }
1059
1060 #[test]
1061 fn test_parse_kv_pairs_multiple() {
1062 let input: TokenStream2 = quote! { name = "test", version = "v1" };
1063 let result = parse_kv_pairs(input).unwrap();
1064 assert_eq!(
1065 result,
1066 vec![
1067 ("name".to_string(), "test".to_string()),
1068 ("version".to_string(), "v1".to_string())
1069 ]
1070 );
1071 }
1072
1073 #[test]
1074 fn test_parse_service_api_args_required() {
1075 let input: TokenStream2 = quote! { name = "test", version = "v1" };
1076 let result = parse_service_api_args(input).unwrap();
1077 assert_eq!(result.0, "test");
1078 assert_eq!(result.1, "v1");
1079 }
1080
1081 #[test]
1082 fn test_parse_service_module_args() {
1083 let input: TokenStream2 = quote! { prefix = "/api/v1" };
1084 let result = parse_service_module_args(input).unwrap();
1085 assert_eq!(result, "/api/v1");
1086 }
1087}