Skip to main content

tokmd_gate/
types.rs

1//! Policy and rule type definitions.
2
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6/// Errors from policy evaluation.
7#[derive(Debug)]
8pub enum GateError {
9    IoError(std::io::Error),
10    TomlError(toml::de::Error),
11    InvalidPointer(String),
12    TypeMismatch { expected: String, actual: String },
13    InvalidOperator { op: String, value_type: String },
14    MissingField { name: String, field: String },
15}
16
17impl std::fmt::Display for GateError {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Self::IoError(e) => write!(f, "Failed to read policy file: {e}"),
21            Self::TomlError(e) => write!(f, "Failed to parse policy TOML: {e}"),
22            Self::InvalidPointer(p) => write!(f, "Invalid JSON pointer: {p}"),
23            Self::TypeMismatch { expected, actual } => {
24                write!(f, "Type mismatch: expected {expected}, got {actual}")
25            }
26            Self::InvalidOperator { op, value_type } => {
27                write!(f, "Invalid operator '{op}' for type '{value_type}'")
28            }
29            Self::MissingField { name, field } => {
30                write!(f, "Rule '{name}' missing required field: {field}")
31            }
32        }
33    }
34}
35
36impl std::error::Error for GateError {
37    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
38        match self {
39            Self::IoError(e) => Some(e),
40            Self::TomlError(e) => Some(e),
41            _ => None,
42        }
43    }
44}
45
46impl From<std::io::Error> for GateError {
47    fn from(err: std::io::Error) -> Self {
48        Self::IoError(err)
49    }
50}
51
52impl From<toml::de::Error> for GateError {
53    fn from(err: toml::de::Error) -> Self {
54        Self::TomlError(err)
55    }
56}
57
58/// Root policy configuration.
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60#[serde(default)]
61pub struct PolicyConfig {
62    /// Policy rules to evaluate.
63    pub rules: Vec<PolicyRule>,
64
65    /// Stop evaluation on first error.
66    #[serde(default)]
67    pub fail_fast: bool,
68
69    /// Allow missing values (treat as pass) instead of error.
70    #[serde(default)]
71    pub allow_missing: bool,
72}
73
74impl PolicyConfig {
75    /// Parse policy from TOML string.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use tokmd_gate::PolicyConfig;
81    ///
82    /// let toml = r#"
83    /// fail_fast = false
84    /// allow_missing = true
85    ///
86    /// [[rules]]
87    /// name = "max_tokens"
88    /// pointer = "/tokens"
89    /// op = "lte"
90    /// value = 100000
91    /// "#;
92    ///
93    /// let policy = PolicyConfig::from_toml(toml).unwrap();
94    /// assert_eq!(policy.rules.len(), 1);
95    /// assert!(policy.allow_missing);
96    /// ```
97    pub fn from_toml(s: &str) -> Result<Self, GateError> {
98        Ok(toml::from_str(s)?)
99    }
100
101    /// Load policy from a TOML file.
102    pub fn from_file(path: &Path) -> Result<Self, GateError> {
103        let content = std::fs::read_to_string(path)?;
104        Self::from_toml(&content)
105    }
106}
107
108/// A single policy rule.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct PolicyRule {
111    /// Human-readable name for the rule.
112    pub name: String,
113
114    /// JSON Pointer to the value to check (RFC 6901).
115    pub pointer: String,
116
117    /// Comparison operator.
118    pub op: RuleOperator,
119
120    /// Single value for comparison (for >, <, ==, etc.).
121    #[serde(default)]
122    pub value: Option<serde_json::Value>,
123
124    /// Multiple values for "in" operator.
125    #[serde(default)]
126    pub values: Option<Vec<serde_json::Value>>,
127
128    /// Negate the result (NOT).
129    #[serde(default)]
130    pub negate: bool,
131
132    /// Rule severity level.
133    #[serde(default)]
134    pub level: RuleLevel,
135
136    /// Custom failure message.
137    #[serde(default)]
138    pub message: Option<String>,
139}
140
141/// Comparison operators for rules.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
143#[serde(rename_all = "snake_case")]
144pub enum RuleOperator {
145    /// Greater than (>)
146    Gt,
147    /// Greater than or equal (>=)
148    Gte,
149    /// Less than (<)
150    Lt,
151    /// Less than or equal (<=)
152    Lte,
153    /// Equal (==)
154    #[default]
155    Eq,
156    /// Not equal (!=)
157    Ne,
158    /// Value is in list
159    In,
160    /// String/array contains value
161    Contains,
162    /// JSON pointer exists (value is present)
163    Exists,
164}
165
166impl std::fmt::Display for RuleOperator {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        match self {
169            RuleOperator::Gt => write!(f, ">"),
170            RuleOperator::Gte => write!(f, ">="),
171            RuleOperator::Lt => write!(f, "<"),
172            RuleOperator::Lte => write!(f, "<="),
173            RuleOperator::Eq => write!(f, "=="),
174            RuleOperator::Ne => write!(f, "!="),
175            RuleOperator::In => write!(f, "in"),
176            RuleOperator::Contains => write!(f, "contains"),
177            RuleOperator::Exists => write!(f, "exists"),
178        }
179    }
180}
181
182/// Rule severity level.
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
184#[serde(rename_all = "lowercase")]
185pub enum RuleLevel {
186    /// Warning - does not fail the gate.
187    Warn,
188    /// Error - fails the gate.
189    #[default]
190    Error,
191}
192
193/// Result of evaluating the entire policy.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GateResult {
196    /// Overall pass/fail.
197    pub passed: bool,
198
199    /// Individual rule results.
200    pub rule_results: Vec<RuleResult>,
201
202    /// Count of errors.
203    pub errors: usize,
204
205    /// Count of warnings.
206    pub warnings: usize,
207}
208
209impl GateResult {
210    /// Create a new gate result from rule results.
211    pub fn from_results(rule_results: Vec<RuleResult>) -> Self {
212        let errors = rule_results
213            .iter()
214            .filter(|r| !r.passed && r.level == RuleLevel::Error)
215            .count();
216        let warnings = rule_results
217            .iter()
218            .filter(|r| !r.passed && r.level == RuleLevel::Warn)
219            .count();
220        let passed = errors == 0;
221
222        Self {
223            passed,
224            rule_results,
225            errors,
226            warnings,
227        }
228    }
229}
230
231/// Result of evaluating a single rule.
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct RuleResult {
234    /// Rule name.
235    pub name: String,
236
237    /// Whether the rule passed.
238    pub passed: bool,
239
240    /// Rule level (error/warn).
241    pub level: RuleLevel,
242
243    /// Actual value found (if any).
244    pub actual: Option<serde_json::Value>,
245
246    /// Expected value or condition.
247    pub expected: String,
248
249    /// Failure message.
250    pub message: Option<String>,
251}
252
253/// Ratchet rule for gradual improvement.
254///
255/// Ratchet rules enforce that metrics don't regress beyond acceptable bounds
256/// when compared to a baseline. This enables gradual quality improvement by
257/// allowing teams to "ratchet" down thresholds over time.
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct RatchetRule {
260    /// JSON pointer to the metric (e.g., "/complexity/avg_cyclomatic").
261    pub pointer: String,
262
263    /// Maximum allowed increase percentage from baseline.
264    /// For example, 10.0 means the current value can be at most 10% higher than baseline.
265    #[serde(default)]
266    pub max_increase_pct: Option<f64>,
267
268    /// Maximum allowed absolute value.
269    /// This acts as a hard ceiling regardless of baseline.
270    #[serde(default)]
271    pub max_value: Option<f64>,
272
273    /// Rule severity level.
274    #[serde(default)]
275    pub level: RuleLevel,
276
277    /// Human-readable description of the rule.
278    #[serde(default)]
279    pub description: Option<String>,
280}
281
282/// Result of ratchet evaluation.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct RatchetResult {
285    /// The rule that was evaluated.
286    pub rule: RatchetRule,
287
288    /// Whether the ratchet check passed.
289    pub passed: bool,
290
291    /// Baseline value (if found).
292    pub baseline_value: Option<f64>,
293
294    /// Current value.
295    pub current_value: f64,
296
297    /// Percentage change from baseline (if baseline exists).
298    pub change_pct: Option<f64>,
299
300    /// Human-readable message describing the result.
301    pub message: String,
302}
303
304/// Configuration for ratchet rules.
305#[derive(Debug, Clone, Default, Serialize, Deserialize)]
306#[serde(default)]
307pub struct RatchetConfig {
308    /// Ratchet rules to evaluate.
309    pub rules: Vec<RatchetRule>,
310
311    /// Stop evaluation on first error.
312    #[serde(default)]
313    pub fail_fast: bool,
314
315    /// Allow missing baseline values (treat as pass) instead of error.
316    #[serde(default)]
317    pub allow_missing_baseline: bool,
318
319    /// Allow missing current values (treat as pass) instead of error.
320    #[serde(default)]
321    pub allow_missing_current: bool,
322}
323
324impl RatchetConfig {
325    /// Parse ratchet config from TOML string.
326    pub fn from_toml(s: &str) -> Result<Self, GateError> {
327        Ok(toml::from_str(s)?)
328    }
329
330    /// Load ratchet config from a TOML file.
331    pub fn from_file(path: &Path) -> Result<Self, GateError> {
332        let content = std::fs::read_to_string(path)?;
333        Self::from_toml(&content)
334    }
335}
336
337/// Overall result of ratchet evaluation.
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct RatchetGateResult {
340    /// Overall pass/fail.
341    pub passed: bool,
342
343    /// Individual ratchet results.
344    pub ratchet_results: Vec<RatchetResult>,
345
346    /// Count of errors.
347    pub errors: usize,
348
349    /// Count of warnings.
350    pub warnings: usize,
351}
352
353impl RatchetGateResult {
354    /// Create a new ratchet gate result from ratchet results.
355    pub fn from_results(ratchet_results: Vec<RatchetResult>) -> Self {
356        let errors = ratchet_results
357            .iter()
358            .filter(|r| !r.passed && r.rule.level == RuleLevel::Error)
359            .count();
360        let warnings = ratchet_results
361            .iter()
362            .filter(|r| !r.passed && r.rule.level == RuleLevel::Warn)
363            .count();
364        let passed = errors == 0;
365
366        Self {
367            passed,
368            ratchet_results,
369            errors,
370            warnings,
371        }
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_parse_policy() {
381        let toml = r#"
382fail_fast = true
383allow_missing = false
384
385[[rules]]
386name = "max_tokens"
387pointer = "/derived/totals/tokens"
388op = "lte"
389value = 500000
390level = "error"
391message = "Too many tokens"
392
393[[rules]]
394name = "has_license"
395pointer = "/license/effective"
396op = "exists"
397level = "warn"
398"#;
399        let policy = PolicyConfig::from_toml(toml).unwrap();
400        assert!(policy.fail_fast);
401        assert!(!policy.allow_missing);
402        assert_eq!(policy.rules.len(), 2);
403        assert_eq!(policy.rules[0].name, "max_tokens");
404        assert_eq!(policy.rules[0].op, RuleOperator::Lte);
405        assert_eq!(policy.rules[1].op, RuleOperator::Exists);
406    }
407
408    #[test]
409    fn test_gate_result() {
410        let results = vec![
411            RuleResult {
412                name: "rule1".into(),
413                passed: true,
414                level: RuleLevel::Error,
415                actual: None,
416                expected: "test".into(),
417                message: None,
418            },
419            RuleResult {
420                name: "rule2".into(),
421                passed: false,
422                level: RuleLevel::Warn,
423                actual: None,
424                expected: "test".into(),
425                message: Some("Warning".into()),
426            },
427        ];
428
429        let gate = GateResult::from_results(results);
430        assert!(gate.passed); // Only warns, no errors
431        assert_eq!(gate.errors, 0);
432        assert_eq!(gate.warnings, 1);
433    }
434
435    #[test]
436    fn test_policy_from_file() {
437        // Kills mutant: PolicyConfig::from_file -> Ok(Default::default()).
438        use std::time::{SystemTime, UNIX_EPOCH};
439
440        let toml = r#"
441fail_fast = true
442allow_missing = false
443
444[[rules]]
445name = "max_tokens"
446pointer = "/derived/totals/tokens"
447op = "lte"
448value = 500000
449level = "error"
450"#;
451
452        let nanos = SystemTime::now()
453            .duration_since(UNIX_EPOCH)
454            .unwrap()
455            .as_nanos();
456        let path = std::env::temp_dir().join(format!("tokmd-gate-policy-{nanos}.toml"));
457        std::fs::write(&path, toml).unwrap();
458
459        let policy = PolicyConfig::from_file(&path).unwrap();
460        let _ = std::fs::remove_file(&path);
461
462        assert!(policy.fail_fast);
463        assert_eq!(policy.rules.len(), 1);
464        assert_eq!(policy.rules[0].name, "max_tokens");
465        assert_eq!(policy.rules[0].op, RuleOperator::Lte);
466    }
467
468    #[test]
469    fn test_rule_operator_display() {
470        // Kills mutant in Display impl.
471        assert_eq!(RuleOperator::Gt.to_string(), ">");
472        assert_eq!(RuleOperator::Gte.to_string(), ">=");
473        assert_eq!(RuleOperator::Lt.to_string(), "<");
474        assert_eq!(RuleOperator::Lte.to_string(), "<=");
475        assert_eq!(RuleOperator::Eq.to_string(), "==");
476        assert_eq!(RuleOperator::Ne.to_string(), "!=");
477        assert_eq!(RuleOperator::In.to_string(), "in");
478        assert_eq!(RuleOperator::Contains.to_string(), "contains");
479        assert_eq!(RuleOperator::Exists.to_string(), "exists");
480    }
481
482    #[test]
483    fn test_gate_result_counts_only_failed_rules() {
484        // Kills `&&` -> `||` mutant in warning counting by including a passed WARN.
485        let results = vec![
486            RuleResult {
487                name: "passed_warn".into(),
488                passed: true,
489                level: RuleLevel::Warn,
490                actual: None,
491                expected: "x".into(),
492                message: None,
493            },
494            RuleResult {
495                name: "failed_warn".into(),
496                passed: false,
497                level: RuleLevel::Warn,
498                actual: None,
499                expected: "x".into(),
500                message: Some("warn".into()),
501            },
502        ];
503
504        let gate = GateResult::from_results(results);
505        assert!(gate.passed); // warns only
506        assert_eq!(gate.errors, 0);
507        assert_eq!(gate.warnings, 1);
508    }
509}