1use crate::lancedb::LanceDbStore;
19use crate::storage::arrow_convert::build_timestamp_column_from_vid_map;
20use crate::storage::index_utils::ensure_btree_index;
21use anyhow::{Result, anyhow};
22use arrow_array::builder::{
23 FixedSizeBinaryBuilder, LargeBinaryBuilder, ListBuilder, StringBuilder,
24};
25use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
26use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
27use futures::TryStreamExt;
28use futures::future;
29use lancedb::Table;
30use lancedb::index::Index as LanceDbIndex;
31use lancedb::index::scalar::LabelListIndexBuilder;
32use lancedb::query::{ExecutableQuery, QueryBase};
33use sha3::{Digest, Sha3_256};
34use std::collections::HashMap;
35use std::sync::Arc;
36use uni_common::Properties;
37use uni_common::core::id::{UniId, Vid};
38
39#[derive(Debug)]
46pub struct MainVertexDataset {
47 _base_uri: String,
48}
49
50impl MainVertexDataset {
51 pub fn new(base_uri: &str) -> Self {
53 Self {
54 _base_uri: base_uri.to_string(),
55 }
56 }
57
58 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
60 Arc::new(ArrowSchema::new(vec![
61 Field::new("_vid", DataType::UInt64, false),
62 Field::new("_uid", DataType::FixedSizeBinary(32), true),
63 Field::new("ext_id", DataType::Utf8, true),
64 Field::new(
65 "labels",
66 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
67 false,
68 ),
69 Field::new("props_json", DataType::LargeBinary, true),
70 Field::new("_deleted", DataType::Boolean, false),
71 Field::new("_version", DataType::UInt64, false),
72 Field::new(
73 "_created_at",
74 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
75 true,
76 ),
77 Field::new(
78 "_updated_at",
79 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
80 true,
81 ),
82 ]))
83 }
84
85 pub fn table_name() -> &'static str {
87 "vertices"
88 }
89
90 pub async fn open_table(store: &LanceDbStore) -> Result<Table> {
94 store
95 .open_table(Self::table_name())
96 .await
97 .map_err(|e| anyhow!("Failed to open main vertices table: {}", e))
98 }
99
100 fn compute_vertex_uid(labels: &[String], ext_id: Option<&str>, props: &Properties) -> UniId {
102 let mut hasher = Sha3_256::new();
103
104 let mut sorted_labels = labels.to_vec();
106 sorted_labels.sort();
107 for label in &sorted_labels {
108 hasher.update(label.as_bytes());
109 hasher.update(b"\0");
110 }
111
112 if let Some(ext_id) = ext_id {
114 hasher.update(b"ext_id:");
115 hasher.update(ext_id.as_bytes());
116 hasher.update(b"\0");
117 }
118
119 let mut sorted_keys: Vec<_> = props.keys().collect();
121 sorted_keys.sort();
122 for key in sorted_keys {
123 if key == "ext_id" {
124 continue; }
126 if let Some(val) = props.get(key) {
127 hasher.update(key.as_bytes());
128 hasher.update(b":");
129 hasher.update(val.to_string().as_bytes());
130 hasher.update(b"\0");
131 }
132 }
133
134 let result = hasher.finalize();
135 UniId::from_bytes(result.into())
136 }
137
138 pub fn build_record_batch(
145 vertices: &[(Vid, Vec<String>, Properties, bool, u64)],
146 created_at: Option<&HashMap<Vid, i64>>,
147 updated_at: Option<&HashMap<Vid, i64>>,
148 ) -> Result<RecordBatch> {
149 let arrow_schema = Self::get_arrow_schema();
150 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
151
152 let vids: Vec<u64> = vertices.iter().map(|(v, _, _, _, _)| v.as_u64()).collect();
154 columns.push(Arc::new(UInt64Array::from(vids)));
155
156 let mut uid_builder = FixedSizeBinaryBuilder::new(32);
158 for (_, labels, props, _, _) in vertices.iter() {
159 let ext_id = props.get("ext_id").and_then(|v| v.as_str());
160 let uid = Self::compute_vertex_uid(labels, ext_id, props);
161 uid_builder.append_value(uid.as_bytes())?;
162 }
163 columns.push(Arc::new(uid_builder.finish()));
164
165 let mut ext_id_builder = StringBuilder::new();
167 for (_, _, props, _, _) in vertices.iter() {
168 if let Some(ext_id_val) = props.get("ext_id").and_then(|v| v.as_str()) {
169 ext_id_builder.append_value(ext_id_val);
170 } else {
171 ext_id_builder.append_null();
172 }
173 }
174 columns.push(Arc::new(ext_id_builder.finish()));
175
176 let mut labels_builder = ListBuilder::new(StringBuilder::new());
178 for (_, labels, _, _, _) in vertices.iter() {
179 let values_builder = labels_builder.values();
180 for label in labels {
181 values_builder.append_value(label);
182 }
183 labels_builder.append(true);
184 }
185 columns.push(Arc::new(labels_builder.finish()));
186
187 let mut props_json_builder = LargeBinaryBuilder::new();
189 for (_, _, props, _, _) in vertices.iter() {
190 let jsonb_bytes = {
191 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
192 let uni_val: uni_common::Value = json_val.into();
193 uni_common::cypher_value_codec::encode(&uni_val)
194 };
195 props_json_builder.append_value(&jsonb_bytes);
196 }
197 columns.push(Arc::new(props_json_builder.finish()));
198
199 let deleted: Vec<bool> = vertices.iter().map(|(_, _, _, d, _)| *d).collect();
201 columns.push(Arc::new(BooleanArray::from(deleted)));
202
203 let versions: Vec<u64> = vertices.iter().map(|(_, _, _, _, v)| *v).collect();
205 columns.push(Arc::new(UInt64Array::from(versions)));
206
207 let vids = vertices.iter().map(|(v, _, _, _, _)| *v);
209 columns.push(build_timestamp_column_from_vid_map(
210 vids.clone(),
211 created_at,
212 ));
213 columns.push(build_timestamp_column_from_vid_map(vids, updated_at));
214
215 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
216 }
217
218 pub async fn write_batch_lancedb(store: &LanceDbStore, batch: RecordBatch) -> Result<Table> {
222 let table_name = Self::table_name();
223
224 if store.table_exists(table_name).await? {
225 let table = store.open_table(table_name).await?;
226 store.append_to_table(&table, vec![batch]).await?;
227 Ok(table)
228 } else {
229 store.create_table(table_name, vec![batch]).await
230 }
231 }
232
233 pub async fn ensure_default_indexes_lancedb(table: &Table) -> Result<()> {
235 let indices = table
236 .list_indices()
237 .await
238 .map_err(|e| anyhow!("Failed to list indices: {}", e))?;
239
240 future::join_all(
242 ["_vid", "ext_id", "_uid"]
243 .iter()
244 .map(|col| ensure_btree_index(table, &indices, col, "main vertices")),
245 )
246 .await;
247
248 if !indices
250 .iter()
251 .any(|idx| idx.columns.iter().any(|c| c == "labels"))
252 {
253 log::info!("Creating labels LABEL_LIST index for main vertices table");
254 if let Err(e) = table
255 .create_index(
256 &["labels"],
257 LanceDbIndex::LabelList(LabelListIndexBuilder::default()),
258 )
259 .execute()
260 .await
261 {
262 log::warn!("Failed to create labels index for main vertices: {}", e);
263 }
264 }
265
266 Ok(())
267 }
268
269 pub async fn find_by_ext_id(
278 store: &LanceDbStore,
279 ext_id: &str,
280 version: Option<u64>,
281 ) -> Result<Option<Vid>> {
282 let table_name = Self::table_name();
283
284 if !store.table_exists(table_name).await? {
285 return Ok(None);
286 }
287
288 let table = store.open_table(table_name).await?;
289 let mut query = format!(
290 "ext_id = '{}' AND _deleted = false",
291 ext_id.replace('\'', "''")
292 );
293 if let Some(hwm) = version {
294 query.push_str(&format!(" AND _version <= {}", hwm));
295 }
296
297 let batches = table
298 .query()
299 .only_if(query)
300 .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
301 .execute()
302 .await
303 .map_err(|e| anyhow!("Query failed: {}", e))?;
304
305 let results: Vec<RecordBatch> = batches.try_collect().await?;
306
307 for batch in results {
308 if batch.num_rows() > 0
309 && let Some(vid_col) = batch.column_by_name("_vid")
310 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
311 {
312 return Ok(Some(Vid::from(vid_arr.value(0))));
313 }
314 }
315
316 Ok(None)
317 }
318
319 pub async fn ext_id_exists(
324 store: &LanceDbStore,
325 ext_id: &str,
326 version: Option<u64>,
327 ) -> Result<bool> {
328 Ok(Self::find_by_ext_id(store, ext_id, version)
329 .await?
330 .is_some())
331 }
332
333 pub async fn find_labels_by_vid(
340 store: &LanceDbStore,
341 vid: Vid,
342 version: Option<u64>,
343 ) -> Result<Option<Vec<String>>> {
344 let table_name = Self::table_name();
345
346 if !store.table_exists(table_name).await? {
347 return Ok(None);
348 }
349
350 let table = store.open_table(table_name).await?;
351 let mut query = format!("_vid = {} AND _deleted = false", vid.as_u64());
352 if let Some(hwm) = version {
353 query.push_str(&format!(" AND _version <= {}", hwm));
354 }
355
356 let batches = table
357 .query()
358 .only_if(query)
359 .select(lancedb::query::Select::Columns(vec!["labels".to_string()]))
360 .execute()
361 .await
362 .map_err(|e| anyhow!("Query failed: {}", e))?;
363
364 let results: Vec<RecordBatch> = batches.try_collect().await?;
365
366 for batch in results {
367 if batch.num_rows() > 0
368 && let Some(labels_col) = batch.column_by_name("labels")
369 && let Some(list_arr) = labels_col.as_any().downcast_ref::<arrow_array::ListArray>()
370 {
371 let values = list_arr.value(0);
373 if let Some(str_arr) = values.as_any().downcast_ref::<arrow_array::StringArray>() {
374 let labels: Vec<String> = (0..str_arr.len())
375 .filter_map(|i| {
376 if str_arr.is_null(i) {
377 None
378 } else {
379 Some(str_arr.value(i).to_string())
380 }
381 })
382 .collect();
383 return Ok(Some(labels));
384 }
385 }
386 }
387
388 Ok(None)
389 }
390
391 pub async fn find_all_vids(store: &LanceDbStore, version: Option<u64>) -> Result<Vec<Vid>> {
402 let table_name = Self::table_name();
403
404 if !store.table_exists(table_name).await? {
405 return Ok(Vec::new());
406 }
407
408 let table = store.open_table(table_name).await?;
409 let mut query = "_deleted = false".to_string();
410 if let Some(hwm) = version {
411 query.push_str(&format!(" AND _version <= {}", hwm));
412 }
413
414 let batches = table
415 .query()
416 .only_if(query)
417 .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
418 .execute()
419 .await
420 .map_err(|e| anyhow!("Query failed: {}", e))?;
421
422 let results: Vec<RecordBatch> = batches.try_collect().await?;
423
424 let mut vids = Vec::new();
425 for batch in results {
426 if let Some(vid_col) = batch.column_by_name("_vid")
427 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
428 {
429 for i in 0..vid_arr.len() {
430 if !vid_arr.is_null(i) {
431 vids.push(Vid::new(vid_arr.value(i)));
432 }
433 }
434 }
435 }
436
437 Ok(vids)
438 }
439
440 pub async fn find_vids_by_label_name(
452 store: &LanceDbStore,
453 label: &str,
454 version: Option<u64>,
455 ) -> Result<Vec<Vid>> {
456 let table_name = Self::table_name();
457
458 if !store.table_exists(table_name).await? {
459 return Ok(Vec::new());
460 }
461
462 let table = store.open_table(table_name).await?;
463 let mut query = format!("_deleted = false AND array_contains(labels, '{}')", label);
465 if let Some(hwm) = version {
466 query.push_str(&format!(" AND _version <= {}", hwm));
467 }
468
469 let batches = table
470 .query()
471 .only_if(query)
472 .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
473 .execute()
474 .await
475 .map_err(|e| anyhow!("Query failed: {}", e))?;
476
477 let results: Vec<RecordBatch> = batches.try_collect().await?;
478
479 let mut vids = Vec::new();
480 for batch in results {
481 if let Some(vid_col) = batch.column_by_name("_vid")
482 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
483 {
484 for i in 0..vid_arr.len() {
485 if !vid_arr.is_null(i) {
486 vids.push(Vid::new(vid_arr.value(i)));
487 }
488 }
489 }
490 }
491
492 Ok(vids)
493 }
494
495 pub async fn find_vids_by_labels(
503 store: &LanceDbStore,
504 labels: &[&str],
505 version: Option<u64>,
506 ) -> Result<Vec<Vid>> {
507 let table_name = Self::table_name();
508
509 if labels.is_empty() || !store.table_exists(table_name).await? {
510 return Ok(Vec::new());
511 }
512
513 let table = store.open_table(table_name).await?;
514
515 let label_conditions: Vec<String> = labels
517 .iter()
518 .map(|label| {
519 let escaped = label.replace('\'', "''");
520 format!("array_contains(labels, '{}')", escaped)
521 })
522 .collect();
523
524 let mut query = format!("_deleted = false AND {}", label_conditions.join(" AND "));
525 if let Some(hwm) = version {
526 query.push_str(&format!(" AND _version <= {}", hwm));
527 }
528
529 let batches = table
530 .query()
531 .only_if(query)
532 .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
533 .execute()
534 .await
535 .map_err(|e| anyhow!("Query failed: {}", e))?;
536
537 let results: Vec<RecordBatch> = batches.try_collect().await?;
538
539 let mut vids = Vec::new();
540 for batch in results {
541 if let Some(vid_col) = batch.column_by_name("_vid")
542 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
543 {
544 for i in 0..vid_arr.len() {
545 if !vid_arr.is_null(i) {
546 vids.push(Vid::new(vid_arr.value(i)));
547 }
548 }
549 }
550 }
551
552 Ok(vids)
553 }
554
555 pub async fn find_batch_props_by_vids(
568 store: &LanceDbStore,
569 vids: &[Vid],
570 version: Option<u64>,
571 ) -> Result<HashMap<Vid, Properties>> {
572 let table_name = Self::table_name();
573
574 if vids.is_empty() || !store.table_exists(table_name).await? {
575 return Ok(HashMap::new());
576 }
577
578 let table = store.open_table(table_name).await?;
579
580 let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
582 let mut query = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
583 if let Some(hwm) = version {
584 query.push_str(&format!(" AND _version <= {}", hwm));
585 }
586
587 let batches = table
588 .query()
589 .only_if(query)
590 .select(lancedb::query::Select::Columns(vec![
591 "_vid".to_string(),
592 "props_json".to_string(),
593 ]))
594 .execute()
595 .await
596 .map_err(|e| anyhow!("Query failed: {}", e))?;
597
598 let results: Vec<RecordBatch> = batches.try_collect().await?;
599
600 let mut props_map = HashMap::new();
601
602 for batch in results {
603 let vid_col = batch.column_by_name("_vid");
604 let props_col = batch.column_by_name("props_json");
605
606 if let (Some(vid_arr), Some(props_arr)) = (
607 vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
608 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
609 ) {
610 for i in 0..batch.num_rows() {
611 if vid_arr.is_null(i) {
612 continue;
613 }
614 let vid = Vid::new(vid_arr.value(i));
615
616 let props: Properties = if props_arr.is_null(i) || props_arr.value(i).is_empty()
617 {
618 Properties::new()
619 } else {
620 let bytes = props_arr.value(i);
621 let uni_val = uni_common::cypher_value_codec::decode(bytes)
622 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
623 let json_val: serde_json::Value = uni_val.into();
624 serde_json::from_value(json_val)
625 .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?
626 };
627
628 props_map.insert(vid, props);
629 }
630 }
631 }
632
633 Ok(props_map)
634 }
635
636 pub async fn find_props_by_vid(
648 store: &LanceDbStore,
649 vid: Vid,
650 version: Option<u64>,
651 ) -> Result<Option<Properties>> {
652 let table_name = Self::table_name();
653
654 if !store.table_exists(table_name).await? {
655 return Ok(None);
656 }
657
658 let table = store.open_table(table_name).await?;
659 let mut query = format!("_vid = {} AND _deleted = false", vid.as_u64());
660 if let Some(hwm) = version {
661 query.push_str(&format!(" AND _version <= {}", hwm));
662 }
663
664 let batches = table
665 .query()
666 .only_if(query)
667 .select(lancedb::query::Select::Columns(vec![
668 "props_json".to_string(),
669 "_version".to_string(),
670 ]))
671 .execute()
672 .await
673 .map_err(|e| anyhow!("Query failed: {}", e))?;
674
675 let results: Vec<RecordBatch> = batches.try_collect().await?;
676
677 let mut best_props: Option<Properties> = None;
679 let mut best_version: u64 = 0;
680
681 for batch in results {
682 let props_col = batch.column_by_name("props_json");
683 let version_col = batch.column_by_name("_version");
684
685 if let (Some(props_arr), Some(ver_arr)) = (
686 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
687 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
688 ) {
689 for i in 0..batch.num_rows() {
690 let version = if ver_arr.is_null(i) {
691 0
692 } else {
693 ver_arr.value(i)
694 };
695
696 if version >= best_version {
697 best_version = version;
698 if props_arr.is_null(i) || props_arr.value(i).is_empty() {
699 best_props = Some(Properties::new());
700 } else {
701 let bytes = props_arr.value(i);
702 let uni_val = uni_common::cypher_value_codec::decode(bytes)
703 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
704 let json_val: serde_json::Value = uni_val.into();
705 let parsed: Properties = serde_json::from_value(json_val)
706 .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?;
707 best_props = Some(parsed);
708 }
709 }
710 }
711 }
712 }
713
714 Ok(best_props)
715 }
716
717 pub async fn find_batch_labels_by_vids(
722 store: &LanceDbStore,
723 vids: &[Vid],
724 version: Option<u64>,
725 ) -> Result<HashMap<Vid, Vec<String>>> {
726 let table_name = Self::table_name();
727
728 if vids.is_empty() || !store.table_exists(table_name).await? {
729 return Ok(HashMap::new());
730 }
731
732 let table = store.open_table(table_name).await?;
733
734 let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
736 let mut query = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
737 if let Some(hwm) = version {
738 query.push_str(&format!(" AND _version <= {}", hwm));
739 }
740
741 let batches = table
742 .query()
743 .only_if(query)
744 .select(lancedb::query::Select::Columns(vec![
745 "_vid".to_string(),
746 "labels".to_string(),
747 ]))
748 .execute()
749 .await
750 .map_err(|e| anyhow!("Query failed: {}", e))?;
751
752 let results: Vec<RecordBatch> = batches.try_collect().await?;
753
754 let mut label_map = HashMap::new();
755
756 for batch in results {
757 let vid_col = batch.column_by_name("_vid");
758 let labels_col = batch.column_by_name("labels");
759
760 if let (Some(vid_arr), Some(labels_arr)) = (
761 vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
762 labels_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::ListArray>()),
763 ) {
764 for i in 0..batch.num_rows() {
765 if vid_arr.is_null(i) {
766 continue;
767 }
768 let vid = Vid::new(vid_arr.value(i));
769
770 let values = labels_arr.value(i);
771 if let Some(str_arr) =
772 values.as_any().downcast_ref::<arrow_array::StringArray>()
773 {
774 let labels: Vec<String> = (0..str_arr.len())
775 .filter_map(|j| {
776 if str_arr.is_null(j) {
777 None
778 } else {
779 Some(str_arr.value(j).to_string())
780 }
781 })
782 .collect();
783 label_map.insert(vid, labels);
784 }
785 }
786 }
787 }
788
789 Ok(label_map)
790 }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use arrow_array::StringArray;
797
798 #[test]
799 fn test_main_vertex_schema() {
800 let schema = MainVertexDataset::get_arrow_schema();
801 assert_eq!(schema.fields().len(), 9);
802 assert!(schema.field_with_name("_vid").is_ok());
803 assert!(schema.field_with_name("_uid").is_ok());
804 assert!(schema.field_with_name("ext_id").is_ok());
805 assert!(schema.field_with_name("labels").is_ok());
806 assert!(schema.field_with_name("props_json").is_ok());
807 assert!(schema.field_with_name("_deleted").is_ok());
808 assert!(schema.field_with_name("_version").is_ok());
809 assert!(schema.field_with_name("_created_at").is_ok());
810 assert!(schema.field_with_name("_updated_at").is_ok());
811 }
812
813 #[test]
814 fn test_build_record_batch() {
815 use uni_common::Value;
816 let mut props = HashMap::new();
817 props.insert("name".to_string(), Value::String("Alice".to_string()));
818 props.insert("ext_id".to_string(), Value::String("user_001".to_string()));
819
820 let vertices = vec![(Vid::new(1), vec!["Person".to_string()], props, false, 1u64)];
821
822 let batch = MainVertexDataset::build_record_batch(&vertices, None, None).unwrap();
823 assert_eq!(batch.num_rows(), 1);
824 assert_eq!(batch.num_columns(), 9);
825
826 let ext_id_col = batch.column_by_name("ext_id").unwrap();
828 let ext_id_arr = ext_id_col.as_any().downcast_ref::<StringArray>().unwrap();
829 assert_eq!(ext_id_arr.value(0), "user_001");
830 }
831}