steer_core/app/
validation.rs1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10 pub cancellation_token: CancellationToken,
11 pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16 pub allowed: bool,
17 pub reason: Option<String>,
18 pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23 #[error("Invalid parameters: {0}")]
24 InvalidParams(String),
25 #[error("IO error: {0}")]
26 Io(#[from] std::io::Error),
27 #[error("Other error: {0}")]
28 Other(String),
29}
30
31#[async_trait]
32pub trait ToolValidator: Send + Sync {
33 fn tool_name(&self) -> &'static str;
34
35 async fn validate(
36 &self,
37 tool_call: &ToolCall,
38 context: &ValidationContext,
39 ) -> Result<ValidationResult, ValidationError>;
40}
41
42pub struct ValidatorRegistry {
43 validators: HashMap<String, Box<dyn ToolValidator>>,
44}
45
46impl Default for ValidatorRegistry {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl ValidatorRegistry {
53 pub fn new() -> Self {
54 let mut validators = HashMap::new();
55
56 validators.insert(
58 BASH_TOOL_NAME.to_string(),
59 Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
60 );
61
62 Self { validators }
63 }
64
65 pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
66 self.validators.get(tool_name).map(|v| v.as_ref())
67 }
68}
69
70use regex::Regex;
72use serde::Deserialize;
73
74#[derive(Debug, Clone, Deserialize)]
75pub struct BashParams {
76 pub command: String,
77 pub timeout: Option<u64>,
78}
79
80pub struct BashValidator;
81
82impl Default for BashValidator {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl BashValidator {
89 pub fn new() -> Self {
90 Self
91 }
92
93 fn is_banned_command(command: &str) -> bool {
95 static BANNED_COMMAND_REGEXES: std::sync::LazyLock<Vec<Regex>> =
96 std::sync::LazyLock::new(|| {
97 let banned_commands = [
98 "curl",
100 "wget",
101 "nc",
102 "telnet",
103 "ssh",
104 "scp",
105 "ftp",
106 "sftp",
107 "lynx",
109 "w3m",
110 "links",
111 "elinks",
112 "httpie",
113 "xh",
114 "http-prompt",
115 "chrome",
116 "firefox",
117 "safari",
118 "edge",
119 "opera",
120 "chromium",
121 "axel",
123 "aria2c",
124 "alias",
126 "unalias",
127 "exec",
128 "source",
129 ".",
130 "history",
131 "sudo",
133 "su",
134 "chown",
135 "chmod",
136 "useradd",
137 "userdel",
138 "groupadd",
139 "groupdel",
140 "vi",
142 "vim",
143 "nano",
144 "pico",
145 "emacs",
146 "ed",
147 ];
148 banned_commands
149 .iter()
150 .filter_map(|cmd| {
151 let pattern = format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd));
152 match Regex::new(&pattern) {
153 Ok(regex) => Some(regex),
154 Err(err) => {
155 tracing::error!(
156 target: "tools::bash",
157 command = %cmd,
158 error = %err,
159 "Failed to compile banned command regex"
160 );
161 None
162 }
163 }
164 })
165 .collect()
166 });
167
168 BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
169 }
170}
171
172#[async_trait]
173impl ToolValidator for BashValidator {
174 fn tool_name(&self) -> &'static str {
175 BASH_TOOL_NAME
176 }
177
178 async fn validate(
179 &self,
180 tool_call: &ToolCall,
181 _context: &ValidationContext,
182 ) -> Result<ValidationResult, ValidationError> {
183 let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
184 .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
185
186 if Self::is_banned_command(¶ms.command) {
188 return Ok(ValidationResult {
189 allowed: false,
190 reason: Some(format!(
191 "Command '{}' is disallowed for security reasons",
192 params.command
193 )),
194 requires_user_approval: false,
195 });
196 }
197
198 Ok(ValidationResult {
200 allowed: true,
201 reason: None,
202 requires_user_approval: false,
203 })
204 }
205}