Skip to main content

tracevault_core/
policy_engine.rs

1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6    policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10    pub fn new(policies: Vec<PolicyRule>) -> Self {
11        Self { policies }
12    }
13
14    pub fn with_defaults() -> Self {
15        Self::new(default_policies())
16    }
17
18    pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19        self.policies
20            .iter()
21            .filter(|p| p.enabled)
22            .map(|p| evaluate_policy(p, trace))
23            .collect()
24    }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28    let (result, details) = match &policy.condition {
29        PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30        PolicyCondition::AiPercentageThreshold { threshold } => {
31            eval_ai_percentage(trace, *threshold)
32        }
33        PolicyCondition::ModelAllowlist { allowed_models } => {
34            eval_model_allowlist(trace, allowed_models)
35        }
36        PolicyCondition::SensitivePathPattern { patterns } => eval_sensitive_paths(trace, patterns),
37        PolicyCondition::RequiredToolCall { tool_names } => {
38            eval_required_tool_call(trace, tool_names)
39        }
40        PolicyCondition::TokenBudget {
41            max_tokens,
42            max_cost_usd,
43        } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
44        PolicyCondition::ConditionalToolCall {
45            tool_name,
46            min_count,
47            when_files_match: _,
48        } => eval_conditional_tool_call(trace, tool_name, *min_count),
49    };
50
51    PolicyEvaluation {
52        policy: policy.clone(),
53        result,
54        details,
55    }
56}
57
58fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
59    let has_session = !trace.session.session_id.is_empty();
60    let has_model = trace.model.is_some();
61    let has_attribution =
62        !trace.attribution.files.is_empty() || trace.attribution.summary.total_lines_added == 0;
63
64    if has_session && has_model && has_attribution {
65        (EvalResult::Pass, "Trace is complete".into())
66    } else {
67        let mut missing = vec![];
68        if !has_session {
69            missing.push("session");
70        }
71        if !has_model {
72            missing.push("model");
73        }
74        if !has_attribution {
75            missing.push("attribution");
76        }
77        (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
78    }
79}
80
81fn eval_ai_percentage(trace: &TraceRecord, threshold: f32) -> (EvalResult, String) {
82    let pct = trace.attribution.summary.ai_percentage;
83    if pct > threshold {
84        (
85            EvalResult::Warn,
86            format!("AI percentage {pct:.1}% exceeds threshold {threshold:.1}%"),
87        )
88    } else {
89        (
90            EvalResult::Pass,
91            format!("AI percentage {pct:.1}% within threshold"),
92        )
93    }
94}
95
96fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
97    if allowed.is_empty() {
98        return (EvalResult::Pass, "No allowlist configured".into());
99    }
100    match &trace.model {
101        Some(model) if allowed.iter().any(|a| model.contains(a)) => {
102            (EvalResult::Pass, format!("Model {model} is allowed"))
103        }
104        Some(model) => (
105            EvalResult::Fail,
106            format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
107        ),
108        None => (EvalResult::Fail, "No model specified in trace".into()),
109    }
110}
111
112fn eval_sensitive_paths(trace: &TraceRecord, patterns: &[String]) -> (EvalResult, String) {
113    let matched: Vec<_> = trace
114        .attribution
115        .files
116        .iter()
117        .filter(|f| patterns.iter().any(|p| f.path.contains(p)) && !f.ai_lines.is_empty())
118        .map(|f| f.path.clone())
119        .collect();
120
121    if matched.is_empty() {
122        (EvalResult::Pass, "No sensitive paths with AI code".into())
123    } else {
124        (
125            EvalResult::Warn,
126            format!("AI code in sensitive paths: {}", matched.join(", ")),
127        )
128    }
129}
130
131fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
132    let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
133    let missing: Vec<_> = required
134        .iter()
135        .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
136        .collect();
137
138    if missing.is_empty() {
139        (EvalResult::Pass, "All required tools were used".into())
140    } else {
141        (
142            EvalResult::Warn,
143            format!(
144                "Missing required tools: {}",
145                missing
146                    .iter()
147                    .map(|s| s.as_str())
148                    .collect::<Vec<_>>()
149                    .join(", ")
150            ),
151        )
152    }
153}
154
155fn eval_token_budget(
156    trace: &TraceRecord,
157    max_tokens: Option<u64>,
158    max_cost: Option<f64>,
159) -> (EvalResult, String) {
160    let tokens = trace.session.token_usage.total_tokens;
161    let cost = trace.session.token_usage.estimated_cost_usd;
162
163    let token_over = max_tokens.is_some_and(|max| tokens > max);
164    let cost_over = max_cost.is_some_and(|max| cost > max);
165
166    if token_over || cost_over {
167        (
168            EvalResult::Warn,
169            format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
170        )
171    } else {
172        (
173            EvalResult::Pass,
174            format!("Within budget: {tokens} tokens (${cost:.2})"),
175        )
176    }
177}
178
179fn eval_conditional_tool_call(
180    trace: &TraceRecord,
181    tool_name: &str,
182    min_count: Option<u32>,
183) -> (EvalResult, String) {
184    let min = min_count.unwrap_or(1) as usize;
185    let count = trace
186        .session
187        .tools_used
188        .iter()
189        .filter(|t| t.name.contains(tool_name))
190        .count();
191
192    if count >= min {
193        (
194            EvalResult::Pass,
195            format!(
196                "Tool '{}' called {} time(s) (required >= {})",
197                tool_name, count, min
198            ),
199        )
200    } else {
201        (
202            EvalResult::Fail,
203            format!(
204                "Tool '{}' called {} time(s) (required >= {})",
205                tool_name, count, min
206            ),
207        )
208    }
209}
210
211fn default_policies() -> Vec<PolicyRule> {
212    vec![
213        PolicyRule {
214            id: Uuid::nil(),
215            org_id: None,
216            name: "Trace completeness".into(),
217            description: "Every AI commit must have complete trace data".into(),
218            condition: PolicyCondition::TraceCompleteness,
219            action: PolicyAction::Warn,
220            severity: PolicySeverity::Medium,
221            enabled: true,
222        },
223        PolicyRule {
224            id: Uuid::nil(),
225            org_id: None,
226            name: "AI percentage threshold".into(),
227            description: "Warn when AI-generated code exceeds 90%".into(),
228            condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
229            action: PolicyAction::Warn,
230            severity: PolicySeverity::Medium,
231            enabled: true,
232        },
233        PolicyRule {
234            id: Uuid::nil(),
235            org_id: None,
236            name: "Model allowlist".into(),
237            description: "Only approved models may generate code".into(),
238            condition: PolicyCondition::ModelAllowlist {
239                allowed_models: vec![
240                    "anthropic/claude".into(),
241                    "openai/gpt".into(),
242                    "google/gemini".into(),
243                ],
244            },
245            action: PolicyAction::BlockMerge,
246            severity: PolicySeverity::High,
247            enabled: true,
248        },
249        PolicyRule {
250            id: Uuid::nil(),
251            org_id: None,
252            name: "Sensitive path review".into(),
253            description: "AI code in sensitive paths requires review".into(),
254            condition: PolicyCondition::SensitivePathPattern {
255                patterns: vec![
256                    "payments".into(),
257                    "auth".into(),
258                    "security".into(),
259                    "crypto".into(),
260                ],
261            },
262            action: PolicyAction::RequireReview,
263            severity: PolicySeverity::High,
264            enabled: true,
265        },
266        PolicyRule {
267            id: Uuid::nil(),
268            org_id: None,
269            name: "Required tool call".into(),
270            description: "Trace must show that tests were run".into(),
271            condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
272            action: PolicyAction::Warn,
273            severity: PolicySeverity::Low,
274            enabled: true,
275        },
276        PolicyRule {
277            id: Uuid::nil(),
278            org_id: None,
279            name: "Token budget".into(),
280            description: "Warn when token usage exceeds budget".into(),
281            condition: PolicyCondition::TokenBudget {
282                max_tokens: Some(500_000),
283                max_cost_usd: Some(50.0),
284            },
285            action: PolicyAction::Warn,
286            severity: PolicySeverity::Medium,
287            enabled: true,
288        },
289    ]
290}