Skip to main content

uni_store/storage/
inverted_index.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3// Rust guideline compliant
4
5//! Inverted index implementation for set membership queries.
6//!
7//! Provides efficient `ANY(x IN list WHERE x IN allowed)` queries by
8//! maintaining a term-to-VID mapping. Supports both full rebuilds and
9//! incremental updates for optimal performance during mutations.
10
11use crate::storage::vertex::VertexDataset;
12use anyhow::{Result, anyhow};
13use arrow_array::types::UInt64Type;
14use arrow_array::{Array, ListArray, RecordBatchIterator, StringArray, UInt64Array};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use futures::TryStreamExt;
17use lance::Dataset;
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tracing::{debug, info, instrument};
22use uni_common::core::id::Vid;
23use uni_common::core::schema::InvertedIndexConfig;
24
25/// Default memory limit for postings accumulation (256 MB)
26const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
27
28/// Estimate memory usage of a postings map
29fn estimated_postings_memory(postings: &HashMap<String, Vec<u64>>) -> usize {
30    postings
31        .iter()
32        .map(|(k, v)| k.len() + std::mem::size_of::<Vec<u64>>() + v.len() * 8)
33        .sum()
34}
35
36/// Merge multiple postings segments into one
37fn merge_postings_segments(segments: Vec<HashMap<String, Vec<u64>>>) -> HashMap<String, Vec<u64>> {
38    let mut merged: HashMap<String, Vec<u64>> = HashMap::new();
39    for segment in segments {
40        for (term, vids) in segment {
41            merged.entry(term).or_default().extend(vids);
42        }
43    }
44    merged
45}
46
47pub struct InvertedIndex {
48    dataset: Option<Dataset>,
49    base_uri: String,
50    label: String,
51    property: String,
52    config: InvertedIndexConfig,
53}
54
55impl InvertedIndex {
56    pub async fn new(base_uri: &str, config: InvertedIndexConfig) -> Result<Self> {
57        let path = format!(
58            "{}/indexes/{}/{}_inverted",
59            base_uri, config.label, config.property
60        );
61
62        let dataset = (Dataset::open(&path).await).ok();
63
64        Ok(Self {
65            dataset,
66            base_uri: base_uri.to_string(),
67            label: config.label.clone(),
68            property: config.property.clone(),
69            config,
70        })
71    }
72
73    pub async fn create_if_missing(&mut self) -> Result<()> {
74        if self.dataset.is_some() {
75            return Ok(());
76        }
77        Ok(())
78    }
79
80    pub async fn build_from_dataset(
81        &mut self,
82        vertex_dataset: &VertexDataset,
83        progress: impl Fn(usize),
84    ) -> Result<()> {
85        let mut postings: HashMap<String, Vec<u64>> = HashMap::new();
86        let mut temp_segments: Vec<HashMap<String, Vec<u64>>> = Vec::new();
87        let mut count = 0;
88        let max_memory = DEFAULT_MAX_POSTINGS_MEMORY;
89
90        debug!(property = %self.property, "Building inverted index from dataset");
91
92        if let Ok(ds) = vertex_dataset.open().await {
93            let scanner = ds.scan();
94            let mut stream = scanner.try_into_stream().await?;
95            while let Some(batch) = stream.try_next().await? {
96                debug!(rows = batch.num_rows(), "Processing batch");
97                let vid_col = batch
98                    .column_by_name("_vid")
99                    .ok_or_else(|| anyhow!("Missing _vid"))?
100                    .as_any()
101                    .downcast_ref::<UInt64Array>()
102                    .ok_or_else(|| anyhow!("Invalid _vid type"))?;
103
104                let term_col = batch
105                    .column_by_name(&self.property)
106                    .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
107
108                let list_array =
109                    term_col
110                        .as_any()
111                        .downcast_ref::<ListArray>()
112                        .ok_or_else(|| {
113                            anyhow!(
114                                "Property {} must be List<String>, got {:?}",
115                                self.property,
116                                term_col.data_type()
117                            )
118                        })?;
119
120                let values = list_array
121                    .values()
122                    .as_any()
123                    .downcast_ref::<StringArray>()
124                    .ok_or_else(|| anyhow!("Property {} must be List<String>", self.property))?;
125
126                for i in 0..batch.num_rows() {
127                    let vid = vid_col.value(i);
128
129                    if list_array.is_null(i) {
130                        continue;
131                    }
132
133                    let start = list_array.value_offsets()[i] as usize;
134                    let end = list_array.value_offsets()[i + 1] as usize;
135
136                    let mut terms = HashSet::new();
137                    for j in start..end {
138                        if !values.is_null(j) {
139                            let term = values.value(j);
140                            let term = if self.config.normalize {
141                                term.to_lowercase().trim().to_string()
142                            } else {
143                                term.to_string()
144                            };
145                            terms.insert(term);
146                        }
147                    }
148
149                    if terms.len() > self.config.max_terms_per_doc {
150                        // Truncate logic if needed
151                    }
152
153                    for term in terms {
154                        postings.entry(term).or_default().push(vid);
155                    }
156
157                    count += 1;
158                    if count % 10_000 == 0 {
159                        progress(count);
160                    }
161                }
162
163                // Check if we've exceeded memory limit - flush to temp segment
164                if max_memory > 0 && estimated_postings_memory(&postings) > max_memory {
165                    debug!(
166                        segment = temp_segments.len(),
167                        terms = postings.len(),
168                        "Flushing postings segment due to memory limit"
169                    );
170                    temp_segments.push(std::mem::take(&mut postings));
171                }
172            }
173        } else {
174            debug!("Vertex dataset not found, creating empty index");
175        }
176
177        debug!(
178            terms = postings.len(),
179            segments = temp_segments.len(),
180            "Built inverted index"
181        );
182
183        // Merge and write
184        if temp_segments.is_empty() {
185            // Small dataset: write directly
186            self.write_postings(postings).await?;
187        } else {
188            // Large dataset: merge all segments
189            temp_segments.push(postings); // add remaining
190            info!(segments = temp_segments.len(), "Merging postings segments");
191            let merged = merge_postings_segments(temp_segments);
192            debug!(final_terms = merged.len(), "Merged postings");
193            self.write_postings(merged).await?;
194        }
195
196        Ok(())
197    }
198
199    async fn write_postings(&mut self, postings: HashMap<String, Vec<u64>>) -> Result<()> {
200        let mut terms = Vec::with_capacity(postings.len());
201        let mut vid_lists = Vec::with_capacity(postings.len());
202
203        for (term, vids) in postings {
204            terms.push(term);
205            vid_lists.push(Some(vids.into_iter().map(Some).collect::<Vec<_>>()));
206        }
207
208        let term_array = StringArray::from(terms);
209        let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
210
211        let batch = arrow_array::RecordBatch::try_from_iter(vec![
212            ("term", Arc::new(term_array) as Arc<dyn arrow_array::Array>),
213            (
214                "vids",
215                Arc::new(vid_list_array) as Arc<dyn arrow_array::Array>,
216            ),
217        ])?;
218
219        let path = format!(
220            "{}/indexes/{}/{}_inverted",
221            self.base_uri, self.label, self.property
222        );
223        let write_params = lance::dataset::WriteParams {
224            mode: lance::dataset::WriteMode::Overwrite,
225            ..Default::default()
226        };
227
228        let iterator = RecordBatchIterator::new(
229            vec![Ok(batch)],
230            Arc::new(ArrowSchema::new(vec![
231                Field::new("term", DataType::Utf8, false),
232                Field::new(
233                    "vids",
234                    DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
235                    false,
236                ),
237            ])),
238        );
239
240        let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
241        self.dataset = Some(ds);
242
243        Ok(())
244    }
245
246    pub async fn query_any(&self, terms: &[String]) -> Result<Vec<Vid>> {
247        let Some(ds) = &self.dataset else {
248            debug!("Inverted index not initialized, returning empty result");
249            return Ok(Vec::new());
250        };
251
252        let normalized: Vec<String> = if self.config.normalize {
253            terms
254                .iter()
255                .map(|t| t.to_lowercase().trim().to_string())
256                .collect()
257        } else {
258            terms.to_vec()
259        };
260
261        if normalized.is_empty() {
262            return Ok(Vec::new());
263        }
264
265        let filter = normalized
266            .iter()
267            .map(|t| format!("term = '{}'", t.replace("'", "''")))
268            .collect::<Vec<_>>()
269            .join(" OR ");
270
271        debug!(filter = %filter, "Querying inverted index");
272
273        let mut scanner = ds.scan();
274        scanner.filter(&filter)?;
275
276        let mut stream = scanner.try_into_stream().await?;
277        let mut result_set: HashSet<u64> = HashSet::new();
278
279        while let Some(batch) = stream.try_next().await? {
280            let vids_col = batch
281                .column_by_name("vids")
282                .ok_or_else(|| anyhow!("Missing vids column"))?
283                .as_any()
284                .downcast_ref::<ListArray>()
285                .ok_or_else(|| anyhow!("Invalid vids column"))?;
286
287            for i in 0..batch.num_rows() {
288                if vids_col.is_null(i) {
289                    continue;
290                }
291
292                let vids_array = vids_col.value(i);
293                let vids = vids_array
294                    .as_any()
295                    .downcast_ref::<UInt64Array>()
296                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
297
298                for vid in vids.iter().flatten() {
299                    result_set.insert(vid);
300                }
301            }
302        }
303
304        debug!(count = result_set.len(), "Found matching VIDs");
305
306        Ok(result_set.into_iter().map(Vid::from).collect())
307    }
308
309    /// Loads all postings from the existing dataset.
310    ///
311    /// Returns a map of term to list of VIDs. Returns an empty map if the
312    /// dataset doesn't exist yet.
313    #[instrument(skip(self), level = "debug")]
314    async fn load_postings(&self) -> Result<HashMap<String, HashSet<u64>>> {
315        let Some(ds) = &self.dataset else {
316            return Ok(HashMap::new());
317        };
318
319        let mut postings: HashMap<String, HashSet<u64>> = HashMap::new();
320        let scanner = ds.scan();
321        let mut stream = scanner.try_into_stream().await?;
322
323        while let Some(batch) = stream.try_next().await? {
324            let term_col = batch
325                .column_by_name("term")
326                .ok_or_else(|| anyhow!("Missing term column"))?
327                .as_any()
328                .downcast_ref::<StringArray>()
329                .ok_or_else(|| anyhow!("Invalid term column type"))?;
330
331            let vids_col = batch
332                .column_by_name("vids")
333                .ok_or_else(|| anyhow!("Missing vids column"))?
334                .as_any()
335                .downcast_ref::<ListArray>()
336                .ok_or_else(|| anyhow!("Invalid vids column type"))?;
337
338            for i in 0..batch.num_rows() {
339                if term_col.is_null(i) || vids_col.is_null(i) {
340                    continue;
341                }
342
343                let term = term_col.value(i).to_string();
344                let vids_array = vids_col.value(i);
345                let vids = vids_array
346                    .as_any()
347                    .downcast_ref::<UInt64Array>()
348                    .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
349
350                let entry = postings.entry(term).or_default();
351                for vid in vids.iter().flatten() {
352                    entry.insert(vid);
353                }
354            }
355        }
356
357        Ok(postings)
358    }
359
360    /// Applies incremental updates to the inverted index.
361    ///
362    /// This method efficiently updates the index by:
363    /// 1. Loading existing postings
364    /// 2. Removing VIDs that have been deleted
365    /// 3. Adding new VIDs with their associated terms
366    /// 4. Writing the updated postings back
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if loading or writing postings fails.
371    #[instrument(skip(self, added, removed), level = "info", fields(
372        label = %self.label,
373        property = %self.property,
374        added_count = added.len(),
375        removed_count = removed.len()
376    ))]
377    pub async fn apply_incremental_updates(
378        &mut self,
379        added: &HashMap<Vid, Vec<String>>,
380        removed: &HashSet<Vid>,
381    ) -> Result<()> {
382        info!(
383            added = added.len(),
384            removed = removed.len(),
385            "Applying incremental updates to inverted index"
386        );
387
388        // Load existing postings
389        let mut postings = self.load_postings().await?;
390
391        // Remove VIDs that have been deleted
392        if !removed.is_empty() {
393            let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
394            for vids in postings.values_mut() {
395                vids.retain(|vid| !removed_u64.contains(vid));
396            }
397            // Remove empty terms
398            postings.retain(|_, vids| !vids.is_empty());
399        }
400
401        // Add new VIDs with their terms
402        for (vid, terms) in added {
403            let vid_u64 = vid.as_u64();
404            let normalized_terms: HashSet<String> = if self.config.normalize {
405                terms
406                    .iter()
407                    .map(|t| t.to_lowercase().trim().to_string())
408                    .collect()
409            } else {
410                terms.iter().cloned().collect()
411            };
412
413            // Respect max_terms_per_doc limit
414            let terms_to_add: Vec<_> = if normalized_terms.len() > self.config.max_terms_per_doc {
415                normalized_terms
416                    .into_iter()
417                    .take(self.config.max_terms_per_doc)
418                    .collect()
419            } else {
420                normalized_terms.into_iter().collect()
421            };
422
423            for term in terms_to_add {
424                postings.entry(term).or_default().insert(vid_u64);
425            }
426        }
427
428        // Convert HashSet<u64> to Vec<u64> for writing
429        let postings_vec: HashMap<String, Vec<u64>> = postings
430            .into_iter()
431            .map(|(term, vids)| (term, vids.into_iter().collect()))
432            .collect();
433
434        info!(terms = postings_vec.len(), "Writing updated postings");
435
436        self.write_postings(postings_vec).await?;
437        Ok(())
438    }
439
440    /// Extracts terms from a JSON value representing a `List<String>`.
441    ///
442    /// Returns `None` if the value is not an array of strings.
443    pub fn extract_terms_from_value(&self, value: &Value) -> Option<Vec<String>> {
444        let arr = value.as_array()?;
445        let terms: Vec<String> = arr
446            .iter()
447            .filter_map(|v| v.as_str().map(ToString::to_string))
448            .collect();
449
450        if terms.is_empty() { None } else { Some(terms) }
451    }
452
453    /// Returns true if the index dataset exists.
454    pub fn is_initialized(&self) -> bool {
455        self.dataset.is_some()
456    }
457
458    /// Returns the property name this index is built on.
459    pub fn property(&self) -> &str {
460        &self.property
461    }
462}