1use crate::schema::to_gemini_parameters;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23
24#[derive(Debug, Clone)]
26pub struct ToolDef {
27 pub name: String,
29 pub description: String,
31 pub parameters: Value,
33}
34
35pub fn tool<T: JsonSchema + DeserializeOwned>(name: &str, description: &str) -> ToolDef {
40 ToolDef {
41 name: name.to_string(),
42 description: description.to_string(),
43 parameters: to_gemini_parameters::<T>(),
44 }
45}
46
47impl ToolDef {
48 pub fn to_gemini(&self) -> Value {
50 serde_json::json!({
51 "name": self.name,
52 "description": self.description,
53 "parameters": self.parameters,
54 })
55 }
56
57 pub fn to_openai(&self) -> Value {
59 serde_json::json!({
60 "type": "function",
61 "function": {
62 "name": self.name,
63 "description": self.description,
64 "parameters": self.parameters,
65 }
66 })
67 }
68
69 pub fn parse_args<T: DeserializeOwned>(&self, args: &Value) -> Result<T, serde_json::Error> {
71 serde_json::from_value(args.clone())
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use schemars::JsonSchema;
79 use serde::{Deserialize, Serialize};
80
81 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
82 struct MockTool {
83 input_path: String,
84 quality: Option<f64>,
85 }
86
87 #[test]
88 fn tool_generates_gemini_format() {
89 let t = tool::<MockTool>("mock_tool", "A mock tool");
90 let gemini = t.to_gemini();
91 assert_eq!(gemini["name"], "mock_tool");
92 assert_eq!(gemini["description"], "A mock tool");
93 assert!(gemini["parameters"]["properties"]["input_path"].is_object());
94 }
95
96 #[test]
97 fn tool_generates_openai_format() {
98 let t = tool::<MockTool>("mock_tool", "A mock tool");
99 let openai = t.to_openai();
100 assert_eq!(openai["type"], "function");
101 assert_eq!(openai["function"]["name"], "mock_tool");
102 }
103
104 #[test]
105 fn parse_args_works() {
106 let t = tool::<MockTool>("mock_tool", "test");
107 let args = serde_json::json!({"input_path": "/video.mp4", "quality": 0.8});
108 let parsed: MockTool = t.parse_args(&args).unwrap();
109 assert_eq!(parsed.input_path, "/video.mp4");
110 assert_eq!(parsed.quality, Some(0.8));
111 }
112}