1use crate::config::RdfConfig;
7use crate::segment::format::PostingListEntry;
8use crate::types::*;
9use std::collections::HashMap;
10
11pub struct RdfBuilder<'a> {
13 config: &'a RdfConfig,
14 dim: u32,
15 vectors: &'a [Vec<f32>],
16 dim_weights: Vec<f32>,
17 doc_freqs: Vec<u32>,
18}
19
20impl<'a> RdfBuilder<'a> {
21 pub fn new(config: &'a RdfConfig, dim: u32, rotated_vectors: &'a [Vec<f32>]) -> Self {
23 let n_vec = rotated_vectors.len();
24 let dim_usize = dim as usize;
25
26 let mut sum = vec![0.0f64; dim_usize];
28 let mut sum_sq = vec![0.0f64; dim_usize];
29 let mut doc_freqs = vec![0u32; dim_usize];
30
31 let top_t = config.top_t as usize;
33
34 for vec in rotated_vectors {
35 let mut scored: Vec<(usize, f32)> =
37 vec.iter().enumerate().map(|(i, &v)| (i, v.abs())).collect();
38 let nth_idx = top_t.min(scored.len()).saturating_sub(1);
39 if nth_idx < scored.len() {
40 scored.select_nth_unstable_by(nth_idx, |a, b| b.1.partial_cmp(&a.1).unwrap());
41 }
42
43 for &(dim_idx, _) in scored.iter().take(top_t) {
44 doc_freqs[dim_idx] += 1;
45 }
46
47 for (i, &v) in vec.iter().enumerate() {
48 sum[i] += v as f64;
49 sum_sq[i] += (v * v) as f64;
50 }
51 }
52
53 let n = n_vec as f64;
55 let mut dim_weights = Vec::with_capacity(dim_usize);
56
57 for d in 0..dim_usize {
58 let mean = sum[d] / n;
59 let var = (sum_sq[d] / n - mean * mean).max(0.0);
60 let std_dev = var.sqrt();
61
62 let df = doc_freqs[d].max(1) as f64;
64 let idf = (n / df).ln();
65
66 let weight = config.idf_weight as f64 * idf + config.var_weight as f64 * std_dev;
68 dim_weights.push(weight as f32);
69 }
70
71 Self {
72 config,
73 dim,
74 vectors: rotated_vectors,
75 dim_weights,
76 doc_freqs,
77 }
78 }
79
80 pub fn dim_weights(&self) -> Vec<f32> {
82 self.dim_weights.clone()
83 }
84
85 pub fn build(&self) -> (Vec<PostingListEntry>, Vec<u8>) {
88 let dim_usize = self.dim as usize;
89 let top_t = self.config.top_t as usize;
90 let stripe_shift = self.config.stripe_shift;
91 let _stripe_size = 1usize << stripe_shift;
92
93 let mut dim_postings: Vec<Vec<(VectorId, bool, u8)>> = vec![Vec::new(); dim_usize];
96
97 let mut dim_max_mag = vec![0.0f32; dim_usize];
99
100 for (_vid, vec) in self.vectors.iter().enumerate() {
101 let mut scored: Vec<(usize, f32, f32)> = vec
103 .iter()
104 .enumerate()
105 .map(|(d, &v)| (d, v.abs() * self.dim_weights[d], v))
106 .collect();
107
108 if scored.len() > top_t {
110 scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
111 scored.truncate(top_t);
112 }
113
114 for &(dim_idx, _, value) in &scored {
115 let mag = value.abs();
116 dim_max_mag[dim_idx] = dim_max_mag[dim_idx].max(mag);
117 }
118 }
119
120 for (vid, vec) in self.vectors.iter().enumerate() {
122 let mut scored: Vec<(usize, f32, f32)> = vec
123 .iter()
124 .enumerate()
125 .map(|(d, &v)| (d, v.abs() * self.dim_weights[d], v))
126 .collect();
127
128 if scored.len() > top_t {
129 scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
130 scored.truncate(top_t);
131 }
132
133 for &(dim_idx, _, value) in &scored {
134 let sign = value >= 0.0;
135 let mag = value.abs();
136 let max_mag = dim_max_mag[dim_idx].max(1e-10);
137 let mag8 = ((mag / max_mag) * 127.0).min(127.0) as u8;
138
139 dim_postings[dim_idx].push((vid as VectorId, sign, mag8));
140 }
141 }
142
143 let mut directory = Vec::with_capacity(dim_usize);
145 let mut data = Vec::new();
146
147 for dim_idx in 0..dim_usize {
148 let postings = &dim_postings[dim_idx];
149
150 if postings.is_empty() {
151 directory.push(PostingListEntry {
152 offset: data.len() as u64,
153 length: 0,
154 num_stripes: 0,
155 flags: 0,
156 });
157 continue;
158 }
159
160 let offset = data.len() as u64;
161
162 let is_stopword = self.doc_freqs[dim_idx] > self.config.stop_dim_threshold;
164 let flags = if is_stopword {
165 PostingListEntry::FLAG_STOPWORD
166 } else {
167 0
168 };
169
170 let mut stripes: HashMap<StripeId, Vec<(u8, bool, u8)>> = HashMap::new();
172 for &(vid, sign, mag) in postings {
173 let stripe_id = vid >> stripe_shift;
174 let vid_in_stripe = (vid & ((1 << stripe_shift) - 1)) as u8;
175 stripes
176 .entry(stripe_id)
177 .or_default()
178 .push((vid_in_stripe, sign, mag));
179 }
180
181 let mut stripe_ids: Vec<StripeId> = stripes.keys().copied().collect();
183 stripe_ids.sort();
184
185 for stripe_id in &stripe_ids {
187 let entries = stripes.get(stripe_id).unwrap();
188
189 let header = StripeChunkHeader {
191 stripe_id: *stripe_id,
192 count: entries.len() as u16,
193 _pad: 0,
194 };
195 data.extend_from_slice(bytemuck::bytes_of(&header));
196
197 let mut sorted_entries = entries.clone();
199 sorted_entries.sort_by_key(|e| e.0);
200
201 for (vid_in_stripe, sign, mag) in sorted_entries {
202 let posting = RdfPosting::new(vid_in_stripe, sign, mag);
203 data.extend_from_slice(bytemuck::bytes_of(&posting));
204 }
205 }
206
207 directory.push(PostingListEntry {
208 offset,
209 length: postings.len() as u32,
210 num_stripes: stripe_ids.len() as u16,
211 flags,
212 });
213 }
214
215 (directory, data)
216 }
217}
218
219pub struct RdfScorer<'a> {
221 directory: &'a [PostingListEntry],
222 rdf_data: &'a [u8],
223 dim_weights: &'a [f32],
224 stripe_shift: u8,
225 stripe_size: usize,
226 n_vec: u32,
227}
228
229impl<'a> RdfScorer<'a> {
230 pub fn new(
232 directory: &'a [PostingListEntry],
233 rdf_data: &'a [u8],
234 dim_weights: &'a [f32],
235 stripe_shift: u8,
236 n_vec: u32,
237 ) -> Self {
238 Self {
239 directory,
240 rdf_data,
241 dim_weights,
242 stripe_shift,
243 stripe_size: 1 << stripe_shift,
244 n_vec,
245 }
246 }
247
248 pub fn score(&self, query: &[f32], top_t: usize, l_a: usize) -> Vec<ScoredCandidate> {
251 if self.directory.is_empty() {
252 return Vec::new();
253 }
254
255 let _dim = query.len();
256
257 let mut scored: Vec<(usize, f32, f32)> = query
259 .iter()
260 .enumerate()
261 .map(|(d, &v)| {
262 let w = if d < self.dim_weights.len() {
263 self.dim_weights[d]
264 } else {
265 1.0
266 };
267 (d, v.abs() * w, v)
268 })
269 .collect();
270
271 if scored.len() > top_t {
272 scored.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
273 scored.truncate(top_t);
274 }
275
276 let query_dims: Vec<(usize, f32, f32)> = scored
279 .into_iter()
280 .filter(|&(d, _, _)| d < self.directory.len() && !self.directory[d].is_stopword())
281 .collect();
282
283 if query_dims.is_empty() {
284 let query_dims_fallback: Vec<(usize, f32, f32)> = {
290 let mut s: Vec<(usize, f32, f32)> = query
291 .iter()
292 .enumerate()
293 .map(|(d, &v)| {
294 let w = if d < self.dim_weights.len() {
295 self.dim_weights[d]
296 } else {
297 1.0
298 };
299 (d, v.abs() * w, v)
300 })
301 .collect();
302 if s.len() > top_t {
303 s.select_nth_unstable_by(top_t - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
304 s.truncate(top_t);
305 }
306 s.into_iter()
307 .filter(|&(d, _, _)| d < self.directory.len())
308 .collect()
309 };
310 if query_dims_fallback.is_empty() {
311 return Vec::new();
312 }
313 return self.score_with_dims(&query_dims_fallback, l_a);
314 }
315
316 self.score_with_dims(&query_dims, l_a)
317 }
318
319 fn score_with_dims(
322 &self,
323 query_dims: &[(usize, f32, f32)],
324 l_a: usize,
325 ) -> Vec<ScoredCandidate> {
326 let num_stripes = (self.n_vec as usize + self.stripe_size - 1) / self.stripe_size;
328 let mut global_candidates = Vec::new();
329
330 let mut stripe_acc = vec![0.0f32; self.stripe_size];
332
333 for stripe_id in 0..num_stripes as u32 {
335 stripe_acc.iter_mut().for_each(|x| *x = 0.0);
337
338 for &(dim_idx, _, q_value) in query_dims {
339 let entry = &self.directory[dim_idx];
340 if entry.length == 0 {
341 continue;
342 }
343
344 self.accumulate_stripe(
346 entry,
347 stripe_id,
348 q_value,
349 self.dim_weights[dim_idx],
350 &mut stripe_acc,
351 );
352 }
353
354 let base_vid = stripe_id << self.stripe_shift;
356 for (i, &score) in stripe_acc.iter().enumerate() {
357 if score > 0.0 {
358 let vid = base_vid + i as u32;
359 if vid < self.n_vec {
360 global_candidates.push(ScoredCandidate { id: vid, score });
361 }
362 }
363 }
364 }
365
366 if global_candidates.len() <= l_a {
368 global_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
369 return global_candidates;
370 }
371
372 global_candidates
373 .select_nth_unstable_by(l_a - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
374 global_candidates.truncate(l_a);
375 global_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
376
377 global_candidates
378 }
379
380 fn accumulate_stripe(
382 &self,
383 entry: &PostingListEntry,
384 target_stripe_id: StripeId,
385 q_value: f32,
386 dim_weight: f32,
387 stripe_acc: &mut [f32],
388 ) {
389 let mut offset = entry.offset as usize;
390 let header_size = std::mem::size_of::<StripeChunkHeader>();
391 let posting_size = std::mem::size_of::<RdfPosting>();
392
393 for _ in 0..entry.num_stripes {
394 if offset + header_size > self.rdf_data.len() {
395 break;
396 }
397
398 let header: StripeChunkHeader =
399 unsafe { std::ptr::read_unaligned(self.rdf_data.as_ptr().add(offset) as *const _) };
400 offset += header_size;
401
402 let count = header.count as usize;
403
404 if header.stripe_id == target_stripe_id {
405 for _ in 0..count {
407 if offset + posting_size > self.rdf_data.len() {
408 break;
409 }
410
411 let posting: RdfPosting = unsafe {
412 std::ptr::read_unaligned(self.rdf_data.as_ptr().add(offset) as *const _)
413 };
414 offset += posting_size;
415
416 let vid_in_stripe = posting.vid_in_stripe as usize;
417 let sign = if posting.sign() { 1.0 } else { -1.0 };
418 let mag = posting.magnitude() as f32 / 127.0;
419
420 let contribution = q_value * sign * mag * dim_weight;
422 stripe_acc[vid_in_stripe] += contribution;
423 }
424 return;
425 } else {
426 offset += count * posting_size;
428 }
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_rdf_build() {
439 let config = RdfConfig {
440 top_t: 8,
441 stripe_shift: 4, stop_dim_threshold: 1000,
443 idf_weight: 0.5,
444 var_weight: 0.5,
445 };
446
447 let vectors: Vec<Vec<f32>> = (0..100)
448 .map(|i| {
449 (0..32)
450 .map(|j| if j == (i % 32) { 1.0 } else { 0.1 })
451 .collect()
452 })
453 .collect();
454
455 let builder = RdfBuilder::new(&config, 32, &vectors);
456 let (directory, data) = builder.build();
457
458 assert_eq!(directory.len(), 32);
459 assert!(!data.is_empty());
460 }
461
462 #[test]
463 fn test_rdf_scorer() {
464 let config = RdfConfig {
465 top_t: 4,
466 stripe_shift: 4,
467 stop_dim_threshold: 1000,
468 idf_weight: 0.5,
469 var_weight: 0.5,
470 };
471
472 let vectors: Vec<Vec<f32>> = (0..50)
474 .map(|i| {
475 (0..16)
476 .map(|j| if j == (i % 16) { 1.0 } else { 0.0 })
477 .collect()
478 })
479 .collect();
480
481 let builder = RdfBuilder::new(&config, 16, &vectors);
482 let dim_weights = builder.dim_weights();
483 let (directory, data) = builder.build();
484
485 let scorer = RdfScorer::new(&directory, &data, &dim_weights, 4, 50);
486
487 let query: Vec<f32> = (0..16).map(|j| if j == 0 { 1.0 } else { 0.0 }).collect();
489 let candidates = scorer.score(&query, 4, 10);
490
491 assert!(!candidates.is_empty());
493 }
494}