Skip to main content

tracevault_core/
policy_engine.rs

1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6    policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10    pub fn new(policies: Vec<PolicyRule>) -> Self {
11        Self { policies }
12    }
13
14    pub fn with_defaults() -> Self {
15        Self::new(default_policies())
16    }
17
18    pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19        self.policies
20            .iter()
21            .filter(|p| p.enabled)
22            .map(|p| evaluate_policy(p, trace))
23            .collect()
24    }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28    let (result, details) = match &policy.condition {
29        PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30        PolicyCondition::AiPercentageThreshold { .. } => {
31            // AI percentage is now computed server-side from commit_attributions.
32            // The client-side policy engine cannot evaluate this.
33            (
34                EvalResult::Pass,
35                "AI percentage evaluated server-side".into(),
36            )
37        }
38        PolicyCondition::ModelAllowlist { allowed_models } => {
39            eval_model_allowlist(trace, allowed_models)
40        }
41        PolicyCondition::SensitivePathPattern { .. } => {
42            // Sensitive path checking requires attribution data, now server-side only.
43            (
44                EvalResult::Pass,
45                "Sensitive path review evaluated server-side".into(),
46            )
47        }
48        PolicyCondition::RequiredToolCall { tool_names } => {
49            eval_required_tool_call(trace, tool_names)
50        }
51        PolicyCondition::TokenBudget {
52            max_tokens,
53            max_cost_usd,
54        } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
55        PolicyCondition::ConditionalToolCall {
56            tool_name,
57            min_count,
58            when_files_match: _,
59        } => eval_conditional_tool_call(trace, tool_name, *min_count),
60    };
61
62    PolicyEvaluation {
63        policy: policy.clone(),
64        result,
65        details,
66    }
67}
68
69fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
70    let has_session = !trace.session.session_id.is_empty();
71    let has_model = trace.model.is_some();
72
73    if has_session && has_model {
74        (EvalResult::Pass, "Trace is complete".into())
75    } else {
76        let mut missing = vec![];
77        if !has_session {
78            missing.push("session");
79        }
80        if !has_model {
81            missing.push("model");
82        }
83        (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
84    }
85}
86
87fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
88    if allowed.is_empty() {
89        return (EvalResult::Pass, "No allowlist configured".into());
90    }
91    match &trace.model {
92        Some(model) if allowed.iter().any(|a| model.contains(a)) => {
93            (EvalResult::Pass, format!("Model {model} is allowed"))
94        }
95        Some(model) => (
96            EvalResult::Fail,
97            format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
98        ),
99        None => (EvalResult::Fail, "No model specified in trace".into()),
100    }
101}
102
103fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
104    let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
105    let missing: Vec<_> = required
106        .iter()
107        .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
108        .collect();
109
110    if missing.is_empty() {
111        (EvalResult::Pass, "All required tools were used".into())
112    } else {
113        (
114            EvalResult::Warn,
115            format!(
116                "Missing required tools: {}",
117                missing
118                    .iter()
119                    .map(|s| s.as_str())
120                    .collect::<Vec<_>>()
121                    .join(", ")
122            ),
123        )
124    }
125}
126
127fn eval_token_budget(
128    trace: &TraceRecord,
129    max_tokens: Option<u64>,
130    max_cost: Option<f64>,
131) -> (EvalResult, String) {
132    let tokens = trace.session.token_usage.total_tokens;
133    let cost = trace.session.token_usage.estimated_cost_usd;
134
135    let token_over = max_tokens.is_some_and(|max| tokens > max);
136    let cost_over = max_cost.is_some_and(|max| cost > max);
137
138    if token_over || cost_over {
139        (
140            EvalResult::Warn,
141            format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
142        )
143    } else {
144        (
145            EvalResult::Pass,
146            format!("Within budget: {tokens} tokens (${cost:.2})"),
147        )
148    }
149}
150
151fn eval_conditional_tool_call(
152    trace: &TraceRecord,
153    tool_name: &str,
154    min_count: Option<u32>,
155) -> (EvalResult, String) {
156    let min = min_count.unwrap_or(1) as usize;
157    let count = trace
158        .session
159        .tools_used
160        .iter()
161        .filter(|t| t.name.contains(tool_name))
162        .count();
163
164    if count >= min {
165        (
166            EvalResult::Pass,
167            format!(
168                "Tool '{}' called {} time(s) (required >= {})",
169                tool_name, count, min
170            ),
171        )
172    } else {
173        (
174            EvalResult::Fail,
175            format!(
176                "Tool '{}' called {} time(s) (required >= {})",
177                tool_name, count, min
178            ),
179        )
180    }
181}
182
183fn default_policies() -> Vec<PolicyRule> {
184    vec![
185        PolicyRule {
186            id: Uuid::nil(),
187            org_id: None,
188            name: "Trace completeness".into(),
189            description: "Every AI commit must have complete trace data".into(),
190            condition: PolicyCondition::TraceCompleteness,
191            action: PolicyAction::Warn,
192            severity: PolicySeverity::Medium,
193            enabled: true,
194        },
195        PolicyRule {
196            id: Uuid::nil(),
197            org_id: None,
198            name: "AI percentage threshold".into(),
199            description: "Warn when AI-generated code exceeds 90%".into(),
200            condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
201            action: PolicyAction::Warn,
202            severity: PolicySeverity::Medium,
203            enabled: true,
204        },
205        PolicyRule {
206            id: Uuid::nil(),
207            org_id: None,
208            name: "Model allowlist".into(),
209            description: "Only approved models may generate code".into(),
210            condition: PolicyCondition::ModelAllowlist {
211                allowed_models: vec![
212                    "anthropic/claude".into(),
213                    "openai/gpt".into(),
214                    "google/gemini".into(),
215                ],
216            },
217            action: PolicyAction::BlockMerge,
218            severity: PolicySeverity::High,
219            enabled: true,
220        },
221        PolicyRule {
222            id: Uuid::nil(),
223            org_id: None,
224            name: "Sensitive path review".into(),
225            description: "AI code in sensitive paths requires review".into(),
226            condition: PolicyCondition::SensitivePathPattern {
227                patterns: vec![
228                    "payments".into(),
229                    "auth".into(),
230                    "security".into(),
231                    "crypto".into(),
232                ],
233            },
234            action: PolicyAction::RequireReview,
235            severity: PolicySeverity::High,
236            enabled: true,
237        },
238        PolicyRule {
239            id: Uuid::nil(),
240            org_id: None,
241            name: "Required tool call".into(),
242            description: "Trace must show that tests were run".into(),
243            condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
244            action: PolicyAction::Warn,
245            severity: PolicySeverity::Low,
246            enabled: true,
247        },
248        PolicyRule {
249            id: Uuid::nil(),
250            org_id: None,
251            name: "Token budget".into(),
252            description: "Warn when token usage exceeds budget".into(),
253            condition: PolicyCondition::TokenBudget {
254                max_tokens: Some(500_000),
255                max_cost_usd: Some(50.0),
256            },
257            action: PolicyAction::Warn,
258            severity: PolicySeverity::Medium,
259            enabled: true,
260        },
261    ]
262}