Skip to main content

uni_store/storage/
sparse_index.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3// Rust guideline compliant
4
5//! Scored sparse-vector (SPLADE / learned-sparse) inverted index.
6//!
7//! Forked from [`super::inverted_index`], but with three differences that make
8//! it a *scored* index rather than a set-membership one:
9//!
10//! 1. Postings are `term_id (u32) -> [(vid, weight)]` instead of
11//!    `term (String) -> [vid]`. Scoring is a dot product over shared terms.
12//! 2. The on-disk schema carries per-term `weights` and a `max_impact` upper
13//!    bound (the prerequisite for P2 block-max pruning).
14//! 3. [`SparseVectorIndex::query_topk`] returns scored, ranked results via a
15//!    dot-product accumulator + bounded min-heap, not an unordered VID set.
16//!
17//! Weights are stored as 8-bit per-term-quantized codes by default (config
18//! `quantize`, ≈ lossless and ~4× smaller); `quantize = false` stores lossless
19//! `f32` instead. Both encodings are read transparently by the same reader (the
20//! `weights` list element type is the discriminator), which also makes legacy
21//! `f32`-only segments forward-compatible without a rebuild. MVCC / tombstone
22//! correctness is applied by the *query orchestration* layer (uni-query),
23//! exactly as the dense `vector_search` path does — this module is the storage
24//! kernel.
25
26use anyhow::{Result, anyhow};
27use arrow_array::types::{Float32Type, UInt8Type, UInt64Type};
28use arrow_array::{
29    Array, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StructArray, UInt8Array,
30    UInt32Array, UInt64Array,
31};
32use arrow_schema::{DataType, Field, Schema as ArrowSchema};
33use futures::TryStreamExt;
34use lance::Dataset;
35use std::cmp::Reverse;
36use std::collections::{BinaryHeap, HashMap, HashSet};
37use std::sync::Arc;
38use tracing::{debug, info, instrument};
39use uni_common::core::id::Vid;
40use uni_common::core::schema::SparseVectorIndexConfig;
41
42/// Default memory limit for postings accumulation (256 MB), matching the
43/// set-membership inverted index.
44const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
45
46/// One term's postings: parallel vid + weight vectors.
47type Postings = HashMap<u32, Vec<(u64, f32)>>;
48
49/// Estimate memory usage of a postings map (term key + parallel vid/weight pairs).
50fn estimated_postings_memory(postings: &Postings) -> usize {
51    postings
52        .values()
53        .map(|v| std::mem::size_of::<u32>() + std::mem::size_of::<Vec<(u64, f32)>>() + v.len() * 12)
54        .sum()
55}
56
57/// Merge multiple postings segments into one (concatenates per-term lists).
58fn merge_postings_segments(segments: Vec<Postings>) -> Postings {
59    let mut merged: Postings = HashMap::new();
60    for segment in segments {
61        for (term, entries) in segment {
62            merged.entry(term).or_default().extend(entries);
63        }
64    }
65    merged
66}
67
68/// Read a sparse vector `(indices, values)` from row `row` of a `Struct
69/// { indices: List<UInt32>, values: List<Float32> }` column. Returns `None` if
70/// the struct row is null (deleted / absent).
71fn read_sparse_row(struct_arr: &StructArray, row: usize) -> Option<(Vec<u32>, Vec<f32>)> {
72    if struct_arr.is_null(row) {
73        return None;
74    }
75    let indices_list = struct_arr
76        .column_by_name("indices")?
77        .as_any()
78        .downcast_ref::<ListArray>()?;
79    let values_list = struct_arr
80        .column_by_name("values")?
81        .as_any()
82        .downcast_ref::<ListArray>()?;
83    let idx_vals = indices_list.value(row);
84    let idx_arr = idx_vals.as_any().downcast_ref::<UInt32Array>()?;
85    let w_vals = values_list.value(row);
86    let w_arr = w_vals.as_any().downcast_ref::<Float32Array>()?;
87    let indices = (0..idx_arr.len()).map(|i| idx_arr.value(i)).collect();
88    let values = (0..w_arr.len()).map(|i| w_arr.value(i)).collect();
89    Some((indices, values))
90}
91
92/// Number of 8-bit quantization levels above zero.
93///
94/// Learned-sparse / SPLADE weights are non-negative (ReLU), so the full
95/// unsigned `0..=255` range is used: the per-term scale is `max_weight / 255`,
96/// giving twice the resolution of a signed `i8` scheme for the same width.
97const QUANT_LEVELS: f32 = 255.0;
98
99/// Quantize one term's weights to 8-bit codes with a shared per-term scale.
100///
101/// Returns `(codes, scale, max_impact)` where `code as f32 * scale` reconstructs
102/// the weight and `max_impact` is computed from the *dequantized* weights, so it
103/// stays a valid upper bound on the values scoring actually multiplies — the
104/// invariant any future block-max pruning depends on.
105///
106/// Weights are clamped to `[0, max_weight]`; learned-sparse weights are
107/// non-negative, and a stray negative has no 8-bit code (it maps to zero).
108fn quantize_term(weights: &[f32]) -> (Vec<u8>, f32, f32) {
109    let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
110    // All-zero (or all-negative) term: scale 0, every code 0. Guards the
111    // `w / scale` division against `0 / 0 = NaN`.
112    if max_weight <= 0.0 {
113        return (vec![0u8; weights.len()], 0.0, 0.0);
114    }
115    let scale = max_weight / QUANT_LEVELS;
116    let codes: Vec<u8> = weights
117        .iter()
118        .map(|&w| {
119            // Round to nearest (never truncate: truncation biases every weight
120            // down and would let a dequantized value exceed `max_impact`). The
121            // `as u8` cast saturates, so fp drift at the top of the range is safe.
122            (w.clamp(0.0, max_weight) / scale).round() as u8
123        })
124        .collect();
125    let max_code = codes.iter().copied().max().unwrap_or(0);
126    (codes, scale, dequantize(max_code, scale))
127}
128
129/// Reconstruct an approximate weight from an 8-bit code and its term scale.
130fn dequantize(code: u8, scale: f32) -> f32 {
131    f32::from(code) * scale
132}
133
134/// A borrowed view over one term's posting weights that yields `f32` regardless
135/// of on-disk encoding: quantized (`UInt8` codes + a per-term scale) or lossless
136/// (`Float32` — legacy segments and `quantize = false`).
137enum TermWeights<'a> {
138    Quantized { codes: &'a UInt8Array, scale: f32 },
139    Lossless(&'a Float32Array),
140}
141
142impl TermWeights<'_> {
143    /// Weight at posting position `j` (`0.0` for a null element).
144    fn get(&self, j: usize) -> f32 {
145        match self {
146            Self::Quantized { codes, scale } => {
147                if codes.is_null(j) {
148                    0.0
149                } else {
150                    dequantize(codes.value(j), *scale)
151                }
152            }
153            Self::Lossless(arr) => {
154                if arr.is_null(j) {
155                    0.0
156                } else {
157                    arr.value(j)
158                }
159            }
160        }
161    }
162}
163
164/// Build a [`TermWeights`] view over one posting row's `weights` element array.
165///
166/// `row_scale` is the row's `weight_scale` value, required when the elements are
167/// quantized `UInt8` codes and absent for lossless `Float32` segments.
168///
169/// # Errors
170/// Returns an error if the element type is neither `UInt8` nor `Float32`, or if
171/// quantized codes arrive without a `weight_scale`.
172fn term_weights(weights_arr: &dyn Array, row_scale: Option<f32>) -> Result<TermWeights<'_>> {
173    if let Some(codes) = weights_arr.as_any().downcast_ref::<UInt8Array>() {
174        let scale = row_scale
175            .ok_or_else(|| anyhow!("Quantized sparse weights missing weight_scale column"))?;
176        Ok(TermWeights::Quantized { codes, scale })
177    } else if let Some(arr) = weights_arr.as_any().downcast_ref::<Float32Array>() {
178        Ok(TermWeights::Lossless(arr))
179    } else {
180        Err(anyhow!(
181            "Invalid inner weights type: {:?}",
182            weights_arr.data_type()
183        ))
184    }
185}
186
187/// Read the optional per-term `weight_scale` column from a postings batch.
188///
189/// Present only for quantized segments; absent for lossless / legacy ones.
190fn weight_scale_column(batch: &RecordBatch) -> Option<&Float32Array> {
191    batch
192        .column_by_name("weight_scale")
193        .and_then(|c| c.as_any().downcast_ref::<Float32Array>())
194}
195
196/// Scored sparse-vector inverted index over a `DataType::SparseVector` column.
197pub struct SparseVectorIndex {
198    dataset: Option<Dataset>,
199    base_uri: String,
200    label: String,
201    property: String,
202    config: SparseVectorIndexConfig,
203}
204
205impl std::fmt::Debug for SparseVectorIndex {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        f.debug_struct("SparseVectorIndex")
208            .field("base_uri", &self.base_uri)
209            .field("label", &self.label)
210            .field("property", &self.property)
211            .field("initialized", &self.dataset.is_some())
212            .finish_non_exhaustive()
213    }
214}
215
216impl SparseVectorIndex {
217    /// The on-disk dataset path for this index's postings.
218    fn postings_path(base_uri: &str, label: &str, property: &str) -> String {
219        format!("{base_uri}/indexes/{label}/{property}_sparse")
220    }
221
222    /// Open or initialize a sparse index at `base_uri` for the given config.
223    pub async fn new(base_uri: &str, config: SparseVectorIndexConfig) -> Result<Self> {
224        let path = Self::postings_path(base_uri, &config.label, &config.property);
225        let dataset = (Dataset::open(&path).await).ok();
226        Ok(Self {
227            dataset,
228            base_uri: base_uri.to_string(),
229            label: config.label.clone(),
230            property: config.property.clone(),
231            config,
232        })
233    }
234
235    /// Accumulate one record batch's sparse rows into `postings`. Returns the
236    /// count of documents (rows carrying a non-null sparse value) processed.
237    /// Deleted rows store a null struct (see `build_sparse_vector_column`), so
238    /// `read_sparse_row` skips them and they never enter the postings.
239    fn accumulate_batch(&self, batch: &RecordBatch, postings: &mut Postings) -> Result<usize> {
240        let vid_col = batch
241            .column_by_name("_vid")
242            .ok_or_else(|| anyhow!("Missing _vid"))?
243            .as_any()
244            .downcast_ref::<UInt64Array>()
245            .ok_or_else(|| anyhow!("Invalid _vid type"))?;
246        let term_col = batch
247            .column_by_name(&self.property)
248            .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
249        let struct_arr = term_col
250            .as_any()
251            .downcast_ref::<StructArray>()
252            .ok_or_else(|| {
253                anyhow!(
254                    "Property {} must be a sparse-vector struct, got {:?}",
255                    self.property,
256                    term_col.data_type()
257                )
258            })?;
259        let mut count = 0;
260        for i in 0..batch.num_rows() {
261            let vid = vid_col.value(i);
262            let Some((indices, values)) = read_sparse_row(struct_arr, i) else {
263                continue;
264            };
265            for (term, weight) in indices.into_iter().zip(values) {
266                // Defense in depth: ingest validation now rejects non-finite weights, but
267                // a corrupt / manually-spliced on-disk segment must not poison scoring —
268                // a single NaN weight would make a vid's accumulated dot product NaN and
269                // corrupt top-k ordering (issue #95). Skip non-finite postings on read.
270                if !weight.is_finite() {
271                    continue;
272                }
273                postings.entry(term).or_default().push((vid, weight));
274            }
275            count += 1;
276        }
277        Ok(count)
278    }
279
280    /// Merge accumulated postings (+ any spilled segments) and write to disk.
281    async fn finish_build(
282        &mut self,
283        postings: Postings,
284        mut temp_segments: Vec<Postings>,
285    ) -> Result<()> {
286        if temp_segments.is_empty() {
287            self.write_postings(postings).await
288        } else {
289            temp_segments.push(postings);
290            info!(
291                segments = temp_segments.len(),
292                "Merging sparse postings segments"
293            );
294            let merged = merge_postings_segments(temp_segments);
295            self.write_postings(merged).await
296        }
297    }
298
299    /// Rebuild the index from already-scanned record batches (the storage
300    /// backend's view of the flushed vertex table). This is the canonical
301    /// backfill path: the LanceDB-managed table is opened by the backend, not
302    /// by a raw `lance::Dataset::open` (whose physical path differs). Uses
303    /// segmented accumulation to stay within the memory limit.
304    pub async fn build_from_batches(
305        &mut self,
306        batches: &[RecordBatch],
307        progress: impl Fn(usize),
308    ) -> Result<()> {
309        let mut postings: Postings = HashMap::new();
310        let mut temp_segments: Vec<Postings> = Vec::new();
311        let mut count = 0;
312        for batch in batches {
313            count += self.accumulate_batch(batch, &mut postings)?;
314            progress(count);
315            if estimated_postings_memory(&postings) > DEFAULT_MAX_POSTINGS_MEMORY {
316                temp_segments.push(std::mem::take(&mut postings));
317            }
318        }
319        self.finish_build(postings, temp_segments).await
320    }
321
322    /// Overwrite the on-disk postings with the provided map.
323    ///
324    /// Schema (quantized, default): `(term_id: UInt32, vids: List<UInt64>,
325    /// weights: List<UInt8>, max_impact: Float32, weight_scale: Float32)`.
326    /// Lossless (`quantize = false`): `weights` is `List<Float32>` and the
327    /// `weight_scale` column is omitted. `max_impact` is the per-term maximum
328    /// (dequantized) weight — the upper bound P2 block-max pruning consumes.
329    async fn write_postings(&mut self, postings: Postings) -> Result<()> {
330        let quantize = self.config.quantize;
331        let n = postings.len();
332        let mut term_ids = Vec::with_capacity(n);
333        let mut vid_lists: Vec<Option<Vec<Option<u64>>>> = Vec::with_capacity(n);
334        let mut max_impacts = Vec::with_capacity(n);
335        // Exactly one weights representation is populated, per `quantize`.
336        let mut q_weight_lists: Vec<Option<Vec<Option<u8>>>> = Vec::new();
337        let mut q_scales: Vec<f32> = Vec::new();
338        let mut f_weight_lists: Vec<Option<Vec<Option<f32>>>> = Vec::new();
339
340        for (term, entries) in postings {
341            let mut vids = Vec::with_capacity(entries.len());
342            let mut weights = Vec::with_capacity(entries.len());
343            for (vid, weight) in entries {
344                vids.push(Some(vid));
345                weights.push(weight);
346            }
347            term_ids.push(term);
348            vid_lists.push(Some(vids));
349
350            if quantize {
351                let (codes, scale, max_impact) = quantize_term(&weights);
352                q_weight_lists.push(Some(codes.into_iter().map(Some).collect()));
353                q_scales.push(scale);
354                max_impacts.push(max_impact);
355            } else {
356                // True maximum weight (the P2 block-max upper bound). Start at
357                // NEG_INFINITY so an all-negative-weight term records its real
358                // max rather than a spurious 0.0; empty terms fall back to 0.0.
359                let mut max_impact = f32::NEG_INFINITY;
360                for &w in &weights {
361                    if w > max_impact {
362                        max_impact = w;
363                    }
364                }
365                if !max_impact.is_finite() {
366                    max_impact = 0.0;
367                }
368                max_impacts.push(max_impact);
369                f_weight_lists.push(Some(weights.into_iter().map(Some).collect()));
370            }
371        }
372
373        let term_array = UInt32Array::from(term_ids);
374        let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
375        let max_impact_array = Float32Array::from(max_impacts);
376
377        let mut columns: Vec<(&str, Arc<dyn Array>)> = vec![
378            ("term_id", Arc::new(term_array) as Arc<dyn Array>),
379            ("vids", Arc::new(vid_list_array) as Arc<dyn Array>),
380        ];
381        if quantize {
382            let weight_list_array =
383                ListArray::from_iter_primitive::<UInt8Type, _, _>(q_weight_lists);
384            columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
385            columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
386            columns.push((
387                "weight_scale",
388                Arc::new(Float32Array::from(q_scales)) as Arc<dyn Array>,
389            ));
390        } else {
391            let weight_list_array =
392                ListArray::from_iter_primitive::<Float32Type, _, _>(f_weight_lists);
393            columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
394            columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
395        }
396
397        let batch = arrow_array::RecordBatch::try_from_iter(columns)?;
398
399        let path = Self::postings_path(&self.base_uri, &self.label, &self.property);
400        let write_params = lance::dataset::WriteParams {
401            mode: lance::dataset::WriteMode::Overwrite,
402            ..Default::default()
403        };
404        let iterator = RecordBatchIterator::new(vec![Ok(batch)], Self::postings_schema(quantize));
405        let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
406        self.dataset = Some(ds);
407        Ok(())
408    }
409
410    /// Arrow schema of the postings dataset for the given `quantize` mode.
411    fn postings_schema(quantize: bool) -> Arc<ArrowSchema> {
412        let weights_item = if quantize {
413            DataType::UInt8
414        } else {
415            DataType::Float32
416        };
417        let mut fields = vec![
418            Field::new("term_id", DataType::UInt32, false),
419            Field::new(
420                "vids",
421                DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
422                false,
423            ),
424            Field::new(
425                "weights",
426                DataType::List(Arc::new(Field::new("item", weights_item, true))),
427                false,
428            ),
429            Field::new("max_impact", DataType::Float32, false),
430        ];
431        if quantize {
432            fields.push(Field::new("weight_scale", DataType::Float32, false));
433        }
434        Arc::new(ArrowSchema::new(fields))
435    }
436
437    /// Score the corpus against `query` (`[(term_id, weight)]`) by dot product
438    /// and return the top `k` `(Vid, score)` pairs, highest score first.
439    ///
440    /// P1 brute-force document-at-a-time: filter postings to the query terms,
441    /// accumulate per-vid dot products, then drain a bounded min-heap. No
442    /// MVCC/tombstone filtering happens here — the orchestration layer applies
443    /// it, mirroring the dense `vector_search` path.
444    pub async fn query_topk(&self, query: &[(u32, f32)], k: usize) -> Result<Vec<(Vid, f32)>> {
445        let Some(ds) = &self.dataset else {
446            debug!("Sparse index not initialized, returning empty result");
447            return Ok(Vec::new());
448        };
449        if query.is_empty() || k == 0 {
450            return Ok(Vec::new());
451        }
452
453        let query_weights: HashMap<u32, f32> = query.iter().copied().collect();
454        let term_filter = query_weights
455            .keys()
456            .map(|t| t.to_string())
457            .collect::<Vec<_>>()
458            .join(", ");
459        let filter = format!("term_id IN ({term_filter})");
460
461        let mut scanner = ds.scan();
462        scanner.filter(&filter)?;
463        let mut stream = scanner.try_into_stream().await?;
464
465        let mut scores: HashMap<u64, f32> = HashMap::new();
466        while let Some(batch) = stream.try_next().await? {
467            let term_col = batch
468                .column_by_name("term_id")
469                .ok_or_else(|| anyhow!("Missing term_id column"))?
470                .as_any()
471                .downcast_ref::<UInt32Array>()
472                .ok_or_else(|| anyhow!("Invalid term_id column"))?;
473            let vids_col = batch
474                .column_by_name("vids")
475                .ok_or_else(|| anyhow!("Missing vids column"))?
476                .as_any()
477                .downcast_ref::<ListArray>()
478                .ok_or_else(|| anyhow!("Invalid vids column"))?;
479            let weights_col = batch
480                .column_by_name("weights")
481                .ok_or_else(|| anyhow!("Missing weights column"))?
482                .as_any()
483                .downcast_ref::<ListArray>()
484                .ok_or_else(|| anyhow!("Invalid weights column"))?;
485            let weight_scale_col = weight_scale_column(&batch);
486
487            for i in 0..batch.num_rows() {
488                let term = term_col.value(i);
489                let Some(&qw) = query_weights.get(&term) else {
490                    continue;
491                };
492                if vids_col.is_null(i) || weights_col.is_null(i) {
493                    continue;
494                }
495                let vids_arr = vids_col.value(i);
496                let vids = vids_arr
497                    .as_any()
498                    .downcast_ref::<UInt64Array>()
499                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
500                let weights_arr = weights_col.value(i);
501                let weights =
502                    term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
503
504                for j in 0..vids.len() {
505                    if vids.is_null(j) {
506                        continue;
507                    }
508                    *scores.entry(vids.value(j)).or_insert(0.0) += qw * weights.get(j);
509                }
510            }
511        }
512
513        Ok(Self::top_k_from_scores(scores, k))
514    }
515
516    /// Drain a score map into the top-`k` `(Vid, score)` pairs (descending),
517    /// using a bounded min-heap so memory stays O(k).
518    fn top_k_from_scores(scores: HashMap<u64, f32>, k: usize) -> Vec<(Vid, f32)> {
519        // Min-heap keyed on score (via OrderedF32 + Reverse) capped at k.
520        let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::with_capacity(k + 1);
521        for (vid, score) in scores {
522            heap.push(Reverse(HeapEntry { score, vid }));
523            if heap.len() > k {
524                heap.pop();
525            }
526        }
527        let mut out: Vec<(Vid, f32)> = heap
528            .into_iter()
529            .map(|Reverse(e)| (Vid::from(e.vid), e.score))
530            .collect();
531        out.sort_by(|a, b| {
532            b.1.partial_cmp(&a.1)
533                .unwrap_or(std::cmp::Ordering::Equal)
534                .then(a.0.as_u64().cmp(&b.0.as_u64()))
535        });
536        out
537    }
538
539    /// Load all postings from disk into a memory map. Empty if no dataset yet.
540    #[instrument(skip(self), level = "debug")]
541    async fn load_postings(&self) -> Result<Postings> {
542        let Some(ds) = &self.dataset else {
543            return Ok(HashMap::new());
544        };
545        let mut postings: Postings = HashMap::new();
546        let scanner = ds.scan();
547        let mut stream = scanner.try_into_stream().await?;
548        while let Some(batch) = stream.try_next().await? {
549            let term_col = batch
550                .column_by_name("term_id")
551                .ok_or_else(|| anyhow!("Missing term_id column"))?
552                .as_any()
553                .downcast_ref::<UInt32Array>()
554                .ok_or_else(|| anyhow!("Invalid term_id column"))?;
555            let vids_col = batch
556                .column_by_name("vids")
557                .ok_or_else(|| anyhow!("Missing vids column"))?
558                .as_any()
559                .downcast_ref::<ListArray>()
560                .ok_or_else(|| anyhow!("Invalid vids column"))?;
561            let weights_col = batch
562                .column_by_name("weights")
563                .ok_or_else(|| anyhow!("Missing weights column"))?
564                .as_any()
565                .downcast_ref::<ListArray>()
566                .ok_or_else(|| anyhow!("Invalid weights column"))?;
567            let weight_scale_col = weight_scale_column(&batch);
568
569            for i in 0..batch.num_rows() {
570                if vids_col.is_null(i) || weights_col.is_null(i) {
571                    continue;
572                }
573                let term = term_col.value(i);
574                let vids_arr = vids_col.value(i);
575                let vids = vids_arr
576                    .as_any()
577                    .downcast_ref::<UInt64Array>()
578                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
579                let weights_arr = weights_col.value(i);
580                // Dequantizes into f32 so the load-modify-write update path stays
581                // in f32 space and re-quantizes on the next `write_postings`.
582                let weights =
583                    term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
584                let entry = postings.entry(term).or_default();
585                for j in 0..vids.len() {
586                    if !vids.is_null(j) {
587                        entry.push((vids.value(j), weights.get(j)));
588                    }
589                }
590            }
591        }
592        Ok(postings)
593    }
594
595    /// Apply incremental updates: drop removed VIDs from every posting, then add
596    /// the new vertices' `(term, weight)` pairs, and rewrite. Mirrors the
597    /// set-membership inverted index's load-modify-write semantics.
598    #[instrument(skip(self, added, removed), level = "info", fields(
599        label = %self.label,
600        property = %self.property,
601        added_count = added.len(),
602        removed_count = removed.len()
603    ))]
604    pub async fn apply_incremental_updates(
605        &mut self,
606        added: &HashMap<Vid, Vec<(u32, f32)>>,
607        removed: &HashSet<Vid>,
608    ) -> Result<()> {
609        let mut postings = self.load_postings().await?;
610
611        if !removed.is_empty() {
612            let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
613            for entries in postings.values_mut() {
614                entries.retain(|(vid, _)| !removed_u64.contains(vid));
615            }
616            postings.retain(|_, entries| !entries.is_empty());
617        }
618
619        for (vid, terms) in added {
620            let vid_u64 = vid.as_u64();
621            for &(term, weight) in terms {
622                postings.entry(term).or_default().push((vid_u64, weight));
623            }
624        }
625
626        self.write_postings(postings).await?;
627        Ok(())
628    }
629
630    /// Returns true if the index dataset exists.
631    pub fn is_initialized(&self) -> bool {
632        self.dataset.is_some()
633    }
634
635    /// Returns the property name this index is built on.
636    pub fn property(&self) -> &str {
637        &self.property
638    }
639}
640
641/// Heap entry ordered by score (NaN treated as smallest), tie-broken by vid.
642struct HeapEntry {
643    score: f32,
644    vid: u64,
645}
646
647impl PartialEq for HeapEntry {
648    fn eq(&self, other: &Self) -> bool {
649        self.cmp(other) == std::cmp::Ordering::Equal
650    }
651}
652impl Eq for HeapEntry {}
653impl PartialOrd for HeapEntry {
654    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
655        Some(self.cmp(other))
656    }
657}
658impl Ord for HeapEntry {
659    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
660        self.score
661            .partial_cmp(&other.score)
662            .unwrap_or(std::cmp::Ordering::Equal)
663            .then(self.vid.cmp(&other.vid))
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn test_merge_postings_segments_overlapping() {
673        let seg1: Postings = [(1u32, vec![(10u64, 1.0f32)]), (2, vec![(11, 2.0)])]
674            .into_iter()
675            .collect();
676        let seg2: Postings = [(1u32, vec![(12u64, 3.0f32)]), (3, vec![(13, 4.0)])]
677            .into_iter()
678            .collect();
679        let merged = merge_postings_segments(vec![seg1, seg2]);
680        assert_eq!(merged.get(&1).unwrap().len(), 2);
681        assert_eq!(merged.get(&2).unwrap(), &vec![(11, 2.0)]);
682        assert_eq!(merged.get(&3).unwrap(), &vec![(13, 4.0)]);
683    }
684
685    #[test]
686    fn test_top_k_from_scores_orders_desc_and_caps() {
687        let scores: HashMap<u64, f32> = [(1u64, 0.5f32), (2, 3.0), (3, 1.0), (4, 2.0)]
688            .into_iter()
689            .collect();
690        let top = SparseVectorIndex::top_k_from_scores(scores, 2);
691        assert_eq!(top.len(), 2);
692        assert_eq!(top[0].0.as_u64(), 2);
693        assert_eq!(top[0].1, 3.0);
694        assert_eq!(top[1].0.as_u64(), 4);
695        assert_eq!(top[1].1, 2.0);
696    }
697
698    #[test]
699    fn test_top_k_tie_break_by_vid() {
700        let scores: HashMap<u64, f32> = [(7u64, 1.0f32), (3, 1.0)].into_iter().collect();
701        let top = SparseVectorIndex::top_k_from_scores(scores, 2);
702        // Equal scores → lower vid first (deterministic).
703        assert_eq!(top[0].0.as_u64(), 3);
704        assert_eq!(top[1].0.as_u64(), 7);
705    }
706
707    #[test]
708    fn test_top_k_empty() {
709        assert!(SparseVectorIndex::top_k_from_scores(HashMap::new(), 5).is_empty());
710    }
711
712    #[test]
713    fn test_quantize_all_zero_term_no_nan() {
714        let (codes, scale, max_impact) = quantize_term(&[0.0, 0.0, 0.0]);
715        assert_eq!(codes, vec![0, 0, 0]);
716        assert_eq!(scale, 0.0);
717        assert_eq!(max_impact, 0.0);
718        assert!(!scale.is_nan() && !max_impact.is_nan());
719    }
720
721    #[test]
722    fn test_quantize_negative_weights_clamp_to_zero() {
723        // Learned-sparse weights are non-negative; a stray negative has no code.
724        let (codes, _scale, max_impact) = quantize_term(&[-1.0, -0.5]);
725        assert_eq!(codes, vec![0, 0]);
726        assert_eq!(max_impact, 0.0);
727    }
728
729    #[test]
730    fn test_quantize_max_weight_maps_to_top_code() {
731        let (codes, scale, max_impact) = quantize_term(&[0.1, 2.0, 1.0]);
732        // The maximum weight quantizes to the top code (255).
733        assert_eq!(codes[1], 255);
734        // max_impact is the dequantized top code and bounds every dequantized
735        // weight (the rank-safety invariant for future block-max pruning).
736        for (j, &w) in [0.1f32, 2.0, 1.0].iter().enumerate() {
737            assert!(dequantize(codes[j], scale) <= max_impact + f32::EPSILON);
738            // Round-trip error is bounded by half a quantization step.
739            assert!((dequantize(codes[j], scale) - w).abs() <= scale / 2.0 + 1e-6);
740        }
741    }
742
743    proptest::proptest! {
744        #[test]
745        fn prop_quantize_roundtrip_and_bound(
746            weights in proptest::collection::vec(0.0f32..1000.0, 1..64)
747        ) {
748            let (codes, scale, max_impact) = quantize_term(&weights);
749            proptest::prop_assert_eq!(codes.len(), weights.len());
750            for (j, &w) in weights.iter().enumerate() {
751                let dq = dequantize(codes[j], scale);
752                // max_impact upper-bounds every dequantized weight.
753                proptest::prop_assert!(dq <= max_impact + 1e-4);
754                // Reconstruction is within half a step of the original.
755                proptest::prop_assert!((dq - w).abs() <= scale / 2.0 + 1e-3);
756                proptest::prop_assert!(dq.is_finite());
757            }
758        }
759    }
760}