tools_core/
lib.rs

1#![deny(unsafe_code)]
2
3use core::fmt;
4use std::{borrow::Cow, collections::HashMap, sync::Arc};
5
6use futures::{FutureExt, future::BoxFuture};
7use once_cell::sync::Lazy;
8use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned};
9use serde_json::{self, Value, to_string_pretty};
10
11// Re-export once_cell
12pub use once_cell;
13
14// ============================================================================
15// TOOL SCHEMA TRAIT AND IMPLEMENTATIONS
16// ============================================================================
17
18/// Trait for types that can generate a JSON Schema representation of themselves.
19pub trait ToolSchema {
20    fn schema() -> Value;
21}
22
23// Macro for implementing ToolSchema for primitive types with caching
24macro_rules! prim {
25    ($ty:ty, $name:expr) => {
26        impl ToolSchema for $ty {
27            fn schema() -> Value {
28                static SCHEMA: Lazy<Value> = Lazy::new(|| serde_json::json!({ "type": $name }));
29                SCHEMA.clone()
30            }
31        }
32    };
33}
34
35prim!(bool, "boolean");
36prim!(i8, "integer");
37prim!(i16, "integer");
38prim!(i32, "integer");
39prim!(i64, "integer");
40prim!(i128, "integer");
41prim!(isize, "integer");
42prim!(u8, "integer");
43prim!(u16, "integer");
44prim!(u32, "integer");
45prim!(u64, "integer");
46prim!(u128, "integer");
47prim!(usize, "integer");
48prim!(f32, "number");
49prim!(f64, "number");
50
51impl ToolSchema for &'_ str {
52    fn schema() -> Value {
53        static SCHEMA: Lazy<Value> = Lazy::new(|| serde_json::json!({ "type": "string" }));
54        SCHEMA.clone()
55    }
56}
57
58impl ToolSchema for str {
59    fn schema() -> Value {
60        static SCHEMA: Lazy<Value> = Lazy::new(|| serde_json::json!({ "type": "string" }));
61        SCHEMA.clone()
62    }
63}
64
65impl ToolSchema for String {
66    fn schema() -> Value {
67        static SCHEMA: Lazy<Value> = Lazy::new(|| serde_json::json!({ "type": "string" }));
68        SCHEMA.clone()
69    }
70}
71
72impl ToolSchema for () {
73    fn schema() -> Value {
74        static SCHEMA: Lazy<Value> = Lazy::new(|| serde_json::json!({ "type": "null" }));
75        SCHEMA.clone()
76    }
77}
78
79impl<T: ToolSchema> ToolSchema for Option<T> {
80    fn schema() -> Value {
81        // Note: For generic types, we can't use static caching since each T creates a different type
82        // The derived implementations will handle caching for concrete types
83        serde_json::json!({
84            "anyOf": [
85                T::schema(),
86                { "type": "null" }
87            ]
88        })
89    }
90}
91
92impl<T: ToolSchema> ToolSchema for Vec<T> {
93    fn schema() -> Value {
94        // Note: For generic types, we can't use static caching since each T creates a different type
95        // The derived implementations will handle caching for concrete types
96        serde_json::json!({
97            "type": "array",
98            "items": T::schema()
99        })
100    }
101}
102
103impl<T: ToolSchema> ToolSchema for HashMap<String, T> {
104    fn schema() -> Value {
105        // Note: For generic types, we can't use static caching since each T creates a different type
106        // The derived implementations will handle caching for concrete types
107        serde_json::json!({
108            "type": "object",
109            "additionalProperties": T::schema()
110        })
111    }
112}
113
114// Tuple implementations
115macro_rules! impl_tuples {
116    ($($len:expr => ($($n:tt $name:ident)+))+) => {
117        $(
118            impl<$($name: ToolSchema),+> ToolSchema for ($($name,)+) {
119                fn schema() -> Value {
120                    // Note: For generic tuples, we can't use static caching since each combination
121                    // of types creates a different tuple type. The derived implementations will
122                    // handle caching for concrete tuple types.
123                    serde_json::json!({
124                        "type": "array",
125                        "prefixItems": [$($name::schema()),+],
126                        "minItems": $len,
127                        "maxItems": $len
128                    })
129                }
130            }
131        )+
132    }
133}
134
135impl_tuples! {
136    1 => (0 T0)
137    2 => (0 T0 1 T1)
138    3 => (0 T0 1 T1 2 T2)
139    4 => (0 T0 1 T1 2 T2 3 T3)
140    5 => (0 T0 1 T1 2 T2 3 T3 4 T4)
141    6 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5)
142    7 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6)
143    8 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7)
144    9 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8)
145    10 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9)
146    11 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10)
147    12 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11)
148    13 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12)
149    14 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13)
150    15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14)
151    16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15)
152    17 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16)
153    18 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17)
154    19 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18)
155    20 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19)
156    21 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20)
157    22 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21)
158    23 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22)
159    24 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23)
160    25 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24)
161    26 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24 25 T25)
162    27 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24 25 T25 26 T26)
163    28 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24 25 T25 26 T26 27 T27)
164    29 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24 25 T25 26 T26 27 T27 28 T28)
165    30 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15 16 T16 17 T17 18 T18 19 T19 20 T20 21 T21 22 T22 23 T23 24 T24 25 T25 26 T26 27 T27 28 T28 29 T29)
166}
167
168// ============================================================================
169// ERROR TYPES
170// ============================================================================
171
172/// Errors that can occur during tool operations
173#[derive(Debug, thiserror::Error)]
174pub enum ToolError {
175    #[error("Tool function '{name}' not found")]
176    FunctionNotFound { name: Cow<'static, str> },
177
178    #[error("Tool function '{name}' is already registered")]
179    AlreadyRegistered { name: &'static str },
180
181    #[error("Deserialization error: {0}")]
182    Deserialize(#[from] DeserializationError),
183
184    #[error("JSON serialization error: {0}")]
185    Serialization(#[from] serde_json::Error),
186
187    #[error("Runtime error: {0}")]
188    Runtime(String),
189}
190
191/// Specific deserialization errors
192#[derive(Debug, thiserror::Error)]
193#[error("Failed to deserialize JSON: {source}")]
194pub struct DeserializationError {
195    #[source]
196    pub source: serde_json::Error,
197}
198
199impl From<serde_json::Error> for DeserializationError {
200    fn from(err: serde_json::Error) -> Self {
201        DeserializationError { source: err }
202    }
203}
204
205// ============================================================================
206// CORE MODELS
207// ============================================================================
208
209/// Represents a function call with name and arguments
210#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
211pub struct FunctionCall {
212    pub id: Option<CallId>,
213    pub name: String,
214    pub arguments: Value,
215}
216
217impl FunctionCall {
218    pub fn new(name: String, arguments: Value) -> FunctionCall {
219        FunctionCall {
220            id: Some(CallId::new()),
221            name,
222            arguments,
223        }
224    }
225}
226
227#[derive(Debug, Clone, PartialEq, Eq, Hash)]
228pub struct CallId(uuid::Uuid);
229
230impl CallId {
231    pub fn new() -> CallId {
232        CallId(uuid::Uuid::new_v4())
233    }
234}
235
236impl Default for CallId {
237    fn default() -> Self {
238        CallId::new()
239    }
240}
241
242impl<'de> Deserialize<'de> for CallId {
243    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
244    where
245        D: Deserializer<'de>,
246    {
247        let s = String::deserialize(deserializer)?;
248        let uuid = uuid::Uuid::parse_str(&s).map_err(serde::de::Error::custom)?;
249        Ok(CallId(uuid))
250    }
251}
252
253impl Serialize for CallId {
254    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255    where
256        S: Serializer,
257    {
258        serializer.serialize_str(&self.0.to_string())
259    }
260}
261
262impl fmt::Display for CallId {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        write!(f, "{}", self.0)
265    }
266}
267
268/// Represents a function response with name and arguments
269#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
270pub struct FunctionResponse {
271    pub id: Option<CallId>,
272    pub name: String,
273    pub result: Value,
274}
275
276impl fmt::Display for FunctionResponse {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        let id_str = self
279            .id
280            .as_ref()
281            .map(|id| id.to_string())
282            .unwrap_or_else(|| "<none>".to_string());
283
284        let pretty_result =
285            to_string_pretty(&self.result).unwrap_or_else(|_| "<invalid json>".to_string());
286
287        write!(
288            f,
289            "FunctionResponse {{\n  id: {},\n  name: \"{}\",\n  result: {}\n}}",
290            id_str,
291            self.name,
292            pretty_result.replace("\n", "\n  ") // indent JSON
293        )
294    }
295}
296
297/// Function signature for tools
298pub type ToolFunc = dyn Fn(Value) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync;
299
300/// Metadata about a tool function
301#[derive(Debug, Clone)]
302pub struct ToolMetadata {
303    pub name: &'static str,
304    pub description: &'static str,
305}
306
307/// Runtime type signature information
308#[derive(Debug, Clone)]
309pub struct TypeSignature {
310    pub input_type: &'static str,
311    pub output_type: &'static str,
312}
313
314/// Tool registration for inventory collection
315pub struct ToolRegistration {
316    pub name: &'static str,
317    pub doc: &'static str,
318    pub f: fn(Value) -> BoxFuture<'static, Result<Value, ToolError>>,
319    pub param_schema: fn() -> Value,
320}
321
322impl ToolRegistration {
323    pub const fn new(
324        name: &'static str,
325        doc: &'static str,
326        f: fn(Value) -> BoxFuture<'static, Result<Value, ToolError>>,
327        param_schema: fn() -> Value,
328    ) -> Self {
329        Self {
330            name,
331            doc,
332            f,
333            param_schema,
334        }
335    }
336}
337
338/// Represents a tool that can be called
339#[derive(Debug, Clone)]
340pub struct Tool {
341    pub metadata: ToolMetadata,
342    pub signature: TypeSignature,
343}
344
345// ============================================================================
346// SCHEMA GENERATION
347// ============================================================================
348
349/// Function declaration for LLM consumption
350#[derive(Debug, Clone, PartialEq, Eq, Serialize, serde::Deserialize)]
351pub struct FunctionDecl<'a> {
352    #[serde(borrow)]
353    pub name: &'a str,
354    #[serde(borrow)]
355    pub description: &'a str,
356    pub parameters: Value,
357}
358
359impl<'a> FunctionDecl<'a> {
360    pub fn new(name: &'a str, description: &'a str, parameters: Value) -> Self {
361        Self {
362            name,
363            description,
364            parameters,
365        }
366    }
367}
368
369// ============================================================================
370// TOOL COLLECTION
371// ============================================================================
372
373fn schema_value<T: ToolSchema>() -> Result<Value, ToolError> {
374    Ok(T::schema())
375}
376
377#[derive(Default, Clone)]
378pub struct ToolCollection {
379    funcs: HashMap<&'static str, Arc<ToolFunc>>,
380    descriptions: HashMap<&'static str, &'static str>,
381    signatures: HashMap<&'static str, TypeSignature>,
382    declarations: HashMap<&'static str, FunctionDecl<'static>>,
383}
384
385impl ToolCollection {
386    pub fn new() -> Self {
387        Self::default()
388    }
389
390    pub fn register<I, O, F, Fut>(
391        &mut self,
392        name: &'static str,
393        desc: &'static str,
394        func: F,
395    ) -> Result<&mut Self, ToolError>
396    where
397        I: 'static + DeserializeOwned + Serialize + Send + ToolSchema,
398        O: 'static + Serialize + Send + ToolSchema,
399        F: Fn(I) -> Fut + Send + Sync + 'static,
400        Fut: std::future::Future<Output = O> + Send + 'static,
401    {
402        if self.funcs.contains_key(name) {
403            return Err(ToolError::AlreadyRegistered { name });
404        }
405
406        self.descriptions.insert(name, desc);
407
408        self.declarations
409            .insert(name, FunctionDecl::new(name, desc, schema_value::<I>()?));
410
411        let func_arc: Arc<F> = Arc::new(func);
412        self.funcs.insert(
413            name,
414            Arc::new(
415                move |raw: Value| -> BoxFuture<'static, Result<Value, ToolError>> {
416                    let func = func_arc.clone();
417                    async move {
418                        let input: I =
419                            serde_json::from_value(raw).map_err(DeserializationError::from)?;
420                        let output: O = (func)(input).await;
421                        serde_json::to_value(output).map_err(|e| ToolError::Runtime(e.to_string()))
422                    }
423                    .boxed()
424                },
425            ),
426        );
427
428        Ok(self)
429    }
430
431    pub async fn call(&self, call: FunctionCall) -> Result<FunctionResponse, ToolError> {
432        let FunctionCall {
433            id,
434            name,
435            arguments,
436        } = call;
437        let async_func = self
438            .funcs
439            .get(name.as_str())
440            .ok_or(ToolError::FunctionNotFound {
441                name: Cow::Owned(name.clone()),
442            })?;
443
444        let result = async_func(arguments).await?;
445        Ok(FunctionResponse { id, name, result })
446    }
447
448    pub fn unregister(&mut self, name: &str) -> Result<(), ToolError> {
449        if self.funcs.remove(name).is_none() {
450            return Err(ToolError::FunctionNotFound {
451                name: Cow::Owned(name.to_string()),
452            });
453        }
454        self.descriptions.remove(name);
455        self.signatures.remove(name);
456        self.declarations.remove(name);
457        Ok(())
458    }
459
460    pub fn descriptions(&self) -> impl Iterator<Item = (&'static str, &'static str)> + '_ {
461        self.descriptions.iter().map(|(k, v)| (*k, *v))
462    }
463
464    pub fn collect_tools() -> Self {
465        let mut hub = Self::new();
466
467        for reg in inventory::iter::<ToolRegistration> {
468            hub.descriptions.insert(reg.name, reg.doc);
469            hub.funcs.insert(reg.name, Arc::new(reg.f));
470
471            hub.declarations.insert(
472                reg.name,
473                FunctionDecl::new(reg.name, reg.doc, (reg.param_schema)()),
474            );
475        }
476
477        hub
478    }
479
480    pub fn json(&self) -> Result<Value, ToolError> {
481        let list: Vec<&FunctionDecl> = self.declarations.values().collect();
482        Ok(serde_json::to_value(list)?)
483    }
484}
485
486inventory::collect!(ToolRegistration);
487
488// ============================================================================
489// TESTS
490// ============================================================================
491
492// Schema tests commented out due to circular dependency with derive macro
493// #[cfg(test)]
494// mod schema_tests {
495//     use super::*;
496//     use serde_json::json;
497
498//     #[test]
499//     fn test_primitive_schemas() {
500//         assert_eq!(bool::schema(), json!({"type": "boolean"}));
501//         assert_eq!(i32::schema(), json!({"type": "integer"}));
502//         assert_eq!(f64::schema(), json!({"type": "number"}));
503//         assert_eq!(String::schema(), json!({"type": "string"}));
504//         assert_eq!(<()>::schema(), json!({"type": "null"}));
505//     }
506
507//     #[test]
508//     fn test_option_schema() {
509//         assert_eq!(
510//             <Option<i32>>::schema(),
511//             json!({
512//                 "anyOf": [
513//                     {"type": "integer"},
514//                     {"type": "null"}
515//                 ]
516//             })
517//         );
518//     }
519
520//     #[test]
521//     fn test_vec_schema() {
522//         assert_eq!(
523//             <Vec<String>>::schema(),
524//             json!({"type": "array", "items": {"type": "string"}})
525//         );
526//     }
527
528//     #[test]
529//     fn test_tuple_schemas() {
530//         assert_eq!(
531//             <(i32,)>::schema(),
532//             json!({
533//                 "type": "array",
534//                 "prefixItems": [{"type": "integer"}],
535//                 "minItems": 1,
536//                 "maxItems": 1
537//             })
538//         );
539
540//         assert_eq!(
541//             <(i32, String)>::schema(),
542//             json!({
543//                 "type": "array",
544//                 "prefixItems": [{"type": "integer"}, {"type": "string"}],
545//                 "minItems": 2,
546//                 "maxItems": 2
547//             })
548//         );
549//     }
550
551//     #[test]
552//     fn test_hashmap_schema() {
553//         assert_eq!(
554//             <HashMap<String, i32>>::schema(),
555//             json!({
556//                 "type": "object",
557//                 "additionalProperties": {"type": "integer"}
558//             })
559//         );
560//     }
561
562//     #[derive(serde::Serialize, serde::Deserialize, ToolSchema)]
563//     struct UserId(u64);
564
565//     #[derive(serde::Serialize, serde::Deserialize, ToolSchema)]
566//     struct Email(String);
567
568//     #[derive(serde::Serialize, serde::Deserialize, ToolSchema)]
569//     struct Temperature(f64);
570
571//     #[derive(serde::Serialize, serde::Deserialize, ToolSchema)]
572//     struct Count(usize);
573
574//     #[test]
575//     fn test_newtype_schemas() {
576//         assert_eq!(
577//             UserId::schema(),
578//             json!({
579//                 "type": "array",
580//                 "prefixItems": [{"type": "integer"}],
581//                 "minItems": 1,
582//                 "maxItems": 1
583//             })
584//         );
585
586//         assert_eq!(
587//             Email::schema(),
588//             json!({
589//                 "type": "array",
590//                 "prefixItems": [{"type": "string"}],
591//                 "minItems": 1,
592//                 "maxItems": 1
593//             })
594//         );
595
596//         assert_eq!(
597//             Temperature::schema(),
598//             json!({
599//                 "type": "array",
600//                 "prefixItems": [{"type": "number"}],
601//                 "minItems": 1,
602//                 "maxItems": 1
603//             })
604//         );
605
606//         assert_eq!(
607//             Count::schema(),
608//             json!({
609//                 "type": "array",
610//                 "prefixItems": [{"type": "integer"}],
611//                 "minItems": 1,
612//                 "maxItems": 1
613//             })
614//         );
615//     }
616
617//     #[derive(serde::Serialize, serde::Deserialize, ToolSchema)]
618//     struct UserProfile {
619//         id: UserId,
620//         email: Email,
621//         name: String,
622//         age: Option<u32>,
623//     }
624
625//     #[test]
626//     fn test_newtype_in_struct() {
627//         let expected = json!({
628//             "type": "object",
629//             "properties": {
630//                 "id": {"type": "array", "prefixItems": [{"type": "integer"}], "minItems": 1, "maxItems": 1},
631//                 "email": {"type": "array", "prefixItems": [{"type": "string"}], "minItems": 1, "maxItems": 1},
632//                 "name": {"type": "string"},
633//                 "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}
634//             },
635//             "required": ["id", "email", "name"]
636//         });
637
638//         assert_eq!(UserProfile::schema(), expected);
639//     }
640// }
641
642#[cfg(test)]
643mod tool_tests {
644    use super::*;
645    use serde::Deserialize;
646    use serde_json::{self, json};
647
648    fn add<T: std::ops::Add<Output = T> + Copy>(a: T, b: T) -> T {
649        a + b
650    }
651    fn concat<T: std::fmt::Display>(a: T, b: T) -> String {
652        format!("{}{}", a, b)
653    }
654    fn noop() {}
655    // async fn async_foo() {}
656
657    #[derive(Debug, PartialEq, Serialize, Deserialize)]
658    struct SomeArgs {
659        a: i32,
660        b: i32,
661    }
662    // fn using_args(_a: SomeArgs) {}
663
664    fn fc(name: &str, args: serde_json::Value) -> FunctionCall {
665        FunctionCall::new(name.to_string(), args)
666    }
667
668    #[tokio::test]
669    async fn test_collection() {
670        let mut collection = ToolCollection::default();
671
672        collection
673            .register("add", "Adds two values", |t: (i32, i32)| async move {
674                add(t.0, t.1)
675            })
676            .unwrap();
677        collection
678            .register(
679                "concat",
680                "Concatenates two strings",
681                |t: (String, String)| async move { concat(t.0, t.1) },
682            )
683            .unwrap();
684        collection
685            .register("noop", "Does nothing", |_t: ()| async move { noop() })
686            .unwrap();
687        // Complex args test commented out due to ToolSchema derive requirement
688        // collection
689        //     .register(
690        //         "complex_args",
691        //         "Uses complex args",
692        //         |t: SomeArgs| async move { using_args(t) },
693        //     )
694        //     .unwrap();
695
696        assert_eq!(
697            collection
698                .call(fc("add", json!([1, 2])))
699                .await
700                .unwrap()
701                .result,
702            json!(3)
703        );
704        assert_eq!(
705            collection
706                .call(fc("concat", json!(["hello", "world"])))
707                .await
708                .unwrap()
709                .result,
710            json!("helloworld")
711        );
712        assert_eq!(
713            collection
714                .call(fc("noop", json!(null)))
715                .await
716                .unwrap()
717                .result,
718            json!(null)
719        );
720        // Complex args test commented out due to ToolSchema derive requirement
721        // assert_eq!(
722        //     collection
723        //         .call(fc("complex_args", json!({ "a": 1, "b": 2 })))
724        //         .await
725        //         .unwrap(),
726        //     json!(null)
727        // );
728    }
729
730    #[tokio::test]
731    async fn test_boolean_function() {
732        let mut col = ToolCollection::default();
733        col.register(
734            "is_even",
735            "Checks even",
736            |t: (i32,)| async move { t.0 % 2 == 0 },
737        )
738        .unwrap();
739
740        assert_eq!(
741            col.call(fc("is_even", json!([4]))).await.unwrap().result,
742            json!(true)
743        );
744        assert_eq!(
745            col.call(fc("is_even", json!([3]))).await.unwrap().result,
746            json!(false)
747        );
748    }
749
750    // Complex return test commented out due to ToolSchema derive requirement
751    // #[derive(Serialize, Deserialize, Debug, PartialEq, ToolSchema)]
752    // struct Point {
753    //     x: i32,
754    //     y: i32,
755    // }
756
757    // #[tokio::test]
758    // async fn test_complex_return() {
759    //     let mut col = ToolCollection::default();
760    //     col.register(
761    //         "create_point",
762    //         "Creates a point",
763    //         |t: (i32, i32)| async move { Point { x: t.0, y: t.1 } },
764    //     )
765    //     .unwrap();
766
767    //     assert_eq!(
768    //         col.call(fc("create_point", json!([10, 20]))).await.unwrap(),
769    //         json!({ "x": 10, "y": 20 })
770    //     );
771    // }
772
773    #[tokio::test]
774    async fn test_invalid_function_name() {
775        let mut col = ToolCollection::default();
776        col.register("dummy", "does nothing", |_: ()| async {})
777            .unwrap();
778
779        let err = col.call(fc("ghost", json!([]))).await.unwrap_err();
780        assert!(matches!(err, ToolError::FunctionNotFound { .. }));
781    }
782
783    #[tokio::test]
784    async fn test_deserialization_error() {
785        let mut col = ToolCollection::default();
786        col.register("subtract", "Sub two numbers", |t: (i32, i32)| async move {
787            t.0 - t.1
788        })
789        .unwrap();
790
791        let err = col
792            .call(fc("subtract", json!(["a", "b"]))) // bad types → error
793            .await
794            .unwrap_err();
795
796        assert!(matches!(err, ToolError::Deserialize(_)));
797    }
798}
799
800// Performance tests for schema caching (primitive types only)
801#[cfg(test)]
802mod performance_tests {
803    use super::*;
804    use std::time::Instant;
805
806    #[test]
807    fn test_schema_caching_primitives() {
808        // Test that primitive type schemas are cached
809        let schema1 = String::schema();
810        let schema2 = String::schema();
811
812        // Verify they're identical (same JSON content)
813        assert_eq!(schema1, schema2);
814
815        // Test multiple primitive types
816        let int_schema1 = i32::schema();
817        let int_schema2 = i32::schema();
818        assert_eq!(int_schema1, int_schema2);
819
820        let bool_schema1 = bool::schema();
821        let bool_schema2 = bool::schema();
822        assert_eq!(bool_schema1, bool_schema2);
823    }
824
825    #[test]
826    fn test_schema_performance_primitive() {
827        // Warm up the cache
828        let _ = String::schema();
829
830        let start = Instant::now();
831        for _ in 0..1000 {
832            let _ = String::schema();
833        }
834        let cached_duration = start.elapsed();
835
836        // Cached calls should be very fast (< 10ms for 1000 calls)
837        assert!(
838            cached_duration.as_millis() < 10,
839            "Cached schema calls took too long: {:?}",
840            cached_duration
841        );
842    }
843
844    #[test]
845    fn test_schema_performance_multiple_primitives() {
846        // Test multiple primitive types for performance
847        let _ = f64::schema(); // Warm up
848
849        let start = Instant::now();
850        for _ in 0..1000 {
851            let _ = f64::schema();
852            let _ = u64::schema();
853            let _ = bool::schema();
854        }
855        let cached_duration = start.elapsed();
856
857        // Multiple primitive cached schemas should be very fast
858        assert!(
859            cached_duration.as_millis() < 20,
860            "Cached primitive schema calls took too long: {:?}",
861            cached_duration
862        );
863    }
864
865    #[test]
866    fn test_primitive_schema_content_correctness() {
867        // Verify primitive schemas have expected structure
868        let string_schema = String::schema();
869        assert_eq!(string_schema["type"], "string");
870
871        let int_schema = i32::schema();
872        assert_eq!(int_schema["type"], "integer");
873
874        let bool_schema = bool::schema();
875        assert_eq!(bool_schema["type"], "boolean");
876
877        let null_schema = <()>::schema();
878        assert_eq!(null_schema["type"], "null");
879    }
880
881    #[test]
882    fn test_concurrent_schema_access() {
883        use std::thread;
884
885        let handles: Vec<_> = (0..10)
886            .map(|_| {
887                thread::spawn(|| {
888                    // Each thread gets primitive schemas multiple times
889                    for _ in 0..100 {
890                        let _ = String::schema();
891                        let _ = i32::schema();
892                        let _ = bool::schema();
893                        let _ = f64::schema();
894                    }
895                })
896            })
897            .collect();
898
899        // Wait for all threads to complete
900        for handle in handles {
901            handle.join().unwrap();
902        }
903
904        // Verify schema is still correct after concurrent access
905        let schema = String::schema();
906        assert_eq!(schema["type"], "string");
907    }
908
909    #[test]
910    fn test_unit_type_caching() {
911        // Test unit type caching
912        let unit_type_schema1 = <()>::schema();
913        let unit_type_schema2 = <()>::schema();
914        assert_eq!(unit_type_schema1, unit_type_schema2);
915        assert_eq!(unit_type_schema1["type"], "null");
916    }
917
918    #[test]
919    fn benchmark_primitive_schema_generation() {
920        const ITERATIONS: usize = 10_000;
921
922        // Benchmark string type
923        let start = Instant::now();
924        for _ in 0..ITERATIONS {
925            let _ = String::schema();
926        }
927        let string_duration = start.elapsed();
928
929        // Benchmark integer type
930        let start = Instant::now();
931        for _ in 0..ITERATIONS {
932            let _ = i32::schema();
933        }
934        let int_duration = start.elapsed();
935
936        // Benchmark boolean type
937        let start = Instant::now();
938        for _ in 0..ITERATIONS {
939            let _ = bool::schema();
940        }
941        let bool_duration = start.elapsed();
942
943        println!("Primitive schema generation performance (cached):");
944        println!("  String ({} calls): {:?}", ITERATIONS, string_duration);
945        println!("  Integer ({} calls): {:?}", ITERATIONS, int_duration);
946        println!("  Boolean ({} calls): {:?}", ITERATIONS, bool_duration);
947
948        // All should be very fast due to caching
949        assert!(string_duration.as_millis() < 100);
950        assert!(int_duration.as_millis() < 100);
951        assert!(bool_duration.as_millis() < 100);
952    }
953}