steer_core/app/
validation.rs

1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10    pub cancellation_token: CancellationToken,
11    pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16    pub allowed: bool,
17    pub reason: Option<String>,
18    pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23    #[error("Invalid parameters: {0}")]
24    InvalidParams(String),
25    #[error("Filter error: {0}")]
26    FilterError(String),
27    #[error("IO error: {0}")]
28    Io(#[from] std::io::Error),
29    #[error("Other error: {0}")]
30    Other(String),
31}
32
33#[async_trait]
34pub trait ToolValidator: Send + Sync {
35    fn tool_name(&self) -> &'static str;
36
37    async fn validate(
38        &self,
39        tool_call: &ToolCall,
40        context: &ValidationContext,
41    ) -> Result<ValidationResult, ValidationError>;
42}
43
44pub struct ValidatorRegistry {
45    validators: HashMap<String, Box<dyn ToolValidator>>,
46}
47
48impl Default for ValidatorRegistry {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl ValidatorRegistry {
55    pub fn new() -> Self {
56        let mut validators = HashMap::new();
57
58        // Register tool-specific validators
59        validators.insert(
60            BASH_TOOL_NAME.to_string(),
61            Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
62        );
63
64        Self { validators }
65    }
66
67    pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
68        self.validators.get(tool_name).map(|v| v.as_ref())
69    }
70}
71
72// Bash validator implementation
73use once_cell::sync::Lazy;
74use regex::Regex;
75use serde::Deserialize;
76
77#[derive(Debug, Clone, Deserialize)]
78pub struct BashParams {
79    pub command: String,
80    pub timeout: Option<u64>,
81}
82
83pub struct BashValidator;
84
85impl Default for BashValidator {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl BashValidator {
92    pub fn new() -> Self {
93        Self
94    }
95
96    /// Check if a command is banned (basic, fast check) - matches src/tools/bash.rs
97    fn is_banned_command(&self, command: &str) -> bool {
98        static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
99            let banned_commands = [
100                // Network tools
101                "curl",
102                "wget",
103                "nc",
104                "telnet",
105                "ssh",
106                "scp",
107                "ftp",
108                "sftp",
109                // Web browsers/clients
110                "lynx",
111                "w3m",
112                "links",
113                "elinks",
114                "httpie",
115                "xh",
116                "http-prompt",
117                "chrome",
118                "firefox",
119                "safari",
120                "edge",
121                "opera",
122                "chromium",
123                // Download managers
124                "axel",
125                "aria2c",
126                // Shell utilities that might be risky if misused
127                "alias",
128                "unalias",
129                "exec",
130                "source",
131                ".",
132                "history",
133                // Potentially dangerous system modification tools
134                "sudo",
135                "su",
136                "chown",
137                "chmod",
138                "useradd",
139                "userdel",
140                "groupadd",
141                "groupdel",
142                // File editors (could be used to modify sensitive files)
143                "vi",
144                "vim",
145                "nano",
146                "pico",
147                "emacs",
148                "ed",
149            ];
150            banned_commands
151                .iter()
152                .map(|cmd| {
153                    Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
154                        .expect("Failed to compile banned command regex")
155                })
156                .collect()
157        });
158
159        BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
160    }
161
162    /// Advanced command filtering using existing command_filter module
163    async fn is_command_allowed(
164        &self,
165        command: &str,
166        context: &ValidationContext,
167    ) -> Result<bool, ValidationError> {
168        crate::tools::command_filter::is_command_allowed(
169            command,
170            &context.llm_config_provider,
171            context.cancellation_token.clone(),
172        )
173        .await
174        .map_err(|e| ValidationError::FilterError(e.to_string()))
175    }
176}
177
178#[async_trait]
179impl ToolValidator for BashValidator {
180    fn tool_name(&self) -> &'static str {
181        BASH_TOOL_NAME
182    }
183
184    async fn validate(
185        &self,
186        tool_call: &ToolCall,
187        context: &ValidationContext,
188    ) -> Result<ValidationResult, ValidationError> {
189        let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
190            .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
191
192        // First check basic banned commands (fast path)
193        if self.is_banned_command(&params.command) {
194            return Ok(ValidationResult {
195                allowed: false,
196                reason: Some(format!(
197                    "Command '{}' is disallowed for security reasons",
198                    params.command
199                )),
200                requires_user_approval: false,
201            });
202        }
203
204        // Advanced filter check using LLM (slower, more thorough)
205        let is_allowed = self.is_command_allowed(&params.command, context).await?;
206
207        if !is_allowed {
208            return Ok(ValidationResult {
209                allowed: false,
210                reason: Some(format!(
211                    "Command '{}' was blocked by the command filter",
212                    params.command
213                )),
214                requires_user_approval: false,
215            });
216        }
217
218        // Command passed all checks
219        Ok(ValidationResult {
220            allowed: true,
221            reason: None,
222            requires_user_approval: false,
223        })
224    }
225}