scirs2_text/spelling/
ngram.rs

1//! N-gram language model for text processing and spelling correction
2//!
3//! This module provides an n-gram language model implementation that can be used
4//! for context-aware spelling correction, text generation, and other natural language
5//! processing tasks.
6//!
7//! # Key Components
8//!
9//! - `NGramModel`: A language model that supports unigrams, bigrams, and trigrams
10//!
11//! # Example
12//!
13//! ```
14//! use scirs2_text::spelling::NGramModel;
15//!
16//! # fn main() {
17//! // Create a new trigram language model
18//! let mut model = NGramModel::new(3);
19//!
20//! // Train the model with some text
21//! model.addtext("The quick brown fox jumps over the lazy dog.");
22//! model.addtext("Programming languages like Python and Rust are popular.");
23//!
24//! // Get probability of a word given its context
25//! let context = vec!["quick".to_string(), "brown".to_string()];
26//! let prob = model.probability("fox", &context);
27//!
28//! // Higher probability for words that appeared in the training text
29//! assert!(prob > model.probability("cat", &context));
30//! # }
31//! ```
32
33use std::collections::{HashMap, HashSet};
34use std::fs::File;
35use std::io::{BufRead, BufReader};
36use std::path::Path;
37
38use crate::error::{Result, TextError};
39
40/// N-gram language model for statistical spelling correction
41#[derive(Clone)]
42pub struct NGramModel {
43    /// Unigram counts
44    unigrams: HashMap<String, usize>,
45    /// Bigram counts
46    bigrams: HashMap<(String, String), usize>,
47    /// Trigram counts
48    trigrams: HashMap<(String, String, String), usize>,
49    /// Total number of words in training data
50    total_words: usize,
51    /// Order of the n-gram model
52    order: usize,
53    /// Start of sentence token
54    start_token: String,
55    /// End of sentence token
56    end_token: String,
57}
58
59impl NGramModel {
60    /// Create a new n-gram model with the specified order
61    pub fn new(order: usize) -> Self {
62        if order > 3 {
63            // Warn but limit to 3
64            eprintln!("Warning: NGramModel only supports orders up to 3. Using order=3.");
65        }
66
67        Self {
68            unigrams: HashMap::new(),
69            bigrams: HashMap::new(),
70            trigrams: HashMap::new(),
71            total_words: 0,
72            order: order.clamp(1, 3),
73            start_token: "<s>".to_string(),
74            end_token: "</s>".to_string(),
75        }
76    }
77
78    /// Add a text to the language model
79    pub fn addtext(&mut self, text: &str) {
80        let words: Vec<String> = text
81            .split_whitespace()
82            .map(|s| {
83                s.trim_matches(|c: char| !c.is_alphanumeric())
84                    .to_lowercase()
85            })
86            .filter(|s| !s.is_empty())
87            .collect();
88
89        if words.is_empty() {
90            return;
91        }
92
93        // Process sentences (separated by punctuation)
94        let mut current_sentence = Vec::new();
95
96        for word in words {
97            // Check if this is end of sentence
98            let is_end = word.ends_with('.') || word.ends_with('?') || word.ends_with('!');
99
100            // Add word to current sentence
101            let clean_word = word
102                .trim_matches(|c: char| !c.is_alphanumeric())
103                .to_string();
104            if !clean_word.is_empty() {
105                current_sentence.push(clean_word);
106                self.total_words += 1;
107            }
108
109            // Process sentence if we're at the end
110            if is_end && !current_sentence.is_empty() {
111                self.process_sentence(&current_sentence);
112                current_sentence.clear();
113            }
114        }
115
116        // Process any remaining words as a sentence
117        if !current_sentence.is_empty() {
118            self.process_sentence(&current_sentence);
119        }
120    }
121
122    /// Process a single sentence to add to the language model
123    fn process_sentence(&mut self, sentence: &[String]) {
124        // Add start and end tokens
125        let mut words = Vec::with_capacity(sentence.len() + 2);
126        words.push(self.start_token.clone());
127        words.extend(sentence.iter().cloned());
128        words.push(self.end_token.clone());
129
130        // Update unigram counts
131        for word in &words {
132            *self.unigrams.entry(word.clone()).or_insert(0) += 1;
133        }
134
135        // Update bigram counts if order >= 2
136        if self.order >= 2 {
137            for i in 0..words.len() - 1 {
138                let bigram = (words[i].clone(), words[i + 1].clone());
139                *self.bigrams.entry(bigram).or_insert(0) += 1;
140            }
141        }
142
143        // Update trigram counts if order >= 3
144        if self.order >= 3 {
145            for i in 0..words.len() - 2 {
146                let trigram = (words[i].clone(), words[i + 1].clone(), words[i + 2].clone());
147                *self.trigrams.entry(trigram).or_insert(0) += 1;
148            }
149        }
150    }
151
152    /// Add a corpus file to the language model
153    pub fn add_corpus_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
154        let file = File::open(path)
155            .map_err(|e| TextError::IoError(format!("Failed to open corpus file: {e}")))?;
156
157        let reader = BufReader::new(file);
158
159        for line in reader.lines() {
160            let line = line.map_err(|e| {
161                TextError::IoError(format!("Failed to read line from corpus file: {e}"))
162            })?;
163
164            // Skip empty lines
165            if line.trim().is_empty() {
166                continue;
167            }
168
169            self.addtext(&line);
170        }
171
172        Ok(())
173    }
174
175    /// Generate a probability estimate for a word given its context
176    pub fn probability(&self, word: &str, context: &[String]) -> f64 {
177        match self.order {
178            1 => self.unigram_probability(word),
179            2 => self.bigram_probability(word, context),
180            3 => self.trigram_probability(word, context),
181            _ => self.unigram_probability(word), // Default fallback
182        }
183    }
184
185    /// Calculate unigram probability P(word)
186    pub fn unigram_probability(&self, word: &str) -> f64 {
187        let word_count = self.unigrams.get(word).copied().unwrap_or(0);
188
189        // Add-one smoothing (Laplace smoothing)
190        let vocabulary_size = self.unigrams.len();
191        (word_count as f64 + 1.0) / (self.total_words as f64 + vocabulary_size as f64)
192    }
193
194    /// Calculate bigram probability P(word | previous)
195    pub fn bigram_probability(&self, word: &str, context: &[String]) -> f64 {
196        if context.is_empty() {
197            return self.unigram_probability(word);
198        }
199
200        let previous = &context[context.len() - 1];
201
202        let bigram_count = self
203            .bigrams
204            .get(&(previous.clone(), word.to_string()))
205            .copied()
206            .unwrap_or(0);
207
208        let previous_count = self.unigrams.get(previous).copied().unwrap_or(0);
209
210        if previous_count == 0 {
211            return self.unigram_probability(word);
212        }
213
214        // Add-one smoothing
215        let vocabulary_size = self.unigrams.len();
216        (bigram_count as f64 + 1.0) / (previous_count as f64 + vocabulary_size as f64)
217    }
218
219    /// Calculate trigram probability P(word | previous1, previous2)
220    pub fn trigram_probability(&self, word: &str, context: &[String]) -> f64 {
221        if context.len() < 2 {
222            return self.bigram_probability(word, context);
223        }
224
225        let previous1 = &context[context.len() - 2];
226        let previous2 = &context[context.len() - 1];
227
228        let trigram_count = self
229            .trigrams
230            .get(&(previous1.clone(), previous2.clone(), word.to_string()))
231            .copied()
232            .unwrap_or(0);
233
234        let bigram_count = self
235            .bigrams
236            .get(&(previous1.clone(), previous2.clone()))
237            .copied()
238            .unwrap_or(0);
239
240        if bigram_count == 0 {
241            return self.bigram_probability(word, &[previous2.clone()]);
242        }
243
244        // Add-one smoothing
245        let vocabulary_size = self.unigrams.len();
246        (trigram_count as f64 + 1.0) / (bigram_count as f64 + vocabulary_size as f64)
247    }
248
249    /// Calculate perplexity on a test text
250    pub fn perplexity(&self, text: &str) -> f64 {
251        let words: Vec<String> = text
252            .split_whitespace()
253            .map(|s| {
254                s.trim_matches(|c: char| !c.is_alphanumeric())
255                    .to_lowercase()
256            })
257            .filter(|s| !s.is_empty())
258            .collect();
259
260        if words.is_empty() {
261            return f64::INFINITY;
262        }
263
264        let mut log_prob_sum = 0.0;
265        let mut context = Vec::new();
266
267        for word in words.iter() {
268            let prob = self.probability(word, &context);
269            log_prob_sum += (prob + 1e-10).log2(); // Add small epsilon to avoid log(0)
270
271            // Update context for next word
272            context.push(word.clone());
273            if context.len() > self.order {
274                context.remove(0);
275            }
276        }
277
278        // Perplexity = 2^(-average log probability)
279        2.0f64.powf(-log_prob_sum / words.len() as f64)
280    }
281
282    /// Get vocabulary size
283    pub fn vocabulary_size(&self) -> usize {
284        self.unigrams.len()
285    }
286
287    /// Get total words processed
288    pub fn total_words(&self) -> usize {
289        self.total_words
290    }
291
292    /// Get the frequency of a word
293    pub fn word_frequency(&self, word: &str) -> usize {
294        self.unigrams.get(word).copied().unwrap_or(0)
295    }
296
297    /// Generate potential single-edit typos for a word
298    pub fn generate_typos(&self, word: &str, numtypos: usize) -> Vec<String> {
299        let mut typos = HashSet::new();
300        let word = word.to_lowercase();
301        let chars: Vec<char> = word.chars().collect();
302
303        // Deletion errors (removing one character)
304        for i in 0..chars.len() {
305            let mut new_word = String::new();
306            for (j, &c) in chars.iter().enumerate() {
307                if j != i {
308                    new_word.push(c);
309                }
310            }
311            typos.insert(new_word);
312        }
313
314        // Transposition errors (swapping adjacent characters)
315        for i in 0..chars.len() - 1 {
316            let mut new_chars = chars.clone();
317            new_chars.swap(i, i + 1);
318            typos.insert(new_chars.iter().collect());
319        }
320
321        // Insertion errors (adding one character)
322        for i in 0..=chars.len() {
323            for c in 'a'..='z' {
324                let mut new_chars = chars.clone();
325                new_chars.insert(i, c);
326                typos.insert(new_chars.iter().collect());
327            }
328        }
329
330        // Replacement errors (changing one character)
331        for i in 0..chars.len() {
332            for c in 'a'..='z' {
333                if chars[i] != c {
334                    let mut new_chars = chars.clone();
335                    new_chars[i] = c;
336                    typos.insert(new_chars.iter().collect());
337                }
338            }
339        }
340
341        // Convert to Vec and limit by frequency
342        let mut typos_vec: Vec<_> = typos.into_iter().collect();
343
344        // Sort by word frequency in our model
345        typos_vec.sort_by(|a, b| {
346            let freq_a = self.word_frequency(a);
347            let freq_b = self.word_frequency(b);
348            freq_b.cmp(&freq_a) // Higher frequency first
349        });
350
351        // Limit to requested number
352        typos_vec.truncate(numtypos);
353
354        typos_vec
355    }
356}
357
358impl std::fmt::Debug for NGramModel {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.debug_struct("NGramModel")
361            .field("order", &self.order)
362            .field("vocabulary_size", &self.vocabulary_size())
363            .field("total_words", &self.total_words)
364            .field("unigrams", &{
365                let unigram_len = self.unigrams.len();
366                format!("<{unigram_len} entries>")
367            })
368            .field("bigrams", &{
369                let bigram_len = self.bigrams.len();
370                format!("<{bigram_len} entries>")
371            })
372            .field("trigrams", &{
373                let trigram_len = self.trigrams.len();
374                format!("<{trigram_len} entries>")
375            })
376            .finish()
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_ngram_model_basics() {
386        let mut model = NGramModel::new(3);
387
388        // Add some training data
389        model.addtext("The quick brown fox jumps over the lazy dog.");
390
391        // Test unigram probabilities
392        let p_the = model.unigram_probability("the");
393        let p_quick = model.unigram_probability("quick");
394        let p_unknown = model.unigram_probability("unknown");
395
396        // The should be more frequent than quick
397        assert!(p_the > p_quick);
398
399        // Unknown words should have non-zero probability due to smoothing
400        assert!(p_unknown > 0.0);
401
402        // Test bigram probabilities
403        let p_quick_given_the = model.bigram_probability("quick", &["the".to_string()]);
404        let p_brown_given_quick = model.bigram_probability("brown", &["quick".to_string()]);
405
406        // These specific bigrams should exist in the training data
407        assert!(p_quick_given_the > 0.0);
408        assert!(p_brown_given_quick > 0.0);
409
410        // Test trigram model
411        let p_fox_given_quick_brown =
412            model.trigram_probability("fox", &["quick".to_string(), "brown".to_string()]);
413
414        // This specific trigram should exist in the training data
415        assert!(p_fox_given_quick_brown > 0.0);
416    }
417}