tokenizers/decoders/
byte_fallback.rs

1use crate::tokenizer::{Decoder, Result};
2use monostate::MustBe;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Deserialize, Clone, Debug, Serialize, Default)]
7/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
8/// to pure bytes, and attempts to make them into a string. If the tokens
9/// cannot be decoded you will get � instead for each inconvertable byte token
10#[non_exhaustive]
11pub struct ByteFallback {
12    #[serde(rename = "type")]
13    type_: MustBe!("ByteFallback"),
14}
15
16impl ByteFallback {
17    pub fn new() -> Self {
18        Self {
19            type_: MustBe!("ByteFallback"),
20        }
21    }
22}
23
24impl Decoder for ByteFallback {
25    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
26        let mut new_tokens: Vec<String> = vec![];
27        let mut previous_byte_tokens: Vec<u8> = vec![];
28
29        for token in tokens {
30            let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
31                if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
32                    Some(byte)
33                } else {
34                    None
35                }
36            } else {
37                None
38            };
39            if let Some(bytes) = bytes {
40                previous_byte_tokens.push(bytes);
41            } else {
42                if !previous_byte_tokens.is_empty() {
43                    if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
44                        new_tokens.push(string);
45                    } else {
46                        for _ in 0..previous_byte_tokens.len() {
47                            new_tokens.push("�".into());
48                        }
49                    }
50                    previous_byte_tokens.clear();
51                }
52                new_tokens.push(token);
53            }
54        }
55        if !previous_byte_tokens.is_empty() {
56            if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
57                new_tokens.push(string);
58            } else {
59                for _ in 0..previous_byte_tokens.len() {
60                    new_tokens.push("�".into());
61                }
62            }
63        }
64
65        Ok(new_tokens)
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    #[test]
74    fn decode() {
75        let decoder = ByteFallback::new();
76        let res = decoder
77            .decode_chain(vec!["Hey".into(), "friend!".into()])
78            .unwrap();
79        assert_eq!(res, vec!["Hey", "friend!"]);
80
81        let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap();
82        assert_eq!(res, vec!["a"]);
83
84        let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap();
85        assert_eq!(res, vec!["�"]);
86
87        let res = decoder
88            .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()])
89            .unwrap();
90        assert_eq!(res, vec!["�", "�"]);
91
92        // 叫
93        let res = decoder
94            .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()])
95            .unwrap();
96        assert_eq!(res, vec!["叫"]);
97
98        let res = decoder
99            .decode_chain(vec![
100                "<0xE5>".into(),
101                "<0x8f>".into(),
102                "<0xab>".into(),
103                "a".into(),
104            ])
105            .unwrap();
106        assert_eq!(res, vec!["叫", "a"]);
107
108        let res = decoder
109            .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()])
110            .unwrap();
111        assert_eq!(res, vec!["�", "�", "a"]);
112    }
113}