1use crate::storage::vertex::VertexDataset;
12use anyhow::{Result, anyhow};
13use arrow_array::types::UInt64Type;
14use arrow_array::{Array, ListArray, RecordBatchIterator, StringArray, UInt64Array};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use futures::TryStreamExt;
17use lance::Dataset;
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tracing::{debug, info, instrument};
22use uni_common::core::id::Vid;
23use uni_common::core::schema::InvertedIndexConfig;
24
25const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
27
28fn estimated_postings_memory(postings: &HashMap<String, Vec<u64>>) -> usize {
30 postings
31 .iter()
32 .map(|(k, v)| k.len() + std::mem::size_of::<Vec<u64>>() + v.len() * 8)
33 .sum()
34}
35
36fn merge_postings_segments(segments: Vec<HashMap<String, Vec<u64>>>) -> HashMap<String, Vec<u64>> {
38 let mut merged: HashMap<String, Vec<u64>> = HashMap::new();
39 for segment in segments {
40 for (term, vids) in segment {
41 merged.entry(term).or_default().extend(vids);
42 }
43 }
44 merged
45}
46
47pub struct InvertedIndex {
52 dataset: Option<Dataset>,
53 base_uri: String,
54 label: String,
55 property: String,
56 config: InvertedIndexConfig,
57}
58
59impl std::fmt::Debug for InvertedIndex {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("InvertedIndex")
62 .field("base_uri", &self.base_uri)
63 .field("label", &self.label)
64 .field("property", &self.property)
65 .field("initialized", &self.dataset.is_some())
66 .finish_non_exhaustive()
67 }
68}
69
70impl InvertedIndex {
71 pub async fn new(base_uri: &str, config: InvertedIndexConfig) -> Result<Self> {
73 let path = format!(
74 "{}/indexes/{}/{}_inverted",
75 base_uri, config.label, config.property
76 );
77
78 let dataset = (Dataset::open(&path).await).ok();
79
80 Ok(Self {
81 dataset,
82 base_uri: base_uri.to_string(),
83 label: config.label.clone(),
84 property: config.property.clone(),
85 config,
86 })
87 }
88
89 pub async fn build_from_dataset(
94 &mut self,
95 vertex_dataset: &VertexDataset,
96 progress: impl Fn(usize),
97 ) -> Result<()> {
98 let mut postings: HashMap<String, Vec<u64>> = HashMap::new();
99 let mut temp_segments: Vec<HashMap<String, Vec<u64>>> = Vec::new();
100 let mut count = 0;
101 let max_memory = DEFAULT_MAX_POSTINGS_MEMORY;
102
103 debug!(property = %self.property, "Building inverted index from dataset");
104
105 if let Ok(ds) = vertex_dataset.open().await {
106 let scanner = ds.scan();
107 let mut stream = scanner.try_into_stream().await?;
108 while let Some(batch) = stream.try_next().await? {
109 debug!(rows = batch.num_rows(), "Processing batch");
110 let vid_col = batch
111 .column_by_name("_vid")
112 .ok_or_else(|| anyhow!("Missing _vid"))?
113 .as_any()
114 .downcast_ref::<UInt64Array>()
115 .ok_or_else(|| anyhow!("Invalid _vid type"))?;
116
117 let term_col = batch
118 .column_by_name(&self.property)
119 .ok_or_else(|| anyhow!("Missing property {}", self.property))?;
120
121 let list_array =
122 term_col
123 .as_any()
124 .downcast_ref::<ListArray>()
125 .ok_or_else(|| {
126 anyhow!(
127 "Property {} must be List<String>, got {:?}",
128 self.property,
129 term_col.data_type()
130 )
131 })?;
132
133 let values = list_array
134 .values()
135 .as_any()
136 .downcast_ref::<StringArray>()
137 .ok_or_else(|| anyhow!("Property {} must be List<String>", self.property))?;
138
139 for i in 0..batch.num_rows() {
140 let vid = vid_col.value(i);
141
142 if list_array.is_null(i) {
143 continue;
144 }
145
146 let start = list_array.value_offsets()[i] as usize;
147 let end = list_array.value_offsets()[i + 1] as usize;
148
149 let mut terms = HashSet::new();
150 for j in start..end {
151 if !values.is_null(j) {
152 let term = values.value(j);
153 let term = if self.config.normalize {
154 term.to_lowercase().trim().to_string()
155 } else {
156 term.to_string()
157 };
158 terms.insert(term);
159 }
160 }
161
162 if terms.len() > self.config.max_terms_per_doc {
163 }
165
166 for term in terms {
167 postings.entry(term).or_default().push(vid);
168 }
169
170 count += 1;
171 if count % 10_000 == 0 {
172 progress(count);
173 }
174 }
175
176 if max_memory > 0 && estimated_postings_memory(&postings) > max_memory {
178 debug!(
179 segment = temp_segments.len(),
180 terms = postings.len(),
181 "Flushing postings segment due to memory limit"
182 );
183 temp_segments.push(std::mem::take(&mut postings));
184 }
185 }
186 } else {
187 debug!("Vertex dataset not found, creating empty index");
188 }
189
190 debug!(
191 terms = postings.len(),
192 segments = temp_segments.len(),
193 "Built inverted index"
194 );
195
196 if temp_segments.is_empty() {
198 self.write_postings(postings).await?;
200 } else {
201 temp_segments.push(postings); info!(segments = temp_segments.len(), "Merging postings segments");
204 let merged = merge_postings_segments(temp_segments);
205 debug!(final_terms = merged.len(), "Merged postings");
206 self.write_postings(merged).await?;
207 }
208
209 Ok(())
210 }
211
212 async fn write_postings(&mut self, postings: HashMap<String, Vec<u64>>) -> Result<()> {
214 let mut terms = Vec::with_capacity(postings.len());
215 let mut vid_lists = Vec::with_capacity(postings.len());
216
217 for (term, vids) in postings {
218 terms.push(term);
219 vid_lists.push(Some(vids.into_iter().map(Some).collect::<Vec<_>>()));
220 }
221
222 let term_array = StringArray::from(terms);
223 let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
224
225 let batch = arrow_array::RecordBatch::try_from_iter(vec![
226 ("term", Arc::new(term_array) as Arc<dyn arrow_array::Array>),
227 (
228 "vids",
229 Arc::new(vid_list_array) as Arc<dyn arrow_array::Array>,
230 ),
231 ])?;
232
233 let path = format!(
234 "{}/indexes/{}/{}_inverted",
235 self.base_uri, self.label, self.property
236 );
237 let write_params = lance::dataset::WriteParams {
238 mode: lance::dataset::WriteMode::Overwrite,
239 ..Default::default()
240 };
241
242 let iterator = RecordBatchIterator::new(
243 vec![Ok(batch)],
244 Arc::new(ArrowSchema::new(vec![
245 Field::new("term", DataType::Utf8, false),
246 Field::new(
247 "vids",
248 DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
249 false,
250 ),
251 ])),
252 );
253
254 let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
255 self.dataset = Some(ds);
256
257 Ok(())
258 }
259
260 pub async fn query_any(&self, terms: &[String]) -> Result<Vec<Vid>> {
262 let Some(ds) = &self.dataset else {
263 debug!("Inverted index not initialized, returning empty result");
264 return Ok(Vec::new());
265 };
266
267 let normalized: Vec<String> = if self.config.normalize {
268 terms
269 .iter()
270 .map(|t| t.to_lowercase().trim().to_string())
271 .collect()
272 } else {
273 terms.to_vec()
274 };
275
276 if normalized.is_empty() {
277 return Ok(Vec::new());
278 }
279
280 let filter = normalized
281 .iter()
282 .map(|t| format!("term = '{}'", t.replace("'", "''")))
283 .collect::<Vec<_>>()
284 .join(" OR ");
285
286 debug!(filter = %filter, "Querying inverted index");
287
288 let mut scanner = ds.scan();
289 scanner.filter(&filter)?;
290
291 let mut stream = scanner.try_into_stream().await?;
292 let mut result_set: HashSet<u64> = HashSet::new();
293
294 while let Some(batch) = stream.try_next().await? {
295 let vids_col = batch
296 .column_by_name("vids")
297 .ok_or_else(|| anyhow!("Missing vids column"))?
298 .as_any()
299 .downcast_ref::<ListArray>()
300 .ok_or_else(|| anyhow!("Invalid vids column"))?;
301
302 for i in 0..batch.num_rows() {
303 if vids_col.is_null(i) {
304 continue;
305 }
306
307 let vids_array = vids_col.value(i);
308 let vids = vids_array
309 .as_any()
310 .downcast_ref::<UInt64Array>()
311 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
312
313 for vid in vids.iter().flatten() {
314 result_set.insert(vid);
315 }
316 }
317 }
318
319 debug!(count = result_set.len(), "Found matching VIDs");
320
321 Ok(result_set.into_iter().map(Vid::from).collect())
322 }
323
324 #[instrument(skip(self), level = "debug")]
329 async fn load_postings(&self) -> Result<HashMap<String, HashSet<u64>>> {
330 let Some(ds) = &self.dataset else {
331 return Ok(HashMap::new());
332 };
333
334 let mut postings: HashMap<String, HashSet<u64>> = HashMap::new();
335 let scanner = ds.scan();
336 let mut stream = scanner.try_into_stream().await?;
337
338 while let Some(batch) = stream.try_next().await? {
339 let term_col = batch
340 .column_by_name("term")
341 .ok_or_else(|| anyhow!("Missing term column"))?
342 .as_any()
343 .downcast_ref::<StringArray>()
344 .ok_or_else(|| anyhow!("Invalid term column type"))?;
345
346 let vids_col = batch
347 .column_by_name("vids")
348 .ok_or_else(|| anyhow!("Missing vids column"))?
349 .as_any()
350 .downcast_ref::<ListArray>()
351 .ok_or_else(|| anyhow!("Invalid vids column type"))?;
352
353 for i in 0..batch.num_rows() {
354 if term_col.is_null(i) || vids_col.is_null(i) {
355 continue;
356 }
357
358 let term = term_col.value(i).to_string();
359 let vids_array = vids_col.value(i);
360 let vids = vids_array
361 .as_any()
362 .downcast_ref::<UInt64Array>()
363 .ok_or_else(|| anyhow!("Invalid inner vids type"))?;
364
365 let entry = postings.entry(term).or_default();
366 for vid in vids.iter().flatten() {
367 entry.insert(vid);
368 }
369 }
370 }
371
372 Ok(postings)
373 }
374
375 #[instrument(skip(self, added, removed), level = "info", fields(
387 label = %self.label,
388 property = %self.property,
389 added_count = added.len(),
390 removed_count = removed.len()
391 ))]
392 pub async fn apply_incremental_updates(
393 &mut self,
394 added: &HashMap<Vid, Vec<String>>,
395 removed: &HashSet<Vid>,
396 ) -> Result<()> {
397 info!(
398 added = added.len(),
399 removed = removed.len(),
400 "Applying incremental updates to inverted index"
401 );
402
403 let mut postings = self.load_postings().await?;
405
406 if !removed.is_empty() {
408 let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
409 for vids in postings.values_mut() {
410 vids.retain(|vid| !removed_u64.contains(vid));
411 }
412 postings.retain(|_, vids| !vids.is_empty());
414 }
415
416 for (vid, terms) in added {
418 let vid_u64 = vid.as_u64();
419 let normalized_terms: HashSet<String> = if self.config.normalize {
420 terms
421 .iter()
422 .map(|t| t.to_lowercase().trim().to_string())
423 .collect()
424 } else {
425 terms.iter().cloned().collect()
426 };
427
428 let terms_to_add: Vec<_> = if normalized_terms.len() > self.config.max_terms_per_doc {
430 normalized_terms
431 .into_iter()
432 .take(self.config.max_terms_per_doc)
433 .collect()
434 } else {
435 normalized_terms.into_iter().collect()
436 };
437
438 for term in terms_to_add {
439 postings.entry(term).or_default().insert(vid_u64);
440 }
441 }
442
443 let postings_vec: HashMap<String, Vec<u64>> = postings
445 .into_iter()
446 .map(|(term, vids)| (term, vids.into_iter().collect()))
447 .collect();
448
449 info!(terms = postings_vec.len(), "Writing updated postings");
450
451 self.write_postings(postings_vec).await?;
452 Ok(())
453 }
454
455 pub fn extract_terms_from_value(&self, value: &Value) -> Option<Vec<String>> {
459 let arr = value.as_array()?;
460 let terms: Vec<String> = arr
461 .iter()
462 .filter_map(|v| v.as_str().map(ToString::to_string))
463 .collect();
464
465 if terms.is_empty() { None } else { Some(terms) }
466 }
467
468 pub fn is_initialized(&self) -> bool {
470 self.dataset.is_some()
471 }
472
473 pub fn property(&self) -> &str {
475 &self.property
476 }
477}