Skip to main content

scirs2_text/evaluation/
perplexity.rs

1//! Perplexity-based language-model evaluation.
2//!
3//! Provides the [`LanguageModelLike`] trait, [`perplexity_evaluate`], and
4//! [`PerplexityReport`]. An implementation of `LanguageModelLike` for the
5//! existing [`crate::language_models::NgramLM`] is included.
6//!
7//! ## Example
8//!
9//! ```rust,ignore
10//! use scirs2_text::evaluation::perplexity::{
11//!     LanguageModelLike, PerplexityReport, perplexity_evaluate,
12//! };
13//!
14//! struct UniformModel { vocab: usize }
15//! impl LanguageModelLike for UniformModel {
16//!     fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
17//!         if tokens.is_empty() { return None; }
18//!         Some(tokens.len() as f64 * -(self.vocab as f64).ln())
19//!     }
20//!     fn vocabulary_size(&self) -> usize { self.vocab }
21//! }
22//! let model = UniformModel { vocab: 100 };
23//! let corpus = vec![vec!["a", "b", "c"]];
24//! let report = perplexity_evaluate(&model, &corpus).unwrap();
25//! assert!((report.corpus_perplexity - 100.0).abs() < 1e-6);
26//! ```
27
28use crate::error::{Result, TextError};
29use std::path::Path;
30
31/// Trait for language models that can compute conditional log-probability
32/// over a token sequence.
33///
34/// Implementors expose string tokens so the trait is usable with both
35/// word-level and character-level models without a vocabulary-bridge adapter.
36pub trait LanguageModelLike {
37    /// Return Σ_t log p(`tokens[t]` | tokens[0..t]) for the whole sequence.
38    ///
39    /// Returns `None` when the sequence is empty.
40    fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64>;
41
42    /// Size of the model vocabulary (used by tests / diagnostics only).
43    fn vocabulary_size(&self) -> usize;
44}
45
46/// Report returned by [`perplexity_evaluate`].
47#[derive(Debug, Clone)]
48pub struct PerplexityReport {
49    /// PP = exp(-1/N Σ log p(wᵢ|w<ᵢ)) across the whole corpus.
50    pub corpus_perplexity: f64,
51    /// Per-sentence perplexity; `NaN` for empty/failed sentences.
52    pub per_sentence_perplexity: Vec<f64>,
53    /// Total number of tokens processed.
54    pub total_tokens: usize,
55    /// Sum of log-probabilities across all tokens.
56    pub total_log_prob: f64,
57}
58
59/// Compute corpus perplexity for `model` over a pre-tokenized corpus.
60///
61/// `corpus` is a slice of sentences; each sentence is a `Vec<&str>` of string tokens.
62///
63/// # Errors
64///
65/// - [`TextError::InvalidInput`] when `corpus` is empty.
66/// - [`TextError::InvalidInput`] when all sentences are empty (zero tokens total).
67pub fn perplexity_evaluate(
68    model: &dyn LanguageModelLike,
69    corpus: &[Vec<&str>],
70) -> Result<PerplexityReport> {
71    if corpus.is_empty() {
72        return Err(TextError::InvalidInput("corpus is empty".into()));
73    }
74
75    let mut total_log_prob = 0.0f64;
76    let mut total_tokens = 0usize;
77    let mut per_sentence = Vec::with_capacity(corpus.len());
78
79    for sentence in corpus {
80        if sentence.is_empty() {
81            per_sentence.push(f64::NAN);
82            continue;
83        }
84        match model.log_prob_sequence(sentence) {
85            Some(lp) => {
86                let n = sentence.len();
87                let ppl = (-lp / n as f64).exp();
88                per_sentence.push(ppl);
89                total_log_prob += lp;
90                total_tokens += n;
91            }
92            None => {
93                per_sentence.push(f64::NAN);
94            }
95        }
96    }
97
98    if total_tokens == 0 {
99        return Err(TextError::InvalidInput(
100            "no tokens found in corpus (all sentences are empty)".into(),
101        ));
102    }
103
104    let corpus_ppl = (-total_log_prob / total_tokens as f64).exp();
105
106    Ok(PerplexityReport {
107        corpus_perplexity: corpus_ppl,
108        per_sentence_perplexity: per_sentence,
109        total_tokens,
110        total_log_prob,
111    })
112}
113
114/// Load a pre-tokenized corpus from a plain text file.
115///
116/// Each line is one sentence; tokens are whitespace-separated.
117/// Lines that tokenize to zero tokens are skipped.
118///
119/// # Errors
120///
121/// Returns [`TextError::IoError`] if the file cannot be opened or read.
122pub fn load_token_corpus(path: impl AsRef<Path>) -> Result<Vec<Vec<String>>> {
123    use std::fs::File;
124    use std::io::{BufRead, BufReader};
125
126    let file = File::open(path.as_ref()).map_err(|e| TextError::IoError(e.to_string()))?;
127    let reader = BufReader::new(file);
128    let mut result = Vec::new();
129
130    for line in reader.lines() {
131        let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
132        let tokens: Vec<String> = line.split_whitespace().map(str::to_owned).collect();
133        if !tokens.is_empty() {
134            result.push(tokens);
135        }
136    }
137
138    Ok(result)
139}
140
141// ---------------------------------------------------------------------------
142// Blanket helper: build owned string references for a sentence (used by impls)
143// ---------------------------------------------------------------------------
144
145/// Implement `LanguageModelLike` for `NgramLM` from `language_models` module.
146///
147/// `NgramLM` uses string-based context windows with Kneser-Ney smoothing.
148impl LanguageModelLike for crate::language_models::NgramLM {
149    fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
150        if tokens.is_empty() {
151            return None;
152        }
153        let n = self.n;
154        let mut log_prob = 0.0f64;
155
156        for i in 0..tokens.len() {
157            let ctx_start = if i >= n - 1 { i + 1 - n } else { 0 };
158            let context: Vec<&str> = tokens[ctx_start..i].to_vec();
159            let word = tokens[i];
160            let p = self.probability(word, &context);
161            log_prob += if p <= 0.0 { 1e-10_f64.ln() } else { p.ln() };
162        }
163
164        Some(log_prob)
165    }
166
167    fn vocabulary_size(&self) -> usize {
168        // Derive vocab size from the set of unique terminal words in n-gram keys.
169        // Each n-gram key Vec<String> has the word at index n-1.
170        let mut vocab = std::collections::HashSet::new();
171        for key in self.counts.keys() {
172            if let Some(word) = key.last() {
173                vocab.insert(word.as_str());
174            }
175        }
176        vocab.len().max(1)
177    }
178}
179
180/// Implement `LanguageModelLike` for the wired [`crate::language_model::NgramModel`].
181///
182/// Delegates to `NgramModel::probability` with the preceding `n-1` tokens as context.
183impl LanguageModelLike for crate::language_model::NgramModel {
184    fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
185        if tokens.is_empty() {
186            return None;
187        }
188        let n = self.order();
189        let mut log_prob = 0.0f64;
190
191        for i in 0..tokens.len() {
192            let ctx_start = if i >= n - 1 { i + 1 - n } else { 0 };
193            let context: Vec<&str> = tokens[ctx_start..i].to_vec();
194            let word = tokens[i];
195            // `probability` can return Err only for context-length mismatch;
196            // pad context to exactly n-1 with <START> tokens when needed.
197            let padded_ctx: Vec<&str> = if context.len() < n - 1 {
198                let needed = n - 1 - context.len();
199                let mut v: Vec<&str> = vec!["<START>"; needed];
200                v.extend_from_slice(&context);
201                v
202            } else {
203                context
204            };
205            let p = self.probability(&padded_ctx, word).unwrap_or(1e-10);
206            log_prob += if p <= 0.0 { 1e-10_f64.ln() } else { p.ln() };
207        }
208
209        Some(log_prob)
210    }
211
212    fn vocabulary_size(&self) -> usize {
213        self.vocabulary_size()
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    // Minimal model that always returns uniform log-probs
222    struct UniformModel {
223        /// Vocabulary size V; assigns probability 1/V to every token.
224        vocab: usize,
225    }
226
227    impl LanguageModelLike for UniformModel {
228        fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
229            if tokens.is_empty() {
230                return None;
231            }
232            Some(tokens.len() as f64 * -(self.vocab as f64).ln())
233        }
234
235        fn vocabulary_size(&self) -> usize {
236            self.vocab
237        }
238    }
239
240    // Perfect predictor: assigns probability 1.0 to every token
241    struct PerfectModel;
242
243    impl LanguageModelLike for PerfectModel {
244        fn log_prob_sequence(&self, tokens: &[&str]) -> Option<f64> {
245            if tokens.is_empty() {
246                return None;
247            }
248            Some(0.0_f64) // log(1.0) = 0
249        }
250
251        fn vocabulary_size(&self) -> usize {
252            1
253        }
254    }
255
256    #[test]
257    fn perplexity_uniform_model_equals_vocab_size() {
258        let model = UniformModel { vocab: 100 };
259        let corpus = vec![vec!["a", "b", "c", "d", "e"]];
260        let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
261        // PPL of uniform model = V = 100
262        assert!(
263            (report.corpus_perplexity - 100.0).abs() < 1e-6,
264            "expected 100.0, got {}",
265            report.corpus_perplexity
266        );
267    }
268
269    #[test]
270    fn perplexity_of_perfect_predictor_is_one() {
271        let model = PerfectModel;
272        let corpus = vec![vec!["a", "b", "c"]];
273        let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
274        assert!(
275            (report.corpus_perplexity - 1.0).abs() < 1e-9,
276            "expected 1.0, got {}",
277            report.corpus_perplexity
278        );
279    }
280
281    #[test]
282    fn perplexity_corpus_aggregates_token_log_probs() {
283        let model = UniformModel { vocab: 10 };
284        // Two sentences of 3 and 2 tokens: total 5 tokens
285        let corpus = vec![vec!["a", "b", "c"], vec!["d", "e"]];
286        let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
287        assert_eq!(report.total_tokens, 5);
288        // log_prob_total = 5 * ln(1/10)
289        let expected_lp = 5.0 * -(10.0f64).ln();
290        assert!(
291            (report.total_log_prob - expected_lp).abs() < 1e-9,
292            "expected total_log_prob {expected_lp}, got {}",
293            report.total_log_prob
294        );
295    }
296
297    #[test]
298    fn perplexity_empty_corpus_returns_error() {
299        let model = UniformModel { vocab: 10 };
300        let result = perplexity_evaluate(&model, &[]);
301        assert!(result.is_err());
302    }
303
304    #[test]
305    fn perplexity_per_sentence_are_positive() {
306        let model = UniformModel { vocab: 5 };
307        let corpus = vec![vec!["a"], vec!["b", "c"]];
308        let report = perplexity_evaluate(&model, &corpus).expect("evaluate");
309        for &ppl in &report.per_sentence_perplexity {
310            assert!(ppl > 0.0 && ppl.is_finite(), "per-sentence ppl = {ppl}");
311        }
312    }
313
314    #[test]
315    fn perplexity_all_empty_sentences_returns_error() {
316        let model = UniformModel { vocab: 5 };
317        // All sentences empty → total_tokens = 0
318        let corpus: Vec<Vec<&str>> = vec![vec![], vec![]];
319        let result = perplexity_evaluate(&model, &corpus);
320        assert!(result.is_err());
321    }
322}