synth_claw/validation/
mod.rs1mod dedup;
2mod validators;
3
4pub use dedup::*;
5pub use validators::*;
6
7use crate::config::ValidationConfig;
8use crate::generation::GenerationResult;
9
10#[derive(Debug, Clone, Default)]
11pub struct ValidationResult {
12 pub is_valid: bool,
13 pub errors: Vec<String>,
14}
15
16impl ValidationResult {
17 pub fn valid() -> Self {
18 Self {
19 is_valid: true,
20 errors: vec![],
21 }
22 }
23
24 pub fn invalid(error: impl Into<String>) -> Self {
25 Self {
26 is_valid: false,
27 errors: vec![error.into()],
28 }
29 }
30
31 pub fn merge(&mut self, other: Self) {
32 if !other.is_valid {
33 self.is_valid = false;
34 self.errors.extend(other.errors);
35 }
36 }
37}
38
39pub trait Validator: Send + Sync {
40 fn validate(&self, result: &GenerationResult) -> ValidationResult;
41}
42
43#[derive(Default)]
44pub struct ValidationPipeline {
45 validators: Vec<Box<dyn Validator>>,
46}
47
48impl ValidationPipeline {
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn from_config(config: &ValidationConfig) -> Self {
54 let mut p = Self::new();
55 if let Some(n) = config.min_length {
56 p = p.add(MinLength(n));
57 }
58 if let Some(n) = config.max_length {
59 p = p.add(MaxLength(n));
60 }
61 if config.json {
62 p = p.add(Json);
63 }
64 if let Some(fields) = &config.json_schema {
65 let fields: Vec<&str> = fields.iter().map(|s| s.as_str()).collect();
66 p = p.add(JsonSchema::require(&fields));
67 }
68 if config.blocklist {
69 p = p.add(Blocklist::llm_artifacts());
70 }
71 if config.repetition {
72 p = p.add(Repetition::default());
73 }
74 p
75 }
76
77 pub fn add<V: Validator + 'static>(mut self, v: V) -> Self {
78 self.validators.push(Box::new(v));
79 self
80 }
81
82 pub fn validate(&self, result: &GenerationResult) -> ValidationResult {
83 self.validators
84 .iter()
85 .fold(ValidationResult::valid(), |mut acc, v| {
86 acc.merge(v.validate(result));
87 acc
88 })
89 }
90
91 pub fn filter(
92 &self,
93 results: Vec<GenerationResult>,
94 ) -> (
95 Vec<GenerationResult>,
96 Vec<(GenerationResult, ValidationResult)>,
97 ) {
98 let (mut valid, mut invalid) = (vec![], vec![]);
99 for r in results {
100 let v = self.validate(&r);
101 if v.is_valid {
102 valid.push(r);
103 } else {
104 invalid.push((r, v));
105 }
106 }
107 (valid, invalid)
108 }
109}
110
111#[derive(Debug, Clone, Default)]
112pub struct ValidationStats {
113 pub total: usize,
114 pub passed: usize,
115 pub failed: usize,
116 pub duplicates_removed: usize,
117}
118
119pub struct ValidatedResults {
120 pub results: Vec<GenerationResult>,
121 pub stats: ValidationStats,
122 pub rejected: Vec<(GenerationResult, ValidationResult)>,
123}
124
125pub fn validate_and_dedupe(
126 results: Vec<GenerationResult>,
127 pipeline: &ValidationPipeline,
128 dedup: Option<&Deduplicator>,
129) -> ValidatedResults {
130 let total = results.len();
131 let (valid, rejected) = pipeline.filter(results);
132 let failed = rejected.len();
133
134 let (results, duplicates_removed) = match dedup {
135 Some(d) => {
136 let before = valid.len();
137 let deduped = d.dedupe(valid);
138 let removed = before - deduped.len();
139 (deduped, removed)
140 }
141 None => (valid, 0),
142 };
143
144 ValidatedResults {
145 stats: ValidationStats {
146 total,
147 passed: results.len(),
148 failed,
149 duplicates_removed,
150 },
151 results,
152 rejected,
153 }
154}