Skip to main content

rvf_runtime/
membership.rs

1//! Membership filter for shared HNSW index traversal.
2//!
3//! Include mode (default): vector visible iff `filter.contains(id)`.
4//! Exclude mode: vector visible iff `!filter.contains(id)`.
5//!
6//! Empty filter in include mode = empty view (fail-safe).
7//!
8//! HNSW traversal integration:
9//! - Excluded nodes MAY be pushed onto exploration heap (routing waypoints)
10//! - Excluded nodes MUST NOT be pushed onto result heap
11//! - Excluded nodes DO NOT decrement `ef_remaining`
12
13use rvf_types::membership::{FilterMode, MembershipHeader, MEMBERSHIP_MAGIC};
14use rvf_types::{ErrorCode, RvfError};
15
16/// Membership filter backed by a dense bitmap.
17pub struct MembershipFilter {
18    /// Include or exclude mode.
19    mode: FilterMode,
20    /// Dense bit vector: one bit per vector ID.
21    bitmap: Vec<u64>,
22    /// Total vector count (capacity of the filter).
23    vector_count: u64,
24    /// Number of set bits (members).
25    member_count: u64,
26    /// Generation counter for optimistic concurrency.
27    generation_id: u32,
28}
29
30impl MembershipFilter {
31    /// Create a new include-mode filter with given capacity. All bits start clear.
32    pub fn new_include(vector_count: u64) -> Self {
33        let words = vector_count.div_ceil(64) as usize;
34        Self {
35            mode: FilterMode::Include,
36            bitmap: vec![0u64; words],
37            vector_count,
38            member_count: 0,
39            generation_id: 0,
40        }
41    }
42
43    /// Create a new exclude-mode filter with given capacity. All bits start clear.
44    pub fn new_exclude(vector_count: u64) -> Self {
45        let words = vector_count.div_ceil(64) as usize;
46        Self {
47            mode: FilterMode::Exclude,
48            bitmap: vec![0u64; words],
49            vector_count,
50            member_count: 0,
51            generation_id: 0,
52        }
53    }
54
55    /// Add a vector ID to the filter.
56    pub fn add(&mut self, vector_id: u64) {
57        if vector_id >= self.vector_count {
58            return;
59        }
60        let word = (vector_id / 64) as usize;
61        let bit = vector_id % 64;
62        if word < self.bitmap.len() {
63            let mask = 1u64 << bit;
64            if self.bitmap[word] & mask == 0 {
65                self.bitmap[word] |= mask;
66                self.member_count += 1;
67            }
68        }
69    }
70
71    /// Remove a vector ID from the filter.
72    pub fn remove(&mut self, vector_id: u64) {
73        if vector_id >= self.vector_count {
74            return;
75        }
76        let word = (vector_id / 64) as usize;
77        let bit = vector_id % 64;
78        if word < self.bitmap.len() {
79            let mask = 1u64 << bit;
80            if self.bitmap[word] & mask != 0 {
81                self.bitmap[word] &= !mask;
82                self.member_count -= 1;
83            }
84        }
85    }
86
87    /// Check if a vector ID is in the filter bitmap.
88    fn bitmap_contains(&self, vector_id: u64) -> bool {
89        if vector_id >= self.vector_count {
90            return false;
91        }
92        let word = (vector_id / 64) as usize;
93        let bit = vector_id % 64;
94        if word < self.bitmap.len() {
95            self.bitmap[word] & (1u64 << bit) != 0
96        } else {
97            false
98        }
99    }
100
101    /// Check if a vector ID is visible through this filter.
102    ///
103    /// In Include mode: visible iff the bit is set.
104    /// In Exclude mode: visible iff the bit is NOT set.
105    pub fn contains(&self, vector_id: u64) -> bool {
106        match self.mode {
107            FilterMode::Include => self.bitmap_contains(vector_id),
108            FilterMode::Exclude => !self.bitmap_contains(vector_id),
109        }
110    }
111
112    /// Number of set bits (members in the bitmap).
113    pub fn member_count(&self) -> u64 {
114        self.member_count
115    }
116
117    /// Total vector capacity.
118    pub fn vector_count(&self) -> u64 {
119        self.vector_count
120    }
121
122    /// Filter mode.
123    pub fn mode(&self) -> FilterMode {
124        self.mode
125    }
126
127    /// Generation ID.
128    pub fn generation_id(&self) -> u32 {
129        self.generation_id
130    }
131
132    /// Increment generation ID.
133    pub fn bump_generation(&mut self) {
134        self.generation_id += 1;
135    }
136
137    /// Serialize the bitmap to bytes (just the raw bitmap words).
138    pub fn serialize(&self) -> Vec<u8> {
139        let mut buf = Vec::with_capacity(self.bitmap.len() * 8);
140        for &word in &self.bitmap {
141            buf.extend_from_slice(&word.to_le_bytes());
142        }
143        buf
144    }
145
146    /// Deserialize a MembershipFilter from bitmap bytes and a header.
147    pub fn deserialize(data: &[u8], header: &MembershipHeader) -> Result<Self, RvfError> {
148        let mode = FilterMode::try_from(header.filter_mode)
149            .map_err(|_| RvfError::Code(ErrorCode::MembershipInvalid))?;
150
151        let word_count = header.vector_count.div_ceil(64) as usize;
152        let expected_bytes = word_count * 8;
153        if data.len() < expected_bytes {
154            return Err(RvfError::Code(ErrorCode::MembershipInvalid));
155        }
156
157        let mut bitmap = Vec::with_capacity(word_count);
158        for i in 0..word_count {
159            let offset = i * 8;
160            let word = u64::from_le_bytes([
161                data[offset],
162                data[offset + 1],
163                data[offset + 2],
164                data[offset + 3],
165                data[offset + 4],
166                data[offset + 5],
167                data[offset + 6],
168                data[offset + 7],
169            ]);
170            bitmap.push(word);
171        }
172
173        // Recount set bits
174        let member_count: u64 = bitmap.iter().map(|w| w.count_ones() as u64).sum();
175
176        Ok(Self {
177            mode,
178            bitmap,
179            vector_count: header.vector_count,
180            member_count,
181            generation_id: header.generation_id,
182        })
183    }
184
185    /// Build a MembershipHeader for this filter.
186    pub fn to_header(&self) -> MembershipHeader {
187        let bitmap_bytes = self.serialize();
188        let filter_hash = crate::store::simple_shake256_256(&bitmap_bytes);
189
190        MembershipHeader {
191            magic: MEMBERSHIP_MAGIC,
192            version: 1,
193            filter_type: rvf_types::membership::FilterType::Bitmap as u8,
194            filter_mode: self.mode as u8,
195            vector_count: self.vector_count,
196            member_count: self.member_count,
197            filter_offset: 96, // right after header
198            filter_size: bitmap_bytes.len() as u32,
199            generation_id: self.generation_id,
200            filter_hash,
201            bloom_offset: 0,
202            bloom_size: 0,
203            _reserved: 0,
204            _reserved2: [0u8; 8],
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn include_mode_empty_is_empty_view() {
215        let filter = MembershipFilter::new_include(100);
216        for i in 0..100 {
217            assert!(!filter.contains(i));
218        }
219    }
220
221    #[test]
222    fn include_mode_add_and_check() {
223        let mut filter = MembershipFilter::new_include(100);
224        filter.add(10);
225        filter.add(50);
226        filter.add(99);
227
228        assert!(filter.contains(10));
229        assert!(filter.contains(50));
230        assert!(filter.contains(99));
231        assert!(!filter.contains(0));
232        assert!(!filter.contains(11));
233        assert_eq!(filter.member_count(), 3);
234    }
235
236    #[test]
237    fn exclude_mode() {
238        let mut filter = MembershipFilter::new_exclude(100);
239        // In exclude mode, all are visible when bitmap is empty
240        assert!(filter.contains(0));
241        assert!(filter.contains(50));
242
243        // Add to bitmap means "exclude this vector"
244        filter.add(50);
245        assert!(!filter.contains(50));
246        assert!(filter.contains(0));
247        assert!(filter.contains(99));
248    }
249
250    #[test]
251    fn add_remove() {
252        let mut filter = MembershipFilter::new_include(64);
253        filter.add(10);
254        assert_eq!(filter.member_count(), 1);
255        assert!(filter.contains(10));
256
257        filter.remove(10);
258        assert_eq!(filter.member_count(), 0);
259        assert!(!filter.contains(10));
260    }
261
262    #[test]
263    fn add_out_of_bounds_ignored() {
264        let mut filter = MembershipFilter::new_include(10);
265        filter.add(100); // beyond vector_count
266        assert_eq!(filter.member_count(), 0);
267    }
268
269    #[test]
270    fn double_add_no_double_count() {
271        let mut filter = MembershipFilter::new_include(64);
272        filter.add(5);
273        filter.add(5);
274        assert_eq!(filter.member_count(), 1);
275    }
276
277    #[test]
278    fn serialize_deserialize_round_trip() {
279        let mut filter = MembershipFilter::new_include(200);
280        filter.add(0);
281        filter.add(63);
282        filter.add(64);
283        filter.add(127);
284        filter.add(199);
285
286        let header = filter.to_header();
287        let bitmap_data = filter.serialize();
288
289        let filter2 = MembershipFilter::deserialize(&bitmap_data, &header).unwrap();
290        assert_eq!(filter2.vector_count(), 200);
291        assert_eq!(filter2.member_count(), 5);
292        assert!(filter2.contains(0));
293        assert!(filter2.contains(63));
294        assert!(filter2.contains(64));
295        assert!(filter2.contains(127));
296        assert!(filter2.contains(199));
297        assert!(!filter2.contains(1));
298        assert!(!filter2.contains(100));
299    }
300
301    #[test]
302    fn generation_bump() {
303        let mut filter = MembershipFilter::new_include(10);
304        assert_eq!(filter.generation_id(), 0);
305        filter.bump_generation();
306        assert_eq!(filter.generation_id(), 1);
307    }
308
309    #[test]
310    fn bitmap_word_boundary() {
311        // Test vectors near 64-bit word boundaries
312        let mut filter = MembershipFilter::new_include(130);
313        filter.add(63);
314        filter.add(64);
315        filter.add(128);
316
317        assert!(filter.contains(63));
318        assert!(filter.contains(64));
319        assert!(filter.contains(128));
320        assert!(!filter.contains(62));
321        assert!(!filter.contains(65));
322        assert!(!filter.contains(129));
323    }
324}