Skip to main content

tuillem_plugin/
lib.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5use tokio::io::AsyncWriteExt;
6use tuillem_config::ToolConfig;
7
8// ---------------------------------------------------------------------------
9// Errors
10// ---------------------------------------------------------------------------
11
12#[derive(Debug, thiserror::Error)]
13pub enum PluginError {
14    #[error("Tool not found: {0}")]
15    NotFound(String),
16
17    #[error("Execution error: {0}")]
18    Execution(String),
19
20    #[error("Tool timed out after {0:?}")]
21    Timeout(Duration),
22
23    #[error("IO error: {0}")]
24    Io(#[from] std::io::Error),
25
26    #[error("JSON error: {0}")]
27    Json(#[from] serde_json::Error),
28}
29
30// ---------------------------------------------------------------------------
31// ToolInput / ToolOutput
32// ---------------------------------------------------------------------------
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolInput {
36    pub name: String,
37    pub input: serde_json::Value,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ToolOutput {
42    pub output: Option<String>,
43    pub error: Option<String>,
44}
45
46// ---------------------------------------------------------------------------
47// PluginHost
48// ---------------------------------------------------------------------------
49
50pub struct PluginHost {
51    tools: HashMap<String, ToolConfig>,
52}
53
54impl PluginHost {
55    pub fn new(tools: Vec<ToolConfig>) -> Self {
56        let tools = tools.into_iter().map(|t| (t.name.clone(), t)).collect();
57        Self { tools }
58    }
59
60    pub fn list_tools(&self) -> Vec<&ToolConfig> {
61        self.tools.values().collect()
62    }
63
64    pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
65        self.tools.get(name)
66    }
67
68    pub fn requires_confirmation(&self, name: &str) -> bool {
69        self.tools.get(name).map(|t| t.confirm).unwrap_or(false)
70    }
71
72    pub async fn invoke(
73        &self,
74        name: &str,
75        input: serde_json::Value,
76    ) -> Result<ToolOutput, PluginError> {
77        let tool = self
78            .tools
79            .get(name)
80            .ok_or_else(|| PluginError::NotFound(name.to_string()))?;
81
82        let timeout = parse_duration(&tool.timeout);
83
84        // Split command into program + args
85        let parts: Vec<&str> = tool.command.split_whitespace().collect();
86        let (program, args) = parts
87            .split_first()
88            .ok_or_else(|| PluginError::Execution("Empty command".to_string()))?;
89
90        let mut cmd = tokio::process::Command::new(program);
91        cmd.args(args);
92        cmd.stdin(std::process::Stdio::piped());
93        cmd.stdout(std::process::Stdio::piped());
94        cmd.stderr(std::process::Stdio::piped());
95
96        // Set environment variables
97        for (k, v) in &tool.env {
98            cmd.env(k, v);
99        }
100
101        let mut child = cmd.spawn()?;
102
103        // Write JSON to stdin
104        let tool_input = ToolInput {
105            name: name.to_string(),
106            input,
107        };
108        let input_json = serde_json::to_string(&tool_input)?;
109
110        if let Some(mut stdin) = child.stdin.take() {
111            stdin.write_all(input_json.as_bytes()).await?;
112            // Drop stdin to signal EOF
113            drop(stdin);
114        }
115
116        // Wait with timeout
117        let result = tokio::time::timeout(timeout, child.wait_with_output()).await;
118
119        match result {
120            Ok(Ok(output)) => {
121                let stdout = String::from_utf8_lossy(&output.stdout).to_string();
122                let stderr = String::from_utf8_lossy(&output.stderr).to_string();
123
124                // Try to parse stdout as ToolOutput JSON, fall back to raw stdout.
125                // Only accept it if at least one of output/error is present.
126                if let Ok(tool_output) = serde_json::from_str::<ToolOutput>(&stdout)
127                    && (tool_output.output.is_some() || tool_output.error.is_some())
128                {
129                    return Ok(tool_output);
130                }
131                {
132                    Ok(ToolOutput {
133                        output: if stdout.is_empty() {
134                            None
135                        } else {
136                            Some(stdout)
137                        },
138                        error: if stderr.is_empty() {
139                            None
140                        } else {
141                            Some(stderr)
142                        },
143                    })
144                }
145            }
146            Ok(Err(e)) => Err(PluginError::Io(e)),
147            Err(_) => Err(PluginError::Timeout(timeout)),
148        }
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Helpers
154// ---------------------------------------------------------------------------
155
156pub fn parse_duration(s: &str) -> Duration {
157    let s = s.trim();
158    if let Some(secs) = s.strip_suffix('s') {
159        Duration::from_secs(secs.parse::<u64>().unwrap_or(30))
160    } else if let Some(mins) = s.strip_suffix('m') {
161        Duration::from_secs(mins.parse::<u64>().unwrap_or(1) * 60)
162    } else {
163        Duration::from_secs(s.parse::<u64>().unwrap_or(30))
164    }
165}
166
167// ---------------------------------------------------------------------------
168// Tests
169// ---------------------------------------------------------------------------
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::collections::HashMap;
175
176    fn make_tool(name: &str, command: &str, timeout: &str, confirm: bool) -> ToolConfig {
177        ToolConfig {
178            name: name.to_string(),
179            description: format!("{name} tool"),
180            command: command.to_string(),
181            input_schema: None,
182            timeout: timeout.to_string(),
183            confirm,
184            env: HashMap::new(),
185        }
186    }
187
188    #[tokio::test]
189    async fn test_invoke_tool() {
190        // `cat` reads stdin and writes to stdout
191        let tool = make_tool("cat_tool", "cat", "10s", false);
192        let host = PluginHost::new(vec![tool]);
193
194        let input = serde_json::json!({"message": "hello"});
195        let result = host.invoke("cat_tool", input).await;
196        assert!(result.is_ok(), "invoke should succeed: {result:?}");
197
198        let output = result.unwrap();
199        // cat echoes the JSON input back as raw stdout
200        assert!(output.output.is_some(), "output should be present");
201        let text = output.output.unwrap();
202        assert!(
203            text.contains("hello"),
204            "output should contain 'hello': {text}"
205        );
206    }
207
208    #[tokio::test]
209    async fn test_tool_not_found() {
210        let host = PluginHost::new(vec![]);
211        let result = host.invoke("nonexistent", serde_json::json!({})).await;
212        assert!(result.is_err());
213        assert!(
214            matches!(result.unwrap_err(), PluginError::NotFound(name) if name == "nonexistent")
215        );
216    }
217
218    #[test]
219    fn test_requires_confirmation() {
220        let tool = make_tool("dangerous", "rm -rf", "10s", true);
221        let host = PluginHost::new(vec![tool]);
222        assert!(host.requires_confirmation("dangerous"));
223        assert!(!host.requires_confirmation("nonexistent"));
224    }
225
226    #[test]
227    fn test_list_tools() {
228        let tools = vec![
229            make_tool("a", "echo", "10s", false),
230            make_tool("b", "cat", "10s", false),
231            make_tool("c", "ls", "10s", false),
232        ];
233        let host = PluginHost::new(tools);
234        assert_eq!(host.list_tools().len(), 3);
235    }
236
237    #[test]
238    fn test_parse_duration() {
239        assert_eq!(parse_duration("30s"), Duration::from_secs(30));
240        assert_eq!(parse_duration("2m"), Duration::from_secs(120));
241        assert_eq!(parse_duration("45"), Duration::from_secs(45));
242    }
243
244    #[tokio::test]
245    async fn test_timeout() {
246        let tool = make_tool("sleeper", "sleep 60", "1s", false);
247        let host = PluginHost::new(vec![tool]);
248
249        let result = host.invoke("sleeper", serde_json::json!({})).await;
250        assert!(result.is_err());
251        assert!(
252            matches!(result.unwrap_err(), PluginError::Timeout(d) if d == Duration::from_secs(1))
253        );
254    }
255}