1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6 policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10 pub fn new(policies: Vec<PolicyRule>) -> Self {
11 Self { policies }
12 }
13
14 pub fn with_defaults() -> Self {
15 Self::new(default_policies())
16 }
17
18 pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19 self.policies
20 .iter()
21 .filter(|p| p.enabled)
22 .map(|p| evaluate_policy(p, trace))
23 .collect()
24 }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28 let (result, details) = match &policy.condition {
29 PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30 PolicyCondition::AiPercentageThreshold { threshold } => {
31 eval_ai_percentage(trace, *threshold)
32 }
33 PolicyCondition::ModelAllowlist { allowed_models } => {
34 eval_model_allowlist(trace, allowed_models)
35 }
36 PolicyCondition::SensitivePathPattern { patterns } => {
37 eval_sensitive_paths(trace, patterns)
38 }
39 PolicyCondition::RequiredToolCall { tool_names } => {
40 eval_required_tool_call(trace, tool_names)
41 }
42 PolicyCondition::TokenBudget {
43 max_tokens,
44 max_cost_usd,
45 } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
46 PolicyCondition::ConditionalToolCall {
47 tool_name,
48 min_count,
49 when_files_match: _,
50 } => eval_conditional_tool_call(trace, tool_name, *min_count),
51 };
52
53 PolicyEvaluation {
54 policy: policy.clone(),
55 result,
56 details,
57 }
58}
59
60fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
61 let has_session = !trace.session.session_id.is_empty();
62 let has_model = trace.model.is_some();
63 let has_attribution = !trace.attribution.files.is_empty()
64 || trace.attribution.summary.total_lines_added == 0;
65
66 if has_session && has_model && has_attribution {
67 (EvalResult::Pass, "Trace is complete".into())
68 } else {
69 let mut missing = vec![];
70 if !has_session {
71 missing.push("session");
72 }
73 if !has_model {
74 missing.push("model");
75 }
76 if !has_attribution {
77 missing.push("attribution");
78 }
79 (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
80 }
81}
82
83fn eval_ai_percentage(trace: &TraceRecord, threshold: f32) -> (EvalResult, String) {
84 let pct = trace.attribution.summary.ai_percentage;
85 if pct > threshold {
86 (
87 EvalResult::Warn,
88 format!("AI percentage {pct:.1}% exceeds threshold {threshold:.1}%"),
89 )
90 } else {
91 (
92 EvalResult::Pass,
93 format!("AI percentage {pct:.1}% within threshold"),
94 )
95 }
96}
97
98fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
99 if allowed.is_empty() {
100 return (EvalResult::Pass, "No allowlist configured".into());
101 }
102 match &trace.model {
103 Some(model) if allowed.iter().any(|a| model.contains(a)) => {
104 (EvalResult::Pass, format!("Model {model} is allowed"))
105 }
106 Some(model) => (
107 EvalResult::Fail,
108 format!(
109 "Model {model} is not in allowlist: {}",
110 allowed.join(", ")
111 ),
112 ),
113 None => (EvalResult::Fail, "No model specified in trace".into()),
114 }
115}
116
117fn eval_sensitive_paths(trace: &TraceRecord, patterns: &[String]) -> (EvalResult, String) {
118 let matched: Vec<_> = trace
119 .attribution
120 .files
121 .iter()
122 .filter(|f| patterns.iter().any(|p| f.path.contains(p)) && !f.ai_lines.is_empty())
123 .map(|f| f.path.clone())
124 .collect();
125
126 if matched.is_empty() {
127 (EvalResult::Pass, "No sensitive paths with AI code".into())
128 } else {
129 (
130 EvalResult::Warn,
131 format!("AI code in sensitive paths: {}", matched.join(", ")),
132 )
133 }
134}
135
136fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
137 let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
138 let missing: Vec<_> = required
139 .iter()
140 .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
141 .collect();
142
143 if missing.is_empty() {
144 (EvalResult::Pass, "All required tools were used".into())
145 } else {
146 (
147 EvalResult::Warn,
148 format!(
149 "Missing required tools: {}",
150 missing
151 .iter()
152 .map(|s| s.as_str())
153 .collect::<Vec<_>>()
154 .join(", ")
155 ),
156 )
157 }
158}
159
160fn eval_token_budget(
161 trace: &TraceRecord,
162 max_tokens: Option<u64>,
163 max_cost: Option<f64>,
164) -> (EvalResult, String) {
165 let tokens = trace.session.token_usage.total_tokens;
166 let cost = trace.session.token_usage.estimated_cost_usd;
167
168 let token_over = max_tokens.is_some_and(|max| tokens > max);
169 let cost_over = max_cost.is_some_and(|max| cost > max);
170
171 if token_over || cost_over {
172 (
173 EvalResult::Warn,
174 format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
175 )
176 } else {
177 (
178 EvalResult::Pass,
179 format!("Within budget: {tokens} tokens (${cost:.2})"),
180 )
181 }
182}
183
184fn eval_conditional_tool_call(
185 trace: &TraceRecord,
186 tool_name: &str,
187 min_count: Option<u32>,
188) -> (EvalResult, String) {
189 let min = min_count.unwrap_or(1) as usize;
190 let count = trace
191 .session
192 .tools_used
193 .iter()
194 .filter(|t| t.name.contains(tool_name))
195 .count();
196
197 if count >= min {
198 (
199 EvalResult::Pass,
200 format!("Tool '{}' called {} time(s) (required >= {})", tool_name, count, min),
201 )
202 } else {
203 (
204 EvalResult::Fail,
205 format!("Tool '{}' called {} time(s) (required >= {})", tool_name, count, min),
206 )
207 }
208}
209
210fn default_policies() -> Vec<PolicyRule> {
211 vec![
212 PolicyRule {
213 id: Uuid::nil(),
214 org_id: None,
215 name: "Trace completeness".into(),
216 description: "Every AI commit must have complete trace data".into(),
217 condition: PolicyCondition::TraceCompleteness,
218 action: PolicyAction::Warn,
219 severity: PolicySeverity::Medium,
220 enabled: true,
221 },
222 PolicyRule {
223 id: Uuid::nil(),
224 org_id: None,
225 name: "AI percentage threshold".into(),
226 description: "Warn when AI-generated code exceeds 90%".into(),
227 condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
228 action: PolicyAction::Warn,
229 severity: PolicySeverity::Medium,
230 enabled: true,
231 },
232 PolicyRule {
233 id: Uuid::nil(),
234 org_id: None,
235 name: "Model allowlist".into(),
236 description: "Only approved models may generate code".into(),
237 condition: PolicyCondition::ModelAllowlist {
238 allowed_models: vec![
239 "anthropic/claude".into(),
240 "openai/gpt".into(),
241 "google/gemini".into(),
242 ],
243 },
244 action: PolicyAction::BlockMerge,
245 severity: PolicySeverity::High,
246 enabled: true,
247 },
248 PolicyRule {
249 id: Uuid::nil(),
250 org_id: None,
251 name: "Sensitive path review".into(),
252 description: "AI code in sensitive paths requires review".into(),
253 condition: PolicyCondition::SensitivePathPattern {
254 patterns: vec![
255 "payments".into(),
256 "auth".into(),
257 "security".into(),
258 "crypto".into(),
259 ],
260 },
261 action: PolicyAction::RequireReview,
262 severity: PolicySeverity::High,
263 enabled: true,
264 },
265 PolicyRule {
266 id: Uuid::nil(),
267 org_id: None,
268 name: "Required tool call".into(),
269 description: "Trace must show that tests were run".into(),
270 condition: PolicyCondition::RequiredToolCall {
271 tool_names: vec![],
272 },
273 action: PolicyAction::Warn,
274 severity: PolicySeverity::Low,
275 enabled: true,
276 },
277 PolicyRule {
278 id: Uuid::nil(),
279 org_id: None,
280 name: "Token budget".into(),
281 description: "Warn when token usage exceeds budget".into(),
282 condition: PolicyCondition::TokenBudget {
283 max_tokens: Some(500_000),
284 max_cost_usd: Some(50.0),
285 },
286 action: PolicyAction::Warn,
287 severity: PolicySeverity::Medium,
288 enabled: true,
289 },
290 ]
291}