1use crate::error::{Result, TextError};
51use crate::tokenize::{Tokenizer, WordTokenizer};
52use scirs2_core::random::prelude::*;
53use std::collections::HashMap;
54use std::fmt::Debug;
55
56#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum SmoothingMethod {
59 None,
61 Laplace,
63 AddK(f64),
65 KneserNey(f64),
67}
68
69pub struct NgramModel {
71 n: usize,
73 smoothing: SmoothingMethod,
75 ngram_counts: HashMap<Vec<String>, HashMap<String, usize>>,
77 context_counts: HashMap<Vec<String>, usize>,
79 vocabulary: Vec<String>,
81 total_words: usize,
83 tokenizer: Box<dyn Tokenizer + Send + Sync>,
85}
86
87impl Debug for NgramModel {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("NgramModel")
90 .field("n", &self.n)
91 .field("smoothing", &self.smoothing)
92 .field("vocabulary_size", &self.vocabulary.len())
93 .field("total_words", &self.total_words)
94 .finish()
95 }
96}
97
98impl Clone for NgramModel {
99 fn clone(&self) -> Self {
100 Self {
101 n: self.n,
102 smoothing: self.smoothing,
103 ngram_counts: self.ngram_counts.clone(),
104 context_counts: self.context_counts.clone(),
105 vocabulary: self.vocabulary.clone(),
106 total_words: self.total_words,
107 tokenizer: Box::new(WordTokenizer::default()),
108 }
109 }
110}
111
112impl NgramModel {
113 pub fn new(n: usize, smoothing: SmoothingMethod) -> Self {
120 if n == 0 {
121 panic!("N-gram order must be at least 1");
122 }
123
124 Self {
125 n,
126 smoothing,
127 ngram_counts: HashMap::new(),
128 context_counts: HashMap::new(),
129 vocabulary: Vec::new(),
130 total_words: 0,
131 tokenizer: Box::new(WordTokenizer::default()),
132 }
133 }
134
135 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
137 self.tokenizer = tokenizer;
138 self
139 }
140
141 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
143 if texts.is_empty() {
144 return Err(TextError::InvalidInput(
145 "No texts provided for training".into(),
146 ));
147 }
148
149 self.ngram_counts.clear();
151 self.context_counts.clear();
152 self.vocabulary.clear();
153 self.total_words = 0;
154
155 let mut vocab_set = std::collections::HashSet::new();
157
158 for &text in texts {
159 let tokens = self.tokenizer.tokenize(text)?;
160
161 let mut augmented_tokens = vec!["<START>".to_string(); self.n - 1];
163 augmented_tokens.extend(tokens);
164 augmented_tokens.push("<END>".to_string());
165
166 for token in &augmented_tokens {
168 vocab_set.insert(token.clone());
169 }
170
171 for i in (self.n - 1)..augmented_tokens.len() {
173 let context = augmented_tokens[i - (self.n - 1)..i].to_vec();
174 let word = &augmented_tokens[i];
175
176 *self
178 .ngram_counts
179 .entry(context.clone())
180 .or_default()
181 .entry(word.clone())
182 .or_insert(0) += 1;
183
184 *self.context_counts.entry(context).or_insert(0) += 1;
186
187 self.total_words += 1;
188 }
189 }
190
191 self.vocabulary = vocab_set.into_iter().collect();
192 self.vocabulary.sort();
193
194 Ok(())
195 }
196
197 pub fn probability(&self, context: &[&str], word: &str) -> Result<f64> {
208 if context.len() != self.n - 1 {
209 return Err(TextError::InvalidInput(format!(
210 "Context must have exactly {} words for {}-gram model",
211 self.n - 1,
212 self.n
213 )));
214 }
215
216 let context_vec: Vec<String> = context.iter().map(|s| s.to_string()).collect();
217 let vocab_size = self.vocabulary.len();
218
219 match self.smoothing {
220 SmoothingMethod::None => {
221 let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
223
224 if context_count == 0 {
225 return Ok(0.0);
226 }
227
228 let ngram_count = self
229 .ngram_counts
230 .get(&context_vec)
231 .and_then(|words| words.get(word))
232 .copied()
233 .unwrap_or(0);
234
235 Ok(ngram_count as f64 / context_count as f64)
236 }
237 SmoothingMethod::Laplace => {
238 let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
240
241 let ngram_count = self
242 .ngram_counts
243 .get(&context_vec)
244 .and_then(|words| words.get(word))
245 .copied()
246 .unwrap_or(0);
247
248 Ok((ngram_count + 1) as f64 / (context_count + vocab_size) as f64)
249 }
250 SmoothingMethod::AddK(k) => {
251 let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
253
254 let ngram_count = self
255 .ngram_counts
256 .get(&context_vec)
257 .and_then(|words| words.get(word))
258 .copied()
259 .unwrap_or(0);
260
261 Ok((ngram_count as f64 + k) / (context_count as f64 + k * vocab_size as f64))
262 }
263 SmoothingMethod::KneserNey(discount) => {
264 let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
266
267 if context_count == 0 {
268 return Ok(1.0 / vocab_size as f64);
269 }
270
271 let ngram_count = self
272 .ngram_counts
273 .get(&context_vec)
274 .and_then(|words| words.get(word))
275 .copied()
276 .unwrap_or(0);
277
278 let adjusted_count = (ngram_count as f64 - discount).max(0.0);
279 let lambda = discount
280 * self
281 .ngram_counts
282 .get(&context_vec)
283 .map(|m| m.len())
284 .unwrap_or(0) as f64
285 / context_count as f64;
286
287 let continuation_prob = 1.0 / vocab_size as f64;
288
289 Ok(adjusted_count / context_count as f64 + lambda * continuation_prob)
290 }
291 }
292 }
293
294 pub fn perplexity(&self, texts: &[&str]) -> Result<f64> {
299 if texts.is_empty() {
300 return Err(TextError::InvalidInput("No test texts provided".into()));
301 }
302
303 let mut log_prob_sum = 0.0;
304 let mut word_count = 0;
305
306 for &text in texts {
307 let tokens = self.tokenizer.tokenize(text)?;
308
309 let mut augmented_tokens = vec!["<START>".to_string(); self.n - 1];
310 augmented_tokens.extend(tokens);
311 augmented_tokens.push("<END>".to_string());
312
313 for i in (self.n - 1)..augmented_tokens.len() {
314 let context: Vec<&str> = augmented_tokens[i - (self.n - 1)..i]
315 .iter()
316 .map(|s| s.as_str())
317 .collect();
318 let word = &augmented_tokens[i];
319
320 let prob = self.probability(&context, word)?;
321
322 if prob > 0.0 {
323 log_prob_sum += prob.ln();
324 word_count += 1;
325 } else {
326 log_prob_sum += f64::ln(1e-10);
328 word_count += 1;
329 }
330 }
331 }
332
333 if word_count == 0 {
334 return Ok(f64::INFINITY);
335 }
336
337 Ok((-log_prob_sum / word_count as f64).exp())
338 }
339
340 pub fn generate(&self, max_length: usize, start_context: Option<&str>) -> Result<String> {
351 let mut rng = scirs2_core::random::rng();
352 let mut generated = Vec::new();
353
354 let mut context: Vec<String> = if let Some(start) = start_context {
356 let tokens = self.tokenizer.tokenize(start)?;
357 if tokens.len() < self.n - 1 {
358 let mut ctx = vec!["<START>".to_string(); self.n - 1 - tokens.len()];
359 ctx.extend(tokens);
360 ctx
361 } else {
362 tokens.into_iter().rev().take(self.n - 1).rev().collect()
363 }
364 } else {
365 vec!["<START>".to_string(); self.n - 1]
366 };
367
368 for _ in 0..max_length {
370 let context_refs: Vec<&str> = context.iter().map(|s| s.as_str()).collect();
371
372 let candidates = match self.ngram_counts.get(&context) {
374 Some(words) => words,
375 None => {
376 break;
378 }
379 };
380
381 if candidates.is_empty() {
382 break;
383 }
384
385 let total: usize = candidates.values().sum();
387 let mut threshold = rng.random_range(0..total);
388 let mut next_word = String::new();
389
390 for (word, &count) in candidates {
391 if threshold < count {
392 next_word = word.clone();
393 break;
394 }
395 threshold -= count;
396 }
397
398 if next_word == "<END>" {
399 break;
400 }
401
402 if next_word != "<START>" {
403 generated.push(next_word.clone());
404 }
405
406 context.remove(0);
408 context.push(next_word);
409 }
410
411 Ok(generated.join(" "))
412 }
413
414 pub fn suggest_next(&self, context: &[&str], top_n: usize) -> Result<Vec<(String, f64)>> {
425 if context.len() != self.n - 1 {
426 return Err(TextError::InvalidInput(format!(
427 "Context must have exactly {} words",
428 self.n - 1
429 )));
430 }
431
432 let context_vec: Vec<String> = context.iter().map(|s| s.to_string()).collect();
433
434 let candidates = match self.ngram_counts.get(&context_vec) {
435 Some(words) => words,
436 None => {
437 return Ok(Vec::new());
438 }
439 };
440
441 let mut suggestions: Vec<(String, f64)> = candidates
442 .keys()
443 .map(|word| {
444 let prob = self.probability(context, word).unwrap_or(0.0);
445 (word.clone(), prob)
446 })
447 .collect();
448
449 suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
450
451 Ok(suggestions.into_iter().take(top_n).collect())
452 }
453
454 pub fn order(&self) -> usize {
456 self.n
457 }
458
459 pub fn vocabulary_size(&self) -> usize {
461 self.vocabulary.len()
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_unigram_model() {
471 let texts = vec!["the cat sat on the mat", "the dog sat on the log"];
472
473 let mut model = NgramModel::new(1, SmoothingMethod::Laplace);
474 model.train(&texts).expect("Training failed");
475
476 let prob = model
478 .probability(&[], "the")
479 .expect("Failed to get probability");
480 assert!(prob > 0.0);
481 }
482
483 #[test]
484 fn test_bigram_model() {
485 let texts = vec!["the cat sat", "the dog sat"];
486
487 let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
488 model.train(&texts).expect("Training failed");
489
490 let prob = model
492 .probability(&["the"], "cat")
493 .expect("Failed to get probability");
494 assert!(prob > 0.0);
495
496 let prob = model
498 .probability(&["the"], "dog")
499 .expect("Failed to get probability");
500 assert!(prob > 0.0);
501 }
502
503 #[test]
504 fn test_trigram_model() {
505 let texts = vec!["the quick brown fox", "the quick red fox"];
506
507 let mut model = NgramModel::new(3, SmoothingMethod::Laplace);
508 model.train(&texts).expect("Training failed");
509
510 let prob = model
512 .probability(&["the", "quick"], "brown")
513 .expect("Failed to get probability");
514 assert!(prob > 0.0);
515 }
516
517 #[test]
518 fn test_smoothing_methods() {
519 let texts = vec!["the cat sat"];
520
521 let mut model_laplace = NgramModel::new(2, SmoothingMethod::Laplace);
523 model_laplace.train(&texts).expect("Training failed");
524
525 let prob_laplace = model_laplace
526 .probability(&["the"], "dog")
527 .expect("Failed to get probability");
528 assert!(
529 prob_laplace > 0.0,
530 "Laplace smoothing should give non-zero probability to unseen n-grams"
531 );
532
533 let mut model_addk = NgramModel::new(2, SmoothingMethod::AddK(0.5));
535 model_addk.train(&texts).expect("Training failed");
536
537 let prob_addk = model_addk
538 .probability(&["the"], "dog")
539 .expect("Failed to get probability");
540 assert!(prob_addk > 0.0);
541 }
542
543 #[test]
544 fn test_text_generation() {
545 let texts = vec![
546 "the quick brown fox jumps over the lazy dog",
547 "the quick brown dog runs fast",
548 ];
549
550 let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
551 model.train(&texts).expect("Training failed");
552
553 let generated = model.generate(10, Some("the")).expect("Generation failed");
554 assert!(!generated.is_empty());
555 }
556
557 #[test]
558 fn test_perplexity() {
559 let train_texts = vec!["the cat sat on the mat"];
560 let test_texts = vec!["the cat sat"];
561
562 let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
563 model.train(&train_texts).expect("Training failed");
564
565 let perplexity = model
566 .perplexity(&test_texts)
567 .expect("Failed to calculate perplexity");
568 assert!(perplexity > 0.0);
569 assert!(perplexity.is_finite());
570 }
571
572 #[test]
573 fn test_suggest_next() {
574 let texts = vec!["the cat sat", "the cat ran", "the dog sat"];
575
576 let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
577 model.train(&texts).expect("Training failed");
578
579 let suggestions = model
580 .suggest_next(&["the"], 3)
581 .expect("Failed to get suggestions");
582
583 assert!(!suggestions.is_empty());
584 assert!(suggestions.iter().any(|(word, _)| word == "cat"));
586 }
587}