synth_claw/validation/
validators.rs1use super::{ValidationResult, Validator};
2use crate::generation::GenerationResult;
3use regex::Regex;
4use std::collections::HashMap;
5
6pub struct MinLength(pub usize);
7pub struct MaxLength(pub usize);
8
9impl Validator for MinLength {
10 fn validate(&self, r: &GenerationResult) -> ValidationResult {
11 if r.content.trim().len() >= self.0 {
12 ValidationResult::valid()
13 } else {
14 ValidationResult::invalid(format!(
15 "too short: {} < {}",
16 r.content.trim().len(),
17 self.0
18 ))
19 }
20 }
21}
22
23impl Validator for MaxLength {
24 fn validate(&self, r: &GenerationResult) -> ValidationResult {
25 if r.content.len() <= self.0 {
26 ValidationResult::valid()
27 } else {
28 ValidationResult::invalid(format!("too long: {} > {}", r.content.len(), self.0))
29 }
30 }
31}
32
33pub struct Json;
34pub struct JsonSchema {
35 required: Vec<String>,
36}
37
38impl JsonSchema {
39 pub fn require(fields: &[&str]) -> Self {
40 Self {
41 required: fields.iter().map(|s| s.to_string()).collect(),
42 }
43 }
44}
45
46impl Validator for Json {
47 fn validate(&self, r: &GenerationResult) -> ValidationResult {
48 match serde_json::from_str::<serde_json::Value>(&extract_json(&r.content)) {
49 Ok(_) => ValidationResult::valid(),
50 Err(e) => ValidationResult::invalid(format!("invalid json: {}", e)),
51 }
52 }
53}
54
55impl Validator for JsonSchema {
56 fn validate(&self, r: &GenerationResult) -> ValidationResult {
57 let v: serde_json::Value = match serde_json::from_str(&extract_json(&r.content)) {
58 Ok(v) => v,
59 Err(e) => return ValidationResult::invalid(format!("invalid json: {}", e)),
60 };
61
62 let obj = match v.as_object() {
63 Some(o) => o,
64 None => return ValidationResult::invalid("expected json object"),
65 };
66
67 let missing: Vec<_> = self
68 .required
69 .iter()
70 .filter(|f| !obj.contains_key(*f))
71 .collect();
72 if missing.is_empty() {
73 ValidationResult::valid()
74 } else {
75 ValidationResult::invalid(format!("missing fields: {:?}", missing))
76 }
77 }
78}
79
80pub struct Blocklist(Vec<(Regex, &'static str)>);
81
82impl Blocklist {
83 pub fn llm_artifacts() -> Self {
84 let patterns = [
85 (r"(?i)^(sure|certainly|of course)[,!]?\s", "filler"),
86 (r"(?i)^here('s| is)", "here is"),
87 (r"(?i)^I('d| would) be happy to", "politeness"),
88 (r"(?i)as an AI", "ai mention"),
89 (r"(?i)I cannot|I can't|I'm unable", "refusal"),
90 ];
91 Self(
92 patterns
93 .iter()
94 .filter_map(|(p, r)| Some((Regex::new(p).ok()?, *r)))
95 .collect(),
96 )
97 }
98}
99
100impl Validator for Blocklist {
101 fn validate(&self, r: &GenerationResult) -> ValidationResult {
102 for (re, reason) in &self.0 {
103 if re.is_match(&r.content) {
104 return ValidationResult::invalid(format!("blocked: {}", reason));
105 }
106 }
107 ValidationResult::valid()
108 }
109}
110
111pub struct Repetition {
112 pub max_ratio: f32,
113 pub ngram_size: usize,
114}
115
116impl Default for Repetition {
117 fn default() -> Self {
118 Self {
119 max_ratio: 0.5,
120 ngram_size: 3,
121 }
122 }
123}
124
125impl Validator for Repetition {
126 fn validate(&self, r: &GenerationResult) -> ValidationResult {
127 let words: Vec<_> = r.content.split_whitespace().collect();
128 if words.len() < self.ngram_size * 2 {
129 return ValidationResult::valid();
130 }
131
132 let mut counts: HashMap<String, usize> = HashMap::new();
133 for w in words.windows(self.ngram_size) {
134 *counts.entry(w.join(" ").to_lowercase()).or_default() += 1;
135 }
136
137 let total = words.len() - self.ngram_size + 1;
138 let repeated: usize = counts.values().filter(|&&c| c > 1).map(|c| c - 1).sum();
139 let ratio = repeated as f32 / total as f32;
140
141 if ratio <= self.max_ratio {
142 ValidationResult::valid()
143 } else {
144 ValidationResult::invalid(format!(
145 "repetitive: {:.0}% > {:.0}%",
146 ratio * 100.0,
147 self.max_ratio * 100.0
148 ))
149 }
150 }
151}
152
153pub struct Custom<F>(pub F);
154
155impl<F: Fn(&GenerationResult) -> ValidationResult + Send + Sync> Validator for Custom<F> {
156 fn validate(&self, r: &GenerationResult) -> ValidationResult {
157 self.0(r)
158 }
159}
160
161fn extract_json(content: &str) -> String {
162 let s = content.trim();
163 if let Some(start) = s.find("```json") {
164 if let Some(end) = s[start + 7..].find("```") {
165 return s[start + 7..start + 7 + end].trim().to_string();
166 }
167 }
168 if let Some(start) = s.find("```") {
169 if let Some(end) = s[start + 3..].find("```") {
170 let inner = s[start + 3..start + 3 + end].trim();
171 return inner.lines().skip(1).collect::<Vec<_>>().join("\n");
172 }
173 }
174 s.to_string()
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 fn r(s: &str) -> GenerationResult {
182 GenerationResult {
183 content: s.to_string(),
184 source_index: None,
185 category: None,
186 input_tokens: 0,
187 output_tokens: 0,
188 }
189 }
190
191 #[test]
192 fn test_length() {
193 assert!(!MinLength(10).validate(&r("short")).is_valid);
194 assert!(MinLength(5).validate(&r("hello")).is_valid);
195 assert!(MaxLength(10).validate(&r("short")).is_valid);
196 assert!(!MaxLength(5).validate(&r("too long")).is_valid);
197 }
198
199 #[test]
200 fn test_json() {
201 assert!(Json.validate(&r(r#"{"a":1}"#)).is_valid);
202 assert!(!Json.validate(&r("not json")).is_valid);
203 assert!(Json.validate(&r("```json\n{\"a\":1}\n```")).is_valid);
204 }
205
206 #[test]
207 fn test_schema() {
208 let v = JsonSchema::require(&["a", "b"]);
209 assert!(v.validate(&r(r#"{"a":1,"b":2}"#)).is_valid);
210 assert!(!v.validate(&r(r#"{"a":1}"#)).is_valid);
211 }
212
213 #[test]
214 fn test_blocklist() {
215 let v = Blocklist::llm_artifacts();
216 assert!(!v.validate(&r("Sure! Here you go")).is_valid);
217 assert!(!v.validate(&r("As an AI, I")).is_valid);
218 assert!(v.validate(&r("Normal text")).is_valid);
219 }
220
221 #[test]
222 fn test_repetition() {
223 let v = Repetition {
224 max_ratio: 0.3,
225 ngram_size: 2,
226 };
227 assert!(!v.validate(&r("the cat the cat the cat the cat")).is_valid);
228 assert!(v.validate(&r("the quick brown fox jumps")).is_valid);
229 }
230}