wonnx_preprocessing/
text.rs1use std::borrow::Cow;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::Path;
5use thiserror::Error;
6use tokenizers::{EncodeInput, Encoding, InputSequence, Tokenizer};
7use wonnx::utils::Shape;
8
9use crate::Tensor;
10
11#[derive(Error, Debug)]
12pub enum PreprocessingError {
13 #[error("text tokenization error: {0}")]
14 TextTokenizationError(#[from] Box<dyn std::error::Error + Sync + Send>),
15}
16
17pub struct TextTokenizer {
18 pub tokenizer: Tokenizer,
19}
20
21#[derive(Debug)]
22pub struct EncodedText {
23 pub encoding: Encoding,
24}
25
26impl TextTokenizer {
27 pub fn new(tokenizer: Tokenizer) -> TextTokenizer {
28 TextTokenizer { tokenizer }
29 }
30
31 pub fn from_config<P: AsRef<Path>>(path: P) -> Result<TextTokenizer, std::io::Error> {
32 let tokenizer_config_file = File::open(path)?;
33 let tokenizer_config_reader = BufReader::new(tokenizer_config_file);
34 let tokenizer = serde_json::from_reader(tokenizer_config_reader)?;
35 Ok(TextTokenizer::new(tokenizer))
36 }
37
38 pub fn tokenize_question_answer(
39 &self,
40 question: &str,
41 context: &str,
42 ) -> Result<Vec<EncodedText>, PreprocessingError> {
43 let mut encoding = self
44 .tokenizer
45 .encode(
46 EncodeInput::Dual(
47 InputSequence::Raw(Cow::from(question)),
48 InputSequence::Raw(Cow::from(context)),
49 ),
50 true,
51 )
52 .map_err(PreprocessingError::TextTokenizationError)?;
53
54 let mut overflowing = encoding.take_overflowing();
55 overflowing.insert(0, encoding);
56
57 Ok(overflowing
58 .into_iter()
59 .map(|x| EncodedText { encoding: x })
60 .collect())
61 }
62
63 fn tokenize(&self, text: &str) -> Result<EncodedText, PreprocessingError> {
64 let encoding = self
65 .tokenizer
66 .encode(
67 EncodeInput::Single(InputSequence::Raw(Cow::from(text))),
68 true,
69 )
70 .map_err(PreprocessingError::TextTokenizationError)?;
71 Ok(EncodedText { encoding })
72 }
73
74 pub fn decode(&self, encoding: &EncodedText) -> Result<String, PreprocessingError> {
75 let ids: Vec<u32> = encoding.get_tokens().iter().map(|x| *x as u32).collect();
76 self.tokenizer
77 .decode(&ids, true)
78 .map_err(PreprocessingError::TextTokenizationError)
79 }
80
81 pub fn get_mask_input_for(
82 &self,
83 text: &str,
84 shape: &Shape,
85 ) -> Result<Tensor, PreprocessingError> {
86 let segment_length = shape.dim(shape.rank() - 1) as usize;
87 let tokenized = self.tokenize(text)?;
88 let mut tokens = tokenized.get_mask();
89 tokens.resize(segment_length, 0);
90 let data = ndarray::Array::from_iter(tokens.iter().map(|x| (*x) as f32)).into_dyn();
91 Ok(Tensor::F32(data))
92 }
93
94 pub fn get_input_for(&self, text: &str, shape: &Shape) -> Result<Tensor, PreprocessingError> {
95 let segment_length = shape.dim(shape.rank() - 1) as usize;
96 let tokenized = self.tokenize(text)?;
97 let mut tokens = tokenized.get_tokens();
98 tokens.resize(segment_length, 0);
99 let data = ndarray::Array::from_iter(tokens.iter().map(|x| (*x) as f32)).into_dyn();
100 Ok(Tensor::F32(data))
101 }
102}
103
104#[derive(Debug)]
105pub struct Answer {
106 pub text: String,
107 pub tokens: Vec<String>,
108 pub score: f32,
109}
110
111impl EncodedText {
112 pub fn get_mask(&self) -> Vec<i64> {
113 self.encoding
114 .get_attention_mask()
115 .iter()
116 .map(|x| *x as i64)
117 .collect()
118 }
119
120 pub fn get_tokens(&self) -> Vec<i64> {
121 self.encoding.get_ids().iter().map(|x| *x as i64).collect()
122 }
123
124 pub fn get_segments(&self) -> Vec<i64> {
125 self.encoding
126 .get_type_ids()
127 .iter()
128 .map(|x| *x as i64)
129 .collect()
130 }
131
132 pub fn get_answer(&self, start_output: &[f32], end_output: &[f32], context: &str) -> Answer {
133 let mut best_start_logit = f32::MIN;
134 let mut best_start_idx: usize = 0;
135
136 let input_tokens = self.encoding.get_tokens();
137 let special_tokens_mask = self.encoding.get_special_tokens_mask();
138
139 for (start_idx, start_logit) in start_output.iter().enumerate() {
140 if start_idx > input_tokens.len() - 1 {
141 break;
142 }
143
144 if special_tokens_mask[start_idx] == 1 {
146 continue;
147 }
148
149 if *start_logit > best_start_logit {
150 best_start_logit = *start_logit;
151 best_start_idx = start_idx;
152 }
153 }
154
155 let mut best_end_logit = f32::MIN;
157 let mut best_end_idx = best_start_idx;
158 for (end_idx, end_logit) in end_output[best_start_idx..].iter().enumerate() {
159 if (end_idx + best_start_idx) > input_tokens.len() - 1 {
160 break;
161 }
162
163 if special_tokens_mask[end_idx + best_start_idx] == 1 {
165 continue;
166 }
167
168 if *end_logit > best_end_logit {
169 best_end_logit = *end_logit;
170 best_end_idx = end_idx + best_start_idx;
171 }
172 }
173
174 log::debug!("start index: {} ({})", best_start_idx, best_start_logit);
175 log::debug!("end index: {} ({})", best_end_idx, best_end_logit);
176
177 let chars: Vec<char> = context.chars().collect();
178 let offsets = self.encoding.get_offsets();
179 log::debug!("offsets: {:?}", &offsets[best_start_idx..=best_end_idx]);
180
181 let answer_tokens: Vec<String> =
182 self.encoding.get_tokens()[best_start_idx..best_end_idx].to_vec();
183
184 let min_offset = offsets[best_start_idx..=best_end_idx]
185 .iter()
186 .map(|o| o.0)
187 .min()
188 .unwrap();
189 let max_offset = offsets[best_start_idx..=best_end_idx]
190 .iter()
191 .map(|o| o.1)
192 .max()
193 .unwrap();
194 assert!(min_offset <= max_offset);
195 if max_offset > chars.len() - 1 {
196 return Answer {
197 text: "".to_string(),
198 tokens: vec![],
199 score: 0.0,
200 };
201 }
202
203 let answer = chars[min_offset..max_offset].iter().collect::<String>();
204
205 Answer {
206 text: answer,
207 tokens: answer_tokens,
208 score: best_start_logit * best_end_logit,
209 }
210 }
211}
212
213pub fn get_lines(path: &Path) -> Vec<String> {
214 let file = BufReader::new(File::open(path).unwrap());
215 file.lines().map(|line| line.unwrap()).collect()
216}