1use crate::{
2 normalizer::Range, Encoding, NormalizedString, OffsetReferential, Offsets, Result, Token,
3};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum OffsetType {
9 Byte,
10 Char,
11 None,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct Split {
21 normalized: NormalizedString,
25 tokens: Option<Vec<Token>>,
27}
28
29impl From<NormalizedString> for Split {
30 fn from(n: NormalizedString) -> Self {
31 Self {
32 normalized: n,
33 tokens: None,
34 }
35 }
36}
37
38impl From<(NormalizedString, Option<Vec<Token>>)> for Split {
39 fn from(f: (NormalizedString, Option<Vec<Token>>)) -> Self {
40 Self {
41 normalized: f.0,
42 tokens: f.1,
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct PreTokenizedString {
55 original: String,
56 splits: Vec<Split>,
57}
58
59impl PreTokenizedString {
60 pub fn split<F, U, R>(&mut self, mut split_fn: F) -> Result<()>
73 where
74 F: FnMut(usize, NormalizedString) -> Result<U>,
75 U: IntoIterator<Item = R>,
76 R: Into<Split>,
77 {
78 let mut new_splits = Vec::with_capacity(self.splits.len());
80 for (i, original_split) in self.splits.drain(..).enumerate() {
81 if original_split.tokens.is_some() {
82 new_splits.push(original_split);
83 continue;
84 }
85
86 new_splits.extend(
87 split_fn(i, original_split.normalized)?
88 .into_iter()
89 .filter_map(|split| {
90 let split: Split = split.into();
91 if split.normalized.is_empty() {
92 None
93 } else {
94 Some(split)
95 }
96 }),
97 );
98 }
99 self.splits = new_splits;
100
101 Ok(())
102 }
103
104 pub fn normalize<F>(&mut self, normalize: F) -> Result<()>
107 where
108 F: Fn(&mut NormalizedString) -> Result<()>,
109 {
110 for split in self.splits.iter_mut().filter(|s| s.tokens.is_none()) {
111 normalize(&mut split.normalized)?;
112 }
113 Ok(())
114 }
115
116 pub fn tokenize<F>(&mut self, tokenize: F) -> Result<()>
119 where
120 F: Fn(&NormalizedString) -> Result<Vec<Token>>,
121 {
122 for split in self.splits.iter_mut().filter(|s| s.tokens.is_none()) {
123 split.tokens = Some(tokenize(&split.normalized)?);
124 }
125
126 Ok(())
127 }
128
129 pub fn into_encoding(
137 self,
138 word_idx: Option<u32>,
139 type_id: u32,
140 offset_type: OffsetType,
141 ) -> Result<Encoding> {
142 if self.splits.is_empty() {
143 Ok(Encoding::default())
144 } else if !self.splits.iter().all(|split| split.tokens.is_some()) {
145 Err("Split has not been tokenized, call `PreTokenizedString::tokenize` first".into())
146 } else {
147 let offset_converter = match offset_type {
148 OffsetType::Char => Some(BytesToCharOffsetConverter::new(&self.original)),
149 OffsetType::Byte => None,
150 OffsetType::None => {
151 let tokens = self
152 .splits
153 .into_iter()
154 .flat_map(|split| {
155 split.tokens.unwrap().into_iter().map(|token| {
156 (token.id, String::with_capacity(0), (0, 0), None, 0)
158 })
159 })
160 .collect();
161 return Ok(tokens);
162 }
163 };
164
165 Ok(self
166 .splits
167 .into_iter()
168 .enumerate()
169 .flat_map(|(idx, split)| {
170 let normalized = split.normalized;
171 let offsets = normalized.offsets_original();
172 let offset_converter = &offset_converter;
173
174 split.tokens.unwrap().into_iter().map(move |token| {
175 let mut offsets = normalized
176 .convert_offsets(Range::Normalized(token.offsets.0..token.offsets.1))
177 .map_or(token.offsets, |range| {
178 (offsets.0 + range.start, offsets.0 + range.end)
179 });
180
181 if let Some(converter) = offset_converter {
183 offsets = converter.convert(offsets).unwrap_or(offsets);
184 }
185
186 (
187 token.id,
188 token.value,
189 offsets,
190 if word_idx.is_some() {
191 word_idx
192 } else {
193 Some(idx as u32)
194 },
195 type_id,
196 )
197 })
198 })
199 .collect())
200 }
201 }
202
203 pub fn get_splits(
207 &self,
208 offset_ref: OffsetReferential,
209 offset_type: OffsetType,
210 ) -> Vec<(&str, Offsets, &Option<Vec<Token>>)> {
211 let offset_converter = match offset_type {
212 OffsetType::Char => Some(BytesToCharOffsetConverter::new(&self.original)),
213 OffsetType::Byte => None,
214 OffsetType::None => None,
215 };
216
217 let mut offset = 0;
218 self.splits
219 .iter()
220 .map(|split| {
221 let mut offsets = match offset_ref {
222 OffsetReferential::Original => split.normalized.offsets_original(),
223 OffsetReferential::Normalized => {
224 let len = split.normalized.len();
225 offset += len;
226 (offset - len, offset)
227 }
228 };
229
230 if let Some(ref converter) = offset_converter {
232 offsets = converter.convert(offsets).unwrap_or(offsets);
233 }
234
235 (split.normalized.get(), offsets, &split.tokens)
236 })
237 .collect()
238 }
239}
240
241impl From<NormalizedString> for PreTokenizedString {
242 fn from(s: NormalizedString) -> Self {
243 Self {
244 original: s.get_original().to_owned(),
245 splits: vec![Split {
246 normalized: s,
247 tokens: None,
248 }],
249 }
250 }
251}
252
253impl From<&str> for PreTokenizedString {
254 fn from(s: &str) -> Self {
255 let normalized: NormalizedString = s.into();
256 normalized.into()
257 }
258}
259
260impl From<String> for PreTokenizedString {
261 fn from(s: String) -> Self {
262 let normalized: NormalizedString = s.into();
263 normalized.into()
264 }
265}
266
267struct BytesToCharOffsetConverter {
268 map: HashMap<usize, usize>,
269}
270
271impl BytesToCharOffsetConverter {
272 pub fn new(sequence: &str) -> Self {
273 Self {
274 map: sequence
275 .char_indices()
276 .enumerate()
277 .flat_map(|(i, (b, c))| {
278 let mut n = 0;
279 std::iter::repeat_with(move || {
280 let o = (b + n, i);
281 n += 1;
282 o
283 })
284 .take(c.len_utf8())
285 })
286 .collect(),
287 }
288 }
289
290 pub fn convert(&self, offsets: Offsets) -> Option<Offsets> {
291 match (self.map.get(&offsets.0), self.map.get(&offsets.1)) {
292 (Some(start), Some(end)) => Some((*start, *end)),
293 (Some(start), None) => {
295 let last = self.map.get(&(offsets.1 - 1)).copied().unwrap_or(start + 1);
297 Some((*start, last + 1))
298 }
299 _ => None,
300 }
301 }
302}