Skip to main content

scirs2_text/language_models/
mod.rs

1//! Statistical language models: Unigram, Bigram, N-gram, and perplexity evaluation.
2//!
3//! This module provides from-scratch implementations of classical statistical
4//! language models with smoothing:
5//!
6//! - `UnigramLM` – unsmoothed maximum-likelihood unigram model
7//! - `BigramLM` – bigram model with Laplace smoothing
8//! - `NgramLM` – arbitrary-order model with Kneser-Ney smoothing
9//! - `PerplexityEval` – perplexity computation for any `NgramLM`
10
11use std::collections::{HashMap, HashSet};
12
13use crate::error::{Result, TextError};
14
15// ---------------------------------------------------------------------------
16// Tokenisation helper
17// ---------------------------------------------------------------------------
18
19/// Split text into lowercase alpha tokens.
20fn simple_tokenize(text: &str) -> Vec<String> {
21    text.split(|c: char| !c.is_alphabetic())
22        .filter(|s| !s.is_empty())
23        .map(|s| s.to_lowercase())
24        .collect()
25}
26
27// ---------------------------------------------------------------------------
28// UnigramLM
29// ---------------------------------------------------------------------------
30
31/// A maximum-likelihood unigram language model.
32#[derive(Debug, Clone)]
33pub struct UnigramLM {
34    /// Word → probability (MLE, no smoothing).
35    pub probs: HashMap<String, f64>,
36    /// Known vocabulary.
37    pub vocab: HashSet<String>,
38}
39
40impl UnigramLM {
41    /// Train a unigram model from a corpus of sentences.
42    pub fn train(sentences: &[Vec<String>]) -> Result<UnigramLM> {
43        let mut counts: HashMap<String, usize> = HashMap::new();
44        let mut total = 0usize;
45        for sent in sentences {
46            for w in sent {
47                *counts.entry(w.clone()).or_insert(0) += 1;
48                total += 1;
49            }
50        }
51        if total == 0 {
52            return Err(TextError::InvalidInput("Empty corpus".to_string()));
53        }
54        let vocab: HashSet<String> = counts.keys().cloned().collect();
55        let probs = counts
56            .into_iter()
57            .map(|(k, c)| (k, c as f64 / total as f64))
58            .collect();
59        Ok(UnigramLM { probs, vocab })
60    }
61
62    /// Probability of a word.
63    ///
64    /// Returns `0.0` for out-of-vocabulary words.
65    pub fn probability(&self, word: &str) -> f64 {
66        self.probs.get(word).copied().unwrap_or(0.0)
67    }
68
69    /// Log-probability of a word (returns `f64::NEG_INFINITY` for OOV).
70    pub fn log_probability(&self, word: &str) -> f64 {
71        let p = self.probability(word);
72        if p <= 0.0 {
73            f64::NEG_INFINITY
74        } else {
75            p.ln()
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// BigramLM
82// ---------------------------------------------------------------------------
83
84/// A bigram language model with Laplace smoothing.
85#[derive(Debug, Clone)]
86pub struct BigramLM {
87    /// (prev, curr) → P(curr | prev)
88    pub probs: HashMap<(String, String), f64>,
89    /// Unigram probabilities (smoothed) for back-off.
90    pub unigrams: HashMap<String, f64>,
91    /// Vocabulary size used for Laplace denominator.
92    vocab_size: usize,
93}
94
95impl BigramLM {
96    /// Train a bigram model from a corpus of sentences.
97    ///
98    /// Laplace smoothing is applied with `k = 1` (add-one).
99    pub fn train(sentences: &[Vec<String>]) -> Result<BigramLM> {
100        let mut uni_counts: HashMap<String, usize> = HashMap::new();
101        let mut bi_counts: HashMap<(String, String), usize> = HashMap::new();
102
103        // Collect <s> and </s> boundaries
104        const START: &str = "<s>";
105        const END: &str = "</s>";
106
107        for sent in sentences {
108            if sent.is_empty() {
109                continue;
110            }
111            let padded: Vec<&str> = std::iter::once(START)
112                .chain(sent.iter().map(String::as_str))
113                .chain(std::iter::once(END))
114                .collect();
115            for i in 0..padded.len() - 1 {
116                *uni_counts.entry(padded[i].to_string()).or_insert(0) += 1;
117                *bi_counts
118                    .entry((padded[i].to_string(), padded[i + 1].to_string()))
119                    .or_insert(0) += 1;
120            }
121            *uni_counts.entry(END.to_string()).or_insert(0) += 1;
122        }
123
124        let vocab_size = uni_counts.len();
125        if vocab_size == 0 {
126            return Err(TextError::InvalidInput("Empty corpus".to_string()));
127        }
128
129        // Laplace-smoothed unigrams
130        let total_uni: usize = uni_counts.values().sum();
131        let unigrams: HashMap<String, f64> = uni_counts
132            .iter()
133            .map(|(w, &c)| {
134                let p = (c as f64 + 1.0) / (total_uni as f64 + vocab_size as f64);
135                (w.clone(), p)
136            })
137            .collect();
138
139        // Laplace-smoothed bigrams
140        let mut probs: HashMap<(String, String), f64> = HashMap::new();
141        // Collect all known bigrams
142        for ((prev, curr), &c) in &bi_counts {
143            let prev_count = uni_counts.get(prev).copied().unwrap_or(0) as f64;
144            let p = (c as f64 + 1.0) / (prev_count + vocab_size as f64);
145            probs.insert((prev.clone(), curr.clone()), p);
146        }
147
148        Ok(BigramLM {
149            probs,
150            unigrams,
151            vocab_size,
152        })
153    }
154
155    /// P(curr | prev) with Laplace smoothing back-off.
156    pub fn probability(&self, prev: &str, curr: &str) -> f64 {
157        self.probs
158            .get(&(prev.to_string(), curr.to_string()))
159            .copied()
160            .unwrap_or_else(|| 1.0 / (self.vocab_size as f64 + 1.0))
161    }
162}
163
164// ---------------------------------------------------------------------------
165// NgramLM  (Kneser-Ney smoothing)
166// ---------------------------------------------------------------------------
167
168/// An N-gram language model with Kneser-Ney smoothing.
169///
170/// Uses absolute discounting with discount `d = 0.75` and a recursive
171/// lower-order back-off.
172#[derive(Debug, Clone)]
173pub struct NgramLM {
174    /// N-gram order.
175    pub n: usize,
176    /// Full-context n-gram counts.
177    pub counts: HashMap<Vec<String>, usize>,
178    /// Context (n-1 gram) counts.
179    pub context_counts: HashMap<Vec<String>, usize>,
180    /// Kneser-Ney continuation counts for lower-order back-off.
181    continuation_counts: HashMap<String, usize>,
182    /// Number of distinct bigrams in the corpus (for KN base).
183    n_bigrams: usize,
184    /// Discount parameter.
185    discount: f64,
186}
187
188impl NgramLM {
189    /// Train an N-gram model from a corpus of sentences.
190    pub fn train(n: usize, sentences: &[Vec<String>]) -> Result<NgramLM> {
191        if n < 1 {
192            return Err(TextError::InvalidInput("n must be >= 1".to_string()));
193        }
194        const START: &str = "<s>";
195        const END: &str = "</s>";
196
197        let mut counts: HashMap<Vec<String>, usize> = HashMap::new();
198        let mut context_counts: HashMap<Vec<String>, usize> = HashMap::new();
199        let mut continuation_counts: HashMap<String, usize> = HashMap::new();
200        let mut bigram_set: HashSet<(String, String)> = HashSet::new();
201
202        for sent in sentences {
203            if sent.is_empty() {
204                continue;
205            }
206            // Pad with (n-1) start tokens and 1 end token
207            let mut padded: Vec<String> = (0..n - 1).map(|_| START.to_string()).collect();
208            padded.extend(sent.iter().cloned());
209            padded.push(END.to_string());
210
211            for i in 0..padded.len().saturating_sub(n - 1) {
212                let ngram: Vec<String> = padded[i..i + n].to_vec();
213                let context: Vec<String> = padded[i..i + n - 1].to_vec();
214                *counts.entry(ngram).or_insert(0) += 1;
215                if n > 1 {
216                    *context_counts.entry(context).or_insert(0) += 1;
217                }
218            }
219
220            // Continuation counts for KN: unique left-contexts for each word
221            for i in 1..padded.len() {
222                bigram_set.insert((padded[i - 1].clone(), padded[i].clone()));
223                *continuation_counts.entry(padded[i].clone()).or_insert(0) += 0;
224                // ensure entry exists
225            }
226        }
227
228        // Count unique left-contexts per word
229        for (_, curr) in &bigram_set {
230            *continuation_counts.entry(curr.clone()).or_insert(0) += 1;
231        }
232        let n_bigrams = bigram_set.len();
233
234        Ok(NgramLM {
235            n,
236            counts,
237            context_counts,
238            continuation_counts,
239            n_bigrams,
240            discount: 0.75,
241        })
242    }
243
244    /// P(word | context) using Kneser-Ney smoothing.
245    ///
246    /// For unigrams (n=1) this reduces to Kneser-Ney continuation probability.
247    pub fn probability(&self, word: &str, context: &[&str]) -> f64 {
248        self.kn_probability(word, context)
249    }
250
251    fn kn_probability(&self, word: &str, context: &[&str]) -> f64 {
252        if self.n == 1 {
253            return self.kn_unigram(word);
254        }
255
256        // Use the last (n-1) words of context
257        let used_ctx: Vec<String> = if context.len() >= self.n - 1 {
258            context[context.len() - (self.n - 1)..]
259                .iter()
260                .map(|s| s.to_string())
261                .collect()
262        } else {
263            context.iter().map(|s| s.to_string()).collect()
264        };
265
266        let ngram: Vec<String> = used_ctx
267            .iter()
268            .cloned()
269            .chain(std::iter::once(word.to_string()))
270            .collect();
271
272        let c = self.counts.get(&ngram).copied().unwrap_or(0) as f64;
273        let c_ctx = self.context_counts.get(&used_ctx).copied().unwrap_or(0) as f64;
274
275        if c_ctx == 0.0 {
276            return self.kn_unigram(word);
277        }
278
279        // Count of types that follow `used_ctx` (for lambda)
280        let types_after_ctx = self
281            .counts
282            .iter()
283            .filter(|(k, &v)| v > 0 && k.len() == self.n && k[..self.n - 1] == used_ctx[..])
284            .count() as f64;
285
286        let lambda = self.discount * types_after_ctx / c_ctx;
287        let first_term = (c - self.discount).max(0.0) / c_ctx;
288        first_term + lambda * self.kn_unigram(word)
289    }
290
291    fn kn_unigram(&self, word: &str) -> f64 {
292        let c_w = self.continuation_counts.get(word).copied().unwrap_or(0) as f64;
293        if self.n_bigrams == 0 {
294            return 1e-10;
295        }
296        (c_w / self.n_bigrams as f64).max(1e-10)
297    }
298
299    /// Log-probability of `word` given `context`.
300    pub fn log_probability(&self, word: &str, context: &[&str]) -> f64 {
301        let p = self.probability(word, context);
302        if p <= 0.0 {
303            f64::NEG_INFINITY
304        } else {
305            p.ln()
306        }
307    }
308}
309
310// ---------------------------------------------------------------------------
311// PerplexityEval
312// ---------------------------------------------------------------------------
313
314/// Perplexity evaluation for an `NgramLM`.
315pub struct PerplexityEval;
316
317impl PerplexityEval {
318    /// Compute per-token perplexity over `test_sentences`.
319    ///
320    /// PP = exp( -1/N * Σ log P(w_i | w_{i-n+1}..w_{i-1}) )
321    pub fn compute(lm: &NgramLM, test_sentences: &[Vec<String>]) -> Result<f64> {
322        let mut log_prob_sum = 0.0f64;
323        let mut token_count = 0usize;
324
325        const START: &str = "<s>";
326
327        for sent in test_sentences {
328            if sent.is_empty() {
329                continue;
330            }
331            // Pad with start tokens
332            let mut padded: Vec<String> = (0..lm.n - 1).map(|_| START.to_string()).collect();
333            padded.extend(sent.iter().cloned());
334
335            for i in lm.n - 1..padded.len() {
336                let word = &padded[i];
337                let ctx_start = i.saturating_sub(lm.n - 1);
338                let context: Vec<&str> = padded[ctx_start..i].iter().map(String::as_str).collect();
339                let lp = lm.log_probability(word, &context);
340                if lp.is_finite() {
341                    log_prob_sum += lp;
342                } else {
343                    // Penalty for completely unknown n-gram
344                    log_prob_sum += (1e-10_f64).ln();
345                }
346                token_count += 1;
347            }
348        }
349
350        if token_count == 0 {
351            return Err(TextError::InvalidInput(
352                "No tokens in test sentences".to_string(),
353            ));
354        }
355
356        let avg_log_prob = log_prob_sum / token_count as f64;
357        Ok((-avg_log_prob).exp())
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    fn corpus() -> Vec<Vec<String>> {
366        vec![
367            simple_tokenize("the cat sat on the mat"),
368            simple_tokenize("the dog ran over the hill"),
369            simple_tokenize("a cat and a dog played"),
370            simple_tokenize("the cat chased the dog"),
371            simple_tokenize("the mat was on the floor"),
372        ]
373    }
374
375    #[test]
376    fn test_unigram_probabilities_sum_to_one() {
377        let lm = UnigramLM::train(&corpus()).expect("train failed");
378        let total: f64 = lm.probs.values().sum();
379        assert!((total - 1.0).abs() < 1e-9, "sum = {}", total);
380    }
381
382    #[test]
383    fn test_unigram_known_word() {
384        let lm = UnigramLM::train(&corpus()).expect("train");
385        assert!(lm.probability("cat") > 0.0);
386    }
387
388    #[test]
389    fn test_unigram_oov() {
390        let lm = UnigramLM::train(&corpus()).expect("train");
391        assert_eq!(lm.probability("xyzzy"), 0.0);
392    }
393
394    #[test]
395    fn test_bigram_probability_positive() {
396        let lm = BigramLM::train(&corpus()).expect("train");
397        let p = lm.probability("the", "cat");
398        assert!(p > 0.0 && p <= 1.0, "p = {}", p);
399    }
400
401    #[test]
402    fn test_bigram_unseen_is_smoothed() {
403        let lm = BigramLM::train(&corpus()).expect("train");
404        let p = lm.probability("cat", "airplane");
405        assert!(p > 0.0, "Laplace smoothed probability should be > 0");
406    }
407
408    #[test]
409    fn test_ngram_probability_trigram() {
410        let lm = NgramLM::train(3, &corpus()).expect("train");
411        let p = lm.probability("cat", &["<s>", "the"]);
412        assert!(p > 0.0, "p = {}", p);
413    }
414
415    #[test]
416    fn test_ngram_probability_unseen() {
417        let lm = NgramLM::train(2, &corpus()).expect("train");
418        let p = lm.probability("airplane", &["the"]);
419        // KN back-off should return a small but positive value
420        assert!(p > 0.0, "KN probability should be > 0 even for OOV");
421    }
422
423    #[test]
424    fn test_perplexity_finite() {
425        let train = corpus();
426        let lm = NgramLM::train(2, &train).expect("train");
427        let test_data = vec![simple_tokenize("the cat sat")];
428        let pp = PerplexityEval::compute(&lm, &test_data).expect("perplexity");
429        assert!(pp.is_finite() && pp > 1.0, "pp = {}", pp);
430    }
431
432    #[test]
433    fn test_perplexity_lower_on_train_than_random() {
434        let train = corpus();
435        let lm = NgramLM::train(2, &train).expect("train");
436
437        let train_pp = PerplexityEval::compute(&lm, &train[..2]).expect("train perplexity");
438        let random_pp = PerplexityEval::compute(&lm, &[simple_tokenize("xyzzy blorp quux flerb")])
439            .expect("random perplexity");
440
441        assert!(
442            train_pp <= random_pp,
443            "train pp {} should be <= random pp {}",
444            train_pp,
445            random_pp
446        );
447    }
448
449    #[test]
450    fn test_perplexity_empty_error() {
451        let lm = NgramLM::train(2, &corpus()).expect("train");
452        let result = PerplexityEval::compute(&lm, &[]);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn test_unigram_log_probability() {
458        let lm = UnigramLM::train(&corpus()).expect("train");
459        let lp = lm.log_probability("cat");
460        assert!(lp < 0.0 && lp.is_finite());
461    }
462}