steer_core/app/
validation.rs

1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10    pub cancellation_token: CancellationToken,
11    pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16    pub allowed: bool,
17    pub reason: Option<String>,
18    pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23    #[error("Invalid parameters: {0}")]
24    InvalidParams(String),
25    #[error("IO error: {0}")]
26    Io(#[from] std::io::Error),
27    #[error("Other error: {0}")]
28    Other(String),
29}
30
31#[async_trait]
32pub trait ToolValidator: Send + Sync {
33    fn tool_name(&self) -> &'static str;
34
35    async fn validate(
36        &self,
37        tool_call: &ToolCall,
38        context: &ValidationContext,
39    ) -> Result<ValidationResult, ValidationError>;
40}
41
42pub struct ValidatorRegistry {
43    validators: HashMap<String, Box<dyn ToolValidator>>,
44}
45
46impl Default for ValidatorRegistry {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl ValidatorRegistry {
53    pub fn new() -> Self {
54        let mut validators = HashMap::new();
55
56        // Register tool-specific validators
57        validators.insert(
58            BASH_TOOL_NAME.to_string(),
59            Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
60        );
61
62        Self { validators }
63    }
64
65    pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
66        self.validators.get(tool_name).map(|v| v.as_ref())
67    }
68}
69
70// Bash validator implementation
71use once_cell::sync::Lazy;
72use regex::Regex;
73use serde::Deserialize;
74
75#[derive(Debug, Clone, Deserialize)]
76pub struct BashParams {
77    pub command: String,
78    pub timeout: Option<u64>,
79}
80
81pub struct BashValidator;
82
83impl Default for BashValidator {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl BashValidator {
90    pub fn new() -> Self {
91        Self
92    }
93
94    /// Check if a command is banned (basic, fast check) - matches src/tools/bash.rs
95    fn is_banned_command(&self, command: &str) -> bool {
96        static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
97            let banned_commands = [
98                // Network tools
99                "curl",
100                "wget",
101                "nc",
102                "telnet",
103                "ssh",
104                "scp",
105                "ftp",
106                "sftp",
107                // Web browsers/clients
108                "lynx",
109                "w3m",
110                "links",
111                "elinks",
112                "httpie",
113                "xh",
114                "http-prompt",
115                "chrome",
116                "firefox",
117                "safari",
118                "edge",
119                "opera",
120                "chromium",
121                // Download managers
122                "axel",
123                "aria2c",
124                // Shell utilities that might be risky if misused
125                "alias",
126                "unalias",
127                "exec",
128                "source",
129                ".",
130                "history",
131                // Potentially dangerous system modification tools
132                "sudo",
133                "su",
134                "chown",
135                "chmod",
136                "useradd",
137                "userdel",
138                "groupadd",
139                "groupdel",
140                // File editors (could be used to modify sensitive files)
141                "vi",
142                "vim",
143                "nano",
144                "pico",
145                "emacs",
146                "ed",
147            ];
148            banned_commands
149                .iter()
150                .map(|cmd| {
151                    Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
152                        .expect("Failed to compile banned command regex")
153                })
154                .collect()
155        });
156
157        BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
158    }
159}
160
161#[async_trait]
162impl ToolValidator for BashValidator {
163    fn tool_name(&self) -> &'static str {
164        BASH_TOOL_NAME
165    }
166
167    async fn validate(
168        &self,
169        tool_call: &ToolCall,
170        _context: &ValidationContext,
171    ) -> Result<ValidationResult, ValidationError> {
172        let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
173            .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
174
175        // First check basic banned commands (fast path)
176        if self.is_banned_command(&params.command) {
177            return Ok(ValidationResult {
178                allowed: false,
179                reason: Some(format!(
180                    "Command '{}' is disallowed for security reasons",
181                    params.command
182                )),
183                requires_user_approval: false,
184            });
185        }
186
187        // Command passed all checks
188        Ok(ValidationResult {
189            allowed: true,
190            reason: None,
191            requires_user_approval: false,
192        })
193    }
194}