steer_core/app/
validation.rs1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10 pub cancellation_token: CancellationToken,
11 pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16 pub allowed: bool,
17 pub reason: Option<String>,
18 pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23 #[error("Invalid parameters: {0}")]
24 InvalidParams(String),
25 #[error("IO error: {0}")]
26 Io(#[from] std::io::Error),
27 #[error("Other error: {0}")]
28 Other(String),
29}
30
31#[async_trait]
32pub trait ToolValidator: Send + Sync {
33 fn tool_name(&self) -> &'static str;
34
35 async fn validate(
36 &self,
37 tool_call: &ToolCall,
38 context: &ValidationContext,
39 ) -> Result<ValidationResult, ValidationError>;
40}
41
42pub struct ValidatorRegistry {
43 validators: HashMap<String, Box<dyn ToolValidator>>,
44}
45
46impl Default for ValidatorRegistry {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl ValidatorRegistry {
53 pub fn new() -> Self {
54 let mut validators = HashMap::new();
55
56 validators.insert(
58 BASH_TOOL_NAME.to_string(),
59 Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
60 );
61
62 Self { validators }
63 }
64
65 pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
66 self.validators.get(tool_name).map(|v| v.as_ref())
67 }
68}
69
70use once_cell::sync::Lazy;
72use regex::Regex;
73use serde::Deserialize;
74
75#[derive(Debug, Clone, Deserialize)]
76pub struct BashParams {
77 pub command: String,
78 pub timeout: Option<u64>,
79}
80
81pub struct BashValidator;
82
83impl Default for BashValidator {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl BashValidator {
90 pub fn new() -> Self {
91 Self
92 }
93
94 fn is_banned_command(&self, command: &str) -> bool {
96 static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
97 let banned_commands = [
98 "curl",
100 "wget",
101 "nc",
102 "telnet",
103 "ssh",
104 "scp",
105 "ftp",
106 "sftp",
107 "lynx",
109 "w3m",
110 "links",
111 "elinks",
112 "httpie",
113 "xh",
114 "http-prompt",
115 "chrome",
116 "firefox",
117 "safari",
118 "edge",
119 "opera",
120 "chromium",
121 "axel",
123 "aria2c",
124 "alias",
126 "unalias",
127 "exec",
128 "source",
129 ".",
130 "history",
131 "sudo",
133 "su",
134 "chown",
135 "chmod",
136 "useradd",
137 "userdel",
138 "groupadd",
139 "groupdel",
140 "vi",
142 "vim",
143 "nano",
144 "pico",
145 "emacs",
146 "ed",
147 ];
148 banned_commands
149 .iter()
150 .map(|cmd| {
151 Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
152 .expect("Failed to compile banned command regex")
153 })
154 .collect()
155 });
156
157 BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
158 }
159}
160
161#[async_trait]
162impl ToolValidator for BashValidator {
163 fn tool_name(&self) -> &'static str {
164 BASH_TOOL_NAME
165 }
166
167 async fn validate(
168 &self,
169 tool_call: &ToolCall,
170 _context: &ValidationContext,
171 ) -> Result<ValidationResult, ValidationError> {
172 let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
173 .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
174
175 if self.is_banned_command(¶ms.command) {
177 return Ok(ValidationResult {
178 allowed: false,
179 reason: Some(format!(
180 "Command '{}' is disallowed for security reasons",
181 params.command
182 )),
183 requires_user_approval: false,
184 });
185 }
186
187 Ok(ValidationResult {
189 allowed: true,
190 reason: None,
191 requires_user_approval: false,
192 })
193 }
194}