scirs2_text/huggingface_compat/pipelines/
fill_mask.rs1use super::FillMaskResult;
7use crate::error::{Result, TextError};
8
9#[derive(Debug)]
11pub struct FillMaskPipeline;
12
13impl Default for FillMaskPipeline {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl FillMaskPipeline {
20 pub fn new() -> Self {
22 Self
23 }
24
25 pub fn fill_mask(&self, text: &str) -> Result<Vec<FillMaskResult>> {
27 if !text.contains("[MASK]") {
29 return Err(TextError::InvalidInput("No [MASK] token found".to_string()));
30 }
31
32 let words: Vec<&str> = text.split_whitespace().collect();
34 let mask_index = words.iter().position(|&w| w == "[MASK]").unwrap_or(0);
35
36 let left_context: Vec<&str> = if mask_index > 0 {
38 words[..mask_index].iter().rev().take(3).copied().collect()
39 } else {
40 vec![]
41 };
42
43 let right_context: Vec<&str> = if mask_index < words.len() - 1 {
44 words[mask_index + 1..].iter().take(3).copied().collect()
45 } else {
46 vec![]
47 };
48
49 let mut candidates = Vec::new();
51
52 let common_words = vec![
54 ("the", 0.85),
55 ("a", 0.75),
56 ("an", 0.65),
57 ("is", 0.80),
58 ("was", 0.75),
59 ("are", 0.70),
60 ("will", 0.68),
61 ("can", 0.72),
62 ("would", 0.70),
63 ("should", 0.65),
64 ("very", 0.60),
65 ("more", 0.68),
66 ("most", 0.65),
67 ("good", 0.60),
68 ("great", 0.58),
69 ("important", 0.55),
70 ("significant", 0.52),
71 ("major", 0.50),
72 ];
73
74 for (word, base_score) in common_words {
75 let mut score = base_score;
77
78 if !left_context.is_empty() {
80 let prev_word = left_context[0];
81 if (prev_word == "a" || prev_word == "an") && word.starts_with(char::is_alphabetic)
82 {
83 score *= 0.3; } else if prev_word.ends_with("ly") && (word == "good" || word == "important") {
85 score *= 1.2; }
87 }
88
89 if !right_context.is_empty() {
90 let next_word = right_context[0];
91 if word == "a" && next_word.starts_with(|c: char| "aeiou".contains(c)) {
92 score *= 0.2; } else if word == "an" && !next_word.starts_with(|c: char| "aeiou".contains(c)) {
94 score *= 0.2; }
96 }
97
98 candidates.push(FillMaskResult {
99 token_str: word.to_string(),
100 sequence: text.replace("[MASK]", word),
101 score,
102 token: candidates.len() + 1,
103 });
104 }
105
106 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
108 Ok(candidates.into_iter().take(5).collect())
109 }
110}