1use crate::backend::StorageBackend;
19use crate::backend::table_names;
20use crate::backend::types::{ScalarIndexType, ScanRequest, WriteMode};
21use crate::storage::arrow_convert::build_timestamp_column_from_vid_map;
22use anyhow::{Result, anyhow};
23use arrow_array::builder::{
24 FixedSizeBinaryBuilder, LargeBinaryBuilder, ListBuilder, StringBuilder,
25};
26use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
27use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
28use sha3::{Digest, Sha3_256};
29use std::collections::HashMap;
30use std::sync::Arc;
31use uni_common::Properties;
32use uni_common::core::id::{UniId, Vid};
33
34#[derive(Debug)]
41pub struct MainVertexDataset {
42 _base_uri: String,
43}
44
45impl MainVertexDataset {
46 pub fn new(base_uri: &str) -> Self {
48 Self {
49 _base_uri: base_uri.to_string(),
50 }
51 }
52
53 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
55 Arc::new(ArrowSchema::new(vec![
56 Field::new("_vid", DataType::UInt64, false),
57 Field::new("_uid", DataType::FixedSizeBinary(32), true),
58 Field::new("ext_id", DataType::Utf8, true),
59 Field::new(
60 "labels",
61 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
62 false,
63 ),
64 Field::new("props_json", DataType::LargeBinary, true),
65 Field::new("_deleted", DataType::Boolean, false),
66 Field::new("_version", DataType::UInt64, false),
67 Field::new(
68 "_created_at",
69 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
70 true,
71 ),
72 Field::new(
73 "_updated_at",
74 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
75 true,
76 ),
77 ]))
78 }
79
80 pub fn table_name() -> &'static str {
82 table_names::main_vertex_table_name()
83 }
84
85 fn compute_vertex_uid(labels: &[String], ext_id: Option<&str>, props: &Properties) -> UniId {
87 let mut hasher = Sha3_256::new();
88
89 let mut sorted_labels = labels.to_vec();
91 sorted_labels.sort();
92 for label in &sorted_labels {
93 hasher.update(label.as_bytes());
94 hasher.update(b"\0");
95 }
96
97 if let Some(ext_id) = ext_id {
99 hasher.update(b"ext_id:");
100 hasher.update(ext_id.as_bytes());
101 hasher.update(b"\0");
102 }
103
104 let mut sorted_keys: Vec<_> = props.keys().collect();
106 sorted_keys.sort();
107 for key in sorted_keys {
108 if key == "ext_id" {
109 continue; }
111 if let Some(val) = props.get(key) {
112 hasher.update(key.as_bytes());
113 hasher.update(b":");
114 hasher.update(val.to_string().as_bytes());
115 hasher.update(b"\0");
116 }
117 }
118
119 let result = hasher.finalize();
120 UniId::from_bytes(result.into())
121 }
122
123 pub fn build_record_batch(
130 vertices: &[(Vid, Vec<String>, Properties, bool, u64)],
131 created_at: Option<&HashMap<Vid, i64>>,
132 updated_at: Option<&HashMap<Vid, i64>>,
133 ) -> Result<RecordBatch> {
134 let arrow_schema = Self::get_arrow_schema();
135 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
136
137 let vids: Vec<u64> = vertices.iter().map(|(v, _, _, _, _)| v.as_u64()).collect();
139 columns.push(Arc::new(UInt64Array::from(vids)));
140
141 let mut uid_builder = FixedSizeBinaryBuilder::new(32);
143 for (_, labels, props, _, _) in vertices.iter() {
144 let ext_id = props.get("ext_id").and_then(|v| v.as_str());
145 let uid = Self::compute_vertex_uid(labels, ext_id, props);
146 uid_builder.append_value(uid.as_bytes())?;
147 }
148 columns.push(Arc::new(uid_builder.finish()));
149
150 let mut ext_id_builder = StringBuilder::new();
152 for (_, _, props, _, _) in vertices.iter() {
153 if let Some(ext_id_val) = props.get("ext_id").and_then(|v| v.as_str()) {
154 ext_id_builder.append_value(ext_id_val);
155 } else {
156 ext_id_builder.append_null();
157 }
158 }
159 columns.push(Arc::new(ext_id_builder.finish()));
160
161 let mut labels_builder = ListBuilder::new(StringBuilder::new());
163 for (_, labels, _, _, _) in vertices.iter() {
164 let values_builder = labels_builder.values();
165 for label in labels {
166 values_builder.append_value(label);
167 }
168 labels_builder.append(true);
169 }
170 columns.push(Arc::new(labels_builder.finish()));
171
172 let mut props_json_builder = LargeBinaryBuilder::new();
174 for (_, _, props, _, _) in vertices.iter() {
175 let jsonb_bytes = {
176 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
177 let uni_val: uni_common::Value = json_val.into();
178 uni_common::cypher_value_codec::encode(&uni_val)
179 };
180 props_json_builder.append_value(&jsonb_bytes);
181 }
182 columns.push(Arc::new(props_json_builder.finish()));
183
184 let deleted: Vec<bool> = vertices.iter().map(|(_, _, _, d, _)| *d).collect();
186 columns.push(Arc::new(BooleanArray::from(deleted)));
187
188 let versions: Vec<u64> = vertices.iter().map(|(_, _, _, _, v)| *v).collect();
190 columns.push(Arc::new(UInt64Array::from(versions)));
191
192 let vids = vertices.iter().map(|(v, _, _, _, _)| *v);
194 columns.push(build_timestamp_column_from_vid_map(
195 vids.clone(),
196 created_at,
197 ));
198 columns.push(build_timestamp_column_from_vid_map(vids, updated_at));
199
200 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
201 }
202
203 pub async fn write_batch(backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
207 let table_name = table_names::main_vertex_table_name();
208
209 if backend.table_exists(table_name).await? {
210 backend
211 .write(table_name, vec![batch], WriteMode::Append)
212 .await
213 } else {
214 backend.create_table(table_name, vec![batch]).await
215 }
216 }
217
218 pub async fn ensure_default_indexes(backend: &dyn StorageBackend) -> Result<()> {
220 let table_name = table_names::main_vertex_table_name();
221
222 let _ = backend
224 .create_scalar_index(table_name, "_vid", ScalarIndexType::BTree)
225 .await;
226 let _ = backend
227 .create_scalar_index(table_name, "ext_id", ScalarIndexType::BTree)
228 .await;
229 let _ = backend
230 .create_scalar_index(table_name, "_uid", ScalarIndexType::BTree)
231 .await;
232
233 let _ = backend
235 .create_scalar_index(table_name, "labels", ScalarIndexType::LabelList)
236 .await;
237
238 Ok(())
239 }
240
241 pub async fn find_by_ext_id(
250 backend: &dyn StorageBackend,
251 ext_id: &str,
252 version: Option<u64>,
253 ) -> Result<Option<Vid>> {
254 let table_name = table_names::main_vertex_table_name();
255
256 if !backend.table_exists(table_name).await? {
257 return Ok(None);
258 }
259
260 let mut filter = format!(
261 "ext_id = '{}' AND _deleted = false",
262 ext_id.replace('\'', "''")
263 );
264 if let Some(hwm) = version {
265 filter.push_str(&format!(" AND _version <= {}", hwm));
266 }
267
268 let results = backend
269 .scan(
270 ScanRequest::all(table_name)
271 .with_filter(filter)
272 .with_columns(vec!["_vid".to_string()]),
273 )
274 .await?;
275
276 for batch in results {
277 if batch.num_rows() > 0
278 && let Some(vid_col) = batch.column_by_name("_vid")
279 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
280 {
281 return Ok(Some(Vid::from(vid_arr.value(0))));
282 }
283 }
284
285 Ok(None)
286 }
287
288 pub async fn ext_id_exists(
293 backend: &dyn StorageBackend,
294 ext_id: &str,
295 version: Option<u64>,
296 ) -> Result<bool> {
297 Ok(Self::find_by_ext_id(backend, ext_id, version)
298 .await?
299 .is_some())
300 }
301
302 pub async fn find_labels_by_vid(
309 backend: &dyn StorageBackend,
310 vid: Vid,
311 version: Option<u64>,
312 ) -> Result<Option<Vec<String>>> {
313 let table_name = table_names::main_vertex_table_name();
314
315 if !backend.table_exists(table_name).await? {
316 return Ok(None);
317 }
318
319 let mut filter = format!("_vid = {} AND _deleted = false", vid.as_u64());
320 if let Some(hwm) = version {
321 filter.push_str(&format!(" AND _version <= {}", hwm));
322 }
323
324 let results = backend
325 .scan(
326 ScanRequest::all(table_name)
327 .with_filter(filter)
328 .with_columns(vec!["labels".to_string()]),
329 )
330 .await?;
331
332 for batch in results {
333 if batch.num_rows() > 0
334 && let Some(labels_col) = batch.column_by_name("labels")
335 && let Some(list_arr) = labels_col.as_any().downcast_ref::<arrow_array::ListArray>()
336 {
337 let values = list_arr.value(0);
339 if let Some(str_arr) = values.as_any().downcast_ref::<arrow_array::StringArray>() {
340 let labels: Vec<String> = (0..str_arr.len())
341 .filter_map(|i| {
342 if str_arr.is_null(i) {
343 None
344 } else {
345 Some(str_arr.value(i).to_string())
346 }
347 })
348 .collect();
349 return Ok(Some(labels));
350 }
351 }
352 }
353
354 Ok(None)
355 }
356
357 pub async fn find_all_vids(
368 backend: &dyn StorageBackend,
369 version: Option<u64>,
370 ) -> Result<Vec<Vid>> {
371 let table_name = table_names::main_vertex_table_name();
372
373 if !backend.table_exists(table_name).await? {
374 return Ok(Vec::new());
375 }
376
377 let mut filter = "_deleted = false".to_string();
378 if let Some(hwm) = version {
379 filter.push_str(&format!(" AND _version <= {}", hwm));
380 }
381
382 let results = backend
383 .scan(
384 ScanRequest::all(table_name)
385 .with_filter(filter)
386 .with_columns(vec!["_vid".to_string()]),
387 )
388 .await?;
389
390 let mut vids = Vec::new();
391 for batch in results {
392 if let Some(vid_col) = batch.column_by_name("_vid")
393 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
394 {
395 for i in 0..vid_arr.len() {
396 if !vid_arr.is_null(i) {
397 vids.push(Vid::new(vid_arr.value(i)));
398 }
399 }
400 }
401 }
402
403 Ok(vids)
404 }
405
406 pub async fn find_vids_by_label_name(
418 backend: &dyn StorageBackend,
419 label: &str,
420 version: Option<u64>,
421 ) -> Result<Vec<Vid>> {
422 let table_name = table_names::main_vertex_table_name();
423
424 if !backend.table_exists(table_name).await? {
425 return Ok(Vec::new());
426 }
427
428 let mut filter = format!("_deleted = false AND array_contains(labels, '{}')", label);
430 if let Some(hwm) = version {
431 filter.push_str(&format!(" AND _version <= {}", hwm));
432 }
433
434 let results = backend
435 .scan(
436 ScanRequest::all(table_name)
437 .with_filter(filter)
438 .with_columns(vec!["_vid".to_string()]),
439 )
440 .await?;
441
442 let mut vids = Vec::new();
443 for batch in results {
444 if let Some(vid_col) = batch.column_by_name("_vid")
445 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
446 {
447 for i in 0..vid_arr.len() {
448 if !vid_arr.is_null(i) {
449 vids.push(Vid::new(vid_arr.value(i)));
450 }
451 }
452 }
453 }
454
455 Ok(vids)
456 }
457
458 pub async fn find_vids_by_labels(
466 backend: &dyn StorageBackend,
467 labels: &[&str],
468 version: Option<u64>,
469 ) -> Result<Vec<Vid>> {
470 let table_name = table_names::main_vertex_table_name();
471
472 if labels.is_empty() || !backend.table_exists(table_name).await? {
473 return Ok(Vec::new());
474 }
475
476 let label_conditions: Vec<String> = labels
478 .iter()
479 .map(|label| {
480 let escaped = label.replace('\'', "''");
481 format!("array_contains(labels, '{}')", escaped)
482 })
483 .collect();
484
485 let mut filter = format!("_deleted = false AND {}", label_conditions.join(" AND "));
486 if let Some(hwm) = version {
487 filter.push_str(&format!(" AND _version <= {}", hwm));
488 }
489
490 let results = backend
491 .scan(
492 ScanRequest::all(table_name)
493 .with_filter(filter)
494 .with_columns(vec!["_vid".to_string()]),
495 )
496 .await?;
497
498 let mut vids = Vec::new();
499 for batch in results {
500 if let Some(vid_col) = batch.column_by_name("_vid")
501 && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
502 {
503 for i in 0..vid_arr.len() {
504 if !vid_arr.is_null(i) {
505 vids.push(Vid::new(vid_arr.value(i)));
506 }
507 }
508 }
509 }
510
511 Ok(vids)
512 }
513
514 pub async fn find_batch_props_by_vids(
527 backend: &dyn StorageBackend,
528 vids: &[Vid],
529 version: Option<u64>,
530 ) -> Result<HashMap<Vid, Properties>> {
531 let table_name = table_names::main_vertex_table_name();
532
533 if vids.is_empty() || !backend.table_exists(table_name).await? {
534 return Ok(HashMap::new());
535 }
536
537 let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
539 let mut filter = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
540 if let Some(hwm) = version {
541 filter.push_str(&format!(" AND _version <= {}", hwm));
542 }
543
544 let results = backend
545 .scan(
546 ScanRequest::all(table_name)
547 .with_filter(filter)
548 .with_columns(vec!["_vid".to_string(), "props_json".to_string()]),
549 )
550 .await?;
551
552 let mut props_map = HashMap::new();
553
554 for batch in results {
555 let vid_col = batch.column_by_name("_vid");
556 let props_col = batch.column_by_name("props_json");
557
558 if let (Some(vid_arr), Some(props_arr)) = (
559 vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
560 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
561 ) {
562 for i in 0..batch.num_rows() {
563 if vid_arr.is_null(i) {
564 continue;
565 }
566 let vid = Vid::new(vid_arr.value(i));
567
568 let props: Properties = if props_arr.is_null(i) || props_arr.value(i).is_empty()
569 {
570 Properties::new()
571 } else {
572 let bytes = props_arr.value(i);
573 let uni_val = uni_common::cypher_value_codec::decode(bytes)
574 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
575 let json_val: serde_json::Value = uni_val.into();
576 serde_json::from_value(json_val)
577 .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?
578 };
579
580 props_map.insert(vid, props);
581 }
582 }
583 }
584
585 Ok(props_map)
586 }
587
588 pub async fn find_props_by_vid(
600 backend: &dyn StorageBackend,
601 vid: Vid,
602 version: Option<u64>,
603 ) -> Result<Option<Properties>> {
604 let table_name = table_names::main_vertex_table_name();
605
606 if !backend.table_exists(table_name).await? {
607 return Ok(None);
608 }
609
610 let mut filter = format!("_vid = {} AND _deleted = false", vid.as_u64());
611 if let Some(hwm) = version {
612 filter.push_str(&format!(" AND _version <= {}", hwm));
613 }
614
615 let results = backend
616 .scan(
617 ScanRequest::all(table_name)
618 .with_filter(filter)
619 .with_columns(vec!["props_json".to_string(), "_version".to_string()]),
620 )
621 .await?;
622
623 let mut best_props: Option<Properties> = None;
625 let mut best_version: u64 = 0;
626
627 for batch in results {
628 let props_col = batch.column_by_name("props_json");
629 let version_col = batch.column_by_name("_version");
630
631 if let (Some(props_arr), Some(ver_arr)) = (
632 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
633 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
634 ) {
635 for i in 0..batch.num_rows() {
636 let version = if ver_arr.is_null(i) {
637 0
638 } else {
639 ver_arr.value(i)
640 };
641
642 if version >= best_version {
643 best_version = version;
644 if props_arr.is_null(i) || props_arr.value(i).is_empty() {
645 best_props = Some(Properties::new());
646 } else {
647 let bytes = props_arr.value(i);
648 let uni_val = uni_common::cypher_value_codec::decode(bytes)
649 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
650 let json_val: serde_json::Value = uni_val.into();
651 let parsed: Properties = serde_json::from_value(json_val)
652 .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?;
653 best_props = Some(parsed);
654 }
655 }
656 }
657 }
658 }
659
660 Ok(best_props)
661 }
662
663 pub async fn find_batch_labels_by_vids(
668 backend: &dyn StorageBackend,
669 vids: &[Vid],
670 version: Option<u64>,
671 ) -> Result<HashMap<Vid, Vec<String>>> {
672 let table_name = table_names::main_vertex_table_name();
673
674 if vids.is_empty() || !backend.table_exists(table_name).await? {
675 return Ok(HashMap::new());
676 }
677
678 let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
680 let mut filter = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
681 if let Some(hwm) = version {
682 filter.push_str(&format!(" AND _version <= {}", hwm));
683 }
684
685 let results = backend
686 .scan(
687 ScanRequest::all(table_name)
688 .with_filter(filter)
689 .with_columns(vec!["_vid".to_string(), "labels".to_string()]),
690 )
691 .await?;
692
693 let mut label_map = HashMap::new();
694
695 for batch in results {
696 let vid_col = batch.column_by_name("_vid");
697 let labels_col = batch.column_by_name("labels");
698
699 if let (Some(vid_arr), Some(labels_arr)) = (
700 vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
701 labels_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::ListArray>()),
702 ) {
703 for i in 0..batch.num_rows() {
704 if vid_arr.is_null(i) {
705 continue;
706 }
707 let vid = Vid::new(vid_arr.value(i));
708
709 let values = labels_arr.value(i);
710 if let Some(str_arr) =
711 values.as_any().downcast_ref::<arrow_array::StringArray>()
712 {
713 let labels: Vec<String> = (0..str_arr.len())
714 .filter_map(|j| {
715 if str_arr.is_null(j) {
716 None
717 } else {
718 Some(str_arr.value(j).to_string())
719 }
720 })
721 .collect();
722 label_map.insert(vid, labels);
723 }
724 }
725 }
726 }
727
728 Ok(label_map)
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735 use arrow_array::StringArray;
736
737 #[test]
738 fn test_main_vertex_schema() {
739 let schema = MainVertexDataset::get_arrow_schema();
740 assert_eq!(schema.fields().len(), 9);
741 assert!(schema.field_with_name("_vid").is_ok());
742 assert!(schema.field_with_name("_uid").is_ok());
743 assert!(schema.field_with_name("ext_id").is_ok());
744 assert!(schema.field_with_name("labels").is_ok());
745 assert!(schema.field_with_name("props_json").is_ok());
746 assert!(schema.field_with_name("_deleted").is_ok());
747 assert!(schema.field_with_name("_version").is_ok());
748 assert!(schema.field_with_name("_created_at").is_ok());
749 assert!(schema.field_with_name("_updated_at").is_ok());
750 }
751
752 #[test]
753 fn test_build_record_batch() {
754 use uni_common::Value;
755 let mut props = HashMap::new();
756 props.insert("name".to_string(), Value::String("Alice".to_string()));
757 props.insert("ext_id".to_string(), Value::String("user_001".to_string()));
758
759 let vertices = vec![(Vid::new(1), vec!["Person".to_string()], props, false, 1u64)];
760
761 let batch = MainVertexDataset::build_record_batch(&vertices, None, None).unwrap();
762 assert_eq!(batch.num_rows(), 1);
763 assert_eq!(batch.num_columns(), 9);
764
765 let ext_id_col = batch.column_by_name("ext_id").unwrap();
767 let ext_id_arr = ext_id_col.as_any().downcast_ref::<StringArray>().unwrap();
768 assert_eq!(ext_id_arr.value(0), "user_001");
769 }
770}