toolfront_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parse, parse_macro_input, Ident, LitStr};
4
5struct ApiEndpoint {
6    client_name: Ident,
7    path: LitStr,
8    method: LitStr,
9    params_type: syn::Type,
10    response_type: syn::Type,
11    error_type: syn::Type,
12}
13
14impl Parse for ApiEndpoint {
15    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16        let client_name: Ident = input.parse()?;
17        input.parse::<syn::Token![,]>()?;
18        let path: LitStr = input.parse()?;
19        input.parse::<syn::Token![,]>()?;
20        let method: LitStr = input.parse()?;
21        input.parse::<syn::Token![,]>()?;
22        let params_type: syn::Type = input.parse()?;
23        input.parse::<syn::Token![,]>()?;
24        let response_type: syn::Type = input.parse()?;
25        input.parse::<syn::Token![,]>()?;
26        let error_type: syn::Type = input.parse()?;
27
28        Ok(ApiEndpoint {
29            client_name,
30            path,
31            method,
32            params_type,
33            response_type,
34            error_type,
35        })
36    }
37}
38
39/// Generate a type-safe API client for an OpenAPI endpoint. This macro is designed to work
40/// within an agent-based API automation system that uses RAG (Retrieval Augmented Generation)
41/// to find and execute relevant API endpoints.
42///
43/// # System Overview
44///
45/// The typical workflow:
46/// 1. OpenAPI specs are downloaded and stored in the database
47/// 2. Endpoints are extracted and embedded for RAG retrieval
48/// 3. When a natural language task arrives, relevant endpoints are retrieved
49/// 4. This macro generates type-safe clients for those endpoints
50///
51/// # Usage Example
52///
53/// ```rust
54/// use serde::{Serialize, Deserialize};
55/// use pgvector::Vector;
56/// use uuid::Uuid;
57///
58/// // Define your custom error type
59/// #[derive(Debug, thiserror::Error)]
60/// pub enum AgentError {
61///     #[error("API request failed: {0}")]
62///     Request(#[from] reqwest::Error),
63///     #[error("JSON error: {0}")]
64///     Json(#[from] serde_json::Error),
65///     // ... other error variants as needed
66/// }
67///
68/// // Define your request and response types
69/// #[derive(Debug, Serialize)]
70/// struct SearchUsersParams {
71///     query: String,
72///     max_results: i32,
73///     include_inactive: bool,
74/// }
75///
76/// #[derive(Debug, Deserialize)]
77/// struct UserSearchResponse {
78///     users: Vec<User>,
79///     total_count: i32,
80///     page_token: Option<String>,
81/// }
82///
83/// // Generate the client with your custom error type
84/// generate_client!(
85///     UserSearchClient,                // Name for the generated client
86///     "/api/v1/users/search",         // Endpoint path
87///     "POST",                         // HTTP method
88///     SearchUsersParams,              // Parameters type
89///     UserSearchResponse,             // Response type
90///     AgentError                      // Your custom error type
91/// );
92///
93/// // Example usage in an agent system
94/// struct Agent {
95///     openai: OpenAIClient,
96///     db: PgPool,
97/// }
98///
99/// impl Agent {
100///     async fn execute_task(&self, task: &str) -> Result<serde_json::Value, AgentError> {
101///         // Find relevant endpoint using RAG
102///         let endpoint = find_relevant_endpoint(&self.db, task).await?;
103///         
104///         // Generate parameters using LLM
105///         let params = self.generate_parameters(task).await?;
106///         
107///         // Execute the API call using our generated client
108///         let client = UserSearchClient::new("https://api.example.com".to_string());
109///         let response = client.execute(params).await?;
110///         
111///         Ok(serde_json::to_value(response)?)
112///     }
113/// }
114/// ```
115///
116/// # Parameters
117///
118/// * `client_name`: The name of the generated client struct
119/// * `path`: The endpoint path template (e.g., "/users/{id}/posts")
120/// * `method`: The HTTP method as a string (e.g., "GET", "POST")
121/// * `params_type`: The request parameters type (must implement Serialize)
122/// * `response_type`: The response type (must implement Deserialize)
123/// * `error_type`: Your custom error type that implements From<reqwest::Error> and From<serde_json::Error>
124///
125/// # Generated Client
126///
127/// The macro generates a client struct with:
128/// - Constructor for base URL configuration
129/// - Type-safe execute method that handles:
130///   - Path parameter substitution
131///   - Request body serialization
132///   - Response deserialization
133///   - Error conversion to your custom type
134///
135/// # Error Handling
136///
137/// The generated client returns `Result<T, E>` where E is your custom error type.
138/// Your error type must implement:
139/// ```rust
140/// impl From<reqwest::Error> for YourErrorType { ... }
141/// impl From<serde_json::Error> for YourErrorType { ... }
142/// ```
143///
144/// Common error cases that will be converted to your error type:
145/// - URL construction failures
146/// - Network errors from reqwest
147/// - Non-200 HTTP responses
148/// - JSON serialization/deserialization errors
149#[proc_macro]
150pub fn generate_client(input: TokenStream) -> TokenStream {
151    let ApiEndpoint {
152        client_name,
153        path,
154        method,
155        params_type,
156        response_type,
157        error_type,
158    } = parse_macro_input!(input as ApiEndpoint);
159
160    let generated = quote! {
161        pub struct #client_name {
162            base_url: String,
163            client: reqwest::Client,
164        }
165
166        impl #client_name {
167            pub fn new(base_url: String) -> Self {
168                Self {
169                    base_url,
170                    client: reqwest::Client::new(),
171                }
172            }
173
174            pub async fn execute(
175                &self,
176                params: #params_type
177            ) -> Result<#response_type, #error_type> {
178                use reqwest::Method;
179                use serde_json::Value;
180                use std::convert::TryFrom;
181
182                // Handle path parameter substitution
183                let mut url = format!("{}{}", self.base_url, #path);
184                let params_json = serde_json::to_value(&params)
185                    .map_err(|e| std::convert::Into::into(e))?;
186
187                if let Value::Object(obj) = &params_json {
188                    for (key, value) in obj {
189                        let pattern = format!("{{{}}}", key);
190                        if url.contains(&pattern) {
191                            url = url.replace(&pattern, &value.as_str().unwrap_or_default());
192                        }
193                    }
194                }
195
196                let method = Method::from_bytes(#method.as_bytes())
197                    .map_err(|_| std::convert::Into::into(
198                        reqwest::Error::from(std::io::Error::new(
199                            std::io::ErrorKind::InvalidInput,
200                            "Invalid HTTP method"
201                        ))
202                    ))?;
203
204                let response = self.client
205                    .request(method, &url)
206                    .json(&params)
207                    .send()
208                    .await
209                    .map_err(std::convert::Into::into)?;
210
211                if !response.status().is_success() {
212                    return Err(std::convert::Into::into(
213                        response.error_for_status().unwrap_err()
214                    ));
215                }
216
217                response.json::<#response_type>()
218                    .await
219                    .map_err(std::convert::Into::into)
220            }
221        }
222    };
223
224    generated.into()
225}