1use std::collections::HashSet;
53use std::fmt;
54use std::sync::Arc;
55
56#[derive(Clone)]
65pub enum AllowedSet {
66 All,
68
69 Bitmap(Arc<AllowedBitmap>),
71
72 SortedVec(Arc<Vec<u64>>),
74
75 HashSet(Arc<HashSet<u64>>),
77
78 None,
80}
81
82impl AllowedSet {
83 pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
85 if bitmap.is_empty() {
86 Self::None
87 } else if bitmap.is_all() {
88 Self::All
89 } else {
90 Self::Bitmap(Arc::new(bitmap))
91 }
92 }
93
94 pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
96 if ids.is_empty() {
97 return Self::None;
98 }
99 ids.sort_unstable();
100 ids.dedup();
101 Self::SortedVec(Arc::new(ids))
102 }
103
104 pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
106 let set: HashSet<u64> = ids.into_iter().collect();
107 if set.is_empty() {
108 Self::None
109 } else {
110 Self::HashSet(Arc::new(set))
111 }
112 }
113
114 #[inline]
118 pub fn contains(&self, doc_id: u64) -> bool {
119 match self {
120 Self::All => true,
121 Self::Bitmap(bm) => bm.contains(doc_id),
122 Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
123 Self::HashSet(set) => set.contains(&doc_id),
124 Self::None => false,
125 }
126 }
127
128 pub fn is_empty(&self) -> bool {
130 matches!(self, Self::None)
131 }
132
133 pub fn is_all(&self) -> bool {
135 matches!(self, Self::All)
136 }
137
138 pub fn cardinality(&self) -> Option<usize> {
142 match self {
143 Self::All => None,
144 Self::Bitmap(bm) => Some(bm.count()),
145 Self::SortedVec(vec) => Some(vec.len()),
146 Self::HashSet(set) => Some(set.len()),
147 Self::None => Some(0),
148 }
149 }
150
151 pub fn selectivity(&self, universe_size: usize) -> f64 {
155 if universe_size == 0 {
156 return 0.0;
157 }
158 match self {
159 Self::All => 1.0,
160 Self::None => 0.0,
161 other => other
162 .cardinality()
163 .map(|c| c as f64 / universe_size as f64)
164 .unwrap_or(1.0),
165 }
166 }
167
168 pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
170 match (self, other) {
171 (Self::All, x) | (x, Self::All) => x.clone(),
173 (Self::None, _) | (_, Self::None) => Self::None,
174
175 (Self::SortedVec(a), Self::SortedVec(b)) => {
177 let result = sorted_vec_intersect(a, b);
178 Self::from_sorted_vec(result)
179 }
180 (Self::HashSet(a), Self::HashSet(b)) => {
181 let result: HashSet<_> = a.intersection(b).copied().collect();
182 if result.is_empty() {
183 Self::None
184 } else {
185 Self::HashSet(Arc::new(result))
186 }
187 }
188 (Self::Bitmap(a), Self::Bitmap(b)) => {
189 let result = a.intersect(b);
190 Self::from_bitmap(result)
191 }
192
193 (a, b) => {
195 let set_a: HashSet<u64> = a.iter().collect();
196 let set_b: HashSet<u64> = b.iter().collect();
197 let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
198 if result.is_empty() {
199 Self::None
200 } else {
201 Self::HashSet(Arc::new(result))
202 }
203 }
204 }
205 }
206
207 pub fn union(&self, other: &AllowedSet) -> AllowedSet {
209 match (self, other) {
210 (Self::All, _) | (_, Self::All) => Self::All,
211 (Self::None, x) | (x, Self::None) => x.clone(),
212
213 (Self::HashSet(a), Self::HashSet(b)) => {
214 let result: HashSet<_> = a.union(b).copied().collect();
215 Self::HashSet(Arc::new(result))
216 }
217
218 (a, b) => {
220 let mut result: HashSet<u64> = a.iter().collect();
221 result.extend(b.iter());
222 Self::HashSet(Arc::new(result))
223 }
224 }
225 }
226
227 pub fn iter(&self) -> AllowedSetIter<'_> {
231 match self {
232 Self::All => AllowedSetIter::Empty,
233 Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
234 Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
235 Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
236 Self::None => AllowedSetIter::Empty,
237 }
238 }
239
240 pub fn to_vec(&self) -> Vec<u64> {
242 self.iter().collect()
243 }
244}
245
246impl fmt::Debug for AllowedSet {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 match self {
249 Self::All => write!(f, "AllowedSet::All"),
250 Self::None => write!(f, "AllowedSet::None"),
251 Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
252 Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
253 Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
254 }
255 }
256}
257
258impl Default for AllowedSet {
259 fn default() -> Self {
260 Self::All
261 }
262}
263
264fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
266 let mut result = Vec::with_capacity(a.len().min(b.len()));
267 let mut i = 0;
268 let mut j = 0;
269
270 while i < a.len() && j < b.len() {
271 match a[i].cmp(&b[j]) {
272 std::cmp::Ordering::Less => i += 1,
273 std::cmp::Ordering::Greater => j += 1,
274 std::cmp::Ordering::Equal => {
275 result.push(a[i]);
276 i += 1;
277 j += 1;
278 }
279 }
280 }
281
282 result
283}
284
285pub enum AllowedSetIter<'a> {
291 Empty,
292 Bitmap(BitmapIter<'a>),
293 SortedVec(std::slice::Iter<'a, u64>),
294 HashSet(std::collections::hash_set::Iter<'a, u64>),
295}
296
297impl<'a> Iterator for AllowedSetIter<'a> {
298 type Item = u64;
299
300 fn next(&mut self) -> Option<Self::Item> {
301 match self {
302 Self::Empty => None,
303 Self::Bitmap(iter) => iter.next(),
304 Self::SortedVec(iter) => iter.next().copied(),
305 Self::HashSet(iter) => iter.next().copied(),
306 }
307 }
308
309 fn size_hint(&self) -> (usize, Option<usize>) {
310 match self {
311 Self::Empty => (0, Some(0)),
312 Self::Bitmap(iter) => iter.size_hint(),
313 Self::SortedVec(iter) => iter.size_hint(),
314 Self::HashSet(iter) => iter.size_hint(),
315 }
316 }
317}
318
319pub struct AllowedBitmap {
328 words: Vec<u64>,
330 count: usize,
332 all: bool,
334}
335
336impl AllowedBitmap {
337 pub fn new() -> Self {
339 Self {
340 words: Vec::new(),
341 count: 0,
342 all: false,
343 }
344 }
345
346 pub fn all(max_id: u64) -> Self {
348 let word_count = (max_id as usize / 64) + 1;
349 Self {
350 words: vec![u64::MAX; word_count],
351 count: max_id as usize + 1,
352 all: true,
353 }
354 }
355
356 pub fn from_ids(ids: &[u64]) -> Self {
358 if ids.is_empty() {
359 return Self::new();
360 }
361
362 let max_id = *ids.iter().max().unwrap();
363 let word_count = (max_id as usize / 64) + 1;
364 let mut words = vec![0u64; word_count];
365
366 for &id in ids {
367 let word_idx = id as usize / 64;
368 let bit_idx = id % 64;
369 words[word_idx] |= 1 << bit_idx;
370 }
371
372 Self {
373 words,
374 count: ids.len(),
375 all: false,
376 }
377 }
378
379 pub fn set(&mut self, id: u64) {
381 let word_idx = id as usize / 64;
382 let bit_idx = id % 64;
383
384 if word_idx >= self.words.len() {
386 self.words.resize(word_idx + 1, 0);
387 }
388
389 let old = self.words[word_idx];
390 self.words[word_idx] |= 1 << bit_idx;
391 if old != self.words[word_idx] {
392 self.count += 1;
393 }
394 }
395
396 pub fn clear(&mut self, id: u64) {
398 let word_idx = id as usize / 64;
399 if word_idx >= self.words.len() {
400 return;
401 }
402
403 let bit_idx = id % 64;
404 let old = self.words[word_idx];
405 self.words[word_idx] &= !(1 << bit_idx);
406 if old != self.words[word_idx] {
407 self.count -= 1;
408 }
409 }
410
411 #[inline]
413 pub fn contains(&self, id: u64) -> bool {
414 let word_idx = id as usize / 64;
415 if word_idx >= self.words.len() {
416 return false;
417 }
418 let bit_idx = id % 64;
419 (self.words[word_idx] & (1 << bit_idx)) != 0
420 }
421
422 pub fn count(&self) -> usize {
424 self.count
425 }
426
427 pub fn is_empty(&self) -> bool {
429 self.count == 0
430 }
431
432 pub fn is_all(&self) -> bool {
434 self.all
435 }
436
437 pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
439 let min_len = self.words.len().min(other.words.len());
440 let mut words = Vec::with_capacity(min_len);
441 let mut count = 0;
442
443 for i in 0..min_len {
444 let word = self.words[i] & other.words[i];
445 count += word.count_ones() as usize;
446 words.push(word);
447 }
448
449 AllowedBitmap {
450 words,
451 count,
452 all: false,
453 }
454 }
455
456 pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
458 let max_len = self.words.len().max(other.words.len());
459 let mut words = Vec::with_capacity(max_len);
460 let mut count = 0;
461
462 for i in 0..max_len {
463 let a = self.words.get(i).copied().unwrap_or(0);
464 let b = other.words.get(i).copied().unwrap_or(0);
465 let word = a | b;
466 count += word.count_ones() as usize;
467 words.push(word);
468 }
469
470 AllowedBitmap {
471 words,
472 count,
473 all: false,
474 }
475 }
476
477 pub fn iter(&self) -> BitmapIter<'_> {
479 BitmapIter {
480 words: &self.words,
481 word_idx: 0,
482 bit_offset: 0,
483 remaining: self.count,
484 }
485 }
486}
487
488impl Default for AllowedBitmap {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494pub struct BitmapIter<'a> {
496 words: &'a [u64],
497 word_idx: usize,
498 bit_offset: u64,
499 remaining: usize,
500}
501
502impl<'a> Iterator for BitmapIter<'a> {
503 type Item = u64;
504
505 fn next(&mut self) -> Option<Self::Item> {
506 if self.remaining == 0 {
507 return None;
508 }
509
510 while self.word_idx < self.words.len() {
511 let word = self.words[self.word_idx];
512 let masked = word >> self.bit_offset;
513
514 if masked != 0 {
515 let trailing = masked.trailing_zeros() as u64;
516 let bit_pos = self.bit_offset + trailing;
517 self.bit_offset = bit_pos + 1;
518
519 if self.bit_offset >= 64 {
520 self.bit_offset = 0;
521 self.word_idx += 1;
522 }
523
524 self.remaining -= 1;
525 return Some(
526 self.word_idx as u64 * 64 + bit_pos
527 - (if self.bit_offset == 0 { 64 } else { 0 })
528 + (if bit_pos >= 64 { 0 } else { bit_pos }),
529 );
530 }
531
532 self.word_idx += 1;
533 self.bit_offset = 0;
534 }
535
536 None
537 }
538
539 fn size_hint(&self) -> (usize, Option<usize>) {
540 (self.remaining, Some(self.remaining))
541 }
542}
543
544impl<'a> BitmapIter<'a> {
546 #[allow(dead_code)]
547 fn new(words: &'a [u64], count: usize) -> Self {
548 Self {
549 words,
550 word_idx: 0,
551 bit_offset: 0,
552 remaining: count,
553 }
554 }
555}
556
557impl AllowedBitmap {
559 pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
561 self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
562 (0..64).filter_map(move |bit| {
563 if (word & (1 << bit)) != 0 {
564 Some(word_idx as u64 * 64 + bit as u64)
565 } else {
566 None
567 }
568 })
569 })
570 }
571}
572
573pub trait CandidateGate {
581 type Query;
583
584 type Result;
586
587 type Error;
589
590 fn execute_with_gate(
598 &self,
599 query: &Self::Query,
600 allowed_set: &AllowedSet,
601 ) -> Result<Self::Result, Self::Error>;
602
603 fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
605 if selectivity >= 0.1 {
606 ExecutionStrategy::FilterDuringSearch
607 } else if selectivity >= 0.001 {
608 ExecutionStrategy::ScanAllowedIds
609 } else {
610 ExecutionStrategy::LinearScan
611 }
612 }
613}
614
615#[derive(Debug, Clone, Copy, PartialEq, Eq)]
617pub enum ExecutionStrategy {
618 FilterDuringSearch,
620
621 ScanAllowedIds,
623
624 LinearScan,
626
627 Reject,
629}
630
631#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_allowed_set_contains() {
641 let all = AllowedSet::All;
643 assert!(all.contains(0));
644 assert!(all.contains(1000000));
645
646 let none = AllowedSet::None;
648 assert!(!none.contains(0));
649
650 let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
652 assert!(vec.contains(1));
653 assert!(vec.contains(5));
654 assert!(!vec.contains(2));
655 assert!(!vec.contains(10));
656
657 let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
659 assert!(set.contains(1));
660 assert!(set.contains(5));
661 assert!(!set.contains(2));
662 }
663
664 #[test]
665 fn test_allowed_set_selectivity() {
666 let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
667
668 assert_eq!(set.selectivity(100), 0.05);
669 assert_eq!(set.selectivity(10), 0.5);
670
671 assert_eq!(AllowedSet::All.selectivity(100), 1.0);
672 assert_eq!(AllowedSet::None.selectivity(100), 0.0);
673 }
674
675 #[test]
676 fn test_allowed_set_intersection() {
677 let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
678 let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
679
680 let c = a.intersect(&b);
681 assert_eq!(c.cardinality(), Some(3));
682 assert!(c.contains(3));
683 assert!(c.contains(4));
684 assert!(c.contains(5));
685 assert!(!c.contains(1));
686 assert!(!c.contains(7));
687 }
688
689 #[test]
690 fn test_bitmap_basic() {
691 let mut bm = AllowedBitmap::new();
692 bm.set(0);
693 bm.set(5);
694 bm.set(64);
695 bm.set(100);
696
697 assert!(bm.contains(0));
698 assert!(bm.contains(5));
699 assert!(bm.contains(64));
700 assert!(bm.contains(100));
701 assert!(!bm.contains(1));
702 assert!(!bm.contains(63));
703
704 assert_eq!(bm.count(), 4);
705 }
706
707 #[test]
708 fn test_bitmap_from_ids() {
709 let ids = vec![1, 5, 10, 100, 1000];
710 let bm = AllowedBitmap::from_ids(&ids);
711
712 for &id in &ids {
713 assert!(bm.contains(id));
714 }
715 assert!(!bm.contains(0));
716 assert!(!bm.contains(50));
717 }
718
719 #[test]
720 fn test_bitmap_intersection() {
721 let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
722 let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
723
724 let c = a.intersect(&b);
725 assert_eq!(c.count(), 3);
726 assert!(c.contains(3));
727 assert!(c.contains(4));
728 assert!(c.contains(5));
729 }
730
731 #[test]
732 fn test_execution_strategy() {
733 struct DummyGate;
734 impl CandidateGate for DummyGate {
735 type Query = ();
736 type Result = ();
737 type Error = ();
738 fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
739 Ok(())
740 }
741 }
742
743 let gate = DummyGate;
744 assert_eq!(
745 gate.strategy_for_selectivity(0.5),
746 ExecutionStrategy::FilterDuringSearch
747 );
748 assert_eq!(
749 gate.strategy_for_selectivity(0.01),
750 ExecutionStrategy::ScanAllowedIds
751 );
752 assert_eq!(
753 gate.strategy_for_selectivity(0.0001),
754 ExecutionStrategy::LinearScan
755 );
756 }
757}