Skip to main content

uni_store/storage/
main_vertex.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Main vertex table for unified vertex storage.
5//!
6//! This module implements the main `vertices` table as described in STORAGE_DESIGN.md.
7//! The main table contains all vertices in the graph with:
8//! - `_vid`: Internal vertex ID (primary key)
9//! - `_uid`: Content-addressed unique ID (SHA3-256 hash)
10//! - `ext_id`: Optional external/user-provided ID (globally unique)
11//! - `labels`: List of label names (OpenCypher multi-label)
12//! - `props_json`: All properties as JSONB blob
13//! - `_deleted`: Soft-delete flag
14//! - `_version`: MVCC version
15//! - `_created_at`: Creation timestamp
16//! - `_updated_at`: Update timestamp
17
18use crate::backend::StorageBackend;
19use crate::backend::table_names;
20use crate::backend::types::{ScalarIndexType, ScanRequest, WriteMode};
21use crate::storage::arrow_convert::build_timestamp_column_from_vid_map;
22use anyhow::{Result, anyhow};
23use arrow_array::builder::{
24    FixedSizeBinaryBuilder, LargeBinaryBuilder, ListBuilder, StringBuilder,
25};
26use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
27use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
28use sha3::{Digest, Sha3_256};
29use std::collections::HashMap;
30use std::sync::Arc;
31use uni_common::Properties;
32use uni_common::core::id::{UniId, Vid};
33
34/// Main vertex dataset for the unified `vertices` table.
35///
36/// This table contains all vertices regardless of label, providing:
37/// - Fast ID-based lookups without knowing the label
38/// - Global ext_id uniqueness enforcement
39/// - Multi-label storage with labels as a list column
40#[derive(Debug)]
41pub struct MainVertexDataset {
42    _base_uri: String,
43}
44
45impl MainVertexDataset {
46    /// Create a new MainVertexDataset.
47    pub fn new(base_uri: &str) -> Self {
48        Self {
49            _base_uri: base_uri.to_string(),
50        }
51    }
52
53    /// Get the Arrow schema for the main vertices table.
54    pub fn get_arrow_schema() -> Arc<ArrowSchema> {
55        Arc::new(ArrowSchema::new(vec![
56            Field::new("_vid", DataType::UInt64, false),
57            Field::new("_uid", DataType::FixedSizeBinary(32), true),
58            Field::new("ext_id", DataType::Utf8, true),
59            Field::new(
60                "labels",
61                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
62                false,
63            ),
64            Field::new("props_json", DataType::LargeBinary, true),
65            Field::new("_deleted", DataType::Boolean, false),
66            Field::new("_version", DataType::UInt64, false),
67            Field::new(
68                "_created_at",
69                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
70                true,
71            ),
72            Field::new(
73                "_updated_at",
74                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
75                true,
76            ),
77        ]))
78    }
79
80    /// Get the table name for the main vertices table.
81    pub fn table_name() -> &'static str {
82        table_names::main_vertex_table_name()
83    }
84
85    /// Compute the UniId (content-addressed hash) for a vertex.
86    fn compute_vertex_uid(labels: &[String], ext_id: Option<&str>, props: &Properties) -> UniId {
87        let mut hasher = Sha3_256::new();
88
89        // Hash labels (sorted for consistency)
90        let mut sorted_labels = labels.to_vec();
91        sorted_labels.sort();
92        for label in &sorted_labels {
93            hasher.update(label.as_bytes());
94            hasher.update(b"\0");
95        }
96
97        // Hash ext_id if present
98        if let Some(ext_id) = ext_id {
99            hasher.update(b"ext_id:");
100            hasher.update(ext_id.as_bytes());
101            hasher.update(b"\0");
102        }
103
104        // Hash properties (sorted by key for deterministic hashing)
105        let mut sorted_keys: Vec<_> = props.keys().collect();
106        sorted_keys.sort();
107        for key in sorted_keys {
108            if key == "ext_id" {
109                continue; // Already handled above
110            }
111            if let Some(val) = props.get(key) {
112                hasher.update(key.as_bytes());
113                hasher.update(b":");
114                hasher.update(val.to_string().as_bytes());
115                hasher.update(b"\0");
116            }
117        }
118
119        let result = hasher.finalize();
120        UniId::from_bytes(result.into())
121    }
122
123    /// Build a record batch for the main vertices table.
124    ///
125    /// # Arguments
126    /// * `vertices` - List of (vid, labels, properties, deleted, version) tuples
127    /// * `created_at` - Optional map of Vid -> nanoseconds since epoch
128    /// * `updated_at` - Optional map of Vid -> nanoseconds since epoch
129    pub fn build_record_batch(
130        vertices: &[(Vid, Vec<String>, Properties, bool, u64)],
131        created_at: Option<&HashMap<Vid, i64>>,
132        updated_at: Option<&HashMap<Vid, i64>>,
133    ) -> Result<RecordBatch> {
134        let arrow_schema = Self::get_arrow_schema();
135        let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
136
137        // _vid column
138        let vids: Vec<u64> = vertices.iter().map(|(v, _, _, _, _)| v.as_u64()).collect();
139        columns.push(Arc::new(UInt64Array::from(vids)));
140
141        // _uid column
142        let mut uid_builder = FixedSizeBinaryBuilder::new(32);
143        for (_, labels, props, _, _) in vertices.iter() {
144            let ext_id = props.get("ext_id").and_then(|v| v.as_str());
145            let uid = Self::compute_vertex_uid(labels, ext_id, props);
146            uid_builder.append_value(uid.as_bytes())?;
147        }
148        columns.push(Arc::new(uid_builder.finish()));
149
150        // ext_id column
151        let mut ext_id_builder = StringBuilder::new();
152        for (_, _, props, _, _) in vertices.iter() {
153            if let Some(ext_id_val) = props.get("ext_id").and_then(|v| v.as_str()) {
154                ext_id_builder.append_value(ext_id_val);
155            } else {
156                ext_id_builder.append_null();
157            }
158        }
159        columns.push(Arc::new(ext_id_builder.finish()));
160
161        // labels column (List<String>)
162        let mut labels_builder = ListBuilder::new(StringBuilder::new());
163        for (_, labels, _, _, _) in vertices.iter() {
164            let values_builder = labels_builder.values();
165            for label in labels {
166                values_builder.append_value(label);
167            }
168            labels_builder.append(true);
169        }
170        columns.push(Arc::new(labels_builder.finish()));
171
172        // props_json column (JSONB binary encoding)
173        let mut props_json_builder = LargeBinaryBuilder::new();
174        for (_, _, props, _, _) in vertices.iter() {
175            let jsonb_bytes = {
176                let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
177                let uni_val: uni_common::Value = json_val.into();
178                uni_common::cypher_value_codec::encode(&uni_val)
179            };
180            props_json_builder.append_value(&jsonb_bytes);
181        }
182        columns.push(Arc::new(props_json_builder.finish()));
183
184        // _deleted column
185        let deleted: Vec<bool> = vertices.iter().map(|(_, _, _, d, _)| *d).collect();
186        columns.push(Arc::new(BooleanArray::from(deleted)));
187
188        // _version column
189        let versions: Vec<u64> = vertices.iter().map(|(_, _, _, _, v)| *v).collect();
190        columns.push(Arc::new(UInt64Array::from(versions)));
191
192        // _created_at and _updated_at columns using shared builder
193        let vids = vertices.iter().map(|(v, _, _, _, _)| *v);
194        columns.push(build_timestamp_column_from_vid_map(
195            vids.clone(),
196            created_at,
197        ));
198        columns.push(build_timestamp_column_from_vid_map(vids, updated_at));
199
200        RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
201    }
202
203    /// Write a batch to the main vertices table.
204    ///
205    /// Creates the table if it doesn't exist, otherwise appends to it.
206    pub async fn write_batch(backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
207        let table_name = table_names::main_vertex_table_name();
208
209        if backend.table_exists(table_name).await? {
210            backend
211                .write(table_name, vec![batch], WriteMode::Append)
212                .await
213        } else {
214            backend.create_table(table_name, vec![batch]).await
215        }
216    }
217
218    /// Ensure default indexes exist on the main vertices table.
219    pub async fn ensure_default_indexes(backend: &dyn StorageBackend) -> Result<()> {
220        let table_name = table_names::main_vertex_table_name();
221
222        // BTree indexes for primary key and lookup columns
223        let _ = backend
224            .create_scalar_index(table_name, "_vid", ScalarIndexType::BTree)
225            .await;
226        let _ = backend
227            .create_scalar_index(table_name, "ext_id", ScalarIndexType::BTree)
228            .await;
229        let _ = backend
230            .create_scalar_index(table_name, "_uid", ScalarIndexType::BTree)
231            .await;
232
233        // LabelList index for array_contains() queries on labels
234        let _ = backend
235            .create_scalar_index(table_name, "labels", ScalarIndexType::LabelList)
236            .await;
237
238        Ok(())
239    }
240
241    /// Query the main vertices table for a vertex by ext_id.
242    ///
243    /// Returns the Vid if found, None otherwise.
244    ///
245    /// # Arguments
246    /// * `version` - Optional version high water mark for snapshot isolation.
247    ///   Pass `None` for writer uniqueness checks (global visibility).
248    ///   Pass `Some(hwm)` for query-time snapshot isolation.
249    pub async fn find_by_ext_id(
250        backend: &dyn StorageBackend,
251        ext_id: &str,
252        version: Option<u64>,
253    ) -> Result<Option<Vid>> {
254        let table_name = table_names::main_vertex_table_name();
255
256        if !backend.table_exists(table_name).await? {
257            return Ok(None);
258        }
259
260        let mut filter = format!(
261            "ext_id = '{}' AND _deleted = false",
262            ext_id.replace('\'', "''")
263        );
264        if let Some(hwm) = version {
265            filter.push_str(&format!(" AND _version <= {}", hwm));
266        }
267
268        let results = backend
269            .scan(
270                ScanRequest::all(table_name)
271                    .with_filter(filter)
272                    .with_columns(vec!["_vid".to_string()]),
273            )
274            .await?;
275
276        for batch in results {
277            if batch.num_rows() > 0
278                && let Some(vid_col) = batch.column_by_name("_vid")
279                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
280            {
281                return Ok(Some(Vid::from(vid_arr.value(0))));
282            }
283        }
284
285        Ok(None)
286    }
287
288    /// Check if an ext_id already exists in the main vertices table.
289    ///
290    /// # Arguments
291    /// * `version` - Optional version high water mark for snapshot isolation.
292    pub async fn ext_id_exists(
293        backend: &dyn StorageBackend,
294        ext_id: &str,
295        version: Option<u64>,
296    ) -> Result<bool> {
297        Ok(Self::find_by_ext_id(backend, ext_id, version)
298            .await?
299            .is_some())
300    }
301
302    /// Find labels for a vertex by VID in the main vertices table.
303    ///
304    /// Returns the list of labels if found, None otherwise.
305    ///
306    /// # Arguments
307    /// * `version` - Optional version high water mark for snapshot isolation.
308    pub async fn find_labels_by_vid(
309        backend: &dyn StorageBackend,
310        vid: Vid,
311        version: Option<u64>,
312    ) -> Result<Option<Vec<String>>> {
313        let table_name = table_names::main_vertex_table_name();
314
315        if !backend.table_exists(table_name).await? {
316            return Ok(None);
317        }
318
319        let mut filter = format!("_vid = {} AND _deleted = false", vid.as_u64());
320        if let Some(hwm) = version {
321            filter.push_str(&format!(" AND _version <= {}", hwm));
322        }
323
324        let results = backend
325            .scan(
326                ScanRequest::all(table_name)
327                    .with_filter(filter)
328                    .with_columns(vec!["labels".to_string()]),
329            )
330            .await?;
331
332        for batch in results {
333            if batch.num_rows() > 0
334                && let Some(labels_col) = batch.column_by_name("labels")
335                && let Some(list_arr) = labels_col.as_any().downcast_ref::<arrow_array::ListArray>()
336            {
337                // Labels is a List<Utf8> column
338                let values = list_arr.value(0);
339                if let Some(str_arr) = values.as_any().downcast_ref::<arrow_array::StringArray>() {
340                    let labels: Vec<String> = (0..str_arr.len())
341                        .filter_map(|i| {
342                            if str_arr.is_null(i) {
343                                None
344                            } else {
345                                Some(str_arr.value(i).to_string())
346                            }
347                        })
348                        .collect();
349                    return Ok(Some(labels));
350                }
351            }
352        }
353
354        Ok(None)
355    }
356
357    /// Find all non-deleted VIDs in the main vertices table.
358    ///
359    /// Returns all VIDs where `_deleted = false`.
360    ///
361    /// # Arguments
362    /// * `version` - Optional version high water mark for snapshot isolation.
363    ///
364    /// # Errors
365    ///
366    /// Returns an error if the table query fails.
367    pub async fn find_all_vids(
368        backend: &dyn StorageBackend,
369        version: Option<u64>,
370    ) -> Result<Vec<Vid>> {
371        let table_name = table_names::main_vertex_table_name();
372
373        if !backend.table_exists(table_name).await? {
374            return Ok(Vec::new());
375        }
376
377        let mut filter = "_deleted = false".to_string();
378        if let Some(hwm) = version {
379            filter.push_str(&format!(" AND _version <= {}", hwm));
380        }
381
382        let results = backend
383            .scan(
384                ScanRequest::all(table_name)
385                    .with_filter(filter)
386                    .with_columns(vec!["_vid".to_string()]),
387            )
388            .await?;
389
390        let mut vids = Vec::new();
391        for batch in results {
392            if let Some(vid_col) = batch.column_by_name("_vid")
393                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
394            {
395                for i in 0..vid_arr.len() {
396                    if !vid_arr.is_null(i) {
397                        vids.push(Vid::new(vid_arr.value(i)));
398                    }
399                }
400            }
401        }
402
403        Ok(vids)
404    }
405
406    /// Find VIDs by label name in the main vertices table.
407    ///
408    /// Searches for vertices where the labels array contains the given label
409    /// and `_deleted = false`.
410    ///
411    /// # Arguments
412    /// * `version` - Optional version high water mark for snapshot isolation.
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if the table query fails.
417    pub async fn find_vids_by_label_name(
418        backend: &dyn StorageBackend,
419        label: &str,
420        version: Option<u64>,
421    ) -> Result<Vec<Vid>> {
422        let table_name = table_names::main_vertex_table_name();
423
424        if !backend.table_exists(table_name).await? {
425            return Ok(Vec::new());
426        }
427
428        // Use SQL array_contains to filter by label
429        let mut filter = format!("_deleted = false AND array_contains(labels, '{}')", label);
430        if let Some(hwm) = version {
431            filter.push_str(&format!(" AND _version <= {}", hwm));
432        }
433
434        let results = backend
435            .scan(
436                ScanRequest::all(table_name)
437                    .with_filter(filter)
438                    .with_columns(vec!["_vid".to_string()]),
439            )
440            .await?;
441
442        let mut vids = Vec::new();
443        for batch in results {
444            if let Some(vid_col) = batch.column_by_name("_vid")
445                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
446            {
447                for i in 0..vid_arr.len() {
448                    if !vid_arr.is_null(i) {
449                        vids.push(Vid::new(vid_arr.value(i)));
450                    }
451                }
452            }
453        }
454
455        Ok(vids)
456    }
457
458    /// Find VIDs by multiple label names (intersection semantics).
459    ///
460    /// Returns vertices that have ALL the specified labels.
461    /// Uses `array_contains(labels, 'A') AND array_contains(labels, 'B')` filtering.
462    ///
463    /// # Arguments
464    /// * `version` - Optional version high water mark for snapshot isolation.
465    pub async fn find_vids_by_labels(
466        backend: &dyn StorageBackend,
467        labels: &[&str],
468        version: Option<u64>,
469    ) -> Result<Vec<Vid>> {
470        let table_name = table_names::main_vertex_table_name();
471
472        if labels.is_empty() || !backend.table_exists(table_name).await? {
473            return Ok(Vec::new());
474        }
475
476        // Build AND conditions for each label
477        let label_conditions: Vec<String> = labels
478            .iter()
479            .map(|label| {
480                let escaped = label.replace('\'', "''");
481                format!("array_contains(labels, '{}')", escaped)
482            })
483            .collect();
484
485        let mut filter = format!("_deleted = false AND {}", label_conditions.join(" AND "));
486        if let Some(hwm) = version {
487            filter.push_str(&format!(" AND _version <= {}", hwm));
488        }
489
490        let results = backend
491            .scan(
492                ScanRequest::all(table_name)
493                    .with_filter(filter)
494                    .with_columns(vec!["_vid".to_string()]),
495            )
496            .await?;
497
498        let mut vids = Vec::new();
499        for batch in results {
500            if let Some(vid_col) = batch.column_by_name("_vid")
501                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
502            {
503                for i in 0..vid_arr.len() {
504                    if !vid_arr.is_null(i) {
505                        vids.push(Vid::new(vid_arr.value(i)));
506                    }
507                }
508            }
509        }
510
511        Ok(vids)
512    }
513
514    /// Batch-fetch properties for multiple VIDs from the main vertices table.
515    ///
516    /// Returns a HashMap mapping VIDs to their parsed properties.
517    /// Non-deleted vertices are returned with properties from props_json.
518    /// This is used for schemaless vertex scans via DataFusion.
519    ///
520    /// # Arguments
521    /// * `version` - Optional version high water mark for snapshot isolation.
522    ///
523    /// # Errors
524    ///
525    /// Returns an error if the table query fails or JSON parsing fails.
526    pub async fn find_batch_props_by_vids(
527        backend: &dyn StorageBackend,
528        vids: &[Vid],
529        version: Option<u64>,
530    ) -> Result<HashMap<Vid, Properties>> {
531        let table_name = table_names::main_vertex_table_name();
532
533        if vids.is_empty() || !backend.table_exists(table_name).await? {
534            return Ok(HashMap::new());
535        }
536
537        // Build IN clause for VIDs
538        let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
539        let mut filter = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
540        if let Some(hwm) = version {
541            filter.push_str(&format!(" AND _version <= {}", hwm));
542        }
543
544        let results = backend
545            .scan(
546                ScanRequest::all(table_name)
547                    .with_filter(filter)
548                    .with_columns(vec!["_vid".to_string(), "props_json".to_string()]),
549            )
550            .await?;
551
552        let mut props_map = HashMap::new();
553
554        for batch in results {
555            let vid_col = batch.column_by_name("_vid");
556            let props_col = batch.column_by_name("props_json");
557
558            if let (Some(vid_arr), Some(props_arr)) = (
559                vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
560                props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
561            ) {
562                for i in 0..batch.num_rows() {
563                    if vid_arr.is_null(i) {
564                        continue;
565                    }
566                    let vid = Vid::new(vid_arr.value(i));
567
568                    let props: Properties = if props_arr.is_null(i) || props_arr.value(i).is_empty()
569                    {
570                        Properties::new()
571                    } else {
572                        let bytes = props_arr.value(i);
573                        let uni_val = uni_common::cypher_value_codec::decode(bytes)
574                            .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
575                        let json_val: serde_json::Value = uni_val.into();
576                        serde_json::from_value(json_val)
577                            .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?
578                    };
579
580                    props_map.insert(vid, props);
581                }
582            }
583        }
584
585        Ok(props_map)
586    }
587
588    /// Find properties for a vertex by VID in the main vertices table.
589    ///
590    /// Returns the props_json parsed into a Properties HashMap if found.
591    /// This is used as a fallback for unknown/schemaless labels.
592    ///
593    /// # Arguments
594    /// * `version` - Optional version high water mark for snapshot isolation.
595    ///
596    /// # Errors
597    ///
598    /// Returns an error if the table query fails or JSON parsing fails.
599    pub async fn find_props_by_vid(
600        backend: &dyn StorageBackend,
601        vid: Vid,
602        version: Option<u64>,
603    ) -> Result<Option<Properties>> {
604        let table_name = table_names::main_vertex_table_name();
605
606        if !backend.table_exists(table_name).await? {
607            return Ok(None);
608        }
609
610        let mut filter = format!("_vid = {} AND _deleted = false", vid.as_u64());
611        if let Some(hwm) = version {
612            filter.push_str(&format!(" AND _version <= {}", hwm));
613        }
614
615        let results = backend
616            .scan(
617                ScanRequest::all(table_name)
618                    .with_filter(filter)
619                    .with_columns(vec!["props_json".to_string(), "_version".to_string()]),
620            )
621            .await?;
622
623        // Find the row with highest version (latest)
624        let mut best_props: Option<Properties> = None;
625        let mut best_version: u64 = 0;
626
627        for batch in results {
628            let props_col = batch.column_by_name("props_json");
629            let version_col = batch.column_by_name("_version");
630
631            if let (Some(props_arr), Some(ver_arr)) = (
632                props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
633                version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
634            ) {
635                for i in 0..batch.num_rows() {
636                    let version = if ver_arr.is_null(i) {
637                        0
638                    } else {
639                        ver_arr.value(i)
640                    };
641
642                    if version >= best_version {
643                        best_version = version;
644                        if props_arr.is_null(i) || props_arr.value(i).is_empty() {
645                            best_props = Some(Properties::new());
646                        } else {
647                            let bytes = props_arr.value(i);
648                            let uni_val = uni_common::cypher_value_codec::decode(bytes)
649                                .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
650                            let json_val: serde_json::Value = uni_val.into();
651                            let parsed: Properties = serde_json::from_value(json_val)
652                                .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?;
653                            best_props = Some(parsed);
654                        }
655                    }
656                }
657            }
658        }
659
660        Ok(best_props)
661    }
662
663    /// Batch-fetch labels for multiple VIDs from the main vertices table.
664    ///
665    /// # Arguments
666    /// * `version` - Optional version high water mark for snapshot isolation.
667    pub async fn find_batch_labels_by_vids(
668        backend: &dyn StorageBackend,
669        vids: &[Vid],
670        version: Option<u64>,
671    ) -> Result<HashMap<Vid, Vec<String>>> {
672        let table_name = table_names::main_vertex_table_name();
673
674        if vids.is_empty() || !backend.table_exists(table_name).await? {
675            return Ok(HashMap::new());
676        }
677
678        // Build IN clause for VIDs
679        let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
680        let mut filter = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
681        if let Some(hwm) = version {
682            filter.push_str(&format!(" AND _version <= {}", hwm));
683        }
684
685        let results = backend
686            .scan(
687                ScanRequest::all(table_name)
688                    .with_filter(filter)
689                    .with_columns(vec!["_vid".to_string(), "labels".to_string()]),
690            )
691            .await?;
692
693        let mut label_map = HashMap::new();
694
695        for batch in results {
696            let vid_col = batch.column_by_name("_vid");
697            let labels_col = batch.column_by_name("labels");
698
699            if let (Some(vid_arr), Some(labels_arr)) = (
700                vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
701                labels_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::ListArray>()),
702            ) {
703                for i in 0..batch.num_rows() {
704                    if vid_arr.is_null(i) {
705                        continue;
706                    }
707                    let vid = Vid::new(vid_arr.value(i));
708
709                    let values = labels_arr.value(i);
710                    if let Some(str_arr) =
711                        values.as_any().downcast_ref::<arrow_array::StringArray>()
712                    {
713                        let labels: Vec<String> = (0..str_arr.len())
714                            .filter_map(|j| {
715                                if str_arr.is_null(j) {
716                                    None
717                                } else {
718                                    Some(str_arr.value(j).to_string())
719                                }
720                            })
721                            .collect();
722                        label_map.insert(vid, labels);
723                    }
724                }
725            }
726        }
727
728        Ok(label_map)
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    use arrow_array::StringArray;
736
737    #[test]
738    fn test_main_vertex_schema() {
739        let schema = MainVertexDataset::get_arrow_schema();
740        assert_eq!(schema.fields().len(), 9);
741        assert!(schema.field_with_name("_vid").is_ok());
742        assert!(schema.field_with_name("_uid").is_ok());
743        assert!(schema.field_with_name("ext_id").is_ok());
744        assert!(schema.field_with_name("labels").is_ok());
745        assert!(schema.field_with_name("props_json").is_ok());
746        assert!(schema.field_with_name("_deleted").is_ok());
747        assert!(schema.field_with_name("_version").is_ok());
748        assert!(schema.field_with_name("_created_at").is_ok());
749        assert!(schema.field_with_name("_updated_at").is_ok());
750    }
751
752    #[test]
753    fn test_build_record_batch() {
754        use uni_common::Value;
755        let mut props = HashMap::new();
756        props.insert("name".to_string(), Value::String("Alice".to_string()));
757        props.insert("ext_id".to_string(), Value::String("user_001".to_string()));
758
759        let vertices = vec![(Vid::new(1), vec!["Person".to_string()], props, false, 1u64)];
760
761        let batch = MainVertexDataset::build_record_batch(&vertices, None, None).unwrap();
762        assert_eq!(batch.num_rows(), 1);
763        assert_eq!(batch.num_columns(), 9);
764
765        // Check ext_id was extracted
766        let ext_id_col = batch.column_by_name("ext_id").unwrap();
767        let ext_id_arr = ext_id_col.as_any().downcast_ref::<StringArray>().unwrap();
768        assert_eq!(ext_id_arr.value(0), "user_001");
769    }
770}