1use anyhow::{Result, anyhow};
27use arrow_array::types::{Float32Type, UInt8Type, UInt64Type};
28use arrow_array::{
29 Array, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StructArray, UInt8Array,
30 UInt32Array, UInt64Array,
31};
32use arrow_schema::{DataType, Field, Schema as ArrowSchema};
33use futures::TryStreamExt;
34use lance::Dataset;
35use std::cmp::Reverse;
36use std::collections::{BinaryHeap, HashMap, HashSet};
37use std::sync::Arc;
38use tracing::{debug, info, instrument};
39use uni_common::core::id::Vid;
40use uni_common::core::schema::SparseVectorIndexConfig;
41
42const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
45
46type Postings = HashMap<u32, Vec<(u64, f32)>>;
48
49fn estimated_postings_memory(postings: &Postings) -> usize {
51 postings
52 .values()
53 .map(|v| std::mem::size_of::<u32>() + std::mem::size_of::<Vec<(u64, f32)>>() + v.len() * 12)
54 .sum()
55}
56
57fn merge_postings_segments(segments: Vec<Postings>) -> Postings {
59 let mut merged: Postings = HashMap::new();
60 for segment in segments {
61 for (term, entries) in segment {
62 merged.entry(term).or_default().extend(entries);
63 }
64 }
65 merged
66}
67
68fn read_sparse_row(struct_arr: &StructArray, row: usize) -> Option<(Vec<u32>, Vec<f32>)> {
72 if struct_arr.is_null(row) {
73 return None;
74 }
75 let indices_list = struct_arr
76 .column_by_name("indices")?
77 .as_any()
78 .downcast_ref::<ListArray>()?;
79 let values_list = struct_arr
80 .column_by_name("values")?
81 .as_any()
82 .downcast_ref::<ListArray>()?;
83 let idx_vals = indices_list.value(row);
84 let idx_arr = idx_vals.as_any().downcast_ref::<UInt32Array>()?;
85 let w_vals = values_list.value(row);
86 let w_arr = w_vals.as_any().downcast_ref::<Float32Array>()?;
87 let indices = (0..idx_arr.len()).map(|i| idx_arr.value(i)).collect();
88 let values = (0..w_arr.len()).map(|i| w_arr.value(i)).collect();
89 Some((indices, values))
90}
91
92const QUANT_LEVELS: f32 = 255.0;
98
99fn quantize_term(weights: &[f32]) -> (Vec<u8>, f32, f32) {
109 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
110 if max_weight <= 0.0 {
113 return (vec![0u8; weights.len()], 0.0, 0.0);
114 }
115 let scale = max_weight / QUANT_LEVELS;
116 let codes: Vec<u8> = weights
117 .iter()
118 .map(|&w| {
119 (w.clamp(0.0, max_weight) / scale).round() as u8
123 })
124 .collect();
125 let max_code = codes.iter().copied().max().unwrap_or(0);
126 (codes, scale, dequantize(max_code, scale))
127}
128
129fn dequantize(code: u8, scale: f32) -> f32 {
131 f32::from(code) * scale
132}
133
134enum TermWeights<'a> {
138 Quantized { codes: &'a UInt8Array, scale: f32 },
139 Lossless(&'a Float32Array),
140}
141
142impl TermWeights<'_> {
143 fn get(&self, j: usize) -> f32 {
145 match self {
146 Self::Quantized { codes, scale } => {
147 if codes.is_null(j) {
148 0.0
149 } else {
150 dequantize(codes.value(j), *scale)
151 }
152 }
153 Self::Lossless(arr) => {
154 if arr.is_null(j) {
155 0.0
156 } else {
157 arr.value(j)
158 }
159 }
160 }
161 }
162}
163
164fn term_weights(weights_arr: &dyn Array, row_scale: Option<f32>) -> Result<TermWeights<'_>> {
173 if let Some(codes) = weights_arr.as_any().downcast_ref::<UInt8Array>() {
174 let scale = row_scale
175 .ok_or_else(|| anyhow!("Quantized sparse weights missing weight_scale column"))?;
176 Ok(TermWeights::Quantized { codes, scale })
177 } else if let Some(arr) = weights_arr.as_any().downcast_ref::<Float32Array>() {
178 Ok(TermWeights::Lossless(arr))
179 } else {
180 Err(anyhow!(
181 "Invalid inner weights type: {:?}",
182 weights_arr.data_type()
183 ))
184 }
185}
186
187fn weight_scale_column(batch: &RecordBatch) -> Option<&Float32Array> {
191 batch
192 .column_by_name("weight_scale")
193 .and_then(|c| c.as_any().downcast_ref::<Float32Array>())
194}
195
196pub struct SparseVectorIndex {
198 dataset: Option<Dataset>,
199 base_uri: String,
200 label: String,
201 property: String,
202 config: SparseVectorIndexConfig,
203}
204
205impl std::fmt::Debug for SparseVectorIndex {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 f.debug_struct("SparseVectorIndex")
208 .field("base_uri", &self.base_uri)
209 .field("label", &self.label)
210 .field("property", &self.property)
211 .field("initialized", &self.dataset.is_some())
212 .finish_non_exhaustive()
213 }
214}
215
216impl SparseVectorIndex {
217 fn postings_path(base_uri: &str, label: &str, property: &str) -> String {
219 format!("{base_uri}/indexes/{label}/{property}_sparse")
220 }
221
222 pub async fn new(base_uri: &str, config: SparseVectorIndexConfig) -> Result<Self> {
224 let path = Self::postings_path(base_uri, &config.label, &config.property);
225 let dataset = (Dataset::open(&path).await).ok();
226 Ok(Self {
227 dataset,
228 base_uri: base_uri.to_string(),
229 label: config.label.clone(),
230 property: config.property.clone(),
231 config,
232 })
233 }
234
235 fn accumulate_batch(&self, batch: &RecordBatch, postings: &mut Postings) -> Result<usize> {
240 let vid_col = batch
241 .column_by_name("_vid")
242 .ok_or_else(|| anyhow!("Missing _vid"))?
243 .as_any()
244 .downcast_ref::<UInt64Array>()
245 .ok_or_else(|| anyhow!("Invalid _vid type"))?;
246 let term_col = batch
247 .column_by_name(&self.property)
248 .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
249 let struct_arr = term_col
250 .as_any()
251 .downcast_ref::<StructArray>()
252 .ok_or_else(|| {
253 anyhow!(
254 "Property {} must be a sparse-vector struct, got {:?}",
255 self.property,
256 term_col.data_type()
257 )
258 })?;
259 let mut count = 0;
260 for i in 0..batch.num_rows() {
261 let vid = vid_col.value(i);
262 let Some((indices, values)) = read_sparse_row(struct_arr, i) else {
263 continue;
264 };
265 for (term, weight) in indices.into_iter().zip(values) {
266 if !weight.is_finite() {
271 continue;
272 }
273 postings.entry(term).or_default().push((vid, weight));
274 }
275 count += 1;
276 }
277 Ok(count)
278 }
279
280 async fn finish_build(
282 &mut self,
283 postings: Postings,
284 mut temp_segments: Vec<Postings>,
285 ) -> Result<()> {
286 if temp_segments.is_empty() {
287 self.write_postings(postings).await
288 } else {
289 temp_segments.push(postings);
290 info!(
291 segments = temp_segments.len(),
292 "Merging sparse postings segments"
293 );
294 let merged = merge_postings_segments(temp_segments);
295 self.write_postings(merged).await
296 }
297 }
298
299 pub async fn build_from_batches(
305 &mut self,
306 batches: &[RecordBatch],
307 progress: impl Fn(usize),
308 ) -> Result<()> {
309 let mut postings: Postings = HashMap::new();
310 let mut temp_segments: Vec<Postings> = Vec::new();
311 let mut count = 0;
312 for batch in batches {
313 count += self.accumulate_batch(batch, &mut postings)?;
314 progress(count);
315 if estimated_postings_memory(&postings) > DEFAULT_MAX_POSTINGS_MEMORY {
316 temp_segments.push(std::mem::take(&mut postings));
317 }
318 }
319 self.finish_build(postings, temp_segments).await
320 }
321
322 async fn write_postings(&mut self, postings: Postings) -> Result<()> {
330 let quantize = self.config.quantize;
331 let n = postings.len();
332 let mut term_ids = Vec::with_capacity(n);
333 let mut vid_lists: Vec<Option<Vec<Option<u64>>>> = Vec::with_capacity(n);
334 let mut max_impacts = Vec::with_capacity(n);
335 let mut q_weight_lists: Vec<Option<Vec<Option<u8>>>> = Vec::new();
337 let mut q_scales: Vec<f32> = Vec::new();
338 let mut f_weight_lists: Vec<Option<Vec<Option<f32>>>> = Vec::new();
339
340 for (term, entries) in postings {
341 let mut vids = Vec::with_capacity(entries.len());
342 let mut weights = Vec::with_capacity(entries.len());
343 for (vid, weight) in entries {
344 vids.push(Some(vid));
345 weights.push(weight);
346 }
347 term_ids.push(term);
348 vid_lists.push(Some(vids));
349
350 if quantize {
351 let (codes, scale, max_impact) = quantize_term(&weights);
352 q_weight_lists.push(Some(codes.into_iter().map(Some).collect()));
353 q_scales.push(scale);
354 max_impacts.push(max_impact);
355 } else {
356 let mut max_impact = f32::NEG_INFINITY;
360 for &w in &weights {
361 if w > max_impact {
362 max_impact = w;
363 }
364 }
365 if !max_impact.is_finite() {
366 max_impact = 0.0;
367 }
368 max_impacts.push(max_impact);
369 f_weight_lists.push(Some(weights.into_iter().map(Some).collect()));
370 }
371 }
372
373 let term_array = UInt32Array::from(term_ids);
374 let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
375 let max_impact_array = Float32Array::from(max_impacts);
376
377 let mut columns: Vec<(&str, Arc<dyn Array>)> = vec![
378 ("term_id", Arc::new(term_array) as Arc<dyn Array>),
379 ("vids", Arc::new(vid_list_array) as Arc<dyn Array>),
380 ];
381 if quantize {
382 let weight_list_array =
383 ListArray::from_iter_primitive::<UInt8Type, _, _>(q_weight_lists);
384 columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
385 columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
386 columns.push((
387 "weight_scale",
388 Arc::new(Float32Array::from(q_scales)) as Arc<dyn Array>,
389 ));
390 } else {
391 let weight_list_array =
392 ListArray::from_iter_primitive::<Float32Type, _, _>(f_weight_lists);
393 columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
394 columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
395 }
396
397 let batch = arrow_array::RecordBatch::try_from_iter(columns)?;
398
399 let path = Self::postings_path(&self.base_uri, &self.label, &self.property);
400 let write_params = lance::dataset::WriteParams {
401 mode: lance::dataset::WriteMode::Overwrite,
402 ..Default::default()
403 };
404 let iterator = RecordBatchIterator::new(vec![Ok(batch)], Self::postings_schema(quantize));
405 let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
406 self.dataset = Some(ds);
407 Ok(())
408 }
409
410 fn postings_schema(quantize: bool) -> Arc<ArrowSchema> {
412 let weights_item = if quantize {
413 DataType::UInt8
414 } else {
415 DataType::Float32
416 };
417 let mut fields = vec![
418 Field::new("term_id", DataType::UInt32, false),
419 Field::new(
420 "vids",
421 DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
422 false,
423 ),
424 Field::new(
425 "weights",
426 DataType::List(Arc::new(Field::new("item", weights_item, true))),
427 false,
428 ),
429 Field::new("max_impact", DataType::Float32, false),
430 ];
431 if quantize {
432 fields.push(Field::new("weight_scale", DataType::Float32, false));
433 }
434 Arc::new(ArrowSchema::new(fields))
435 }
436
437 pub async fn query_topk(&self, query: &[(u32, f32)], k: usize) -> Result<Vec<(Vid, f32)>> {
445 let Some(ds) = &self.dataset else {
446 debug!("Sparse index not initialized, returning empty result");
447 return Ok(Vec::new());
448 };
449 if query.is_empty() || k == 0 {
450 return Ok(Vec::new());
451 }
452
453 let query_weights: HashMap<u32, f32> = query.iter().copied().collect();
454 let term_filter = query_weights
455 .keys()
456 .map(|t| t.to_string())
457 .collect::<Vec<_>>()
458 .join(", ");
459 let filter = format!("term_id IN ({term_filter})");
460
461 let mut scanner = ds.scan();
462 scanner.filter(&filter)?;
463 let mut stream = scanner.try_into_stream().await?;
464
465 let mut scores: HashMap<u64, f32> = HashMap::new();
466 while let Some(batch) = stream.try_next().await? {
467 let term_col = batch
468 .column_by_name("term_id")
469 .ok_or_else(|| anyhow!("Missing term_id column"))?
470 .as_any()
471 .downcast_ref::<UInt32Array>()
472 .ok_or_else(|| anyhow!("Invalid term_id column"))?;
473 let vids_col = batch
474 .column_by_name("vids")
475 .ok_or_else(|| anyhow!("Missing vids column"))?
476 .as_any()
477 .downcast_ref::<ListArray>()
478 .ok_or_else(|| anyhow!("Invalid vids column"))?;
479 let weights_col = batch
480 .column_by_name("weights")
481 .ok_or_else(|| anyhow!("Missing weights column"))?
482 .as_any()
483 .downcast_ref::<ListArray>()
484 .ok_or_else(|| anyhow!("Invalid weights column"))?;
485 let weight_scale_col = weight_scale_column(&batch);
486
487 for i in 0..batch.num_rows() {
488 let term = term_col.value(i);
489 let Some(&qw) = query_weights.get(&term) else {
490 continue;
491 };
492 if vids_col.is_null(i) || weights_col.is_null(i) {
493 continue;
494 }
495 let vids_arr = vids_col.value(i);
496 let vids = vids_arr
497 .as_any()
498 .downcast_ref::<UInt64Array>()
499 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
500 let weights_arr = weights_col.value(i);
501 let weights =
502 term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
503
504 for j in 0..vids.len() {
505 if vids.is_null(j) {
506 continue;
507 }
508 *scores.entry(vids.value(j)).or_insert(0.0) += qw * weights.get(j);
509 }
510 }
511 }
512
513 Ok(Self::top_k_from_scores(scores, k))
514 }
515
516 fn top_k_from_scores(scores: HashMap<u64, f32>, k: usize) -> Vec<(Vid, f32)> {
519 let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::with_capacity(k + 1);
521 for (vid, score) in scores {
522 heap.push(Reverse(HeapEntry { score, vid }));
523 if heap.len() > k {
524 heap.pop();
525 }
526 }
527 let mut out: Vec<(Vid, f32)> = heap
528 .into_iter()
529 .map(|Reverse(e)| (Vid::from(e.vid), e.score))
530 .collect();
531 out.sort_by(|a, b| {
532 b.1.partial_cmp(&a.1)
533 .unwrap_or(std::cmp::Ordering::Equal)
534 .then(a.0.as_u64().cmp(&b.0.as_u64()))
535 });
536 out
537 }
538
539 #[instrument(skip(self), level = "debug")]
541 async fn load_postings(&self) -> Result<Postings> {
542 let Some(ds) = &self.dataset else {
543 return Ok(HashMap::new());
544 };
545 let mut postings: Postings = HashMap::new();
546 let scanner = ds.scan();
547 let mut stream = scanner.try_into_stream().await?;
548 while let Some(batch) = stream.try_next().await? {
549 let term_col = batch
550 .column_by_name("term_id")
551 .ok_or_else(|| anyhow!("Missing term_id column"))?
552 .as_any()
553 .downcast_ref::<UInt32Array>()
554 .ok_or_else(|| anyhow!("Invalid term_id column"))?;
555 let vids_col = batch
556 .column_by_name("vids")
557 .ok_or_else(|| anyhow!("Missing vids column"))?
558 .as_any()
559 .downcast_ref::<ListArray>()
560 .ok_or_else(|| anyhow!("Invalid vids column"))?;
561 let weights_col = batch
562 .column_by_name("weights")
563 .ok_or_else(|| anyhow!("Missing weights column"))?
564 .as_any()
565 .downcast_ref::<ListArray>()
566 .ok_or_else(|| anyhow!("Invalid weights column"))?;
567 let weight_scale_col = weight_scale_column(&batch);
568
569 for i in 0..batch.num_rows() {
570 if vids_col.is_null(i) || weights_col.is_null(i) {
571 continue;
572 }
573 let term = term_col.value(i);
574 let vids_arr = vids_col.value(i);
575 let vids = vids_arr
576 .as_any()
577 .downcast_ref::<UInt64Array>()
578 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
579 let weights_arr = weights_col.value(i);
580 let weights =
583 term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
584 let entry = postings.entry(term).or_default();
585 for j in 0..vids.len() {
586 if !vids.is_null(j) {
587 entry.push((vids.value(j), weights.get(j)));
588 }
589 }
590 }
591 }
592 Ok(postings)
593 }
594
595 #[instrument(skip(self, added, removed), level = "info", fields(
599 label = %self.label,
600 property = %self.property,
601 added_count = added.len(),
602 removed_count = removed.len()
603 ))]
604 pub async fn apply_incremental_updates(
605 &mut self,
606 added: &HashMap<Vid, Vec<(u32, f32)>>,
607 removed: &HashSet<Vid>,
608 ) -> Result<()> {
609 let mut postings = self.load_postings().await?;
610
611 if !removed.is_empty() {
612 let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
613 for entries in postings.values_mut() {
614 entries.retain(|(vid, _)| !removed_u64.contains(vid));
615 }
616 postings.retain(|_, entries| !entries.is_empty());
617 }
618
619 for (vid, terms) in added {
620 let vid_u64 = vid.as_u64();
621 for &(term, weight) in terms {
622 postings.entry(term).or_default().push((vid_u64, weight));
623 }
624 }
625
626 self.write_postings(postings).await?;
627 Ok(())
628 }
629
630 pub fn is_initialized(&self) -> bool {
632 self.dataset.is_some()
633 }
634
635 pub fn property(&self) -> &str {
637 &self.property
638 }
639}
640
641struct HeapEntry {
643 score: f32,
644 vid: u64,
645}
646
647impl PartialEq for HeapEntry {
648 fn eq(&self, other: &Self) -> bool {
649 self.cmp(other) == std::cmp::Ordering::Equal
650 }
651}
652impl Eq for HeapEntry {}
653impl PartialOrd for HeapEntry {
654 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
655 Some(self.cmp(other))
656 }
657}
658impl Ord for HeapEntry {
659 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
660 self.score
661 .partial_cmp(&other.score)
662 .unwrap_or(std::cmp::Ordering::Equal)
663 .then(self.vid.cmp(&other.vid))
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn test_merge_postings_segments_overlapping() {
673 let seg1: Postings = [(1u32, vec![(10u64, 1.0f32)]), (2, vec![(11, 2.0)])]
674 .into_iter()
675 .collect();
676 let seg2: Postings = [(1u32, vec![(12u64, 3.0f32)]), (3, vec![(13, 4.0)])]
677 .into_iter()
678 .collect();
679 let merged = merge_postings_segments(vec![seg1, seg2]);
680 assert_eq!(merged.get(&1).unwrap().len(), 2);
681 assert_eq!(merged.get(&2).unwrap(), &vec![(11, 2.0)]);
682 assert_eq!(merged.get(&3).unwrap(), &vec![(13, 4.0)]);
683 }
684
685 #[test]
686 fn test_top_k_from_scores_orders_desc_and_caps() {
687 let scores: HashMap<u64, f32> = [(1u64, 0.5f32), (2, 3.0), (3, 1.0), (4, 2.0)]
688 .into_iter()
689 .collect();
690 let top = SparseVectorIndex::top_k_from_scores(scores, 2);
691 assert_eq!(top.len(), 2);
692 assert_eq!(top[0].0.as_u64(), 2);
693 assert_eq!(top[0].1, 3.0);
694 assert_eq!(top[1].0.as_u64(), 4);
695 assert_eq!(top[1].1, 2.0);
696 }
697
698 #[test]
699 fn test_top_k_tie_break_by_vid() {
700 let scores: HashMap<u64, f32> = [(7u64, 1.0f32), (3, 1.0)].into_iter().collect();
701 let top = SparseVectorIndex::top_k_from_scores(scores, 2);
702 assert_eq!(top[0].0.as_u64(), 3);
704 assert_eq!(top[1].0.as_u64(), 7);
705 }
706
707 #[test]
708 fn test_top_k_empty() {
709 assert!(SparseVectorIndex::top_k_from_scores(HashMap::new(), 5).is_empty());
710 }
711
712 #[test]
713 fn test_quantize_all_zero_term_no_nan() {
714 let (codes, scale, max_impact) = quantize_term(&[0.0, 0.0, 0.0]);
715 assert_eq!(codes, vec![0, 0, 0]);
716 assert_eq!(scale, 0.0);
717 assert_eq!(max_impact, 0.0);
718 assert!(!scale.is_nan() && !max_impact.is_nan());
719 }
720
721 #[test]
722 fn test_quantize_negative_weights_clamp_to_zero() {
723 let (codes, _scale, max_impact) = quantize_term(&[-1.0, -0.5]);
725 assert_eq!(codes, vec![0, 0]);
726 assert_eq!(max_impact, 0.0);
727 }
728
729 #[test]
730 fn test_quantize_max_weight_maps_to_top_code() {
731 let (codes, scale, max_impact) = quantize_term(&[0.1, 2.0, 1.0]);
732 assert_eq!(codes[1], 255);
734 for (j, &w) in [0.1f32, 2.0, 1.0].iter().enumerate() {
737 assert!(dequantize(codes[j], scale) <= max_impact + f32::EPSILON);
738 assert!((dequantize(codes[j], scale) - w).abs() <= scale / 2.0 + 1e-6);
740 }
741 }
742
743 proptest::proptest! {
744 #[test]
745 fn prop_quantize_roundtrip_and_bound(
746 weights in proptest::collection::vec(0.0f32..1000.0, 1..64)
747 ) {
748 let (codes, scale, max_impact) = quantize_term(&weights);
749 proptest::prop_assert_eq!(codes.len(), weights.len());
750 for (j, &w) in weights.iter().enumerate() {
751 let dq = dequantize(codes[j], scale);
752 proptest::prop_assert!(dq <= max_impact + 1e-4);
754 proptest::prop_assert!((dq - w).abs() <= scale / 2.0 + 1e-3);
756 proptest::prop_assert!(dq.is_finite());
757 }
758 }
759 }
760}