1use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6#[derive(Debug)]
8pub enum GateError {
9 IoError(std::io::Error),
10 TomlError(toml::de::Error),
11 InvalidPointer(String),
12 TypeMismatch { expected: String, actual: String },
13 InvalidOperator { op: String, value_type: String },
14 MissingField { name: String, field: String },
15}
16
17impl std::fmt::Display for GateError {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Self::IoError(e) => write!(f, "Failed to read policy file: {e}"),
21 Self::TomlError(e) => write!(f, "Failed to parse policy TOML: {e}"),
22 Self::InvalidPointer(p) => write!(f, "Invalid JSON pointer: {p}"),
23 Self::TypeMismatch { expected, actual } => {
24 write!(f, "Type mismatch: expected {expected}, got {actual}")
25 }
26 Self::InvalidOperator { op, value_type } => {
27 write!(f, "Invalid operator '{op}' for type '{value_type}'")
28 }
29 Self::MissingField { name, field } => {
30 write!(f, "Rule '{name}' missing required field: {field}")
31 }
32 }
33 }
34}
35
36impl std::error::Error for GateError {
37 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
38 match self {
39 Self::IoError(e) => Some(e),
40 Self::TomlError(e) => Some(e),
41 _ => None,
42 }
43 }
44}
45
46impl From<std::io::Error> for GateError {
47 fn from(err: std::io::Error) -> Self {
48 Self::IoError(err)
49 }
50}
51
52impl From<toml::de::Error> for GateError {
53 fn from(err: toml::de::Error) -> Self {
54 Self::TomlError(err)
55 }
56}
57
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
60#[serde(default)]
61pub struct PolicyConfig {
62 pub rules: Vec<PolicyRule>,
64
65 #[serde(default)]
67 pub fail_fast: bool,
68
69 #[serde(default)]
71 pub allow_missing: bool,
72}
73
74impl PolicyConfig {
75 pub fn from_toml(s: &str) -> Result<Self, GateError> {
98 Ok(toml::from_str(s)?)
99 }
100
101 pub fn from_file(path: &Path) -> Result<Self, GateError> {
103 let content = std::fs::read_to_string(path)?;
104 Self::from_toml(&content)
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct PolicyRule {
111 pub name: String,
113
114 pub pointer: String,
116
117 pub op: RuleOperator,
119
120 #[serde(default)]
122 pub value: Option<serde_json::Value>,
123
124 #[serde(default)]
126 pub values: Option<Vec<serde_json::Value>>,
127
128 #[serde(default)]
130 pub negate: bool,
131
132 #[serde(default)]
134 pub level: RuleLevel,
135
136 #[serde(default)]
138 pub message: Option<String>,
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
143#[serde(rename_all = "snake_case")]
144pub enum RuleOperator {
145 Gt,
147 Gte,
149 Lt,
151 Lte,
153 #[default]
155 Eq,
156 Ne,
158 In,
160 Contains,
162 Exists,
164}
165
166impl std::fmt::Display for RuleOperator {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 match self {
169 RuleOperator::Gt => write!(f, ">"),
170 RuleOperator::Gte => write!(f, ">="),
171 RuleOperator::Lt => write!(f, "<"),
172 RuleOperator::Lte => write!(f, "<="),
173 RuleOperator::Eq => write!(f, "=="),
174 RuleOperator::Ne => write!(f, "!="),
175 RuleOperator::In => write!(f, "in"),
176 RuleOperator::Contains => write!(f, "contains"),
177 RuleOperator::Exists => write!(f, "exists"),
178 }
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
184#[serde(rename_all = "lowercase")]
185pub enum RuleLevel {
186 Warn,
188 #[default]
190 Error,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GateResult {
196 pub passed: bool,
198
199 pub rule_results: Vec<RuleResult>,
201
202 pub errors: usize,
204
205 pub warnings: usize,
207}
208
209impl GateResult {
210 pub fn from_results(rule_results: Vec<RuleResult>) -> Self {
212 let errors = rule_results
213 .iter()
214 .filter(|r| !r.passed && r.level == RuleLevel::Error)
215 .count();
216 let warnings = rule_results
217 .iter()
218 .filter(|r| !r.passed && r.level == RuleLevel::Warn)
219 .count();
220 let passed = errors == 0;
221
222 Self {
223 passed,
224 rule_results,
225 errors,
226 warnings,
227 }
228 }
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct RuleResult {
234 pub name: String,
236
237 pub passed: bool,
239
240 pub level: RuleLevel,
242
243 pub actual: Option<serde_json::Value>,
245
246 pub expected: String,
248
249 pub message: Option<String>,
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct RatchetRule {
260 pub pointer: String,
262
263 #[serde(default)]
266 pub max_increase_pct: Option<f64>,
267
268 #[serde(default)]
271 pub max_value: Option<f64>,
272
273 #[serde(default)]
275 pub level: RuleLevel,
276
277 #[serde(default)]
279 pub description: Option<String>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct RatchetResult {
285 pub rule: RatchetRule,
287
288 pub passed: bool,
290
291 pub baseline_value: Option<f64>,
293
294 pub current_value: f64,
296
297 pub change_pct: Option<f64>,
299
300 pub message: String,
302}
303
304#[derive(Debug, Clone, Default, Serialize, Deserialize)]
306#[serde(default)]
307pub struct RatchetConfig {
308 pub rules: Vec<RatchetRule>,
310
311 #[serde(default)]
313 pub fail_fast: bool,
314
315 #[serde(default)]
317 pub allow_missing_baseline: bool,
318
319 #[serde(default)]
321 pub allow_missing_current: bool,
322}
323
324impl RatchetConfig {
325 pub fn from_toml(s: &str) -> Result<Self, GateError> {
327 Ok(toml::from_str(s)?)
328 }
329
330 pub fn from_file(path: &Path) -> Result<Self, GateError> {
332 let content = std::fs::read_to_string(path)?;
333 Self::from_toml(&content)
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct RatchetGateResult {
340 pub passed: bool,
342
343 pub ratchet_results: Vec<RatchetResult>,
345
346 pub errors: usize,
348
349 pub warnings: usize,
351}
352
353impl RatchetGateResult {
354 pub fn from_results(ratchet_results: Vec<RatchetResult>) -> Self {
356 let errors = ratchet_results
357 .iter()
358 .filter(|r| !r.passed && r.rule.level == RuleLevel::Error)
359 .count();
360 let warnings = ratchet_results
361 .iter()
362 .filter(|r| !r.passed && r.rule.level == RuleLevel::Warn)
363 .count();
364 let passed = errors == 0;
365
366 Self {
367 passed,
368 ratchet_results,
369 errors,
370 warnings,
371 }
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_parse_policy() {
381 let toml = r#"
382fail_fast = true
383allow_missing = false
384
385[[rules]]
386name = "max_tokens"
387pointer = "/derived/totals/tokens"
388op = "lte"
389value = 500000
390level = "error"
391message = "Too many tokens"
392
393[[rules]]
394name = "has_license"
395pointer = "/license/effective"
396op = "exists"
397level = "warn"
398"#;
399 let policy = PolicyConfig::from_toml(toml).unwrap();
400 assert!(policy.fail_fast);
401 assert!(!policy.allow_missing);
402 assert_eq!(policy.rules.len(), 2);
403 assert_eq!(policy.rules[0].name, "max_tokens");
404 assert_eq!(policy.rules[0].op, RuleOperator::Lte);
405 assert_eq!(policy.rules[1].op, RuleOperator::Exists);
406 }
407
408 #[test]
409 fn test_gate_result() {
410 let results = vec![
411 RuleResult {
412 name: "rule1".into(),
413 passed: true,
414 level: RuleLevel::Error,
415 actual: None,
416 expected: "test".into(),
417 message: None,
418 },
419 RuleResult {
420 name: "rule2".into(),
421 passed: false,
422 level: RuleLevel::Warn,
423 actual: None,
424 expected: "test".into(),
425 message: Some("Warning".into()),
426 },
427 ];
428
429 let gate = GateResult::from_results(results);
430 assert!(gate.passed); assert_eq!(gate.errors, 0);
432 assert_eq!(gate.warnings, 1);
433 }
434
435 #[test]
436 fn test_policy_from_file() {
437 use std::time::{SystemTime, UNIX_EPOCH};
439
440 let toml = r#"
441fail_fast = true
442allow_missing = false
443
444[[rules]]
445name = "max_tokens"
446pointer = "/derived/totals/tokens"
447op = "lte"
448value = 500000
449level = "error"
450"#;
451
452 let nanos = SystemTime::now()
453 .duration_since(UNIX_EPOCH)
454 .unwrap()
455 .as_nanos();
456 let path = std::env::temp_dir().join(format!("tokmd-gate-policy-{nanos}.toml"));
457 std::fs::write(&path, toml).unwrap();
458
459 let policy = PolicyConfig::from_file(&path).unwrap();
460 let _ = std::fs::remove_file(&path);
461
462 assert!(policy.fail_fast);
463 assert_eq!(policy.rules.len(), 1);
464 assert_eq!(policy.rules[0].name, "max_tokens");
465 assert_eq!(policy.rules[0].op, RuleOperator::Lte);
466 }
467
468 #[test]
469 fn test_rule_operator_display() {
470 assert_eq!(RuleOperator::Gt.to_string(), ">");
472 assert_eq!(RuleOperator::Gte.to_string(), ">=");
473 assert_eq!(RuleOperator::Lt.to_string(), "<");
474 assert_eq!(RuleOperator::Lte.to_string(), "<=");
475 assert_eq!(RuleOperator::Eq.to_string(), "==");
476 assert_eq!(RuleOperator::Ne.to_string(), "!=");
477 assert_eq!(RuleOperator::In.to_string(), "in");
478 assert_eq!(RuleOperator::Contains.to_string(), "contains");
479 assert_eq!(RuleOperator::Exists.to_string(), "exists");
480 }
481
482 #[test]
483 fn test_gate_result_counts_only_failed_rules() {
484 let results = vec![
486 RuleResult {
487 name: "passed_warn".into(),
488 passed: true,
489 level: RuleLevel::Warn,
490 actual: None,
491 expected: "x".into(),
492 message: None,
493 },
494 RuleResult {
495 name: "failed_warn".into(),
496 passed: false,
497 level: RuleLevel::Warn,
498 actual: None,
499 expected: "x".into(),
500 message: Some("warn".into()),
501 },
502 ];
503
504 let gate = GateResult::from_results(results);
505 assert!(gate.passed); assert_eq!(gate.errors, 0);
507 assert_eq!(gate.warnings, 1);
508 }
509}