Skip to main content

task_graph_mcp/
gates.rs

1//! Gate evaluation logic for workflow transitions.
2//!
3//! Gates are checklist items that must be satisfied before transitioning out of
4//! a status or phase. A gate is satisfied when the task has an attachment with
5//! a matching type (e.g., "gate/tests", "gate/commit").
6
7use crate::config::{GateDefinition, GateEnforcement};
8use crate::db::Database;
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12
13/// Result of evaluating a single gate.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GateResult {
16    /// The attachment type that would satisfy this gate.
17    pub gate_type: String,
18    /// Enforcement level for this gate.
19    pub enforcement: GateEnforcement,
20    /// Human-readable description of what this gate requires.
21    pub description: String,
22    /// Whether the gate is satisfied (always false in unsatisfied_gates list).
23    pub satisfied: bool,
24}
25
26/// Aggregated result of evaluating all gates for a transition.
27#[derive(Debug, Serialize, Deserialize)]
28pub struct GateCheckResult {
29    /// Overall status: "pass", "warn", or "fail"
30    pub status: String,
31    /// Only unsatisfied gates are included
32    pub unsatisfied_gates: Vec<GateResult>,
33}
34
35/// Evaluate gates for a task against its attachments.
36/// Returns only unsatisfied gates.
37///
38/// # Arguments
39/// * `db` - Database handle for fetching attachments
40/// * `task_id` - The task ID to check gates for
41/// * `gates` - List of gate definitions to evaluate
42///
43/// # Returns
44/// A `GateCheckResult` with:
45/// - `status`: "pass" if all gates satisfied, "warn" if only warn-level gates unsatisfied,
46///   "fail" if any reject-level gates are unsatisfied
47/// - `unsatisfied_gates`: List of gates that are not satisfied
48pub fn evaluate_gates(
49    db: &Database,
50    task_id: &str,
51    gates: &[GateDefinition],
52) -> Result<GateCheckResult> {
53    // Get all attachment types for this task
54    let attachments = db.get_attachments(task_id)?;
55    let attachment_types: HashSet<String> = attachments
56        .iter()
57        .map(|a| a.attachment_type.clone())
58        .collect();
59
60    let mut unsatisfied_gates = Vec::new();
61    let mut has_reject = false;
62    let mut has_warn = false;
63
64    for gate in gates {
65        let satisfied = attachment_types.contains(&gate.gate_type);
66
67        if !satisfied {
68            match gate.enforcement {
69                GateEnforcement::Reject => has_reject = true,
70                GateEnforcement::Warn => has_warn = true,
71                GateEnforcement::Allow => {} // Still include in results but doesn't affect status
72            }
73
74            unsatisfied_gates.push(GateResult {
75                gate_type: gate.gate_type.clone(),
76                enforcement: gate.enforcement,
77                description: gate.description.clone(),
78                satisfied: false,
79            });
80        }
81        // Satisfied gates are omitted from results per spec
82    }
83
84    let status = if has_reject {
85        "fail".to_string()
86    } else if has_warn {
87        "warn".to_string()
88    } else {
89        "pass".to_string()
90    };
91
92    Ok(GateCheckResult {
93        status,
94        unsatisfied_gates,
95    })
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn test_gate_check_result_status_pass() {
104        // Empty unsatisfied gates should result in "pass"
105        let result = GateCheckResult {
106            status: "pass".to_string(),
107            unsatisfied_gates: vec![],
108        };
109        assert_eq!(result.status, "pass");
110    }
111
112    #[test]
113    fn test_gate_result_serialization() {
114        let gate = GateResult {
115            gate_type: "gate/tests".to_string(),
116            enforcement: GateEnforcement::Warn,
117            description: "Tests must pass".to_string(),
118            satisfied: false,
119        };
120
121        let json = serde_json::to_string(&gate).unwrap();
122        assert!(json.contains("gate/tests"));
123        assert!(json.contains("warn"));
124    }
125}