1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6 policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10 pub fn new(policies: Vec<PolicyRule>) -> Self {
11 Self { policies }
12 }
13
14 pub fn with_defaults() -> Self {
15 Self::new(default_policies())
16 }
17
18 pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19 self.policies
20 .iter()
21 .filter(|p| p.enabled)
22 .map(|p| evaluate_policy(p, trace))
23 .collect()
24 }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28 let (result, details) = match &policy.condition {
29 PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30 PolicyCondition::AiPercentageThreshold { .. } => {
31 (
34 EvalResult::Pass,
35 "AI percentage evaluated server-side".into(),
36 )
37 }
38 PolicyCondition::ModelAllowlist { allowed_models } => {
39 eval_model_allowlist(trace, allowed_models)
40 }
41 PolicyCondition::SensitivePathPattern { .. } => {
42 (
44 EvalResult::Pass,
45 "Sensitive path review evaluated server-side".into(),
46 )
47 }
48 PolicyCondition::RequiredToolCall { tool_names } => {
49 eval_required_tool_call(trace, tool_names)
50 }
51 PolicyCondition::TokenBudget {
52 max_tokens,
53 max_cost_usd,
54 } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
55 PolicyCondition::ConditionalToolCall {
56 tool_name,
57 min_count,
58 when_files_match: _,
59 } => eval_conditional_tool_call(trace, tool_name, *min_count),
60 };
61
62 PolicyEvaluation {
63 policy: policy.clone(),
64 result,
65 details,
66 }
67}
68
69fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
70 let has_session = !trace.session.session_id.is_empty();
71 let has_model = trace.model.is_some();
72
73 if has_session && has_model {
74 (EvalResult::Pass, "Trace is complete".into())
75 } else {
76 let mut missing = vec![];
77 if !has_session {
78 missing.push("session");
79 }
80 if !has_model {
81 missing.push("model");
82 }
83 (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
84 }
85}
86
87fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
88 if allowed.is_empty() {
89 return (EvalResult::Pass, "No allowlist configured".into());
90 }
91 match &trace.model {
92 Some(model) if allowed.iter().any(|a| model.contains(a)) => {
93 (EvalResult::Pass, format!("Model {model} is allowed"))
94 }
95 Some(model) => (
96 EvalResult::Fail,
97 format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
98 ),
99 None => (EvalResult::Fail, "No model specified in trace".into()),
100 }
101}
102
103fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
104 let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
105 let missing: Vec<_> = required
106 .iter()
107 .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
108 .collect();
109
110 if missing.is_empty() {
111 (EvalResult::Pass, "All required tools were used".into())
112 } else {
113 (
114 EvalResult::Warn,
115 format!(
116 "Missing required tools: {}",
117 missing
118 .iter()
119 .map(|s| s.as_str())
120 .collect::<Vec<_>>()
121 .join(", ")
122 ),
123 )
124 }
125}
126
127fn eval_token_budget(
128 trace: &TraceRecord,
129 max_tokens: Option<u64>,
130 max_cost: Option<f64>,
131) -> (EvalResult, String) {
132 let tokens = trace.session.token_usage.total_tokens;
133 let cost = trace.session.token_usage.estimated_cost_usd;
134
135 let token_over = max_tokens.is_some_and(|max| tokens > max);
136 let cost_over = max_cost.is_some_and(|max| cost > max);
137
138 if token_over || cost_over {
139 (
140 EvalResult::Warn,
141 format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
142 )
143 } else {
144 (
145 EvalResult::Pass,
146 format!("Within budget: {tokens} tokens (${cost:.2})"),
147 )
148 }
149}
150
151fn eval_conditional_tool_call(
152 trace: &TraceRecord,
153 tool_name: &str,
154 min_count: Option<u32>,
155) -> (EvalResult, String) {
156 let min = min_count.unwrap_or(1) as usize;
157 let count = trace
158 .session
159 .tools_used
160 .iter()
161 .filter(|t| t.name.contains(tool_name))
162 .count();
163
164 if count >= min {
165 (
166 EvalResult::Pass,
167 format!(
168 "Tool '{}' called {} time(s) (required >= {})",
169 tool_name, count, min
170 ),
171 )
172 } else {
173 (
174 EvalResult::Fail,
175 format!(
176 "Tool '{}' called {} time(s) (required >= {})",
177 tool_name, count, min
178 ),
179 )
180 }
181}
182
183fn default_policies() -> Vec<PolicyRule> {
184 vec![
185 PolicyRule {
186 id: Uuid::nil(),
187 org_id: None,
188 name: "Trace completeness".into(),
189 description: "Every AI commit must have complete trace data".into(),
190 condition: PolicyCondition::TraceCompleteness,
191 action: PolicyAction::Warn,
192 severity: PolicySeverity::Medium,
193 enabled: true,
194 },
195 PolicyRule {
196 id: Uuid::nil(),
197 org_id: None,
198 name: "AI percentage threshold".into(),
199 description: "Warn when AI-generated code exceeds 90%".into(),
200 condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
201 action: PolicyAction::Warn,
202 severity: PolicySeverity::Medium,
203 enabled: true,
204 },
205 PolicyRule {
206 id: Uuid::nil(),
207 org_id: None,
208 name: "Model allowlist".into(),
209 description: "Only approved models may generate code".into(),
210 condition: PolicyCondition::ModelAllowlist {
211 allowed_models: vec![
212 "anthropic/claude".into(),
213 "openai/gpt".into(),
214 "google/gemini".into(),
215 ],
216 },
217 action: PolicyAction::BlockMerge,
218 severity: PolicySeverity::High,
219 enabled: true,
220 },
221 PolicyRule {
222 id: Uuid::nil(),
223 org_id: None,
224 name: "Sensitive path review".into(),
225 description: "AI code in sensitive paths requires review".into(),
226 condition: PolicyCondition::SensitivePathPattern {
227 patterns: vec![
228 "payments".into(),
229 "auth".into(),
230 "security".into(),
231 "crypto".into(),
232 ],
233 },
234 action: PolicyAction::RequireReview,
235 severity: PolicySeverity::High,
236 enabled: true,
237 },
238 PolicyRule {
239 id: Uuid::nil(),
240 org_id: None,
241 name: "Required tool call".into(),
242 description: "Trace must show that tests were run".into(),
243 condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
244 action: PolicyAction::Warn,
245 severity: PolicySeverity::Low,
246 enabled: true,
247 },
248 PolicyRule {
249 id: Uuid::nil(),
250 org_id: None,
251 name: "Token budget".into(),
252 description: "Warn when token usage exceeds budget".into(),
253 condition: PolicyCondition::TokenBudget {
254 max_tokens: Some(500_000),
255 max_cost_usd: Some(50.0),
256 },
257 action: PolicyAction::Warn,
258 severity: PolicySeverity::Medium,
259 enabled: true,
260 },
261 ]
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::token_usage::TokenUsage;
268 use crate::trace::{Session, ToolCall};
269 use chrono::Utc;
270
271 fn make_trace(session_id: &str, model: Option<&str>) -> TraceRecord {
272 TraceRecord {
273 id: Uuid::nil(),
274 repo_id: "repo".into(),
275 commit_sha: "abc123".into(),
276 branch: None,
277 author: "dev".into(),
278 created_at: Utc::now(),
279 model: model.map(String::from),
280 tool: "claude-code".into(),
281 tool_version: None,
282 session: Session {
283 session_id: session_id.into(),
284 started_at: Utc::now(),
285 ended_at: None,
286 prompts: vec![],
287 responses: vec![],
288 token_usage: TokenUsage::default(),
289 tools_used: vec![],
290 },
291 agent_trace: None,
292 signature: None,
293 }
294 }
295
296 fn make_trace_with_tools(tools: Vec<&str>) -> TraceRecord {
297 let mut trace = make_trace("sess-1", Some("anthropic/claude-3"));
298 trace.session.tools_used = tools
299 .into_iter()
300 .map(|name| ToolCall {
301 name: name.into(),
302 input_summary: String::new(),
303 timestamp: Utc::now(),
304 })
305 .collect();
306 trace
307 }
308
309 fn make_trace_with_tokens(total_tokens: u64, cost: f64) -> TraceRecord {
310 let mut trace = make_trace("sess-1", Some("anthropic/claude-3"));
311 trace.session.token_usage.total_tokens = total_tokens;
312 trace.session.token_usage.estimated_cost_usd = cost;
313 trace
314 }
315
316 #[test]
319 fn completeness_pass_when_both_present() {
320 let trace = make_trace("sess-1", Some("claude-3"));
321 let (result, _) = eval_trace_completeness(&trace);
322 assert_eq!(result, EvalResult::Pass);
323 }
324
325 #[test]
326 fn completeness_fail_missing_session_id() {
327 let trace = make_trace("", Some("claude-3"));
328 let (result, details) = eval_trace_completeness(&trace);
329 assert_eq!(result, EvalResult::Warn);
330 assert!(details.contains("session"));
331 }
332
333 #[test]
334 fn completeness_fail_missing_model() {
335 let trace = make_trace("sess-1", None);
336 let (result, details) = eval_trace_completeness(&trace);
337 assert_eq!(result, EvalResult::Warn);
338 assert!(details.contains("model"));
339 }
340
341 #[test]
342 fn completeness_fail_both_missing() {
343 let trace = make_trace("", None);
344 let (result, details) = eval_trace_completeness(&trace);
345 assert_eq!(result, EvalResult::Warn);
346 assert!(details.contains("session"));
347 assert!(details.contains("model"));
348 }
349
350 #[test]
353 fn model_allowlist_pass_empty_list() {
354 let trace = make_trace("sess-1", Some("anything"));
355 let (result, _) = eval_model_allowlist(&trace, &[]);
356 assert_eq!(result, EvalResult::Pass);
357 }
358
359 #[test]
360 fn model_allowlist_pass_match() {
361 let trace = make_trace("sess-1", Some("anthropic/claude-3"));
362 let allowed = vec!["anthropic/claude".into(), "openai/gpt".into()];
363 let (result, _) = eval_model_allowlist(&trace, &allowed);
364 assert_eq!(result, EvalResult::Pass);
365 }
366
367 #[test]
368 fn model_allowlist_fail_no_match() {
369 let trace = make_trace("sess-1", Some("unknown/model"));
370 let allowed = vec!["anthropic/claude".into(), "openai/gpt".into()];
371 let (result, _) = eval_model_allowlist(&trace, &allowed);
372 assert_eq!(result, EvalResult::Fail);
373 }
374
375 #[test]
378 fn required_tool_call_pass_all_present() {
379 let trace = make_trace_with_tools(vec!["cargo test", "cargo clippy"]);
380 let required = vec!["cargo test".into(), "cargo clippy".into()];
381 let (result, _) = eval_required_tool_call(&trace, &required);
382 assert_eq!(result, EvalResult::Pass);
383 }
384
385 #[test]
386 fn required_tool_call_warn_missing() {
387 let trace = make_trace_with_tools(vec!["cargo test"]);
388 let required = vec!["cargo test".into(), "cargo clippy".into()];
389 let (result, details) = eval_required_tool_call(&trace, &required);
390 assert_eq!(result, EvalResult::Warn);
391 assert!(details.contains("cargo clippy"));
392 }
393
394 #[test]
397 fn token_budget_pass_under_limits() {
398 let trace = make_trace_with_tokens(1000, 1.0);
399 let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
400 assert_eq!(result, EvalResult::Pass);
401 }
402
403 #[test]
404 fn token_budget_warn_over_max_tokens() {
405 let trace = make_trace_with_tokens(10_000, 1.0);
406 let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
407 assert_eq!(result, EvalResult::Warn);
408 }
409
410 #[test]
411 fn token_budget_warn_over_max_cost() {
412 let trace = make_trace_with_tokens(1000, 20.0);
413 let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
414 assert_eq!(result, EvalResult::Warn);
415 }
416
417 #[test]
418 fn token_budget_pass_both_none() {
419 let trace = make_trace_with_tokens(999_999, 999.0);
420 let (result, _) = eval_token_budget(&trace, None, None);
421 assert_eq!(result, EvalResult::Pass);
422 }
423
424 #[test]
427 fn conditional_tool_call_pass_count_met() {
428 let trace = make_trace_with_tools(vec!["cargo test", "cargo test"]);
429 let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(2));
430 assert_eq!(result, EvalResult::Pass);
431 }
432
433 #[test]
434 fn conditional_tool_call_fail_count_not_met() {
435 let trace = make_trace_with_tools(vec!["cargo test"]);
436 let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(3));
437 assert_eq!(result, EvalResult::Fail);
438 }
439
440 #[test]
441 fn conditional_tool_call_fail_absent() {
442 let trace = make_trace_with_tools(vec!["cargo clippy"]);
443 let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(1));
444 assert_eq!(result, EvalResult::Fail);
445 }
446
447 #[test]
450 fn evaluate_skips_disabled_policies() {
451 let policies = vec![
452 PolicyRule {
453 id: Uuid::nil(),
454 org_id: None,
455 name: "enabled".into(),
456 description: String::new(),
457 condition: PolicyCondition::TraceCompleteness,
458 action: PolicyAction::Warn,
459 severity: PolicySeverity::Low,
460 enabled: true,
461 },
462 PolicyRule {
463 id: Uuid::nil(),
464 org_id: None,
465 name: "disabled".into(),
466 description: String::new(),
467 condition: PolicyCondition::TraceCompleteness,
468 action: PolicyAction::Warn,
469 severity: PolicySeverity::Low,
470 enabled: false,
471 },
472 ];
473 let engine = PolicyEngine::new(policies);
474 let trace = make_trace("sess-1", Some("claude-3"));
475 let results = engine.evaluate(&trace);
476 assert_eq!(results.len(), 1);
477 assert_eq!(results[0].policy.name, "enabled");
478 }
479
480 #[test]
481 fn with_defaults_returns_six_enabled() {
482 let engine = PolicyEngine::with_defaults();
483 let trace = make_trace("sess-1", Some("anthropic/claude-3"));
484 let results = engine.evaluate(&trace);
485 assert_eq!(results.len(), 6);
486 assert!(results.iter().all(|r| r.policy.enabled));
487 }
488}