tokenizers/decoders/
byte_fallback.rs1use crate::tokenizer::{Decoder, Result};
2use monostate::MustBe;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Deserialize, Clone, Debug, Serialize, Default)]
7#[non_exhaustive]
11pub struct ByteFallback {
12 #[serde(rename = "type")]
13 type_: MustBe!("ByteFallback"),
14}
15
16impl ByteFallback {
17 pub fn new() -> Self {
18 Self {
19 type_: MustBe!("ByteFallback"),
20 }
21 }
22}
23
24impl Decoder for ByteFallback {
25 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
26 let mut new_tokens: Vec<String> = vec![];
27 let mut previous_byte_tokens: Vec<u8> = vec![];
28
29 for token in tokens {
30 let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
31 if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
32 Some(byte)
33 } else {
34 None
35 }
36 } else {
37 None
38 };
39 if let Some(bytes) = bytes {
40 previous_byte_tokens.push(bytes);
41 } else {
42 if !previous_byte_tokens.is_empty() {
43 if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
44 new_tokens.push(string);
45 } else {
46 for _ in 0..previous_byte_tokens.len() {
47 new_tokens.push("�".into());
48 }
49 }
50 previous_byte_tokens.clear();
51 }
52 new_tokens.push(token);
53 }
54 }
55 if !previous_byte_tokens.is_empty() {
56 if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
57 new_tokens.push(string);
58 } else {
59 for _ in 0..previous_byte_tokens.len() {
60 new_tokens.push("�".into());
61 }
62 }
63 }
64
65 Ok(new_tokens)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[test]
74 fn decode() {
75 let decoder = ByteFallback::new();
76 let res = decoder
77 .decode_chain(vec!["Hey".into(), "friend!".into()])
78 .unwrap();
79 assert_eq!(res, vec!["Hey", "friend!"]);
80
81 let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap();
82 assert_eq!(res, vec!["a"]);
83
84 let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap();
85 assert_eq!(res, vec!["�"]);
86
87 let res = decoder
88 .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()])
89 .unwrap();
90 assert_eq!(res, vec!["�", "�"]);
91
92 let res = decoder
94 .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()])
95 .unwrap();
96 assert_eq!(res, vec!["叫"]);
97
98 let res = decoder
99 .decode_chain(vec![
100 "<0xE5>".into(),
101 "<0x8f>".into(),
102 "<0xab>".into(),
103 "a".into(),
104 ])
105 .unwrap();
106 assert_eq!(res, vec!["叫", "a"]);
107
108 let res = decoder
109 .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()])
110 .unwrap();
111 assert_eq!(res, vec!["�", "�", "a"]);
112 }
113}