1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6 policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10 pub fn new(policies: Vec<PolicyRule>) -> Self {
11 Self { policies }
12 }
13
14 pub fn with_defaults() -> Self {
15 Self::new(default_policies())
16 }
17
18 pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19 self.policies
20 .iter()
21 .filter(|p| p.enabled)
22 .map(|p| evaluate_policy(p, trace))
23 .collect()
24 }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28 let (result, details) = match &policy.condition {
29 PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30 PolicyCondition::AiPercentageThreshold { .. } => {
31 (
34 EvalResult::Pass,
35 "AI percentage evaluated server-side".into(),
36 )
37 }
38 PolicyCondition::ModelAllowlist { allowed_models } => {
39 eval_model_allowlist(trace, allowed_models)
40 }
41 PolicyCondition::SensitivePathPattern { .. } => {
42 (
44 EvalResult::Pass,
45 "Sensitive path review evaluated server-side".into(),
46 )
47 }
48 PolicyCondition::RequiredToolCall { tool_names } => {
49 eval_required_tool_call(trace, tool_names)
50 }
51 PolicyCondition::TokenBudget {
52 max_tokens,
53 max_cost_usd,
54 } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
55 PolicyCondition::ConditionalToolCall {
56 tool_name,
57 min_count,
58 when_files_match: _,
59 } => eval_conditional_tool_call(trace, tool_name, *min_count),
60 };
61
62 PolicyEvaluation {
63 policy: policy.clone(),
64 result,
65 details,
66 }
67}
68
69fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
70 let has_session = !trace.session.session_id.is_empty();
71 let has_model = trace.model.is_some();
72
73 if has_session && has_model {
74 (EvalResult::Pass, "Trace is complete".into())
75 } else {
76 let mut missing = vec![];
77 if !has_session {
78 missing.push("session");
79 }
80 if !has_model {
81 missing.push("model");
82 }
83 (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
84 }
85}
86
87fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
88 if allowed.is_empty() {
89 return (EvalResult::Pass, "No allowlist configured".into());
90 }
91 match &trace.model {
92 Some(model) if allowed.iter().any(|a| model.contains(a)) => {
93 (EvalResult::Pass, format!("Model {model} is allowed"))
94 }
95 Some(model) => (
96 EvalResult::Fail,
97 format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
98 ),
99 None => (EvalResult::Fail, "No model specified in trace".into()),
100 }
101}
102
103fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
104 let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
105 let missing: Vec<_> = required
106 .iter()
107 .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
108 .collect();
109
110 if missing.is_empty() {
111 (EvalResult::Pass, "All required tools were used".into())
112 } else {
113 (
114 EvalResult::Warn,
115 format!(
116 "Missing required tools: {}",
117 missing
118 .iter()
119 .map(|s| s.as_str())
120 .collect::<Vec<_>>()
121 .join(", ")
122 ),
123 )
124 }
125}
126
127fn eval_token_budget(
128 trace: &TraceRecord,
129 max_tokens: Option<u64>,
130 max_cost: Option<f64>,
131) -> (EvalResult, String) {
132 let tokens = trace.session.token_usage.total_tokens;
133 let cost = trace.session.token_usage.estimated_cost_usd;
134
135 let token_over = max_tokens.is_some_and(|max| tokens > max);
136 let cost_over = max_cost.is_some_and(|max| cost > max);
137
138 if token_over || cost_over {
139 (
140 EvalResult::Warn,
141 format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
142 )
143 } else {
144 (
145 EvalResult::Pass,
146 format!("Within budget: {tokens} tokens (${cost:.2})"),
147 )
148 }
149}
150
151fn eval_conditional_tool_call(
152 trace: &TraceRecord,
153 tool_name: &str,
154 min_count: Option<u32>,
155) -> (EvalResult, String) {
156 let min = min_count.unwrap_or(1) as usize;
157 let count = trace
158 .session
159 .tools_used
160 .iter()
161 .filter(|t| t.name.contains(tool_name))
162 .count();
163
164 if count >= min {
165 (
166 EvalResult::Pass,
167 format!(
168 "Tool '{}' called {} time(s) (required >= {})",
169 tool_name, count, min
170 ),
171 )
172 } else {
173 (
174 EvalResult::Fail,
175 format!(
176 "Tool '{}' called {} time(s) (required >= {})",
177 tool_name, count, min
178 ),
179 )
180 }
181}
182
183fn default_policies() -> Vec<PolicyRule> {
184 vec![
185 PolicyRule {
186 id: Uuid::nil(),
187 org_id: None,
188 name: "Trace completeness".into(),
189 description: "Every AI commit must have complete trace data".into(),
190 condition: PolicyCondition::TraceCompleteness,
191 action: PolicyAction::Warn,
192 severity: PolicySeverity::Medium,
193 enabled: true,
194 },
195 PolicyRule {
196 id: Uuid::nil(),
197 org_id: None,
198 name: "AI percentage threshold".into(),
199 description: "Warn when AI-generated code exceeds 90%".into(),
200 condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
201 action: PolicyAction::Warn,
202 severity: PolicySeverity::Medium,
203 enabled: true,
204 },
205 PolicyRule {
206 id: Uuid::nil(),
207 org_id: None,
208 name: "Model allowlist".into(),
209 description: "Only approved models may generate code".into(),
210 condition: PolicyCondition::ModelAllowlist {
211 allowed_models: vec![
212 "anthropic/claude".into(),
213 "openai/gpt".into(),
214 "google/gemini".into(),
215 ],
216 },
217 action: PolicyAction::BlockMerge,
218 severity: PolicySeverity::High,
219 enabled: true,
220 },
221 PolicyRule {
222 id: Uuid::nil(),
223 org_id: None,
224 name: "Sensitive path review".into(),
225 description: "AI code in sensitive paths requires review".into(),
226 condition: PolicyCondition::SensitivePathPattern {
227 patterns: vec![
228 "payments".into(),
229 "auth".into(),
230 "security".into(),
231 "crypto".into(),
232 ],
233 },
234 action: PolicyAction::RequireReview,
235 severity: PolicySeverity::High,
236 enabled: true,
237 },
238 PolicyRule {
239 id: Uuid::nil(),
240 org_id: None,
241 name: "Required tool call".into(),
242 description: "Trace must show that tests were run".into(),
243 condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
244 action: PolicyAction::Warn,
245 severity: PolicySeverity::Low,
246 enabled: true,
247 },
248 PolicyRule {
249 id: Uuid::nil(),
250 org_id: None,
251 name: "Token budget".into(),
252 description: "Warn when token usage exceeds budget".into(),
253 condition: PolicyCondition::TokenBudget {
254 max_tokens: Some(500_000),
255 max_cost_usd: Some(50.0),
256 },
257 action: PolicyAction::Warn,
258 severity: PolicySeverity::Medium,
259 enabled: true,
260 },
261 ]
262}