scirs2_text/language_models/
mod.rs1use std::collections::{HashMap, HashSet};
12
13use crate::error::{Result, TextError};
14
15fn simple_tokenize(text: &str) -> Vec<String> {
21 text.split(|c: char| !c.is_alphabetic())
22 .filter(|s| !s.is_empty())
23 .map(|s| s.to_lowercase())
24 .collect()
25}
26
27#[derive(Debug, Clone)]
33pub struct UnigramLM {
34 pub probs: HashMap<String, f64>,
36 pub vocab: HashSet<String>,
38}
39
40impl UnigramLM {
41 pub fn train(sentences: &[Vec<String>]) -> Result<UnigramLM> {
43 let mut counts: HashMap<String, usize> = HashMap::new();
44 let mut total = 0usize;
45 for sent in sentences {
46 for w in sent {
47 *counts.entry(w.clone()).or_insert(0) += 1;
48 total += 1;
49 }
50 }
51 if total == 0 {
52 return Err(TextError::InvalidInput("Empty corpus".to_string()));
53 }
54 let vocab: HashSet<String> = counts.keys().cloned().collect();
55 let probs = counts
56 .into_iter()
57 .map(|(k, c)| (k, c as f64 / total as f64))
58 .collect();
59 Ok(UnigramLM { probs, vocab })
60 }
61
62 pub fn probability(&self, word: &str) -> f64 {
66 self.probs.get(word).copied().unwrap_or(0.0)
67 }
68
69 pub fn log_probability(&self, word: &str) -> f64 {
71 let p = self.probability(word);
72 if p <= 0.0 {
73 f64::NEG_INFINITY
74 } else {
75 p.ln()
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
86pub struct BigramLM {
87 pub probs: HashMap<(String, String), f64>,
89 pub unigrams: HashMap<String, f64>,
91 vocab_size: usize,
93}
94
95impl BigramLM {
96 pub fn train(sentences: &[Vec<String>]) -> Result<BigramLM> {
100 let mut uni_counts: HashMap<String, usize> = HashMap::new();
101 let mut bi_counts: HashMap<(String, String), usize> = HashMap::new();
102
103 const START: &str = "<s>";
105 const END: &str = "</s>";
106
107 for sent in sentences {
108 if sent.is_empty() {
109 continue;
110 }
111 let padded: Vec<&str> = std::iter::once(START)
112 .chain(sent.iter().map(String::as_str))
113 .chain(std::iter::once(END))
114 .collect();
115 for i in 0..padded.len() - 1 {
116 *uni_counts.entry(padded[i].to_string()).or_insert(0) += 1;
117 *bi_counts
118 .entry((padded[i].to_string(), padded[i + 1].to_string()))
119 .or_insert(0) += 1;
120 }
121 *uni_counts.entry(END.to_string()).or_insert(0) += 1;
122 }
123
124 let vocab_size = uni_counts.len();
125 if vocab_size == 0 {
126 return Err(TextError::InvalidInput("Empty corpus".to_string()));
127 }
128
129 let total_uni: usize = uni_counts.values().sum();
131 let unigrams: HashMap<String, f64> = uni_counts
132 .iter()
133 .map(|(w, &c)| {
134 let p = (c as f64 + 1.0) / (total_uni as f64 + vocab_size as f64);
135 (w.clone(), p)
136 })
137 .collect();
138
139 let mut probs: HashMap<(String, String), f64> = HashMap::new();
141 for ((prev, curr), &c) in &bi_counts {
143 let prev_count = uni_counts.get(prev).copied().unwrap_or(0) as f64;
144 let p = (c as f64 + 1.0) / (prev_count + vocab_size as f64);
145 probs.insert((prev.clone(), curr.clone()), p);
146 }
147
148 Ok(BigramLM {
149 probs,
150 unigrams,
151 vocab_size,
152 })
153 }
154
155 pub fn probability(&self, prev: &str, curr: &str) -> f64 {
157 self.probs
158 .get(&(prev.to_string(), curr.to_string()))
159 .copied()
160 .unwrap_or_else(|| 1.0 / (self.vocab_size as f64 + 1.0))
161 }
162}
163
164#[derive(Debug, Clone)]
173pub struct NgramLM {
174 pub n: usize,
176 pub counts: HashMap<Vec<String>, usize>,
178 pub context_counts: HashMap<Vec<String>, usize>,
180 continuation_counts: HashMap<String, usize>,
182 n_bigrams: usize,
184 discount: f64,
186}
187
188impl NgramLM {
189 pub fn train(n: usize, sentences: &[Vec<String>]) -> Result<NgramLM> {
191 if n < 1 {
192 return Err(TextError::InvalidInput("n must be >= 1".to_string()));
193 }
194 const START: &str = "<s>";
195 const END: &str = "</s>";
196
197 let mut counts: HashMap<Vec<String>, usize> = HashMap::new();
198 let mut context_counts: HashMap<Vec<String>, usize> = HashMap::new();
199 let mut continuation_counts: HashMap<String, usize> = HashMap::new();
200 let mut bigram_set: HashSet<(String, String)> = HashSet::new();
201
202 for sent in sentences {
203 if sent.is_empty() {
204 continue;
205 }
206 let mut padded: Vec<String> = (0..n - 1).map(|_| START.to_string()).collect();
208 padded.extend(sent.iter().cloned());
209 padded.push(END.to_string());
210
211 for i in 0..padded.len().saturating_sub(n - 1) {
212 let ngram: Vec<String> = padded[i..i + n].to_vec();
213 let context: Vec<String> = padded[i..i + n - 1].to_vec();
214 *counts.entry(ngram).or_insert(0) += 1;
215 if n > 1 {
216 *context_counts.entry(context).or_insert(0) += 1;
217 }
218 }
219
220 for i in 1..padded.len() {
222 bigram_set.insert((padded[i - 1].clone(), padded[i].clone()));
223 *continuation_counts.entry(padded[i].clone()).or_insert(0) += 0;
224 }
226 }
227
228 for (_, curr) in &bigram_set {
230 *continuation_counts.entry(curr.clone()).or_insert(0) += 1;
231 }
232 let n_bigrams = bigram_set.len();
233
234 Ok(NgramLM {
235 n,
236 counts,
237 context_counts,
238 continuation_counts,
239 n_bigrams,
240 discount: 0.75,
241 })
242 }
243
244 pub fn probability(&self, word: &str, context: &[&str]) -> f64 {
248 self.kn_probability(word, context)
249 }
250
251 fn kn_probability(&self, word: &str, context: &[&str]) -> f64 {
252 if self.n == 1 {
253 return self.kn_unigram(word);
254 }
255
256 let used_ctx: Vec<String> = if context.len() >= self.n - 1 {
258 context[context.len() - (self.n - 1)..]
259 .iter()
260 .map(|s| s.to_string())
261 .collect()
262 } else {
263 context.iter().map(|s| s.to_string()).collect()
264 };
265
266 let ngram: Vec<String> = used_ctx
267 .iter()
268 .cloned()
269 .chain(std::iter::once(word.to_string()))
270 .collect();
271
272 let c = self.counts.get(&ngram).copied().unwrap_or(0) as f64;
273 let c_ctx = self.context_counts.get(&used_ctx).copied().unwrap_or(0) as f64;
274
275 if c_ctx == 0.0 {
276 return self.kn_unigram(word);
277 }
278
279 let types_after_ctx = self
281 .counts
282 .iter()
283 .filter(|(k, &v)| v > 0 && k.len() == self.n && k[..self.n - 1] == used_ctx[..])
284 .count() as f64;
285
286 let lambda = self.discount * types_after_ctx / c_ctx;
287 let first_term = (c - self.discount).max(0.0) / c_ctx;
288 first_term + lambda * self.kn_unigram(word)
289 }
290
291 fn kn_unigram(&self, word: &str) -> f64 {
292 let c_w = self.continuation_counts.get(word).copied().unwrap_or(0) as f64;
293 if self.n_bigrams == 0 {
294 return 1e-10;
295 }
296 (c_w / self.n_bigrams as f64).max(1e-10)
297 }
298
299 pub fn log_probability(&self, word: &str, context: &[&str]) -> f64 {
301 let p = self.probability(word, context);
302 if p <= 0.0 {
303 f64::NEG_INFINITY
304 } else {
305 p.ln()
306 }
307 }
308}
309
310pub struct PerplexityEval;
316
317impl PerplexityEval {
318 pub fn compute(lm: &NgramLM, test_sentences: &[Vec<String>]) -> Result<f64> {
322 let mut log_prob_sum = 0.0f64;
323 let mut token_count = 0usize;
324
325 const START: &str = "<s>";
326
327 for sent in test_sentences {
328 if sent.is_empty() {
329 continue;
330 }
331 let mut padded: Vec<String> = (0..lm.n - 1).map(|_| START.to_string()).collect();
333 padded.extend(sent.iter().cloned());
334
335 for i in lm.n - 1..padded.len() {
336 let word = &padded[i];
337 let ctx_start = i.saturating_sub(lm.n - 1);
338 let context: Vec<&str> = padded[ctx_start..i].iter().map(String::as_str).collect();
339 let lp = lm.log_probability(word, &context);
340 if lp.is_finite() {
341 log_prob_sum += lp;
342 } else {
343 log_prob_sum += (1e-10_f64).ln();
345 }
346 token_count += 1;
347 }
348 }
349
350 if token_count == 0 {
351 return Err(TextError::InvalidInput(
352 "No tokens in test sentences".to_string(),
353 ));
354 }
355
356 let avg_log_prob = log_prob_sum / token_count as f64;
357 Ok((-avg_log_prob).exp())
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 fn corpus() -> Vec<Vec<String>> {
366 vec![
367 simple_tokenize("the cat sat on the mat"),
368 simple_tokenize("the dog ran over the hill"),
369 simple_tokenize("a cat and a dog played"),
370 simple_tokenize("the cat chased the dog"),
371 simple_tokenize("the mat was on the floor"),
372 ]
373 }
374
375 #[test]
376 fn test_unigram_probabilities_sum_to_one() {
377 let lm = UnigramLM::train(&corpus()).expect("train failed");
378 let total: f64 = lm.probs.values().sum();
379 assert!((total - 1.0).abs() < 1e-9, "sum = {}", total);
380 }
381
382 #[test]
383 fn test_unigram_known_word() {
384 let lm = UnigramLM::train(&corpus()).expect("train");
385 assert!(lm.probability("cat") > 0.0);
386 }
387
388 #[test]
389 fn test_unigram_oov() {
390 let lm = UnigramLM::train(&corpus()).expect("train");
391 assert_eq!(lm.probability("xyzzy"), 0.0);
392 }
393
394 #[test]
395 fn test_bigram_probability_positive() {
396 let lm = BigramLM::train(&corpus()).expect("train");
397 let p = lm.probability("the", "cat");
398 assert!(p > 0.0 && p <= 1.0, "p = {}", p);
399 }
400
401 #[test]
402 fn test_bigram_unseen_is_smoothed() {
403 let lm = BigramLM::train(&corpus()).expect("train");
404 let p = lm.probability("cat", "airplane");
405 assert!(p > 0.0, "Laplace smoothed probability should be > 0");
406 }
407
408 #[test]
409 fn test_ngram_probability_trigram() {
410 let lm = NgramLM::train(3, &corpus()).expect("train");
411 let p = lm.probability("cat", &["<s>", "the"]);
412 assert!(p > 0.0, "p = {}", p);
413 }
414
415 #[test]
416 fn test_ngram_probability_unseen() {
417 let lm = NgramLM::train(2, &corpus()).expect("train");
418 let p = lm.probability("airplane", &["the"]);
419 assert!(p > 0.0, "KN probability should be > 0 even for OOV");
421 }
422
423 #[test]
424 fn test_perplexity_finite() {
425 let train = corpus();
426 let lm = NgramLM::train(2, &train).expect("train");
427 let test_data = vec![simple_tokenize("the cat sat")];
428 let pp = PerplexityEval::compute(&lm, &test_data).expect("perplexity");
429 assert!(pp.is_finite() && pp > 1.0, "pp = {}", pp);
430 }
431
432 #[test]
433 fn test_perplexity_lower_on_train_than_random() {
434 let train = corpus();
435 let lm = NgramLM::train(2, &train).expect("train");
436
437 let train_pp = PerplexityEval::compute(&lm, &train[..2]).expect("train perplexity");
438 let random_pp = PerplexityEval::compute(&lm, &[simple_tokenize("xyzzy blorp quux flerb")])
439 .expect("random perplexity");
440
441 assert!(
442 train_pp <= random_pp,
443 "train pp {} should be <= random pp {}",
444 train_pp,
445 random_pp
446 );
447 }
448
449 #[test]
450 fn test_perplexity_empty_error() {
451 let lm = NgramLM::train(2, &corpus()).expect("train");
452 let result = PerplexityEval::compute(&lm, &[]);
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn test_unigram_log_probability() {
458 let lm = UnigramLM::train(&corpus()).expect("train");
459 let lp = lm.log_probability("cat");
460 assert!(lp < 0.0 && lp.is_finite());
461 }
462}