Skip to main content

shodh_memory/embeddings/
keywords.rs

1//! Statistical Keyword Extraction using YAKE
2//!
3//! Extracts salient keywords from text using statistical features:
4//! - Position in text (earlier = more important)
5//! - Word frequency and distribution across sentences
6//! - Capitalization patterns
7//! - Word length and structure
8//!
9//! Unlike NER which only extracts named entities (Person, Org, Location, Misc),
10//! keyword extraction captures any semantically important terms including:
11//! - Common nouns: "sunrise", "painting", "lake"
12//! - Verbs/actions: "painted", "visited", "bought"
13//! - Adjectives: "beautiful", "expensive", "favorite"
14//!
15//! This is critical for graph traversal in multi-hop reasoning where
16//! query terms like "sunrise" need to match graph nodes but aren't
17//! named entities detectable by NER.
18
19use std::collections::HashSet;
20use yake_rust::{get_n_best, Config, StopWords};
21
22/// Configuration for keyword extraction
23#[derive(Debug, Clone)]
24pub struct KeywordConfig {
25    /// Maximum number of keywords to extract
26    pub max_keywords: usize,
27    /// Maximum n-gram size (1=unigrams, 2=bigrams, 3=trigrams)
28    pub ngrams: usize,
29    /// Minimum keyword length
30    pub min_length: usize,
31    /// Language for stopwords
32    pub language: String,
33    /// Deduplication threshold (0.0-1.0, higher = stricter dedup)
34    pub dedup_threshold: f64,
35}
36
37impl Default for KeywordConfig {
38    fn default() -> Self {
39        Self {
40            max_keywords: 10,
41            ngrams: 2,
42            min_length: 3,
43            language: "en".to_string(),
44            dedup_threshold: 0.9,
45        }
46    }
47}
48
49/// A keyword extracted from text with its importance score
50#[derive(Debug, Clone)]
51pub struct Keyword {
52    /// The keyword text (normalized)
53    pub text: String,
54    /// YAKE score (lower = more important, typically 0.0-1.0)
55    pub score: f64,
56    /// Normalized importance (0.0-1.0, higher = more important)
57    pub importance: f32,
58}
59
60/// Keyword extractor using YAKE algorithm
61pub struct KeywordExtractor {
62    config: KeywordConfig,
63    stopwords: StopWords,
64}
65
66impl KeywordExtractor {
67    /// Create a new keyword extractor with default config
68    pub fn new() -> Self {
69        Self::with_config(KeywordConfig::default())
70    }
71
72    /// Create a new keyword extractor with custom config
73    pub fn with_config(config: KeywordConfig) -> Self {
74        // StopWords::predefined returns Option, fallback to empty set if not found
75        let stopwords = StopWords::predefined(&config.language)
76            .or_else(|| StopWords::predefined("en"))
77            .unwrap_or_else(|| StopWords::custom(HashSet::new()));
78        Self { config, stopwords }
79    }
80
81    /// Extract keywords from text
82    pub fn extract(&self, text: &str) -> Vec<Keyword> {
83        if text.trim().is_empty() {
84            return Vec::new();
85        }
86
87        // Standard punctuation set
88        let punctuation: HashSet<char> = [
89            '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';',
90            '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~',
91        ]
92        .into_iter()
93        .collect();
94
95        let yake_config = Config {
96            ngrams: self.config.ngrams,
97            punctuation,
98            remove_duplicates: true,
99            deduplication_threshold: self.config.dedup_threshold,
100            minimum_chars: self.config.min_length,
101            ..Config::default()
102        };
103
104        let results = get_n_best(
105            self.config.max_keywords,
106            text,
107            &self.stopwords,
108            &yake_config,
109        );
110
111        // Convert YAKE results to Keywords
112        // YAKE score: lower = better, typically 0.0-0.5 for important keywords
113        // We invert this to importance: higher = better
114        let mut keywords: Vec<Keyword> = results
115            .into_iter()
116            .map(|item| {
117                // Convert YAKE score (lower=better) to importance (higher=better)
118                // Use sigmoid-like transformation: importance = 1 / (1 + score)
119                let importance = (1.0 / (1.0 + item.score)) as f32;
120                Keyword {
121                    text: item.keyword, // Already lowercased
122                    score: item.score,
123                    importance,
124                }
125            })
126            .collect();
127
128        // Sort by importance descending
129        keywords.sort_by(|a, b| b.importance.total_cmp(&a.importance));
130        keywords
131    }
132
133    /// Extract keyword texts only (for graph node creation)
134    pub fn extract_texts(&self, text: &str) -> Vec<String> {
135        self.extract(text).into_iter().map(|k| k.text).collect()
136    }
137
138    /// Extract keywords with minimum importance threshold
139    pub fn extract_filtered(&self, text: &str, min_importance: f32) -> Vec<Keyword> {
140        self.extract(text)
141            .into_iter()
142            .filter(|k| k.importance >= min_importance)
143            .collect()
144    }
145}
146
147impl Default for KeywordExtractor {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_extract_basic() {
159        let extractor = KeywordExtractor::new();
160        let text = "Caroline painted a beautiful sunrise over the lake yesterday morning.";
161        let keywords = extractor.extract(text);
162
163        // Should extract key terms
164        assert!(!keywords.is_empty());
165
166        // Check that we got some expected keywords
167        let texts: Vec<&str> = keywords.iter().map(|k| k.text.as_str()).collect();
168        assert!(
169            texts.contains(&"sunrise") || texts.contains(&"beautiful sunrise"),
170            "Should extract 'sunrise': {texts:?}"
171        );
172    }
173
174    #[test]
175    fn test_extract_texts() {
176        let extractor = KeywordExtractor::new();
177        let text = "The quick brown fox jumps over the lazy dog near the river.";
178        let texts = extractor.extract_texts(text);
179
180        assert!(!texts.is_empty());
181        // All should be lowercase
182        for t in &texts {
183            assert_eq!(t.to_lowercase(), *t);
184        }
185    }
186
187    #[test]
188    fn test_empty_text() {
189        let extractor = KeywordExtractor::new();
190        let keywords = extractor.extract("");
191        assert!(keywords.is_empty());
192    }
193
194    #[test]
195    fn test_importance_ordering() {
196        let extractor = KeywordExtractor::new();
197        let text =
198            "Machine learning and artificial intelligence are transforming computer science.";
199        let keywords = extractor.extract(text);
200
201        // Should be sorted by importance (descending)
202        for i in 1..keywords.len() {
203            assert!(keywords[i - 1].importance >= keywords[i].importance);
204        }
205    }
206
207    #[test]
208    fn test_filter_by_importance() {
209        let extractor = KeywordExtractor::new();
210        let text = "The conference discussed various topics including climate change and renewable energy.";
211        let filtered = extractor.extract_filtered(text, 0.5);
212
213        for k in filtered {
214            assert!(k.importance >= 0.5);
215        }
216    }
217
218    #[test]
219    fn test_custom_config() {
220        let config = KeywordConfig {
221            max_keywords: 5,
222            ngrams: 3,
223            min_length: 4,
224            ..Default::default()
225        };
226        let extractor = KeywordExtractor::with_config(config);
227        let text = "Natural language processing enables computers to understand human language.";
228        let keywords = extractor.extract(text);
229
230        assert!(keywords.len() <= 5);
231        for k in &keywords {
232            assert!(k.text.chars().count() >= 4);
233        }
234    }
235}