Skip to main content

uni_store/storage/
inverted_index.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3// Rust guideline compliant
4
5//! Inverted index implementation for set membership queries.
6//!
7//! Provides efficient `ANY(x IN list WHERE x IN allowed)` queries by
8//! maintaining a term-to-VID mapping. Supports both full rebuilds and
9//! incremental updates for optimal performance during mutations.
10
11use crate::storage::vertex::VertexDataset;
12use anyhow::{Result, anyhow};
13use arrow_array::types::UInt64Type;
14use arrow_array::{Array, ListArray, RecordBatchIterator, StringArray, UInt64Array};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use futures::TryStreamExt;
17use lance::Dataset;
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tracing::{debug, info, instrument};
22use uni_common::core::id::Vid;
23use uni_common::core::schema::InvertedIndexConfig;
24
25/// Default memory limit for postings accumulation (256 MB)
26const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
27
28/// Estimate memory usage of a postings map
29fn estimated_postings_memory(postings: &HashMap<String, Vec<u64>>) -> usize {
30    postings
31        .iter()
32        .map(|(k, v)| k.len() + std::mem::size_of::<Vec<u64>>() + v.len() * 8)
33        .sum()
34}
35
36/// Merge multiple postings segments into one
37fn merge_postings_segments(segments: Vec<HashMap<String, Vec<u64>>>) -> HashMap<String, Vec<u64>> {
38    let mut merged: HashMap<String, Vec<u64>> = HashMap::new();
39    for segment in segments {
40        for (term, vids) in segment {
41            merged.entry(term).or_default().extend(vids);
42        }
43    }
44    merged
45}
46
47/// Term-to-VID inverted index for efficient set-membership queries.
48///
49/// Supports both full rebuilds from a vertex dataset and incremental updates
50/// for optimal performance during small mutations.
51pub struct InvertedIndex {
52    dataset: Option<Dataset>,
53    base_uri: String,
54    label: String,
55    property: String,
56    config: InvertedIndexConfig,
57}
58
59impl std::fmt::Debug for InvertedIndex {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("InvertedIndex")
62            .field("base_uri", &self.base_uri)
63            .field("label", &self.label)
64            .field("property", &self.property)
65            .field("initialized", &self.dataset.is_some())
66            .finish_non_exhaustive()
67    }
68}
69
70impl InvertedIndex {
71    /// Open or initialize an inverted index at `base_uri` for the given config.
72    pub async fn new(base_uri: &str, config: InvertedIndexConfig) -> Result<Self> {
73        let path = format!(
74            "{}/indexes/{}/{}_inverted",
75            base_uri, config.label, config.property
76        );
77
78        let dataset = (Dataset::open(&path).await).ok();
79
80        Ok(Self {
81            dataset,
82            base_uri: base_uri.to_string(),
83            label: config.label.clone(),
84            property: config.property.clone(),
85            config,
86        })
87    }
88
89    /// Rebuild the index from scratch by scanning `vertex_dataset`.
90    ///
91    /// Calls `progress` every 10 000 vertices with the running document count.
92    /// Uses segmented accumulation to stay within the 256 MB memory limit.
93    pub async fn build_from_dataset(
94        &mut self,
95        vertex_dataset: &VertexDataset,
96        progress: impl Fn(usize),
97    ) -> Result<()> {
98        let mut postings: HashMap<String, Vec<u64>> = HashMap::new();
99        let mut temp_segments: Vec<HashMap<String, Vec<u64>>> = Vec::new();
100        let mut count = 0;
101        let max_memory = DEFAULT_MAX_POSTINGS_MEMORY;
102
103        debug!(property = %self.property, "Building inverted index from dataset");
104
105        if let Ok(ds) = vertex_dataset.open().await {
106            let scanner = ds.scan();
107            let mut stream = scanner.try_into_stream().await?;
108            while let Some(batch) = stream.try_next().await? {
109                debug!(rows = batch.num_rows(), "Processing batch");
110                let vid_col = batch
111                    .column_by_name("_vid")
112                    .ok_or_else(|| anyhow!("Missing _vid"))?
113                    .as_any()
114                    .downcast_ref::<UInt64Array>()
115                    .ok_or_else(|| anyhow!("Invalid _vid type"))?;
116
117                let term_col = batch
118                    .column_by_name(&self.property)
119                    .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
120
121                let list_array =
122                    term_col
123                        .as_any()
124                        .downcast_ref::<ListArray>()
125                        .ok_or_else(|| {
126                            anyhow!(
127                                "Property {} must be List<String>, got {:?}",
128                                self.property,
129                                term_col.data_type()
130                            )
131                        })?;
132
133                let values = list_array
134                    .values()
135                    .as_any()
136                    .downcast_ref::<StringArray>()
137                    .ok_or_else(|| anyhow!("Property {} must be List<String>", self.property))?;
138
139                for i in 0..batch.num_rows() {
140                    let vid = vid_col.value(i);
141
142                    if list_array.is_null(i) {
143                        continue;
144                    }
145
146                    let start = list_array.value_offsets()[i] as usize;
147                    let end = list_array.value_offsets()[i + 1] as usize;
148
149                    let mut terms = HashSet::new();
150                    for j in start..end {
151                        if !values.is_null(j) {
152                            let term = values.value(j);
153                            let term = if self.config.normalize {
154                                term.to_lowercase().trim().to_string()
155                            } else {
156                                term.to_string()
157                            };
158                            terms.insert(term);
159                        }
160                    }
161
162                    if terms.len() > self.config.max_terms_per_doc {
163                        // Truncate logic if needed
164                    }
165
166                    for term in terms {
167                        postings.entry(term).or_default().push(vid);
168                    }
169
170                    count += 1;
171                    if count % 10_000 == 0 {
172                        progress(count);
173                    }
174                }
175
176                // Check if we've exceeded memory limit - flush to temp segment
177                if max_memory > 0 && estimated_postings_memory(&postings) > max_memory {
178                    debug!(
179                        segment = temp_segments.len(),
180                        terms = postings.len(),
181                        "Flushing postings segment due to memory limit"
182                    );
183                    temp_segments.push(std::mem::take(&mut postings));
184                }
185            }
186        } else {
187            debug!("Vertex dataset not found, creating empty index");
188        }
189
190        debug!(
191            terms = postings.len(),
192            segments = temp_segments.len(),
193            "Built inverted index"
194        );
195
196        // Merge and write
197        if temp_segments.is_empty() {
198            // Small dataset: write directly
199            self.write_postings(postings).await?;
200        } else {
201            // Large dataset: merge all segments
202            temp_segments.push(postings); // add remaining
203            info!(segments = temp_segments.len(), "Merging postings segments");
204            let merged = merge_postings_segments(temp_segments);
205            debug!(final_terms = merged.len(), "Merged postings");
206            self.write_postings(merged).await?;
207        }
208
209        Ok(())
210    }
211
212    /// Overwrite the on-disk postings with the provided map.
213    async fn write_postings(&mut self, postings: HashMap<String, Vec<u64>>) -> Result<()> {
214        let mut terms = Vec::with_capacity(postings.len());
215        let mut vid_lists = Vec::with_capacity(postings.len());
216
217        for (term, vids) in postings {
218            terms.push(term);
219            vid_lists.push(Some(vids.into_iter().map(Some).collect::<Vec<_>>()));
220        }
221
222        let term_array = StringArray::from(terms);
223        let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
224
225        let batch = arrow_array::RecordBatch::try_from_iter(vec![
226            ("term", Arc::new(term_array) as Arc<dyn arrow_array::Array>),
227            (
228                "vids",
229                Arc::new(vid_list_array) as Arc<dyn arrow_array::Array>,
230            ),
231        ])?;
232
233        let path = format!(
234            "{}/indexes/{}/{}_inverted",
235            self.base_uri, self.label, self.property
236        );
237        let write_params = lance::dataset::WriteParams {
238            mode: lance::dataset::WriteMode::Overwrite,
239            ..Default::default()
240        };
241
242        let iterator = RecordBatchIterator::new(
243            vec![Ok(batch)],
244            Arc::new(ArrowSchema::new(vec![
245                Field::new("term", DataType::Utf8, false),
246                Field::new(
247                    "vids",
248                    DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
249                    false,
250                ),
251            ])),
252        );
253
254        let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
255        self.dataset = Some(ds);
256
257        Ok(())
258    }
259
260    /// Return all VIDs whose term list intersects `terms` (OR semantics).
261    pub async fn query_any(&self, terms: &[String]) -> Result<Vec<Vid>> {
262        let Some(ds) = &self.dataset else {
263            debug!("Inverted index not initialized, returning empty result");
264            return Ok(Vec::new());
265        };
266
267        let normalized: Vec<String> = if self.config.normalize {
268            terms
269                .iter()
270                .map(|t| t.to_lowercase().trim().to_string())
271                .collect()
272        } else {
273            terms.to_vec()
274        };
275
276        if normalized.is_empty() {
277            return Ok(Vec::new());
278        }
279
280        let filter = normalized
281            .iter()
282            .map(|t| format!("term = '{}'", t.replace("'", "''")))
283            .collect::<Vec<_>>()
284            .join(" OR ");
285
286        debug!(filter = %filter, "Querying inverted index");
287
288        let mut scanner = ds.scan();
289        scanner.filter(&filter)?;
290
291        let mut stream = scanner.try_into_stream().await?;
292        let mut result_set: HashSet<u64> = HashSet::new();
293
294        while let Some(batch) = stream.try_next().await? {
295            let vids_col = batch
296                .column_by_name("vids")
297                .ok_or_else(|| anyhow!("Missing vids column"))?
298                .as_any()
299                .downcast_ref::<ListArray>()
300                .ok_or_else(|| anyhow!("Invalid vids column"))?;
301
302            for i in 0..batch.num_rows() {
303                if vids_col.is_null(i) {
304                    continue;
305                }
306
307                let vids_array = vids_col.value(i);
308                let vids = vids_array
309                    .as_any()
310                    .downcast_ref::<UInt64Array>()
311                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
312
313                for vid in vids.iter().flatten() {
314                    result_set.insert(vid);
315                }
316            }
317        }
318
319        debug!(count = result_set.len(), "Found matching VIDs");
320
321        Ok(result_set.into_iter().map(Vid::from).collect())
322    }
323
324    /// Loads all postings from the existing dataset.
325    ///
326    /// Returns a map of term to list of VIDs. Returns an empty map if the
327    /// dataset doesn't exist yet.
328    #[instrument(skip(self), level = "debug")]
329    async fn load_postings(&self) -> Result<HashMap<String, HashSet<u64>>> {
330        let Some(ds) = &self.dataset else {
331            return Ok(HashMap::new());
332        };
333
334        let mut postings: HashMap<String, HashSet<u64>> = HashMap::new();
335        let scanner = ds.scan();
336        let mut stream = scanner.try_into_stream().await?;
337
338        while let Some(batch) = stream.try_next().await? {
339            let term_col = batch
340                .column_by_name("term")
341                .ok_or_else(|| anyhow!("Missing term column"))?
342                .as_any()
343                .downcast_ref::<StringArray>()
344                .ok_or_else(|| anyhow!("Invalid term column type"))?;
345
346            let vids_col = batch
347                .column_by_name("vids")
348                .ok_or_else(|| anyhow!("Missing vids column"))?
349                .as_any()
350                .downcast_ref::<ListArray>()
351                .ok_or_else(|| anyhow!("Invalid vids column type"))?;
352
353            for i in 0..batch.num_rows() {
354                if term_col.is_null(i) || vids_col.is_null(i) {
355                    continue;
356                }
357
358                let term = term_col.value(i).to_string();
359                let vids_array = vids_col.value(i);
360                let vids = vids_array
361                    .as_any()
362                    .downcast_ref::<UInt64Array>()
363                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
364
365                let entry = postings.entry(term).or_default();
366                for vid in vids.iter().flatten() {
367                    entry.insert(vid);
368                }
369            }
370        }
371
372        Ok(postings)
373    }
374
375    /// Applies incremental updates to the inverted index.
376    ///
377    /// This method efficiently updates the index by:
378    /// 1. Loading existing postings
379    /// 2. Removing VIDs that have been deleted
380    /// 3. Adding new VIDs with their associated terms
381    /// 4. Writing the updated postings back
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if loading or writing postings fails.
386    #[instrument(skip(self, added, removed), level = "info", fields(
387        label = %self.label,
388        property = %self.property,
389        added_count = added.len(),
390        removed_count = removed.len()
391    ))]
392    pub async fn apply_incremental_updates(
393        &mut self,
394        added: &HashMap<Vid, Vec<String>>,
395        removed: &HashSet<Vid>,
396    ) -> Result<()> {
397        info!(
398            added = added.len(),
399            removed = removed.len(),
400            "Applying incremental updates to inverted index"
401        );
402
403        // Load existing postings
404        let mut postings = self.load_postings().await?;
405
406        // Remove VIDs that have been deleted
407        if !removed.is_empty() {
408            let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
409            for vids in postings.values_mut() {
410                vids.retain(|vid| !removed_u64.contains(vid));
411            }
412            // Remove empty terms
413            postings.retain(|_, vids| !vids.is_empty());
414        }
415
416        // Add new VIDs with their terms
417        for (vid, terms) in added {
418            let vid_u64 = vid.as_u64();
419            let normalized_terms: HashSet<String> = if self.config.normalize {
420                terms
421                    .iter()
422                    .map(|t| t.to_lowercase().trim().to_string())
423                    .collect()
424            } else {
425                terms.iter().cloned().collect()
426            };
427
428            // Respect max_terms_per_doc limit
429            let terms_to_add: Vec<_> = if normalized_terms.len() > self.config.max_terms_per_doc {
430                normalized_terms
431                    .into_iter()
432                    .take(self.config.max_terms_per_doc)
433                    .collect()
434            } else {
435                normalized_terms.into_iter().collect()
436            };
437
438            for term in terms_to_add {
439                postings.entry(term).or_default().insert(vid_u64);
440            }
441        }
442
443        // Convert HashSet<u64> to Vec<u64> for writing
444        let postings_vec: HashMap<String, Vec<u64>> = postings
445            .into_iter()
446            .map(|(term, vids)| (term, vids.into_iter().collect()))
447            .collect();
448
449        info!(terms = postings_vec.len(), "Writing updated postings");
450
451        self.write_postings(postings_vec).await?;
452        Ok(())
453    }
454
455    /// Extracts terms from a JSON value representing a `List<String>`.
456    ///
457    /// Returns `None` if the value is not an array of strings.
458    pub fn extract_terms_from_value(&self, value: &Value) -> Option<Vec<String>> {
459        let arr = value.as_array()?;
460        let terms: Vec<String> = arr
461            .iter()
462            .filter_map(|v| v.as_str().map(ToString::to_string))
463            .collect();
464
465        if terms.is_empty() { None } else { Some(terms) }
466    }
467
468    /// Returns true if the index dataset exists.
469    pub fn is_initialized(&self) -> bool {
470        self.dataset.is_some()
471    }
472
473    /// Returns the property name this index is built on.
474    pub fn property(&self) -> &str {
475        &self.property
476    }
477}