Skip to main content

tracevault_core/
policy_engine.rs

1use crate::policy::*;
2use crate::trace::TraceRecord;
3use uuid::Uuid;
4
5pub struct PolicyEngine {
6    policies: Vec<PolicyRule>,
7}
8
9impl PolicyEngine {
10    pub fn new(policies: Vec<PolicyRule>) -> Self {
11        Self { policies }
12    }
13
14    pub fn with_defaults() -> Self {
15        Self::new(default_policies())
16    }
17
18    pub fn evaluate(&self, trace: &TraceRecord) -> Vec<PolicyEvaluation> {
19        self.policies
20            .iter()
21            .filter(|p| p.enabled)
22            .map(|p| evaluate_policy(p, trace))
23            .collect()
24    }
25}
26
27fn evaluate_policy(policy: &PolicyRule, trace: &TraceRecord) -> PolicyEvaluation {
28    let (result, details) = match &policy.condition {
29        PolicyCondition::TraceCompleteness => eval_trace_completeness(trace),
30        PolicyCondition::AiPercentageThreshold { .. } => {
31            // AI percentage is now computed server-side from commit_attributions.
32            // The client-side policy engine cannot evaluate this.
33            (
34                EvalResult::Pass,
35                "AI percentage evaluated server-side".into(),
36            )
37        }
38        PolicyCondition::ModelAllowlist { allowed_models } => {
39            eval_model_allowlist(trace, allowed_models)
40        }
41        PolicyCondition::SensitivePathPattern { .. } => {
42            // Sensitive path checking requires attribution data, now server-side only.
43            (
44                EvalResult::Pass,
45                "Sensitive path review evaluated server-side".into(),
46            )
47        }
48        PolicyCondition::RequiredToolCall { tool_names } => {
49            eval_required_tool_call(trace, tool_names)
50        }
51        PolicyCondition::TokenBudget {
52            max_tokens,
53            max_cost_usd,
54        } => eval_token_budget(trace, *max_tokens, *max_cost_usd),
55        PolicyCondition::ConditionalToolCall {
56            tool_name,
57            min_count,
58            when_files_match: _,
59        } => eval_conditional_tool_call(trace, tool_name, *min_count),
60    };
61
62    PolicyEvaluation {
63        policy: policy.clone(),
64        result,
65        details,
66    }
67}
68
69fn eval_trace_completeness(trace: &TraceRecord) -> (EvalResult, String) {
70    let has_session = !trace.session.session_id.is_empty();
71    let has_model = trace.model.is_some();
72
73    if has_session && has_model {
74        (EvalResult::Pass, "Trace is complete".into())
75    } else {
76        let mut missing = vec![];
77        if !has_session {
78            missing.push("session");
79        }
80        if !has_model {
81            missing.push("model");
82        }
83        (EvalResult::Warn, format!("Missing: {}", missing.join(", ")))
84    }
85}
86
87fn eval_model_allowlist(trace: &TraceRecord, allowed: &[String]) -> (EvalResult, String) {
88    if allowed.is_empty() {
89        return (EvalResult::Pass, "No allowlist configured".into());
90    }
91    match &trace.model {
92        Some(model) if allowed.iter().any(|a| model.contains(a)) => {
93            (EvalResult::Pass, format!("Model {model} is allowed"))
94        }
95        Some(model) => (
96            EvalResult::Fail,
97            format!("Model {model} is not in allowlist: {}", allowed.join(", ")),
98        ),
99        None => (EvalResult::Fail, "No model specified in trace".into()),
100    }
101}
102
103fn eval_required_tool_call(trace: &TraceRecord, required: &[String]) -> (EvalResult, String) {
104    let used: Vec<_> = trace.session.tools_used.iter().map(|t| &t.name).collect();
105    let missing: Vec<_> = required
106        .iter()
107        .filter(|r| !used.iter().any(|u| u.contains(r.as_str())))
108        .collect();
109
110    if missing.is_empty() {
111        (EvalResult::Pass, "All required tools were used".into())
112    } else {
113        (
114            EvalResult::Warn,
115            format!(
116                "Missing required tools: {}",
117                missing
118                    .iter()
119                    .map(|s| s.as_str())
120                    .collect::<Vec<_>>()
121                    .join(", ")
122            ),
123        )
124    }
125}
126
127fn eval_token_budget(
128    trace: &TraceRecord,
129    max_tokens: Option<u64>,
130    max_cost: Option<f64>,
131) -> (EvalResult, String) {
132    let tokens = trace.session.token_usage.total_tokens;
133    let cost = trace.session.token_usage.estimated_cost_usd;
134
135    let token_over = max_tokens.is_some_and(|max| tokens > max);
136    let cost_over = max_cost.is_some_and(|max| cost > max);
137
138    if token_over || cost_over {
139        (
140            EvalResult::Warn,
141            format!("Budget exceeded: {tokens} tokens (${cost:.2})"),
142        )
143    } else {
144        (
145            EvalResult::Pass,
146            format!("Within budget: {tokens} tokens (${cost:.2})"),
147        )
148    }
149}
150
151fn eval_conditional_tool_call(
152    trace: &TraceRecord,
153    tool_name: &str,
154    min_count: Option<u32>,
155) -> (EvalResult, String) {
156    let min = min_count.unwrap_or(1) as usize;
157    let count = trace
158        .session
159        .tools_used
160        .iter()
161        .filter(|t| t.name.contains(tool_name))
162        .count();
163
164    if count >= min {
165        (
166            EvalResult::Pass,
167            format!(
168                "Tool '{}' called {} time(s) (required >= {})",
169                tool_name, count, min
170            ),
171        )
172    } else {
173        (
174            EvalResult::Fail,
175            format!(
176                "Tool '{}' called {} time(s) (required >= {})",
177                tool_name, count, min
178            ),
179        )
180    }
181}
182
183fn default_policies() -> Vec<PolicyRule> {
184    vec![
185        PolicyRule {
186            id: Uuid::nil(),
187            org_id: None,
188            name: "Trace completeness".into(),
189            description: "Every AI commit must have complete trace data".into(),
190            condition: PolicyCondition::TraceCompleteness,
191            action: PolicyAction::Warn,
192            severity: PolicySeverity::Medium,
193            enabled: true,
194        },
195        PolicyRule {
196            id: Uuid::nil(),
197            org_id: None,
198            name: "AI percentage threshold".into(),
199            description: "Warn when AI-generated code exceeds 90%".into(),
200            condition: PolicyCondition::AiPercentageThreshold { threshold: 90.0 },
201            action: PolicyAction::Warn,
202            severity: PolicySeverity::Medium,
203            enabled: true,
204        },
205        PolicyRule {
206            id: Uuid::nil(),
207            org_id: None,
208            name: "Model allowlist".into(),
209            description: "Only approved models may generate code".into(),
210            condition: PolicyCondition::ModelAllowlist {
211                allowed_models: vec![
212                    "anthropic/claude".into(),
213                    "openai/gpt".into(),
214                    "google/gemini".into(),
215                ],
216            },
217            action: PolicyAction::BlockMerge,
218            severity: PolicySeverity::High,
219            enabled: true,
220        },
221        PolicyRule {
222            id: Uuid::nil(),
223            org_id: None,
224            name: "Sensitive path review".into(),
225            description: "AI code in sensitive paths requires review".into(),
226            condition: PolicyCondition::SensitivePathPattern {
227                patterns: vec![
228                    "payments".into(),
229                    "auth".into(),
230                    "security".into(),
231                    "crypto".into(),
232                ],
233            },
234            action: PolicyAction::RequireReview,
235            severity: PolicySeverity::High,
236            enabled: true,
237        },
238        PolicyRule {
239            id: Uuid::nil(),
240            org_id: None,
241            name: "Required tool call".into(),
242            description: "Trace must show that tests were run".into(),
243            condition: PolicyCondition::RequiredToolCall { tool_names: vec![] },
244            action: PolicyAction::Warn,
245            severity: PolicySeverity::Low,
246            enabled: true,
247        },
248        PolicyRule {
249            id: Uuid::nil(),
250            org_id: None,
251            name: "Token budget".into(),
252            description: "Warn when token usage exceeds budget".into(),
253            condition: PolicyCondition::TokenBudget {
254                max_tokens: Some(500_000),
255                max_cost_usd: Some(50.0),
256            },
257            action: PolicyAction::Warn,
258            severity: PolicySeverity::Medium,
259            enabled: true,
260        },
261    ]
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::token_usage::TokenUsage;
268    use crate::trace::{Session, ToolCall};
269    use chrono::Utc;
270
271    fn make_trace(session_id: &str, model: Option<&str>) -> TraceRecord {
272        TraceRecord {
273            id: Uuid::nil(),
274            repo_id: "repo".into(),
275            commit_sha: "abc123".into(),
276            branch: None,
277            author: "dev".into(),
278            created_at: Utc::now(),
279            model: model.map(String::from),
280            tool: "claude-code".into(),
281            tool_version: None,
282            session: Session {
283                session_id: session_id.into(),
284                started_at: Utc::now(),
285                ended_at: None,
286                prompts: vec![],
287                responses: vec![],
288                token_usage: TokenUsage::default(),
289                tools_used: vec![],
290            },
291            agent_trace: None,
292            signature: None,
293        }
294    }
295
296    fn make_trace_with_tools(tools: Vec<&str>) -> TraceRecord {
297        let mut trace = make_trace("sess-1", Some("anthropic/claude-3"));
298        trace.session.tools_used = tools
299            .into_iter()
300            .map(|name| ToolCall {
301                name: name.into(),
302                input_summary: String::new(),
303                timestamp: Utc::now(),
304            })
305            .collect();
306        trace
307    }
308
309    fn make_trace_with_tokens(total_tokens: u64, cost: f64) -> TraceRecord {
310        let mut trace = make_trace("sess-1", Some("anthropic/claude-3"));
311        trace.session.token_usage.total_tokens = total_tokens;
312        trace.session.token_usage.estimated_cost_usd = cost;
313        trace
314    }
315
316    // --- TraceCompleteness ---
317
318    #[test]
319    fn completeness_pass_when_both_present() {
320        let trace = make_trace("sess-1", Some("claude-3"));
321        let (result, _) = eval_trace_completeness(&trace);
322        assert_eq!(result, EvalResult::Pass);
323    }
324
325    #[test]
326    fn completeness_fail_missing_session_id() {
327        let trace = make_trace("", Some("claude-3"));
328        let (result, details) = eval_trace_completeness(&trace);
329        assert_eq!(result, EvalResult::Warn);
330        assert!(details.contains("session"));
331    }
332
333    #[test]
334    fn completeness_fail_missing_model() {
335        let trace = make_trace("sess-1", None);
336        let (result, details) = eval_trace_completeness(&trace);
337        assert_eq!(result, EvalResult::Warn);
338        assert!(details.contains("model"));
339    }
340
341    #[test]
342    fn completeness_fail_both_missing() {
343        let trace = make_trace("", None);
344        let (result, details) = eval_trace_completeness(&trace);
345        assert_eq!(result, EvalResult::Warn);
346        assert!(details.contains("session"));
347        assert!(details.contains("model"));
348    }
349
350    // --- ModelAllowlist ---
351
352    #[test]
353    fn model_allowlist_pass_empty_list() {
354        let trace = make_trace("sess-1", Some("anything"));
355        let (result, _) = eval_model_allowlist(&trace, &[]);
356        assert_eq!(result, EvalResult::Pass);
357    }
358
359    #[test]
360    fn model_allowlist_pass_match() {
361        let trace = make_trace("sess-1", Some("anthropic/claude-3"));
362        let allowed = vec!["anthropic/claude".into(), "openai/gpt".into()];
363        let (result, _) = eval_model_allowlist(&trace, &allowed);
364        assert_eq!(result, EvalResult::Pass);
365    }
366
367    #[test]
368    fn model_allowlist_fail_no_match() {
369        let trace = make_trace("sess-1", Some("unknown/model"));
370        let allowed = vec!["anthropic/claude".into(), "openai/gpt".into()];
371        let (result, _) = eval_model_allowlist(&trace, &allowed);
372        assert_eq!(result, EvalResult::Fail);
373    }
374
375    // --- RequiredToolCall ---
376
377    #[test]
378    fn required_tool_call_pass_all_present() {
379        let trace = make_trace_with_tools(vec!["cargo test", "cargo clippy"]);
380        let required = vec!["cargo test".into(), "cargo clippy".into()];
381        let (result, _) = eval_required_tool_call(&trace, &required);
382        assert_eq!(result, EvalResult::Pass);
383    }
384
385    #[test]
386    fn required_tool_call_warn_missing() {
387        let trace = make_trace_with_tools(vec!["cargo test"]);
388        let required = vec!["cargo test".into(), "cargo clippy".into()];
389        let (result, details) = eval_required_tool_call(&trace, &required);
390        assert_eq!(result, EvalResult::Warn);
391        assert!(details.contains("cargo clippy"));
392    }
393
394    // --- TokenBudget ---
395
396    #[test]
397    fn token_budget_pass_under_limits() {
398        let trace = make_trace_with_tokens(1000, 1.0);
399        let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
400        assert_eq!(result, EvalResult::Pass);
401    }
402
403    #[test]
404    fn token_budget_warn_over_max_tokens() {
405        let trace = make_trace_with_tokens(10_000, 1.0);
406        let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
407        assert_eq!(result, EvalResult::Warn);
408    }
409
410    #[test]
411    fn token_budget_warn_over_max_cost() {
412        let trace = make_trace_with_tokens(1000, 20.0);
413        let (result, _) = eval_token_budget(&trace, Some(5000), Some(10.0));
414        assert_eq!(result, EvalResult::Warn);
415    }
416
417    #[test]
418    fn token_budget_pass_both_none() {
419        let trace = make_trace_with_tokens(999_999, 999.0);
420        let (result, _) = eval_token_budget(&trace, None, None);
421        assert_eq!(result, EvalResult::Pass);
422    }
423
424    // --- ConditionalToolCall ---
425
426    #[test]
427    fn conditional_tool_call_pass_count_met() {
428        let trace = make_trace_with_tools(vec!["cargo test", "cargo test"]);
429        let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(2));
430        assert_eq!(result, EvalResult::Pass);
431    }
432
433    #[test]
434    fn conditional_tool_call_fail_count_not_met() {
435        let trace = make_trace_with_tools(vec!["cargo test"]);
436        let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(3));
437        assert_eq!(result, EvalResult::Fail);
438    }
439
440    #[test]
441    fn conditional_tool_call_fail_absent() {
442        let trace = make_trace_with_tools(vec!["cargo clippy"]);
443        let (result, _) = eval_conditional_tool_call(&trace, "cargo test", Some(1));
444        assert_eq!(result, EvalResult::Fail);
445    }
446
447    // --- PolicyEngine ---
448
449    #[test]
450    fn evaluate_skips_disabled_policies() {
451        let policies = vec![
452            PolicyRule {
453                id: Uuid::nil(),
454                org_id: None,
455                name: "enabled".into(),
456                description: String::new(),
457                condition: PolicyCondition::TraceCompleteness,
458                action: PolicyAction::Warn,
459                severity: PolicySeverity::Low,
460                enabled: true,
461            },
462            PolicyRule {
463                id: Uuid::nil(),
464                org_id: None,
465                name: "disabled".into(),
466                description: String::new(),
467                condition: PolicyCondition::TraceCompleteness,
468                action: PolicyAction::Warn,
469                severity: PolicySeverity::Low,
470                enabled: false,
471            },
472        ];
473        let engine = PolicyEngine::new(policies);
474        let trace = make_trace("sess-1", Some("claude-3"));
475        let results = engine.evaluate(&trace);
476        assert_eq!(results.len(), 1);
477        assert_eq!(results[0].policy.name, "enabled");
478    }
479
480    #[test]
481    fn with_defaults_returns_six_enabled() {
482        let engine = PolicyEngine::with_defaults();
483        let trace = make_trace("sess-1", Some("anthropic/claude-3"));
484        let results = engine.evaluate(&trace);
485        assert_eq!(results.len(), 6);
486        assert!(results.iter().all(|r| r.policy.enabled));
487    }
488}