Skip to main content

synth_claw/validation/
validators.rs

1use super::{ValidationResult, Validator};
2use crate::generation::GenerationResult;
3use regex::Regex;
4use std::collections::HashMap;
5
6pub struct MinLength(pub usize);
7pub struct MaxLength(pub usize);
8
9impl Validator for MinLength {
10    fn validate(&self, r: &GenerationResult) -> ValidationResult {
11        if r.content.trim().len() >= self.0 {
12            ValidationResult::valid()
13        } else {
14            ValidationResult::invalid(format!(
15                "too short: {} < {}",
16                r.content.trim().len(),
17                self.0
18            ))
19        }
20    }
21}
22
23impl Validator for MaxLength {
24    fn validate(&self, r: &GenerationResult) -> ValidationResult {
25        if r.content.len() <= self.0 {
26            ValidationResult::valid()
27        } else {
28            ValidationResult::invalid(format!("too long: {} > {}", r.content.len(), self.0))
29        }
30    }
31}
32
33pub struct Json;
34pub struct JsonSchema {
35    required: Vec<String>,
36}
37
38impl JsonSchema {
39    pub fn require(fields: &[&str]) -> Self {
40        Self {
41            required: fields.iter().map(|s| s.to_string()).collect(),
42        }
43    }
44}
45
46impl Validator for Json {
47    fn validate(&self, r: &GenerationResult) -> ValidationResult {
48        match serde_json::from_str::<serde_json::Value>(&extract_json(&r.content)) {
49            Ok(_) => ValidationResult::valid(),
50            Err(e) => ValidationResult::invalid(format!("invalid json: {}", e)),
51        }
52    }
53}
54
55impl Validator for JsonSchema {
56    fn validate(&self, r: &GenerationResult) -> ValidationResult {
57        let v: serde_json::Value = match serde_json::from_str(&extract_json(&r.content)) {
58            Ok(v) => v,
59            Err(e) => return ValidationResult::invalid(format!("invalid json: {}", e)),
60        };
61
62        let obj = match v.as_object() {
63            Some(o) => o,
64            None => return ValidationResult::invalid("expected json object"),
65        };
66
67        let missing: Vec<_> = self
68            .required
69            .iter()
70            .filter(|f| !obj.contains_key(*f))
71            .collect();
72        if missing.is_empty() {
73            ValidationResult::valid()
74        } else {
75            ValidationResult::invalid(format!("missing fields: {:?}", missing))
76        }
77    }
78}
79
80pub struct Blocklist(Vec<(Regex, &'static str)>);
81
82impl Blocklist {
83    pub fn llm_artifacts() -> Self {
84        let patterns = [
85            (r"(?i)^(sure|certainly|of course)[,!]?\s", "filler"),
86            (r"(?i)^here('s| is)", "here is"),
87            (r"(?i)^I('d| would) be happy to", "politeness"),
88            (r"(?i)as an AI", "ai mention"),
89            (r"(?i)I cannot|I can't|I'm unable", "refusal"),
90        ];
91        Self(
92            patterns
93                .iter()
94                .filter_map(|(p, r)| Some((Regex::new(p).ok()?, *r)))
95                .collect(),
96        )
97    }
98}
99
100impl Validator for Blocklist {
101    fn validate(&self, r: &GenerationResult) -> ValidationResult {
102        for (re, reason) in &self.0 {
103            if re.is_match(&r.content) {
104                return ValidationResult::invalid(format!("blocked: {}", reason));
105            }
106        }
107        ValidationResult::valid()
108    }
109}
110
111pub struct Repetition {
112    pub max_ratio: f32,
113    pub ngram_size: usize,
114}
115
116impl Default for Repetition {
117    fn default() -> Self {
118        Self {
119            max_ratio: 0.5,
120            ngram_size: 3,
121        }
122    }
123}
124
125impl Validator for Repetition {
126    fn validate(&self, r: &GenerationResult) -> ValidationResult {
127        let words: Vec<_> = r.content.split_whitespace().collect();
128        if words.len() < self.ngram_size * 2 {
129            return ValidationResult::valid();
130        }
131
132        let mut counts: HashMap<String, usize> = HashMap::new();
133        for w in words.windows(self.ngram_size) {
134            *counts.entry(w.join(" ").to_lowercase()).or_default() += 1;
135        }
136
137        let total = words.len() - self.ngram_size + 1;
138        let repeated: usize = counts.values().filter(|&&c| c > 1).map(|c| c - 1).sum();
139        let ratio = repeated as f32 / total as f32;
140
141        if ratio <= self.max_ratio {
142            ValidationResult::valid()
143        } else {
144            ValidationResult::invalid(format!(
145                "repetitive: {:.0}% > {:.0}%",
146                ratio * 100.0,
147                self.max_ratio * 100.0
148            ))
149        }
150    }
151}
152
153pub struct Custom<F>(pub F);
154
155impl<F: Fn(&GenerationResult) -> ValidationResult + Send + Sync> Validator for Custom<F> {
156    fn validate(&self, r: &GenerationResult) -> ValidationResult {
157        self.0(r)
158    }
159}
160
161fn extract_json(content: &str) -> String {
162    let s = content.trim();
163    if let Some(start) = s.find("```json") {
164        if let Some(end) = s[start + 7..].find("```") {
165            return s[start + 7..start + 7 + end].trim().to_string();
166        }
167    }
168    if let Some(start) = s.find("```") {
169        if let Some(end) = s[start + 3..].find("```") {
170            let inner = s[start + 3..start + 3 + end].trim();
171            return inner.lines().skip(1).collect::<Vec<_>>().join("\n");
172        }
173    }
174    s.to_string()
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    fn r(s: &str) -> GenerationResult {
182        GenerationResult {
183            content: s.to_string(),
184            source_index: None,
185            category: None,
186            input_tokens: 0,
187            output_tokens: 0,
188        }
189    }
190
191    #[test]
192    fn test_length() {
193        assert!(!MinLength(10).validate(&r("short")).is_valid);
194        assert!(MinLength(5).validate(&r("hello")).is_valid);
195        assert!(MaxLength(10).validate(&r("short")).is_valid);
196        assert!(!MaxLength(5).validate(&r("too long")).is_valid);
197    }
198
199    #[test]
200    fn test_json() {
201        assert!(Json.validate(&r(r#"{"a":1}"#)).is_valid);
202        assert!(!Json.validate(&r("not json")).is_valid);
203        assert!(Json.validate(&r("```json\n{\"a\":1}\n```")).is_valid);
204    }
205
206    #[test]
207    fn test_schema() {
208        let v = JsonSchema::require(&["a", "b"]);
209        assert!(v.validate(&r(r#"{"a":1,"b":2}"#)).is_valid);
210        assert!(!v.validate(&r(r#"{"a":1}"#)).is_valid);
211    }
212
213    #[test]
214    fn test_blocklist() {
215        let v = Blocklist::llm_artifacts();
216        assert!(!v.validate(&r("Sure! Here you go")).is_valid);
217        assert!(!v.validate(&r("As an AI, I")).is_valid);
218        assert!(v.validate(&r("Normal text")).is_valid);
219    }
220
221    #[test]
222    fn test_repetition() {
223        let v = Repetition {
224            max_ratio: 0.3,
225            ngram_size: 2,
226        };
227        assert!(!v.validate(&r("the cat the cat the cat the cat")).is_valid);
228        assert!(v.validate(&r("the quick brown fox jumps")).is_valid);
229    }
230}