scirs2_text/spelling/
statistical.rs

1//! Statistical spelling correction using language models and error models
2//!
3//! This module provides statistical spelling correction that combines dictionary-based
4//! approaches with language models and error models for context-aware corrections.
5//!
6//! # Key Components
7//!
8//! - `StatisticalCorrector`: Main implementation of statistical spelling correction
9//! - `StatisticalCorrectorConfig`: Configuration options for the statistical corrector
10//!
11//! # Example
12//!
13//! ```
14//! use scirs2_text::spelling::{StatisticalCorrector, SpellingCorrector};
15//!
16//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! // Create a statistical spelling corrector
18//! let mut corrector = StatisticalCorrector::default();
19//!
20//! // Directly add the words we want to test with
21//! corrector.add_word("received", 1000);
22//! corrector.add_word("message", 1000);
23//! corrector.add_word("meeting", 1000);
24//!
25//! // Correct individual misspelled words
26//! assert_eq!(corrector.correct("recieved")?, "received");
27//! assert_eq!(corrector.correct("mesage")?, "message");
28//!
29//! // For text correction, just verify it runs without errors
30//! let text = "I recieved your mesage about the meeting.";
31//! let corrected = corrector.correcttext(text)?;
32//! assert!(!corrected.is_empty());
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::error::Result;
38use crate::string_metrics::{DamerauLevenshteinMetric, StringMetric};
39use std::collections::HashMap;
40use std::path::Path;
41use std::sync::Arc;
42
43use super::dictionary::DictionaryCorrector;
44use super::error_model::ErrorModel;
45use super::ngram::NGramModel;
46use super::SpellingCorrector;
47
48/// Configuration for the statistical spelling corrector
49#[derive(Debug, Clone)]
50pub struct StatisticalCorrectorConfig {
51    /// Maximum edit distance to consider for corrections
52    pub max_edit_distance: usize,
53    /// Whether to use case-sensitive matching
54    pub case_sensitive: bool,
55    /// Maximum number of suggestions to consider
56    pub max_suggestions: usize,
57    /// Minimum word frequency to consider for suggestions
58    pub min_frequency: usize,
59    /// N-gram order for language model (1, 2, or 3)
60    pub ngram_order: usize,
61    /// Weighting factor for language model scores (0.0-1.0)
62    pub language_model_weight: f64,
63    /// Weighting factor for edit distance scores (0.0-1.0)
64    pub edit_distance_weight: f64,
65    /// Whether to use contextual information for correction
66    pub use_context: bool,
67    /// Context window size (in words) for contextual correction
68    pub context_window: usize,
69    /// Maximum number of candidate words to consider for each position
70    pub max_candidates: usize,
71}
72
73impl Default for StatisticalCorrectorConfig {
74    fn default() -> Self {
75        Self {
76            max_edit_distance: 2,
77            case_sensitive: false,
78            max_suggestions: 5,
79            min_frequency: 1,
80            ngram_order: 3,
81            language_model_weight: 0.7,
82            edit_distance_weight: 0.3,
83            use_context: true,
84            context_window: 2,
85            max_candidates: 5,
86        }
87    }
88}
89
90/// Statistical spelling corrector
91pub struct StatisticalCorrector {
92    /// Dictionary of words and their frequencies
93    dictionary: HashMap<String, usize>,
94    /// Configuration for the corrector
95    config: StatisticalCorrectorConfig,
96    /// Metric to use for string similarity
97    metric: Arc<dyn StringMetric + Send + Sync>,
98    /// Language model for context-aware correction
99    language_model: NGramModel,
100    /// Error model for the noisy channel model
101    error_model: ErrorModel,
102}
103
104impl Clone for StatisticalCorrector {
105    fn clone(&self) -> Self {
106        Self {
107            dictionary: self.dictionary.clone(),
108            config: self.config.clone(),
109            metric: self.metric.clone(),
110            language_model: self.language_model.clone(),
111            error_model: self.error_model.clone(),
112        }
113    }
114}
115
116impl std::fmt::Debug for StatisticalCorrector {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("StatisticalCorrector")
119            .field("dictionary", &{
120                let dict_len = self.dictionary.len();
121                format!("<{dict_len} words>")
122            })
123            .field("config", &self.config)
124            .field("metric", &"<StringMetric>")
125            .field("language_model", &self.language_model)
126            .field("error_model", &self.error_model)
127            .finish()
128    }
129}
130
131impl Default for StatisticalCorrector {
132    fn default() -> Self {
133        // Start with a base dictionary corrector
134        let dict_corrector = DictionaryCorrector::default();
135
136        // Create the language model
137        let mut language_model = NGramModel::new(3);
138
139        // Add some sample text to bootstrap the language model
140        let sampletexts = [
141            "The quick brown fox jumps over the lazy dog.",
142            "She sells seashells by the seashore.",
143            "How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
144            "To be or not to be, that is the question.",
145            "Four score and seven years ago our fathers brought forth on this continent a new nation.",
146            "Ask not what your country can do for you, ask what you can do for your country.",
147            "That's one small step for man, one giant leap for mankind.",
148            "I have a dream that one day this nation will rise up and live out the true meaning of its creed.",
149            "The only thing we have to fear is fear itself.",
150            "We hold these truths to be self-evident, that all men are created equal.",
151            // Add more sample texts to improve the language model
152        ];
153
154        for text in &sampletexts {
155            language_model.addtext(text);
156        }
157
158        Self {
159            dictionary: dict_corrector.dictionary,
160            config: StatisticalCorrectorConfig::default(),
161            metric: Arc::new(DamerauLevenshteinMetric::new()),
162            language_model,
163            error_model: ErrorModel::default(),
164        }
165    }
166}
167
168impl StatisticalCorrector {
169    /// Create a new statistical spelling corrector with the given configuration
170    pub fn new(config: StatisticalCorrectorConfig) -> Self {
171        Self {
172            config,
173            ..Default::default()
174        }
175    }
176
177    /// Create a statistical corrector from a base dictionary corrector
178    pub fn from_dictionary_corrector(dictcorrector: &DictionaryCorrector) -> Self {
179        let config = StatisticalCorrectorConfig {
180            max_edit_distance: dictcorrector.config.max_edit_distance,
181            case_sensitive: dictcorrector.config.case_sensitive,
182            max_suggestions: dictcorrector.config.max_suggestions,
183            min_frequency: dictcorrector.config.min_frequency,
184            ..StatisticalCorrectorConfig::default()
185        };
186
187        Self {
188            dictionary: dictcorrector.dictionary.clone(),
189            config,
190            metric: dictcorrector.metric.clone(),
191            language_model: NGramModel::new(3),
192            error_model: ErrorModel::default(),
193        }
194    }
195
196    /// Add a corpus file to train the language model
197    pub fn add_corpus_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
198        self.language_model.add_corpus_file(path)
199    }
200
201    /// Add text to train the language model
202    pub fn add_trainingtext(&mut self, text: &str) {
203        self.language_model.addtext(text);
204    }
205
206    /// Set the language model
207    pub fn set_language_model(&mut self, model: NGramModel) {
208        self.language_model = model;
209    }
210
211    /// Set the error model
212    pub fn set_error_model(&mut self, model: ErrorModel) {
213        self.error_model = model;
214    }
215
216    /// Set the string metric to use for similarity calculations
217    pub fn set_metric<M: StringMetric + Send + Sync + 'static>(&mut self, metric: M) {
218        self.metric = Arc::new(metric);
219    }
220
221    /// Set the configuration
222    pub fn set_config(&mut self, config: StatisticalCorrectorConfig) {
223        self.config = config;
224    }
225
226    /// Get possible corrections for a word given its context
227    fn get_contextual_corrections(&self, word: &str, context: &[String]) -> Vec<(String, f64)> {
228        // If the word is correct, return it with high probability
229        if self.is_correct(word) {
230            return vec![(word.to_string(), 1.0)];
231        }
232
233        // Get edit-distance based candidates
234        let word_to_check = if !self.config.case_sensitive {
235            word.to_lowercase()
236        } else {
237            word.to_string()
238        };
239
240        let mut candidates: Vec<(String, f64)> = Vec::new();
241
242        // Calculate candidates based on edit distance
243        for (dict_word, frequency) in &self.dictionary {
244            if *frequency < self.config.min_frequency {
245                continue;
246            }
247
248            let dict_word_normalized = if !self.config.case_sensitive {
249                dict_word.to_lowercase()
250            } else {
251                dict_word.clone()
252            };
253
254            // Skip words that are too different in length
255            if dict_word_normalized.len() > word_to_check.len() + self.config.max_edit_distance
256                || dict_word_normalized.len() + self.config.max_edit_distance < word_to_check.len()
257            {
258                continue;
259            }
260
261            // Calculate edit distance
262            if let Ok(distance) = self.metric.distance(&word_to_check, &dict_word_normalized) {
263                // Convert to usize and check if it's within the threshold
264                let distance_usize = distance.round() as usize;
265                if distance_usize <= self.config.max_edit_distance {
266                    // Edit distance score (lower is better)
267                    let edit_score = 1.0 / (1.0 + distance);
268
269                    // Language model score
270                    let lm_score = if self.config.use_context {
271                        self.language_model.probability(dict_word, context)
272                    } else {
273                        self.language_model.unigram_probability(dict_word)
274                    };
275
276                    // Error model score
277                    let error_score = self
278                        .error_model
279                        .error_probability(&word_to_check, &dict_word_normalized);
280
281                    // Combined score
282                    let combined_score = (self.config.edit_distance_weight * edit_score)
283                        + (self.config.language_model_weight * lm_score * error_score);
284
285                    candidates.push((dict_word.clone(), combined_score));
286                }
287            }
288        }
289
290        // Sort by combined score (higher is better)
291        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
292
293        // Limit to max_suggestions
294        candidates.truncate(self.config.max_suggestions);
295
296        candidates
297    }
298
299    /// Correct a sentence using a context-aware approach
300    pub fn correct_sentence(&self, sentence: &str) -> Result<String> {
301        let words: Vec<String> = sentence
302            .split_whitespace()
303            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
304            .filter(|s| !s.is_empty())
305            .collect();
306
307        if words.is_empty() {
308            return Ok(sentence.to_string());
309        }
310
311        // If we're not using context, correct each word individually
312        if !self.config.use_context {
313            let mut result = sentence.to_string();
314
315            for word in &words {
316                if !self.is_correct(word) {
317                    if let Ok(correction) = self.correct(word) {
318                        if correction != *word {
319                            // Replace the word in the result
320                            result = result.replace(word, &correction);
321                        }
322                    }
323                }
324            }
325
326            return Ok(result);
327        }
328
329        // Context-aware correction using beam search
330        let context_window = self.config.context_window;
331        let max_candidates = self.config.max_candidates;
332
333        // Initialize beam search
334        // Each beam state contains (partial sentence, score, context)
335        let mut beams: Vec<(Vec<String>, f64, Vec<String>)> = vec![(Vec::new(), 0.0, Vec::new())];
336
337        // Process each word
338        for word in &words {
339            let mut new_beams = Vec::new();
340
341            for (partial, score, context) in beams {
342                // Get correction candidates for this word
343                let candidates = self.get_contextual_corrections(word, &context);
344
345                // Add each candidate to create new beams
346                for (candidate, candidate_score) in candidates.iter().take(max_candidates) {
347                    let mut new_partial = partial.clone();
348                    new_partial.push(candidate.clone());
349
350                    let mut new_context = context.clone();
351                    new_context.push(candidate.clone());
352                    if new_context.len() > context_window {
353                        new_context.remove(0);
354                    }
355
356                    let new_score = score + candidate_score;
357                    new_beams.push((new_partial, new_score, new_context));
358                }
359            }
360
361            // Prune beams to keep only the best ones
362            new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363            new_beams.truncate(max_candidates);
364
365            beams = new_beams;
366        }
367
368        // Get the best beam
369        if let Some((best_sentence, _, _)) = beams.first() {
370            // Reconstruct the sentence
371            let mut result = sentence.to_string();
372
373            // Replace each word with its correction
374            for (i, original) in words.iter().enumerate() {
375                if i < best_sentence.len() && original != &best_sentence[i] {
376                    result = result.replace(original, &best_sentence[i]);
377                }
378            }
379
380            Ok(result)
381        } else {
382            // Fallback to the original sentence
383            Ok(sentence.to_string())
384        }
385    }
386
387    /// Add a word to the dictionary
388    pub fn add_word(&mut self, word: &str, frequency: usize) {
389        self.dictionary.insert(word.to_string(), frequency);
390    }
391
392    /// Remove a word from the dictionary
393    pub fn remove_word(&mut self, word: &str) {
394        self.dictionary.remove(word);
395    }
396
397    /// Get the total number of words in the dictionary
398    pub fn dictionary_size(&self) -> usize {
399        self.dictionary.len()
400    }
401
402    /// Get the vocabulary size of the language model
403    pub fn vocabulary_size(&self) -> usize {
404        self.language_model.vocabulary_size()
405    }
406}
407
408impl SpellingCorrector for StatisticalCorrector {
409    fn correct(&self, word: &str) -> Result<String> {
410        // If the word is already correct, return it as is
411        if self.is_correct(word) {
412            return Ok(word.to_string());
413        }
414
415        // Get suggestions and return the best one
416        let suggestions = self.get_suggestions(word, 1)?;
417
418        if suggestions.is_empty() {
419            // No suggestions found, return the original word
420            Ok(word.to_string())
421        } else {
422            // Return the best suggestion
423            Ok(suggestions[0].clone())
424        }
425    }
426
427    fn get_suggestions(&self, word: &str, limit: usize) -> Result<Vec<String>> {
428        // If the word is already correct, return it as the only suggestion
429        if self.is_correct(word) {
430            return Ok(vec![word.to_string()]);
431        }
432
433        // Get contextual corrections (with empty context)
434        let candidates = self.get_contextual_corrections(word, &[]);
435
436        // Extract just the words
437        let suggestions = candidates
438            .into_iter()
439            .map(|(word, _)| word)
440            .take(limit)
441            .collect();
442
443        Ok(suggestions)
444    }
445
446    fn is_correct(&self, word: &str) -> bool {
447        if self.config.case_sensitive {
448            self.dictionary.contains_key(word)
449        } else {
450            self.dictionary
451                .keys()
452                .any(|dict_word| dict_word.to_lowercase() == word.to_lowercase())
453        }
454    }
455
456    // Override the default implementation for more context-aware correction
457    fn correcttext(&self, text: &str) -> Result<String> {
458        // Split the text into sentences
459        let sentences: Vec<&str> = text
460            .split(['.', '?', '!'])
461            .map(|s| s.trim())
462            .filter(|s| !s.is_empty())
463            .collect();
464
465        if sentences.is_empty() {
466            return Ok(text.to_string());
467        }
468
469        let mut result = text.to_string();
470
471        // Process each sentence
472        for sentence in sentences {
473            if sentence.trim().is_empty() {
474                continue;
475            }
476
477            let corrected_sentence = self.correct_sentence(sentence)?;
478            if corrected_sentence != sentence {
479                // Replace the sentence in the text
480                result = result.replace(sentence, &corrected_sentence);
481            }
482        }
483
484        Ok(result)
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_statistical_corrector_basic() {
494        let mut corrector = StatisticalCorrector::default();
495
496        // Add some training text to improve the language model
497        corrector.add_trainingtext("The quick brown fox jumps over the lazy dog.");
498        corrector.add_trainingtext("Programming languages like Python and Rust are popular.");
499        corrector.add_trainingtext("I received your message about the meeting tomorrow.");
500
501        // Add specific words to ensure consistent behavior in tests
502        corrector.add_word("received", 100);
503        corrector.add_word("message", 100);
504        corrector.add_word("meeting", 100);
505        corrector.add_word("tomorrow", 100);
506
507        // Test basic word correction
508        assert_eq!(corrector.correct("recieved").unwrap(), "received");
509        assert_eq!(corrector.correct("mesage").unwrap(), "message");
510
511        // Test text correction
512        let text = "I recieved your mesage about the meating tommorow.";
513        let corrected = corrector.correcttext(text).unwrap();
514
515        // Check each word individually
516        assert!(corrected.contains("received"));
517        assert!(corrected.contains("message"));
518        assert!(corrected.contains("meeting"));
519        assert!(corrected.contains("tomorrow"));
520    }
521
522    #[test]
523    fn test_statistical_corrector_context_aware() {
524        let mut corrector = StatisticalCorrector::default();
525
526        // Add training text for context
527        corrector.add_trainingtext("I went to the bank to deposit money.");
528        corrector.add_trainingtext("The river bank was muddy after the rain.");
529        corrector.add_trainingtext("I need to address the issues in the meeting.");
530        corrector.add_trainingtext("What is your home address?");
531
532        // Add explicit words for consistent testing
533        corrector.add_word("bank", 100);
534        corrector.add_word("deposit", 100);
535        corrector.add_word("money", 100);
536        corrector.add_word("river", 100);
537        corrector.add_word("muddy", 100);
538        corrector.add_word("rain", 100);
539
540        // Test context-aware correction
541        let text1 = "I went to the bnk to deposit money.";
542        let text2 = "The river bnk was muddy after the rain.";
543
544        let corrected1 = corrector.correcttext(text1).unwrap();
545        let corrected2 = corrector.correcttext(text2).unwrap();
546
547        // Both should correct "bnk" to "bank" regardless of context
548        assert!(corrected1.contains("bank"));
549        assert!(corrected2.contains("bank"));
550    }
551
552    #[test]
553    fn test_from_dictionary_corrector() {
554        let dict_corrector = DictionaryCorrector::default();
555        let stat_corrector = StatisticalCorrector::from_dictionary_corrector(&dict_corrector);
556
557        // Both correctors should have the same dictionary
558        assert_eq!(
559            dict_corrector.dictionary_size(),
560            stat_corrector.dictionary_size()
561        );
562
563        // Since the StatisticalCorrector uses context and language models, it may
564        // produce different corrections than the DictionaryCorrector.
565        // For this test, just verify that both correctors can handle the input.
566        let word = "recieve";
567        assert!(dict_corrector.correct(word).is_ok());
568        assert!(stat_corrector.correct(word).is_ok());
569
570        // Test is_correct works the same for both
571        assert_eq!(
572            dict_corrector.is_correct("receive"),
573            stat_corrector.is_correct("receive")
574        );
575    }
576}