sdforge_macros/
lib.rs

1// Copyright (c) 2026 Kirky.X
2//! Axiom procedural macros
3//!
4//! This crate provides procedural macros for the Axiom framework.
5
6#![doc(html_root_url = "https://docs.rs/sdforge-macros/0.1.0")]
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::quote;
11use syn::{parse_macro_input, FnArg, ItemFn, ItemMod, Pat};
12
13/// Type alias for service_api arguments parsing result
14type ServiceApiArgs = Result<
15    (
16        String,
17        String,
18        Option<String>,
19        Option<String>,
20        Option<String>,
21        Option<String>,
22        Option<bool>,
23        Option<u64>,
24        Option<String>,
25        Option<String>,
26    ),
27    syn::Error,
28>;
29
30/// Parse key=value pairs from token stream
31/// Preserves original string-based parsing for compatibility
32fn parse_kv_pairs(args: TokenStream2) -> Result<Vec<(String, String)>, syn::Error> {
33    let args_str = args.to_string();
34    let mut pairs = Vec::new();
35
36    let mut chars = args_str.chars().peekable();
37    while let Some(&c) = chars.peek() {
38        if c.is_whitespace() || c == ',' {
39            chars.next();
40            continue;
41        }
42
43        let mut key = String::new();
44        while let Some(&c) = chars.peek() {
45            if c == '=' || c.is_whitespace() {
46                break;
47            }
48            key.push(c);
49            chars.next();
50        }
51
52        while let Some(&c) = chars.peek() {
53            if c == '=' {
54                chars.next();
55                break;
56            }
57            chars.next();
58        }
59
60        while let Some(&c) = chars.peek() {
61            if c.is_whitespace() {
62                chars.next();
63            } else {
64                break;
65            }
66        }
67
68        let mut value = String::new();
69        if let Some(&'"') = chars.peek() {
70            chars.next();
71            for c in chars.by_ref() {
72                if c == '"' {
73                    break;
74                }
75                value.push(c);
76            }
77        }
78
79        if !key.is_empty() && !value.is_empty() {
80            pairs.push((key, value));
81        }
82
83        if chars.peek().is_none() {
84            break;
85        }
86    }
87
88    Ok(pairs)
89}
90
91/// Generate ApiMetadata TokenStream for service API
92/// Accepts TokenStream2 parameters to work within quote! macro
93#[allow(dead_code)]
94#[inline]
95fn api_metadata_tokens(
96    name: TokenStream2,
97    version: TokenStream2,
98    description: TokenStream2,
99    cache_ttl: TokenStream2,
100    is_streaming: TokenStream2,
101) -> Result<TokenStream2, syn::Error> {
102    // Validate and sanitize inputs at compile time to prevent code injection
103    // These validations will cause compilation to fail if inputs are invalid
104    let validated_name = validate_api_name(&name.to_string())?;
105    let validated_version = validate_version(&version.to_string())?;
106
107    Ok(quote! {
108        sdforge::core::ApiMetadata::new(
109            #validated_name.to_string(),
110            #validated_version.to_string(),
111            #description.to_string(),
112            #cache_ttl,
113            #is_streaming,
114        )
115    })
116}
117
118/// Validate API name to prevent code injection
119/// API names must be valid Rust identifiers (alphanumeric + underscores, starting with letter)
120fn validate_api_name(name: &str) -> Result<String, syn::Error> {
121    let name = name.trim_matches('"').trim();
122
123    // Check for empty name
124    if name.is_empty() {
125        return Err(syn::Error::new(
126            proc_macro2::Span::call_site(),
127            "API name cannot be empty",
128        ));
129    }
130
131    // Check for invalid characters (allow alphanumeric and underscores)
132    if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
133        return Err(syn::Error::new(
134            proc_macro2::Span::call_site(),
135            format!("API name contains invalid characters: {}", name),
136        ));
137    }
138
139    // Check that name starts with a letter (valid Rust identifier)
140    if name.starts_with(|c: char| !c.is_alphabetic() && c != '_') {
141        return Err(syn::Error::new(
142            proc_macro2::Span::call_site(),
143            format!("API name must start with a letter or underscore: {}", name),
144        ));
145    }
146
147    // Check for reserved Rust keywords
148    if RESERVED_KEYWORDS.contains(&name) {
149        return Err(syn::Error::new(
150            proc_macro2::Span::call_site(),
151            format!("API name cannot be a Rust keyword: {}", name),
152        ));
153    }
154
155    Ok(name.to_string())
156}
157
158/// Validate version string to prevent code injection
159/// Version strings should match common patterns like "v1", "1.0", "v1.2.3"
160fn validate_version(version: &str) -> Result<String, syn::Error> {
161    let version = version.trim_matches('"').trim();
162
163    // Check for empty version
164    if version.is_empty() {
165        return Err(syn::Error::new(
166            proc_macro2::Span::call_site(),
167            "API version cannot be empty",
168        ));
169    }
170
171    // Version should only contain alphanumeric characters, dots, and optionally a 'v' prefix
172    if !version
173        .chars()
174        .all(|c| c.is_alphanumeric() || c == '.' || c == '-')
175    {
176        let invalid_chars: Vec<char> = version
177            .chars()
178            .filter(|c| !c.is_alphanumeric() && *c != '.' && *c != '-')
179            .collect();
180        return Err(syn::Error::new(
181            proc_macro2::Span::call_site(),
182            format!(
183                "API version contains invalid characters: {}",
184                invalid_chars.iter().collect::<String>()
185            ),
186        ));
187    }
188
189    Ok(version.to_string())
190}
191
192/// Reserved Rust keywords that cannot be used as API names
193const RESERVED_KEYWORDS: &[&str] = &[
194    "match", "if", "else", "loop", "while", "for", "break", "continue", "fn", "struct", "enum",
195    "impl", "trait", "pub", "mod", "use", "const", "static", "let", "mut", "ref", "self", "super",
196    "crate", "return", "true", "false", "async", "await", "dyn", "unsafe", "extern", "type",
197    "where", "move", "as", "in", "of", "is", "Some", "None", "Ok", "Err",
198];
199
200/// Default cache TTL in seconds (5 minutes)
201const DEFAULT_CACHE_TTL: u64 = 300;
202
203/// Parse service_api attributes
204fn parse_service_api_args(args: TokenStream2) -> ServiceApiArgs {
205    let pairs = parse_kv_pairs(args)?;
206
207    let mut name = None;
208    let mut version = None;
209    let mut description = None;
210    let mut path = None;
211    let mut method = None;
212    let mut tool_name = None;
213    let mut stream = None;
214    let mut cache_ttl = None;
215    let mut ws_path = None;
216    let mut grpc_method = None;
217
218    for (key, value) in pairs {
219        match key.as_str() {
220            "name" => name = Some(value),
221            "version" => version = Some(value),
222            "description" => description = Some(value),
223            "path" => path = Some(value),
224            "method" => method = Some(value),
225            "tool_name" => tool_name = Some(value),
226            "stream" => {
227                stream = Some(value.parse::<bool>().map_err(|_| {
228                    syn::Error::new(
229                        proc_macro2::Span::call_site(),
230                        format!("Invalid boolean value for 'stream': {}", value),
231                    )
232                })?)
233            }
234            "cache_ttl" => {
235                cache_ttl = Some(
236                    value
237                        .parse::<u64>()
238                        .map_err(|_| {
239                            syn::Error::new(
240                                proc_macro2::Span::call_site(),
241                                format!(
242                                    "Invalid cache TTL value (must be a positive integer): {}",
243                                    value
244                                ),
245                            )
246                        })?
247                        .max(DEFAULT_CACHE_TTL),
248                )
249            }
250            "ws_path" => ws_path = Some(value),
251            "grpc_method" => grpc_method = Some(value),
252            _ => {
253                return Err(syn::Error::new(
254                    proc_macro2::Span::call_site(),
255                    format!("Unknown attribute: {}", key),
256                ))
257            }
258        }
259    }
260
261    let name = name.ok_or_else(|| {
262        syn::Error::new(
263            proc_macro2::Span::call_site(),
264            "Missing required attribute: name",
265        )
266    })?;
267    let version = version.ok_or_else(|| {
268        syn::Error::new(
269            proc_macro2::Span::call_site(),
270            "Missing required attribute: version",
271        )
272    })?;
273
274    Ok((
275        name,
276        version,
277        description,
278        path,
279        method,
280        tool_name,
281        stream,
282        cache_ttl,
283        ws_path,
284        grpc_method,
285    ))
286}
287
288/// Parse service_module attributes
289fn parse_service_module_args(args: TokenStream2) -> Result<String, syn::Error> {
290    let pairs = parse_kv_pairs(args)?;
291
292    let mut prefix = None;
293
294    for (key, value) in pairs {
295        match key.as_str() {
296            "prefix" => prefix = Some(value),
297            _ => {
298                return Err(syn::Error::new(
299                    proc_macro2::Span::call_site(),
300                    format!("Unknown attribute: {}", key),
301                ))
302            }
303        }
304    }
305
306    prefix.ok_or_else(|| {
307        syn::Error::new(
308            proc_macro2::Span::call_site(),
309            "Missing required attribute: prefix",
310        )
311    })
312}
313
314#[derive(Debug, Clone)]
315enum ParamKind {
316    Path,
317    Query,
318    Header,
319    Cookie,
320    Form,
321    Body,
322}
323
324impl std::fmt::Display for ParamKind {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        match self {
327            ParamKind::Path => write!(f, "path"),
328            ParamKind::Query => write!(f, "query"),
329            ParamKind::Header => write!(f, "header"),
330            ParamKind::Cookie => write!(f, "cookie"),
331            ParamKind::Form => write!(f, "form"),
332            ParamKind::Body => write!(f, "body"),
333        }
334    }
335}
336
337/// Extract parameter info from function arguments
338#[derive(Debug, Clone)]
339#[allow(dead_code)]
340struct ParamInfo {
341    /// Parameter name (identifier)
342    name: String,
343    /// Parameter type
344    ty: syn::Type,
345    /// Extraction kind
346    param_kind: ParamKind,
347    /// Whether the parameter is Option<T>
348    is_option: bool,
349    /// Whether the parameter is Vec<T>
350    is_vec: bool,
351    /// The inner type for Option or Vec (as string for comparison)
352    inner_type: String,
353    /// Explicit parameter annotation (if any)
354    explicit_annotation: Option<ParamKind>,
355}
356
357impl ParamInfo {
358    fn from_arg(
359        arg: &FnArg,
360        path_params: &[String],
361        http_method: Option<&str>,
362        body_params: &[String],
363    ) -> Option<Self> {
364        let pat_type = match arg {
365            FnArg::Receiver(_) => return None,
366            FnArg::Typed(pat_type) => pat_type,
367        };
368
369        let pat = &*pat_type.pat;
370        if let Pat::Ident(pat_ident) = pat {
371            let name = pat_ident.ident.to_string();
372
373            // Clone the type from the typed pattern
374            let ty = (*pat_type.ty).clone();
375
376            let ty_str = quote! { #ty }.to_string();
377            let ty_str_trimmed = ty_str.trim().to_string();
378
379            // Check for explicit #[param(kind = "...")] attribute
380            let explicit_annotation = Self::extract_param_annotation(pat_type);
381
382            // Determine extraction kind based on explicit annotation first, then path parameters, then type inference
383            let param_kind = if let Some(ref kind) = explicit_annotation {
384                kind.clone()
385            } else if path_params.contains(&name) {
386                ParamKind::Path
387            } else if ty_str_trimmed.starts_with("Option<") {
388                // Check if it's Option<HeaderMap<...>> or similar
389                let inner = &ty_str_trimmed[7..ty_str_trimmed.len() - 1];
390                if inner.starts_with("HeaderMap") || inner.starts_with("HeaderValue") {
391                    ParamKind::Header
392                } else {
393                    ParamKind::Query
394                }
395            } else if http_method.map(|m| m.to_uppercase()) == Some("GET".to_string()) {
396                ParamKind::Query
397            } else if body_params.contains(&name) {
398                // Always use Json extractor for body parameters
399                // Form extractor is only for form-urlencoded requests
400                ParamKind::Body
401            } else {
402                ParamKind::Body
403            };
404
405            let (is_option, is_vec, inner_type) = if ty_str_trimmed.starts_with("Option<") {
406                let inner = &ty_str_trimmed[7..ty_str_trimmed.len() - 1];
407                (true, false, inner.to_string())
408            } else if ty_str_trimmed.starts_with("Vec<") {
409                let inner = &ty_str_trimmed[4..ty_str_trimmed.len() - 1];
410                (false, true, inner.to_string())
411            } else {
412                (false, false, ty_str_trimmed.clone())
413            };
414
415            Some(Self {
416                name,
417                ty,
418                param_kind,
419                is_option,
420                is_vec,
421                inner_type,
422                explicit_annotation,
423            })
424        } else {
425            None
426        }
427    }
428
429    /// Extract explicit #[param(kind = "...")] attribute from function argument
430    fn extract_param_annotation(pat_type: &syn::PatType) -> Option<ParamKind> {
431        for attr in &pat_type.attrs {
432            if attr.path().is_ident("param") {
433                // Parse the attribute: #[param(kind = "path")]
434                if let Ok(meta) = attr.parse_args_with(
435                    syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
436                ) {
437                    for meta_item in meta {
438                        if let syn::Meta::NameValue(name_value) = meta_item {
439                            if name_value.path.is_ident("kind") {
440                                if let syn::Expr::Lit(syn::ExprLit {
441                                    lit: syn::Lit::Str(lit_str),
442                                    ..
443                                }) = &name_value.value
444                                {
445                                    return match lit_str.value().as_str() {
446                                        "path" => Some(ParamKind::Path),
447                                        "query" => Some(ParamKind::Query),
448                                        "header" => Some(ParamKind::Header),
449                                        "cookie" => Some(ParamKind::Cookie),
450                                        "form" => Some(ParamKind::Form),
451                                        "body" => Some(ParamKind::Body),
452                                        _ => None,
453                                    };
454                                }
455                            }
456                        }
457                    }
458                }
459            }
460        }
461        None
462    }
463
464    /// Convert parameter to JSON schema property
465    fn to_json_schema(&self) -> String {
466        let param_type = if self.is_option {
467            format!(
468                "{{\"type\":[\"null\",{}]}}",
469                self.inner_type_to_json_schema()
470            )
471        } else if self.is_vec {
472            format!(
473                "{{\"type\":\"array\",\"items\":{}}}",
474                self.inner_type_to_json_schema()
475            )
476        } else {
477            format!("{{\"type\":{}}}", self.inner_type_to_json_schema())
478        };
479        format!("\"{}\":{}", self.name, param_type)
480    }
481
482    fn inner_type_to_json_schema(&self) -> String {
483        match self.inner_type.as_str() {
484            "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128"
485            | "f32" | "f64" => "\"number\"".to_string(),
486            "bool" => "\"boolean\"".to_string(),
487            "String" | "&str" => "\"string\"".to_string(),
488            _ => "\"object\"".to_string(),
489        }
490    }
491}
492
493/// Extract path parameters from path string
494fn extract_path_params(path: &str) -> Vec<String> {
495    path.split('/')
496        .filter(|segment| segment.starts_with(':') || segment.starts_with('{'))
497        .map(|segment| {
498            // Remove leading : or { and trailing } or }
499            segment
500                .trim_start_matches(':')
501                .trim_start_matches('{')
502                .trim_end_matches('}')
503                .trim_end_matches('}')
504                .to_string()
505        })
506        .collect()
507}
508
509#[proc_macro_attribute]
510pub fn service_api(args: TokenStream, input: TokenStream) -> TokenStream {
511    let args = match parse_service_api_args(args.into()) {
512        Ok(args) => args,
513        Err(e) => return e.into_compile_error().into(),
514    };
515    let input = parse_macro_input!(input as ItemFn);
516
517    let (
518        name,
519        version,
520        description,
521        path,
522        method,
523        tool_name,
524        stream,
525        cache_ttl,
526        ws_path,
527        grpc_method,
528    ) = args;
529    let fn_name = &input.sig.ident;
530    let _fn_vis = &input.vis; // Currently unused but kept for future use
531    let return_type = &input.sig.output;
532
533    // Extract path parameters from path string
534    let path_params = path
535        .as_ref()
536        .map(|p| extract_path_params(p))
537        .unwrap_or_default();
538
539    // Collect all parameter names first to determine if we need Form extractor
540    let all_param_names: Vec<String> = input
541        .sig
542        .inputs
543        .iter()
544        .filter_map(|arg| {
545            if let FnArg::Typed(pat_type) = arg {
546                if let Pat::Ident(pat_ident) = &*pat_type.pat {
547                    return Some(pat_ident.ident.to_string());
548                }
549            }
550            None
551        })
552        .collect();
553
554    // Filter to get body params (non-path, non-Option params)
555    let body_param_names: Vec<String> = all_param_names
556        .iter()
557        .filter(|name| !path_params.contains(name))
558        .cloned()
559        .collect();
560
561    // Extract function parameters
562    let params: Vec<ParamInfo> = input
563        .sig
564        .inputs
565        .iter()
566        .filter_map(|arg| {
567            ParamInfo::from_arg(arg, &path_params, method.as_deref(), &body_param_names)
568        })
569        .collect();
570
571    // Check if there are any parameters
572    let has_params = !params.is_empty();
573
574    // Build parameter patterns based on type
575    let param_patterns: Vec<_> = params
576        .iter()
577        .map(|p| {
578            let name_ident = syn::Ident::new(&p.name, proc_macro2::Span::call_site());
579            let ty = &p.ty;
580            match p.param_kind {
581                ParamKind::Path => quote! { #name_ident: sdforge::axum::extract::Path<#ty> },
582                ParamKind::Query => quote! { #name_ident: sdforge::axum::extract::Query<#ty> },
583                ParamKind::Header => {
584                    quote! { #name_ident: sdforge::axum::extract::TypedHeader<#ty> }
585                }
586                ParamKind::Cookie => quote! { #name_ident: sdforge::axum::extract::Cookie },
587                ParamKind::Form => quote! { #name_ident: sdforge::axum::extract::Form<#ty> },
588                ParamKind::Body => quote! { #name_ident: sdforge::axum::extract::Json<#ty> },
589            }
590        })
591        .collect();
592
593    // Build parameter unwrapping logic
594    // All parameter types use the same unwrapping pattern: extract .0 field
595    let _param_unwraps: Vec<_> = params // Currently unused but kept for future use
596        .iter()
597        .map(|p| {
598            let name_ident = syn::Ident::new(&p.name, proc_macro2::Span::call_site());
599            // All parameter kinds use identical unwrapping: extract first element
600            quote! { let #name_ident = #name_ident.0; }
601        })
602        .collect();
603
604    let param_names: Vec<_> = params
605        .iter()
606        .map(|p| syn::Ident::new(&p.name, proc_macro2::Span::call_site()))
607        .collect();
608
609    // Collect parameter types for MCP tool struct generation
610    let param_types: Vec<_> = params.iter().map(|p| &p.ty).collect();
611
612    // Build MCP input schema
613    let mcp_schema_props: Vec<String> = params.iter().map(|p| p.to_json_schema()).collect();
614    let mcp_schema_required: Vec<String> = params
615        .iter()
616        .filter(|p| !p.is_option)
617        .map(|p| format!("\"{}\"", p.name))
618        .collect();
619
620    // Pre-compute properties JSON to avoid macro nesting issues
621    let mcp_properties_json = if mcp_schema_props.is_empty() {
622        quote! { serde_json::json!({}) }
623    } else {
624        let props_vec: Vec<TokenStream2> = mcp_schema_props
625            .iter()
626            .map(|s| s.parse().expect("valid JSON property"))
627            .collect();
628        quote! { serde_json::json!({ #(#props_vec),* }) }
629    };
630
631    // Pre-compute required array JSON
632    let mcp_required_json = if mcp_schema_required.is_empty() {
633        quote! { serde_json::json!([]) }
634    } else {
635        quote! { serde_json::json!([#(#mcp_schema_required),*]) }
636    };
637
638    // Generate unique handler name to avoid conflicts
639    let fn_name_str = fn_name.to_string();
640    let _handler_name = syn::Ident::new(
641        // Currently unused but kept for future use
642        &format!("__axiom_http_handler_{}", fn_name_str),
643        proc_macro2::Span::call_site(),
644    );
645
646    // Generate unique route registration function name
647    let register_fn_name = syn::Ident::new(
648        &format!("__axiom_register_{}", fn_name_str),
649        proc_macro2::Span::call_site(),
650    );
651
652    // Build HTTP path with version, converting :param to {param} for axum 0.8
653    let path_str = path.as_ref().cloned().unwrap_or_default();
654    // Convert :param to {param} for axum 0.8 compatibility
655    // e.g., "/users/:id" -> "/users/{id}"
656    let axum_path = path_str
657        .split('/')
658        .map(|segment| {
659            if let Some(stripped) = segment.strip_prefix(':') {
660                format!("{{{}}}", stripped)
661            } else {
662                segment.to_string()
663            }
664        })
665        .collect::<Vec<_>>()
666        .join("/");
667    let http_path = format!("/api/{}{}", version, axum_path);
668
669    // Build HTTP method
670    let http_method_upper = method.as_ref().unwrap_or(&"GET".to_string()).to_uppercase();
671    let http_method_lower = http_method_upper.to_lowercase();
672
673    // Convert cache_ttl to a proper expression for the quote macro
674    let cache_ttl_expr = match &cache_ttl {
675        Some(ttl) => quote! { Some(#ttl) },
676        None => quote! { None },
677    };
678
679    // Build description expression
680    let description_literal = description.as_deref().unwrap_or(&name);
681
682    // Generate HTTP code
683    let is_streaming = stream.unwrap_or(false);
684
685    let http_code = if path.is_some() && method.is_some() {
686        // Generate metadata tokens before the quote block
687        let streaming_metadata = match api_metadata_tokens(
688            quote! { #name },
689            quote! { #version },
690            quote! { #description_literal },
691            quote! { None },
692            quote! { true },
693        ) {
694            Ok(tokens) => tokens,
695            Err(e) => return e.into_compile_error().into(),
696        };
697
698        let non_streaming_metadata = match api_metadata_tokens(
699            quote! { #name },
700            quote! { #version },
701            quote! { #description_literal },
702            quote! { None },
703            quote! { false },
704        ) {
705            Ok(tokens) => tokens,
706            Err(e) => return e.into_compile_error().into(),
707        };
708
709        // Generate route creation function with inline handler closure
710        let route_creation = if is_streaming {
711            quote! {
712                fn #register_fn_name() -> sdforge::http::HttpRoute {
713                    sdforge::http::HttpRoute {
714                        path: #http_path.to_string(),
715                        handler: {
716                            let mut router = sdforge::axum::routing::MethodRouter::new();
717                            router = router.get(#(#param_patterns),* | #(#param_names.0),* | {
718                                async move {
719                                    use sdforge::prelude::*;
720                                    match #fn_name(#(#param_names.0),*).await {
721                                        Ok(_stream) => {
722                                            let body = sdforge::axum::body::Body::from_streaming_bytes(
723                                                tokio_stream::iter(vec![])
724                                            );
725                                            let response: sdforge::axum::response::Response = (
726                                                [(sdforge::axum::http::header::CONTENT_TYPE, "text/event-stream")],
727                                                body
728                                            ).into_response();
729                                            response
730                                        }
731                                        Err(e) => e.into_response(),
732                                    }
733                                }
734                            });
735                            router
736                        },
737                        metadata: #streaming_metadata,
738                        module_prefix: None,
739                    }
740                }
741            }
742        } else {
743            let is_result = match return_type {
744                syn::ReturnType::Type(_, ty) => {
745                    matches!(ty.as_ref(), syn::Type::Path(syn::TypePath { qself: None, path: syn::Path { segments, .. } }) if segments.iter().any(|s| s.ident == "Result"))
746                }
747                syn::ReturnType::Default => false,
748            };
749
750            let handler_closure = if is_result {
751                quote! {
752                    |#(#param_patterns),*| {
753                        async move {
754                            use sdforge::prelude::*;
755                            match #fn_name(#(#param_names.0),*).await {
756                                Ok(value) => sdforge::axum::extract::Json(value).into_response(),
757                                Err(e) => e.into_response(),
758                            }
759                        }
760                    }
761                }
762            } else {
763                quote! {
764                    |#(#param_patterns),*| {
765                        async move {
766                            use sdforge::prelude::*;
767                            let result = #fn_name(#(#param_names.0),*).await;
768                            sdforge::axum::extract::Json(result).into_response()
769                        }
770                    }
771                }
772            };
773
774            quote! {
775                fn #register_fn_name() -> sdforge::http::HttpRoute {
776                    sdforge::http::HttpRoute {
777                        path: #http_path.to_string(),
778                        handler: {
779                            let mut router = sdforge::axum::routing::MethodRouter::new();
780                            match #http_method_lower.as_ref() {
781                                "get" => router = router.get(#handler_closure),
782                                "post" => router = router.post(#handler_closure),
783                                "put" => router = router.put(#handler_closure),
784                                "delete" => router = router.delete(#handler_closure),
785                                "patch" => router = router.patch(#handler_closure),
786                                "head" => router = router.head(#handler_closure),
787                                "options" => router = router.options(#handler_closure),
788                                _ => router = router.get(#handler_closure),
789                            }
790                            router
791                        },
792                        metadata: #non_streaming_metadata,
793                        module_prefix: None,
794                    }
795                }
796            }
797        };
798
799        // Combine route creation function and registration
800        quote! {
801            #route_creation
802            sdforge::inventory::submit!(sdforge::http::RouteRegistration {
803                name: #name,
804                version: #version,
805                register_fn: #register_fn_name,
806            });
807        }
808    } else {
809        quote! {}
810    };
811
812    // Generate gRPC metadata tokens before the quote block
813    let grpc_metadata = match api_metadata_tokens(
814        quote! { #name },
815        quote! { #version },
816        quote! { #description_literal },
817        quote! { #cache_ttl_expr },
818        quote! { false },
819    ) {
820        Ok(tokens) => tokens,
821        Err(e) => return e.into_compile_error().into(),
822    };
823
824    let mcp_code = if let Some(ref tool_name) = tool_name {
825        let mcp_call_logic = if has_params {
826            quote! {
827                #[derive(serde::Deserialize)]
828                struct Params {
829                    #(pub #param_names: #param_types),*
830                }
831
832                let params: Params = match input {
833                    Some(v) => serde_json::from_value(v)
834                        .map_err(|e| anyhow::anyhow!("Failed to parse input: {}", e))?,
835                    None => {
836                        return Err(anyhow::anyhow!("Missing input parameters"));
837                    }
838                };
839
840                let result = #fn_name(#(params.#param_names),*).await;
841                Ok(result)
842            }
843        } else {
844            quote! {
845                let result = #fn_name().await;
846                Ok(result)
847            }
848        };
849
850        let mcp_tool_name = tool_name;
851        let mcp_tool_description = description.as_ref().unwrap_or(&name);
852        let mcp_struct_name = syn::Ident::new(
853            &format!("{}McpTool", fn_name),
854            proc_macro2::Span::call_site(),
855        );
856        let mcp_create_fn_name = syn::Ident::new(
857            &format!("__create_{}_mcp_tool", fn_name),
858            proc_macro2::Span::call_site(),
859        );
860
861        quote! {
862            #[cfg(feature = "mcp")]
863            #[derive(Debug)]
864            struct #mcp_struct_name;
865
866            #[cfg(feature = "mcp")]
867            impl #mcp_struct_name {
868                fn create() -> std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool> {
869                    std::sync::Arc::new(Self) as std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool>
870                }
871            }
872
873            #[cfg(feature = "mcp")]
874            impl sdforge::mcp_sdk::tools::Tool for #mcp_struct_name {
875                fn name(&self) -> String {
876                    #mcp_tool_name.to_string()
877                }
878
879                fn description(&self) -> String {
880                    #mcp_tool_description.to_string()
881                }
882
883                fn input_schema(&self) -> serde_json::Value {
884                    serde_json::json!({
885                        "type": "object",
886                        "properties": #mcp_properties_json,
887                        "required": #mcp_required_json
888                    })
889                }
890
891                fn call(&self, input: Option<serde_json::Value>) -> anyhow::Result<sdforge::mcp_sdk::types::CallToolResponse> {
892                    use sdforge::prelude::*;
893                    use tokio::runtime::Runtime;
894
895                    let rt = Runtime::new().map_err(|e| anyhow::anyhow!("Failed to create runtime: {}", e))?;
896                    let inner_result: Result<Result<_, ApiError>, anyhow::Error> = rt.block_on(async {
897                        #mcp_call_logic
898                    });
899                    let result = inner_result?;
900
901                    match result {
902                        Ok(response) => {
903                            let response_json = serde_json::to_value(response)
904                                .map_err(|e| anyhow::anyhow!("Failed to serialize response: {}", e))?;
905                            Ok(sdforge::mcp_sdk::types::CallToolResponse {
906                                content: vec![sdforge::mcp_sdk::types::ToolResponseContent::Text {
907                                    text: serde_json::to_string(&response_json)
908                                        .map_err(|e| anyhow::anyhow!("Failed to stringify response: {}", e))?,
909                                }],
910                                is_error: Some(false),
911                                meta: None,
912                            })
913                        }
914                        Err(e) => {
915                            let error_json = serde_json::to_value(e)
916                                .map_err(|e| anyhow::anyhow!("Failed to serialize error: {}", e))
917                                .unwrap_or_else(|_| {
918                                    serde_json::json!({
919                                        "success": false,
920                                        "error": {
921                                            "code": "UNKNOWN_ERROR",
922                                            "message": "An unknown error occurred"
923                                        }
924                                    })
925                                });
926                            Ok(sdforge::mcp_sdk::types::CallToolResponse {
927                                content: vec![sdforge::mcp_sdk::types::ToolResponseContent::Text {
928                                    text: serde_json::to_string(&error_json)
929                                        .map_err(|e| anyhow::anyhow!("Failed to stringify error: {}", e))?,
930                                }],
931                                is_error: Some(true),
932                                meta: None,
933                            })
934                        }
935                    }
936                }
937            }
938
939            #[cfg(feature = "mcp")]
940            fn #mcp_create_fn_name() -> std::sync::Arc<dyn sdforge::mcp_sdk::tools::Tool> {
941                #mcp_struct_name::create()
942            }
943
944            #[cfg(feature = "mcp")]
945            sdforge::inventory::submit!(sdforge::mcp::McpToolRegistration {
946                name: #mcp_tool_name,
947                version: #version,
948                description: #mcp_tool_description,
949                create_fn: #mcp_create_fn_name,
950            });
951        }
952    } else {
953        quote! {}
954    };
955
956    let ws_code = if ws_path.is_some() {
957        quote! {
958            #[cfg(feature = "websocket")]
959            sdforge::inventory::submit!(sdforge::websocket::WebSocketRoute {
960                path: #ws_path.unwrap().to_string(),
961                handler: #fn_name,
962            });
963        }
964    } else {
965        quote! {}
966    };
967
968    let grpc_code = if grpc_method.is_some() {
969        quote! {
970            #[cfg(feature = "grpc")]
971            sdforge::inventory::submit!(sdforge::grpc::GrpcRoute {
972                service_name: #name.to_string(),
973                metadata: #grpc_metadata,
974            });
975        }
976    } else {
977        quote! {}
978    };
979
980    let generated = quote! {
981        #input
982        #http_code
983        #mcp_code
984        #ws_code
985        #grpc_code
986    };
987
988    generated.into()
989}
990
991#[proc_macro_attribute]
992pub fn service_module(args: TokenStream, input: TokenStream) -> TokenStream {
993    let prefix = match parse_service_module_args(args.into()) {
994        Ok(prefix) => prefix,
995        Err(e) => return e.into_compile_error().into(),
996    };
997    let input = parse_macro_input!(input as ItemMod);
998
999    // Generate a constant for the module prefix
1000    let prefix_const = quote! {
1001        pub const MODULE_PREFIX: &str = #prefix;
1002    };
1003
1004    // Generate a helper function that applies the prefix
1005    let prefix_helper = quote! {
1006        #[inline]
1007        pub fn apply_prefix(path: &str) -> String {
1008            if path.starts_with('/') {
1009                format!("{}{}", MODULE_PREFIX, path)
1010            } else {
1011                format!("{}{}", MODULE_PREFIX, path)
1012            }
1013        }
1014    };
1015
1016    let expanded = quote! {
1017        #input
1018
1019        #prefix_const
1020        #prefix_helper
1021    };
1022
1023    expanded.into()
1024}
1025
1026#[proc_macro]
1027pub fn test_macro(input: TokenStream) -> TokenStream {
1028    let input = parse_macro_input!(input as ItemFn);
1029
1030    let fn_name = &input.sig.ident;
1031
1032    let expanded = quote! {
1033        #input
1034
1035        #[cfg(test)]
1036        mod #fn_name {
1037            use super::*;
1038
1039            #[test]
1040            fn test_generated() {
1041                println!("Test macro generated for: {}", stringify!(#fn_name));
1042            }
1043        }
1044    };
1045
1046    expanded.into()
1047}
1048
1049#[cfg(test)]
1050mod macro_parsing_tests {
1051    use super::*;
1052
1053    #[test]
1054    fn test_parse_kv_pairs_simple() {
1055        let input: TokenStream2 = quote! { name = "test" };
1056        let result = parse_kv_pairs(input).unwrap();
1057        assert_eq!(result, vec![("name".to_string(), "test".to_string())]);
1058    }
1059
1060    #[test]
1061    fn test_parse_kv_pairs_multiple() {
1062        let input: TokenStream2 = quote! { name = "test", version = "v1" };
1063        let result = parse_kv_pairs(input).unwrap();
1064        assert_eq!(
1065            result,
1066            vec![
1067                ("name".to_string(), "test".to_string()),
1068                ("version".to_string(), "v1".to_string())
1069            ]
1070        );
1071    }
1072
1073    #[test]
1074    fn test_parse_service_api_args_required() {
1075        let input: TokenStream2 = quote! { name = "test", version = "v1" };
1076        let result = parse_service_api_args(input).unwrap();
1077        assert_eq!(result.0, "test");
1078        assert_eq!(result.1, "v1");
1079    }
1080
1081    #[test]
1082    fn test_parse_service_module_args() {
1083        let input: TokenStream2 = quote! { prefix = "/api/v1" };
1084        let result = parse_service_module_args(input).unwrap();
1085        assert_eq!(result, "/api/v1");
1086    }
1087}