Skip to main content

tracevault_core/
policy_engine.rs

1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6    policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10    pub fn new(policies: Vec<PolicyRule>) -> Self {
11        Self { policies }
12    }
13
14    pub fn with_defaults() -> Self {
15        Self::new(default_policies())
16    }
17
18    pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19        self.policies
20            .iter()
21            .filter(|p| p.enabled)
22            .map(|p| evaluate_policy(p, trace))
23            .collect()
24    }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28    let (result, details) = match &policy.condition {
29        PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30        PolicyCondition::AiPercentageThreshold { threshold } => {
31            eval_ai_percentage(trace, *threshold)
32        }
33        PolicyCondition::ModelAllowlist { allowed_models } => {
34            eval_model_allowlist(trace, allowed_models)
35        }
36        PolicyCondition::SensitivePathPattern { patterns } => {
37            eval_sensitive_paths(trace, patterns)
38        }
39        PolicyCondition::RequiredToolCall { tool_names } => {
40            eval_required_tool_call(trace, tool_names)
41        }
42        PolicyCondition::TokenBudget {
43            max_tokens,
44            max_cost_usd,
45        } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
46        PolicyCondition::ConditionalToolCall {
47            tool_name,
48            min_count,
49            when_files_match: _,
50        } => eval_conditional_tool_call(trace, tool_name, *min_count),
51    };
52
53    PolicyEvaluation {
54        policy: policy.clone(),
55        result,
56        details,
57    }
58}
59
60fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
61    let has_session = !trace.session.session_id.is_empty();
62    let has_model = trace.model.is_some();
63    let has_attribution = !trace.attribution.files.is_empty()
64        || trace.attribution.summary.total_lines_added == 0;
65
66    if has_session && has_model && has_attribution {
67        (EvalResult::Pass, "Trace is complete".into())
68    } else {
69        let mut missing = vec![];
70        if !has_session {
71            missing.push("session");
72        }
73        if !has_model {
74            missing.push("model");
75        }
76        if !has_attribution {
77            missing.push("attribution");
78        }
79        (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
80    }
81}
82
83fn eval_ai_percentage(trace: &TraceRecord, threshold: f32) -> (EvalResult, String) {
84    let pct = trace.attribution.summary.ai_percentage;
85    if pct > threshold {
86        (
87            EvalResult::Warn,
88            format!("AI percentage {pct:.1}% exceeds threshold {threshold:.1}%"),
89        )
90    } else {
91        (
92            EvalResult::Pass,
93            format!("AI percentage {pct:.1}% within threshold"),
94        )
95    }
96}
97
98fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
99    if allowed.is_empty() {
100        return (EvalResult::Pass, "No allowlist configured".into());
101    }
102    match &trace.model {
103        Some(model) if allowed.iter().any(|a| model.contains(a)) => {
104            (EvalResult::Pass, format!("Model {model} is allowed"))
105        }
106        Some(model) => (
107            EvalResult::Fail,
108            format!(
109                "Model {model} is not in allowlist: {}",
110                allowed.join(", ")
111            ),
112        ),
113        None => (EvalResult::Fail, "No model specified in trace".into()),
114    }
115}
116
117fn eval_sensitive_paths(trace: &TraceRecord, patterns: &[String]) -> (EvalResult, String) {
118    let matched: Vec<_> = trace
119        .attribution
120        .files
121        .iter()
122        .filter(|f| patterns.iter().any(|p| f.path.contains(p)) && !f.ai_lines.is_empty())
123        .map(|f| f.path.clone())
124        .collect();
125
126    if matched.is_empty() {
127        (EvalResult::Pass, "No sensitive paths with AI code".into())
128    } else {
129        (
130            EvalResult::Warn,
131            format!("AI code in sensitive paths: {}", matched.join(", ")),
132        )
133    }
134}
135
136fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
137    let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
138    let missing: Vec<_> = required
139        .iter()
140        .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
141        .collect();
142
143    if missing.is_empty() {
144        (EvalResult::Pass, "All required tools were used".into())
145    } else {
146        (
147            EvalResult::Warn,
148            format!(
149                "Missing required tools: {}",
150                missing
151                    .iter()
152                    .map(|s| s.as_str())
153                    .collect::<Vec<_>>()
154                    .join(", ")
155            ),
156        )
157    }
158}
159
160fn eval_token_budget(
161    trace: &TraceRecord,
162    max_tokens: Option<u64>,
163    max_cost: Option<f64>,
164) -> (EvalResult, String) {
165    let tokens = trace.session.token_usage.total_tokens;
166    let cost = trace.session.token_usage.estimated_cost_usd;
167
168    let token_over = max_tokens.is_some_and(|max| tokens > max);
169    let cost_over = max_cost.is_some_and(|max| cost > max);
170
171    if token_over || cost_over {
172        (
173            EvalResult::Warn,
174            format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
175        )
176    } else {
177        (
178            EvalResult::Pass,
179            format!("Within budget: {tokens} tokens (${cost:.2})"),
180        )
181    }
182}
183
184fn eval_conditional_tool_call(
185    trace: &TraceRecord,
186    tool_name: &str,
187    min_count: Option<u32>,
188) -> (EvalResult, String) {
189    let min = min_count.unwrap_or(1) as usize;
190    let count = trace
191        .session
192        .tools_used
193        .iter()
194        .filter(|t| t.name.contains(tool_name))
195        .count();
196
197    if count >= min {
198        (
199            EvalResult::Pass,
200            format!("Tool '{}' called {} time(s) (required >= {})", tool_name, count, min),
201        )
202    } else {
203        (
204            EvalResult::Fail,
205            format!("Tool '{}' called {} time(s) (required >= {})", tool_name, count, min),
206        )
207    }
208}
209
210fn default_policies() -> Vec<PolicyRule> {
211    vec![
212        PolicyRule {
213            id: Uuid::nil(),
214            org_id: None,
215            name: "Trace completeness".into(),
216            description: "Every AI commit must have complete trace data".into(),
217            condition: PolicyCondition::TraceCompleteness,
218            action: PolicyAction::Warn,
219            severity: PolicySeverity::Medium,
220            enabled: true,
221        },
222        PolicyRule {
223            id: Uuid::nil(),
224            org_id: None,
225            name: "AI percentage threshold".into(),
226            description: "Warn when AI-generated code exceeds 90%".into(),
227            condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
228            action: PolicyAction::Warn,
229            severity: PolicySeverity::Medium,
230            enabled: true,
231        },
232        PolicyRule {
233            id: Uuid::nil(),
234            org_id: None,
235            name: "Model allowlist".into(),
236            description: "Only approved models may generate code".into(),
237            condition: PolicyCondition::ModelAllowlist {
238                allowed_models: vec![
239                    "anthropic/claude".into(),
240                    "openai/gpt".into(),
241                    "google/gemini".into(),
242                ],
243            },
244            action: PolicyAction::BlockMerge,
245            severity: PolicySeverity::High,
246            enabled: true,
247        },
248        PolicyRule {
249            id: Uuid::nil(),
250            org_id: None,
251            name: "Sensitive path review".into(),
252            description: "AI code in sensitive paths requires review".into(),
253            condition: PolicyCondition::SensitivePathPattern {
254                patterns: vec![
255                    "payments".into(),
256                    "auth".into(),
257                    "security".into(),
258                    "crypto".into(),
259                ],
260            },
261            action: PolicyAction::RequireReview,
262            severity: PolicySeverity::High,
263            enabled: true,
264        },
265        PolicyRule {
266            id: Uuid::nil(),
267            org_id: None,
268            name: "Required tool call".into(),
269            description: "Trace must show that tests were run".into(),
270            condition: PolicyCondition::RequiredToolCall {
271                tool_names: vec![],
272            },
273            action: PolicyAction::Warn,
274            severity: PolicySeverity::Low,
275            enabled: true,
276        },
277        PolicyRule {
278            id: Uuid::nil(),
279            org_id: None,
280            name: "Token budget".into(),
281            description: "Warn when token usage exceeds budget".into(),
282            condition: PolicyCondition::TokenBudget {
283                max_tokens: Some(500_000),
284                max_cost_usd: Some(50.0),
285            },
286            action: PolicyAction::Warn,
287            severity: PolicySeverity::Medium,
288            enabled: true,
289        },
290    ]
291}