Skip to main content

wax_core/
token_stream.rs

1use crate::{Result, WaxError};
2
3pub struct TokenOutputStream {
4    tokenizer: tokenizers::Tokenizer,
5    tokens: Vec<u32>,
6    prev_index: usize,
7    current_index: usize,
8}
9
10impl TokenOutputStream {
11    pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
12        Self {
13            tokenizer,
14            tokens: Vec::new(),
15            prev_index: 0,
16            current_index: 0,
17        }
18    }
19
20    pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
21        let prev_text = if self.tokens.is_empty() {
22            String::new()
23        } else {
24            self.decode(&self.tokens[self.prev_index..self.current_index])?
25        };
26
27        self.tokens.push(token);
28        let text = self.decode(&self.tokens[self.prev_index..])?;
29        if text.len() > prev_text.len() && text.chars().last().is_some_and(char::is_alphanumeric) {
30            let (_, delta) = text.split_at(prev_text.len());
31            self.prev_index = self.current_index;
32            self.current_index = self.tokens.len();
33            Ok(Some(delta.to_string()))
34        } else {
35            Ok(None)
36        }
37    }
38
39    pub fn decode_rest(&self) -> Result<Option<String>> {
40        let prev_text = if self.tokens.is_empty() {
41            String::new()
42        } else {
43            self.decode(&self.tokens[self.prev_index..self.current_index])?
44        };
45        let text = self.decode(&self.tokens[self.prev_index..])?;
46        if text.len() > prev_text.len() {
47            let (_, delta) = text.split_at(prev_text.len());
48            Ok(Some(delta.to_string()))
49        } else {
50            Ok(None)
51        }
52    }
53
54    fn decode(&self, tokens: &[u32]) -> Result<String> {
55        self.tokenizer
56            .decode(tokens, true)
57            .map_err(WaxError::tokenizer)
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use ahash::AHashMap;
64    use tokenizers::{models::wordlevel::WordLevel, Tokenizer};
65
66    use super::TokenOutputStream;
67
68    fn tokenizer() -> Tokenizer {
69        let vocab = AHashMap::from([
70            ("Hello".to_string(), 0),
71            ("world".to_string(), 1),
72            ("!".to_string(), 2),
73            ("[UNK]".to_string(), 3),
74        ]);
75        let model = WordLevel::builder()
76            .vocab(vocab)
77            .unk_token("[UNK]".to_string())
78            .build()
79            .unwrap();
80        Tokenizer::new(model)
81    }
82
83    #[test]
84    fn streams_alphanumeric_tokens_and_flushes_punctuation_at_end() {
85        let mut stream = TokenOutputStream::new(tokenizer());
86
87        assert_eq!(stream.next_token(0).unwrap(), Some("Hello".to_string()));
88        assert_eq!(stream.next_token(1).unwrap(), Some(" world".to_string()));
89        assert_eq!(stream.next_token(2).unwrap(), None);
90        assert_eq!(stream.decode_rest().unwrap(), Some(" !".to_string()));
91    }
92}