tokenizers/decoders/
strip.rs

1use crate::tokenizer::{Decoder, Result};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Deserialize, Clone, Debug, Serialize, Default)]
6/// Strip is a simple trick which converts tokens looking like `<0x61>`
7/// to pure bytes, and attempts to make them into a string. If the tokens
8/// cannot be decoded you will get � instead for each inconvertable byte token
9#[serde(tag = "type")]
10#[non_exhaustive]
11pub struct Strip {
12    pub content: char,
13    pub start: usize,
14    pub stop: usize,
15}
16
17impl Strip {
18    pub fn new(content: char, start: usize, stop: usize) -> Self {
19        Self {
20            content,
21            start,
22            stop,
23        }
24    }
25}
26
27impl Decoder for Strip {
28    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
29        Ok(tokens
30            .into_iter()
31            .map(|token| {
32                let chars: Vec<char> = token.chars().collect();
33
34                let mut start_cut = 0;
35                for (i, &c) in chars.iter().enumerate().take(self.start) {
36                    if c == self.content {
37                        start_cut = i + 1;
38                        continue;
39                    } else {
40                        break;
41                    }
42                }
43
44                let mut stop_cut = chars.len();
45                for i in 0..self.stop {
46                    let index = chars.len() - i - 1;
47                    if chars[index] == self.content {
48                        stop_cut = index;
49                        continue;
50                    } else {
51                        break;
52                    }
53                }
54
55                let new_token: String = chars[start_cut..stop_cut].iter().collect();
56                new_token
57            })
58            .collect())
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn decode() {
68        let decoder = Strip::new('H', 1, 0);
69        let res = decoder
70            .decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()])
71            .unwrap();
72        assert_eq!(res, vec!["ey", " friend!", "HH"]);
73
74        let decoder = Strip::new('y', 0, 1);
75        let res = decoder
76            .decode_chain(vec!["Hey".into(), " friend!".into()])
77            .unwrap();
78        assert_eq!(res, vec!["He", " friend!"]);
79    }
80}