Skip to main content

traitclaw_macros/
lib.rs

1//! Proc macros for the `TraitClaw` AI Agent Framework.
2//!
3//! # Usage
4//!
5//! ```rust,ignore
6//! use traitclaw::Tool;
7//! use schemars::JsonSchema;
8//! use serde::Deserialize;
9//!
10//! #[derive(Tool, Deserialize, JsonSchema)]
11//! #[tool(description = "Search the web for information")]
12//! struct WebSearch {
13//!     /// The search query
14//!     query: String,
15//! }
16//!
17//! impl WebSearch {
18//!     async fn execute(&self) -> traitclaw_core::Result<serde_json::Value> {
19//!         Ok(serde_json::json!({"results": []}))
20//!     }
21//! }
22//! ```
23
24use proc_macro::TokenStream;
25use proc_macro2::TokenStream as TokenStream2;
26use quote::quote;
27use syn::{parse_macro_input, Data, DeriveInput};
28
29/// Derive the `ErasedTool` implementation boilerplate for a struct.
30///
31/// The struct itself acts as the tool's Input type and MUST derive
32/// `serde::Deserialize` and `schemars::JsonSchema`.
33///
34/// The user MUST provide an inherent `execute(&self) -> Result<serde_json::Value>` method.
35///
36/// # Attributes
37///
38/// - `#[tool(description = "...")]` — tool description (required)
39/// - `#[tool(name = "...")]` — override tool name (optional, defaults to snake_case)
40#[proc_macro_derive(Tool, attributes(tool))]
41pub fn derive_tool(input: TokenStream) -> TokenStream {
42    let input = parse_macro_input!(input as DeriveInput);
43    match expand_tool(input) {
44        Ok(ts) => ts.into(),
45        Err(e) => e.to_compile_error().into(),
46    }
47}
48
49fn expand_tool(input: DeriveInput) -> syn::Result<TokenStream2> {
50    // Validate struct
51    match &input.data {
52        Data::Struct(_) => {}
53        _ => {
54            return Err(syn::Error::new_spanned(
55                &input.ident,
56                "#[derive(Tool)] can only be applied to structs",
57            ));
58        }
59    }
60
61    let struct_name = &input.ident;
62
63    // Parse #[tool(...)] attributes
64    let mut description: Option<String> = None;
65    let mut name_override: Option<String> = None;
66
67    for attr in &input.attrs {
68        if !attr.path().is_ident("tool") {
69            continue;
70        }
71        attr.parse_nested_meta(|meta| {
72            if meta.path.is_ident("description") {
73                let value: syn::LitStr = meta.value()?.parse()?;
74                description = Some(value.value());
75            } else if meta.path.is_ident("name") {
76                let value: syn::LitStr = meta.value()?.parse()?;
77                name_override = Some(value.value());
78            }
79            Ok(())
80        })?;
81    }
82
83    let description = description.unwrap_or_else(|| to_title_case(&struct_name.to_string()));
84    let tool_name = name_override.unwrap_or_else(|| to_snake_case(&struct_name.to_string()));
85
86    let expanded = quote! {
87        // Inherent helper methods for static access
88        impl #struct_name {
89            /// Returns the statically known tool name.
90            pub fn tool_name() -> &'static str {
91                #tool_name
92            }
93
94            /// Returns the statically known tool description.
95            pub fn tool_description() -> &'static str {
96                #description
97            }
98
99            /// Generate the [`traitclaw_core::ToolSchema`] for this tool.
100            pub fn tool_schema() -> traitclaw_core::ToolSchema {
101                let schema = schemars::schema_for!(#struct_name);
102                traitclaw_core::ToolSchema {
103                    name: #tool_name.to_string(),
104                    description: #description.to_string(),
105                    parameters: serde_json::to_value(schema)
106                        .unwrap_or_else(|_| serde_json::Value::Object(Default::default())),
107                }
108            }
109        }
110
111        // ErasedTool impl — the struct IS the input type.
112        // The user must provide:
113        //   `async fn execute(&self) -> traitclaw_core::Result<serde_json::Value>`
114        #[async_trait::async_trait]
115        impl traitclaw_core::ErasedTool for #struct_name {
116            fn name(&self) -> &str {
117                #tool_name
118            }
119
120            fn description(&self) -> &str {
121                #description
122            }
123
124            fn schema(&self) -> traitclaw_core::ToolSchema {
125                #struct_name::tool_schema()
126            }
127
128            async fn execute_json(
129                &self,
130                input: serde_json::Value,
131            ) -> traitclaw_core::Result<serde_json::Value> {
132                let typed: #struct_name = serde_json::from_value(input)
133                    .map_err(|e| traitclaw_core::Error::tool_execution(
134                        #tool_name,
135                        format!("Invalid input: {e}"),
136                    ))?;
137                typed.execute().await
138            }
139        }
140    };
141
142    Ok(expanded)
143}
144
145/// Convert `PascalCase` to `snake_case`.
146fn to_snake_case(s: &str) -> String {
147    let mut result = String::new();
148    for (i, c) in s.chars().enumerate() {
149        if c.is_uppercase() && i > 0 {
150            result.push('_');
151        }
152        result.extend(c.to_lowercase());
153    }
154    result
155}
156
157/// Convert `PascalCase` to `Title Case`.
158fn to_title_case(s: &str) -> String {
159    let mut result = String::new();
160    for (i, c) in s.chars().enumerate() {
161        if c.is_uppercase() && i > 0 {
162            result.push(' ');
163        }
164        result.push(c);
165    }
166    result
167}