tool_useful/
security.rs

1//! Security primitives: permissions, resource limits, and sandboxing.
2
3use crate::{ToolError, ToolResult};
4use parking_lot::RwLock;
5use std::collections::HashSet;
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10/// Permission system for tools
11#[derive(Debug, Clone)]
12pub struct Permissions {
13    pub network: NetworkPermission,
14    pub filesystem: FileSystemPermission,
15    pub max_memory_bytes: Option<usize>,
16    pub max_cpu_time: Option<Duration>,
17    pub allowed_syscalls: Option<HashSet<String>>,
18}
19
20impl Default for Permissions {
21    fn default() -> Self {
22        Self {
23            network: NetworkPermission::Deny,
24            filesystem: FileSystemPermission::Deny,
25            max_memory_bytes: Some(100_000_000), // 100MB default
26            max_cpu_time: Some(Duration::from_secs(30)),
27            allowed_syscalls: None,
28        }
29    }
30}
31
32impl Permissions {
33    pub fn unrestricted() -> Self {
34        Self {
35            network: NetworkPermission::Allow,
36            filesystem: FileSystemPermission::Allow,
37            max_memory_bytes: None,
38            max_cpu_time: None,
39            allowed_syscalls: None,
40        }
41    }
42
43    pub fn builder() -> PermissionsBuilder {
44        PermissionsBuilder::new()
45    }
46
47    pub fn check_network_access(&self, host: &str) -> ToolResult<()> {
48        match &self.network {
49            NetworkPermission::Allow => Ok(()),
50            NetworkPermission::Deny => Err(ToolError::permission_denied("Network access denied")),
51            NetworkPermission::AllowList(allowed) => {
52                if allowed.iter().any(|pattern| matches_pattern(host, pattern)) {
53                    Ok(())
54                } else {
55                    Err(ToolError::permission_denied(format!(
56                        "Network access to {} not allowed",
57                        host
58                    )))
59                }
60            }
61            NetworkPermission::DenyList(denied) => {
62                if denied.iter().any(|pattern| matches_pattern(host, pattern)) {
63                    Err(ToolError::permission_denied(format!(
64                        "Network access to {} explicitly denied",
65                        host
66                    )))
67                } else {
68                    Ok(())
69                }
70            }
71        }
72    }
73
74    pub fn check_file_access(&self, path: &std::path::Path) -> ToolResult<()> {
75        match &self.filesystem {
76            FileSystemPermission::Allow => Ok(()),
77            FileSystemPermission::Deny => {
78                Err(ToolError::permission_denied("Filesystem access denied"))
79            }
80            FileSystemPermission::ReadOnly(paths) => {
81                if paths.iter().any(|allowed| path.starts_with(allowed)) {
82                    Ok(())
83                } else {
84                    Err(ToolError::permission_denied(format!(
85                        "Filesystem access to {:?} not allowed",
86                        path
87                    )))
88                }
89            }
90            FileSystemPermission::AllowList(paths) => {
91                if paths.iter().any(|allowed| path.starts_with(allowed)) {
92                    Ok(())
93                } else {
94                    Err(ToolError::permission_denied(format!(
95                        "Filesystem access to {:?} not allowed",
96                        path
97                    )))
98                }
99            }
100        }
101    }
102}
103
104#[derive(Debug, Clone)]
105pub enum NetworkPermission {
106    Allow,
107    Deny,
108    AllowList(Vec<String>),
109    DenyList(Vec<String>),
110}
111
112#[derive(Debug, Clone)]
113pub enum FileSystemPermission {
114    Allow,
115    Deny,
116    ReadOnly(Vec<std::path::PathBuf>),
117    AllowList(Vec<std::path::PathBuf>),
118}
119
120pub struct PermissionsBuilder {
121    permissions: Permissions,
122}
123
124impl PermissionsBuilder {
125    pub fn new() -> Self {
126        Self {
127            permissions: Permissions::default(),
128        }
129    }
130
131    pub fn allow_network(mut self) -> Self {
132        self.permissions.network = NetworkPermission::Allow;
133        self
134    }
135
136    pub fn allow_network_hosts(mut self, hosts: Vec<String>) -> Self {
137        self.permissions.network = NetworkPermission::AllowList(hosts);
138        self
139    }
140
141    pub fn deny_network(mut self) -> Self {
142        self.permissions.network = NetworkPermission::Deny;
143        self
144    }
145
146    pub fn allow_filesystem(mut self) -> Self {
147        self.permissions.filesystem = FileSystemPermission::Allow;
148        self
149    }
150
151    pub fn allow_filesystem_paths(mut self, paths: Vec<std::path::PathBuf>) -> Self {
152        self.permissions.filesystem = FileSystemPermission::AllowList(paths);
153        self
154    }
155
156    pub fn readonly_filesystem(mut self, paths: Vec<std::path::PathBuf>) -> Self {
157        self.permissions.filesystem = FileSystemPermission::ReadOnly(paths);
158        self
159    }
160
161    pub fn max_memory(mut self, bytes: usize) -> Self {
162        self.permissions.max_memory_bytes = Some(bytes);
163        self
164    }
165
166    pub fn max_cpu_time(mut self, duration: Duration) -> Self {
167        self.permissions.max_cpu_time = Some(duration);
168        self
169    }
170
171    pub fn build(self) -> Permissions {
172        self.permissions
173    }
174}
175
176impl Default for PermissionsBuilder {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182/// Resource tracker for monitoring and enforcing limits
183pub struct ResourceTracker {
184    permissions: Arc<Permissions>,
185    memory_used: AtomicUsize,
186    start_time: Instant,
187    cpu_time: Arc<RwLock<Duration>>,
188}
189
190impl ResourceTracker {
191    pub fn new(permissions: Permissions) -> Self {
192        Self {
193            permissions: Arc::new(permissions),
194            memory_used: AtomicUsize::new(0),
195            start_time: Instant::now(),
196            cpu_time: Arc::new(RwLock::new(Duration::ZERO)),
197        }
198    }
199
200    pub fn track_memory_allocation(&self, bytes: usize) -> ToolResult<()> {
201        let new_total = self.memory_used.fetch_add(bytes, Ordering::Relaxed) + bytes;
202
203        if let Some(max) = self.permissions.max_memory_bytes {
204            if new_total > max {
205                return Err(ToolError::ResourceLimitExceeded(format!(
206                    "Memory limit exceeded: {} > {}",
207                    new_total, max
208                )));
209            }
210        }
211
212        Ok(())
213    }
214
215    pub fn track_memory_deallocation(&self, bytes: usize) {
216        self.memory_used.fetch_sub(bytes, Ordering::Relaxed);
217    }
218
219    pub fn check_cpu_time(&self) -> ToolResult<()> {
220        if let Some(max_time) = self.permissions.max_cpu_time {
221            let elapsed = self.start_time.elapsed();
222            if elapsed > max_time {
223                return Err(ToolError::ResourceLimitExceeded(format!(
224                    "CPU time limit exceeded: {:?} > {:?}",
225                    elapsed, max_time
226                )));
227            }
228        }
229        Ok(())
230    }
231
232    pub fn memory_usage(&self) -> usize {
233        self.memory_used.load(Ordering::Relaxed)
234    }
235
236    pub fn elapsed_time(&self) -> Duration {
237        self.start_time.elapsed()
238    }
239}
240
241/// Rate limiter for API calls or tool executions
242pub struct RateLimiter {
243    tokens: Arc<AtomicU64>,
244    max_tokens: u64,
245    refill_rate: u64,
246    last_refill: Arc<RwLock<Instant>>,
247    refill_interval: Duration,
248}
249
250impl RateLimiter {
251    pub fn new(max_requests: u64, duration: Duration) -> Self {
252        Self {
253            tokens: Arc::new(AtomicU64::new(max_requests)),
254            max_tokens: max_requests,
255            refill_rate: max_requests,
256            last_refill: Arc::new(RwLock::new(Instant::now())),
257            refill_interval: duration,
258        }
259    }
260
261    pub fn per_second(requests: u64) -> Self {
262        Self::new(requests, Duration::from_secs(1))
263    }
264
265    pub fn per_minute(requests: u64) -> Self {
266        Self::new(requests, Duration::from_secs(60))
267    }
268
269    pub async fn acquire(&self) -> ToolResult<()> {
270        loop {
271            self.refill();
272
273            let current = self.tokens.load(Ordering::Relaxed);
274            if current > 0 {
275                if self
276                    .tokens
277                    .compare_exchange(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
278                    .is_ok()
279                {
280                    return Ok(());
281                }
282            } else {
283                // Wait a bit before retrying
284                tokio::time::sleep(Duration::from_millis(10)).await;
285            }
286        }
287    }
288
289    fn refill(&self) {
290        let mut last_refill = self.last_refill.write();
291        let elapsed = last_refill.elapsed();
292
293        if elapsed >= self.refill_interval {
294            let periods = elapsed.as_secs_f64() / self.refill_interval.as_secs_f64();
295            let tokens_to_add = (self.refill_rate as f64 * periods) as u64;
296
297            let current = self.tokens.load(Ordering::Relaxed);
298            let new_tokens = (current + tokens_to_add).min(self.max_tokens);
299            self.tokens.store(new_tokens, Ordering::Relaxed);
300
301            *last_refill = Instant::now();
302        }
303    }
304}
305
306fn matches_pattern(text: &str, pattern: &str) -> bool {
307    if pattern.contains('*') {
308        // Simple wildcard matching
309        let parts: Vec<&str> = pattern.split('*').collect();
310        if parts.is_empty() {
311            return true;
312        }
313
314        let mut pos = 0;
315        for (i, part) in parts.iter().enumerate() {
316            if i == 0 && !part.is_empty() {
317                if !text.starts_with(part) {
318                    return false;
319                }
320                pos = part.len();
321            } else if i == parts.len() - 1 && !part.is_empty() {
322                if !text.ends_with(part) {
323                    return false;
324                }
325            } else if !part.is_empty() {
326                if let Some(index) = text[pos..].find(part) {
327                    pos += index + part.len();
328                } else {
329                    return false;
330                }
331            }
332        }
333        true
334    } else {
335        text == pattern
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_pattern_matching() {
345        assert!(matches_pattern("api.example.com", "*.example.com"));
346        assert!(matches_pattern("api.example.com", "api.*"));
347        assert!(matches_pattern("api.example.com", "*"));
348        assert!(!matches_pattern("api.example.com", "*.other.com"));
349    }
350
351    #[test]
352    fn test_resource_tracker() {
353        let permissions = Permissions::builder().max_memory(1000).build();
354
355        let tracker = ResourceTracker::new(permissions);
356
357        assert!(tracker.track_memory_allocation(500).is_ok());
358        assert!(tracker.track_memory_allocation(400).is_ok());
359        assert!(tracker.track_memory_allocation(200).is_err()); // Exceeds limit
360    }
361
362    #[tokio::test]
363    async fn test_rate_limiter() {
364        let limiter = RateLimiter::per_second(2);
365
366        assert!(limiter.acquire().await.is_ok());
367        assert!(limiter.acquire().await.is_ok());
368        // Third request should wait
369    }
370}