Skip to main content

steer_core/app/
validation.rs

1use crate::config::LlmConfigProvider;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use steer_tools::ToolCall;
5use steer_tools::tools::bash::BASH_TOOL_NAME;
6use tokio_util::sync::CancellationToken;
7
8#[derive(Debug)]
9pub struct ValidationContext {
10    pub cancellation_token: CancellationToken,
11    pub llm_config_provider: LlmConfigProvider,
12}
13
14#[derive(Debug)]
15pub struct ValidationResult {
16    pub allowed: bool,
17    pub reason: Option<String>,
18    pub requires_user_approval: bool,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ValidationError {
23    #[error("Invalid parameters: {0}")]
24    InvalidParams(String),
25    #[error("IO error: {0}")]
26    Io(#[from] std::io::Error),
27    #[error("Other error: {0}")]
28    Other(String),
29}
30
31#[async_trait]
32pub trait ToolValidator: Send + Sync {
33    fn tool_name(&self) -> &'static str;
34
35    async fn validate(
36        &self,
37        tool_call: &ToolCall,
38        context: &ValidationContext,
39    ) -> Result<ValidationResult, ValidationError>;
40}
41
42pub struct ValidatorRegistry {
43    validators: HashMap<String, Box<dyn ToolValidator>>,
44}
45
46impl Default for ValidatorRegistry {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl ValidatorRegistry {
53    pub fn new() -> Self {
54        let mut validators = HashMap::new();
55
56        // Register tool-specific validators
57        validators.insert(
58            BASH_TOOL_NAME.to_string(),
59            Box::new(BashValidator::new()) as Box<dyn ToolValidator>,
60        );
61
62        Self { validators }
63    }
64
65    pub fn get_validator(&self, tool_name: &str) -> Option<&dyn ToolValidator> {
66        self.validators.get(tool_name).map(|v| v.as_ref())
67    }
68}
69
70// Bash validator implementation
71use regex::Regex;
72use serde::Deserialize;
73
74#[derive(Debug, Clone, Deserialize)]
75pub struct BashParams {
76    pub command: String,
77    pub timeout: Option<u64>,
78}
79
80pub struct BashValidator;
81
82impl Default for BashValidator {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl BashValidator {
89    pub fn new() -> Self {
90        Self
91    }
92
93    /// Check if a command is banned (basic, fast check) - matches src/tools/bash.rs
94    fn is_banned_command(command: &str) -> bool {
95        static BANNED_COMMAND_REGEXES: std::sync::LazyLock<Vec<Regex>> =
96            std::sync::LazyLock::new(|| {
97                let banned_commands = [
98                    // Network tools
99                    "curl",
100                    "wget",
101                    "nc",
102                    "telnet",
103                    "ssh",
104                    "scp",
105                    "ftp",
106                    "sftp",
107                    // Web browsers/clients
108                    "lynx",
109                    "w3m",
110                    "links",
111                    "elinks",
112                    "httpie",
113                    "xh",
114                    "http-prompt",
115                    "chrome",
116                    "firefox",
117                    "safari",
118                    "edge",
119                    "opera",
120                    "chromium",
121                    // Download managers
122                    "axel",
123                    "aria2c",
124                    // Shell utilities that might be risky if misused
125                    "alias",
126                    "unalias",
127                    "exec",
128                    "source",
129                    ".",
130                    "history",
131                    // Potentially dangerous system modification tools
132                    "sudo",
133                    "su",
134                    "chown",
135                    "chmod",
136                    "useradd",
137                    "userdel",
138                    "groupadd",
139                    "groupdel",
140                    // File editors (could be used to modify sensitive files)
141                    "vi",
142                    "vim",
143                    "nano",
144                    "pico",
145                    "emacs",
146                    "ed",
147                ];
148                banned_commands
149                    .iter()
150                    .filter_map(|cmd| {
151                        let pattern = format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd));
152                        match Regex::new(&pattern) {
153                            Ok(regex) => Some(regex),
154                            Err(err) => {
155                                tracing::error!(
156                                    target: "tools::bash",
157                                    command = %cmd,
158                                    error = %err,
159                                    "Failed to compile banned command regex"
160                                );
161                                None
162                            }
163                        }
164                    })
165                    .collect()
166            });
167
168        BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
169    }
170}
171
172#[async_trait]
173impl ToolValidator for BashValidator {
174    fn tool_name(&self) -> &'static str {
175        BASH_TOOL_NAME
176    }
177
178    async fn validate(
179        &self,
180        tool_call: &ToolCall,
181        _context: &ValidationContext,
182    ) -> Result<ValidationResult, ValidationError> {
183        let params: BashParams = serde_json::from_value(tool_call.parameters.clone())
184            .map_err(|e| ValidationError::InvalidParams(e.to_string()))?;
185
186        // First check basic banned commands (fast path)
187        if Self::is_banned_command(&params.command) {
188            return Ok(ValidationResult {
189                allowed: false,
190                reason: Some(format!(
191                    "Command '{}' is disallowed for security reasons",
192                    params.command
193                )),
194                requires_user_approval: false,
195            });
196        }
197
198        // Command passed all checks
199        Ok(ValidationResult {
200            allowed: true,
201            reason: None,
202            requires_user_approval: false,
203        })
204    }
205}