scirs2_text/spelling/
ngram.rs1use std::collections::{HashMap, HashSet};
34use std::fs::File;
35use std::io::{BufRead, BufReader};
36use std::path::Path;
37
38use crate::error::{Result, TextError};
39
40#[derive(Clone)]
42pub struct NGramModel {
43 unigrams: HashMap<String, usize>,
45 bigrams: HashMap<(String, String), usize>,
47 trigrams: HashMap<(String, String, String), usize>,
49 total_words: usize,
51 order: usize,
53 start_token: String,
55 end_token: String,
57}
58
59impl NGramModel {
60 pub fn new(order: usize) -> Self {
62 if order > 3 {
63 eprintln!("Warning: NGramModel only supports orders up to 3. Using order=3.");
65 }
66
67 Self {
68 unigrams: HashMap::new(),
69 bigrams: HashMap::new(),
70 trigrams: HashMap::new(),
71 total_words: 0,
72 order: order.clamp(1, 3),
73 start_token: "<s>".to_string(),
74 end_token: "</s>".to_string(),
75 }
76 }
77
78 pub fn addtext(&mut self, text: &str) {
80 let words: Vec<String> = text
81 .split_whitespace()
82 .map(|s| {
83 s.trim_matches(|c: char| !c.is_alphanumeric())
84 .to_lowercase()
85 })
86 .filter(|s| !s.is_empty())
87 .collect();
88
89 if words.is_empty() {
90 return;
91 }
92
93 let mut current_sentence = Vec::new();
95
96 for word in words {
97 let is_end = word.ends_with('.') || word.ends_with('?') || word.ends_with('!');
99
100 let clean_word = word
102 .trim_matches(|c: char| !c.is_alphanumeric())
103 .to_string();
104 if !clean_word.is_empty() {
105 current_sentence.push(clean_word);
106 self.total_words += 1;
107 }
108
109 if is_end && !current_sentence.is_empty() {
111 self.process_sentence(¤t_sentence);
112 current_sentence.clear();
113 }
114 }
115
116 if !current_sentence.is_empty() {
118 self.process_sentence(¤t_sentence);
119 }
120 }
121
122 fn process_sentence(&mut self, sentence: &[String]) {
124 let mut words = Vec::with_capacity(sentence.len() + 2);
126 words.push(self.start_token.clone());
127 words.extend(sentence.iter().cloned());
128 words.push(self.end_token.clone());
129
130 for word in &words {
132 *self.unigrams.entry(word.clone()).or_insert(0) += 1;
133 }
134
135 if self.order >= 2 {
137 for i in 0..words.len() - 1 {
138 let bigram = (words[i].clone(), words[i + 1].clone());
139 *self.bigrams.entry(bigram).or_insert(0) += 1;
140 }
141 }
142
143 if self.order >= 3 {
145 for i in 0..words.len() - 2 {
146 let trigram = (words[i].clone(), words[i + 1].clone(), words[i + 2].clone());
147 *self.trigrams.entry(trigram).or_insert(0) += 1;
148 }
149 }
150 }
151
152 pub fn add_corpus_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
154 let file = File::open(path)
155 .map_err(|e| TextError::IoError(format!("Failed to open corpus file: {e}")))?;
156
157 let reader = BufReader::new(file);
158
159 for line in reader.lines() {
160 let line = line.map_err(|e| {
161 TextError::IoError(format!("Failed to read line from corpus file: {e}"))
162 })?;
163
164 if line.trim().is_empty() {
166 continue;
167 }
168
169 self.addtext(&line);
170 }
171
172 Ok(())
173 }
174
175 pub fn probability(&self, word: &str, context: &[String]) -> f64 {
177 match self.order {
178 1 => self.unigram_probability(word),
179 2 => self.bigram_probability(word, context),
180 3 => self.trigram_probability(word, context),
181 _ => self.unigram_probability(word), }
183 }
184
185 pub fn unigram_probability(&self, word: &str) -> f64 {
187 let word_count = self.unigrams.get(word).copied().unwrap_or(0);
188
189 let vocabulary_size = self.unigrams.len();
191 (word_count as f64 + 1.0) / (self.total_words as f64 + vocabulary_size as f64)
192 }
193
194 pub fn bigram_probability(&self, word: &str, context: &[String]) -> f64 {
196 if context.is_empty() {
197 return self.unigram_probability(word);
198 }
199
200 let previous = &context[context.len() - 1];
201
202 let bigram_count = self
203 .bigrams
204 .get(&(previous.clone(), word.to_string()))
205 .copied()
206 .unwrap_or(0);
207
208 let previous_count = self.unigrams.get(previous).copied().unwrap_or(0);
209
210 if previous_count == 0 {
211 return self.unigram_probability(word);
212 }
213
214 let vocabulary_size = self.unigrams.len();
216 (bigram_count as f64 + 1.0) / (previous_count as f64 + vocabulary_size as f64)
217 }
218
219 pub fn trigram_probability(&self, word: &str, context: &[String]) -> f64 {
221 if context.len() < 2 {
222 return self.bigram_probability(word, context);
223 }
224
225 let previous1 = &context[context.len() - 2];
226 let previous2 = &context[context.len() - 1];
227
228 let trigram_count = self
229 .trigrams
230 .get(&(previous1.clone(), previous2.clone(), word.to_string()))
231 .copied()
232 .unwrap_or(0);
233
234 let bigram_count = self
235 .bigrams
236 .get(&(previous1.clone(), previous2.clone()))
237 .copied()
238 .unwrap_or(0);
239
240 if bigram_count == 0 {
241 return self.bigram_probability(word, &[previous2.clone()]);
242 }
243
244 let vocabulary_size = self.unigrams.len();
246 (trigram_count as f64 + 1.0) / (bigram_count as f64 + vocabulary_size as f64)
247 }
248
249 pub fn perplexity(&self, text: &str) -> f64 {
251 let words: Vec<String> = text
252 .split_whitespace()
253 .map(|s| {
254 s.trim_matches(|c: char| !c.is_alphanumeric())
255 .to_lowercase()
256 })
257 .filter(|s| !s.is_empty())
258 .collect();
259
260 if words.is_empty() {
261 return f64::INFINITY;
262 }
263
264 let mut log_prob_sum = 0.0;
265 let mut context = Vec::new();
266
267 for word in words.iter() {
268 let prob = self.probability(word, &context);
269 log_prob_sum += (prob + 1e-10).log2(); context.push(word.clone());
273 if context.len() > self.order {
274 context.remove(0);
275 }
276 }
277
278 2.0f64.powf(-log_prob_sum / words.len() as f64)
280 }
281
282 pub fn vocabulary_size(&self) -> usize {
284 self.unigrams.len()
285 }
286
287 pub fn total_words(&self) -> usize {
289 self.total_words
290 }
291
292 pub fn word_frequency(&self, word: &str) -> usize {
294 self.unigrams.get(word).copied().unwrap_or(0)
295 }
296
297 pub fn generate_typos(&self, word: &str, numtypos: usize) -> Vec<String> {
299 let mut typos = HashSet::new();
300 let word = word.to_lowercase();
301 let chars: Vec<char> = word.chars().collect();
302
303 for i in 0..chars.len() {
305 let mut new_word = String::new();
306 for (j, &c) in chars.iter().enumerate() {
307 if j != i {
308 new_word.push(c);
309 }
310 }
311 typos.insert(new_word);
312 }
313
314 for i in 0..chars.len() - 1 {
316 let mut new_chars = chars.clone();
317 new_chars.swap(i, i + 1);
318 typos.insert(new_chars.iter().collect());
319 }
320
321 for i in 0..=chars.len() {
323 for c in 'a'..='z' {
324 let mut new_chars = chars.clone();
325 new_chars.insert(i, c);
326 typos.insert(new_chars.iter().collect());
327 }
328 }
329
330 for i in 0..chars.len() {
332 for c in 'a'..='z' {
333 if chars[i] != c {
334 let mut new_chars = chars.clone();
335 new_chars[i] = c;
336 typos.insert(new_chars.iter().collect());
337 }
338 }
339 }
340
341 let mut typos_vec: Vec<_> = typos.into_iter().collect();
343
344 typos_vec.sort_by(|a, b| {
346 let freq_a = self.word_frequency(a);
347 let freq_b = self.word_frequency(b);
348 freq_b.cmp(&freq_a) });
350
351 typos_vec.truncate(numtypos);
353
354 typos_vec
355 }
356}
357
358impl std::fmt::Debug for NGramModel {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("NGramModel")
361 .field("order", &self.order)
362 .field("vocabulary_size", &self.vocabulary_size())
363 .field("total_words", &self.total_words)
364 .field("unigrams", &{
365 let unigram_len = self.unigrams.len();
366 format!("<{unigram_len} entries>")
367 })
368 .field("bigrams", &{
369 let bigram_len = self.bigrams.len();
370 format!("<{bigram_len} entries>")
371 })
372 .field("trigrams", &{
373 let trigram_len = self.trigrams.len();
374 format!("<{trigram_len} entries>")
375 })
376 .finish()
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_ngram_model_basics() {
386 let mut model = NGramModel::new(3);
387
388 model.addtext("The quick brown fox jumps over the lazy dog.");
390
391 let p_the = model.unigram_probability("the");
393 let p_quick = model.unigram_probability("quick");
394 let p_unknown = model.unigram_probability("unknown");
395
396 assert!(p_the > p_quick);
398
399 assert!(p_unknown > 0.0);
401
402 let p_quick_given_the = model.bigram_probability("quick", &["the".to_string()]);
404 let p_brown_given_quick = model.bigram_probability("brown", &["quick".to_string()]);
405
406 assert!(p_quick_given_the > 0.0);
408 assert!(p_brown_given_quick > 0.0);
409
410 let p_fox_given_quick_brown =
412 model.trigram_probability("fox", &["quick".to_string(), "brown".to_string()]);
413
414 assert!(p_fox_given_quick_brown > 0.0);
416 }
417}