strands_agents_macros/
lib.rs1use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, Attribute, FnArg, ItemFn, Pat};
6
7#[proc_macro_attribute]
30pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
31 let input_fn = parse_macro_input!(item as ItemFn);
32
33 let fn_name = &input_fn.sig.ident;
34 let fn_name_str = fn_name.to_string();
35 let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
36
37 let description = extract_doc_comment(&input_fn.attrs);
38
39 let params: Vec<_> = input_fn
40 .sig
41 .inputs
42 .iter()
43 .filter_map(|arg| {
44 if let FnArg::Typed(pat_type) = arg {
45 if let Pat::Ident(pat_ident) = &*pat_type.pat {
46 let name = pat_ident.ident.to_string();
47 if name == "tool_context" || name == "agent" {
48 return None;
49 }
50 let ty = &pat_type.ty;
51 return Some((name, quote!(#ty).to_string()));
52 }
53 }
54 None
55 })
56 .collect();
57
58 let param_schemas: Vec<_> = params
59 .iter()
60 .map(|(name, _ty)| {
61 quote! {
62 properties.insert(
63 #name.to_string(),
64 serde_json::json!({
65 "type": "string",
66 "description": format!("Parameter {}", #name)
67 })
68 );
69 }
70 })
71 .collect();
72
73 let required_params: Vec<_> = params
74 .iter()
75 .filter(|(_, ty)| !ty.contains("Option"))
76 .map(|(name, _)| name.clone())
77 .collect();
78
79 let is_async = input_fn.sig.asyncness.is_some();
80
81 let param_names: Vec<_> = params
82 .iter()
83 .map(|(name, _)| format_ident!("{}", name))
84 .collect();
85
86 let execute_call = if is_async {
87 quote! { #fn_name(#(#param_names),*).await }
88 } else {
89 quote! { #fn_name(#(#param_names),*) }
90 };
91
92 let param_extractions: Vec<_> = params
93 .iter()
94 .map(|(name, ty)| {
95 let ident = format_ident!("{}", name);
96 if ty.contains("Option") {
97 quote! {
98 let #ident = tool_use.input.get(#name)
99 .and_then(|v| serde_json::from_value(v.clone()).ok());
100 }
101 } else {
102 quote! {
103 let #ident = tool_use.input.get(#name)
104 .and_then(|v| serde_json::from_value(v.clone()).ok())
105 .ok_or_else(|| format!("Missing required parameter: {}", #name))?;
106 }
107 }
108 })
109 .collect();
110
111 let expanded = quote! {
112 #input_fn
113
114 #[derive(Clone)]
115 pub struct #struct_name;
116
117 impl #struct_name {
118 pub fn new() -> Self { Self }
119 }
120
121 impl Default for #struct_name {
122 fn default() -> Self { Self::new() }
123 }
124
125 impl strands::tools::AgentTool for #struct_name {
126 fn tool_name(&self) -> &str {
127 #fn_name_str
128 }
129
130 fn tool_spec(&self) -> strands::types::tools::ToolSpec {
131 let mut properties = std::collections::HashMap::new();
132 #(#param_schemas)*
133
134 let required: Vec<String> = vec![#(#required_params.to_string()),*];
135
136 strands::types::tools::ToolSpec {
137 name: #fn_name_str.to_string(),
138 description: #description.to_string(),
139 input_schema: strands::types::tools::InputSchema {
140 json: serde_json::json!({
141 "type": "object",
142 "properties": properties,
143 "required": required
144 }),
145 },
146 output_schema: None,
147 }
148 }
149
150 fn tool_type(&self) -> &str {
151 "function"
152 }
153
154 fn stream(
155 &self,
156 tool_use: &strands::types::tools::ToolUse,
157 _invocation_state: &strands::tools::InvocationState,
158 ) -> strands::tools::ToolGenerator {
159 let tool_use = tool_use.clone();
160 Box::pin(async_stream::stream! {
161 let result: Result<String, String> = (|| async {
162 #(#param_extractions)*
163 let output = #execute_call;
164 Ok(output.to_string())
165 })().await;
166
167 let tool_result = match result {
168 Ok(text) => strands::types::tools::ToolResult::success(
169 &tool_use.tool_use_id,
170 text,
171 ),
172 Err(e) => strands::types::tools::ToolResult::error(
173 &tool_use.tool_use_id,
174 e,
175 ),
176 };
177 yield strands::tools::ToolEvent::Result(tool_result);
178 })
179 }
180 }
181 };
182
183 TokenStream::from(expanded)
184}
185
186fn extract_doc_comment(attrs: &[Attribute]) -> String {
187 let mut doc_lines = Vec::new();
188
189 for attr in attrs {
190 if attr.path().is_ident("doc") {
191 if let syn::Meta::NameValue(meta) = &attr.meta {
192 if let syn::Expr::Lit(expr_lit) = &meta.value {
193 if let syn::Lit::Str(lit_str) = &expr_lit.lit {
194 doc_lines.push(lit_str.value().trim().to_string());
195 }
196 }
197 }
198 }
199 }
200
201 let mut description = Vec::new();
202 for line in doc_lines {
203 let lower = line.to_lowercase();
204 if lower.starts_with("# arg") || lower.starts_with("args:") || lower.starts_with("arguments:") {
205 break;
206 }
207 description.push(line);
208 }
209
210 description.join(" ").trim().to_string()
211}
212
213fn to_pascal_case(s: &str) -> String {
214 s.split('_')
215 .map(|word| {
216 let mut chars = word.chars();
217 match chars.next() {
218 Some(c) => c.to_uppercase().chain(chars).collect(),
219 None => String::new(),
220 }
221 })
222 .collect()
223}