1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6 policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10 pub fn new(policies: Vec<PolicyRule>) -> Self {
11 Self { policies }
12 }
13
14 pub fn with_defaults() -> Self {
15 Self::new(default_policies())
16 }
17
18 pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19 self.policies
20 .iter()
21 .filter(|p| p.enabled)
22 .map(|p| evaluate_policy(p, trace))
23 .collect()
24 }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28 let (result, details) = match &policy.condition {
29 PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30 PolicyCondition::AiPercentageThreshold { threshold } => {
31 eval_ai_percentage(trace, *threshold)
32 }
33 PolicyCondition::ModelAllowlist { allowed_models } => {
34 eval_model_allowlist(trace, allowed_models)
35 }
36 PolicyCondition::SensitivePathPattern { patterns } => eval_sensitive_paths(trace, patterns),
37 PolicyCondition::RequiredToolCall { tool_names } => {
38 eval_required_tool_call(trace, tool_names)
39 }
40 PolicyCondition::TokenBudget {
41 max_tokens,
42 max_cost_usd,
43 } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
44 PolicyCondition::ConditionalToolCall {
45 tool_name,
46 min_count,
47 when_files_match: _,
48 } => eval_conditional_tool_call(trace, tool_name, *min_count),
49 };
50
51 PolicyEvaluation {
52 policy: policy.clone(),
53 result,
54 details,
55 }
56}
57
58fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
59 let has_session = !trace.session.session_id.is_empty();
60 let has_model = trace.model.is_some();
61 let has_attribution =
62 !trace.attribution.files.is_empty() || trace.attribution.summary.total_lines_added == 0;
63
64 if has_session && has_model && has_attribution {
65 (EvalResult::Pass, "Trace is complete".into())
66 } else {
67 let mut missing = vec![];
68 if !has_session {
69 missing.push("session");
70 }
71 if !has_model {
72 missing.push("model");
73 }
74 if !has_attribution {
75 missing.push("attribution");
76 }
77 (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
78 }
79}
80
81fn eval_ai_percentage(trace: &TraceRecord, threshold: f32) -> (EvalResult, String) {
82 let pct = trace.attribution.summary.ai_percentage;
83 if pct > threshold {
84 (
85 EvalResult::Warn,
86 format!("AI percentage {pct:.1}% exceeds threshold {threshold:.1}%"),
87 )
88 } else {
89 (
90 EvalResult::Pass,
91 format!("AI percentage {pct:.1}% within threshold"),
92 )
93 }
94}
95
96fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
97 if allowed.is_empty() {
98 return (EvalResult::Pass, "No allowlist configured".into());
99 }
100 match &trace.model {
101 Some(model) if allowed.iter().any(|a| model.contains(a)) => {
102 (EvalResult::Pass, format!("Model {model} is allowed"))
103 }
104 Some(model) => (
105 EvalResult::Fail,
106 format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
107 ),
108 None => (EvalResult::Fail, "No model specified in trace".into()),
109 }
110}
111
112fn eval_sensitive_paths(trace: &TraceRecord, patterns: &[String]) -> (EvalResult, String) {
113 let matched: Vec<_> = trace
114 .attribution
115 .files
116 .iter()
117 .filter(|f| patterns.iter().any(|p| f.path.contains(p)) && !f.ai_lines.is_empty())
118 .map(|f| f.path.clone())
119 .collect();
120
121 if matched.is_empty() {
122 (EvalResult::Pass, "No sensitive paths with AI code".into())
123 } else {
124 (
125 EvalResult::Warn,
126 format!("AI code in sensitive paths: {}", matched.join(", ")),
127 )
128 }
129}
130
131fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
132 let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
133 let missing: Vec<_> = required
134 .iter()
135 .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
136 .collect();
137
138 if missing.is_empty() {
139 (EvalResult::Pass, "All required tools were used".into())
140 } else {
141 (
142 EvalResult::Warn,
143 format!(
144 "Missing required tools: {}",
145 missing
146 .iter()
147 .map(|s| s.as_str())
148 .collect::<Vec<_>>()
149 .join(", ")
150 ),
151 )
152 }
153}
154
155fn eval_token_budget(
156 trace: &TraceRecord,
157 max_tokens: Option<u64>,
158 max_cost: Option<f64>,
159) -> (EvalResult, String) {
160 let tokens = trace.session.token_usage.total_tokens;
161 let cost = trace.session.token_usage.estimated_cost_usd;
162
163 let token_over = max_tokens.is_some_and(|max| tokens > max);
164 let cost_over = max_cost.is_some_and(|max| cost > max);
165
166 if token_over || cost_over {
167 (
168 EvalResult::Warn,
169 format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
170 )
171 } else {
172 (
173 EvalResult::Pass,
174 format!("Within budget: {tokens} tokens (${cost:.2})"),
175 )
176 }
177}
178
179fn eval_conditional_tool_call(
180 trace: &TraceRecord,
181 tool_name: &str,
182 min_count: Option<u32>,
183) -> (EvalResult, String) {
184 let min = min_count.unwrap_or(1) as usize;
185 let count = trace
186 .session
187 .tools_used
188 .iter()
189 .filter(|t| t.name.contains(tool_name))
190 .count();
191
192 if count >= min {
193 (
194 EvalResult::Pass,
195 format!(
196 "Tool '{}' called {} time(s) (required >= {})",
197 tool_name, count, min
198 ),
199 )
200 } else {
201 (
202 EvalResult::Fail,
203 format!(
204 "Tool '{}' called {} time(s) (required >= {})",
205 tool_name, count, min
206 ),
207 )
208 }
209}
210
211fn default_policies() -> Vec<PolicyRule> {
212 vec![
213 PolicyRule {
214 id: Uuid::nil(),
215 org_id: None,
216 name: "Trace completeness".into(),
217 description: "Every AI commit must have complete trace data".into(),
218 condition: PolicyCondition::TraceCompleteness,
219 action: PolicyAction::Warn,
220 severity: PolicySeverity::Medium,
221 enabled: true,
222 },
223 PolicyRule {
224 id: Uuid::nil(),
225 org_id: None,
226 name: "AI percentage threshold".into(),
227 description: "Warn when AI-generated code exceeds 90%".into(),
228 condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
229 action: PolicyAction::Warn,
230 severity: PolicySeverity::Medium,
231 enabled: true,
232 },
233 PolicyRule {
234 id: Uuid::nil(),
235 org_id: None,
236 name: "Model allowlist".into(),
237 description: "Only approved models may generate code".into(),
238 condition: PolicyCondition::ModelAllowlist {
239 allowed_models: vec![
240 "anthropic/claude".into(),
241 "openai/gpt".into(),
242 "google/gemini".into(),
243 ],
244 },
245 action: PolicyAction::BlockMerge,
246 severity: PolicySeverity::High,
247 enabled: true,
248 },
249 PolicyRule {
250 id: Uuid::nil(),
251 org_id: None,
252 name: "Sensitive path review".into(),
253 description: "AI code in sensitive paths requires review".into(),
254 condition: PolicyCondition::SensitivePathPattern {
255 patterns: vec![
256 "payments".into(),
257 "auth".into(),
258 "security".into(),
259 "crypto".into(),
260 ],
261 },
262 action: PolicyAction::RequireReview,
263 severity: PolicySeverity::High,
264 enabled: true,
265 },
266 PolicyRule {
267 id: Uuid::nil(),
268 org_id: None,
269 name: "Required tool call".into(),
270 description: "Trace must show that tests were run".into(),
271 condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
272 action: PolicyAction::Warn,
273 severity: PolicySeverity::Low,
274 enabled: true,
275 },
276 PolicyRule {
277 id: Uuid::nil(),
278 org_id: None,
279 name: "Token budget".into(),
280 description: "Warn when token usage exceeds budget".into(),
281 condition: PolicyCondition::TokenBudget {
282 max_tokens: Some(500_000),
283 max_cost_usd: Some(50.0),
284 },
285 action: PolicyAction::Warn,
286 severity: PolicySeverity::Medium,
287 enabled: true,
288 },
289 ]
290}