Skip to main content

sh_layer3/builtin_tools/
shell.rs

1//! # Shell Tools
2//!
3//! Shell 执行工具集。
4
5use crate::builtin_tools::BuiltinTool;
6use crate::types::{Layer3Result, ToolCategory};
7use async_trait::async_trait;
8use std::process::Stdio;
9use std::time::Duration;
10use tokio::process::Command;
11use tokio::time::timeout;
12
13/// Bash Tool - Execute shell commands
14pub struct BashTool;
15
16#[async_trait]
17impl BuiltinTool for BashTool {
18    fn name(&self) -> &str {
19        "bash"
20    }
21
22    fn description(&self) -> &str {
23        "Execute a bash shell command with timeout."
24    }
25
26    fn parameters_schema(&self) -> serde_json::Value {
27        serde_json::json!({
28            "type": "object",
29            "properties": {
30                "command": {
31                    "type": "string",
32                    "description": "The bash command to execute"
33                },
34                "timeout": {
35                    "type": "integer",
36                    "description": "Optional: timeout in milliseconds (default: 30000)"
37                },
38                "working_dir": {
39                    "type": "string",
40                    "description": "Optional: working directory for the command"
41                }
42            },
43            "required": ["command"]
44        })
45    }
46
47    fn category(&self) -> ToolCategory {
48        ToolCategory::Shell
49    }
50
51    fn is_dangerous(&self) -> bool {
52        true
53    }
54
55    fn requires_confirmation(&self) -> bool {
56        true
57    }
58
59    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
60        let command = args["command"]
61            .as_str()
62            .ok_or_else(|| anyhow::anyhow!("Missing command parameter"))?;
63
64        let timeout_ms = args["timeout"].as_u64().unwrap_or(30000);
65        let working_dir = args["working_dir"].as_str().map(|s| s.to_string());
66
67        // Build the command
68        #[cfg(windows)]
69        let mut cmd = Command::new("cmd");
70        #[cfg(windows)]
71        cmd.args(["/C", command]);
72
73        #[cfg(not(windows))]
74        let mut cmd = Command::new("sh");
75        #[cfg(not(windows))]
76        cmd.args(["-c", command]);
77
78        // Set working directory
79        if let Some(dir) = working_dir {
80            cmd.current_dir(dir);
81        }
82
83        // Configure stdio
84        cmd.stdout(Stdio::piped());
85        cmd.stderr(Stdio::piped());
86
87        // Execute with timeout
88        let timeout_duration = Duration::from_millis(timeout_ms);
89
90        let output = timeout(timeout_duration, cmd.output())
91            .await
92            .map_err(|_| anyhow::anyhow!("Command timed out after {}ms", timeout_ms))?
93            .map_err(|e| anyhow::anyhow!("Failed to execute command: {}", e))?;
94
95        // Process output
96        let stdout = String::from_utf8_lossy(&output.stdout);
97        let stderr = String::from_utf8_lossy(&output.stderr);
98
99        if output.status.success() {
100            Ok(stdout.trim().to_string())
101        } else {
102            let exit_code = output.status.code().unwrap_or(-1);
103            let mut error_msg = format!("Exit code: {}", exit_code);
104            if !stderr.is_empty() {
105                error_msg.push_str(&format!("\nError: {}", stderr.trim()));
106            }
107            if !stdout.is_empty() {
108                error_msg.push_str(&format!("\nOutput: {}", stdout.trim()));
109            }
110            Err(anyhow::anyhow!(error_msg))
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use serde_json::json;
119
120    #[test]
121    fn test_bash_tool_dangerous() {
122        let tool = BashTool;
123        assert!(tool.is_dangerous());
124        assert!(tool.requires_confirmation());
125    }
126
127    #[tokio::test]
128    async fn test_bash_execute_success() {
129        let tool = BashTool;
130
131        #[cfg(windows)]
132        let result = tool.execute(json!({"command": "echo hello"})).await;
133        #[cfg(not(windows))]
134        let result = tool.execute(json!({"command": "echo hello"})).await;
135
136        assert!(result.is_ok());
137        assert!(result.unwrap().contains("hello"));
138    }
139
140    #[tokio::test]
141    async fn test_bash_execute_failure() {
142        let tool = BashTool;
143
144        #[cfg(windows)]
145        let result = tool.execute(json!({"command": "exit 1"})).await;
146        #[cfg(not(windows))]
147        let result = tool.execute(json!({"command": "exit 1"})).await;
148
149        assert!(result.is_err());
150        let err = result.unwrap_err();
151        assert!(err.to_string().contains("Exit code: 1"));
152    }
153
154    #[tokio::test]
155    async fn test_bash_execute_timeout() {
156        let tool = BashTool;
157
158        #[cfg(windows)]
159        let result = tool
160            .execute(json!({"command": "ping -n 10 localhost", "timeout": 100}))
161            .await;
162        #[cfg(not(windows))]
163        let result = tool
164            .execute(json!({"command": "sleep 10", "timeout": 100}))
165            .await;
166
167        assert!(result.is_err());
168        let err = result.unwrap_err();
169        assert!(err.to_string().contains("timed out"));
170    }
171
172    #[tokio::test]
173    async fn test_bash_working_directory() {
174        let tool = BashTool;
175        let temp_dir = std::env::temp_dir();
176
177        #[cfg(windows)]
178        let result = tool
179            .execute(json!({"command": "cd", "working_dir": temp_dir.to_str()}))
180            .await;
181        #[cfg(not(windows))]
182        let result = tool
183            .execute(json!({"command": "pwd", "working_dir": temp_dir.to_str()}))
184            .await;
185
186        assert!(result.is_ok());
187        let output = result.unwrap();
188        let temp_str = temp_dir.to_string_lossy().to_string();
189        assert!(output.contains(&temp_str) || output.contains("Temp"));
190    }
191
192    #[tokio::test]
193    async fn test_bash_missing_command() {
194        let tool = BashTool;
195        let result = tool.execute(json!({})).await;
196        assert!(result.is_err());
197        assert!(result.unwrap_err().to_string().contains("Missing command"));
198    }
199}