1use crate::{Result, WaxError};
2
3pub struct TokenOutputStream {
4 tokenizer: tokenizers::Tokenizer,
5 tokens: Vec<u32>,
6 prev_index: usize,
7 current_index: usize,
8}
9
10impl TokenOutputStream {
11 pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
12 Self {
13 tokenizer,
14 tokens: Vec::new(),
15 prev_index: 0,
16 current_index: 0,
17 }
18 }
19
20 pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
21 let prev_text = if self.tokens.is_empty() {
22 String::new()
23 } else {
24 self.decode(&self.tokens[self.prev_index..self.current_index])?
25 };
26
27 self.tokens.push(token);
28 let text = self.decode(&self.tokens[self.prev_index..])?;
29 if text.len() > prev_text.len() && text.chars().last().is_some_and(char::is_alphanumeric) {
30 let (_, delta) = text.split_at(prev_text.len());
31 self.prev_index = self.current_index;
32 self.current_index = self.tokens.len();
33 Ok(Some(delta.to_string()))
34 } else {
35 Ok(None)
36 }
37 }
38
39 pub fn decode_rest(&self) -> Result<Option<String>> {
40 let prev_text = if self.tokens.is_empty() {
41 String::new()
42 } else {
43 self.decode(&self.tokens[self.prev_index..self.current_index])?
44 };
45 let text = self.decode(&self.tokens[self.prev_index..])?;
46 if text.len() > prev_text.len() {
47 let (_, delta) = text.split_at(prev_text.len());
48 Ok(Some(delta.to_string()))
49 } else {
50 Ok(None)
51 }
52 }
53
54 fn decode(&self, tokens: &[u32]) -> Result<String> {
55 self.tokenizer
56 .decode(tokens, true)
57 .map_err(WaxError::tokenizer)
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use ahash::AHashMap;
64 use tokenizers::{models::wordlevel::WordLevel, Tokenizer};
65
66 use super::TokenOutputStream;
67
68 fn tokenizer() -> Tokenizer {
69 let vocab = AHashMap::from([
70 ("Hello".to_string(), 0),
71 ("world".to_string(), 1),
72 ("!".to_string(), 2),
73 ("[UNK]".to_string(), 3),
74 ]);
75 let model = WordLevel::builder()
76 .vocab(vocab)
77 .unk_token("[UNK]".to_string())
78 .build()
79 .unwrap();
80 Tokenizer::new(model)
81 }
82
83 #[test]
84 fn streams_alphanumeric_tokens_and_flushes_punctuation_at_end() {
85 let mut stream = TokenOutputStream::new(tokenizer());
86
87 assert_eq!(stream.next_token(0).unwrap(), Some("Hello".to_string()));
88 assert_eq!(stream.next_token(1).unwrap(), Some(" world".to_string()));
89 assert_eq!(stream.next_token(2).unwrap(), None);
90 assert_eq!(stream.decode_rest().unwrap(), Some(" !".to_string()));
91 }
92}