wonnx_preprocessing/
text.rs

1use std::borrow::Cow;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::Path;
5use thiserror::Error;
6use tokenizers::{EncodeInput, Encoding, InputSequence, Tokenizer};
7use wonnx::utils::Shape;
8
9use crate::Tensor;
10
11#[derive(Error, Debug)]
12pub enum PreprocessingError {
13    #[error("text tokenization error: {0}")]
14    TextTokenizationError(#[from] Box<dyn std::error::Error + Sync + Send>),
15}
16
17pub struct TextTokenizer {
18    pub tokenizer: Tokenizer,
19}
20
21#[derive(Debug)]
22pub struct EncodedText {
23    pub encoding: Encoding,
24}
25
26impl TextTokenizer {
27    pub fn new(tokenizer: Tokenizer) -> TextTokenizer {
28        TextTokenizer { tokenizer }
29    }
30
31    pub fn from_config<P: AsRef<Path>>(path: P) -> Result<TextTokenizer, std::io::Error> {
32        let tokenizer_config_file = File::open(path)?;
33        let tokenizer_config_reader = BufReader::new(tokenizer_config_file);
34        let tokenizer = serde_json::from_reader(tokenizer_config_reader)?;
35        Ok(TextTokenizer::new(tokenizer))
36    }
37
38    pub fn tokenize_question_answer(
39        &self,
40        question: &str,
41        context: &str,
42    ) -> Result<Vec<EncodedText>, PreprocessingError> {
43        let mut encoding = self
44            .tokenizer
45            .encode(
46                EncodeInput::Dual(
47                    InputSequence::Raw(Cow::from(question)),
48                    InputSequence::Raw(Cow::from(context)),
49                ),
50                true,
51            )
52            .map_err(PreprocessingError::TextTokenizationError)?;
53
54        let mut overflowing = encoding.take_overflowing();
55        overflowing.insert(0, encoding);
56
57        Ok(overflowing
58            .into_iter()
59            .map(|x| EncodedText { encoding: x })
60            .collect())
61    }
62
63    fn tokenize(&self, text: &str) -> Result<EncodedText, PreprocessingError> {
64        let encoding = self
65            .tokenizer
66            .encode(
67                EncodeInput::Single(InputSequence::Raw(Cow::from(text))),
68                true,
69            )
70            .map_err(PreprocessingError::TextTokenizationError)?;
71        Ok(EncodedText { encoding })
72    }
73
74    pub fn decode(&self, encoding: &EncodedText) -> Result<String, PreprocessingError> {
75        let ids: Vec<u32> = encoding.get_tokens().iter().map(|x| *x as u32).collect();
76        self.tokenizer
77            .decode(&ids, true)
78            .map_err(PreprocessingError::TextTokenizationError)
79    }
80
81    pub fn get_mask_input_for(
82        &self,
83        text: &str,
84        shape: &Shape,
85    ) -> Result<Tensor, PreprocessingError> {
86        let segment_length = shape.dim(shape.rank() - 1) as usize;
87        let tokenized = self.tokenize(text)?;
88        let mut tokens = tokenized.get_mask();
89        tokens.resize(segment_length, 0);
90        let data = ndarray::Array::from_iter(tokens.iter().map(|x| (*x) as f32)).into_dyn();
91        Ok(Tensor::F32(data))
92    }
93
94    pub fn get_input_for(&self, text: &str, shape: &Shape) -> Result<Tensor, PreprocessingError> {
95        let segment_length = shape.dim(shape.rank() - 1) as usize;
96        let tokenized = self.tokenize(text)?;
97        let mut tokens = tokenized.get_tokens();
98        tokens.resize(segment_length, 0);
99        let data = ndarray::Array::from_iter(tokens.iter().map(|x| (*x) as f32)).into_dyn();
100        Ok(Tensor::F32(data))
101    }
102}
103
104#[derive(Debug)]
105pub struct Answer {
106    pub text: String,
107    pub tokens: Vec<String>,
108    pub score: f32,
109}
110
111impl EncodedText {
112    pub fn get_mask(&self) -> Vec<i64> {
113        self.encoding
114            .get_attention_mask()
115            .iter()
116            .map(|x| *x as i64)
117            .collect()
118    }
119
120    pub fn get_tokens(&self) -> Vec<i64> {
121        self.encoding.get_ids().iter().map(|x| *x as i64).collect()
122    }
123
124    pub fn get_segments(&self) -> Vec<i64> {
125        self.encoding
126            .get_type_ids()
127            .iter()
128            .map(|x| *x as i64)
129            .collect()
130    }
131
132    pub fn get_answer(&self, start_output: &[f32], end_output: &[f32], context: &str) -> Answer {
133        let mut best_start_logit = f32::MIN;
134        let mut best_start_idx: usize = 0;
135
136        let input_tokens = self.encoding.get_tokens();
137        let special_tokens_mask = self.encoding.get_special_tokens_mask();
138
139        for (start_idx, start_logit) in start_output.iter().enumerate() {
140            if start_idx > input_tokens.len() - 1 {
141                break;
142            }
143
144            // Skip special tokens such as [CLS], [SEP], [PAD]
145            if special_tokens_mask[start_idx] == 1 {
146                continue;
147            }
148
149            if *start_logit > best_start_logit {
150                best_start_logit = *start_logit;
151                best_start_idx = start_idx;
152            }
153        }
154
155        // Find matching end
156        let mut best_end_logit = f32::MIN;
157        let mut best_end_idx = best_start_idx;
158        for (end_idx, end_logit) in end_output[best_start_idx..].iter().enumerate() {
159            if (end_idx + best_start_idx) > input_tokens.len() - 1 {
160                break;
161            }
162
163            // Skip special tokens such as [CLS], [SEP], [PAD]
164            if special_tokens_mask[end_idx + best_start_idx] == 1 {
165                continue;
166            }
167
168            if *end_logit > best_end_logit {
169                best_end_logit = *end_logit;
170                best_end_idx = end_idx + best_start_idx;
171            }
172        }
173
174        log::debug!("start index: {} ({})", best_start_idx, best_start_logit);
175        log::debug!("end index: {} ({})", best_end_idx, best_end_logit);
176
177        let chars: Vec<char> = context.chars().collect();
178        let offsets = self.encoding.get_offsets();
179        log::debug!("offsets: {:?}", &offsets[best_start_idx..=best_end_idx]);
180
181        let answer_tokens: Vec<String> =
182            self.encoding.get_tokens()[best_start_idx..best_end_idx].to_vec();
183
184        let min_offset = offsets[best_start_idx..=best_end_idx]
185            .iter()
186            .map(|o| o.0)
187            .min()
188            .unwrap();
189        let max_offset = offsets[best_start_idx..=best_end_idx]
190            .iter()
191            .map(|o| o.1)
192            .max()
193            .unwrap();
194        assert!(min_offset <= max_offset);
195        if max_offset > chars.len() - 1 {
196            return Answer {
197                text: "".to_string(),
198                tokens: vec![],
199                score: 0.0,
200            };
201        }
202
203        let answer = chars[min_offset..max_offset].iter().collect::<String>();
204
205        Answer {
206            text: answer,
207            tokens: answer_tokens,
208            score: best_start_logit * best_end_logit,
209        }
210    }
211}
212
213pub fn get_lines(path: &Path) -> Vec<String> {
214    let file = BufReader::new(File::open(path).unwrap());
215    file.lines().map(|line| line.unwrap()).collect()
216}