synth_claw/validation/
dedup.rs1use crate::config::DedupeStrategy;
2use crate::generation::GenerationResult;
3use std::collections::HashSet;
4
5pub enum Deduplicator {
6 Exact,
7 Normalized,
8 Jaccard { n: usize, threshold: f32 },
9}
10
11impl Default for Deduplicator {
12 fn default() -> Self {
13 Self::Normalized
14 }
15}
16
17impl From<&DedupeStrategy> for Deduplicator {
18 fn from(s: &DedupeStrategy) -> Self {
19 match s {
20 DedupeStrategy::Exact => Self::Exact,
21 DedupeStrategy::Normalized => Self::Normalized,
22 DedupeStrategy::Jaccard => Self::Jaccard {
23 n: 2,
24 threshold: 0.7,
25 },
26 }
27 }
28}
29
30impl Deduplicator {
31 pub fn dedupe(&self, results: Vec<GenerationResult>) -> Vec<GenerationResult> {
32 match self {
33 Self::Exact => {
34 let mut seen = HashSet::new();
35 results
36 .into_iter()
37 .filter(|r| seen.insert(r.content.clone()))
38 .collect()
39 }
40 Self::Normalized => {
41 let mut seen = HashSet::new();
42 results
43 .into_iter()
44 .filter(|r| seen.insert(normalize(&r.content)))
45 .collect()
46 }
47 Self::Jaccard { n, threshold } => {
48 let mut kept = vec![];
49 let mut ngrams_list: Vec<HashSet<String>> = vec![];
50
51 for r in results {
52 let ng = ngrams(&r.content, *n);
53 if !ngrams_list
54 .iter()
55 .any(|existing| jaccard(&ng, existing) >= *threshold)
56 {
57 ngrams_list.push(ng);
58 kept.push(r);
59 }
60 }
61 kept
62 }
63 }
64 }
65}
66
67fn normalize(s: &str) -> String {
68 s.to_lowercase()
69 .split_whitespace()
70 .collect::<Vec<_>>()
71 .join(" ")
72}
73
74fn ngrams(s: &str, n: usize) -> HashSet<String> {
75 let words: Vec<_> = s.split_whitespace().collect();
76 if words.len() < n {
77 return HashSet::from([s.to_lowercase()]);
78 }
79 words
80 .windows(n)
81 .map(|w| w.join(" ").to_lowercase())
82 .collect()
83}
84
85fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
86 if a.is_empty() && b.is_empty() {
87 return 1.0;
88 }
89 a.intersection(b).count() as f32 / a.union(b).count() as f32
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 fn r(s: &str) -> GenerationResult {
97 GenerationResult {
98 content: s.to_string(),
99 source_index: None,
100 category: None,
101 input_tokens: 0,
102 output_tokens: 0,
103 }
104 }
105
106 #[test]
107 fn test_exact() {
108 let d = Deduplicator::Exact;
109 let res = d.dedupe(vec![r("a"), r("a"), r("b")]);
110 assert_eq!(res.len(), 2);
111 }
112
113 #[test]
114 fn test_normalized() {
115 let d = Deduplicator::Normalized;
116 let res = d.dedupe(vec![r("Hello"), r(" hello "), r("HELLO"), r("other")]);
117 assert_eq!(res.len(), 2);
118 }
119
120 #[test]
121 fn test_jaccard() {
122 let d = Deduplicator::Jaccard {
123 n: 2,
124 threshold: 0.5,
125 };
126 let res = d.dedupe(vec![
127 r("the quick brown fox jumps over the lazy dog"),
128 r("the quick brown fox jumps over the lazy cat"),
129 r("completely different text here"),
130 ]);
131 assert_eq!(res.len(), 2);
132 }
133}