1use crate::error::Result;
38use crate::string_metrics::{DamerauLevenshteinMetric, StringMetric};
39use std::collections::HashMap;
40use std::path::Path;
41use std::sync::Arc;
42
43use super::dictionary::DictionaryCorrector;
44use super::error_model::ErrorModel;
45use super::ngram::NGramModel;
46use super::SpellingCorrector;
47
48#[derive(Debug, Clone)]
50pub struct StatisticalCorrectorConfig {
51 pub max_edit_distance: usize,
53 pub case_sensitive: bool,
55 pub max_suggestions: usize,
57 pub min_frequency: usize,
59 pub ngram_order: usize,
61 pub language_model_weight: f64,
63 pub edit_distance_weight: f64,
65 pub use_context: bool,
67 pub context_window: usize,
69 pub max_candidates: usize,
71}
72
73impl Default for StatisticalCorrectorConfig {
74 fn default() -> Self {
75 Self {
76 max_edit_distance: 2,
77 case_sensitive: false,
78 max_suggestions: 5,
79 min_frequency: 1,
80 ngram_order: 3,
81 language_model_weight: 0.7,
82 edit_distance_weight: 0.3,
83 use_context: true,
84 context_window: 2,
85 max_candidates: 5,
86 }
87 }
88}
89
90pub struct StatisticalCorrector {
92 dictionary: HashMap<String, usize>,
94 config: StatisticalCorrectorConfig,
96 metric: Arc<dyn StringMetric + Send + Sync>,
98 language_model: NGramModel,
100 error_model: ErrorModel,
102}
103
104impl Clone for StatisticalCorrector {
105 fn clone(&self) -> Self {
106 Self {
107 dictionary: self.dictionary.clone(),
108 config: self.config.clone(),
109 metric: self.metric.clone(),
110 language_model: self.language_model.clone(),
111 error_model: self.error_model.clone(),
112 }
113 }
114}
115
116impl std::fmt::Debug for StatisticalCorrector {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("StatisticalCorrector")
119 .field("dictionary", &{
120 let dict_len = self.dictionary.len();
121 format!("<{dict_len} words>")
122 })
123 .field("config", &self.config)
124 .field("metric", &"<StringMetric>")
125 .field("language_model", &self.language_model)
126 .field("error_model", &self.error_model)
127 .finish()
128 }
129}
130
131impl Default for StatisticalCorrector {
132 fn default() -> Self {
133 let dict_corrector = DictionaryCorrector::default();
135
136 let mut language_model = NGramModel::new(3);
138
139 let sampletexts = [
141 "The quick brown fox jumps over the lazy dog.",
142 "She sells seashells by the seashore.",
143 "How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
144 "To be or not to be, that is the question.",
145 "Four score and seven years ago our fathers brought forth on this continent a new nation.",
146 "Ask not what your country can do for you, ask what you can do for your country.",
147 "That's one small step for man, one giant leap for mankind.",
148 "I have a dream that one day this nation will rise up and live out the true meaning of its creed.",
149 "The only thing we have to fear is fear itself.",
150 "We hold these truths to be self-evident, that all men are created equal.",
151 ];
153
154 for text in &sampletexts {
155 language_model.addtext(text);
156 }
157
158 Self {
159 dictionary: dict_corrector.dictionary,
160 config: StatisticalCorrectorConfig::default(),
161 metric: Arc::new(DamerauLevenshteinMetric::new()),
162 language_model,
163 error_model: ErrorModel::default(),
164 }
165 }
166}
167
168impl StatisticalCorrector {
169 pub fn new(config: StatisticalCorrectorConfig) -> Self {
171 Self {
172 config,
173 ..Default::default()
174 }
175 }
176
177 pub fn from_dictionary_corrector(dictcorrector: &DictionaryCorrector) -> Self {
179 let config = StatisticalCorrectorConfig {
180 max_edit_distance: dictcorrector.config.max_edit_distance,
181 case_sensitive: dictcorrector.config.case_sensitive,
182 max_suggestions: dictcorrector.config.max_suggestions,
183 min_frequency: dictcorrector.config.min_frequency,
184 ..StatisticalCorrectorConfig::default()
185 };
186
187 Self {
188 dictionary: dictcorrector.dictionary.clone(),
189 config,
190 metric: dictcorrector.metric.clone(),
191 language_model: NGramModel::new(3),
192 error_model: ErrorModel::default(),
193 }
194 }
195
196 pub fn add_corpus_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
198 self.language_model.add_corpus_file(path)
199 }
200
201 pub fn add_trainingtext(&mut self, text: &str) {
203 self.language_model.addtext(text);
204 }
205
206 pub fn set_language_model(&mut self, model: NGramModel) {
208 self.language_model = model;
209 }
210
211 pub fn set_error_model(&mut self, model: ErrorModel) {
213 self.error_model = model;
214 }
215
216 pub fn set_metric<M: StringMetric + Send + Sync + 'static>(&mut self, metric: M) {
218 self.metric = Arc::new(metric);
219 }
220
221 pub fn set_config(&mut self, config: StatisticalCorrectorConfig) {
223 self.config = config;
224 }
225
226 fn get_contextual_corrections(&self, word: &str, context: &[String]) -> Vec<(String, f64)> {
228 if self.is_correct(word) {
230 return vec![(word.to_string(), 1.0)];
231 }
232
233 let word_to_check = if !self.config.case_sensitive {
235 word.to_lowercase()
236 } else {
237 word.to_string()
238 };
239
240 let mut candidates: Vec<(String, f64)> = Vec::new();
241
242 for (dict_word, frequency) in &self.dictionary {
244 if *frequency < self.config.min_frequency {
245 continue;
246 }
247
248 let dict_word_normalized = if !self.config.case_sensitive {
249 dict_word.to_lowercase()
250 } else {
251 dict_word.clone()
252 };
253
254 if dict_word_normalized.len() > word_to_check.len() + self.config.max_edit_distance
256 || dict_word_normalized.len() + self.config.max_edit_distance < word_to_check.len()
257 {
258 continue;
259 }
260
261 if let Ok(distance) = self.metric.distance(&word_to_check, &dict_word_normalized) {
263 let distance_usize = distance.round() as usize;
265 if distance_usize <= self.config.max_edit_distance {
266 let edit_score = 1.0 / (1.0 + distance);
268
269 let lm_score = if self.config.use_context {
271 self.language_model.probability(dict_word, context)
272 } else {
273 self.language_model.unigram_probability(dict_word)
274 };
275
276 let error_score = self
278 .error_model
279 .error_probability(&word_to_check, &dict_word_normalized);
280
281 let combined_score = (self.config.edit_distance_weight * edit_score)
283 + (self.config.language_model_weight * lm_score * error_score);
284
285 candidates.push((dict_word.clone(), combined_score));
286 }
287 }
288 }
289
290 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
292
293 candidates.truncate(self.config.max_suggestions);
295
296 candidates
297 }
298
299 pub fn correct_sentence(&self, sentence: &str) -> Result<String> {
301 let words: Vec<String> = sentence
302 .split_whitespace()
303 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
304 .filter(|s| !s.is_empty())
305 .collect();
306
307 if words.is_empty() {
308 return Ok(sentence.to_string());
309 }
310
311 if !self.config.use_context {
313 let mut result = sentence.to_string();
314
315 for word in &words {
316 if !self.is_correct(word) {
317 if let Ok(correction) = self.correct(word) {
318 if correction != *word {
319 result = result.replace(word, &correction);
321 }
322 }
323 }
324 }
325
326 return Ok(result);
327 }
328
329 let context_window = self.config.context_window;
331 let max_candidates = self.config.max_candidates;
332
333 let mut beams: Vec<(Vec<String>, f64, Vec<String>)> = vec![(Vec::new(), 0.0, Vec::new())];
336
337 for word in &words {
339 let mut new_beams = Vec::new();
340
341 for (partial, score, context) in beams {
342 let candidates = self.get_contextual_corrections(word, &context);
344
345 for (candidate, candidate_score) in candidates.iter().take(max_candidates) {
347 let mut new_partial = partial.clone();
348 new_partial.push(candidate.clone());
349
350 let mut new_context = context.clone();
351 new_context.push(candidate.clone());
352 if new_context.len() > context_window {
353 new_context.remove(0);
354 }
355
356 let new_score = score + candidate_score;
357 new_beams.push((new_partial, new_score, new_context));
358 }
359 }
360
361 new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363 new_beams.truncate(max_candidates);
364
365 beams = new_beams;
366 }
367
368 if let Some((best_sentence, _, _)) = beams.first() {
370 let mut result = sentence.to_string();
372
373 for (i, original) in words.iter().enumerate() {
375 if i < best_sentence.len() && original != &best_sentence[i] {
376 result = result.replace(original, &best_sentence[i]);
377 }
378 }
379
380 Ok(result)
381 } else {
382 Ok(sentence.to_string())
384 }
385 }
386
387 pub fn add_word(&mut self, word: &str, frequency: usize) {
389 self.dictionary.insert(word.to_string(), frequency);
390 }
391
392 pub fn remove_word(&mut self, word: &str) {
394 self.dictionary.remove(word);
395 }
396
397 pub fn dictionary_size(&self) -> usize {
399 self.dictionary.len()
400 }
401
402 pub fn vocabulary_size(&self) -> usize {
404 self.language_model.vocabulary_size()
405 }
406}
407
408impl SpellingCorrector for StatisticalCorrector {
409 fn correct(&self, word: &str) -> Result<String> {
410 if self.is_correct(word) {
412 return Ok(word.to_string());
413 }
414
415 let suggestions = self.get_suggestions(word, 1)?;
417
418 if suggestions.is_empty() {
419 Ok(word.to_string())
421 } else {
422 Ok(suggestions[0].clone())
424 }
425 }
426
427 fn get_suggestions(&self, word: &str, limit: usize) -> Result<Vec<String>> {
428 if self.is_correct(word) {
430 return Ok(vec![word.to_string()]);
431 }
432
433 let candidates = self.get_contextual_corrections(word, &[]);
435
436 let suggestions = candidates
438 .into_iter()
439 .map(|(word, _)| word)
440 .take(limit)
441 .collect();
442
443 Ok(suggestions)
444 }
445
446 fn is_correct(&self, word: &str) -> bool {
447 if self.config.case_sensitive {
448 self.dictionary.contains_key(word)
449 } else {
450 self.dictionary
451 .keys()
452 .any(|dict_word| dict_word.to_lowercase() == word.to_lowercase())
453 }
454 }
455
456 fn correcttext(&self, text: &str) -> Result<String> {
458 let sentences: Vec<&str> = text
460 .split(['.', '?', '!'])
461 .map(|s| s.trim())
462 .filter(|s| !s.is_empty())
463 .collect();
464
465 if sentences.is_empty() {
466 return Ok(text.to_string());
467 }
468
469 let mut result = text.to_string();
470
471 for sentence in sentences {
473 if sentence.trim().is_empty() {
474 continue;
475 }
476
477 let corrected_sentence = self.correct_sentence(sentence)?;
478 if corrected_sentence != sentence {
479 result = result.replace(sentence, &corrected_sentence);
481 }
482 }
483
484 Ok(result)
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_statistical_corrector_basic() {
494 let mut corrector = StatisticalCorrector::default();
495
496 corrector.add_trainingtext("The quick brown fox jumps over the lazy dog.");
498 corrector.add_trainingtext("Programming languages like Python and Rust are popular.");
499 corrector.add_trainingtext("I received your message about the meeting tomorrow.");
500
501 corrector.add_word("received", 100);
503 corrector.add_word("message", 100);
504 corrector.add_word("meeting", 100);
505 corrector.add_word("tomorrow", 100);
506
507 assert_eq!(corrector.correct("recieved").unwrap(), "received");
509 assert_eq!(corrector.correct("mesage").unwrap(), "message");
510
511 let text = "I recieved your mesage about the meating tommorow.";
513 let corrected = corrector.correcttext(text).unwrap();
514
515 assert!(corrected.contains("received"));
517 assert!(corrected.contains("message"));
518 assert!(corrected.contains("meeting"));
519 assert!(corrected.contains("tomorrow"));
520 }
521
522 #[test]
523 fn test_statistical_corrector_context_aware() {
524 let mut corrector = StatisticalCorrector::default();
525
526 corrector.add_trainingtext("I went to the bank to deposit money.");
528 corrector.add_trainingtext("The river bank was muddy after the rain.");
529 corrector.add_trainingtext("I need to address the issues in the meeting.");
530 corrector.add_trainingtext("What is your home address?");
531
532 corrector.add_word("bank", 100);
534 corrector.add_word("deposit", 100);
535 corrector.add_word("money", 100);
536 corrector.add_word("river", 100);
537 corrector.add_word("muddy", 100);
538 corrector.add_word("rain", 100);
539
540 let text1 = "I went to the bnk to deposit money.";
542 let text2 = "The river bnk was muddy after the rain.";
543
544 let corrected1 = corrector.correcttext(text1).unwrap();
545 let corrected2 = corrector.correcttext(text2).unwrap();
546
547 assert!(corrected1.contains("bank"));
549 assert!(corrected2.contains("bank"));
550 }
551
552 #[test]
553 fn test_from_dictionary_corrector() {
554 let dict_corrector = DictionaryCorrector::default();
555 let stat_corrector = StatisticalCorrector::from_dictionary_corrector(&dict_corrector);
556
557 assert_eq!(
559 dict_corrector.dictionary_size(),
560 stat_corrector.dictionary_size()
561 );
562
563 let word = "recieve";
567 assert!(dict_corrector.correct(word).is_ok());
568 assert!(stat_corrector.correct(word).is_ok());
569
570 assert_eq!(
572 dict_corrector.is_correct("receive"),
573 stat_corrector.is_correct("receive")
574 );
575 }
576}