1use async_trait::async_trait;
2use std::collections::HashMap;
3use steer_tools::tools::bash::BASH_TOOL_NAME;
4use tokio_util::sync::CancellationToken;
5
6use crate::api::ToolCall;
7use crate::config::LlmConfigProvider;
8
9#[derive(Debug)]
10pub struct ValidationContext {
11 pub cancellation_token: CancellationToken,
12 pub llm_config_provider: LlmConfigProvider,
13}
14
15#[derive(Debug)]
16pub struct ValidationResult {
17 pub allowed: bool,
18 pub reason: Option<String>,
19 pub requires_user_approval: bool,
20}
21
22#[derive(Debug, thiserror::Error)]
23pub enum ValidationError {
24 #[error("Invalid parameters: {0}")]
25 InvalidParams(String),
26 #[error("Filter error: {0}")]
27 FilterError(String),
28 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30 #[error("Other error: {0}")]
31 Other(String),
32}
33
34#[async_trait]
35pub trait ToolValidator: Send + Sync {
36 fn tool_name(&self) -> &'static str;
37
38 async fn validate(
39 &self,
40 tool_call: &ToolCall,
41 context: &ValidationContext,
42 ) -> Result<ValidationResult, ValidationError>;
43}
44
45pub struct ValidatorRegistry {
46 validators: HashMap<String, Box<dyn ToolValidator>>,
47}
48
49impl Default for ValidatorRegistry {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl ValidatorRegistry {
56 pub fn new() -> Self {
57 let mut validators = HashMap::new();
58
59 validators.insert(
61 BASH_TOOL_NAME.to_string(),
62 Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
63 );
64
65 Self { validators }
66 }
67
68 pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
69 self.validators.get(tool_name).map(|v| v.as_ref())
70 }
71}
72
73use once_cell::sync::Lazy;
75use regex::Regex;
76use serde::Deserialize;
77
78#[derive(Debug, Clone, Deserialize)]
79pub struct BashParams {
80 pub command: String,
81 pub timeout: Option<u64>,
82}
83
84pub struct BashValidator;
85
86impl Default for BashValidator {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl BashValidator {
93 pub fn new() -> Self {
94 Self
95 }
96
97 fn is_banned_command(&self, command: &str) -> bool {
99 static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
100 let banned_commands = [
101 "curl",
103 "wget",
104 "nc",
105 "telnet",
106 "ssh",
107 "scp",
108 "ftp",
109 "sftp",
110 "lynx",
112 "w3m",
113 "links",
114 "elinks",
115 "httpie",
116 "xh",
117 "http-prompt",
118 "chrome",
119 "firefox",
120 "safari",
121 "edge",
122 "opera",
123 "chromium",
124 "axel",
126 "aria2c",
127 "alias",
129 "unalias",
130 "exec",
131 "source",
132 ".",
133 "history",
134 "sudo",
136 "su",
137 "chown",
138 "chmod",
139 "useradd",
140 "userdel",
141 "groupadd",
142 "groupdel",
143 "vi",
145 "vim",
146 "nano",
147 "pico",
148 "emacs",
149 "ed",
150 ];
151 banned_commands
152 .iter()
153 .map(|cmd| {
154 Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
155 .expect("Failed to compile banned command regex")
156 })
157 .collect()
158 });
159
160 BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
161 }
162
163 async fn is_command_allowed(
165 &self,
166 command: &str,
167 context: &ValidationContext,
168 ) -> Result<bool, ValidationError> {
169 crate::tools::command_filter::is_command_allowed(
170 command,
171 &context.llm_config_provider,
172 context.cancellation_token.clone(),
173 )
174 .await
175 .map_err(|e| ValidationError::FilterError(e.to_string()))
176 }
177}
178
179#[async_trait]
180impl ToolValidator for BashValidator {
181 fn tool_name(&self) -> &'static str {
182 BASH_TOOL_NAME
183 }
184
185 async fn validate(
186 &self,
187 tool_call: &ToolCall,
188 context: &ValidationContext,
189 ) -> Result<ValidationResult, ValidationError> {
190 let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
191 .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
192
193 if self.is_banned_command(¶ms.command) {
195 return Ok(ValidationResult {
196 allowed: false,
197 reason: Some(format!(
198 "Command '{}' is disallowed for security reasons",
199 params.command
200 )),
201 requires_user_approval: false,
202 });
203 }
204
205 let is_allowed = self.is_command_allowed(¶ms.command, context).await?;
207
208 if !is_allowed {
209 return Ok(ValidationResult {
210 allowed: false,
211 reason: Some(format!(
212 "Command '{}' was blocked by the command filter",
213 params.command
214 )),
215 requires_user_approval: false,
216 });
217 }
218
219 Ok(ValidationResult {
221 allowed: true,
222 reason: None,
223 requires_user_approval: false,
224 })
225 }
226}