1use crate::TRonError;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::RwLock;
7
8#[non_exhaustive]
9pub enum PolicyResult {
10 Allow,
11 Deny(String),
12 UnknownAgent,
14 UnknownTool,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct AgentPolicy {
21 #[serde(default)]
22 pub allow: Vec<String>,
23 #[serde(default)]
24 pub deny: Vec<String>,
25 #[serde(default)]
26 pub rate_limit: Option<RateLimitPolicy>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RateLimitPolicy {
32 pub calls_per_minute: u64,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct PolicyConfig {
38 #[serde(default)]
39 pub agent: HashMap<String, AgentPolicy>,
40}
41
42pub struct PolicyEngine {
43 config: RwLock<PolicyConfig>,
44}
45
46impl Default for PolicyEngine {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl PolicyEngine {
53 pub fn new() -> Self {
54 Self {
55 config: RwLock::new(PolicyConfig::default()),
56 }
57 }
58
59 #[must_use]
61 pub fn config(&self) -> PolicyConfig {
62 self.config
63 .read()
64 .unwrap_or_else(|poisoned| poisoned.into_inner())
65 .clone()
66 }
67
68 pub fn load_toml(&self, toml_str: &str) -> Result<(), TRonError> {
70 let config: PolicyConfig =
71 toml::from_str(toml_str).map_err(|e| TRonError::PolicyConfig(e.to_string()))?;
72 let mut guard = self
73 .config
74 .write()
75 .unwrap_or_else(|poisoned| poisoned.into_inner());
76 *guard = config;
77 tracing::info!("policy reloaded");
78 Ok(())
79 }
80
81 #[must_use]
83 pub fn check(&self, agent_id: &str, tool_name: &str) -> PolicyResult {
84 let config = self
85 .config
86 .read()
87 .unwrap_or_else(|poisoned| poisoned.into_inner());
88
89 let policy = match config.agent.get(agent_id) {
90 Some(p) => p,
91 None => return PolicyResult::UnknownAgent,
92 };
93
94 for pattern in &policy.deny {
96 if matches_glob(pattern, tool_name) {
97 return PolicyResult::Deny(format!(
98 "tool '{tool_name}' denied by policy for agent '{agent_id}'"
99 ));
100 }
101 }
102
103 for pattern in &policy.allow {
105 if matches_glob(pattern, tool_name) {
106 return PolicyResult::Allow;
107 }
108 }
109
110 PolicyResult::UnknownTool
112 }
113
114 pub fn grant(&self, agent_id: &str, pattern: &str) {
116 let mut config = self
117 .config
118 .write()
119 .unwrap_or_else(|poisoned| poisoned.into_inner());
120 let policy = config
121 .agent
122 .entry(agent_id.to_string())
123 .or_insert_with(|| AgentPolicy {
124 allow: vec![],
125 deny: vec![],
126 rate_limit: None,
127 });
128 policy.allow.push(pattern.to_string());
129 }
130
131 pub fn revoke(&self, agent_id: &str, pattern: &str) {
133 let mut config = self
134 .config
135 .write()
136 .unwrap_or_else(|poisoned| poisoned.into_inner());
137 let policy = config
138 .agent
139 .entry(agent_id.to_string())
140 .or_insert_with(|| AgentPolicy {
141 allow: vec![],
142 deny: vec![],
143 rate_limit: None,
144 });
145 policy.deny.push(pattern.to_string());
146 }
147}
148
149#[inline]
151fn matches_glob(pattern: &str, name: &str) -> bool {
152 if pattern == "*" {
153 return true;
154 }
155 if let Some(prefix) = pattern.strip_suffix('*') {
156 name.starts_with(prefix)
157 } else {
158 pattern == name
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn glob_wildcard() {
168 assert!(matches_glob("*", "anything"));
169 assert!(matches_glob("tarang_*", "tarang_probe"));
170 assert!(matches_glob("tarang_*", "tarang_analyze"));
171 assert!(!matches_glob("tarang_*", "rasa_edit"));
172 assert!(matches_glob("aegis_quarantine", "aegis_quarantine"));
173 assert!(!matches_glob("aegis_quarantine", "aegis_scan"));
174 }
175
176 #[test]
177 fn policy_deny_wins() {
178 let engine = PolicyEngine::new();
179 engine.grant("agent-1", "tarang_*");
180 engine.revoke("agent-1", "tarang_delete");
181
182 assert!(matches!(
183 engine.check("agent-1", "tarang_probe"),
184 PolicyResult::Allow
185 ));
186 assert!(matches!(
187 engine.check("agent-1", "tarang_delete"),
188 PolicyResult::Deny(_)
189 ));
190 }
191
192 #[test]
193 fn unknown_agent() {
194 let engine = PolicyEngine::new();
195 assert!(matches!(
196 engine.check("nobody", "any_tool"),
197 PolicyResult::UnknownAgent
198 ));
199 }
200
201 #[test]
202 fn load_toml_policy() {
203 let engine = PolicyEngine::new();
204 let toml = r#"
205[agent."web-agent"]
206allow = ["tarang_*", "rasa_*"]
207deny = ["aegis_*"]
208"#;
209 engine.load_toml(toml).unwrap();
210 assert!(matches!(
211 engine.check("web-agent", "tarang_probe"),
212 PolicyResult::Allow
213 ));
214 assert!(matches!(
215 engine.check("web-agent", "aegis_scan"),
216 PolicyResult::Deny(_)
217 ));
218 }
219
220 #[test]
221 fn unknown_tool_for_known_agent() {
222 let engine = PolicyEngine::new();
223 engine.grant("agent-1", "tarang_*");
224 assert!(matches!(
226 engine.check("agent-1", "rasa_edit"),
227 PolicyResult::UnknownTool
228 ));
229 }
230
231 #[test]
232 fn malformed_toml_error() {
233 let engine = PolicyEngine::new();
234 let result = engine.load_toml("this is not valid toml {{{}}}");
235 assert!(result.is_err());
236 }
237
238 #[test]
239 fn deny_only_policy() {
240 let engine = PolicyEngine::new();
241 let toml = r#"
242[agent."lockdown"]
243deny = ["*"]
244"#;
245 engine.load_toml(toml).unwrap();
246 assert!(matches!(
247 engine.check("lockdown", "anything"),
248 PolicyResult::Deny(_)
249 ));
250 }
251
252 #[test]
253 fn allow_only_policy() {
254 let engine = PolicyEngine::new();
255 let toml = r#"
256[agent."open"]
257allow = ["*"]
258"#;
259 engine.load_toml(toml).unwrap();
260 assert!(matches!(
261 engine.check("open", "anything"),
262 PolicyResult::Allow
263 ));
264 }
265
266 #[test]
267 fn reload_policy_replaces_previous() {
268 let engine = PolicyEngine::new();
269 engine.grant("agent-1", "tarang_*");
270 assert!(matches!(
271 engine.check("agent-1", "tarang_probe"),
272 PolicyResult::Allow
273 ));
274
275 engine.load_toml("").unwrap();
277 assert!(matches!(
278 engine.check("agent-1", "tarang_probe"),
279 PolicyResult::UnknownAgent
280 ));
281 }
282
283 #[test]
284 fn multiple_agents_in_policy() {
285 let engine = PolicyEngine::new();
286 let toml = r#"
287[agent."reader"]
288allow = ["tarang_*"]
289
290[agent."admin"]
291allow = ["*"]
292deny = ["ark_remove"]
293"#;
294 engine.load_toml(toml).unwrap();
295 assert!(matches!(
296 engine.check("reader", "tarang_probe"),
297 PolicyResult::Allow
298 ));
299 assert!(matches!(
300 engine.check("reader", "aegis_scan"),
301 PolicyResult::UnknownTool
302 ));
303 assert!(matches!(
304 engine.check("admin", "aegis_scan"),
305 PolicyResult::Allow
306 ));
307 assert!(matches!(
308 engine.check("admin", "ark_remove"),
309 PolicyResult::Deny(_)
310 ));
311 }
312
313 #[test]
314 fn empty_pattern_no_match() {
315 assert!(!matches_glob("", "anything"));
316 assert!(matches_glob("", ""));
317 }
318
319 #[test]
320 fn glob_star_suffix_only() {
321 assert!(!matches_glob("*_delete", "tarang_delete"));
323 }
324
325 #[test]
326 fn rate_limit_parsed_from_toml() {
327 let engine = PolicyEngine::new();
328 let toml = r#"
329[agent."limited"]
330allow = ["*"]
331[agent."limited".rate_limit]
332calls_per_minute = 10
333
334[agent."unlimited"]
335allow = ["*"]
336"#;
337 engine.load_toml(toml).unwrap();
338 let config = engine.config();
339 let limited = config.agent.get("limited").unwrap();
340 assert_eq!(limited.rate_limit.as_ref().unwrap().calls_per_minute, 10);
341 let unlimited = config.agent.get("unlimited").unwrap();
342 assert!(unlimited.rate_limit.is_none());
343 }
344
345 #[test]
346 fn config_snapshot() {
347 let engine = PolicyEngine::new();
348 engine.grant("agent-1", "tarang_*");
349 let config = engine.config();
350 assert!(config.agent.contains_key("agent-1"));
351 assert_eq!(config.agent["agent-1"].allow, vec!["tarang_*"]);
352 }
353}