steer_core/app/
validation.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use steer_tools::tools::bash::BASH_TOOL_NAME;
4use tokio_util::sync::CancellationToken;
5
6use crate::api::ToolCall;
7use crate::config::LlmConfigProvider;
8
9#[derive(Debug)]
10pub struct ValidationContext {
11    pub cancellation_token: CancellationToken,
12    pub llm_config_provider: LlmConfigProvider,
13}
14
15#[derive(Debug)]
16pub struct ValidationResult {
17    pub allowed: bool,
18    pub reason: Option<String>,
19    pub requires_user_approval: bool,
20}
21
22#[derive(Debug, thiserror::Error)]
23pub enum ValidationError {
24    #[error("Invalid parameters: {0}")]
25    InvalidParams(String),
26    #[error("Filter error: {0}")]
27    FilterError(String),
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30    #[error("Other error: {0}")]
31    Other(String),
32}
33
34#[async_trait]
35pub trait ToolValidator: Send + Sync {
36    fn tool_name(&self) -> &'static str;
37
38    async fn validate(
39        &self,
40        tool_call: &ToolCall,
41        context: &ValidationContext,
42    ) -> Result<ValidationResult, ValidationError>;
43}
44
45pub struct ValidatorRegistry {
46    validators: HashMap<String, Box<dyn ToolValidator>>,
47}
48
49impl Default for ValidatorRegistry {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl ValidatorRegistry {
56    pub fn new() -> Self {
57        let mut validators = HashMap::new();
58
59        // Register tool-specific validators
60        validators.insert(
61            BASH_TOOL_NAME.to_string(),
62            Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
63        );
64
65        Self { validators }
66    }
67
68    pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
69        self.validators.get(tool_name).map(|v| v.as_ref())
70    }
71}
72
73// Bash validator implementation
74use once_cell::sync::Lazy;
75use regex::Regex;
76use serde::Deserialize;
77
78#[derive(Debug, Clone, Deserialize)]
79pub struct BashParams {
80    pub command: String,
81    pub timeout: Option<u64>,
82}
83
84pub struct BashValidator;
85
86impl Default for BashValidator {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl BashValidator {
93    pub fn new() -> Self {
94        Self
95    }
96
97    /// Check if a command is banned (basic, fast check) - matches src/tools/bash.rs
98    fn is_banned_command(&self, command: &str) -> bool {
99        static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
100            let banned_commands = [
101                // Network tools
102                "curl",
103                "wget",
104                "nc",
105                "telnet",
106                "ssh",
107                "scp",
108                "ftp",
109                "sftp",
110                // Web browsers/clients
111                "lynx",
112                "w3m",
113                "links",
114                "elinks",
115                "httpie",
116                "xh",
117                "http-prompt",
118                "chrome",
119                "firefox",
120                "safari",
121                "edge",
122                "opera",
123                "chromium",
124                // Download managers
125                "axel",
126                "aria2c",
127                // Shell utilities that might be risky if misused
128                "alias",
129                "unalias",
130                "exec",
131                "source",
132                ".",
133                "history",
134                // Potentially dangerous system modification tools
135                "sudo",
136                "su",
137                "chown",
138                "chmod",
139                "useradd",
140                "userdel",
141                "groupadd",
142                "groupdel",
143                // File editors (could be used to modify sensitive files)
144                "vi",
145                "vim",
146                "nano",
147                "pico",
148                "emacs",
149                "ed",
150            ];
151            banned_commands
152                .iter()
153                .map(|cmd| {
154                    Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
155                        .expect("Failed to compile banned command regex")
156                })
157                .collect()
158        });
159
160        BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
161    }
162
163    /// Advanced command filtering using existing command_filter module
164    async fn is_command_allowed(
165        &self,
166        command: &str,
167        context: &ValidationContext,
168    ) -> Result<bool, ValidationError> {
169        crate::tools::command_filter::is_command_allowed(
170            command,
171            &context.llm_config_provider,
172            context.cancellation_token.clone(),
173        )
174        .await
175        .map_err(|e| ValidationError::FilterError(e.to_string()))
176    }
177}
178
179#[async_trait]
180impl ToolValidator for BashValidator {
181    fn tool_name(&self) -> &'static str {
182        BASH_TOOL_NAME
183    }
184
185    async fn validate(
186        &self,
187        tool_call: &ToolCall,
188        context: &ValidationContext,
189    ) -> Result<ValidationResult, ValidationError> {
190        let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
191            .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
192
193        // First check basic banned commands (fast path)
194        if self.is_banned_command(&params.command) {
195            return Ok(ValidationResult {
196                allowed: false,
197                reason: Some(format!(
198                    "Command '{}' is disallowed for security reasons",
199                    params.command
200                )),
201                requires_user_approval: false,
202            });
203        }
204
205        // Advanced filter check using LLM (slower, more thorough)
206        let is_allowed = self.is_command_allowed(&params.command, context).await?;
207
208        if !is_allowed {
209            return Ok(ValidationResult {
210                allowed: false,
211                reason: Some(format!(
212                    "Command '{}' was blocked by the command filter",
213                    params.command
214                )),
215                requires_user_approval: false,
216            });
217        }
218
219        // Command passed all checks
220        Ok(ValidationResult {
221            allowed: true,
222            reason: None,
223            requires_user_approval: false,
224        })
225    }
226}