Skip to main content

uni_store/storage/
main_vertex.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Main vertex table for unified vertex storage.
5//!
6//! This module implements the main `vertices` table as described in STORAGE_DESIGN.md.
7//! The main table contains all vertices in the graph with:
8//! - `_vid`: Internal vertex ID (primary key)
9//! - `_uid`: Content-addressed unique ID (SHA3-256 hash)
10//! - `ext_id`: Optional external/user-provided ID (globally unique)
11//! - `labels`: List of label names (OpenCypher multi-label)
12//! - `props_json`: All properties as JSONB blob
13//! - `_deleted`: Soft-delete flag
14//! - `_version`: MVCC version
15//! - `_created_at`: Creation timestamp
16//! - `_updated_at`: Update timestamp
17
18use crate::lancedb::LanceDbStore;
19use crate::storage::arrow_convert::build_timestamp_column_from_vid_map;
20use anyhow::{Result, anyhow};
21use arrow_array::builder::{
22    FixedSizeBinaryBuilder, LargeBinaryBuilder, ListBuilder, StringBuilder,
23};
24use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
26use futures::TryStreamExt;
27use lancedb::Table;
28use lancedb::index::Index as LanceDbIndex;
29use lancedb::index::scalar::{BTreeIndexBuilder, LabelListIndexBuilder};
30use lancedb::query::{ExecutableQuery, QueryBase};
31use sha3::{Digest, Sha3_256};
32use std::collections::HashMap;
33use std::sync::Arc;
34use uni_common::Properties;
35use uni_common::core::id::{UniId, Vid};
36
37/// Main vertex dataset for the unified `vertices` table.
38///
39/// This table contains all vertices regardless of label, providing:
40/// - Fast ID-based lookups without knowing the label
41/// - Global ext_id uniqueness enforcement
42/// - Multi-label storage with labels as a list column
43pub struct MainVertexDataset {
44    _base_uri: String,
45}
46
47impl MainVertexDataset {
48    /// Create a new MainVertexDataset.
49    pub fn new(base_uri: &str) -> Self {
50        Self {
51            _base_uri: base_uri.to_string(),
52        }
53    }
54
55    /// Get the Arrow schema for the main vertices table.
56    pub fn get_arrow_schema() -> Arc<ArrowSchema> {
57        Arc::new(ArrowSchema::new(vec![
58            Field::new("_vid", DataType::UInt64, false),
59            Field::new("_uid", DataType::FixedSizeBinary(32), true),
60            Field::new("ext_id", DataType::Utf8, true),
61            Field::new(
62                "labels",
63                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
64                false,
65            ),
66            Field::new("props_json", DataType::LargeBinary, true),
67            Field::new("_deleted", DataType::Boolean, false),
68            Field::new("_version", DataType::UInt64, false),
69            Field::new(
70                "_created_at",
71                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
72                true,
73            ),
74            Field::new(
75                "_updated_at",
76                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
77                true,
78            ),
79        ]))
80    }
81
82    /// Get the table name for the main vertices table.
83    pub fn table_name() -> &'static str {
84        "vertices"
85    }
86
87    /// Open the main vertices table.
88    ///
89    /// Returns the LanceDB table handle for querying vertices.
90    pub async fn open_table(store: &LanceDbStore) -> Result<Table> {
91        store
92            .open_table(Self::table_name())
93            .await
94            .map_err(|e| anyhow!("Failed to open main vertices table: {}", e))
95    }
96
97    /// Compute the UniId (content-addressed hash) for a vertex.
98    fn compute_vertex_uid(labels: &[String], ext_id: Option<&str>, props: &Properties) -> UniId {
99        let mut hasher = Sha3_256::new();
100
101        // Hash labels (sorted for consistency)
102        let mut sorted_labels = labels.to_vec();
103        sorted_labels.sort();
104        for label in &sorted_labels {
105            hasher.update(label.as_bytes());
106            hasher.update(b"\0");
107        }
108
109        // Hash ext_id if present
110        if let Some(ext_id) = ext_id {
111            hasher.update(b"ext_id:");
112            hasher.update(ext_id.as_bytes());
113            hasher.update(b"\0");
114        }
115
116        // Hash properties (sorted by key for deterministic hashing)
117        let mut sorted_keys: Vec<_> = props.keys().collect();
118        sorted_keys.sort();
119        for key in sorted_keys {
120            if key == "ext_id" {
121                continue; // Already handled above
122            }
123            if let Some(val) = props.get(key) {
124                hasher.update(key.as_bytes());
125                hasher.update(b":");
126                hasher.update(val.to_string().as_bytes());
127                hasher.update(b"\0");
128            }
129        }
130
131        let result = hasher.finalize();
132        UniId::from_bytes(result.into())
133    }
134
135    /// Build a record batch for the main vertices table.
136    ///
137    /// # Arguments
138    /// * `vertices` - List of (vid, labels, properties, deleted, version) tuples
139    /// * `created_at` - Optional map of Vid -> nanoseconds since epoch
140    /// * `updated_at` - Optional map of Vid -> nanoseconds since epoch
141    pub fn build_record_batch(
142        vertices: &[(Vid, Vec<String>, Properties, bool, u64)],
143        created_at: Option<&HashMap<Vid, i64>>,
144        updated_at: Option<&HashMap<Vid, i64>>,
145    ) -> Result<RecordBatch> {
146        let arrow_schema = Self::get_arrow_schema();
147        let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
148
149        // _vid column
150        let vids: Vec<u64> = vertices.iter().map(|(v, _, _, _, _)| v.as_u64()).collect();
151        columns.push(Arc::new(UInt64Array::from(vids)));
152
153        // _uid column
154        let mut uid_builder = FixedSizeBinaryBuilder::new(32);
155        for (_, labels, props, _, _) in vertices.iter() {
156            let ext_id = props.get("ext_id").and_then(|v| v.as_str());
157            let uid = Self::compute_vertex_uid(labels, ext_id, props);
158            uid_builder.append_value(uid.as_bytes())?;
159        }
160        columns.push(Arc::new(uid_builder.finish()));
161
162        // ext_id column
163        let mut ext_id_builder = StringBuilder::new();
164        for (_, _, props, _, _) in vertices.iter() {
165            if let Some(ext_id_val) = props.get("ext_id").and_then(|v| v.as_str()) {
166                ext_id_builder.append_value(ext_id_val);
167            } else {
168                ext_id_builder.append_null();
169            }
170        }
171        columns.push(Arc::new(ext_id_builder.finish()));
172
173        // labels column (List<String>)
174        let mut labels_builder = ListBuilder::new(StringBuilder::new());
175        for (_, labels, _, _, _) in vertices.iter() {
176            let values_builder = labels_builder.values();
177            for label in labels {
178                values_builder.append_value(label);
179            }
180            labels_builder.append(true);
181        }
182        columns.push(Arc::new(labels_builder.finish()));
183
184        // props_json column (JSONB binary encoding)
185        let mut props_json_builder = LargeBinaryBuilder::new();
186        for (_, _, props, _, _) in vertices.iter() {
187            let jsonb_bytes = {
188                let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
189                let uni_val: uni_common::Value = json_val.into();
190                uni_common::cypher_value_codec::encode(&uni_val)
191            };
192            props_json_builder.append_value(&jsonb_bytes);
193        }
194        columns.push(Arc::new(props_json_builder.finish()));
195
196        // _deleted column
197        let deleted: Vec<bool> = vertices.iter().map(|(_, _, _, d, _)| *d).collect();
198        columns.push(Arc::new(BooleanArray::from(deleted)));
199
200        // _version column
201        let versions: Vec<u64> = vertices.iter().map(|(_, _, _, _, v)| *v).collect();
202        columns.push(Arc::new(UInt64Array::from(versions)));
203
204        // _created_at and _updated_at columns using shared builder
205        let vids = vertices.iter().map(|(v, _, _, _, _)| *v);
206        columns.push(build_timestamp_column_from_vid_map(
207            vids.clone(),
208            created_at,
209        ));
210        columns.push(build_timestamp_column_from_vid_map(vids, updated_at));
211
212        RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
213    }
214
215    /// Write a batch to the main vertices table.
216    ///
217    /// Creates the table if it doesn't exist, otherwise appends to it.
218    pub async fn write_batch_lancedb(store: &LanceDbStore, batch: RecordBatch) -> Result<Table> {
219        let table_name = Self::table_name();
220
221        if store.table_exists(table_name).await? {
222            let table = store.open_table(table_name).await?;
223            store.append_to_table(&table, vec![batch]).await?;
224            Ok(table)
225        } else {
226            store.create_table(table_name, vec![batch]).await
227        }
228    }
229
230    /// Ensure default indexes exist on the main vertices table.
231    pub async fn ensure_default_indexes_lancedb(table: &Table) -> Result<()> {
232        let indices = table
233            .list_indices()
234            .await
235            .map_err(|e| anyhow!("Failed to list indices: {}", e))?;
236
237        // Ensure _vid index (primary key)
238        if !indices
239            .iter()
240            .any(|idx| idx.columns.contains(&"_vid".to_string()))
241        {
242            log::info!("Creating _vid BTree index for main vertices table");
243            if let Err(e) = table
244                .create_index(&["_vid"], LanceDbIndex::BTree(BTreeIndexBuilder::default()))
245                .execute()
246                .await
247            {
248                log::warn!("Failed to create _vid index for main vertices: {}", e);
249            }
250        }
251
252        // Ensure ext_id index (unique lookup)
253        if !indices
254            .iter()
255            .any(|idx| idx.columns.contains(&"ext_id".to_string()))
256        {
257            log::info!("Creating ext_id BTree index for main vertices table");
258            if let Err(e) = table
259                .create_index(
260                    &["ext_id"],
261                    LanceDbIndex::BTree(BTreeIndexBuilder::default()),
262                )
263                .execute()
264                .await
265            {
266                log::warn!("Failed to create ext_id index for main vertices: {}", e);
267            }
268        }
269
270        // Ensure _uid index
271        if !indices
272            .iter()
273            .any(|idx| idx.columns.contains(&"_uid".to_string()))
274        {
275            log::info!("Creating _uid BTree index for main vertices table");
276            if let Err(e) = table
277                .create_index(&["_uid"], LanceDbIndex::BTree(BTreeIndexBuilder::default()))
278                .execute()
279                .await
280            {
281                log::warn!("Failed to create _uid index for main vertices: {}", e);
282            }
283        }
284
285        // Ensure labels LABEL_LIST index (for array_contains() queries)
286        if !indices
287            .iter()
288            .any(|idx| idx.columns.contains(&"labels".to_string()))
289        {
290            log::info!("Creating labels LABEL_LIST index for main vertices table");
291            if let Err(e) = table
292                .create_index(
293                    &["labels"],
294                    LanceDbIndex::LabelList(LabelListIndexBuilder::default()),
295                )
296                .execute()
297                .await
298            {
299                log::warn!("Failed to create labels index for main vertices: {}", e);
300            }
301        }
302
303        Ok(())
304    }
305
306    /// Query the main vertices table for a vertex by ext_id.
307    ///
308    /// Returns the Vid if found, None otherwise.
309    ///
310    /// # Arguments
311    /// * `version` - Optional version high water mark for snapshot isolation.
312    ///   Pass `None` for writer uniqueness checks (global visibility).
313    ///   Pass `Some(hwm)` for query-time snapshot isolation.
314    pub async fn find_by_ext_id(
315        store: &LanceDbStore,
316        ext_id: &str,
317        version: Option<u64>,
318    ) -> Result<Option<Vid>> {
319        let table_name = Self::table_name();
320
321        if !store.table_exists(table_name).await? {
322            return Ok(None);
323        }
324
325        let table = store.open_table(table_name).await?;
326        let mut query = format!(
327            "ext_id = '{}' AND _deleted = false",
328            ext_id.replace('\'', "''")
329        );
330        if let Some(hwm) = version {
331            query.push_str(&format!(" AND _version <= {}", hwm));
332        }
333
334        let batches = table
335            .query()
336            .only_if(query)
337            .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
338            .execute()
339            .await
340            .map_err(|e| anyhow!("Query failed: {}", e))?;
341
342        let results: Vec<RecordBatch> = batches.try_collect().await?;
343
344        for batch in results {
345            if batch.num_rows() > 0
346                && let Some(vid_col) = batch.column_by_name("_vid")
347                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
348            {
349                return Ok(Some(Vid::from(vid_arr.value(0))));
350            }
351        }
352
353        Ok(None)
354    }
355
356    /// Check if an ext_id already exists in the main vertices table.
357    ///
358    /// # Arguments
359    /// * `version` - Optional version high water mark for snapshot isolation.
360    pub async fn ext_id_exists(
361        store: &LanceDbStore,
362        ext_id: &str,
363        version: Option<u64>,
364    ) -> Result<bool> {
365        Ok(Self::find_by_ext_id(store, ext_id, version)
366            .await?
367            .is_some())
368    }
369
370    /// Find labels for a vertex by VID in the main vertices table.
371    ///
372    /// Returns the list of labels if found, None otherwise.
373    ///
374    /// # Arguments
375    /// * `version` - Optional version high water mark for snapshot isolation.
376    pub async fn find_labels_by_vid(
377        store: &LanceDbStore,
378        vid: Vid,
379        version: Option<u64>,
380    ) -> Result<Option<Vec<String>>> {
381        let table_name = Self::table_name();
382
383        if !store.table_exists(table_name).await? {
384            return Ok(None);
385        }
386
387        let table = store.open_table(table_name).await?;
388        let mut query = format!("_vid = {} AND _deleted = false", vid.as_u64());
389        if let Some(hwm) = version {
390            query.push_str(&format!(" AND _version <= {}", hwm));
391        }
392
393        let batches = table
394            .query()
395            .only_if(query)
396            .select(lancedb::query::Select::Columns(vec!["labels".to_string()]))
397            .execute()
398            .await
399            .map_err(|e| anyhow!("Query failed: {}", e))?;
400
401        let results: Vec<RecordBatch> = batches.try_collect().await?;
402
403        for batch in results {
404            if batch.num_rows() > 0
405                && let Some(labels_col) = batch.column_by_name("labels")
406                && let Some(list_arr) = labels_col.as_any().downcast_ref::<arrow_array::ListArray>()
407            {
408                // Labels is a List<Utf8> column
409                let values = list_arr.value(0);
410                if let Some(str_arr) = values.as_any().downcast_ref::<arrow_array::StringArray>() {
411                    let labels: Vec<String> = (0..str_arr.len())
412                        .filter_map(|i| {
413                            if str_arr.is_null(i) {
414                                None
415                            } else {
416                                Some(str_arr.value(i).to_string())
417                            }
418                        })
419                        .collect();
420                    return Ok(Some(labels));
421                }
422            }
423        }
424
425        Ok(None)
426    }
427
428    /// Find all non-deleted VIDs in the main vertices table.
429    ///
430    /// Returns all VIDs where `_deleted = false`.
431    ///
432    /// # Arguments
433    /// * `version` - Optional version high water mark for snapshot isolation.
434    ///
435    /// # Errors
436    ///
437    /// Returns an error if the table query fails.
438    pub async fn find_all_vids(store: &LanceDbStore, version: Option<u64>) -> Result<Vec<Vid>> {
439        let table_name = Self::table_name();
440
441        if !store.table_exists(table_name).await? {
442            return Ok(Vec::new());
443        }
444
445        let table = store.open_table(table_name).await?;
446        let mut query = "_deleted = false".to_string();
447        if let Some(hwm) = version {
448            query.push_str(&format!(" AND _version <= {}", hwm));
449        }
450
451        let batches = table
452            .query()
453            .only_if(query)
454            .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
455            .execute()
456            .await
457            .map_err(|e| anyhow!("Query failed: {}", e))?;
458
459        let results: Vec<RecordBatch> = batches.try_collect().await?;
460
461        let mut vids = Vec::new();
462        for batch in results {
463            if let Some(vid_col) = batch.column_by_name("_vid")
464                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
465            {
466                for i in 0..vid_arr.len() {
467                    if !vid_arr.is_null(i) {
468                        vids.push(Vid::new(vid_arr.value(i)));
469                    }
470                }
471            }
472        }
473
474        Ok(vids)
475    }
476
477    /// Find VIDs by label name in the main vertices table.
478    ///
479    /// Searches for vertices where the labels array contains the given label
480    /// and `_deleted = false`.
481    ///
482    /// # Arguments
483    /// * `version` - Optional version high water mark for snapshot isolation.
484    ///
485    /// # Errors
486    ///
487    /// Returns an error if the table query fails.
488    pub async fn find_vids_by_label_name(
489        store: &LanceDbStore,
490        label: &str,
491        version: Option<u64>,
492    ) -> Result<Vec<Vid>> {
493        let table_name = Self::table_name();
494
495        if !store.table_exists(table_name).await? {
496            return Ok(Vec::new());
497        }
498
499        let table = store.open_table(table_name).await?;
500        // Use SQL array_contains to filter by label
501        let mut query = format!("_deleted = false AND array_contains(labels, '{}')", label);
502        if let Some(hwm) = version {
503            query.push_str(&format!(" AND _version <= {}", hwm));
504        }
505
506        let batches = table
507            .query()
508            .only_if(query)
509            .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
510            .execute()
511            .await
512            .map_err(|e| anyhow!("Query failed: {}", e))?;
513
514        let results: Vec<RecordBatch> = batches.try_collect().await?;
515
516        let mut vids = Vec::new();
517        for batch in results {
518            if let Some(vid_col) = batch.column_by_name("_vid")
519                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
520            {
521                for i in 0..vid_arr.len() {
522                    if !vid_arr.is_null(i) {
523                        vids.push(Vid::new(vid_arr.value(i)));
524                    }
525                }
526            }
527        }
528
529        Ok(vids)
530    }
531
532    /// Find VIDs by multiple label names (intersection semantics).
533    ///
534    /// Returns vertices that have ALL the specified labels.
535    /// Uses `array_contains(labels, 'A') AND array_contains(labels, 'B')` filtering.
536    ///
537    /// # Arguments
538    /// * `version` - Optional version high water mark for snapshot isolation.
539    pub async fn find_vids_by_labels(
540        store: &LanceDbStore,
541        labels: &[&str],
542        version: Option<u64>,
543    ) -> Result<Vec<Vid>> {
544        let table_name = Self::table_name();
545
546        if labels.is_empty() || !store.table_exists(table_name).await? {
547            return Ok(Vec::new());
548        }
549
550        let table = store.open_table(table_name).await?;
551
552        // Build AND conditions for each label
553        let label_conditions: Vec<String> = labels
554            .iter()
555            .map(|label| {
556                let escaped = label.replace('\'', "''");
557                format!("array_contains(labels, '{}')", escaped)
558            })
559            .collect();
560
561        let mut query = format!("_deleted = false AND {}", label_conditions.join(" AND "));
562        if let Some(hwm) = version {
563            query.push_str(&format!(" AND _version <= {}", hwm));
564        }
565
566        let batches = table
567            .query()
568            .only_if(query)
569            .select(lancedb::query::Select::Columns(vec!["_vid".to_string()]))
570            .execute()
571            .await
572            .map_err(|e| anyhow!("Query failed: {}", e))?;
573
574        let results: Vec<RecordBatch> = batches.try_collect().await?;
575
576        let mut vids = Vec::new();
577        for batch in results {
578            if let Some(vid_col) = batch.column_by_name("_vid")
579                && let Some(vid_arr) = vid_col.as_any().downcast_ref::<UInt64Array>()
580            {
581                for i in 0..vid_arr.len() {
582                    if !vid_arr.is_null(i) {
583                        vids.push(Vid::new(vid_arr.value(i)));
584                    }
585                }
586            }
587        }
588
589        Ok(vids)
590    }
591
592    /// Batch-fetch properties for multiple VIDs from the main vertices table.
593    ///
594    /// Returns a HashMap mapping VIDs to their parsed properties.
595    /// Non-deleted vertices are returned with properties from props_json.
596    /// This is used for schemaless vertex scans via DataFusion.
597    ///
598    /// # Arguments
599    /// * `version` - Optional version high water mark for snapshot isolation.
600    ///
601    /// # Errors
602    ///
603    /// Returns an error if the table query fails or JSON parsing fails.
604    pub async fn find_batch_props_by_vids(
605        store: &LanceDbStore,
606        vids: &[Vid],
607        version: Option<u64>,
608    ) -> Result<HashMap<Vid, Properties>> {
609        let table_name = Self::table_name();
610
611        if vids.is_empty() || !store.table_exists(table_name).await? {
612            return Ok(HashMap::new());
613        }
614
615        let table = store.open_table(table_name).await?;
616
617        // Build IN clause for VIDs
618        let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
619        let mut query = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
620        if let Some(hwm) = version {
621            query.push_str(&format!(" AND _version <= {}", hwm));
622        }
623
624        let batches = table
625            .query()
626            .only_if(query)
627            .select(lancedb::query::Select::Columns(vec![
628                "_vid".to_string(),
629                "props_json".to_string(),
630            ]))
631            .execute()
632            .await
633            .map_err(|e| anyhow!("Query failed: {}", e))?;
634
635        let results: Vec<RecordBatch> = batches.try_collect().await?;
636
637        let mut props_map = HashMap::new();
638
639        for batch in results {
640            let vid_col = batch.column_by_name("_vid");
641            let props_col = batch.column_by_name("props_json");
642
643            if let (Some(vid_arr), Some(props_arr)) = (
644                vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
645                props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
646            ) {
647                for i in 0..batch.num_rows() {
648                    if vid_arr.is_null(i) {
649                        continue;
650                    }
651                    let vid = Vid::new(vid_arr.value(i));
652
653                    let props: Properties = if props_arr.is_null(i) || props_arr.value(i).is_empty()
654                    {
655                        Properties::new()
656                    } else {
657                        let bytes = props_arr.value(i);
658                        let uni_val = uni_common::cypher_value_codec::decode(bytes)
659                            .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
660                        let json_val: serde_json::Value = uni_val.into();
661                        serde_json::from_value(json_val)
662                            .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?
663                    };
664
665                    props_map.insert(vid, props);
666                }
667            }
668        }
669
670        Ok(props_map)
671    }
672
673    /// Find properties for a vertex by VID in the main vertices table.
674    ///
675    /// Returns the props_json parsed into a Properties HashMap if found.
676    /// This is used as a fallback for unknown/schemaless labels.
677    ///
678    /// # Arguments
679    /// * `version` - Optional version high water mark for snapshot isolation.
680    ///
681    /// # Errors
682    ///
683    /// Returns an error if the table query fails or JSON parsing fails.
684    pub async fn find_props_by_vid(
685        store: &LanceDbStore,
686        vid: Vid,
687        version: Option<u64>,
688    ) -> Result<Option<Properties>> {
689        let table_name = Self::table_name();
690
691        if !store.table_exists(table_name).await? {
692            return Ok(None);
693        }
694
695        let table = store.open_table(table_name).await?;
696        let mut query = format!("_vid = {} AND _deleted = false", vid.as_u64());
697        if let Some(hwm) = version {
698            query.push_str(&format!(" AND _version <= {}", hwm));
699        }
700
701        let batches = table
702            .query()
703            .only_if(query)
704            .select(lancedb::query::Select::Columns(vec![
705                "props_json".to_string(),
706                "_version".to_string(),
707            ]))
708            .execute()
709            .await
710            .map_err(|e| anyhow!("Query failed: {}", e))?;
711
712        let results: Vec<RecordBatch> = batches.try_collect().await?;
713
714        // Find the row with highest version (latest)
715        let mut best_props: Option<Properties> = None;
716        let mut best_version: u64 = 0;
717
718        for batch in results {
719            let props_col = batch.column_by_name("props_json");
720            let version_col = batch.column_by_name("_version");
721
722            if let (Some(props_arr), Some(ver_arr)) = (
723                props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
724                version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
725            ) {
726                for i in 0..batch.num_rows() {
727                    let version = if ver_arr.is_null(i) {
728                        0
729                    } else {
730                        ver_arr.value(i)
731                    };
732
733                    if version >= best_version {
734                        best_version = version;
735                        if props_arr.is_null(i) || props_arr.value(i).is_empty() {
736                            best_props = Some(Properties::new());
737                        } else {
738                            let bytes = props_arr.value(i);
739                            let uni_val = uni_common::cypher_value_codec::decode(bytes)
740                                .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
741                            let json_val: serde_json::Value = uni_val.into();
742                            let parsed: Properties = serde_json::from_value(json_val)
743                                .map_err(|e| anyhow!("Failed to parse props_json: {}", e))?;
744                            best_props = Some(parsed);
745                        }
746                    }
747                }
748            }
749        }
750
751        Ok(best_props)
752    }
753
754    /// Batch-fetch labels for multiple VIDs from the main vertices table.
755    ///
756    /// # Arguments
757    /// * `version` - Optional version high water mark for snapshot isolation.
758    pub async fn find_batch_labels_by_vids(
759        store: &LanceDbStore,
760        vids: &[Vid],
761        version: Option<u64>,
762    ) -> Result<HashMap<Vid, Vec<String>>> {
763        let table_name = Self::table_name();
764
765        if vids.is_empty() || !store.table_exists(table_name).await? {
766            return Ok(HashMap::new());
767        }
768
769        let table = store.open_table(table_name).await?;
770
771        // Build IN clause for VIDs
772        let vid_list: Vec<String> = vids.iter().map(|v| v.as_u64().to_string()).collect();
773        let mut query = format!("_vid IN ({}) AND _deleted = false", vid_list.join(", "));
774        if let Some(hwm) = version {
775            query.push_str(&format!(" AND _version <= {}", hwm));
776        }
777
778        let batches = table
779            .query()
780            .only_if(query)
781            .select(lancedb::query::Select::Columns(vec![
782                "_vid".to_string(),
783                "labels".to_string(),
784            ]))
785            .execute()
786            .await
787            .map_err(|e| anyhow!("Query failed: {}", e))?;
788
789        let results: Vec<RecordBatch> = batches.try_collect().await?;
790
791        let mut label_map = HashMap::new();
792
793        for batch in results {
794            let vid_col = batch.column_by_name("_vid");
795            let labels_col = batch.column_by_name("labels");
796
797            if let (Some(vid_arr), Some(labels_arr)) = (
798                vid_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
799                labels_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::ListArray>()),
800            ) {
801                for i in 0..batch.num_rows() {
802                    if vid_arr.is_null(i) {
803                        continue;
804                    }
805                    let vid = Vid::new(vid_arr.value(i));
806
807                    let values = labels_arr.value(i);
808                    if let Some(str_arr) =
809                        values.as_any().downcast_ref::<arrow_array::StringArray>()
810                    {
811                        let labels: Vec<String> = (0..str_arr.len())
812                            .filter_map(|j| {
813                                if str_arr.is_null(j) {
814                                    None
815                                } else {
816                                    Some(str_arr.value(j).to_string())
817                                }
818                            })
819                            .collect();
820                        label_map.insert(vid, labels);
821                    }
822                }
823            }
824        }
825
826        Ok(label_map)
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use arrow_array::StringArray;
834
835    #[test]
836    fn test_main_vertex_schema() {
837        let schema = MainVertexDataset::get_arrow_schema();
838        assert_eq!(schema.fields().len(), 9);
839        assert!(schema.field_with_name("_vid").is_ok());
840        assert!(schema.field_with_name("_uid").is_ok());
841        assert!(schema.field_with_name("ext_id").is_ok());
842        assert!(schema.field_with_name("labels").is_ok());
843        assert!(schema.field_with_name("props_json").is_ok());
844        assert!(schema.field_with_name("_deleted").is_ok());
845        assert!(schema.field_with_name("_version").is_ok());
846        assert!(schema.field_with_name("_created_at").is_ok());
847        assert!(schema.field_with_name("_updated_at").is_ok());
848    }
849
850    #[test]
851    fn test_build_record_batch() {
852        use uni_common::Value;
853        let mut props = HashMap::new();
854        props.insert("name".to_string(), Value::String("Alice".to_string()));
855        props.insert("ext_id".to_string(), Value::String("user_001".to_string()));
856
857        let vertices = vec![(Vid::new(1), vec!["Person".to_string()], props, false, 1u64)];
858
859        let batch = MainVertexDataset::build_record_batch(&vertices, None, None).unwrap();
860        assert_eq!(batch.num_rows(), 1);
861        assert_eq!(batch.num_columns(), 9);
862
863        // Check ext_id was extracted
864        let ext_id_col = batch.column_by_name("ext_id").unwrap();
865        let ext_id_arr = ext_id_col.as_any().downcast_ref::<StringArray>().unwrap();
866        assert_eq!(ext_id_arr.value(0), "user_001");
867    }
868}