1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10 pub cancellation_token: CancellationToken,
11 pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16 pub allowed: bool,
17 pub reason: Option<String>,
18 pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23 #[error("Invalid parameters: {0}")]
24 InvalidParams(String),
25 #[error("Filter error: {0}")]
26 FilterError(String),
27 #[error("IO error: {0}")]
28 Io(#[from] std::io::Error),
29 #[error("Other error: {0}")]
30 Other(String),
31}
32
33#[async_trait]
34pub trait ToolValidator: Send + Sync {
35 fn tool_name(&self) -> &'static str;
36
37 async fn validate(
38 &self,
39 tool_call: &ToolCall,
40 context: &ValidationContext,
41 ) -> Result<ValidationResult, ValidationError>;
42}
43
44pub struct ValidatorRegistry {
45 validators: HashMap<String, Box<dyn ToolValidator>>,
46}
47
48impl Default for ValidatorRegistry {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl ValidatorRegistry {
55 pub fn new() -> Self {
56 let mut validators = HashMap::new();
57
58 validators.insert(
60 BASH_TOOL_NAME.to_string(),
61 Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
62 );
63
64 Self { validators }
65 }
66
67 pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
68 self.validators.get(tool_name).map(|v| v.as_ref())
69 }
70}
71
72use once_cell::sync::Lazy;
74use regex::Regex;
75use serde::Deserialize;
76
77#[derive(Debug, Clone, Deserialize)]
78pub struct BashParams {
79 pub command: String,
80 pub timeout: Option<u64>,
81}
82
83pub struct BashValidator;
84
85impl Default for BashValidator {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl BashValidator {
92 pub fn new() -> Self {
93 Self
94 }
95
96 fn is_banned_command(&self, command: &str) -> bool {
98 static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
99 let banned_commands = [
100 "curl",
102 "wget",
103 "nc",
104 "telnet",
105 "ssh",
106 "scp",
107 "ftp",
108 "sftp",
109 "lynx",
111 "w3m",
112 "links",
113 "elinks",
114 "httpie",
115 "xh",
116 "http-prompt",
117 "chrome",
118 "firefox",
119 "safari",
120 "edge",
121 "opera",
122 "chromium",
123 "axel",
125 "aria2c",
126 "alias",
128 "unalias",
129 "exec",
130 "source",
131 ".",
132 "history",
133 "sudo",
135 "su",
136 "chown",
137 "chmod",
138 "useradd",
139 "userdel",
140 "groupadd",
141 "groupdel",
142 "vi",
144 "vim",
145 "nano",
146 "pico",
147 "emacs",
148 "ed",
149 ];
150 banned_commands
151 .iter()
152 .map(|cmd| {
153 Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
154 .expect("Failed to compile banned command regex")
155 })
156 .collect()
157 });
158
159 BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
160 }
161
162 async fn is_command_allowed(
164 &self,
165 command: &str,
166 context: &ValidationContext,
167 ) -> Result<bool, ValidationError> {
168 crate::tools::command_filter::is_command_allowed(
169 command,
170 &context.llm_config_provider,
171 context.cancellation_token.clone(),
172 )
173 .await
174 .map_err(|e| ValidationError::FilterError(e.to_string()))
175 }
176}
177
178#[async_trait]
179impl ToolValidator for BashValidator {
180 fn tool_name(&self) -> &'static str {
181 BASH_TOOL_NAME
182 }
183
184 async fn validate(
185 &self,
186 tool_call: &ToolCall,
187 context: &ValidationContext,
188 ) -> Result<ValidationResult, ValidationError> {
189 let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
190 .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
191
192 if self.is_banned_command(¶ms.command) {
194 return Ok(ValidationResult {
195 allowed: false,
196 reason: Some(format!(
197 "Command '{}' is disallowed for security reasons",
198 params.command
199 )),
200 requires_user_approval: false,
201 });
202 }
203
204 let is_allowed = self.is_command_allowed(¶ms.command, context).await?;
206
207 if !is_allowed {
208 return Ok(ValidationResult {
209 allowed: false,
210 reason: Some(format!(
211 "Command '{}' was blocked by the command filter",
212 params.command
213 )),
214 requires_user_approval: false,
215 });
216 }
217
218 Ok(ValidationResult {
220 allowed: true,
221 reason: None,
222 requires_user_approval: false,
223 })
224 }
225}