1use crate::{ToolError, ToolResult};
4use parking_lot::RwLock;
5use std::collections::HashSet;
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
12pub struct Permissions {
13 pub network: NetworkPermission,
14 pub filesystem: FileSystemPermission,
15 pub max_memory_bytes: Option<usize>,
16 pub max_cpu_time: Option<Duration>,
17 pub allowed_syscalls: Option<HashSet<String>>,
18}
19
20impl Default for Permissions {
21 fn default() -> Self {
22 Self {
23 network: NetworkPermission::Deny,
24 filesystem: FileSystemPermission::Deny,
25 max_memory_bytes: Some(100_000_000), max_cpu_time: Some(Duration::from_secs(30)),
27 allowed_syscalls: None,
28 }
29 }
30}
31
32impl Permissions {
33 pub fn unrestricted() -> Self {
34 Self {
35 network: NetworkPermission::Allow,
36 filesystem: FileSystemPermission::Allow,
37 max_memory_bytes: None,
38 max_cpu_time: None,
39 allowed_syscalls: None,
40 }
41 }
42
43 pub fn builder() -> PermissionsBuilder {
44 PermissionsBuilder::new()
45 }
46
47 pub fn check_network_access(&self, host: &str) -> ToolResult<()> {
48 match &self.network {
49 NetworkPermission::Allow => Ok(()),
50 NetworkPermission::Deny => Err(ToolError::permission_denied("Network access denied")),
51 NetworkPermission::AllowList(allowed) => {
52 if allowed.iter().any(|pattern| matches_pattern(host, pattern)) {
53 Ok(())
54 } else {
55 Err(ToolError::permission_denied(format!(
56 "Network access to {} not allowed",
57 host
58 )))
59 }
60 }
61 NetworkPermission::DenyList(denied) => {
62 if denied.iter().any(|pattern| matches_pattern(host, pattern)) {
63 Err(ToolError::permission_denied(format!(
64 "Network access to {} explicitly denied",
65 host
66 )))
67 } else {
68 Ok(())
69 }
70 }
71 }
72 }
73
74 pub fn check_file_access(&self, path: &std::path::Path) -> ToolResult<()> {
75 match &self.filesystem {
76 FileSystemPermission::Allow => Ok(()),
77 FileSystemPermission::Deny => {
78 Err(ToolError::permission_denied("Filesystem access denied"))
79 }
80 FileSystemPermission::ReadOnly(paths) => {
81 if paths.iter().any(|allowed| path.starts_with(allowed)) {
82 Ok(())
83 } else {
84 Err(ToolError::permission_denied(format!(
85 "Filesystem access to {:?} not allowed",
86 path
87 )))
88 }
89 }
90 FileSystemPermission::AllowList(paths) => {
91 if paths.iter().any(|allowed| path.starts_with(allowed)) {
92 Ok(())
93 } else {
94 Err(ToolError::permission_denied(format!(
95 "Filesystem access to {:?} not allowed",
96 path
97 )))
98 }
99 }
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
105pub enum NetworkPermission {
106 Allow,
107 Deny,
108 AllowList(Vec<String>),
109 DenyList(Vec<String>),
110}
111
112#[derive(Debug, Clone)]
113pub enum FileSystemPermission {
114 Allow,
115 Deny,
116 ReadOnly(Vec<std::path::PathBuf>),
117 AllowList(Vec<std::path::PathBuf>),
118}
119
120pub struct PermissionsBuilder {
121 permissions: Permissions,
122}
123
124impl PermissionsBuilder {
125 pub fn new() -> Self {
126 Self {
127 permissions: Permissions::default(),
128 }
129 }
130
131 pub fn allow_network(mut self) -> Self {
132 self.permissions.network = NetworkPermission::Allow;
133 self
134 }
135
136 pub fn allow_network_hosts(mut self, hosts: Vec<String>) -> Self {
137 self.permissions.network = NetworkPermission::AllowList(hosts);
138 self
139 }
140
141 pub fn deny_network(mut self) -> Self {
142 self.permissions.network = NetworkPermission::Deny;
143 self
144 }
145
146 pub fn allow_filesystem(mut self) -> Self {
147 self.permissions.filesystem = FileSystemPermission::Allow;
148 self
149 }
150
151 pub fn allow_filesystem_paths(mut self, paths: Vec<std::path::PathBuf>) -> Self {
152 self.permissions.filesystem = FileSystemPermission::AllowList(paths);
153 self
154 }
155
156 pub fn readonly_filesystem(mut self, paths: Vec<std::path::PathBuf>) -> Self {
157 self.permissions.filesystem = FileSystemPermission::ReadOnly(paths);
158 self
159 }
160
161 pub fn max_memory(mut self, bytes: usize) -> Self {
162 self.permissions.max_memory_bytes = Some(bytes);
163 self
164 }
165
166 pub fn max_cpu_time(mut self, duration: Duration) -> Self {
167 self.permissions.max_cpu_time = Some(duration);
168 self
169 }
170
171 pub fn build(self) -> Permissions {
172 self.permissions
173 }
174}
175
176impl Default for PermissionsBuilder {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182pub struct ResourceTracker {
184 permissions: Arc<Permissions>,
185 memory_used: AtomicUsize,
186 start_time: Instant,
187 cpu_time: Arc<RwLock<Duration>>,
188}
189
190impl ResourceTracker {
191 pub fn new(permissions: Permissions) -> Self {
192 Self {
193 permissions: Arc::new(permissions),
194 memory_used: AtomicUsize::new(0),
195 start_time: Instant::now(),
196 cpu_time: Arc::new(RwLock::new(Duration::ZERO)),
197 }
198 }
199
200 pub fn track_memory_allocation(&self, bytes: usize) -> ToolResult<()> {
201 let new_total = self.memory_used.fetch_add(bytes, Ordering::Relaxed) + bytes;
202
203 if let Some(max) = self.permissions.max_memory_bytes {
204 if new_total > max {
205 return Err(ToolError::ResourceLimitExceeded(format!(
206 "Memory limit exceeded: {} > {}",
207 new_total, max
208 )));
209 }
210 }
211
212 Ok(())
213 }
214
215 pub fn track_memory_deallocation(&self, bytes: usize) {
216 self.memory_used.fetch_sub(bytes, Ordering::Relaxed);
217 }
218
219 pub fn check_cpu_time(&self) -> ToolResult<()> {
220 if let Some(max_time) = self.permissions.max_cpu_time {
221 let elapsed = self.start_time.elapsed();
222 if elapsed > max_time {
223 return Err(ToolError::ResourceLimitExceeded(format!(
224 "CPU time limit exceeded: {:?} > {:?}",
225 elapsed, max_time
226 )));
227 }
228 }
229 Ok(())
230 }
231
232 pub fn memory_usage(&self) -> usize {
233 self.memory_used.load(Ordering::Relaxed)
234 }
235
236 pub fn elapsed_time(&self) -> Duration {
237 self.start_time.elapsed()
238 }
239}
240
241pub struct RateLimiter {
243 tokens: Arc<AtomicU64>,
244 max_tokens: u64,
245 refill_rate: u64,
246 last_refill: Arc<RwLock<Instant>>,
247 refill_interval: Duration,
248}
249
250impl RateLimiter {
251 pub fn new(max_requests: u64, duration: Duration) -> Self {
252 Self {
253 tokens: Arc::new(AtomicU64::new(max_requests)),
254 max_tokens: max_requests,
255 refill_rate: max_requests,
256 last_refill: Arc::new(RwLock::new(Instant::now())),
257 refill_interval: duration,
258 }
259 }
260
261 pub fn per_second(requests: u64) -> Self {
262 Self::new(requests, Duration::from_secs(1))
263 }
264
265 pub fn per_minute(requests: u64) -> Self {
266 Self::new(requests, Duration::from_secs(60))
267 }
268
269 pub async fn acquire(&self) -> ToolResult<()> {
270 loop {
271 self.refill();
272
273 let current = self.tokens.load(Ordering::Relaxed);
274 if current > 0 {
275 if self
276 .tokens
277 .compare_exchange(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
278 .is_ok()
279 {
280 return Ok(());
281 }
282 } else {
283 tokio::time::sleep(Duration::from_millis(10)).await;
285 }
286 }
287 }
288
289 fn refill(&self) {
290 let mut last_refill = self.last_refill.write();
291 let elapsed = last_refill.elapsed();
292
293 if elapsed >= self.refill_interval {
294 let periods = elapsed.as_secs_f64() / self.refill_interval.as_secs_f64();
295 let tokens_to_add = (self.refill_rate as f64 * periods) as u64;
296
297 let current = self.tokens.load(Ordering::Relaxed);
298 let new_tokens = (current + tokens_to_add).min(self.max_tokens);
299 self.tokens.store(new_tokens, Ordering::Relaxed);
300
301 *last_refill = Instant::now();
302 }
303 }
304}
305
306fn matches_pattern(text: &str, pattern: &str) -> bool {
307 if pattern.contains('*') {
308 let parts: Vec<&str> = pattern.split('*').collect();
310 if parts.is_empty() {
311 return true;
312 }
313
314 let mut pos = 0;
315 for (i, part) in parts.iter().enumerate() {
316 if i == 0 && !part.is_empty() {
317 if !text.starts_with(part) {
318 return false;
319 }
320 pos = part.len();
321 } else if i == parts.len() - 1 && !part.is_empty() {
322 if !text.ends_with(part) {
323 return false;
324 }
325 } else if !part.is_empty() {
326 if let Some(index) = text[pos..].find(part) {
327 pos += index + part.len();
328 } else {
329 return false;
330 }
331 }
332 }
333 true
334 } else {
335 text == pattern
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_pattern_matching() {
345 assert!(matches_pattern("api.example.com", "*.example.com"));
346 assert!(matches_pattern("api.example.com", "api.*"));
347 assert!(matches_pattern("api.example.com", "*"));
348 assert!(!matches_pattern("api.example.com", "*.other.com"));
349 }
350
351 #[test]
352 fn test_resource_tracker() {
353 let permissions = Permissions::builder().max_memory(1000).build();
354
355 let tracker = ResourceTracker::new(permissions);
356
357 assert!(tracker.track_memory_allocation(500).is_ok());
358 assert!(tracker.track_memory_allocation(400).is_ok());
359 assert!(tracker.track_memory_allocation(200).is_err()); }
361
362 #[tokio::test]
363 async fn test_rate_limiter() {
364 let limiter = RateLimiter::per_second(2);
365
366 assert!(limiter.acquire().await.is_ok());
367 assert!(limiter.acquire().await.is_ok());
368 }
370}