1use crate::storage::vertex::VertexDataset;
12use anyhow::{Result, anyhow};
13use arrow_array::types::UInt64Type;
14use arrow_array::{Array, ListArray, RecordBatchIterator, StringArray, UInt64Array};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use futures::TryStreamExt;
17use lance::Dataset;
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tracing::{debug, info, instrument};
22use uni_common::core::id::Vid;
23use uni_common::core::schema::InvertedIndexConfig;
24
25const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
27
28fn estimated_postings_memory(postings: &HashMap<String, Vec<u64>>) -> usize {
30 postings
31 .iter()
32 .map(|(k, v)| k.len() + std::mem::size_of::<Vec<u64>>() + v.len() * 8)
33 .sum()
34}
35
36fn merge_postings_segments(segments: Vec<HashMap<String, Vec<u64>>>) -> HashMap<String, Vec<u64>> {
38 let mut merged: HashMap<String, Vec<u64>> = HashMap::new();
39 for segment in segments {
40 for (term, vids) in segment {
41 merged.entry(term).or_default().extend(vids);
42 }
43 }
44 merged
45}
46
47pub struct InvertedIndex {
48 dataset: Option<Dataset>,
49 base_uri: String,
50 label: String,
51 property: String,
52 config: InvertedIndexConfig,
53}
54
55impl InvertedIndex {
56 pub async fn new(base_uri: &str, config: InvertedIndexConfig) -> Result<Self> {
57 let path = format!(
58 "{}/indexes/{}/{}_inverted",
59 base_uri, config.label, config.property
60 );
61
62 let dataset = (Dataset::open(&path).await).ok();
63
64 Ok(Self {
65 dataset,
66 base_uri: base_uri.to_string(),
67 label: config.label.clone(),
68 property: config.property.clone(),
69 config,
70 })
71 }
72
73 pub async fn create_if_missing(&mut self) -> Result<()> {
74 if self.dataset.is_some() {
75 return Ok(());
76 }
77 Ok(())
78 }
79
80 pub async fn build_from_dataset(
81 &mut self,
82 vertex_dataset: &VertexDataset,
83 progress: impl Fn(usize),
84 ) -> Result<()> {
85 let mut postings: HashMap<String, Vec<u64>> = HashMap::new();
86 let mut temp_segments: Vec<HashMap<String, Vec<u64>>> = Vec::new();
87 let mut count = 0;
88 let max_memory = DEFAULT_MAX_POSTINGS_MEMORY;
89
90 debug!(property = %self.property, "Building inverted index from dataset");
91
92 if let Ok(ds) = vertex_dataset.open().await {
93 let scanner = ds.scan();
94 let mut stream = scanner.try_into_stream().await?;
95 while let Some(batch) = stream.try_next().await? {
96 debug!(rows = batch.num_rows(), "Processing batch");
97 let vid_col = batch
98 .column_by_name("_vid")
99 .ok_or_else(|| anyhow!("Missing _vid"))?
100 .as_any()
101 .downcast_ref::<UInt64Array>()
102 .ok_or_else(|| anyhow!("Invalid _vid type"))?;
103
104 let term_col = batch
105 .column_by_name(&self.property)
106 .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
107
108 let list_array =
109 term_col
110 .as_any()
111 .downcast_ref::<ListArray>()
112 .ok_or_else(|| {
113 anyhow!(
114 "Property {} must be List<String>, got {:?}",
115 self.property,
116 term_col.data_type()
117 )
118 })?;
119
120 let values = list_array
121 .values()
122 .as_any()
123 .downcast_ref::<StringArray>()
124 .ok_or_else(|| anyhow!("Property {} must be List<String>", self.property))?;
125
126 for i in 0..batch.num_rows() {
127 let vid = vid_col.value(i);
128
129 if list_array.is_null(i) {
130 continue;
131 }
132
133 let start = list_array.value_offsets()[i] as usize;
134 let end = list_array.value_offsets()[i + 1] as usize;
135
136 let mut terms = HashSet::new();
137 for j in start..end {
138 if !values.is_null(j) {
139 let term = values.value(j);
140 let term = if self.config.normalize {
141 term.to_lowercase().trim().to_string()
142 } else {
143 term.to_string()
144 };
145 terms.insert(term);
146 }
147 }
148
149 if terms.len() > self.config.max_terms_per_doc {
150 }
152
153 for term in terms {
154 postings.entry(term).or_default().push(vid);
155 }
156
157 count += 1;
158 if count % 10_000 == 0 {
159 progress(count);
160 }
161 }
162
163 if max_memory > 0 && estimated_postings_memory(&postings) > max_memory {
165 debug!(
166 segment = temp_segments.len(),
167 terms = postings.len(),
168 "Flushing postings segment due to memory limit"
169 );
170 temp_segments.push(std::mem::take(&mut postings));
171 }
172 }
173 } else {
174 debug!("Vertex dataset not found, creating empty index");
175 }
176
177 debug!(
178 terms = postings.len(),
179 segments = temp_segments.len(),
180 "Built inverted index"
181 );
182
183 if temp_segments.is_empty() {
185 self.write_postings(postings).await?;
187 } else {
188 temp_segments.push(postings); info!(segments = temp_segments.len(), "Merging postings segments");
191 let merged = merge_postings_segments(temp_segments);
192 debug!(final_terms = merged.len(), "Merged postings");
193 self.write_postings(merged).await?;
194 }
195
196 Ok(())
197 }
198
199 async fn write_postings(&mut self, postings: HashMap<String, Vec<u64>>) -> Result<()> {
200 let mut terms = Vec::with_capacity(postings.len());
201 let mut vid_lists = Vec::with_capacity(postings.len());
202
203 for (term, vids) in postings {
204 terms.push(term);
205 vid_lists.push(Some(vids.into_iter().map(Some).collect::<Vec<_>>()));
206 }
207
208 let term_array = StringArray::from(terms);
209 let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
210
211 let batch = arrow_array::RecordBatch::try_from_iter(vec![
212 ("term", Arc::new(term_array) as Arc<dyn arrow_array::Array>),
213 (
214 "vids",
215 Arc::new(vid_list_array) as Arc<dyn arrow_array::Array>,
216 ),
217 ])?;
218
219 let path = format!(
220 "{}/indexes/{}/{}_inverted",
221 self.base_uri, self.label, self.property
222 );
223 let write_params = lance::dataset::WriteParams {
224 mode: lance::dataset::WriteMode::Overwrite,
225 ..Default::default()
226 };
227
228 let iterator = RecordBatchIterator::new(
229 vec![Ok(batch)],
230 Arc::new(ArrowSchema::new(vec![
231 Field::new("term", DataType::Utf8, false),
232 Field::new(
233 "vids",
234 DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
235 false,
236 ),
237 ])),
238 );
239
240 let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
241 self.dataset = Some(ds);
242
243 Ok(())
244 }
245
246 pub async fn query_any(&self, terms: &[String]) -> Result<Vec<Vid>> {
247 let Some(ds) = &self.dataset else {
248 debug!("Inverted index not initialized, returning empty result");
249 return Ok(Vec::new());
250 };
251
252 let normalized: Vec<String> = if self.config.normalize {
253 terms
254 .iter()
255 .map(|t| t.to_lowercase().trim().to_string())
256 .collect()
257 } else {
258 terms.to_vec()
259 };
260
261 if normalized.is_empty() {
262 return Ok(Vec::new());
263 }
264
265 let filter = normalized
266 .iter()
267 .map(|t| format!("term = '{}'", t.replace("'", "''")))
268 .collect::<Vec<_>>()
269 .join(" OR ");
270
271 debug!(filter = %filter, "Querying inverted index");
272
273 let mut scanner = ds.scan();
274 scanner.filter(&filter)?;
275
276 let mut stream = scanner.try_into_stream().await?;
277 let mut result_set: HashSet<u64> = HashSet::new();
278
279 while let Some(batch) = stream.try_next().await? {
280 let vids_col = batch
281 .column_by_name("vids")
282 .ok_or_else(|| anyhow!("Missing vids column"))?
283 .as_any()
284 .downcast_ref::<ListArray>()
285 .ok_or_else(|| anyhow!("Invalid vids column"))?;
286
287 for i in 0..batch.num_rows() {
288 if vids_col.is_null(i) {
289 continue;
290 }
291
292 let vids_array = vids_col.value(i);
293 let vids = vids_array
294 .as_any()
295 .downcast_ref::<UInt64Array>()
296 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
297
298 for vid in vids.iter().flatten() {
299 result_set.insert(vid);
300 }
301 }
302 }
303
304 debug!(count = result_set.len(), "Found matching VIDs");
305
306 Ok(result_set.into_iter().map(Vid::from).collect())
307 }
308
309 #[instrument(skip(self), level = "debug")]
314 async fn load_postings(&self) -> Result<HashMap<String, HashSet<u64>>> {
315 let Some(ds) = &self.dataset else {
316 return Ok(HashMap::new());
317 };
318
319 let mut postings: HashMap<String, HashSet<u64>> = HashMap::new();
320 let scanner = ds.scan();
321 let mut stream = scanner.try_into_stream().await?;
322
323 while let Some(batch) = stream.try_next().await? {
324 let term_col = batch
325 .column_by_name("term")
326 .ok_or_else(|| anyhow!("Missing term column"))?
327 .as_any()
328 .downcast_ref::<StringArray>()
329 .ok_or_else(|| anyhow!("Invalid term column type"))?;
330
331 let vids_col = batch
332 .column_by_name("vids")
333 .ok_or_else(|| anyhow!("Missing vids column"))?
334 .as_any()
335 .downcast_ref::<ListArray>()
336 .ok_or_else(|| anyhow!("Invalid vids column type"))?;
337
338 for i in 0..batch.num_rows() {
339 if term_col.is_null(i) || vids_col.is_null(i) {
340 continue;
341 }
342
343 let term = term_col.value(i).to_string();
344 let vids_array = vids_col.value(i);
345 let vids = vids_array
346 .as_any()
347 .downcast_ref::<UInt64Array>()
348 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
349
350 let entry = postings.entry(term).or_default();
351 for vid in vids.iter().flatten() {
352 entry.insert(vid);
353 }
354 }
355 }
356
357 Ok(postings)
358 }
359
360 #[instrument(skip(self, added, removed), level = "info", fields(
372 label = %self.label,
373 property = %self.property,
374 added_count = added.len(),
375 removed_count = removed.len()
376 ))]
377 pub async fn apply_incremental_updates(
378 &mut self,
379 added: &HashMap<Vid, Vec<String>>,
380 removed: &HashSet<Vid>,
381 ) -> Result<()> {
382 info!(
383 added = added.len(),
384 removed = removed.len(),
385 "Applying incremental updates to inverted index"
386 );
387
388 let mut postings = self.load_postings().await?;
390
391 if !removed.is_empty() {
393 let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
394 for vids in postings.values_mut() {
395 vids.retain(|vid| !removed_u64.contains(vid));
396 }
397 postings.retain(|_, vids| !vids.is_empty());
399 }
400
401 for (vid, terms) in added {
403 let vid_u64 = vid.as_u64();
404 let normalized_terms: HashSet<String> = if self.config.normalize {
405 terms
406 .iter()
407 .map(|t| t.to_lowercase().trim().to_string())
408 .collect()
409 } else {
410 terms.iter().cloned().collect()
411 };
412
413 let terms_to_add: Vec<_> = if normalized_terms.len() > self.config.max_terms_per_doc {
415 normalized_terms
416 .into_iter()
417 .take(self.config.max_terms_per_doc)
418 .collect()
419 } else {
420 normalized_terms.into_iter().collect()
421 };
422
423 for term in terms_to_add {
424 postings.entry(term).or_default().insert(vid_u64);
425 }
426 }
427
428 let postings_vec: HashMap<String, Vec<u64>> = postings
430 .into_iter()
431 .map(|(term, vids)| (term, vids.into_iter().collect()))
432 .collect();
433
434 info!(terms = postings_vec.len(), "Writing updated postings");
435
436 self.write_postings(postings_vec).await?;
437 Ok(())
438 }
439
440 pub fn extract_terms_from_value(&self, value: &Value) -> Option<Vec<String>> {
444 let arr = value.as_array()?;
445 let terms: Vec<String> = arr
446 .iter()
447 .filter_map(|v| v.as_str().map(ToString::to_string))
448 .collect();
449
450 if terms.is_empty() { None } else { Some(terms) }
451 }
452
453 pub fn is_initialized(&self) -> bool {
455 self.dataset.is_some()
456 }
457
458 pub fn property(&self) -> &str {
460 &self.property
461 }
462}