toktrie/
svob.rs

1use std::{
2    fmt::Debug,
3    hash::Hash,
4    ops::{Index, RangeInclusive},
5};
6
7pub type TokenId = u32;
8
9#[derive(Clone)]
10pub struct SimpleVob {
11    data: Vec<u32>,
12    size: usize,
13}
14
15impl Hash for SimpleVob {
16    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
17        self.size.hash(state);
18        self.data.hash(state);
19    }
20}
21
22impl PartialEq for SimpleVob {
23    fn eq(&self, other: &Self) -> bool {
24        self.size == other.size && self.data == other.data
25    }
26}
27
28impl Eq for SimpleVob {}
29
30impl Debug for SimpleVob {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("SimpleVob")
33            .field("len", &self.len())
34            .finish()
35    }
36}
37
38impl Default for SimpleVob {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl From<SimpleVob> for Vec<u32> {
45    fn from(val: SimpleVob) -> Self {
46        val.data
47    }
48}
49
50const BITS: usize = 32;
51
52impl SimpleVob {
53    pub fn new() -> Self {
54        Self {
55            data: Vec::new(),
56            size: 0,
57        }
58    }
59
60    pub fn from_slice(bits: &[bool]) -> Self {
61        let mut r = Self::alloc(bits.len());
62        for (idx, b) in bits.iter().enumerate() {
63            r.set(idx, *b);
64        }
65        r
66    }
67
68    pub fn alloc(size: usize) -> Self {
69        let mut r = Self::new();
70        r.resize(size);
71        r
72    }
73
74    pub fn alloc_ones(size: usize) -> Self {
75        let mut r = Self::alloc(size);
76        r.set_all(true);
77        r
78    }
79
80    pub fn alloc_with_capacity(size: usize, capacity: usize) -> Self {
81        let mut r = Self::new();
82        assert!(size <= capacity);
83        r.resize(capacity);
84        r.size = size;
85        r
86    }
87
88    pub fn len(&self) -> usize {
89        self.size
90    }
91
92    pub fn is_empty(&self) -> bool {
93        self.size == 0
94    }
95
96    pub fn num_set(&self) -> usize {
97        self.data.iter().map(|x| x.count_ones() as usize).sum()
98    }
99
100    fn clear_excessive_bits(&mut self) {
101        for i in self.size..(self.data.len() * 32) {
102            // disallow tokens that are out of range
103            self.disallow_token(i as TokenId);
104        }
105    }
106
107    pub fn to_bin_string(&self) -> String {
108        let mut s = String::new();
109        for i in 0..self.size {
110            s.push(if self.is_allowed(i as TokenId) {
111                '1'
112            } else {
113                '0'
114            });
115        }
116        s
117    }
118
119    pub fn negated(&self) -> Self {
120        let mut r = Self::new();
121        r.data = self.data.iter().map(|x| !x).collect();
122        r.size = self.size;
123        r.clear_excessive_bits();
124        r
125    }
126
127    pub fn as_ptr(&self) -> *const u32 {
128        self.data.as_ptr()
129    }
130
131    pub fn as_slice(&self) -> &[u32] {
132        &self.data
133    }
134
135    #[inline(always)]
136    pub fn iter_set_entries(&self, mut f: impl FnMut(usize)) {
137        let numelts = self.size;
138        let max_len = numelts / 32;
139        for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
140            // optimize for the two common cases
141            if d == 0 {
142                continue;
143            } else if d == u32::MAX {
144                for bit in 0..32 {
145                    f(idx * 32 + bit);
146                }
147            } else {
148                for bit in 0..32 {
149                    if d & (1 << bit) != 0 {
150                        f(idx * 32 + bit);
151                    }
152                }
153            }
154        }
155        // final few elts
156        for idx in (max_len * 32)..numelts {
157            if self.is_allowed(idx as TokenId) {
158                f(idx);
159            }
160        }
161    }
162
163    #[inline(always)]
164    pub fn iter_unset_entries(&self, mut f: impl FnMut(usize)) {
165        let numelts = self.size;
166        let max_len = numelts / 32;
167        for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
168            // optimize for the two common cases
169            if d == 0 {
170                for bit in 0..32 {
171                    f(idx * 32 + bit);
172                }
173            } else if d == u32::MAX {
174                continue;
175            } else {
176                for bit in 0..32 {
177                    if d & (1 << bit) == 0 {
178                        f(idx * 32 + bit);
179                    }
180                }
181            }
182        }
183        // final few elts
184        for idx in (max_len * 32)..numelts {
185            if !self.is_allowed(idx as TokenId) {
186                f(idx);
187            }
188        }
189    }
190
191    #[inline(always)]
192    pub fn iter_entries(&self, mut f: impl FnMut(bool, usize)) {
193        let numelts = self.size;
194        let max_len = numelts / 32;
195        for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
196            // optimize for the two common cases
197            if d == 0 {
198                for bit in 0..32 {
199                    f(false, idx * 32 + bit);
200                }
201            } else if d == u32::MAX {
202                for bit in 0..32 {
203                    f(true, idx * 32 + bit);
204                }
205            } else {
206                for bit in 0..32 {
207                    f(d & (1 << bit) != 0, idx * 32 + bit);
208                }
209            }
210        }
211        // final few elts
212        for idx in (max_len * 32)..numelts {
213            f(self.is_allowed(idx as TokenId), idx);
214        }
215    }
216
217    pub fn write_to(&self, buf: &mut [u8]) {
218        assert!(buf.len() <= self.data.len() * (BITS / 8));
219        buf.copy_from_slice(&bytemuck::cast_slice(&self.data)[..buf.len()]);
220    }
221
222    #[inline(always)]
223    pub fn allow_token(&mut self, tok: TokenId) {
224        self.set(tok as usize, true)
225    }
226
227    #[inline(always)]
228    pub fn disallow_token(&mut self, tok: TokenId) {
229        self.set(tok as usize, false)
230    }
231
232    #[inline(always)]
233    pub fn set(&mut self, idx: usize, val: bool) {
234        let byte_idx = idx / BITS;
235        let bit_idx = idx % BITS;
236        if val {
237            self.data[byte_idx] |= 1 << bit_idx;
238        } else {
239            self.data[byte_idx] &= !(1 << bit_idx);
240        }
241    }
242
243    pub fn allow_range(&mut self, range: RangeInclusive<TokenId>) {
244        assert!(*range.end() < self.size as TokenId);
245        let start = *range.start() as usize;
246        let end = *range.end() as usize;
247        if start > end {
248            return;
249        }
250        let start_word = start / BITS;
251        let end_word = end / BITS;
252        let start_mask = !0u32 << (start % BITS);
253        let end_bit = end % BITS;
254        let end_mask = !0u32 >> (BITS - 1 - end_bit);
255        if start_word == end_word {
256            let mask = start_mask & end_mask;
257            self.data[start_word] |= mask;
258        } else {
259            self.data[start_word] |= start_mask;
260            for w in (start_word + 1)..end_word {
261                self.data[w] = !0u32;
262            }
263            self.data[end_word] |= end_mask;
264        }
265    }
266
267    pub fn resize(&mut self, size: usize) {
268        let new_size = size / BITS + 1;
269        assert!(new_size >= self.data.len());
270        self.data.resize(new_size, 0);
271        self.size = size;
272    }
273
274    #[inline(always)]
275    pub fn get(&self, idx: usize) -> bool {
276        let byte_idx = idx / 32;
277        let bit_idx = idx % 32;
278        (self.data[byte_idx] & (1 << bit_idx)) != 0
279    }
280
281    #[inline(always)]
282    pub fn is_allowed(&self, tok: TokenId) -> bool {
283        self.get(tok as usize)
284    }
285
286    pub fn set_all(&mut self, val: bool) {
287        let bits = if val { !0 } else { 0 };
288        self.data.iter_mut().for_each(|x| *x = bits);
289        if val {
290            self.clear_excessive_bits();
291        }
292    }
293
294    pub fn apply_to(&self, logits: &mut [f32]) {
295        for (idx, v) in self.data.iter().enumerate() {
296            if *v == 0 {
297                continue;
298            }
299            let idx = idx * BITS;
300            for bit_idx in 0..BITS {
301                if v & (1 << bit_idx) != 0 {
302                    logits[idx + bit_idx] = 0.0;
303                }
304            }
305        }
306    }
307
308    pub fn iter(&self) -> SimpleVobIter {
309        SimpleVobIter { vob: self, idx: 0 }
310    }
311
312    pub fn set_from(&mut self, other: &SimpleVob) {
313        assert_eq!(self.size, other.size);
314        self.data.copy_from_slice(&other.data);
315    }
316
317    pub fn or(&mut self, other: &SimpleVob) {
318        assert!(self.size >= other.size);
319        for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
320            *idx |= *v;
321        }
322    }
323
324    pub fn trim_trailing_zeros(&mut self) {
325        let mut idx = self.data.len();
326        while idx > 0 && self.data[idx - 1] == 0 {
327            idx -= 1;
328        }
329        if self.data.len() != idx {
330            self.data.truncate(idx);
331            self.size = self.data.len() * BITS;
332        }
333    }
334
335    /// self |= other & !minus
336    pub fn or_minus(&mut self, other: &SimpleVob, minus: &SimpleVob) {
337        assert_eq!(self.size, other.size);
338        assert_eq!(self.size, minus.size);
339        for ((slf, oth), mn) in self
340            .data
341            .iter_mut()
342            .zip(other.data.iter())
343            .zip(minus.data.iter())
344        {
345            *slf |= *oth & !*mn;
346        }
347    }
348
349    pub fn and(&mut self, other: &SimpleVob) {
350        assert_eq!(self.size, other.size);
351        for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
352            *idx &= *v;
353        }
354    }
355
356    pub fn is_zero(&self) -> bool {
357        self.data.iter().all(|x| *x == 0)
358    }
359
360    pub fn and_is_zero(&self, other: &SimpleVob) -> bool {
361        assert_eq!(self.size, other.size);
362        self.data
363            .iter()
364            .zip(other.data.iter())
365            .all(|(a, b)| *a & *b == 0)
366    }
367
368    pub fn sub(&mut self, other: &SimpleVob) {
369        assert_eq!(self.size, other.size);
370        for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
371            *idx &= !*v;
372        }
373    }
374
375    pub fn first_bit_set_here_and_in(&self, other: &SimpleVob) -> Option<usize> {
376        assert_eq!(self.size, other.size);
377        for (idx, (a, b)) in self.data.iter().zip(other.data.iter()).enumerate() {
378            let v = *a & *b;
379            if v != 0 {
380                return Some(idx * BITS + v.trailing_zeros() as usize);
381            }
382        }
383        None
384    }
385
386    pub fn first_bit_set(&self) -> Option<usize> {
387        for (idx, v) in self.data.iter().enumerate() {
388            if *v != 0 {
389                return Some(idx * BITS + v.trailing_zeros() as usize);
390            }
391        }
392        None
393    }
394
395    pub fn to_list(&self) -> Vec<u32> {
396        let mut r = Vec::new();
397        self.iter_set_entries(|x| r.push(x as u32));
398        r
399    }
400}
401
402pub struct SimpleVobIter<'a> {
403    vob: &'a SimpleVob,
404    idx: usize,
405}
406
407impl Iterator for SimpleVobIter<'_> {
408    type Item = u32;
409
410    #[inline(always)]
411    fn next(&mut self) -> Option<Self::Item> {
412        let mut bitoff = self.idx % BITS;
413        let mut dataoff = self.idx / BITS;
414        let data = &self.vob.data;
415        while dataoff < data.len() {
416            let d = data[dataoff] >> bitoff;
417            if d != 0 {
418                let idx = dataoff * BITS + d.trailing_zeros() as usize + bitoff;
419                self.idx = idx + 1;
420                return Some(idx as u32);
421            }
422            bitoff = 0;
423            dataoff += 1;
424        }
425        None
426    }
427}
428
429impl Index<usize> for SimpleVob {
430    type Output = bool;
431
432    fn index(&self, index: usize) -> &Self::Output {
433        if self.is_allowed(index as TokenId) {
434            &true
435        } else {
436            &false
437        }
438    }
439}