1use std::{
2 fmt::Debug,
3 hash::Hash,
4 ops::{Index, RangeInclusive},
5};
6
7pub type TokenId = u32;
8
9#[derive(Clone)]
10pub struct SimpleVob {
11 data: Vec<u32>,
12 size: usize,
13}
14
15impl Hash for SimpleVob {
16 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
17 self.size.hash(state);
18 self.data.hash(state);
19 }
20}
21
22impl PartialEq for SimpleVob {
23 fn eq(&self, other: &Self) -> bool {
24 self.size == other.size && self.data == other.data
25 }
26}
27
28impl Eq for SimpleVob {}
29
30impl Debug for SimpleVob {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("SimpleVob")
33 .field("len", &self.len())
34 .finish()
35 }
36}
37
38impl Default for SimpleVob {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl From<SimpleVob> for Vec<u32> {
45 fn from(val: SimpleVob) -> Self {
46 val.data
47 }
48}
49
50const BITS: usize = 32;
51
52impl SimpleVob {
53 pub fn new() -> Self {
54 Self {
55 data: Vec::new(),
56 size: 0,
57 }
58 }
59
60 pub fn from_slice(bits: &[bool]) -> Self {
61 let mut r = Self::alloc(bits.len());
62 for (idx, b) in bits.iter().enumerate() {
63 r.set(idx, *b);
64 }
65 r
66 }
67
68 pub fn alloc(size: usize) -> Self {
69 let mut r = Self::new();
70 r.resize(size);
71 r
72 }
73
74 pub fn alloc_ones(size: usize) -> Self {
75 let mut r = Self::alloc(size);
76 r.set_all(true);
77 r
78 }
79
80 pub fn alloc_with_capacity(size: usize, capacity: usize) -> Self {
81 let mut r = Self::new();
82 assert!(size <= capacity);
83 r.resize(capacity);
84 r.size = size;
85 r
86 }
87
88 pub fn len(&self) -> usize {
89 self.size
90 }
91
92 pub fn is_empty(&self) -> bool {
93 self.size == 0
94 }
95
96 pub fn num_set(&self) -> usize {
97 self.data.iter().map(|x| x.count_ones() as usize).sum()
98 }
99
100 fn clear_excessive_bits(&mut self) {
101 for i in self.size..(self.data.len() * 32) {
102 self.disallow_token(i as TokenId);
104 }
105 }
106
107 pub fn to_bin_string(&self) -> String {
108 let mut s = String::new();
109 for i in 0..self.size {
110 s.push(if self.is_allowed(i as TokenId) {
111 '1'
112 } else {
113 '0'
114 });
115 }
116 s
117 }
118
119 pub fn negated(&self) -> Self {
120 let mut r = Self::new();
121 r.data = self.data.iter().map(|x| !x).collect();
122 r.size = self.size;
123 r.clear_excessive_bits();
124 r
125 }
126
127 pub fn as_ptr(&self) -> *const u32 {
128 self.data.as_ptr()
129 }
130
131 pub fn as_slice(&self) -> &[u32] {
132 &self.data
133 }
134
135 #[inline(always)]
136 pub fn iter_set_entries(&self, mut f: impl FnMut(usize)) {
137 let numelts = self.size;
138 let max_len = numelts / 32;
139 for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
140 if d == 0 {
142 continue;
143 } else if d == u32::MAX {
144 for bit in 0..32 {
145 f(idx * 32 + bit);
146 }
147 } else {
148 for bit in 0..32 {
149 if d & (1 << bit) != 0 {
150 f(idx * 32 + bit);
151 }
152 }
153 }
154 }
155 for idx in (max_len * 32)..numelts {
157 if self.is_allowed(idx as TokenId) {
158 f(idx);
159 }
160 }
161 }
162
163 #[inline(always)]
164 pub fn iter_unset_entries(&self, mut f: impl FnMut(usize)) {
165 let numelts = self.size;
166 let max_len = numelts / 32;
167 for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
168 if d == 0 {
170 for bit in 0..32 {
171 f(idx * 32 + bit);
172 }
173 } else if d == u32::MAX {
174 continue;
175 } else {
176 for bit in 0..32 {
177 if d & (1 << bit) == 0 {
178 f(idx * 32 + bit);
179 }
180 }
181 }
182 }
183 for idx in (max_len * 32)..numelts {
185 if !self.is_allowed(idx as TokenId) {
186 f(idx);
187 }
188 }
189 }
190
191 #[inline(always)]
192 pub fn iter_entries(&self, mut f: impl FnMut(bool, usize)) {
193 let numelts = self.size;
194 let max_len = numelts / 32;
195 for (idx, &d) in self.as_slice()[..max_len].iter().enumerate() {
196 if d == 0 {
198 for bit in 0..32 {
199 f(false, idx * 32 + bit);
200 }
201 } else if d == u32::MAX {
202 for bit in 0..32 {
203 f(true, idx * 32 + bit);
204 }
205 } else {
206 for bit in 0..32 {
207 f(d & (1 << bit) != 0, idx * 32 + bit);
208 }
209 }
210 }
211 for idx in (max_len * 32)..numelts {
213 f(self.is_allowed(idx as TokenId), idx);
214 }
215 }
216
217 pub fn write_to(&self, buf: &mut [u8]) {
218 assert!(buf.len() <= self.data.len() * (BITS / 8));
219 buf.copy_from_slice(&bytemuck::cast_slice(&self.data)[..buf.len()]);
220 }
221
222 #[inline(always)]
223 pub fn allow_token(&mut self, tok: TokenId) {
224 self.set(tok as usize, true)
225 }
226
227 #[inline(always)]
228 pub fn disallow_token(&mut self, tok: TokenId) {
229 self.set(tok as usize, false)
230 }
231
232 #[inline(always)]
233 pub fn set(&mut self, idx: usize, val: bool) {
234 let byte_idx = idx / BITS;
235 let bit_idx = idx % BITS;
236 if val {
237 self.data[byte_idx] |= 1 << bit_idx;
238 } else {
239 self.data[byte_idx] &= !(1 << bit_idx);
240 }
241 }
242
243 pub fn allow_range(&mut self, range: RangeInclusive<TokenId>) {
244 assert!(*range.end() < self.size as TokenId);
245 let start = *range.start() as usize;
246 let end = *range.end() as usize;
247 if start > end {
248 return;
249 }
250 let start_word = start / BITS;
251 let end_word = end / BITS;
252 let start_mask = !0u32 << (start % BITS);
253 let end_bit = end % BITS;
254 let end_mask = !0u32 >> (BITS - 1 - end_bit);
255 if start_word == end_word {
256 let mask = start_mask & end_mask;
257 self.data[start_word] |= mask;
258 } else {
259 self.data[start_word] |= start_mask;
260 for w in (start_word + 1)..end_word {
261 self.data[w] = !0u32;
262 }
263 self.data[end_word] |= end_mask;
264 }
265 }
266
267 pub fn resize(&mut self, size: usize) {
268 let new_size = size / BITS + 1;
269 assert!(new_size >= self.data.len());
270 self.data.resize(new_size, 0);
271 self.size = size;
272 }
273
274 #[inline(always)]
275 pub fn get(&self, idx: usize) -> bool {
276 let byte_idx = idx / 32;
277 let bit_idx = idx % 32;
278 (self.data[byte_idx] & (1 << bit_idx)) != 0
279 }
280
281 #[inline(always)]
282 pub fn is_allowed(&self, tok: TokenId) -> bool {
283 self.get(tok as usize)
284 }
285
286 pub fn set_all(&mut self, val: bool) {
287 let bits = if val { !0 } else { 0 };
288 self.data.iter_mut().for_each(|x| *x = bits);
289 if val {
290 self.clear_excessive_bits();
291 }
292 }
293
294 pub fn apply_to(&self, logits: &mut [f32]) {
295 for (idx, v) in self.data.iter().enumerate() {
296 if *v == 0 {
297 continue;
298 }
299 let idx = idx * BITS;
300 for bit_idx in 0..BITS {
301 if v & (1 << bit_idx) != 0 {
302 logits[idx + bit_idx] = 0.0;
303 }
304 }
305 }
306 }
307
308 pub fn iter(&self) -> SimpleVobIter {
309 SimpleVobIter { vob: self, idx: 0 }
310 }
311
312 pub fn set_from(&mut self, other: &SimpleVob) {
313 assert_eq!(self.size, other.size);
314 self.data.copy_from_slice(&other.data);
315 }
316
317 pub fn or(&mut self, other: &SimpleVob) {
318 assert!(self.size >= other.size);
319 for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
320 *idx |= *v;
321 }
322 }
323
324 pub fn trim_trailing_zeros(&mut self) {
325 let mut idx = self.data.len();
326 while idx > 0 && self.data[idx - 1] == 0 {
327 idx -= 1;
328 }
329 if self.data.len() != idx {
330 self.data.truncate(idx);
331 self.size = self.data.len() * BITS;
332 }
333 }
334
335 pub fn or_minus(&mut self, other: &SimpleVob, minus: &SimpleVob) {
337 assert_eq!(self.size, other.size);
338 assert_eq!(self.size, minus.size);
339 for ((slf, oth), mn) in self
340 .data
341 .iter_mut()
342 .zip(other.data.iter())
343 .zip(minus.data.iter())
344 {
345 *slf |= *oth & !*mn;
346 }
347 }
348
349 pub fn and(&mut self, other: &SimpleVob) {
350 assert_eq!(self.size, other.size);
351 for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
352 *idx &= *v;
353 }
354 }
355
356 pub fn is_zero(&self) -> bool {
357 self.data.iter().all(|x| *x == 0)
358 }
359
360 pub fn and_is_zero(&self, other: &SimpleVob) -> bool {
361 assert_eq!(self.size, other.size);
362 self.data
363 .iter()
364 .zip(other.data.iter())
365 .all(|(a, b)| *a & *b == 0)
366 }
367
368 pub fn sub(&mut self, other: &SimpleVob) {
369 assert_eq!(self.size, other.size);
370 for (idx, v) in self.data.iter_mut().zip(other.data.iter()) {
371 *idx &= !*v;
372 }
373 }
374
375 pub fn first_bit_set_here_and_in(&self, other: &SimpleVob) -> Option<usize> {
376 assert_eq!(self.size, other.size);
377 for (idx, (a, b)) in self.data.iter().zip(other.data.iter()).enumerate() {
378 let v = *a & *b;
379 if v != 0 {
380 return Some(idx * BITS + v.trailing_zeros() as usize);
381 }
382 }
383 None
384 }
385
386 pub fn first_bit_set(&self) -> Option<usize> {
387 for (idx, v) in self.data.iter().enumerate() {
388 if *v != 0 {
389 return Some(idx * BITS + v.trailing_zeros() as usize);
390 }
391 }
392 None
393 }
394
395 pub fn to_list(&self) -> Vec<u32> {
396 let mut r = Vec::new();
397 self.iter_set_entries(|x| r.push(x as u32));
398 r
399 }
400}
401
402pub struct SimpleVobIter<'a> {
403 vob: &'a SimpleVob,
404 idx: usize,
405}
406
407impl Iterator for SimpleVobIter<'_> {
408 type Item = u32;
409
410 #[inline(always)]
411 fn next(&mut self) -> Option<Self::Item> {
412 let mut bitoff = self.idx % BITS;
413 let mut dataoff = self.idx / BITS;
414 let data = &self.vob.data;
415 while dataoff < data.len() {
416 let d = data[dataoff] >> bitoff;
417 if d != 0 {
418 let idx = dataoff * BITS + d.trailing_zeros() as usize + bitoff;
419 self.idx = idx + 1;
420 return Some(idx as u32);
421 }
422 bitoff = 0;
423 dataoff += 1;
424 }
425 None
426 }
427}
428
429impl Index<usize> for SimpleVob {
430 type Output = bool;
431
432 fn index(&self, index: usize) -> &Self::Output {
433 if self.is_allowed(index as TokenId) {
434 &true
435 } else {
436 &false
437 }
438 }
439}