toolfront_macro/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parse, parse_macro_input, Ident, LitStr};
4
5struct ApiEndpoint {
6 client_name: Ident,
7 path: LitStr,
8 method: LitStr,
9 params_type: syn::Type,
10 response_type: syn::Type,
11 error_type: syn::Type,
12}
13
14impl Parse for ApiEndpoint {
15 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16 let client_name: Ident = input.parse()?;
17 input.parse::<syn::Token![,]>()?;
18 let path: LitStr = input.parse()?;
19 input.parse::<syn::Token![,]>()?;
20 let method: LitStr = input.parse()?;
21 input.parse::<syn::Token![,]>()?;
22 let params_type: syn::Type = input.parse()?;
23 input.parse::<syn::Token![,]>()?;
24 let response_type: syn::Type = input.parse()?;
25 input.parse::<syn::Token![,]>()?;
26 let error_type: syn::Type = input.parse()?;
27
28 Ok(ApiEndpoint {
29 client_name,
30 path,
31 method,
32 params_type,
33 response_type,
34 error_type,
35 })
36 }
37}
38
39/// Generate a type-safe API client for an OpenAPI endpoint. This macro is designed to work
40/// within an agent-based API automation system that uses RAG (Retrieval Augmented Generation)
41/// to find and execute relevant API endpoints.
42///
43/// # System Overview
44///
45/// The typical workflow:
46/// 1. OpenAPI specs are downloaded and stored in the database
47/// 2. Endpoints are extracted and embedded for RAG retrieval
48/// 3. When a natural language task arrives, relevant endpoints are retrieved
49/// 4. This macro generates type-safe clients for those endpoints
50///
51/// # Usage Example
52///
53/// ```rust
54/// use serde::{Serialize, Deserialize};
55/// use pgvector::Vector;
56/// use uuid::Uuid;
57///
58/// // Define your custom error type
59/// #[derive(Debug, thiserror::Error)]
60/// pub enum AgentError {
61/// #[error("API request failed: {0}")]
62/// Request(#[from] reqwest::Error),
63/// #[error("JSON error: {0}")]
64/// Json(#[from] serde_json::Error),
65/// // ... other error variants as needed
66/// }
67///
68/// // Define your request and response types
69/// #[derive(Debug, Serialize)]
70/// struct SearchUsersParams {
71/// query: String,
72/// max_results: i32,
73/// include_inactive: bool,
74/// }
75///
76/// #[derive(Debug, Deserialize)]
77/// struct UserSearchResponse {
78/// users: Vec<User>,
79/// total_count: i32,
80/// page_token: Option<String>,
81/// }
82///
83/// // Generate the client with your custom error type
84/// generate_client!(
85/// UserSearchClient, // Name for the generated client
86/// "/api/v1/users/search", // Endpoint path
87/// "POST", // HTTP method
88/// SearchUsersParams, // Parameters type
89/// UserSearchResponse, // Response type
90/// AgentError // Your custom error type
91/// );
92///
93/// // Example usage in an agent system
94/// struct Agent {
95/// openai: OpenAIClient,
96/// db: PgPool,
97/// }
98///
99/// impl Agent {
100/// async fn execute_task(&self, task: &str) -> Result<serde_json::Value, AgentError> {
101/// // Find relevant endpoint using RAG
102/// let endpoint = find_relevant_endpoint(&self.db, task).await?;
103///
104/// // Generate parameters using LLM
105/// let params = self.generate_parameters(task).await?;
106///
107/// // Execute the API call using our generated client
108/// let client = UserSearchClient::new("https://api.example.com".to_string());
109/// let response = client.execute(params).await?;
110///
111/// Ok(serde_json::to_value(response)?)
112/// }
113/// }
114/// ```
115///
116/// # Parameters
117///
118/// * `client_name`: The name of the generated client struct
119/// * `path`: The endpoint path template (e.g., "/users/{id}/posts")
120/// * `method`: The HTTP method as a string (e.g., "GET", "POST")
121/// * `params_type`: The request parameters type (must implement Serialize)
122/// * `response_type`: The response type (must implement Deserialize)
123/// * `error_type`: Your custom error type that implements From<reqwest::Error> and From<serde_json::Error>
124///
125/// # Generated Client
126///
127/// The macro generates a client struct with:
128/// - Constructor for base URL configuration
129/// - Type-safe execute method that handles:
130/// - Path parameter substitution
131/// - Request body serialization
132/// - Response deserialization
133/// - Error conversion to your custom type
134///
135/// # Error Handling
136///
137/// The generated client returns `Result<T, E>` where E is your custom error type.
138/// Your error type must implement:
139/// ```rust
140/// impl From<reqwest::Error> for YourErrorType { ... }
141/// impl From<serde_json::Error> for YourErrorType { ... }
142/// ```
143///
144/// Common error cases that will be converted to your error type:
145/// - URL construction failures
146/// - Network errors from reqwest
147/// - Non-200 HTTP responses
148/// - JSON serialization/deserialization errors
149#[proc_macro]
150pub fn generate_client(input: TokenStream) -> TokenStream {
151 let ApiEndpoint {
152 client_name,
153 path,
154 method,
155 params_type,
156 response_type,
157 error_type,
158 } = parse_macro_input!(input as ApiEndpoint);
159
160 let generated = quote! {
161 pub struct #client_name {
162 base_url: String,
163 client: reqwest::Client,
164 }
165
166 impl #client_name {
167 pub fn new(base_url: String) -> Self {
168 Self {
169 base_url,
170 client: reqwest::Client::new(),
171 }
172 }
173
174 pub async fn execute(
175 &self,
176 params: #params_type
177 ) -> Result<#response_type, #error_type> {
178 use reqwest::Method;
179 use serde_json::Value;
180 use std::convert::TryFrom;
181
182 // Handle path parameter substitution
183 let mut url = format!("{}{}", self.base_url, #path);
184 let params_json = serde_json::to_value(¶ms)
185 .map_err(|e| std::convert::Into::into(e))?;
186
187 if let Value::Object(obj) = ¶ms_json {
188 for (key, value) in obj {
189 let pattern = format!("{{{}}}", key);
190 if url.contains(&pattern) {
191 url = url.replace(&pattern, &value.as_str().unwrap_or_default());
192 }
193 }
194 }
195
196 let method = Method::from_bytes(#method.as_bytes())
197 .map_err(|_| std::convert::Into::into(
198 reqwest::Error::from(std::io::Error::new(
199 std::io::ErrorKind::InvalidInput,
200 "Invalid HTTP method"
201 ))
202 ))?;
203
204 let response = self.client
205 .request(method, &url)
206 .json(¶ms)
207 .send()
208 .await
209 .map_err(std::convert::Into::into)?;
210
211 if !response.status().is_success() {
212 return Err(std::convert::Into::into(
213 response.error_for_status().unwrap_err()
214 ));
215 }
216
217 response.json::<#response_type>()
218 .await
219 .map_err(std::convert::Into::into)
220 }
221 }
222 };
223
224 generated.into()
225}