1use std::collections::{HashMap, HashSet};
2
3use crate::utils::SysRegex;
4use serde::{Deserialize, Serialize};
5
6use crate::tokenizer::{
7 Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
8 SplitDelimiterBehavior,
9};
10use crate::utils::macro_rules_attribute;
11
12pub(crate) fn bytes_char() -> HashMap<u8, char> {
15 let mut bs: Vec<u8> = vec![];
16 bs.extend(b'!'..=b'~');
17 bs.extend(b'\xA1'..=b'\xAC');
18 bs.extend(b'\xAE'..=b'\xFF');
19
20 let mut cs: Vec<u32> = bs.iter().map(|i| *i as u32).collect();
21 let mut n = 0;
22
23 for b in 0..=255u8 {
24 if !bs.contains(&b) {
25 bs.push(b);
26 cs.push(u32::pow(2, 8) + n);
27 n += 1;
28 }
29 }
30
31 bs.into_iter()
35 .zip(cs)
36 .map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
37 .collect()
38}
39
40lazy_static! {
41 static ref RE: SysRegex = SysRegex::new(
44 r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
45 )
46 .unwrap();
47 static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
48 static ref CHAR_BYTES: HashMap<char, u8> =
49 bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq)]
53#[macro_rules_attribute(impl_serde_type!)]
57#[non_exhaustive]
58pub struct ByteLevel {
59 pub add_prefix_space: bool,
62 pub trim_offsets: bool,
64
65 #[serde(default = "default_true")]
68 pub use_regex: bool,
69}
70
71fn default_true() -> bool {
72 true
73}
74
75impl Default for ByteLevel {
76 fn default() -> Self {
77 Self {
78 add_prefix_space: true,
79 trim_offsets: true,
80 use_regex: true,
81 }
82 }
83}
84
85impl ByteLevel {
86 pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self {
87 Self {
88 add_prefix_space,
89 trim_offsets,
90 use_regex,
91 }
92 }
93
94 pub fn alphabet() -> HashSet<char> {
95 BYTES_CHAR.values().copied().collect()
96 }
97
98 #[must_use]
99 pub fn add_prefix_space(mut self, v: bool) -> Self {
100 self.add_prefix_space = v;
101 self
102 }
103
104 #[must_use]
105 pub fn trim_offsets(mut self, v: bool) -> Self {
106 self.trim_offsets = v;
107 self
108 }
109
110 #[must_use]
111 pub fn use_regex(mut self, v: bool) -> Self {
112 self.use_regex = v;
113 self
114 }
115}
116
117impl PreTokenizer for ByteLevel {
121 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
122 let re_ref: &SysRegex = &RE;
123 pretokenized.split(|_, mut normalized| {
124 if self.add_prefix_space && !normalized.get().starts_with(' ') {
125 normalized.prepend(" ");
126 }
127 if self.use_regex {
128 normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
129 } else {
130 Ok(vec![normalized])
131 }
132 })?;
133 pretokenized.normalize(|normalized| {
134 let s = normalized.get();
135 let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len());
136 let mut i = 0;
137 for cur_char in s.chars() {
138 let size = cur_char.len_utf8();
139 let bytes = s[i..i + size].as_bytes();
140 i += size;
141 transformations.extend(
142 bytes
143 .iter()
144 .enumerate()
145 .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))),
146 );
147 }
148 normalized.transform(transformations, 0);
149 Ok(())
150 })
151 }
152}
153
154impl Decoder for ByteLevel {
160 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
161 let toks = tokens
162 .into_iter()
163 .flat_map(|t| {
164 t.chars()
165 .try_fold(vec![], |mut acc, c| {
166 CHAR_BYTES.get(&c).map(|b| {
167 acc.push(*b);
168 acc
169 })
170 })
171 .unwrap_or_else(|| t.as_bytes().to_vec())
172 })
173 .collect::<Vec<u8>>();
174 Ok(vec![String::from_utf8_lossy(&toks).to_string()])
175 }
176}
177
178impl PostProcessor for ByteLevel {
180 fn added_tokens(&self, _is_pair: bool) -> usize {
181 0
182 }
183
184 fn process_encodings(
185 &self,
186 mut encodings: Vec<Encoding>,
187 _add_special_tokens: bool,
188 ) -> Result<Vec<Encoding>> {
189 if self.trim_offsets {
190 for encoding in encodings.iter_mut() {
191 process_offsets(encoding, self.add_prefix_space);
192 encoding
193 .get_overflowing_mut()
194 .iter_mut()
195 .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
196 }
197 }
198 for (i, encoding) in encodings.iter_mut().enumerate() {
199 encoding.set_sequence_id(i);
200 }
201 Ok(encodings)
202 }
204}
205
206pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) {
207 encoding.process_tokens_with_offsets_mut(|(i, (token, offsets))| {
208 let mut leading_spaces = token
209 .chars()
210 .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace())
211 .count();
212 let trailing_spaces = token
213 .chars()
214 .rev()
215 .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace())
216 .count();
217
218 if leading_spaces > 0 || trailing_spaces > 0 {
219 if leading_spaces > 0 {
220 let is_first = i == 0 || offsets.0 == 0;
224 if is_first && add_prefix_space && leading_spaces == 1 {
225 leading_spaces = 0;
230 }
231 offsets.0 = std::cmp::min(offsets.0 + leading_spaces, offsets.1);
232 }
233 if trailing_spaces > 0 && offsets.1 >= trailing_spaces {
234 offsets.1 = std::cmp::max(offsets.1 - trailing_spaces, offsets.0);
235 }
236 }
237 });
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::tokenizer::{
244 Decoder, Encoding, OffsetReferential, OffsetType, PostProcessor, PreTokenizedString,
245 PreTokenizer,
246 };
247 use std::iter::FromIterator;
248
249 #[test]
250 fn pre_tokenization() {
251 let bytelevel = ByteLevel::default().add_prefix_space(false);
252 let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
253 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
254 assert_eq!(
255 pretokenized
256 .get_splits(OffsetReferential::Original, OffsetType::Byte)
257 .into_iter()
258 .map(|(s, o, _)| (s, o))
259 .collect::<Vec<_>>(),
260 vec![
261 ("Hello", (0, 5)),
262 ("Ġmy", (5, 8)),
263 ("Ġfriend", (8, 15)),
264 (",", (15, 16)),
265 ("Ġhow", (16, 20)),
266 ("Ġis", (20, 23)),
267 ("Ġyour", (23, 28)),
268 ("Ġday", (28, 32)),
269 ("Ġgoing", (32, 38)),
270 ("?", (38, 39))
271 ]
272 );
273 }
274
275 #[test]
276 fn pre_tokenization_no_regex() {
277 let bytelevel = ByteLevel::default().use_regex(false);
278 let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
279 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
280 assert_eq!(
281 pretokenized
282 .get_splits(OffsetReferential::Original, OffsetType::Byte)
283 .into_iter()
284 .map(|(s, o, _)| (s, o))
285 .collect::<Vec<_>>(),
286 vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))]
287 );
288 }
289
290 #[test]
291 fn decoding() {
292 let bytelevel = ByteLevel::default().add_prefix_space(false);
293 assert_eq!(
294 bytelevel
295 .decode_chain(
296 vec![
297 "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing",
298 "?"
299 ]
300 .into_iter()
301 .map(|s| s.into())
302 .collect::<Vec<String>>()
303 )
304 .unwrap(),
305 vec!["Hello my friend, how is your day going?"]
306 );
307 }
308
309 #[test]
310 fn add_prefix_space() {
311 let bytelevel = ByteLevel::default().add_prefix_space(true);
312 for s in &[
313 " Hello my friend, how is your day going?",
314 "Hello my friend, how is your day going?",
315 ] {
316 let mut pretokenized = PreTokenizedString::from(*s);
317 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
318 assert_eq!(
319 pretokenized
320 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
321 .into_iter()
322 .map(|(s, o, _)| (s, o))
323 .collect::<Vec<_>>(),
324 vec![
325 ("ĠHello", (0, 7)),
326 ("Ġmy", (7, 11)),
327 ("Ġfriend", (11, 19)),
328 (",", (19, 20)),
329 ("Ġhow", (20, 25)),
330 ("Ġis", (25, 29)),
331 ("Ġyour", (29, 35)),
332 ("Ġday", (35, 40)),
333 ("Ġgoing", (40, 47)),
334 ("?", (47, 48))
335 ]
336 );
337 }
338 }
339
340 #[test]
341 fn decode_works_on_separated_tokens() {
342 let samples = vec![
343 "A Nuskhuri abbreviation of იესუ ქრისტე ( iesu kriste ) \" Jesus Christ \"",
344 "An equal number have descenders , like p or q in English \
345 : გ , დ , ე , ვ , კ , ლ , ჟ , ტ , უ , ფ , ღ , ყ , ც",
346 ];
347
348 let bytelevel = ByteLevel::default().add_prefix_space(false);
349 for sample in samples {
350 let mut pretokenized = PreTokenizedString::from(sample);
351 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
352 let separated_tokens = pretokenized
353 .get_splits(OffsetReferential::Original, OffsetType::Byte)
354 .iter()
355 .flat_map(|(s, _, _)| s.split("").map(|t| t.into()))
356 .collect::<Vec<_>>();
357 assert_eq!(
358 sample,
359 bytelevel.decode_chain(separated_tokens).unwrap().join("")
360 );
361 }
362 }
363
364 #[test]
365 fn handling_of_newlines() {
366 let mut pretokenized = PreTokenizedString::from("Hello there\nHello there");
367 let bytelevel = ByteLevel::default().add_prefix_space(false);
368 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
369
370 assert_eq!(
371 pretokenized
372 .get_splits(OffsetReferential::Original, OffsetType::Byte)
373 .into_iter()
374 .map(|(s, o, _)| (s, o))
375 .collect::<Vec<_>>(),
376 vec![
377 ("Hello", (0, 5)),
378 ("Ġthere", (5, 11)),
379 ("Ċ", (11, 12)),
380 ("Hello", (12, 17)),
381 ("Ġthere", (17, 23))
382 ]
383 );
384 }
385
386 #[test]
387 fn handling_of_multiple_whitespaces() {
388 let mut pretokenized = PreTokenizedString::from("Hello there dear");
389 let bytelevel = ByteLevel::default().add_prefix_space(false);
390 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
391
392 assert_eq!(
393 pretokenized
394 .get_splits(OffsetReferential::Original, OffsetType::Byte)
395 .into_iter()
396 .map(|(s, o, _)| (s, o))
397 .collect::<Vec<_>>(),
398 vec![
399 ("Hello", (0, 5)),
400 ("Ġthere", (5, 11)),
401 ("ĠĠĠĠĠĠ", (11, 17)),
402 ("Ġdear", (17, 22))
403 ]
404 );
405 }
406
407 #[test]
408 fn offsets_when_char_split_up() {
409 let input = "i⭢j";
410 let mut pretokenized = PreTokenizedString::from(input);
411 let bytelevel = ByteLevel::default().add_prefix_space(false);
412 bytelevel.pre_tokenize(&mut pretokenized).unwrap();
413
414 assert_eq!(
415 pretokenized
416 .get_splits(OffsetReferential::Original, OffsetType::Byte)
417 .into_iter()
418 .map(|(s, o, _)| (s, o))
419 .collect::<Vec<_>>(),
420 vec![("i", (0, 1)), ("âŃ¢", (1, 4)), ("j", (4, 5))]
421 );
422 assert_eq!(
423 pretokenized
424 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
425 .into_iter()
426 .map(|(s, o, _)| (s, o))
427 .collect::<Vec<_>>(),
428 vec![("i", (0, 1)), ("âŃ¢", (1, 7)), ("j", (7, 8))]
429 );
430 assert_eq!(
431 pretokenized
432 .get_splits(OffsetReferential::Original, OffsetType::Byte)
433 .into_iter()
434 .map(|(_, o, _)| &input[o.0..o.1])
435 .collect::<Vec<_>>(),
436 vec!["i", "⭢", "j"]
437 );
438 }
439
440 #[test]
441 fn processor_trims_offsets_pre_tokenized() {
442 let mut encoding = Encoding::new(
446 vec![0; 5],
447 vec![],
448 vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()],
449 vec![],
450 vec![(0, 1), (1, 4), (0, 1), (1, 4)],
451 vec![],
452 vec![],
453 vec![],
454 HashMap::new(),
455 );
456 process_offsets(&mut encoding, true);
457 assert_eq!(
458 encoding,
459 Encoding::new(
460 vec![0; 5],
461 vec![],
462 vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()],
463 vec![],
464 vec![(0, 1), (1, 4), (0, 1), (1, 4)],
465 vec![],
466 vec![],
467 vec![],
468 HashMap::new(),
469 )
470 );
471 }
472
473 #[test]
474 fn processor_trims_offsets() {
475 let start = Encoding::new(
476 vec![0; 5],
477 vec![],
478 vec![
479 "Ġ".into(),
480 "ĠĠĠĠHelloĠĠ".into(),
481 "ĠĠHello".into(),
482 "HelloĠĠ".into(),
483 "ĠĠĠĠ".into(),
484 ],
485 vec![],
486 vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
487 vec![],
488 vec![],
489 vec![],
490 HashMap::new(),
491 );
492 let expected = Encoding::new(
493 vec![0; 5],
494 vec![0; 5],
495 vec![
496 "Ġ".into(),
497 "ĠĠĠĠHelloĠĠ".into(),
498 "ĠĠHello".into(),
499 "HelloĠĠ".into(),
500 "ĠĠĠĠ".into(),
501 ],
502 vec![],
503 vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
504 vec![],
505 vec![],
506 vec![],
507 HashMap::from_iter(vec![(0, 0..5)]),
508 );
509
510 let bytelevel = ByteLevel::default().trim_offsets(true);
511 assert_eq!(
512 expected,
513 bytelevel.process(start.clone(), None, false).unwrap()
514 );
515
516 let pair_expected = Encoding::new(
517 vec![0; 10],
518 vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
519 vec![
520 "Ġ".into(),
521 "ĠĠĠĠHelloĠĠ".into(),
522 "ĠĠHello".into(),
523 "HelloĠĠ".into(),
524 "ĠĠĠĠ".into(),
525 "Ġ".into(),
526 "ĠĠĠĠHelloĠĠ".into(),
527 "ĠĠHello".into(),
528 "HelloĠĠ".into(),
529 "ĠĠĠĠ".into(),
530 ],
531 vec![],
532 vec![
533 (0, 0),
534 (4, 9),
535 (13, 18),
536 (18, 23),
537 (29, 29),
538 (0, 0),
539 (4, 9),
540 (13, 18),
541 (18, 23),
542 (29, 29),
543 ],
544 vec![],
545 vec![],
546 vec![],
547 HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
548 );
549 assert_eq!(
550 pair_expected,
551 bytelevel
552 .process(start.clone(), Some(start), false)
553 .unwrap()
554 );
555 }
556
557 #[test]
558 fn decode_unknown_characters() {
559 let byte_level = ByteLevel::default();
560 assert_eq!(
561 byte_level
562 .decode_chain(vec![
563 "Hello".into(),
564 "Ġthere".into(),
565 "Ġdear".into(),
566 "Ġfriend!".into(),
567 "Ġ".into(),
568 "[PA D]".into()
569 ])
570 .unwrap(),
571 vec!["Hello there dear friend! [PA D]"]
572 );
573 }
574
575 #[test]
576 fn deserialization() {
577 let byte_level: ByteLevel = serde_json::from_str(
579 r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#,
580 )
581 .unwrap();
582 assert!(byte_level.use_regex);
583
584 let byte_level: ByteLevel = serde_json::from_str(
586 r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#,
587 )
588 .unwrap();
589 assert!(byte_level.use_regex);
590
591 let byte_level: ByteLevel = serde_json::from_str(
592 r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#,
593 )
594 .unwrap();
595 assert!(!byte_level.use_regex);
596 }
597}