scirs2_text/huggingface_compat/pipelines/
fill_mask.rs

1//! Fill mask pipeline implementation for masked language modeling
2//!
3//! This module provides functionality for filling masked tokens in text
4//! based on contextual understanding.
5
6use super::FillMaskResult;
7use crate::error::{Result, TextError};
8
9/// Fill mask pipeline for masked language modeling
10#[derive(Debug)]
11pub struct FillMaskPipeline;
12
13impl Default for FillMaskPipeline {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl FillMaskPipeline {
20    /// Create new fill mask pipeline
21    pub fn new() -> Self {
22        Self
23    }
24
25    /// Fill masked tokens in text
26    pub fn fill_mask(&self, text: &str) -> Result<Vec<FillMaskResult>> {
27        // Improved mask filling using context analysis
28        if !text.contains("[MASK]") {
29            return Err(TextError::InvalidInput("No [MASK] token found".to_string()));
30        }
31
32        // Analyze context around mask
33        let words: Vec<&str> = text.split_whitespace().collect();
34        let mask_index = words.iter().position(|&w| w == "[MASK]").unwrap_or(0);
35
36        // Get context words
37        let left_context: Vec<&str> = if mask_index > 0 {
38            words[..mask_index].iter().rev().take(3).copied().collect()
39        } else {
40            vec![]
41        };
42
43        let right_context: Vec<&str> = if mask_index < words.len() - 1 {
44            words[mask_index + 1..].iter().take(3).copied().collect()
45        } else {
46            vec![]
47        };
48
49        // Generate contextually appropriate candidates
50        let mut candidates = Vec::new();
51
52        // Common words with context-based scoring
53        let common_words = vec![
54            ("the", 0.85),
55            ("a", 0.75),
56            ("an", 0.65),
57            ("is", 0.80),
58            ("was", 0.75),
59            ("are", 0.70),
60            ("will", 0.68),
61            ("can", 0.72),
62            ("would", 0.70),
63            ("should", 0.65),
64            ("very", 0.60),
65            ("more", 0.68),
66            ("most", 0.65),
67            ("good", 0.60),
68            ("great", 0.58),
69            ("important", 0.55),
70            ("significant", 0.52),
71            ("major", 0.50),
72        ];
73
74        for (word, base_score) in common_words {
75            // Adjust score based on context
76            let mut score = base_score;
77
78            // Boost score if word fits grammatical context
79            if !left_context.is_empty() {
80                let prev_word = left_context[0];
81                if (prev_word == "a" || prev_word == "an") && word.starts_with(char::is_alphabetic)
82                {
83                    score *= 0.3; // Reduce score for articles after articles
84                } else if prev_word.ends_with("ly") && (word == "good" || word == "important") {
85                    score *= 1.2; // Boost adjectives after adverbs
86                }
87            }
88
89            if !right_context.is_empty() {
90                let next_word = right_context[0];
91                if word == "a" && next_word.starts_with(|c: char| "aeiou".contains(c)) {
92                    score *= 0.2; // Heavily penalize "a" before vowels
93                } else if word == "an" && !next_word.starts_with(|c: char| "aeiou".contains(c)) {
94                    score *= 0.2; // Heavily penalize "an" before consonants
95                }
96            }
97
98            candidates.push(FillMaskResult {
99                token_str: word.to_string(),
100                sequence: text.replace("[MASK]", word),
101                score,
102                token: candidates.len() + 1,
103            });
104        }
105
106        // Sort by score and return top candidates
107        candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
108        Ok(candidates.into_iter().take(5).collect())
109    }
110}