Skip to main content

zai_rs/toolkits/
core.rs

1//! Core traits and types with enhanced type safety
2
3use std::{
4    borrow::Cow,
5    collections::{HashMap, hash_map::DefaultHasher},
6    hash::{Hash, Hasher},
7    sync::Arc,
8};
9
10use async_trait::async_trait;
11use jsonschema;
12use once_cell::sync::Lazy;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15
16use crate::toolkits::error::{ToolResult, error_context};
17
18/// Type-erased tool trait for dynamic dispatch
19#[async_trait]
20pub trait DynTool: Send + Sync {
21    /// Get the tool's metadata
22    fn metadata(&self) -> &ToolMetadata;
23
24    /// Execute with JSON input/output
25    async fn execute_json(&self, input: serde_json::Value) -> ToolResult<serde_json::Value>;
26
27    /// Get input schema
28    fn input_schema(&self) -> serde_json::Value;
29
30    /// Get the tool name
31    fn name(&self) -> &str {
32        &self.metadata().name
33    }
34
35    /// Clone the tool as a boxed trait object
36    fn clone_box(&self) -> Box<dyn DynTool>;
37}
38
39/// Global schema cache for compiled JSON schemas
40static SCHEMA_CACHE: Lazy<RwLock<HashMap<u64, Arc<jsonschema::Validator>>>> =
41    Lazy::new(|| RwLock::new(HashMap::new()));
42
43/// Enhanced tool metadata with better type information and memory optimization
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ToolMetadata {
46    /// Tool name (must be unique)
47    pub name: Cow<'static, str>,
48
49    /// Tool description
50    pub description: Cow<'static, str>,
51
52    /// Tool version
53    pub version: Cow<'static, str>,
54
55    /// Tool author
56    pub author: Option<Cow<'static, str>>,
57
58    /// Tool tags for categorization
59    pub tags: Vec<Cow<'static, str>>,
60
61    /// Whether the tool is enabled
62    pub enabled: bool,
63
64    /// Additional metadata
65    pub metadata: HashMap<Cow<'static, str>, serde_json::Value>,
66}
67
68impl ToolMetadata {
69    /// Create new metadata with validation
70    pub fn new(name: impl Into<String>, description: impl Into<String>) -> ToolResult<Self> {
71        let name = name.into();
72        let description = description.into();
73
74        // Validate tool name
75        if name.trim().is_empty() {
76            return Err(error_context().invalid_parameters("Tool name cannot be empty"));
77        }
78        if name.contains(|c: char| !c.is_alphanumeric() && c != '_') {
79            return Err(error_context()
80                .invalid_parameters("Tool name must be alphanumeric with underscores only"));
81        }
82
83        Ok(Self {
84            name: Cow::Owned(name),
85            description: Cow::Owned(description),
86            version: Cow::Borrowed("1.0.0"),
87            author: None,
88            tags: Vec::new(),
89            enabled: true,
90            metadata: HashMap::new(),
91        })
92    }
93
94    /// Builder pattern methods
95    pub fn version(mut self, version: impl Into<Cow<'static, str>>) -> Self {
96        self.version = version.into();
97        self
98    }
99
100    pub fn author(mut self, author: impl Into<Cow<'static, str>>) -> Self {
101        self.author = Some(author.into());
102        self
103    }
104
105    pub fn tags<T: Into<Cow<'static, str>>>(mut self, tags: impl IntoIterator<Item = T>) -> Self {
106        self.tags = tags.into_iter().map(Into::into).collect();
107        self
108    }
109
110    pub fn enabled(mut self, enabled: bool) -> Self {
111        self.enabled = enabled;
112        self
113    }
114
115    pub fn with_metadata(
116        mut self,
117        key: impl Into<Cow<'static, str>>,
118        value: serde_json::Value,
119    ) -> Self {
120        self.metadata.insert(key.into(), value);
121        self
122    }
123}
124
125/// Helper functions for type conversions (avoiding orphan rule issues)
126pub mod conversions {
127    use crate::toolkits::error::{ToolResult, error_context};
128
129    /// Convert a value to JSON
130    pub fn to_json<T: serde::Serialize>(value: T) -> ToolResult<serde_json::Value> {
131        serde_json::to_value(value).map_err(|e| error_context().serialization_error(e))
132    }
133
134    /// Extract string from JSON value
135    pub fn from_json_string(value: serde_json::Value) -> ToolResult<String> {
136        match value {
137            serde_json::Value::String(s) => Ok(s),
138            _ => Err(error_context().invalid_parameters("Expected string value")),
139        }
140    }
141
142    /// Extract i32 from JSON value
143    pub fn from_json_i32(value: serde_json::Value) -> ToolResult<i32> {
144        match value {
145            serde_json::Value::Number(n) => n
146                .as_i64()
147                .and_then(|i| i.try_into().ok())
148                .ok_or_else(|| error_context().invalid_parameters("Expected i32 value")),
149            _ => Err(error_context().invalid_parameters("Expected number value")),
150        }
151    }
152
153    /// Extract f64 from JSON value
154    pub fn from_json_f64(value: serde_json::Value) -> ToolResult<f64> {
155        match value {
156            serde_json::Value::Number(n) => n
157                .as_f64()
158                .ok_or_else(|| error_context().invalid_parameters("Expected f64 value")),
159            _ => Err(error_context().invalid_parameters("Expected number value")),
160        }
161    }
162
163    /// Extract bool from JSON value
164    pub fn from_json_bool(value: serde_json::Value) -> ToolResult<bool> {
165        match value {
166            serde_json::Value::Bool(b) => Ok(b),
167            _ => Err(error_context().invalid_parameters("Expected boolean value")),
168        }
169    }
170}
171
172// -----------------------------
173// Single-struct dynamic FunctionTool
174// -----------------------------
175
176/// Type alias for the complex handler type to reduce complexity warnings
177type ToolHandler = std::sync::Arc<
178    dyn Fn(
179            serde_json::Value,
180        ) -> std::pin::Pin<
181            Box<
182                dyn std::future::Future<
183                        Output = crate::toolkits::error::ToolResult<serde_json::Value>,
184                    > + Send,
185            >,
186        > + Send
187        + Sync,
188>;
189
190/// A single-struct tool that carries metadata, JSON schema, and an async
191/// handler
192pub struct FunctionTool {
193    metadata: ToolMetadata,
194    input_schema: serde_json::Value,
195    compiled_schema: Arc<jsonschema::Validator>,
196    handler: ToolHandler,
197}
198
199impl Clone for FunctionTool {
200    fn clone(&self) -> Self {
201        Self {
202            metadata: self.metadata.clone(),
203            input_schema: self.input_schema.clone(),
204            compiled_schema: Arc::clone(&self.compiled_schema),
205            handler: self.handler.clone(),
206        }
207    }
208}
209
210impl FunctionTool {
211    pub fn builder(name: impl Into<String>, description: impl Into<String>) -> FunctionToolBuilder {
212        FunctionToolBuilder::new(name, description)
213    }
214    /// Convenience: build a FunctionTool directly from a full JSON schema and a
215    /// handler
216    pub fn from_schema<F, Fut>(
217        name: impl Into<String>,
218        description: impl Into<String>,
219        schema: serde_json::Value,
220        f: F,
221    ) -> crate::toolkits::error::ToolResult<FunctionTool>
222    where
223        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
224        Fut: std::future::Future<Output = crate::toolkits::error::ToolResult<serde_json::Value>>
225            + Send
226            + 'static,
227    {
228        Self::builder(name, description)
229            .schema(schema)
230            .handler(f)
231            .build()
232    }
233    /// Build a FunctionTool from a full JSON spec (supports two shapes):
234    /// 1) {"name":..., "description":..., "parameters": {...}}
235    /// 2) {"type":"function", "function": {"name":..., "description":...,
236    ///    "parameters": {...}}}
237    pub fn from_function_spec<F, Fut>(
238        spec: serde_json::Value,
239        f: F,
240    ) -> crate::toolkits::error::ToolResult<FunctionTool>
241    where
242        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
243        Fut: std::future::Future<Output = crate::toolkits::error::ToolResult<serde_json::Value>>
244            + Send
245            + 'static,
246    {
247        let (name, description, parameters) = parse_function_spec_details(&spec)?;
248        let mut builder = Self::builder(name, description);
249        if let Some(p) = parameters {
250            builder = builder.schema(p);
251        }
252        builder.handler(f).build()
253    }
254
255    /// Read a JSON function spec from a file and build a FunctionTool.
256    pub fn from_function_spec_file<F, Fut>(
257        path: impl AsRef<std::path::Path>,
258        f: F,
259    ) -> crate::toolkits::error::ToolResult<FunctionTool>
260    where
261        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
262        Fut: std::future::Future<Output = crate::toolkits::error::ToolResult<serde_json::Value>>
263            + Send
264            + 'static,
265    {
266        let content = std::fs::read_to_string(path).map_err(|e| {
267            error_context().invalid_parameters(format!("Failed to read spec file: {}", e))
268        })?;
269        let spec: serde_json::Value = serde_json::from_str(&content)
270            .map_err(|e| error_context().invalid_parameters(format!("Invalid JSON: {}", e)))?;
271        Self::from_function_spec(spec, f)
272    }
273}
274
275/// Compile JSON schema with caching for better performance
276fn compile_schema_cached(schema: &serde_json::Value) -> ToolResult<Arc<jsonschema::Validator>> {
277    let mut hasher = DefaultHasher::new();
278    schema.to_string().hash(&mut hasher);
279    let hash = hasher.finish();
280
281    // Check cache first
282    {
283        let cache = SCHEMA_CACHE.read();
284        if let Some(cached) = cache.get(&hash) {
285            return Ok(Arc::clone(cached));
286        }
287    }
288
289    // Compile and cache
290    let validator = jsonschema::validator_for(schema).map_err(|e| {
291        error_context().schema_validation(format!("Failed to compile schema: {}", e))
292    })?;
293
294    let validator = Arc::new(validator);
295
296    {
297        let mut cache = SCHEMA_CACHE.write();
298        cache.insert(hash, Arc::clone(&validator));
299    }
300
301    Ok(validator)
302}
303
304/// (internal) Parses the name, description, and parameters from a JSON function
305/// spec.
306pub(crate) fn parse_function_spec_details(
307    spec: &serde_json::Value,
308) -> crate::toolkits::error::ToolResult<(String, String, Option<serde_json::Value>)> {
309    use serde_json::Value;
310    let obj = match spec {
311        Value::Object(map) => map,
312        _ => return Err(error_context().invalid_parameters("Function spec must be a JSON object")),
313    };
314    // Shape 2 with outer {type:function, function:{...}}
315    let (name, desc, params) = if obj.get("type").and_then(|v| v.as_str()) == Some("function") {
316        let f = obj
317            .get("function")
318            .and_then(|v| v.as_object())
319            .ok_or_else(|| error_context().invalid_parameters("Missing 'function' object"))?;
320        let name = f
321            .get("name")
322            .and_then(|v| v.as_str())
323            .ok_or_else(|| error_context().invalid_parameters("Missing function.name"))?
324            .to_string();
325        let desc = f
326            .get("description")
327            .and_then(|v| v.as_str())
328            .unwrap_or("")
329            .to_string();
330        let params = f.get("parameters").cloned();
331        (name, desc, params)
332    } else {
333        // Shape 1 inner {name, description, parameters}
334        let name = obj
335            .get("name")
336            .and_then(|v| v.as_str())
337            .ok_or_else(|| error_context().invalid_parameters("Missing name"))?
338            .to_string();
339        let desc = obj
340            .get("description")
341            .and_then(|v| v.as_str())
342            .unwrap_or("")
343            .to_string();
344        let params = obj.get("parameters").cloned();
345        (name, desc, params)
346    };
347    Ok((name, desc, params))
348}
349
350/// Builder for FunctionTool
351pub struct FunctionToolBuilder {
352    metadata: ToolMetadata,
353    input_schema: Option<serde_json::Value>,
354    // Optional staged schema pieces for convenience building when schema() is omitted or for
355    // merging
356    staged_properties: Option<serde_json::Map<String, serde_json::Value>>,
357    staged_required: Vec<String>,
358    handler: Option<ToolHandler>,
359}
360
361impl FunctionToolBuilder {
362    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
363        Self {
364            metadata: ToolMetadata::new(name, description).unwrap_or_else(|_| ToolMetadata {
365                name: Cow::Borrowed("unknown"),
366                description: Cow::Borrowed("unknown"),
367                version: Cow::Borrowed("1.0.0"),
368                author: None,
369                tags: Vec::new(),
370                enabled: true,
371                metadata: HashMap::new(),
372            }),
373            input_schema: None,
374            staged_properties: None,
375            staged_required: Vec::new(),
376            handler: None,
377        }
378    }
379
380    pub fn schema(mut self, schema: serde_json::Value) -> Self {
381        self.input_schema = Some(schema);
382        self
383    }
384
385    pub fn metadata(mut self, f: impl FnOnce(ToolMetadata) -> ToolMetadata) -> Self {
386        self.metadata = f(self.metadata);
387        self
388    }
389
390    pub fn handler<F, Fut>(mut self, f: F) -> Self
391    where
392        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
393        Fut: std::future::Future<Output = crate::toolkits::error::ToolResult<serde_json::Value>>
394            + Send
395            + 'static,
396    {
397        let wrapped = move |args: serde_json::Value| -> std::pin::Pin<
398            Box<
399                dyn std::future::Future<
400                        Output = crate::toolkits::error::ToolResult<serde_json::Value>,
401                    > + Send,
402            >,
403        > { Box::pin(f(args)) };
404        self.handler = Some(std::sync::Arc::new(wrapped));
405        self
406    }
407
408    /// Chain API: add one property to the schema. If `schema(json!(...))` is
409    /// also provided, the property will be merged into its `properties`
410    /// object.
411    pub fn property(mut self, name: impl Into<String>, schema: serde_json::Value) -> Self {
412        let name = name.into();
413        let entry = self
414            .staged_properties
415            .get_or_insert_with(serde_json::Map::new);
416        entry.insert(name, schema);
417        self
418    }
419
420    /// Chain API: mark a property as required. Will be merged with any provided
421    /// schema's `required`.
422    pub fn required(mut self, name: impl Into<String>) -> Self {
423        self.staged_required.push(name.into());
424        self
425    }
426
427    pub fn build(mut self) -> crate::toolkits::error::ToolResult<FunctionTool> {
428        let handler = self
429            .handler
430            .ok_or_else(|| error_context().invalid_parameters("FunctionTool handler not set"))?;
431        // Start with provided schema or an empty object to fill
432        let mut schema = self
433            .input_schema
434            .take()
435            .unwrap_or_else(|| serde_json::json!({}));
436
437        // If schema is an object, we can augment it; otherwise leave it as-is
438        if let serde_json::Value::Object(ref mut obj) = schema {
439            // Ensure required base shape
440            obj.entry("type")
441                .or_insert(serde_json::Value::String("object".to_string()));
442            obj.entry("additionalProperties")
443                .or_insert(serde_json::Value::Bool(false));
444
445            // Merge staged properties (if any)
446            if let Some(staged) = self.staged_properties.take() {
447                let props = obj
448                    .entry("properties")
449                    .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
450                if let serde_json::Value::Object(props_obj) = props {
451                    for (k, v) in staged {
452                        props_obj.insert(k, v);
453                    }
454                }
455            }
456            // Merge staged required (if any), de-duplicated
457            if !self.staged_required.is_empty() {
458                use std::collections::BTreeSet;
459                let mut set: BTreeSet<String> = obj
460                    .get("required")
461                    .and_then(|v| v.as_array())
462                    .map(|arr| {
463                        arr.iter()
464                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
465                            .collect()
466                    })
467                    .unwrap_or_default();
468                for r in self.staged_required.into_iter() {
469                    set.insert(r);
470                }
471                obj.insert(
472                    "required".to_string(),
473                    serde_json::Value::Array(
474                        set.into_iter().map(serde_json::Value::String).collect(),
475                    ),
476                );
477            }
478        } else {
479            // If schema is not an object and also not provided, enforce default
480            // But since we only hit here when schema is not an object (provided
481            // by user), we leave it.
482        }
483
484        // If user provided nothing and schema is empty object, ensure defaults
485        if let serde_json::Value::Object(ref mut obj) = schema {
486            obj.entry("type")
487                .or_insert(serde_json::Value::String("object".to_string()));
488            obj.entry("additionalProperties")
489                .or_insert(serde_json::Value::Bool(false));
490            // Ensure properties exists when we staged some but merging didn't set (edge
491            // case)
492            if obj.get("properties").is_none() {
493                obj.insert(
494                    "properties".to_string(),
495                    serde_json::Value::Object(serde_json::Map::new()),
496                );
497            }
498        }
499
500        let compiled_schema = compile_schema_cached(&schema).map_err(|e| {
501            error_context()
502                .with_tool(self.metadata.name.clone())
503                .schema_validation(format!("Failed to compile schema: {}", e))
504        })?;
505
506        Ok(FunctionTool {
507            metadata: self.metadata,
508            input_schema: schema,
509            compiled_schema,
510            handler,
511        })
512    }
513}
514
515#[async_trait]
516impl DynTool for FunctionTool {
517    fn metadata(&self) -> &ToolMetadata {
518        &self.metadata
519    }
520
521    async fn execute_json(&self, input: serde_json::Value) -> ToolResult<serde_json::Value> {
522        // Validate the input against the compiled schema
523        if let Err(validation_error) = self.compiled_schema.validate(&input) {
524            return Err(error_context()
525                .with_tool(self.name())
526                .invalid_parameters(format!("Input validation failed: {}", validation_error)));
527        }
528
529        // If validation passes, execute the handler
530        (self.handler)(input).await
531    }
532
533    fn input_schema(&self) -> serde_json::Value {
534        self.input_schema.clone()
535    }
536
537    fn clone_box(&self) -> Box<dyn DynTool> {
538        Box::new(self.clone())
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use crate::toolkits::ToolError;
546
547    #[test]
548    fn test_tool_metadata_new() {
549        let metadata = ToolMetadata::new("test_tool", "A test tool").unwrap();
550        assert_eq!(metadata.name, "test_tool");
551        assert_eq!(metadata.description, "A test tool");
552        assert_eq!(metadata.version, "1.0.0");
553        assert!(metadata.enabled);
554    }
555
556    #[test]
557    fn test_tool_metadata_invalid_name_empty() {
558        let result = ToolMetadata::new("", "A test tool");
559        assert!(result.is_err());
560        match result.unwrap_err() {
561            ToolError::InvalidParameters { .. } => {},
562            _ => panic!("Expected InvalidParameters error"),
563        }
564    }
565
566    #[test]
567    fn test_tool_metadata_invalid_name_special_chars() {
568        let result = ToolMetadata::new("test-tool!", "A test tool");
569        assert!(result.is_err());
570        match result.unwrap_err() {
571            ToolError::InvalidParameters { .. } => {},
572            _ => panic!("Expected InvalidParameters error"),
573        }
574    }
575
576    #[test]
577    fn test_tool_metadata_builder() {
578        let metadata = ToolMetadata::new("test_tool", "A test tool")
579            .unwrap()
580            .version("2.0.0")
581            .author("Test Author")
582            .tags(vec!["tag1", "tag2"])
583            .enabled(false);
584
585        assert_eq!(metadata.version, "2.0.0");
586        assert_eq!(metadata.author, Some(Cow::Borrowed("Test Author")));
587        assert_eq!(metadata.tags.len(), 2);
588        assert!(!metadata.enabled);
589    }
590
591    #[test]
592    fn test_conversions_to_json() {
593        let value = conversions::to_json(42).unwrap();
594        assert_eq!(value, 42);
595    }
596
597    #[test]
598    fn test_conversions_from_json_string() {
599        let value = serde_json::Value::String("hello".to_string());
600        let result = conversions::from_json_string(value).unwrap();
601        assert_eq!(result, "hello");
602    }
603
604    #[test]
605    fn test_conversions_from_json_string_invalid() {
606        let value = serde_json::Value::Number(42.into());
607        let result = conversions::from_json_string(value);
608        assert!(result.is_err());
609    }
610
611    #[test]
612    fn test_conversions_from_json_i32() {
613        let value = serde_json::Value::Number(42.into());
614        let result = conversions::from_json_i32(value).unwrap();
615        assert_eq!(result, 42);
616    }
617
618    #[test]
619    fn test_conversions_from_json_f64() {
620        let value = serde_json::json!(3.5);
621        let result = conversions::from_json_f64(value).unwrap();
622        assert_eq!(result, 3.5);
623    }
624
625    #[test]
626    fn test_conversions_from_json_bool() {
627        let value = serde_json::Value::Bool(true);
628        let result = conversions::from_json_bool(value).unwrap();
629        assert!(result);
630    }
631
632    #[test]
633    fn test_function_tool_builder() {
634        let tool = FunctionTool::builder("test_tool", "A test tool")
635            .property("param1", serde_json::json!({"type": "string"}))
636            .property("param2", serde_json::json!({"type": "number"}))
637            .required("param1")
638            .handler(|_args| async move { Ok(serde_json::json!({"result": "ok"})) })
639            .build();
640
641        assert!(tool.is_ok());
642        let tool = tool.unwrap();
643        assert_eq!(tool.name(), "test_tool");
644    }
645
646    #[test]
647    fn test_function_tool_clone() {
648        let tool1 = FunctionTool::builder("test_tool", "A test tool")
649            .property("param1", serde_json::json!({"type": "string"}))
650            .required("param1")
651            .handler(|_args| async move { Ok(serde_json::json!({"result": "ok"})) })
652            .build()
653            .unwrap();
654
655        let tool2 = tool1.clone();
656        assert_eq!(tool1.name(), tool2.name());
657        assert_eq!(tool1.input_schema(), tool2.input_schema());
658    }
659
660    #[test]
661    fn test_parse_function_spec_shape1() {
662        let spec = serde_json::json!({
663            "name": "test_tool",
664            "description": "A test tool",
665            "parameters": {
666                "type": "object",
667                "properties": {
668                    "param1": {"type": "string"}
669                }
670            }
671        });
672
673        let (name, description, parameters) = parse_function_spec_details(&spec).unwrap();
674        assert_eq!(name, "test_tool");
675        assert_eq!(description, "A test tool");
676        assert!(parameters.is_some());
677    }
678
679    #[test]
680    fn test_parse_function_spec_shape2() {
681        let spec = serde_json::json!({
682            "type": "function",
683            "function": {
684                "name": "test_tool",
685                "description": "A test tool",
686                "parameters": {
687                    "type": "object",
688                    "properties": {
689                        "param1": {"type": "string"}
690                    }
691                }
692            }
693        });
694
695        let (name, description, parameters) = parse_function_spec_details(&spec).unwrap();
696        assert_eq!(name, "test_tool");
697        assert_eq!(description, "A test tool");
698        assert!(parameters.is_some());
699    }
700
701    #[test]
702    fn test_parse_function_spec_invalid() {
703        let spec = serde_json::Value::String("invalid".to_string());
704        let result = parse_function_spec_details(&spec);
705        assert!(result.is_err());
706    }
707}