1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5use tokio::io::AsyncWriteExt;
6use tuillem_config::ToolConfig;
7
8#[derive(Debug, thiserror::Error)]
13pub enum PluginError {
14 #[error("Tool not found: {0}")]
15 NotFound(String),
16
17 #[error("Execution error: {0}")]
18 Execution(String),
19
20 #[error("Tool timed out after {0:?}")]
21 Timeout(Duration),
22
23 #[error("IO error: {0}")]
24 Io(#[from] std::io::Error),
25
26 #[error("JSON error: {0}")]
27 Json(#[from] serde_json::Error),
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolInput {
36 pub name: String,
37 pub input: serde_json::Value,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ToolOutput {
42 pub output: Option<String>,
43 pub error: Option<String>,
44}
45
46pub struct PluginHost {
51 tools: HashMap<String, ToolConfig>,
52}
53
54impl PluginHost {
55 pub fn new(tools: Vec<ToolConfig>) -> Self {
56 let tools = tools.into_iter().map(|t| (t.name.clone(), t)).collect();
57 Self { tools }
58 }
59
60 pub fn list_tools(&self) -> Vec<&ToolConfig> {
61 self.tools.values().collect()
62 }
63
64 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
65 self.tools.get(name)
66 }
67
68 pub fn requires_confirmation(&self, name: &str) -> bool {
69 self.tools.get(name).map(|t| t.confirm).unwrap_or(false)
70 }
71
72 pub async fn invoke(
73 &self,
74 name: &str,
75 input: serde_json::Value,
76 ) -> Result<ToolOutput, PluginError> {
77 let tool = self
78 .tools
79 .get(name)
80 .ok_or_else(|| PluginError::NotFound(name.to_string()))?;
81
82 let timeout = parse_duration(&tool.timeout);
83
84 let parts: Vec<&str> = tool.command.split_whitespace().collect();
86 let (program, args) = parts
87 .split_first()
88 .ok_or_else(|| PluginError::Execution("Empty command".to_string()))?;
89
90 let mut cmd = tokio::process::Command::new(program);
91 cmd.args(args);
92 cmd.stdin(std::process::Stdio::piped());
93 cmd.stdout(std::process::Stdio::piped());
94 cmd.stderr(std::process::Stdio::piped());
95
96 for (k, v) in &tool.env {
98 cmd.env(k, v);
99 }
100
101 let mut child = cmd.spawn()?;
102
103 let tool_input = ToolInput {
105 name: name.to_string(),
106 input,
107 };
108 let input_json = serde_json::to_string(&tool_input)?;
109
110 if let Some(mut stdin) = child.stdin.take() {
111 stdin.write_all(input_json.as_bytes()).await?;
112 drop(stdin);
114 }
115
116 let result = tokio::time::timeout(timeout, child.wait_with_output()).await;
118
119 match result {
120 Ok(Ok(output)) => {
121 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
122 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
123
124 if let Ok(tool_output) = serde_json::from_str::<ToolOutput>(&stdout)
127 && (tool_output.output.is_some() || tool_output.error.is_some())
128 {
129 return Ok(tool_output);
130 }
131 {
132 Ok(ToolOutput {
133 output: if stdout.is_empty() {
134 None
135 } else {
136 Some(stdout)
137 },
138 error: if stderr.is_empty() {
139 None
140 } else {
141 Some(stderr)
142 },
143 })
144 }
145 }
146 Ok(Err(e)) => Err(PluginError::Io(e)),
147 Err(_) => Err(PluginError::Timeout(timeout)),
148 }
149 }
150}
151
152pub fn parse_duration(s: &str) -> Duration {
157 let s = s.trim();
158 if let Some(secs) = s.strip_suffix('s') {
159 Duration::from_secs(secs.parse::<u64>().unwrap_or(30))
160 } else if let Some(mins) = s.strip_suffix('m') {
161 Duration::from_secs(mins.parse::<u64>().unwrap_or(1) * 60)
162 } else {
163 Duration::from_secs(s.parse::<u64>().unwrap_or(30))
164 }
165}
166
167#[cfg(test)]
172mod tests {
173 use super::*;
174 use std::collections::HashMap;
175
176 fn make_tool(name: &str, command: &str, timeout: &str, confirm: bool) -> ToolConfig {
177 ToolConfig {
178 name: name.to_string(),
179 description: format!("{name} tool"),
180 command: command.to_string(),
181 input_schema: None,
182 timeout: timeout.to_string(),
183 confirm,
184 env: HashMap::new(),
185 }
186 }
187
188 #[tokio::test]
189 async fn test_invoke_tool() {
190 let tool = make_tool("cat_tool", "cat", "10s", false);
192 let host = PluginHost::new(vec![tool]);
193
194 let input = serde_json::json!({"message": "hello"});
195 let result = host.invoke("cat_tool", input).await;
196 assert!(result.is_ok(), "invoke should succeed: {result:?}");
197
198 let output = result.unwrap();
199 assert!(output.output.is_some(), "output should be present");
201 let text = output.output.unwrap();
202 assert!(
203 text.contains("hello"),
204 "output should contain 'hello': {text}"
205 );
206 }
207
208 #[tokio::test]
209 async fn test_tool_not_found() {
210 let host = PluginHost::new(vec![]);
211 let result = host.invoke("nonexistent", serde_json::json!({})).await;
212 assert!(result.is_err());
213 assert!(
214 matches!(result.unwrap_err(), PluginError::NotFound(name) if name == "nonexistent")
215 );
216 }
217
218 #[test]
219 fn test_requires_confirmation() {
220 let tool = make_tool("dangerous", "rm -rf", "10s", true);
221 let host = PluginHost::new(vec![tool]);
222 assert!(host.requires_confirmation("dangerous"));
223 assert!(!host.requires_confirmation("nonexistent"));
224 }
225
226 #[test]
227 fn test_list_tools() {
228 let tools = vec![
229 make_tool("a", "echo", "10s", false),
230 make_tool("b", "cat", "10s", false),
231 make_tool("c", "ls", "10s", false),
232 ];
233 let host = PluginHost::new(tools);
234 assert_eq!(host.list_tools().len(), 3);
235 }
236
237 #[test]
238 fn test_parse_duration() {
239 assert_eq!(parse_duration("30s"), Duration::from_secs(30));
240 assert_eq!(parse_duration("2m"), Duration::from_secs(120));
241 assert_eq!(parse_duration("45"), Duration::from_secs(45));
242 }
243
244 #[tokio::test]
245 async fn test_timeout() {
246 let tool = make_tool("sleeper", "sleep 60", "1s", false);
247 let host = PluginHost::new(vec![tool]);
248
249 let result = host.invoke("sleeper", serde_json::json!({})).await;
250 assert!(result.is_err());
251 assert!(
252 matches!(result.unwrap_err(), PluginError::Timeout(d) if d == Duration::from_secs(1))
253 );
254 }
255}