Skip to main content

synth_claw/validation/
dedup.rs

1use crate::config::DedupeStrategy;
2use crate::generation::GenerationResult;
3use std::collections::HashSet;
4
5pub enum Deduplicator {
6    Exact,
7    Normalized,
8    Jaccard { n: usize, threshold: f32 },
9}
10
11impl Default for Deduplicator {
12    fn default() -> Self {
13        Self::Normalized
14    }
15}
16
17impl From<&DedupeStrategy> for Deduplicator {
18    fn from(s: &DedupeStrategy) -> Self {
19        match s {
20            DedupeStrategy::Exact => Self::Exact,
21            DedupeStrategy::Normalized => Self::Normalized,
22            DedupeStrategy::Jaccard => Self::Jaccard {
23                n: 2,
24                threshold: 0.7,
25            },
26        }
27    }
28}
29
30impl Deduplicator {
31    pub fn dedupe(&self, results: Vec<GenerationResult>) -> Vec<GenerationResult> {
32        match self {
33            Self::Exact => {
34                let mut seen = HashSet::new();
35                results
36                    .into_iter()
37                    .filter(|r| seen.insert(r.content.clone()))
38                    .collect()
39            }
40            Self::Normalized => {
41                let mut seen = HashSet::new();
42                results
43                    .into_iter()
44                    .filter(|r| seen.insert(normalize(&r.content)))
45                    .collect()
46            }
47            Self::Jaccard { n, threshold } => {
48                let mut kept = vec![];
49                let mut ngrams_list: Vec<HashSet<String>> = vec![];
50
51                for r in results {
52                    let ng = ngrams(&r.content, *n);
53                    if !ngrams_list
54                        .iter()
55                        .any(|existing| jaccard(&ng, existing) >= *threshold)
56                    {
57                        ngrams_list.push(ng);
58                        kept.push(r);
59                    }
60                }
61                kept
62            }
63        }
64    }
65}
66
67fn normalize(s: &str) -> String {
68    s.to_lowercase()
69        .split_whitespace()
70        .collect::<Vec<_>>()
71        .join(" ")
72}
73
74fn ngrams(s: &str, n: usize) -> HashSet<String> {
75    let words: Vec<_> = s.split_whitespace().collect();
76    if words.len() < n {
77        return HashSet::from([s.to_lowercase()]);
78    }
79    words
80        .windows(n)
81        .map(|w| w.join(" ").to_lowercase())
82        .collect()
83}
84
85fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
86    if a.is_empty() && b.is_empty() {
87        return 1.0;
88    }
89    a.intersection(b).count() as f32 / a.union(b).count() as f32
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    fn r(s: &str) -> GenerationResult {
97        GenerationResult {
98            content: s.to_string(),
99            source_index: None,
100            category: None,
101            input_tokens: 0,
102            output_tokens: 0,
103        }
104    }
105
106    #[test]
107    fn test_exact() {
108        let d = Deduplicator::Exact;
109        let res = d.dedupe(vec![r("a"), r("a"), r("b")]);
110        assert_eq!(res.len(), 2);
111    }
112
113    #[test]
114    fn test_normalized() {
115        let d = Deduplicator::Normalized;
116        let res = d.dedupe(vec![r("Hello"), r("  hello  "), r("HELLO"), r("other")]);
117        assert_eq!(res.len(), 2);
118    }
119
120    #[test]
121    fn test_jaccard() {
122        let d = Deduplicator::Jaccard {
123            n: 2,
124            threshold: 0.5,
125        };
126        let res = d.dedupe(vec![
127            r("the quick brown fox jumps over the lazy dog"),
128            r("the quick brown fox jumps over the lazy cat"),
129            r("completely different text here"),
130        ]);
131        assert_eq!(res.len(), 2);
132    }
133}