shodh_memory/embeddings/
keywords.rs1use std::collections::HashSet;
20use yake_rust::{get_n_best, Config, StopWords};
21
22#[derive(Debug, Clone)]
24pub struct KeywordConfig {
25 pub max_keywords: usize,
27 pub ngrams: usize,
29 pub min_length: usize,
31 pub language: String,
33 pub dedup_threshold: f64,
35}
36
37impl Default for KeywordConfig {
38 fn default() -> Self {
39 Self {
40 max_keywords: 10,
41 ngrams: 2,
42 min_length: 3,
43 language: "en".to_string(),
44 dedup_threshold: 0.9,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct Keyword {
52 pub text: String,
54 pub score: f64,
56 pub importance: f32,
58}
59
60pub struct KeywordExtractor {
62 config: KeywordConfig,
63 stopwords: StopWords,
64}
65
66impl KeywordExtractor {
67 pub fn new() -> Self {
69 Self::with_config(KeywordConfig::default())
70 }
71
72 pub fn with_config(config: KeywordConfig) -> Self {
74 let stopwords = StopWords::predefined(&config.language)
76 .or_else(|| StopWords::predefined("en"))
77 .unwrap_or_else(|| StopWords::custom(HashSet::new()));
78 Self { config, stopwords }
79 }
80
81 pub fn extract(&self, text: &str) -> Vec<Keyword> {
83 if text.trim().is_empty() {
84 return Vec::new();
85 }
86
87 let punctuation: HashSet<char> = [
89 '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';',
90 '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~',
91 ]
92 .into_iter()
93 .collect();
94
95 let yake_config = Config {
96 ngrams: self.config.ngrams,
97 punctuation,
98 remove_duplicates: true,
99 deduplication_threshold: self.config.dedup_threshold,
100 minimum_chars: self.config.min_length,
101 ..Config::default()
102 };
103
104 let results = get_n_best(
105 self.config.max_keywords,
106 text,
107 &self.stopwords,
108 &yake_config,
109 );
110
111 let mut keywords: Vec<Keyword> = results
115 .into_iter()
116 .map(|item| {
117 let importance = (1.0 / (1.0 + item.score)) as f32;
120 Keyword {
121 text: item.keyword, score: item.score,
123 importance,
124 }
125 })
126 .collect();
127
128 keywords.sort_by(|a, b| b.importance.total_cmp(&a.importance));
130 keywords
131 }
132
133 pub fn extract_texts(&self, text: &str) -> Vec<String> {
135 self.extract(text).into_iter().map(|k| k.text).collect()
136 }
137
138 pub fn extract_filtered(&self, text: &str, min_importance: f32) -> Vec<Keyword> {
140 self.extract(text)
141 .into_iter()
142 .filter(|k| k.importance >= min_importance)
143 .collect()
144 }
145}
146
147impl Default for KeywordExtractor {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_extract_basic() {
159 let extractor = KeywordExtractor::new();
160 let text = "Caroline painted a beautiful sunrise over the lake yesterday morning.";
161 let keywords = extractor.extract(text);
162
163 assert!(!keywords.is_empty());
165
166 let texts: Vec<&str> = keywords.iter().map(|k| k.text.as_str()).collect();
168 assert!(
169 texts.contains(&"sunrise") || texts.contains(&"beautiful sunrise"),
170 "Should extract 'sunrise': {texts:?}"
171 );
172 }
173
174 #[test]
175 fn test_extract_texts() {
176 let extractor = KeywordExtractor::new();
177 let text = "The quick brown fox jumps over the lazy dog near the river.";
178 let texts = extractor.extract_texts(text);
179
180 assert!(!texts.is_empty());
181 for t in &texts {
183 assert_eq!(t.to_lowercase(), *t);
184 }
185 }
186
187 #[test]
188 fn test_empty_text() {
189 let extractor = KeywordExtractor::new();
190 let keywords = extractor.extract("");
191 assert!(keywords.is_empty());
192 }
193
194 #[test]
195 fn test_importance_ordering() {
196 let extractor = KeywordExtractor::new();
197 let text =
198 "Machine learning and artificial intelligence are transforming computer science.";
199 let keywords = extractor.extract(text);
200
201 for i in 1..keywords.len() {
203 assert!(keywords[i - 1].importance >= keywords[i].importance);
204 }
205 }
206
207 #[test]
208 fn test_filter_by_importance() {
209 let extractor = KeywordExtractor::new();
210 let text = "The conference discussed various topics including climate change and renewable energy.";
211 let filtered = extractor.extract_filtered(text, 0.5);
212
213 for k in filtered {
214 assert!(k.importance >= 0.5);
215 }
216 }
217
218 #[test]
219 fn test_custom_config() {
220 let config = KeywordConfig {
221 max_keywords: 5,
222 ngrams: 3,
223 min_length: 4,
224 ..Default::default()
225 };
226 let extractor = KeywordExtractor::with_config(config);
227 let text = "Natural language processing enables computers to understand human language.";
228 let keywords = extractor.extract(text);
229
230 assert!(keywords.len() <= 5);
231 for k in &keywords {
232 assert!(k.text.chars().count() >= 4);
233 }
234 }
235}