1use crate::lancedb::LanceDbStore;
19use crate::storage::arrow_convert::build_timestamp_column_from_eid_map;
20use anyhow::{Result, anyhow};
21use arrow_array::builder::{LargeBinaryBuilder, StringBuilder};
22use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
23use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
24use futures::TryStreamExt;
25use lancedb::Table;
26use lancedb::index::Index as LanceDbIndex;
27use lancedb::index::scalar::BTreeIndexBuilder;
28use lancedb::query::{ExecutableQuery, QueryBase, Select};
29use std::collections::HashMap;
30use std::sync::Arc;
31use uni_common::Properties;
32use uni_common::core::id::{Eid, Vid};
33
34pub struct MainEdgeDataset {
40 _base_uri: String,
41}
42
43impl MainEdgeDataset {
44 pub fn new(base_uri: &str) -> Self {
46 Self {
47 _base_uri: base_uri.to_string(),
48 }
49 }
50
51 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
53 Arc::new(ArrowSchema::new(vec![
54 Field::new("_eid", DataType::UInt64, false),
55 Field::new("src_vid", DataType::UInt64, false),
56 Field::new("dst_vid", DataType::UInt64, false),
57 Field::new("type", DataType::Utf8, false),
58 Field::new("props_json", DataType::LargeBinary, true),
59 Field::new("_deleted", DataType::Boolean, false),
60 Field::new("_version", DataType::UInt64, false),
61 Field::new(
62 "_created_at",
63 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
64 true,
65 ),
66 Field::new(
67 "_updated_at",
68 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
69 true,
70 ),
71 ]))
72 }
73
74 pub fn table_name() -> &'static str {
76 "edges"
77 }
78
79 pub fn build_record_batch(
86 edges: &[(Eid, Vid, Vid, String, Properties, bool, u64)],
87 created_at: Option<&HashMap<Eid, i64>>,
88 updated_at: Option<&HashMap<Eid, i64>>,
89 ) -> Result<RecordBatch> {
90 let arrow_schema = Self::get_arrow_schema();
91 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
92
93 let eids: Vec<u64> = edges
95 .iter()
96 .map(|(e, _, _, _, _, _, _)| e.as_u64())
97 .collect();
98 columns.push(Arc::new(UInt64Array::from(eids)));
99
100 let src_vids: Vec<u64> = edges
102 .iter()
103 .map(|(_, s, _, _, _, _, _)| s.as_u64())
104 .collect();
105 columns.push(Arc::new(UInt64Array::from(src_vids)));
106
107 let dst_vids: Vec<u64> = edges
109 .iter()
110 .map(|(_, _, d, _, _, _, _)| d.as_u64())
111 .collect();
112 columns.push(Arc::new(UInt64Array::from(dst_vids)));
113
114 let mut type_builder = StringBuilder::new();
116 for (_, _, _, edge_type, _, _, _) in edges.iter() {
117 type_builder.append_value(edge_type);
118 }
119 columns.push(Arc::new(type_builder.finish()));
120
121 let mut props_json_builder = LargeBinaryBuilder::new();
123 for (_, _, _, _, props, _, _) in edges.iter() {
124 let jsonb_bytes = {
125 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
126 let uni_val: uni_common::Value = json_val.into();
127 uni_common::cypher_value_codec::encode(&uni_val)
128 };
129 props_json_builder.append_value(&jsonb_bytes);
130 }
131 columns.push(Arc::new(props_json_builder.finish()));
132
133 let deleted: Vec<bool> = edges.iter().map(|(_, _, _, _, _, d, _)| *d).collect();
135 columns.push(Arc::new(BooleanArray::from(deleted)));
136
137 let versions: Vec<u64> = edges.iter().map(|(_, _, _, _, _, _, v)| *v).collect();
139 columns.push(Arc::new(UInt64Array::from(versions)));
140
141 let eids = edges.iter().map(|(e, _, _, _, _, _, _)| *e);
143 columns.push(build_timestamp_column_from_eid_map(
144 eids.clone(),
145 created_at,
146 ));
147 columns.push(build_timestamp_column_from_eid_map(eids, updated_at));
148
149 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
150 }
151
152 pub async fn write_batch_lancedb(store: &LanceDbStore, batch: RecordBatch) -> Result<Table> {
156 let table_name = Self::table_name();
157
158 if store.table_exists(table_name).await? {
159 let table = store.open_table(table_name).await?;
160 store.append_to_table(&table, vec![batch]).await?;
161 Ok(table)
162 } else {
163 store.create_table(table_name, vec![batch]).await
164 }
165 }
166
167 pub async fn ensure_default_indexes_lancedb(table: &Table) -> Result<()> {
169 let indices = table
170 .list_indices()
171 .await
172 .map_err(|e| anyhow!("Failed to list indices: {}", e))?;
173
174 if !indices
176 .iter()
177 .any(|idx| idx.columns.contains(&"_eid".to_string()))
178 {
179 log::info!("Creating _eid BTree index for main edges table");
180 if let Err(e) = table
181 .create_index(&["_eid"], LanceDbIndex::BTree(BTreeIndexBuilder::default()))
182 .execute()
183 .await
184 {
185 log::warn!("Failed to create _eid index for main edges: {}", e);
186 }
187 }
188
189 if !indices
191 .iter()
192 .any(|idx| idx.columns.contains(&"src_vid".to_string()))
193 {
194 log::info!("Creating src_vid BTree index for main edges table");
195 if let Err(e) = table
196 .create_index(
197 &["src_vid"],
198 LanceDbIndex::BTree(BTreeIndexBuilder::default()),
199 )
200 .execute()
201 .await
202 {
203 log::warn!("Failed to create src_vid index for main edges: {}", e);
204 }
205 }
206
207 if !indices
209 .iter()
210 .any(|idx| idx.columns.contains(&"dst_vid".to_string()))
211 {
212 log::info!("Creating dst_vid BTree index for main edges table");
213 if let Err(e) = table
214 .create_index(
215 &["dst_vid"],
216 LanceDbIndex::BTree(BTreeIndexBuilder::default()),
217 )
218 .execute()
219 .await
220 {
221 log::warn!("Failed to create dst_vid index for main edges: {}", e);
222 }
223 }
224
225 if !indices
227 .iter()
228 .any(|idx| idx.columns.contains(&"type".to_string()))
229 {
230 log::info!("Creating type BTree index for main edges table");
231 if let Err(e) = table
232 .create_index(&["type"], LanceDbIndex::BTree(BTreeIndexBuilder::default()))
233 .execute()
234 .await
235 {
236 log::warn!("Failed to create type index for main edges: {}", e);
237 }
238 }
239
240 Ok(())
241 }
242
243 pub async fn find_by_eid(
245 store: &LanceDbStore,
246 eid: Eid,
247 ) -> Result<Option<(Vid, Vid, String, Properties)>> {
248 let table_name = Self::table_name();
249
250 if !store.table_exists(table_name).await? {
251 return Ok(None);
252 }
253
254 let table = store.open_table(table_name).await?;
255 let query = format!("_eid = {}", eid.as_u64());
256
257 let batches = table
258 .query()
259 .only_if(query)
260 .execute()
261 .await
262 .map_err(|e| anyhow!("Query failed: {}", e))?;
263
264 let results: Vec<RecordBatch> = batches.try_collect().await?;
265
266 for batch in results {
267 if batch.num_rows() > 0 {
268 let src_vid_col = batch.column_by_name("src_vid");
269 let dst_vid_col = batch.column_by_name("dst_vid");
270 let type_col = batch.column_by_name("type");
271 let props_col = batch.column_by_name("props_json");
272
273 if let (Some(src), Some(dst), Some(typ), Some(props)) =
274 (src_vid_col, dst_vid_col, type_col, props_col)
275 && let (Some(src_arr), Some(dst_arr), Some(type_arr), Some(props_arr)) = (
276 src.as_any().downcast_ref::<UInt64Array>(),
277 dst.as_any().downcast_ref::<UInt64Array>(),
278 typ.as_any().downcast_ref::<arrow_array::StringArray>(),
279 props
280 .as_any()
281 .downcast_ref::<arrow_array::LargeBinaryArray>(),
282 )
283 {
284 let src_vid = Vid::from(src_arr.value(0));
285 let dst_vid = Vid::from(dst_arr.value(0));
286 let edge_type = type_arr.value(0).to_string();
287 let properties: Properties = if props_arr.is_null(0)
288 || props_arr.value(0).is_empty()
289 {
290 Properties::new()
291 } else {
292 let uni_val = uni_common::cypher_value_codec::decode(props_arr.value(0))
293 .unwrap_or(uni_common::Value::Null);
294 let json_val: serde_json::Value = uni_val.into();
295 serde_json::from_value(json_val).unwrap_or_default()
296 };
297
298 return Ok(Some((src_vid, dst_vid, edge_type, properties)));
299 }
300 }
301 }
302
303 Ok(None)
304 }
305
306 pub async fn open_table(store: &LanceDbStore) -> Result<Option<Table>> {
310 let table_name = Self::table_name();
311
312 if !store.table_exists(table_name).await? {
313 return Ok(None);
314 }
315
316 let table = store.open_table(table_name).await?;
317 Ok(Some(table))
318 }
319
320 async fn execute_query(
324 store: &LanceDbStore,
325 filter: &str,
326 columns: Option<Vec<&str>>,
327 ) -> Result<Vec<RecordBatch>> {
328 let Some(table) = Self::open_table(store).await? else {
329 return Ok(Vec::new());
330 };
331
332 let mut query = table.query();
333 query = query.only_if(filter);
334
335 if let Some(cols) = columns {
336 query = query.select(Select::Columns(
337 cols.into_iter().map(String::from).collect(),
338 ));
339 }
340
341 let batches = query
342 .execute()
343 .await
344 .map_err(|e| anyhow!("Query failed: {}", e))?;
345
346 batches.try_collect().await.map_err(Into::into)
347 }
348
349 fn extract_eids(batches: &[RecordBatch]) -> Vec<Eid> {
351 let mut eids = Vec::new();
352 for batch in batches {
353 if let Some(eid_col) = batch.column_by_name("_eid")
354 && let Some(eid_arr) = eid_col.as_any().downcast_ref::<UInt64Array>()
355 {
356 for i in 0..eid_arr.len() {
357 if !eid_arr.is_null(i) {
358 eids.push(Eid::new(eid_arr.value(i)));
359 }
360 }
361 }
362 }
363 eids
364 }
365
366 pub async fn find_all_eids(store: &LanceDbStore) -> Result<Vec<Eid>> {
368 let batches = Self::execute_query(store, "_deleted = false", Some(vec!["_eid"])).await?;
369 Ok(Self::extract_eids(&batches))
370 }
371
372 pub async fn find_eids_by_type_name(store: &LanceDbStore, type_name: &str) -> Result<Vec<Eid>> {
374 let filter = format!(
375 "_deleted = false AND type = '{}'",
376 type_name.replace('\'', "''")
377 );
378 let batches = Self::execute_query(store, &filter, Some(vec!["_eid"])).await?;
379 Ok(Self::extract_eids(&batches))
380 }
381
382 pub async fn find_props_by_eid(store: &LanceDbStore, eid: Eid) -> Result<Option<Properties>> {
387 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
388 let batches =
389 Self::execute_query(store, &filter, Some(vec!["props_json", "_version"])).await?;
390
391 if batches.is_empty() {
392 return Ok(None);
393 }
394
395 let mut best_props: Option<Properties> = None;
397 let mut best_version: u64 = 0;
398
399 for batch in &batches {
400 let props_col = batch.column_by_name("props_json");
401 let version_col = batch.column_by_name("_version");
402
403 if let (Some(props_arr), Some(ver_arr)) = (
404 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
405 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
406 ) {
407 for i in 0..batch.num_rows() {
408 let version = if ver_arr.is_null(i) {
409 0
410 } else {
411 ver_arr.value(i)
412 };
413
414 if version >= best_version {
415 best_version = version;
416 best_props = Some(Self::parse_props_json(props_arr, i)?);
417 }
418 }
419 }
420 }
421
422 Ok(best_props)
423 }
424
425 fn parse_props_json(arr: &arrow_array::LargeBinaryArray, idx: usize) -> Result<Properties> {
427 if arr.is_null(idx) || arr.value(idx).is_empty() {
428 return Ok(Properties::new());
429 }
430 let bytes = arr.value(idx);
431 let uni_val = uni_common::cypher_value_codec::decode(bytes)
432 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
433 let json_val: serde_json::Value = uni_val.into();
434 serde_json::from_value(json_val).map_err(|e| anyhow!("Failed to parse props_json: {}", e))
435 }
436
437 pub async fn find_type_by_eid(store: &LanceDbStore, eid: Eid) -> Result<Option<String>> {
439 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
440 let batches = Self::execute_query(store, &filter, Some(vec!["type"])).await?;
441
442 for batch in batches {
443 if batch.num_rows() > 0
444 && let Some(type_col) = batch.column_by_name("type")
445 && let Some(type_arr) = type_col.as_any().downcast_ref::<arrow_array::StringArray>()
446 && !type_arr.is_null(0)
447 {
448 return Ok(Some(type_arr.value(0).to_string()));
449 }
450 }
451
452 Ok(None)
453 }
454
455 pub async fn find_edges_by_type_name(
459 store: &LanceDbStore,
460 type_name: &str,
461 ) -> Result<Vec<(Eid, Vid, Vid, Properties)>> {
462 let filter = format!(
463 "_deleted = false AND type = '{}'",
464 type_name.replace('\'', "''")
465 );
466 let batches = Self::execute_query(store, &filter, None).await?;
468
469 let mut edges = Vec::new();
470 for batch in &batches {
471 Self::extract_edges_from_batch(batch, &mut edges)?;
472 }
473
474 Ok(edges)
475 }
476
477 pub async fn find_edges_by_type_names(
482 store: &LanceDbStore,
483 type_names: &[&str],
484 ) -> Result<Vec<(Eid, Vid, Vid, String, Properties)>> {
485 if type_names.is_empty() {
486 return Ok(Vec::new());
487 }
488
489 let escaped_types: Vec<String> = type_names
491 .iter()
492 .map(|t| format!("'{}'", t.replace('\'', "''")))
493 .collect();
494 let filter = format!(
495 "_deleted = false AND type IN ({})",
496 escaped_types.join(", ")
497 );
498
499 let batches = Self::execute_query(store, &filter, None).await?;
501
502 let mut edges = Vec::new();
503 for batch in &batches {
504 Self::extract_edges_with_type_from_batch(batch, &mut edges)?;
505 }
506
507 Ok(edges)
508 }
509
510 fn extract_edges_from_batch(
512 batch: &RecordBatch,
513 edges: &mut Vec<(Eid, Vid, Vid, Properties)>,
514 ) -> Result<()> {
515 let mut edges_with_type = Vec::new();
517 Self::extract_edges_with_type_from_batch(batch, &mut edges_with_type)?;
518 edges.extend(
519 edges_with_type
520 .into_iter()
521 .map(|(eid, src, dst, _type, props)| (eid, src, dst, props)),
522 );
523 Ok(())
524 }
525
526 fn extract_edges_with_type_from_batch(
528 batch: &RecordBatch,
529 edges: &mut Vec<(Eid, Vid, Vid, String, Properties)>,
530 ) -> Result<()> {
531 let Some(eid_arr) = batch
532 .column_by_name("_eid")
533 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
534 else {
535 return Ok(());
536 };
537 let Some(src_arr) = batch
538 .column_by_name("src_vid")
539 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
540 else {
541 return Ok(());
542 };
543 let Some(dst_arr) = batch
544 .column_by_name("dst_vid")
545 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
546 else {
547 return Ok(());
548 };
549 let type_arr = batch
550 .column_by_name("type")
551 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
552 let props_arr = batch
553 .column_by_name("props_json")
554 .and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>());
555
556 for i in 0..batch.num_rows() {
557 if eid_arr.is_null(i) || src_arr.is_null(i) || dst_arr.is_null(i) {
558 continue;
559 }
560
561 let eid = Eid::new(eid_arr.value(i));
562 let src_vid = Vid::new(src_arr.value(i));
563 let dst_vid = Vid::new(dst_arr.value(i));
564 let edge_type = type_arr
565 .filter(|arr| !arr.is_null(i))
566 .map(|arr| arr.value(i).to_string())
567 .unwrap_or_default();
568 let props = props_arr
569 .map(|arr| Self::parse_props_json(arr, i))
570 .transpose()?
571 .unwrap_or_default();
572
573 edges.push((eid, src_vid, dst_vid, edge_type, props));
574 }
575
576 Ok(())
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn test_main_edge_schema() {
586 let schema = MainEdgeDataset::get_arrow_schema();
587 assert_eq!(schema.fields().len(), 9);
588 assert!(schema.field_with_name("_eid").is_ok());
589 assert!(schema.field_with_name("src_vid").is_ok());
590 assert!(schema.field_with_name("dst_vid").is_ok());
591 assert!(schema.field_with_name("type").is_ok());
592 assert!(schema.field_with_name("props_json").is_ok());
593 assert!(schema.field_with_name("_deleted").is_ok());
594 assert!(schema.field_with_name("_version").is_ok());
595 assert!(schema.field_with_name("_created_at").is_ok());
596 assert!(schema.field_with_name("_updated_at").is_ok());
597 }
598
599 #[test]
600 fn test_build_record_batch() {
601 use uni_common::Value;
602 let mut props = HashMap::new();
603 props.insert("weight".to_string(), Value::Float(0.5));
604
605 let edges = vec![(
606 Eid::new(1),
607 Vid::new(1),
608 Vid::new(2),
609 "KNOWS".to_string(),
610 props,
611 false,
612 1u64,
613 )];
614
615 let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
616 assert_eq!(batch.num_rows(), 1);
617 assert_eq!(batch.num_columns(), 9);
618 }
619}