Skip to main content

synth_claw/validation/
mod.rs

1mod dedup;
2mod validators;
3
4pub use dedup::*;
5pub use validators::*;
6
7use crate::config::ValidationConfig;
8use crate::generation::GenerationResult;
9
10#[derive(Debug, Clone, Default)]
11pub struct ValidationResult {
12    pub is_valid: bool,
13    pub errors: Vec<String>,
14}
15
16impl ValidationResult {
17    pub fn valid() -> Self {
18        Self {
19            is_valid: true,
20            errors: vec![],
21        }
22    }
23
24    pub fn invalid(error: impl Into<String>) -> Self {
25        Self {
26            is_valid: false,
27            errors: vec![error.into()],
28        }
29    }
30
31    pub fn merge(&mut self, other: Self) {
32        if !other.is_valid {
33            self.is_valid = false;
34            self.errors.extend(other.errors);
35        }
36    }
37}
38
39pub trait Validator: Send + Sync {
40    fn validate(&self, result: &GenerationResult) -> ValidationResult;
41}
42
43#[derive(Default)]
44pub struct ValidationPipeline {
45    validators: Vec<Box<dyn Validator>>,
46}
47
48impl ValidationPipeline {
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    pub fn from_config(config: &ValidationConfig) -> Self {
54        let mut p = Self::new();
55        if let Some(n) = config.min_length {
56            p = p.add(MinLength(n));
57        }
58        if let Some(n) = config.max_length {
59            p = p.add(MaxLength(n));
60        }
61        if config.json {
62            p = p.add(Json);
63        }
64        if let Some(fields) = &config.json_schema {
65            let fields: Vec<&str> = fields.iter().map(|s| s.as_str()).collect();
66            p = p.add(JsonSchema::require(&fields));
67        }
68        if config.blocklist {
69            p = p.add(Blocklist::llm_artifacts());
70        }
71        if config.repetition {
72            p = p.add(Repetition::default());
73        }
74        p
75    }
76
77    pub fn add<V: Validator + 'static>(mut self, v: V) -> Self {
78        self.validators.push(Box::new(v));
79        self
80    }
81
82    pub fn validate(&self, result: &GenerationResult) -> ValidationResult {
83        self.validators
84            .iter()
85            .fold(ValidationResult::valid(), |mut acc, v| {
86                acc.merge(v.validate(result));
87                acc
88            })
89    }
90
91    pub fn filter(
92        &self,
93        results: Vec<GenerationResult>,
94    ) -> (
95        Vec<GenerationResult>,
96        Vec<(GenerationResult, ValidationResult)>,
97    ) {
98        let (mut valid, mut invalid) = (vec![], vec![]);
99        for r in results {
100            let v = self.validate(&r);
101            if v.is_valid {
102                valid.push(r);
103            } else {
104                invalid.push((r, v));
105            }
106        }
107        (valid, invalid)
108    }
109}
110
111#[derive(Debug, Clone, Default)]
112pub struct ValidationStats {
113    pub total: usize,
114    pub passed: usize,
115    pub failed: usize,
116    pub duplicates_removed: usize,
117}
118
119pub struct ValidatedResults {
120    pub results: Vec<GenerationResult>,
121    pub stats: ValidationStats,
122    pub rejected: Vec<(GenerationResult, ValidationResult)>,
123}
124
125pub fn validate_and_dedupe(
126    results: Vec<GenerationResult>,
127    pipeline: &ValidationPipeline,
128    dedup: Option<&Deduplicator>,
129) -> ValidatedResults {
130    let total = results.len();
131    let (valid, rejected) = pipeline.filter(results);
132    let failed = rejected.len();
133
134    let (results, duplicates_removed) = match dedup {
135        Some(d) => {
136            let before = valid.len();
137            let deduped = d.dedupe(valid);
138            let removed = before - deduped.len();
139            (deduped, removed)
140        }
141        None => (valid, 0),
142    };
143
144    ValidatedResults {
145        stats: ValidationStats {
146            total,
147            passed: results.len(),
148            failed,
149            duplicates_removed,
150        },
151        results,
152        rejected,
153    }
154}