1impl AsMut<Self> for TokenDict{
2 fn as_mut(&mut self)->&mut Self{self}
3}
4impl AsRef<Self> for TokenDict{
5 fn as_ref(&self)->&Self{self}
6}
7impl Default for TokenDict{
8 fn default()->Self{
9 let x:[Vec<u8>;0]=[];
10 x.into_iter().collect()
11 }
12}
13impl DoubleEndedIterator for DictIntoIter{
14 fn next_back(&mut self)->Option<Self::Item>{
15 self.range.next_back().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
16 }
17 fn nth_back(&mut self,n:usize)->Option<Self::Item>{
18 self.range.nth_back(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
19 }
20 fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
21 let (range,tokens)=(self.range,self.tokens);
22 let (start,stop)=(range.start,range.end);
23
24 init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().rfold(init,&mut f);
25 SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().rfold(init,f)
26 }
27}
28impl ExactSizeIterator for DictIntoIter{
29 fn len(&self)->usize{self.range.len()}
30}
31impl Index<u32> for TokenDict{
32 fn index(&self,ix:u32)->&Self::Output{
33 let ix=ix as usize;
34 if ix<256{&SINGLE_BYTES[ix..ix+1]}else{&self.tokens[ix-256]}
35 }
36 type Output=[u8];
37}
38impl Index<usize> for TokenDict{
39 fn index(&self,ix:usize)->&Self::Output{
40 if ix<256{&SINGLE_TOKENS[ix]}else{&self.tokens[ix-256]}
41 }
42 type Output=Token;
43}
44impl IntoIterator for TokenDict{
45 fn into_iter(self)->Self::IntoIter{
46 DictIntoIter{range:0..self.len(),tokens:self.tokens}
47 }
48 type IntoIter=DictIntoIter;
49 type Item=Token;
50}
51impl Iterator for DictIntoIter{
52 fn count(self)->usize{self.range.count()}
53 fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
54 let (range,tokens)=(self.range,self.tokens);
55 let (start,stop)=(range.start,range.end);
56
57 init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().fold(init,&mut f);
58 tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().fold(init,f)
59 }
60 fn last(mut self)->Option<Self::Item>{self.next_back()}
61 fn next(&mut self)->Option<Self::Item>{
62 self.range.next().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
63 }
64 fn nth(&mut self,n:usize)->Option<Self::Item>{
65 self.range.nth(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
66 }
67 fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
68 type Item=Token;
69}
70#[cfg(feature="serial")]
71impl Serialize for TokenDict{
72 fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{
73 let data:Vec<Vec<u8>>=self.iter().map(|t|t.to_vec()).collect();
74 data.serialize(serializer)
75 }
76}
77impl TokenDict{
78 pub fn detokenize<I:IntoIterator>(&self,tokens:I)->Detokenization<I::IntoIter> where I::Item:Val<u32>{
80 Detokenization{inner:tokens.into_iter().fuse(),maxtokenlen:self.maxtokenlen,position:1,tokenid:0,tokens:self.tokens.clone()}
81 }
82 pub fn detoken_iter<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=Token> where I::Item:Val<u32>{
84 let tokenizer=self.clone();
85 tokens.into_iter().map(move|id|tokenizer[id.val() as usize].clone())
86 }
87 pub fn detokenize_str<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=char> where I::Item:Val<u32>{
89 UTF8CharIter::from(self.detokenize(tokens)).map(|r|if let Ok(c)=r{c}else{char::REPLACEMENT_CHARACTER})
90 }
91 pub fn detokenize_string<I:IntoIterator>(&self,tokens:I)->String where I::Item:Val<u32>{self.detokenize_str(tokens).collect()}
93 pub fn frequencies<I:IntoIterator,O:Into<Option<Vec<usize>>>>(&self,data:I,freq:O)->Vec<usize> where I::Item:Val<u8>{
95 let mut freq=freq.into().unwrap_or_default();
96 if freq.len()<self.len(){freq.resize(self.len(),0)}
97
98 for t in self.tokenize(data){freq[t as usize]+=1}
99 freq
100 }
101 pub fn get_id(&self,token:&[u8])->Option<u32>{self.ids[*token.get(0)? as usize].get(token.iter().copied().skip(1)).copied()}
103 pub fn iter(&self)->DictIter<'_>{
105 DictIter{range:0..self.len(),tokens:&self.tokens}
106 }
107 pub fn len(&self)->usize{self.tokens.len()+256}
109 pub fn pairs<I:IntoIterator,O:Into<Option<HashMap<Token,usize>>>>(&self,data:I,freq:O)->HashMap<Token,usize> where I::Item:Val<u8>{
111 let mut freq=freq.into().unwrap_or_default();
112 let mut nexttokenid=freq.keys().map(|t|t.id()+1).chain([self.len() as u32]).max().unwrap();
113 let mut previous:Option<u32>=None;
114 let mut temp:Vec<u8>=Vec::new();
115
116 self.tokenize(data).for_each(|id|{
117 if let Some(previous)=previous{
118 temp.clear();
119 temp.extend(self[previous as usize].clone());
120 temp.extend(self[id as usize].clone());
121
122 if let Some(f)=freq.get_mut(temp.as_slice()){
123 *f+=1
124 }else{
125 let newid=if let Some(id)=self.get_id(&temp){id}else{post_inc!(nexttokenid)};
126 let token=Token::new(newid,Some(Arc::from(temp.as_slice())));
127
128 freq.insert(token,1);
129 }
130 }
131 previous=Some(id);
132 });
133 freq
134 }
135 pub fn push<A:AsRef<[u8]>>(&mut self,token:A){self.extend(Some(token))}
137 pub fn string_to_tokens<S:?Sized+AsRef<str>>(&self,input:&S)->Vec<u32>{self.tokenize(input.as_ref().as_bytes()).collect()}
139 pub fn token_iter<I:IntoIterator>(&self,bytes:I)->impl Iterator<Item=Token> where I::Item:Val<u8>{
141 let tokenizer=self.clone();
142 self.tokenize(bytes).map(move|id|tokenizer[id as usize].clone())
143 }
144 pub fn tokenize<I:IntoIterator>(&self,bytes:I)->Tokenization<I::IntoIter> where I::Item:Val<u8>{
146 Tokenization{ids:self.ids.clone(),inner:bytes.into_iter().fuse(),state:VecDeque::with_capacity(self.maxtokenlen)}
147 }
148 pub fn tokenize_str<'a,S:?Sized+AsRef<str>>(&self,input:&'a S)->Tokenization<SliceIter<'a,u8>>{self.tokenize(input.as_ref().as_bytes())}
150 pub fn tokenize_string(&self,input:String)->Tokenization<VecIntoIter<u8>>{self.tokenize(Vec::from(input))}
152 pub fn tokens_to_string<V:?Sized+AsRef<[u32]>>(&self,input:&V)->String{String::from_utf8_lossy(&self.detokenize(input.as_ref()).collect::<Vec<u8>>()).to_string()}
154}
155#[cfg(feature="serial")]
156impl<'a> Deserialize<'a> for TokenDict{
157 fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{
158 let data:Vec<Vec<u8>>=Deserialize::deserialize(deserializer)?;
159 Ok(data.into_iter().collect())
160 }
161}
162impl<'a> DoubleEndedIterator for DictIter<'a>{
163 fn next_back(&mut self)->Option<Self::Item>{
164 self.range.next_back().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
165 }
166 fn nth_back(&mut self,n:usize)->Option<Self::Item>{
167 self.range.nth_back(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
168 }
169 fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
170 let (range,tokens)=(self.range,self.tokens);
171 let (start,stop)=(range.start,range.end);
172
173 init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().rfold(init,&mut f);
174 SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().rfold(init,f)
175 }
176}
177impl<'a> ExactSizeIterator for DictIter<'a>{
178 fn len(&self)->usize{self.range.len()}
179}
180impl<'a> Iterator for DictIter<'a>{
181 fn count(self)->usize{self.range.count()}
182 fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
183 let (range,tokens)=(self.range,self.tokens);
184 let (start,stop)=(range.start,range.end);
185
186 init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().fold(init,&mut f);
187 tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().fold(init,f)
188 }
189 fn last(mut self)->Option<Self::Item>{self.next_back()}
190 fn next(&mut self)->Option<Self::Item>{
191 self.range.next().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
192 }
193 fn nth(&mut self,n:usize)->Option<Self::Item>{
194 self.range.nth(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
195 }
196 fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
197 type Item=&'a Token;
198}
199impl<A:AsRef<[u8]>> Extend<A> for TokenDict{
200 fn extend<I:IntoIterator<Item=A>>(&mut self,iter:I){
201 let (ids,tokens)=(Arc::make_mut(&mut self.ids),Arc::make_mut(&mut self.tokens));
202 let maxtokenlen=&mut self.maxtokenlen;
203
204 iter.into_iter().filter(|a|a.as_ref().len()>1).for_each(|a|{
205 let id=u32::try_from(tokens.len()+256).unwrap();
206 let token:Arc<[u8]>=Arc::from(a.as_ref());
207
208 ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
209 *maxtokenlen=(*maxtokenlen).max(token.len());
210 tokens.push(Token::new(id,Some(token)))
211 });
212 }
213}
214impl<A:AsRef<[u8]>> FromIterator<A> for TokenDict{
215 fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{
216 let mut maxtokenlen=1;
217 let mut ids:[Trie<_,_>;256]=std::array::from_fn(|_|Trie::new());
218 let tokens:Vec<Token>=iter.into_iter().filter(|t|t.as_ref().len()>1).enumerate().map(|(n,t)|{
219 let id=u32::try_from(n+256).unwrap();
220 let token:Arc<[u8]>=Arc::from(t.as_ref());
221
222 ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
223 maxtokenlen=maxtokenlen.max(token.len());
224
225 Token::new(id,Some(token))
226 }).collect();
227
228 let (ids,tokens)=(Arc::new(ids),Arc::new(tokens));
229 Self{ids,maxtokenlen,tokens}
230 }
231}
232impl<I:Iterator> Iterator for Detokenization<I> where I::Item:Val<u32>{
233 fn fold<B,F:FnMut(B,Self::Item)->B>(self,init:B,mut f:F)->B{
234 self.inner.map(Val::val).fold(init,|acc,tokenid|if tokenid<256{f(acc,tokenid as u8)}else{self.tokens[tokenid as usize].iter().fold(acc,|acc,&b|f(acc,b))})
235 }
236 fn next(&mut self)->Option<u8>{
237 let (inner,position)=(&mut self.inner,&mut self.position);
238 let tokenid=&mut self.tokenid;
239 let tokens=&self.tokens;
240
241 if let Some(b)=if *tokenid<256{(*position==0).then_some(*tokenid as u8)}else{tokens[*tokenid as usize-256].get(*position).map(|&b|b)}{
242 *position+=1;
243 b
244 }else{
245 *position=1;
246 *tokenid=inner.map(Val::val).next()?;
247 if *tokenid<256{*tokenid as u8}else{tokens[*tokenid as usize-256][0]}
248 }.into()
249 }
250 fn size_hint(&self)->(usize,Option<usize>){
251 let (lowertokens,uppertokens)=self.inner.size_hint();
252 let maxtoken=self.maxtokenlen;
253
254 (lowertokens,uppertokens.map(|h|h*maxtoken))
255 }
256 type Item=u8;
257}
258impl<I:Iterator> Iterator for Tokenization<I> where I::Item:Val<u8>{
259 fn next(&mut self)->Option<u32>{
260 let (inner,state)=(&mut self.inner,&mut self.state);
261 let ids=&self.ids;
262 state.extend(inner.map(Val::val).take(state.capacity()-state.len()));
263 if state.len()==0{return None}
264
265 let (tokenlen,&tokenid)=if let Some(t)=ids[state[0] as usize].find_longest_prefix_len(state.iter().copied().skip(1)).filter(|(tokenlen,_tokenid)|*tokenlen>0){t}else{return Some(state.pop_front().unwrap() as u32)};
266 state.drain(..tokenlen+1);
267 return Some(tokenid)
268 }
269 fn size_hint(&self)->(usize,Option<usize>){
270 let (lowerbytes,upperbytes)=self.inner.size_hint();
271 let maxtoken=self.state.capacity();
272 let statelen=self.state.len();
273
274 ((lowerbytes+statelen).div_ceil(maxtoken),upperbytes.map(|b|b+statelen))
275 }
276 type Item=u32;
277}
278
279macro_rules! post_inc {
280 ($e:expr) => {{
281 let old = $e;
282 $e += 1;
283 old
284 }};
285}
286
287#[cfg(test)]
288mod tests{
289 #[test]
290 fn tokenizer_iter(){
291 let tokenizer:TokenDict=["aa","bb","cc"].into_iter().collect();
292 let t2:TokenDict=tokenizer.iter().collect();
293 assert_eq!(tokenizer.tokenize_str("ccaabb").collect::<Vec<_>>(),t2.tokenize_str("ccaabb").collect::<Vec<_>>());
294 }
295 #[test]
296 fn bytes_only(){
297 let teststring="oishsoghohhduihahdufghud";
298 let tokenizer=TokenDict::default();
299 let tokens:Vec<u32>=tokenizer.tokenize_str(teststring).collect();
300 let detokens:Vec<u8>=tokenizer.detokenize(tokens).collect();
301
302 assert_eq!(detokens.as_slice(),teststring.as_bytes());
303 }
304 #[test]
305 fn there_are_tokens_yay(){
306 let teststring="there are tokens! yay";
307 let tokenizer:TokenDict=["there","are","tokens","yay"].into_iter().collect();
308 let tokens:Vec<u32>=tokenizer.tokenize(teststring.bytes()).collect();
309 let detokens:Vec<u8>=tokenizer.detokenize(&tokens).collect();
310
311 assert_eq!(tokens.len(),8);
312 assert_eq!(detokens.as_slice(),teststring.as_bytes());
313 }
314 #[test]
315 fn test_default_token_dict_detokenize_empty() {
316 let dict = TokenDict::default();
317 let inp:Vec<u32>=vec![];
318 let out: Vec<u8> = dict.detokenize(inp).collect();
319 assert!(out.is_empty(), "Detokenizing an empty input should yield no bytes");
320 }
321 use super::*;
322}
323
324pub (crate) const SINGLE_TOKENS:&[Token;256]=&[Token::single(0),Token::single(1),Token::single(2),Token::single(3),Token::single(4),Token::single(5),Token::single(6),Token::single(7),Token::single(8),Token::single(9),Token::single(10),Token::single(11),Token::single(12),Token::single(13),Token::single(14),Token::single(15),Token::single(16),Token::single(17),Token::single(18),Token::single(19),Token::single(20),Token::single(21),Token::single(22),Token::single(23),Token::single(24),Token::single(25),Token::single(26),Token::single(27),Token::single(28),Token::single(29),Token::single(30),Token::single(31),Token::single(32),Token::single(33),Token::single(34),Token::single(35),Token::single(36),Token::single(37),Token::single(38),Token::single(39),Token::single(40),Token::single(41),Token::single(42),Token::single(43),Token::single(44),Token::single(45),Token::single(46),Token::single(47),Token::single(48),Token::single(49),Token::single(50),Token::single(51),Token::single(52),Token::single(53),Token::single(54),Token::single(55),Token::single(56),Token::single(57),Token::single(58),Token::single(59),Token::single(60),Token::single(61),Token::single(62),Token::single(63),Token::single(64),Token::single(65),Token::single(66),Token::single(67),Token::single(68),Token::single(69),Token::single(70),Token::single(71),Token::single(72),Token::single(73),Token::single(74),Token::single(75),Token::single(76),Token::single(77),Token::single(78),Token::single(79),Token::single(80),Token::single(81),Token::single(82),Token::single(83),Token::single(84),Token::single(85),Token::single(86),Token::single(87),Token::single(88),Token::single(89),Token::single(90),Token::single(91),Token::single(92),Token::single(93),Token::single(94),Token::single(95),Token::single(96),Token::single(97),Token::single(98),Token::single(99),Token::single(100),Token::single(101),Token::single(102),Token::single(103),Token::single(104),Token::single(105),Token::single(106),Token::single(107),Token::single(108),Token::single(109),Token::single(110),Token::single(111),Token::single(112),Token::single(113),Token::single(114),Token::single(115),Token::single(116),Token::single(117),Token::single(118),Token::single(119),Token::single(120),Token::single(121),Token::single(122),Token::single(123),Token::single(124),Token::single(125),Token::single(126),Token::single(127),Token::single(128),Token::single(129),Token::single(130),Token::single(131),Token::single(132),Token::single(133),Token::single(134),Token::single(135),Token::single(136),Token::single(137),Token::single(138),Token::single(139),Token::single(140),Token::single(141),Token::single(142),Token::single(143),Token::single(144),Token::single(145),Token::single(146),Token::single(147),Token::single(148),Token::single(149),Token::single(150),Token::single(151),Token::single(152),Token::single(153),Token::single(154),Token::single(155),Token::single(156),Token::single(157),Token::single(158),Token::single(159),Token::single(160),Token::single(161),Token::single(162),Token::single(163),Token::single(164),Token::single(165),Token::single(166),Token::single(167),Token::single(168),Token::single(169),Token::single(170),Token::single(171),Token::single(172),Token::single(173),Token::single(174),Token::single(175),Token::single(176),Token::single(177),Token::single(178),Token::single(179),Token::single(180),Token::single(181),Token::single(182),Token::single(183),Token::single(184),Token::single(185),Token::single(186),Token::single(187),Token::single(188),Token::single(189),Token::single(190),Token::single(191),Token::single(192),Token::single(193),Token::single(194),Token::single(195),Token::single(196),Token::single(197),Token::single(198),Token::single(199),Token::single(200),Token::single(201),Token::single(202),Token::single(203),Token::single(204),Token::single(205),Token::single(206),Token::single(207),Token::single(208),Token::single(209),Token::single(210),Token::single(211),Token::single(212),Token::single(213),Token::single(214),Token::single(215),Token::single(216),Token::single(217),Token::single(218),Token::single(219),Token::single(220),Token::single(221),Token::single(222),Token::single(223),Token::single(224),Token::single(225),Token::single(226),Token::single(227),Token::single(228),Token::single(229),Token::single(230),Token::single(231),Token::single(232),Token::single(233),Token::single(234),Token::single(235),Token::single(236),Token::single(237),Token::single(238),Token::single(239),Token::single(240),Token::single(241),Token::single(242),Token::single(243),Token::single(244),Token::single(245),Token::single(246),Token::single(247),Token::single(248),Token::single(249),Token::single(250),Token::single(251),Token::single(252),Token::single(253),Token::single(254),Token::single(255)];
325#[derive(Clone,Debug)]
326pub struct Detokenization<I:Iterator> where I::Item:Val<u32>{inner:Fuse<I>,maxtokenlen:usize,position:usize,tokenid:u32,tokens:Arc<Vec<Token>>}
328#[derive(Clone,Debug)]
329pub struct DictIter<'a>{range:Range<usize>,tokens:&'a [Token]}
331#[derive(Clone,Debug)]
332pub struct DictIntoIter{range:Range<usize>,tokens:Arc<Vec<Token>>}
334#[derive(Clone,Debug)]
335pub struct TokenDict{ids:Arc<[Trie<u8,u32>;256]>,maxtokenlen:usize,tokens:Arc<Vec<Token>>}
337#[derive(Clone,Debug)]
338pub struct Tokenization<I:Iterator> where I::Item:Val<u8>{ids:Arc<[Trie<u8,u32>;256]>,inner:Fuse<I>,state:VecDeque<u8>}
340
341use crate::{Token,UTF8CharIter,Val,token::SINGLE_BYTES};
342use post_inc;
343use ptrie::Trie;
344#[cfg(feature="serial")]
345use serde::{Deserialize,Deserializer,Serialize,Serializer};
346use std::{
347 collections::{HashMap,VecDeque},iter::{Extend,Fuse},ops::{Index,Range},slice::Iter as SliceIter,sync::Arc,vec::IntoIter as VecIntoIter,
348};