text_score/
rouge.rs

1//! This is an implementation for metrics to be used in various ML/DL fields.
2//! for now, split_whitespace based rouge-n score is provided.
3//!
4use std::collections::HashMap;
5use std::cmp::{min, max};
6use anyhow::{Result, Error};
7pub use crate::commons::{Score, f1, precision, recall};
8
9
10
11/// Creates n-grams from a list of tokens.
12///
13/// Given a list of tokens and the desired size `n`, this function generates n-grams,
14/// which are contiguous sequences of `n` tokens from the input list.
15///
16/// ### Arguments
17///
18/// * `tokens` - A vector of string slices representing individual tokens.
19/// * `n` - The size of the n-grams to be created.
20///
21/// ### Returns
22///
23/// A `HashMap` where keys are n-grams (represented as vectors of string slices) and values
24/// are the counts of each n-gram in the input sequence.
25///
26/// ### Examples
27///
28/// ```
29/// use std::collections::HashMap;
30/// use text_score::rouge::create_ngrams;
31///
32/// let tokens = vec!["this", "is", "an", "example"];
33/// let n = 2;
34///
35/// let ngrams = create_ngrams(tokens, n);
36///
37/// // The result may look like: {"this is": 1, "is an": 1, "an example": 1}
38/// ```
39///
40/// ### Note
41///
42/// - The function uses a sliding window approach to iterate through the input `tokens`
43///   and create n-grams of the specified size `n`.
44/// - The resulting n-grams are stored in a `HashMap`, where each key is an n-gram,
45///   and the corresponding value is the count of occurrences of that n-gram in the input sequence.
46pub fn create_ngrams(tokens: Vec<&str>, n: usize) -> HashMap<Vec<&str>, u32> {
47    let mut ngrams: HashMap<Vec<&str>, u32> = HashMap::new();
48
49    for i in 0..(tokens.len() - n + 1) {
50        let ngram: Vec<&str> = tokens[i..i + n].to_vec();
51        *ngrams.entry(ngram).or_insert(0) += 1;
52    }
53    return ngrams;
54}
55
56/// Computes precision, recall, and F1 score based on n-grams.
57///
58/// Given two HashMaps representing the n-grams of predicted and target sequences,
59/// this function calculates precision, recall, and F1 score for the prediction.
60///
61/// ### Arguments
62///
63/// * `predicted_ngrams` - A HashMap containing n-grams and their counts for the predicted sequence.
64/// * `target_ngrams` - A HashMap containing n-grams and their counts for the target (reference) sequence.
65///
66/// ### Returns
67///
68/// A `Score` struct containing precision, recall, and F1 score for the prediction based on n-grams.
69///
70/// ### Examples
71///
72/// ```
73/// use std::collections::{HashMap, hash_map};
74/// use text_score::rouge::{ngram_based_score, Score}; // Replace with the actual module name
75///
76/// let predicted_ngrams = hashmap! { vec!["this", "is"] => 2, vec!["is", "an"] => 1 };
77/// let target_ngrams = hashmap! { vec!["this", "is"] => 3, vec!["is", "an"] => 2 };
78///
79/// let score = ngram_based_score(predicted_ngrams, target_ngrams);
80/// println!("Precision: {}", score.precision); // Accessing precision field
81/// println!("Recall: {}", score.recall);       // Accessing recall field
82/// println!("F1 Score: {}", score.f1);         // Accessing f1 field
83/// ```
84///
85/// # Note
86///
87/// - The function iterates through the target n-grams and computes the intersection count
88///   with the predicted n-grams to calculate precision, recall, and F1 score.
89/// - Precision and recall are calculated using the standard formulas, and F1 score is computed
90///   using the `f1` function defined in the module.
91/// - The resulting scores are returned in a `Score` struct.
92pub fn ngram_based_score(predicted_ngrams:HashMap<Vec<&str>, u32>, target_ngrams:HashMap<Vec<&str>, u32>) -> Score{
93    let mut intersection_ngrams_count: u32=0;
94    let target_ngrams_count:u32 = target_ngrams.values().map(|&v| v).sum();
95    let prediction_ngrams_count:u32= predicted_ngrams.values().map(|&v| v).sum();
96
97    for (ngram, target_cnt) in target_ngrams.iter(){
98        intersection_ngrams_count += min(target_cnt, predicted_ngrams.get(ngram).unwrap_or(&0));
99
100    }
101    let p:f32 = intersection_ngrams_count as f32/ max(prediction_ngrams_count, 1) as f32;
102    let r:f32 = intersection_ngrams_count as f32/ max(target_ngrams_count, 1) as f32;
103    let f:f32 = f1(p, r);
104
105    return Score{precision:p, recall:r, f1:f};
106}
107
108
109/// Computes ROUGE scores based on n-grams for a given input and reference text.
110///
111/// ROUGE (Recall-Oriented Understudy for Gisting Evaluation) is a metric commonly used
112/// in natural language processing to evaluate the quality of text summaries or translations.
113/// This function calculates precision, recall, and F1 score based on n-grams for the provided input
114/// text and reference text.
115///
116/// ### Arguments
117///
118/// * `input` - The input text to be evaluated.
119/// * `reference` - The reference text, considered as the ground truth or gold standard.
120/// * `n` - The size of n-grams to be used in the evaluation.
121///
122/// ### Returns
123///
124/// A `Result` containing a `Score` struct if successful, or an error message if `n` is less than 1.
125///
126/// ### Examples
127///
128/// ```
129/// use text_score::rouge::{rouge_n, Score}; // Replace with the actual module name
130///
131/// let input_text = "This is a sample sentence for evaluation.";
132/// let reference_text = "This is a sample sentence for testing.";
133/// let n = 2;
134///
135/// match rouge_n(input_text, reference_text, n) {
136///     Ok(score) => {
137///         println!("Precision: {}", score.precision); // Accessing precision field
138///         println!("Recall: {}", score.recall);       // Accessing recall field
139///         println!("F1 Score: {}", score.f1);         // Accessing f1 field
140///     }
141///     Err(err) => println!("Error: {}", err),
142/// }
143/// ```
144///
145/// # Note
146///
147/// - The function checks if the specified `n` is greater than or equal to 1. If not, it returns an error.
148/// - The input and reference texts are tokenized into words, and n-grams are created using the `create_ngrams` function.
149/// - The n-gram based scores are then calculated using the `ngram_based_score` function.
150/// - The resulting scores are returned in a `Score` struct if the operation is successful.
151pub fn rouge_n(input:&str, reference: &str, n:usize) -> Result<Score>{
152    if n < 1 {
153        return Err(Error::msg("n should be >= 1"));
154    }
155
156    let input_words = input.split_whitespace().collect();
157    let reference_words = reference.split_whitespace().collect();
158
159    // create n-grams
160    let mut input_ngrams = create_ngrams(input_words, n);
161    let mut reference_ngrams = create_ngrams(reference_words, n);
162
163    // get n-gram based f1 score
164    Ok(ngram_based_score(input_ngrams, reference_ngrams))
165}