Skip to main content

sochdb_vector/segment/
rdf.rs

1//! RDF (Rare-Dominant Fingerprint) builder and posting list handling.
2//!
3//! RDF uses IR-style inverted lists over a sparse fingerprint of each vector.
4//! Posting lists are stored in VID-striped chunks for cache-friendly scoring.
5
6use crate::config::RdfConfig;
7use crate::segment::format::PostingListEntry;
8use crate::types::*;
9use std::collections::HashMap;
10
11/// Builder for RDF posting lists
12pub struct RdfBuilder<'a> {
13    config: &'a RdfConfig,
14    dim: u32,
15    vectors: &'a [Vec<f32>],
16    dim_weights: Vec<f32>,
17    doc_freqs: Vec<u32>,
18}
19
20impl<'a> RdfBuilder<'a> {
21    /// Create a new RDF builder
22    pub fn new(config: &'a RdfConfig, dim: u32, rotated_vectors: &'a [Vec<f32>]) -> Self {
23        let n_vec = rotated_vectors.len();
24        let dim_usize = dim as usize;
25
26        // Compute dimension statistics
27        let mut sum = vec![0.0f64; dim_usize];
28        let mut sum_sq = vec![0.0f64; dim_usize];
29        let mut doc_freqs = vec![0u32; dim_usize];
30
31        // Track which dims appear in each vector's top-t
32        let top_t = config.top_t as usize;
33
34        for vec in rotated_vectors {
35            // Find top-t dims by absolute value
36            let mut scored: Vec<(usize, f32)> =
37                vec.iter().enumerate().map(|(i, &v)| (i, v.abs())).collect();
38            let nth_idx = top_t.min(scored.len()).saturating_sub(1);
39            if nth_idx < scored.len() {
40                scored.select_nth_unstable_by(nth_idx, |a, b| b.1.partial_cmp(&a.1).unwrap());
41            }
42
43            for &(dim_idx, _) in scored.iter().take(top_t) {
44                doc_freqs[dim_idx] += 1;
45            }
46
47            for (i, &v) in vec.iter().enumerate() {
48                sum[i] += v as f64;
49                sum_sq[i] += (v * v) as f64;
50            }
51        }
52
53        // Compute dimension weights: w[d] = α·idf[d] + β·sqrt(var[d])
54        let n = n_vec as f64;
55        let mut dim_weights = Vec::with_capacity(dim_usize);
56
57        for d in 0..dim_usize {
58            let mean = sum[d] / n;
59            let var = (sum_sq[d] / n - mean * mean).max(0.0);
60            let std_dev = var.sqrt();
61
62            // IDF-like weight: log(N / df)
63            let df = doc_freqs[d].max(1) as f64;
64            let idf = (n / df).ln();
65
66            // Combined weight
67            let weight = config.idf_weight as f64 * idf + config.var_weight as f64 * std_dev;
68            dim_weights.push(weight as f32);
69        }
70
71        Self {
72            config,
73            dim,
74            vectors: rotated_vectors,
75            dim_weights,
76            doc_freqs,
77        }
78    }
79
80    /// Get dimension weights
81    pub fn dim_weights(&self) -> Vec<f32> {
82        self.dim_weights.clone()
83    }
84
85    /// Build posting lists with striped storage
86    /// Returns (directory, concatenated posting data)
87    pub fn build(&self) -> (Vec<PostingListEntry>, Vec<u8>) {
88        let dim_usize = self.dim as usize;
89        let top_t = self.config.top_t as usize;
90        let stripe_shift = self.config.stripe_shift;
91        let _stripe_size = 1usize << stripe_shift;
92
93        // Collect postings per dimension
94        // Each posting: (vid, sign, magnitude)
95        let mut dim_postings: Vec<Vec<(VectorId, bool, u8)>> = vec![Vec::new(); dim_usize];
96
97        // Compute per-dimension magnitude scales for quantization
98        let mut dim_max_mag = vec![0.0f32; dim_usize];
99
100        for (_vid, vec) in self.vectors.iter().enumerate() {
101            // Score each dim: |v[d]| * w[d]
102            let mut scored: Vec<(usize, f32, f32)> = vec
103                .iter()
104                .enumerate()
105                .map(|(d, &v)| (d, v.abs() * self.dim_weights[d], v))
106                .collect();
107
108            // Select top-t by score
109            if scored.len() > top_t {
110                scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
111                scored.truncate(top_t);
112            }
113
114            for &(dim_idx, _, value) in &scored {
115                let mag = value.abs();
116                dim_max_mag[dim_idx] = dim_max_mag[dim_idx].max(mag);
117            }
118        }
119
120        // Second pass: collect postings with quantized magnitudes
121        for (vid, vec) in self.vectors.iter().enumerate() {
122            let mut scored: Vec<(usize, f32, f32)> = vec
123                .iter()
124                .enumerate()
125                .map(|(d, &v)| (d, v.abs() * self.dim_weights[d], v))
126                .collect();
127
128            if scored.len() > top_t {
129                scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
130                scored.truncate(top_t);
131            }
132
133            for &(dim_idx, _, value) in &scored {
134                let sign = value >= 0.0;
135                let mag = value.abs();
136                let max_mag = dim_max_mag[dim_idx].max(1e-10);
137                let mag8 = ((mag / max_mag) * 127.0).min(127.0) as u8;
138
139                dim_postings[dim_idx].push((vid as VectorId, sign, mag8));
140            }
141        }
142
143        // Build striped posting lists
144        let mut directory = Vec::with_capacity(dim_usize);
145        let mut data = Vec::new();
146
147        for dim_idx in 0..dim_usize {
148            let postings = &dim_postings[dim_idx];
149
150            if postings.is_empty() {
151                directory.push(PostingListEntry {
152                    offset: data.len() as u64,
153                    length: 0,
154                    num_stripes: 0,
155                    flags: 0,
156                });
157                continue;
158            }
159
160            let offset = data.len() as u64;
161
162            // Check if this is a stop-dim
163            let is_stopword = self.doc_freqs[dim_idx] > self.config.stop_dim_threshold;
164            let flags = if is_stopword {
165                PostingListEntry::FLAG_STOPWORD
166            } else {
167                0
168            };
169
170            // Group postings by stripe
171            let mut stripes: HashMap<StripeId, Vec<(u8, bool, u8)>> = HashMap::new();
172            for &(vid, sign, mag) in postings {
173                let stripe_id = vid >> stripe_shift;
174                let vid_in_stripe = (vid & ((1 << stripe_shift) - 1)) as u8;
175                stripes
176                    .entry(stripe_id)
177                    .or_default()
178                    .push((vid_in_stripe, sign, mag));
179            }
180
181            // Sort stripes by stripe_id
182            let mut stripe_ids: Vec<StripeId> = stripes.keys().copied().collect();
183            stripe_ids.sort();
184
185            // Write stripe chunks
186            for stripe_id in &stripe_ids {
187                let entries = stripes.get(stripe_id).unwrap();
188
189                // Write stripe header
190                let header = StripeChunkHeader {
191                    stripe_id: *stripe_id,
192                    count: entries.len() as u16,
193                    _pad: 0,
194                };
195                data.extend_from_slice(bytemuck::bytes_of(&header));
196
197                // Write entries sorted by vid_in_stripe
198                let mut sorted_entries = entries.clone();
199                sorted_entries.sort_by_key(|e| e.0);
200
201                for (vid_in_stripe, sign, mag) in sorted_entries {
202                    let posting = RdfPosting::new(vid_in_stripe, sign, mag);
203                    data.extend_from_slice(bytemuck::bytes_of(&posting));
204                }
205            }
206
207            directory.push(PostingListEntry {
208                offset,
209                length: postings.len() as u32,
210                num_stripes: stripe_ids.len() as u16,
211                flags,
212            });
213        }
214
215        (directory, data)
216    }
217}
218
219/// RDF scorer for query-time candidate generation
220pub struct RdfScorer<'a> {
221    directory: &'a [PostingListEntry],
222    rdf_data: &'a [u8],
223    dim_weights: &'a [f32],
224    stripe_shift: u8,
225    stripe_size: usize,
226    n_vec: u32,
227}
228
229impl<'a> RdfScorer<'a> {
230    /// Create a new RDF scorer
231    pub fn new(
232        directory: &'a [PostingListEntry],
233        rdf_data: &'a [u8],
234        dim_weights: &'a [f32],
235        stripe_shift: u8,
236        n_vec: u32,
237    ) -> Self {
238        Self {
239            directory,
240            rdf_data,
241            dim_weights,
242            stripe_shift,
243            stripe_size: 1 << stripe_shift,
244            n_vec,
245        }
246    }
247
248    /// Score candidates using RDF
249    /// Returns top L_A candidates by score (higher = better)
250    pub fn score(&self, query: &[f32], top_t: usize, l_a: usize) -> Vec<ScoredCandidate> {
251        if self.directory.is_empty() {
252            return Vec::new();
253        }
254
255        let _dim = query.len();
256
257        // Find top-t query dimensions by |q[d]| * w[d]
258        let mut scored: Vec<(usize, f32, f32)> = query
259            .iter()
260            .enumerate()
261            .map(|(d, &v)| {
262                let w = if d < self.dim_weights.len() {
263                    self.dim_weights[d]
264                } else {
265                    1.0
266                };
267                (d, v.abs() * w, v)
268            })
269            .collect();
270
271        if scored.len() > top_t {
272            scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
273            scored.truncate(top_t);
274        }
275
276        // Get query dims (excluding stopwords — but keep dims that pass
277        // the directory bounds check even if all are stopwords)
278        let query_dims: Vec<(usize, f32, f32)> = scored
279            .into_iter()
280            .filter(|&(d, _, _)| d < self.directory.len() && !self.directory[d].is_stopword())
281            .collect();
282
283        if query_dims.is_empty() {
284            // All query dimensions were stopwords — fall back to using
285            // the original dimensions WITHOUT the stopword filter.
286            // Returning empty here would cause zero recall for common
287            // queries where all top-t dimensions happen to be stop-dims.
288            // IDF-based dim_weights will naturally downweight these.
289            let query_dims_fallback: Vec<(usize, f32, f32)> = {
290                let mut s: Vec<(usize, f32, f32)> = query
291                    .iter()
292                    .enumerate()
293                    .map(|(d, &v)| {
294                        let w = if d < self.dim_weights.len() {
295                            self.dim_weights[d]
296                        } else {
297                            1.0
298                        };
299                        (d, v.abs() * w, v)
300                    })
301                    .collect();
302                if s.len() > top_t {
303                    s.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
304                    s.truncate(top_t);
305                }
306                s.into_iter()
307                    .filter(|&(d, _, _)| d < self.directory.len())
308                    .collect()
309            };
310            if query_dims_fallback.is_empty() {
311                return Vec::new();
312            }
313            return self.score_with_dims(&query_dims_fallback, l_a);
314        }
315
316        self.score_with_dims(&query_dims, l_a)
317    }
318
319    /// Internal scoring with a given set of query dimensions.
320    /// Separated from `score()` to allow fallback when all dims are stopwords.
321    fn score_with_dims(
322        &self,
323        query_dims: &[(usize, f32, f32)],
324        l_a: usize,
325    ) -> Vec<ScoredCandidate> {
326        // Use stripe-based accumulation
327        let num_stripes = (self.n_vec as usize + self.stripe_size - 1) / self.stripe_size;
328        let mut global_candidates = Vec::new();
329
330        // Allocate stripe accumulator ONCE, clear per-stripe (avoids N allocs)
331        let mut stripe_acc = vec![0.0f32; self.stripe_size];
332
333        // Process stripe by stripe for cache locality
334        for stripe_id in 0..num_stripes as u32 {
335            // Clear accumulator (memset — vectorizes to single SIMD store)
336            stripe_acc.iter_mut().for_each(|x| *x = 0.0);
337
338            for &(dim_idx, _, q_value) in query_dims {
339                let entry = &self.directory[dim_idx];
340                if entry.length == 0 {
341                    continue;
342                }
343
344                // Find and process the stripe chunk for this dimension
345                self.accumulate_stripe(
346                    entry,
347                    stripe_id,
348                    q_value,
349                    self.dim_weights[dim_idx],
350                    &mut stripe_acc,
351                );
352            }
353
354            // Collect non-zero scores from this stripe
355            let base_vid = stripe_id << self.stripe_shift;
356            for (i, &score) in stripe_acc.iter().enumerate() {
357                if score > 0.0 {
358                    let vid = base_vid + i as u32;
359                    if vid < self.n_vec {
360                        global_candidates.push(ScoredCandidate { id: vid, score });
361                    }
362                }
363            }
364        }
365
366        // Select top L_A
367        if global_candidates.len() <= l_a {
368            global_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
369            return global_candidates;
370        }
371
372        global_candidates
373            .select_nth_unstable_by(l_a - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
374        global_candidates.truncate(l_a);
375        global_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
376
377        global_candidates
378    }
379
380    /// Accumulate scores for a specific stripe from one dimension's posting list
381    fn accumulate_stripe(
382        &self,
383        entry: &PostingListEntry,
384        target_stripe_id: StripeId,
385        q_value: f32,
386        dim_weight: f32,
387        stripe_acc: &mut [f32],
388    ) {
389        let mut offset = entry.offset as usize;
390        let header_size = std::mem::size_of::<StripeChunkHeader>();
391        let posting_size = std::mem::size_of::<RdfPosting>();
392
393        for _ in 0..entry.num_stripes {
394            if offset + header_size > self.rdf_data.len() {
395                break;
396            }
397
398            let header: StripeChunkHeader =
399                unsafe { std::ptr::read_unaligned(self.rdf_data.as_ptr().add(offset) as *const _) };
400            offset += header_size;
401
402            let count = header.count as usize;
403
404            if header.stripe_id == target_stripe_id {
405                // Process this stripe
406                for _ in 0..count {
407                    if offset + posting_size > self.rdf_data.len() {
408                        break;
409                    }
410
411                    let posting: RdfPosting = unsafe {
412                        std::ptr::read_unaligned(self.rdf_data.as_ptr().add(offset) as *const _)
413                    };
414                    offset += posting_size;
415
416                    let vid_in_stripe = posting.vid_in_stripe as usize;
417                    let sign = if posting.sign() { 1.0 } else { -1.0 };
418                    let mag = posting.magnitude() as f32 / 127.0;
419
420                    // Score contribution: q_value * sign * mag * weight
421                    let contribution = q_value * sign * mag * dim_weight;
422                    stripe_acc[vid_in_stripe] += contribution;
423                }
424                return;
425            } else {
426                // Skip this stripe
427                offset += count * posting_size;
428            }
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_rdf_build() {
439        let config = RdfConfig {
440            top_t: 8,
441            stripe_shift: 4, // 16 vids per stripe
442            stop_dim_threshold: 1000,
443            idf_weight: 0.5,
444            var_weight: 0.5,
445        };
446
447        let vectors: Vec<Vec<f32>> = (0..100)
448            .map(|i| {
449                (0..32)
450                    .map(|j| if j == (i % 32) { 1.0 } else { 0.1 })
451                    .collect()
452            })
453            .collect();
454
455        let builder = RdfBuilder::new(&config, 32, &vectors);
456        let (directory, data) = builder.build();
457
458        assert_eq!(directory.len(), 32);
459        assert!(!data.is_empty());
460    }
461
462    #[test]
463    fn test_rdf_scorer() {
464        let config = RdfConfig {
465            top_t: 4,
466            stripe_shift: 4,
467            stop_dim_threshold: 1000,
468            idf_weight: 0.5,
469            var_weight: 0.5,
470        };
471
472        // Create vectors with distinctive patterns
473        let vectors: Vec<Vec<f32>> = (0..50)
474            .map(|i| {
475                (0..16)
476                    .map(|j| if j == (i % 16) { 1.0 } else { 0.0 })
477                    .collect()
478            })
479            .collect();
480
481        let builder = RdfBuilder::new(&config, 16, &vectors);
482        let dim_weights = builder.dim_weights();
483        let (directory, data) = builder.build();
484
485        let scorer = RdfScorer::new(&directory, &data, &dim_weights, 4, 50);
486
487        // Query matching vector 0 pattern
488        let query: Vec<f32> = (0..16).map(|j| if j == 0 { 1.0 } else { 0.0 }).collect();
489        let candidates = scorer.score(&query, 4, 10);
490
491        // Should find vector 0 (and others with same pattern) as top candidates
492        assert!(!candidates.is_empty());
493    }
494}