1use std::collections::HashSet;
50use std::fmt;
51use std::sync::Arc;
52
53#[derive(Clone)]
62pub enum AllowedSet {
63 All,
65
66 Bitmap(Arc<AllowedBitmap>),
68
69 SortedVec(Arc<Vec<u64>>),
71
72 HashSet(Arc<HashSet<u64>>),
74
75 None,
77}
78
79impl AllowedSet {
80 pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
82 if bitmap.is_empty() {
83 Self::None
84 } else if bitmap.is_all() {
85 Self::All
86 } else {
87 Self::Bitmap(Arc::new(bitmap))
88 }
89 }
90
91 pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
93 if ids.is_empty() {
94 return Self::None;
95 }
96 ids.sort_unstable();
97 ids.dedup();
98 Self::SortedVec(Arc::new(ids))
99 }
100
101 pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
103 let set: HashSet<u64> = ids.into_iter().collect();
104 if set.is_empty() {
105 Self::None
106 } else {
107 Self::HashSet(Arc::new(set))
108 }
109 }
110
111 #[inline]
115 pub fn contains(&self, doc_id: u64) -> bool {
116 match self {
117 Self::All => true,
118 Self::Bitmap(bm) => bm.contains(doc_id),
119 Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
120 Self::HashSet(set) => set.contains(&doc_id),
121 Self::None => false,
122 }
123 }
124
125 pub fn is_empty(&self) -> bool {
127 matches!(self, Self::None)
128 }
129
130 pub fn is_all(&self) -> bool {
132 matches!(self, Self::All)
133 }
134
135 pub fn cardinality(&self) -> Option<usize> {
139 match self {
140 Self::All => None,
141 Self::Bitmap(bm) => Some(bm.count()),
142 Self::SortedVec(vec) => Some(vec.len()),
143 Self::HashSet(set) => Some(set.len()),
144 Self::None => Some(0),
145 }
146 }
147
148 pub fn selectivity(&self, universe_size: usize) -> f64 {
152 if universe_size == 0 {
153 return 0.0;
154 }
155 match self {
156 Self::All => 1.0,
157 Self::None => 0.0,
158 other => {
159 other.cardinality()
160 .map(|c| c as f64 / universe_size as f64)
161 .unwrap_or(1.0)
162 }
163 }
164 }
165
166 pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
168 match (self, other) {
169 (Self::All, x) | (x, Self::All) => x.clone(),
171 (Self::None, _) | (_, Self::None) => Self::None,
172
173 (Self::SortedVec(a), Self::SortedVec(b)) => {
175 let result = sorted_vec_intersect(a, b);
176 Self::from_sorted_vec(result)
177 }
178 (Self::HashSet(a), Self::HashSet(b)) => {
179 let result: HashSet<_> = a.intersection(b).copied().collect();
180 if result.is_empty() {
181 Self::None
182 } else {
183 Self::HashSet(Arc::new(result))
184 }
185 }
186 (Self::Bitmap(a), Self::Bitmap(b)) => {
187 let result = a.intersect(b);
188 Self::from_bitmap(result)
189 }
190
191 (a, b) => {
193 let set_a: HashSet<u64> = a.iter().collect();
194 let set_b: HashSet<u64> = b.iter().collect();
195 let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
196 if result.is_empty() {
197 Self::None
198 } else {
199 Self::HashSet(Arc::new(result))
200 }
201 }
202 }
203 }
204
205 pub fn union(&self, other: &AllowedSet) -> AllowedSet {
207 match (self, other) {
208 (Self::All, _) | (_, Self::All) => Self::All,
209 (Self::None, x) | (x, Self::None) => x.clone(),
210
211 (Self::HashSet(a), Self::HashSet(b)) => {
212 let result: HashSet<_> = a.union(b).copied().collect();
213 Self::HashSet(Arc::new(result))
214 }
215
216 (a, b) => {
218 let mut result: HashSet<u64> = a.iter().collect();
219 result.extend(b.iter());
220 Self::HashSet(Arc::new(result))
221 }
222 }
223 }
224
225 pub fn iter(&self) -> AllowedSetIter<'_> {
229 match self {
230 Self::All => AllowedSetIter::Empty,
231 Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
232 Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
233 Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
234 Self::None => AllowedSetIter::Empty,
235 }
236 }
237
238 pub fn to_vec(&self) -> Vec<u64> {
240 self.iter().collect()
241 }
242}
243
244impl fmt::Debug for AllowedSet {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 match self {
247 Self::All => write!(f, "AllowedSet::All"),
248 Self::None => write!(f, "AllowedSet::None"),
249 Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
250 Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
251 Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
252 }
253 }
254}
255
256impl Default for AllowedSet {
257 fn default() -> Self {
258 Self::All
259 }
260}
261
262fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
264 let mut result = Vec::with_capacity(a.len().min(b.len()));
265 let mut i = 0;
266 let mut j = 0;
267
268 while i < a.len() && j < b.len() {
269 match a[i].cmp(&b[j]) {
270 std::cmp::Ordering::Less => i += 1,
271 std::cmp::Ordering::Greater => j += 1,
272 std::cmp::Ordering::Equal => {
273 result.push(a[i]);
274 i += 1;
275 j += 1;
276 }
277 }
278 }
279
280 result
281}
282
283pub enum AllowedSetIter<'a> {
289 Empty,
290 Bitmap(BitmapIter<'a>),
291 SortedVec(std::slice::Iter<'a, u64>),
292 HashSet(std::collections::hash_set::Iter<'a, u64>),
293}
294
295impl<'a> Iterator for AllowedSetIter<'a> {
296 type Item = u64;
297
298 fn next(&mut self) -> Option<Self::Item> {
299 match self {
300 Self::Empty => None,
301 Self::Bitmap(iter) => iter.next(),
302 Self::SortedVec(iter) => iter.next().copied(),
303 Self::HashSet(iter) => iter.next().copied(),
304 }
305 }
306
307 fn size_hint(&self) -> (usize, Option<usize>) {
308 match self {
309 Self::Empty => (0, Some(0)),
310 Self::Bitmap(iter) => iter.size_hint(),
311 Self::SortedVec(iter) => iter.size_hint(),
312 Self::HashSet(iter) => iter.size_hint(),
313 }
314 }
315}
316
317pub struct AllowedBitmap {
326 words: Vec<u64>,
328 count: usize,
330 all: bool,
332}
333
334impl AllowedBitmap {
335 pub fn new() -> Self {
337 Self {
338 words: Vec::new(),
339 count: 0,
340 all: false,
341 }
342 }
343
344 pub fn all(max_id: u64) -> Self {
346 let word_count = (max_id as usize / 64) + 1;
347 Self {
348 words: vec![u64::MAX; word_count],
349 count: max_id as usize + 1,
350 all: true,
351 }
352 }
353
354 pub fn from_ids(ids: &[u64]) -> Self {
356 if ids.is_empty() {
357 return Self::new();
358 }
359
360 let max_id = *ids.iter().max().unwrap();
361 let word_count = (max_id as usize / 64) + 1;
362 let mut words = vec![0u64; word_count];
363
364 for &id in ids {
365 let word_idx = id as usize / 64;
366 let bit_idx = id % 64;
367 words[word_idx] |= 1 << bit_idx;
368 }
369
370 Self {
371 words,
372 count: ids.len(),
373 all: false,
374 }
375 }
376
377 pub fn set(&mut self, id: u64) {
379 let word_idx = id as usize / 64;
380 let bit_idx = id % 64;
381
382 if word_idx >= self.words.len() {
384 self.words.resize(word_idx + 1, 0);
385 }
386
387 let old = self.words[word_idx];
388 self.words[word_idx] |= 1 << bit_idx;
389 if old != self.words[word_idx] {
390 self.count += 1;
391 }
392 }
393
394 pub fn clear(&mut self, id: u64) {
396 let word_idx = id as usize / 64;
397 if word_idx >= self.words.len() {
398 return;
399 }
400
401 let bit_idx = id % 64;
402 let old = self.words[word_idx];
403 self.words[word_idx] &= !(1 << bit_idx);
404 if old != self.words[word_idx] {
405 self.count -= 1;
406 }
407 }
408
409 #[inline]
411 pub fn contains(&self, id: u64) -> bool {
412 let word_idx = id as usize / 64;
413 if word_idx >= self.words.len() {
414 return false;
415 }
416 let bit_idx = id % 64;
417 (self.words[word_idx] & (1 << bit_idx)) != 0
418 }
419
420 pub fn count(&self) -> usize {
422 self.count
423 }
424
425 pub fn is_empty(&self) -> bool {
427 self.count == 0
428 }
429
430 pub fn is_all(&self) -> bool {
432 self.all
433 }
434
435 pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
437 let min_len = self.words.len().min(other.words.len());
438 let mut words = Vec::with_capacity(min_len);
439 let mut count = 0;
440
441 for i in 0..min_len {
442 let word = self.words[i] & other.words[i];
443 count += word.count_ones() as usize;
444 words.push(word);
445 }
446
447 AllowedBitmap {
448 words,
449 count,
450 all: false,
451 }
452 }
453
454 pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
456 let max_len = self.words.len().max(other.words.len());
457 let mut words = Vec::with_capacity(max_len);
458 let mut count = 0;
459
460 for i in 0..max_len {
461 let a = self.words.get(i).copied().unwrap_or(0);
462 let b = other.words.get(i).copied().unwrap_or(0);
463 let word = a | b;
464 count += word.count_ones() as usize;
465 words.push(word);
466 }
467
468 AllowedBitmap {
469 words,
470 count,
471 all: false,
472 }
473 }
474
475 pub fn iter(&self) -> BitmapIter<'_> {
477 BitmapIter {
478 words: &self.words,
479 word_idx: 0,
480 bit_offset: 0,
481 remaining: self.count,
482 }
483 }
484}
485
486impl Default for AllowedBitmap {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492pub struct BitmapIter<'a> {
494 words: &'a [u64],
495 word_idx: usize,
496 bit_offset: u64,
497 remaining: usize,
498}
499
500impl<'a> Iterator for BitmapIter<'a> {
501 type Item = u64;
502
503 fn next(&mut self) -> Option<Self::Item> {
504 if self.remaining == 0 {
505 return None;
506 }
507
508 while self.word_idx < self.words.len() {
509 let word = self.words[self.word_idx];
510 let masked = word >> self.bit_offset;
511
512 if masked != 0 {
513 let trailing = masked.trailing_zeros() as u64;
514 let bit_pos = self.bit_offset + trailing;
515 self.bit_offset = bit_pos + 1;
516
517 if self.bit_offset >= 64 {
518 self.bit_offset = 0;
519 self.word_idx += 1;
520 }
521
522 self.remaining -= 1;
523 return Some(self.word_idx as u64 * 64 + bit_pos - (if self.bit_offset == 0 { 64 } else { 0 }) + (if bit_pos >= 64 { 0 } else { bit_pos }));
524 }
525
526 self.word_idx += 1;
527 self.bit_offset = 0;
528 }
529
530 None
531 }
532
533 fn size_hint(&self) -> (usize, Option<usize>) {
534 (self.remaining, Some(self.remaining))
535 }
536}
537
538impl<'a> BitmapIter<'a> {
540 #[allow(dead_code)]
541 fn new(words: &'a [u64], count: usize) -> Self {
542 Self {
543 words,
544 word_idx: 0,
545 bit_offset: 0,
546 remaining: count,
547 }
548 }
549}
550
551impl AllowedBitmap {
553 pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
555 self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
556 (0..64).filter_map(move |bit| {
557 if (word & (1 << bit)) != 0 {
558 Some(word_idx as u64 * 64 + bit as u64)
559 } else {
560 None
561 }
562 })
563 })
564 }
565}
566
567pub trait CandidateGate {
575 type Query;
577
578 type Result;
580
581 type Error;
583
584 fn execute_with_gate(
592 &self,
593 query: &Self::Query,
594 allowed_set: &AllowedSet,
595 ) -> Result<Self::Result, Self::Error>;
596
597 fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
599 if selectivity >= 0.1 {
600 ExecutionStrategy::FilterDuringSearch
601 } else if selectivity >= 0.001 {
602 ExecutionStrategy::ScanAllowedIds
603 } else {
604 ExecutionStrategy::LinearScan
605 }
606 }
607}
608
609#[derive(Debug, Clone, Copy, PartialEq, Eq)]
611pub enum ExecutionStrategy {
612 FilterDuringSearch,
614
615 ScanAllowedIds,
617
618 LinearScan,
620
621 Reject,
623}
624
625#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_allowed_set_contains() {
635 let all = AllowedSet::All;
637 assert!(all.contains(0));
638 assert!(all.contains(1000000));
639
640 let none = AllowedSet::None;
642 assert!(!none.contains(0));
643
644 let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
646 assert!(vec.contains(1));
647 assert!(vec.contains(5));
648 assert!(!vec.contains(2));
649 assert!(!vec.contains(10));
650
651 let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
653 assert!(set.contains(1));
654 assert!(set.contains(5));
655 assert!(!set.contains(2));
656 }
657
658 #[test]
659 fn test_allowed_set_selectivity() {
660 let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
661
662 assert_eq!(set.selectivity(100), 0.05);
663 assert_eq!(set.selectivity(10), 0.5);
664
665 assert_eq!(AllowedSet::All.selectivity(100), 1.0);
666 assert_eq!(AllowedSet::None.selectivity(100), 0.0);
667 }
668
669 #[test]
670 fn test_allowed_set_intersection() {
671 let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
672 let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
673
674 let c = a.intersect(&b);
675 assert_eq!(c.cardinality(), Some(3));
676 assert!(c.contains(3));
677 assert!(c.contains(4));
678 assert!(c.contains(5));
679 assert!(!c.contains(1));
680 assert!(!c.contains(7));
681 }
682
683 #[test]
684 fn test_bitmap_basic() {
685 let mut bm = AllowedBitmap::new();
686 bm.set(0);
687 bm.set(5);
688 bm.set(64);
689 bm.set(100);
690
691 assert!(bm.contains(0));
692 assert!(bm.contains(5));
693 assert!(bm.contains(64));
694 assert!(bm.contains(100));
695 assert!(!bm.contains(1));
696 assert!(!bm.contains(63));
697
698 assert_eq!(bm.count(), 4);
699 }
700
701 #[test]
702 fn test_bitmap_from_ids() {
703 let ids = vec![1, 5, 10, 100, 1000];
704 let bm = AllowedBitmap::from_ids(&ids);
705
706 for &id in &ids {
707 assert!(bm.contains(id));
708 }
709 assert!(!bm.contains(0));
710 assert!(!bm.contains(50));
711 }
712
713 #[test]
714 fn test_bitmap_intersection() {
715 let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
716 let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
717
718 let c = a.intersect(&b);
719 assert_eq!(c.count(), 3);
720 assert!(c.contains(3));
721 assert!(c.contains(4));
722 assert!(c.contains(5));
723 }
724
725 #[test]
726 fn test_execution_strategy() {
727 struct DummyGate;
728 impl CandidateGate for DummyGate {
729 type Query = ();
730 type Result = ();
731 type Error = ();
732 fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
733 Ok(())
734 }
735 }
736
737 let gate = DummyGate;
738 assert_eq!(gate.strategy_for_selectivity(0.5), ExecutionStrategy::FilterDuringSearch);
739 assert_eq!(gate.strategy_for_selectivity(0.01), ExecutionStrategy::ScanAllowedIds);
740 assert_eq!(gate.strategy_for_selectivity(0.0001), ExecutionStrategy::LinearScan);
741 }
742}