1use crate::error::{Result, TextError};
35
36#[derive(Debug, Clone)]
38pub struct MeteorScore {
39 pub score: f64,
41 pub precision: f64,
43 pub recall: f64,
45 pub f_mean: f64,
47 pub penalty: f64,
49 pub chunks: usize,
51 pub matches: usize,
53}
54
55#[derive(Debug, Clone)]
57pub struct MeteorConfig {
58 pub alpha: f64,
62 pub beta: f64,
65 pub gamma: f64,
68 pub use_stemming: bool,
71 pub use_approximate: bool,
74 pub approximate_threshold: f64,
78}
79
80impl Default for MeteorConfig {
81 fn default() -> Self {
82 Self {
83 alpha: 0.9,
84 beta: 3.0,
85 gamma: 0.5,
86 use_stemming: true,
87 use_approximate: true,
88 approximate_threshold: 0.4,
89 }
90 }
91}
92
93fn simple_stem(word: &str) -> String {
99 let w = word.to_lowercase();
100 let len = w.len();
101
102 if len <= 3 {
103 return w;
104 }
105
106 let suffixes = [
108 "ational", "tional", "ences", "ances", "ments", "ously", "ively", "ation", "ness", "ment",
109 "able", "ible", "ting", "ally", "ence", "ance", "ings", "ized", "ling", "ful", "ous",
110 "ive", "ize", "ing", "ies", "ied", "ion", "ers", "est", "ess", "ism", "ist", "ity", "ble",
111 "ful", "ous", "ent", "ant", "ary", "ery", "ory", "al", "ly", "er", "ed", "en", "es", "ty",
112 ];
113
114 for suffix in &suffixes {
115 if w.ends_with(suffix) && len - suffix.len() >= 3 {
116 return w[..len - suffix.len()].to_string();
117 }
118 }
119
120 if w.ends_with('s') && !w.ends_with("ss") && len >= 4 {
122 return w[..len - 1].to_string();
123 }
124
125 w
126}
127
128fn edit_distance(a: &str, b: &str) -> usize {
130 let a_chars: Vec<char> = a.chars().collect();
131 let b_chars: Vec<char> = b.chars().collect();
132 let m = a_chars.len();
133 let n = b_chars.len();
134
135 if m == 0 {
136 return n;
137 }
138 if n == 0 {
139 return m;
140 }
141
142 let mut prev = vec![0usize; n + 1];
143 let mut curr = vec![0usize; n + 1];
144
145 for j in 0..=n {
146 prev[j] = j;
147 }
148
149 for i in 1..=m {
150 curr[0] = i;
151 for j in 1..=n {
152 let cost = if a_chars[i - 1] == b_chars[j - 1] {
153 0
154 } else {
155 1
156 };
157 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
158 }
159 std::mem::swap(&mut prev, &mut curr);
160 }
161
162 prev[n]
163}
164
165#[derive(Debug, Clone)]
167struct Alignment {
168 hyp_idx: usize,
170 ref_idx: usize,
172}
173
174fn build_alignment(
185 hypothesis: &[&str],
186 reference: &[&str],
187 config: &MeteorConfig,
188) -> Vec<Alignment> {
189 let hyp_lower: Vec<String> = hypothesis.iter().map(|w| w.to_lowercase()).collect();
190 let ref_lower: Vec<String> = reference.iter().map(|w| w.to_lowercase()).collect();
191
192 let mut hyp_matched = vec![false; hypothesis.len()];
193 let mut ref_matched = vec![false; reference.len()];
194 let mut alignments: Vec<Alignment> = Vec::new();
195
196 stage_match(
198 &hyp_lower,
199 &ref_lower,
200 &mut hyp_matched,
201 &mut ref_matched,
202 &mut alignments,
203 |h, r| h == r,
204 );
205
206 if config.use_stemming {
208 let hyp_stems: Vec<String> = hyp_lower.iter().map(|w| simple_stem(w)).collect();
209 let ref_stems: Vec<String> = ref_lower.iter().map(|w| simple_stem(w)).collect();
210
211 stage_match(
212 &hyp_stems,
213 &ref_stems,
214 &mut hyp_matched,
215 &mut ref_matched,
216 &mut alignments,
217 |h, r| h == r,
218 );
219 }
220
221 if config.use_approximate {
223 let threshold = config.approximate_threshold;
224 stage_match(
225 &hyp_lower,
226 &ref_lower,
227 &mut hyp_matched,
228 &mut ref_matched,
229 &mut alignments,
230 |h, r| {
231 let max_len = h.len().max(r.len());
232 if max_len == 0 {
233 return true;
234 }
235 let dist = edit_distance(h, r);
236 (dist as f64 / max_len as f64) <= threshold
237 },
238 );
239 }
240
241 alignments
242}
243
244fn stage_match<F>(
250 hyp_forms: &[String],
251 ref_forms: &[String],
252 hyp_matched: &mut [bool],
253 ref_matched: &mut [bool],
254 alignments: &mut Vec<Alignment>,
255 matches: F,
256) where
257 F: Fn(&str, &str) -> bool,
258{
259 for (h_idx, h_form) in hyp_forms.iter().enumerate() {
260 if hyp_matched[h_idx] {
261 continue;
262 }
263
264 let mut best_r_idx: Option<usize> = None;
265 let mut best_dist = usize::MAX;
266
267 for (r_idx, r_form) in ref_forms.iter().enumerate() {
268 if ref_matched[r_idx] {
269 continue;
270 }
271 if matches(h_form, r_form) {
272 let dist = h_idx.abs_diff(r_idx);
273 if dist < best_dist {
274 best_dist = dist;
275 best_r_idx = Some(r_idx);
276 }
277 }
278 }
279
280 if let Some(r_idx) = best_r_idx {
281 hyp_matched[h_idx] = true;
282 ref_matched[r_idx] = true;
283 alignments.push(Alignment {
284 hyp_idx: h_idx,
285 ref_idx: r_idx,
286 });
287 }
288 }
289}
290
291fn count_chunks(alignments: &[Alignment]) -> usize {
296 if alignments.is_empty() {
297 return 0;
298 }
299
300 let mut sorted = alignments.to_vec();
301 sorted.sort_by_key(|a| a.hyp_idx);
302
303 let mut chunks = 1usize;
304 for i in 1..sorted.len() {
305 let hyp_contiguous = sorted[i].hyp_idx == sorted[i - 1].hyp_idx + 1;
308 let ref_contiguous = sorted[i].ref_idx == sorted[i - 1].ref_idx + 1;
309 if !hyp_contiguous || !ref_contiguous {
310 chunks += 1;
311 }
312 }
313
314 chunks
315}
316
317pub fn meteor_score(
333 hypothesis: &[&str],
334 reference: &[&str],
335 config: &MeteorConfig,
336) -> Result<MeteorScore> {
337 if config.alpha <= 0.0 || config.alpha >= 1.0 {
338 return Err(TextError::InvalidInput(format!(
339 "Alpha must be in (0, 1), got {}",
340 config.alpha
341 )));
342 }
343
344 let hyp_len = hypothesis.len();
345 let ref_len = reference.len();
346
347 if hyp_len == 0 && ref_len == 0 {
349 return Ok(MeteorScore {
350 score: 1.0,
351 precision: 1.0,
352 recall: 1.0,
353 f_mean: 1.0,
354 penalty: 0.0,
355 chunks: 0,
356 matches: 0,
357 });
358 }
359 if hyp_len == 0 || ref_len == 0 {
360 return Ok(MeteorScore {
361 score: 0.0,
362 precision: 0.0,
363 recall: 0.0,
364 f_mean: 0.0,
365 penalty: 0.0,
366 chunks: 0,
367 matches: 0,
368 });
369 }
370
371 let alignments = build_alignment(hypothesis, reference, config);
373 let matches = alignments.len();
374
375 if matches == 0 {
376 return Ok(MeteorScore {
377 score: 0.0,
378 precision: 0.0,
379 recall: 0.0,
380 f_mean: 0.0,
381 penalty: 0.0,
382 chunks: 0,
383 matches: 0,
384 });
385 }
386
387 let precision = matches as f64 / hyp_len as f64;
389 let recall = matches as f64 / ref_len as f64;
390
391 let alpha = config.alpha;
393 let f_mean = (precision * recall) / (alpha * precision + (1.0 - alpha) * recall);
394
395 let chunks = count_chunks(&alignments);
397 let frag = chunks as f64 / matches as f64;
398 let penalty = config.gamma * frag.powf(config.beta);
399
400 let penalty = penalty.clamp(0.0, 1.0);
402
403 let score = f_mean * (1.0 - penalty);
404
405 Ok(MeteorScore {
406 score,
407 precision,
408 recall,
409 f_mean,
410 penalty,
411 chunks,
412 matches,
413 })
414}
415
416pub fn meteor_score_multi(
428 hypothesis: &[&str],
429 references: &[Vec<&str>],
430 config: &MeteorConfig,
431) -> Result<MeteorScore> {
432 if references.is_empty() {
433 return Err(TextError::InvalidInput(
434 "References must not be empty".to_string(),
435 ));
436 }
437
438 let mut best: Option<MeteorScore> = None;
439 for reference in references {
440 let score = meteor_score(hypothesis, reference, config)?;
441 if best.is_none() || score.score > best.as_ref().map_or(0.0, |b| b.score) {
442 best = Some(score);
443 }
444 }
445
446 best.ok_or_else(|| TextError::InvalidInput("No references provided".to_string()))
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_exact_match_score() {
456 let hypothesis = vec!["the", "cat", "is", "on", "the", "mat"];
457 let reference = vec!["the", "cat", "is", "on", "the", "mat"];
458 let config = MeteorConfig::default();
459 let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
460
461 assert!(
463 (result.precision - 1.0).abs() < 1e-9,
464 "Precision should be 1.0"
465 );
466 assert!((result.recall - 1.0).abs() < 1e-9, "Recall should be 1.0");
467 assert!(result.score > 0.9, "Perfect match should score high");
469 }
470
471 #[test]
472 fn test_no_match_score() {
473 let hypothesis = vec!["a", "b", "c"];
474 let reference = vec!["x", "y", "z"];
475 let config = MeteorConfig {
476 use_approximate: false,
477 ..Default::default()
478 };
479 let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
480 assert!(result.score.abs() < 1e-9, "No match should score 0.0");
481 }
482
483 #[test]
484 fn test_partial_match_with_stemming() {
485 let hypothesis = vec!["the", "cats", "sitting", "on", "the", "mats"];
486 let reference = vec!["the", "cat", "sat", "on", "the", "mat"];
487 let config = MeteorConfig {
488 use_stemming: true,
489 use_approximate: false,
490 ..Default::default()
491 };
492 let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
493
494 assert!(
498 result.matches >= 3,
499 "Should have at least exact matches: got {}",
500 result.matches
501 );
502 assert!(
503 result.score > 0.0,
504 "Partial match should give positive score"
505 );
506 }
507
508 #[test]
509 fn test_fragmentation_penalty() {
510 let hypothesis = vec!["mat", "the", "on", "sat", "cat", "the"];
512 let reference = vec!["the", "cat", "sat", "on", "the", "mat"];
513 let config = MeteorConfig {
514 use_stemming: false,
515 use_approximate: false,
516 ..Default::default()
517 };
518 let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
519
520 assert!(result.chunks > 1, "Scrambled order should produce chunks");
522 assert!(
523 result.penalty > 0.0,
524 "Should have fragmentation penalty: {}",
525 result.penalty
526 );
527 }
528
529 #[test]
530 fn test_approximate_matching() {
531 let hypothesis = vec!["colour", "neighbours"];
532 let reference = vec!["color", "neighbors"];
533 let config = MeteorConfig {
534 use_stemming: false,
535 use_approximate: true,
536 approximate_threshold: 0.4,
537 ..Default::default()
538 };
539 let result = meteor_score(&hypothesis, &reference, &config).expect("should compute");
540
541 assert_eq!(result.matches, 2, "Both should match approximately");
544 assert!(result.score > 0.0);
545 }
546
547 #[test]
548 fn test_invalid_alpha() {
549 let result = meteor_score(
550 &["a"],
551 &["a"],
552 &MeteorConfig {
553 alpha: 0.0,
554 ..Default::default()
555 },
556 );
557 assert!(result.is_err());
558 }
559
560 #[test]
561 fn test_multi_reference() {
562 let hypothesis = vec!["the", "cat", "sat"];
563 let references = vec![vec!["a", "dog", "ran"], vec!["the", "cat", "sat"]];
564 let config = MeteorConfig::default();
565 let result = meteor_score_multi(&hypothesis, &references, &config).expect("should compute");
566 assert!(
567 result.score > 0.8,
568 "Should match second reference well: {}",
569 result.score
570 );
571 }
572
573 #[test]
574 fn test_simple_stem() {
575 assert_eq!(simple_stem("running"), "runn");
576 assert_eq!(simple_stem("cats"), "cat");
577 assert_eq!(simple_stem("happiness"), "happi");
578 assert_eq!(simple_stem("the"), "the");
580 }
581}