1use crate::lancedb::LanceDbStore;
19use crate::storage::arrow_convert::build_timestamp_column_from_eid_map;
20use crate::storage::index_utils::ensure_btree_index;
21use anyhow::{Result, anyhow};
22use arrow_array::builder::{LargeBinaryBuilder, StringBuilder};
23use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
24use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
25use futures::TryStreamExt;
26use futures::future;
27use lancedb::Table;
28use lancedb::query::{ExecutableQuery, QueryBase, Select};
29use std::collections::HashMap;
30use std::sync::Arc;
31use uni_common::Properties;
32use uni_common::core::id::{Eid, Vid};
33
34#[derive(Debug)]
40pub struct MainEdgeDataset {
41 _base_uri: String,
42}
43
44impl MainEdgeDataset {
45 pub fn new(base_uri: &str) -> Self {
47 Self {
48 _base_uri: base_uri.to_string(),
49 }
50 }
51
52 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
54 Arc::new(ArrowSchema::new(vec![
55 Field::new("_eid", DataType::UInt64, false),
56 Field::new("src_vid", DataType::UInt64, false),
57 Field::new("dst_vid", DataType::UInt64, false),
58 Field::new("type", DataType::Utf8, false),
59 Field::new("props_json", DataType::LargeBinary, true),
60 Field::new("_deleted", DataType::Boolean, false),
61 Field::new("_version", DataType::UInt64, false),
62 Field::new(
63 "_created_at",
64 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
65 true,
66 ),
67 Field::new(
68 "_updated_at",
69 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
70 true,
71 ),
72 ]))
73 }
74
75 pub fn table_name() -> &'static str {
77 "edges"
78 }
79
80 pub fn build_record_batch(
87 edges: &[(Eid, Vid, Vid, String, Properties, bool, u64)],
88 created_at: Option<&HashMap<Eid, i64>>,
89 updated_at: Option<&HashMap<Eid, i64>>,
90 ) -> Result<RecordBatch> {
91 let arrow_schema = Self::get_arrow_schema();
92 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
93
94 let eids: Vec<u64> = edges
96 .iter()
97 .map(|(e, _, _, _, _, _, _)| e.as_u64())
98 .collect();
99 columns.push(Arc::new(UInt64Array::from(eids)));
100
101 let src_vids: Vec<u64> = edges
103 .iter()
104 .map(|(_, s, _, _, _, _, _)| s.as_u64())
105 .collect();
106 columns.push(Arc::new(UInt64Array::from(src_vids)));
107
108 let dst_vids: Vec<u64> = edges
110 .iter()
111 .map(|(_, _, d, _, _, _, _)| d.as_u64())
112 .collect();
113 columns.push(Arc::new(UInt64Array::from(dst_vids)));
114
115 let mut type_builder = StringBuilder::new();
117 for (_, _, _, edge_type, _, _, _) in edges.iter() {
118 type_builder.append_value(edge_type);
119 }
120 columns.push(Arc::new(type_builder.finish()));
121
122 let mut props_json_builder = LargeBinaryBuilder::new();
124 for (_, _, _, _, props, _, _) in edges.iter() {
125 let jsonb_bytes = {
126 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
127 let uni_val: uni_common::Value = json_val.into();
128 uni_common::cypher_value_codec::encode(&uni_val)
129 };
130 props_json_builder.append_value(&jsonb_bytes);
131 }
132 columns.push(Arc::new(props_json_builder.finish()));
133
134 let deleted: Vec<bool> = edges.iter().map(|(_, _, _, _, _, d, _)| *d).collect();
136 columns.push(Arc::new(BooleanArray::from(deleted)));
137
138 let versions: Vec<u64> = edges.iter().map(|(_, _, _, _, _, _, v)| *v).collect();
140 columns.push(Arc::new(UInt64Array::from(versions)));
141
142 let eids = edges.iter().map(|(e, _, _, _, _, _, _)| *e);
144 columns.push(build_timestamp_column_from_eid_map(
145 eids.clone(),
146 created_at,
147 ));
148 columns.push(build_timestamp_column_from_eid_map(eids, updated_at));
149
150 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
151 }
152
153 pub async fn write_batch_lancedb(store: &LanceDbStore, batch: RecordBatch) -> Result<Table> {
157 let table_name = Self::table_name();
158
159 if store.table_exists(table_name).await? {
160 let table = store.open_table(table_name).await?;
161 store.append_to_table(&table, vec![batch]).await?;
162 Ok(table)
163 } else {
164 store.create_table(table_name, vec![batch]).await
165 }
166 }
167
168 pub async fn ensure_default_indexes_lancedb(table: &Table) -> Result<()> {
170 let indices = table
171 .list_indices()
172 .await
173 .map_err(|e| anyhow!("Failed to list indices: {}", e))?;
174
175 future::join_all(
176 ["_eid", "src_vid", "dst_vid", "type"]
177 .iter()
178 .map(|col| ensure_btree_index(table, &indices, col, "main edges")),
179 )
180 .await;
181
182 Ok(())
183 }
184
185 pub async fn find_by_eid(
187 store: &LanceDbStore,
188 eid: Eid,
189 ) -> Result<Option<(Vid, Vid, String, Properties)>> {
190 let table_name = Self::table_name();
191
192 if !store.table_exists(table_name).await? {
193 return Ok(None);
194 }
195
196 let table = store.open_table(table_name).await?;
197 let query = format!("_eid = {}", eid.as_u64());
198
199 let batches = table
200 .query()
201 .only_if(query)
202 .execute()
203 .await
204 .map_err(|e| anyhow!("Query failed: {}", e))?;
205
206 let results: Vec<RecordBatch> = batches.try_collect().await?;
207
208 for batch in results {
209 if batch.num_rows() > 0 {
210 let src_vid_col = batch.column_by_name("src_vid");
211 let dst_vid_col = batch.column_by_name("dst_vid");
212 let type_col = batch.column_by_name("type");
213 let props_col = batch.column_by_name("props_json");
214
215 if let (Some(src), Some(dst), Some(typ), Some(props)) =
216 (src_vid_col, dst_vid_col, type_col, props_col)
217 && let (Some(src_arr), Some(dst_arr), Some(type_arr), Some(props_arr)) = (
218 src.as_any().downcast_ref::<UInt64Array>(),
219 dst.as_any().downcast_ref::<UInt64Array>(),
220 typ.as_any().downcast_ref::<arrow_array::StringArray>(),
221 props
222 .as_any()
223 .downcast_ref::<arrow_array::LargeBinaryArray>(),
224 )
225 {
226 let src_vid = Vid::from(src_arr.value(0));
227 let dst_vid = Vid::from(dst_arr.value(0));
228 let edge_type = type_arr.value(0).to_string();
229 let properties: Properties = if props_arr.is_null(0)
230 || props_arr.value(0).is_empty()
231 {
232 Properties::new()
233 } else {
234 let uni_val = uni_common::cypher_value_codec::decode(props_arr.value(0))
235 .unwrap_or(uni_common::Value::Null);
236 let json_val: serde_json::Value = uni_val.into();
237 serde_json::from_value(json_val).unwrap_or_default()
238 };
239
240 return Ok(Some((src_vid, dst_vid, edge_type, properties)));
241 }
242 }
243 }
244
245 Ok(None)
246 }
247
248 pub async fn open_table(store: &LanceDbStore) -> Result<Option<Table>> {
252 let table_name = Self::table_name();
253
254 if !store.table_exists(table_name).await? {
255 return Ok(None);
256 }
257
258 let table = store.open_table(table_name).await?;
259 Ok(Some(table))
260 }
261
262 async fn execute_query(
266 store: &LanceDbStore,
267 filter: &str,
268 columns: Option<Vec<&str>>,
269 ) -> Result<Vec<RecordBatch>> {
270 let Some(table) = Self::open_table(store).await? else {
271 return Ok(Vec::new());
272 };
273
274 let mut query = table.query();
275 query = query.only_if(filter);
276
277 if let Some(cols) = columns {
278 query = query.select(Select::Columns(
279 cols.into_iter().map(String::from).collect(),
280 ));
281 }
282
283 let batches = query
284 .execute()
285 .await
286 .map_err(|e| anyhow!("Query failed: {}", e))?;
287
288 batches.try_collect().await.map_err(Into::into)
289 }
290
291 fn extract_eids(batches: &[RecordBatch]) -> Vec<Eid> {
293 let mut eids = Vec::new();
294 for batch in batches {
295 if let Some(eid_col) = batch.column_by_name("_eid")
296 && let Some(eid_arr) = eid_col.as_any().downcast_ref::<UInt64Array>()
297 {
298 for i in 0..eid_arr.len() {
299 if !eid_arr.is_null(i) {
300 eids.push(Eid::new(eid_arr.value(i)));
301 }
302 }
303 }
304 }
305 eids
306 }
307
308 pub async fn find_all_eids(store: &LanceDbStore) -> Result<Vec<Eid>> {
310 let batches = Self::execute_query(store, "_deleted = false", Some(vec!["_eid"])).await?;
311 Ok(Self::extract_eids(&batches))
312 }
313
314 pub async fn find_eids_by_type_name(store: &LanceDbStore, type_name: &str) -> Result<Vec<Eid>> {
316 let filter = format!(
317 "_deleted = false AND type = '{}'",
318 type_name.replace('\'', "''")
319 );
320 let batches = Self::execute_query(store, &filter, Some(vec!["_eid"])).await?;
321 Ok(Self::extract_eids(&batches))
322 }
323
324 pub async fn find_props_by_eid(store: &LanceDbStore, eid: Eid) -> Result<Option<Properties>> {
329 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
330 let batches =
331 Self::execute_query(store, &filter, Some(vec!["props_json", "_version"])).await?;
332
333 if batches.is_empty() {
334 return Ok(None);
335 }
336
337 let mut best_props: Option<Properties> = None;
339 let mut best_version: u64 = 0;
340
341 for batch in &batches {
342 let props_col = batch.column_by_name("props_json");
343 let version_col = batch.column_by_name("_version");
344
345 if let (Some(props_arr), Some(ver_arr)) = (
346 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
347 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
348 ) {
349 for i in 0..batch.num_rows() {
350 let version = if ver_arr.is_null(i) {
351 0
352 } else {
353 ver_arr.value(i)
354 };
355
356 if version >= best_version {
357 best_version = version;
358 best_props = Some(Self::parse_props_json(props_arr, i)?);
359 }
360 }
361 }
362 }
363
364 Ok(best_props)
365 }
366
367 fn parse_props_json(arr: &arrow_array::LargeBinaryArray, idx: usize) -> Result<Properties> {
369 if arr.is_null(idx) || arr.value(idx).is_empty() {
370 return Ok(Properties::new());
371 }
372 let bytes = arr.value(idx);
373 let uni_val = uni_common::cypher_value_codec::decode(bytes)
374 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
375 let json_val: serde_json::Value = uni_val.into();
376 serde_json::from_value(json_val).map_err(|e| anyhow!("Failed to parse props_json: {}", e))
377 }
378
379 pub async fn find_type_by_eid(store: &LanceDbStore, eid: Eid) -> Result<Option<String>> {
381 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
382 let batches = Self::execute_query(store, &filter, Some(vec!["type"])).await?;
383
384 for batch in batches {
385 if batch.num_rows() > 0
386 && let Some(type_col) = batch.column_by_name("type")
387 && let Some(type_arr) = type_col.as_any().downcast_ref::<arrow_array::StringArray>()
388 && !type_arr.is_null(0)
389 {
390 return Ok(Some(type_arr.value(0).to_string()));
391 }
392 }
393
394 Ok(None)
395 }
396
397 pub async fn find_edges_by_type_name(
401 store: &LanceDbStore,
402 type_name: &str,
403 ) -> Result<Vec<(Eid, Vid, Vid, Properties)>> {
404 let filter = format!(
405 "_deleted = false AND type = '{}'",
406 type_name.replace('\'', "''")
407 );
408 let batches = Self::execute_query(store, &filter, None).await?;
410
411 let mut edges = Vec::new();
412 for batch in &batches {
413 Self::extract_edges_from_batch(batch, &mut edges)?;
414 }
415
416 Ok(edges)
417 }
418
419 pub async fn find_edges_by_type_names(
424 store: &LanceDbStore,
425 type_names: &[&str],
426 ) -> Result<Vec<(Eid, Vid, Vid, String, Properties)>> {
427 if type_names.is_empty() {
428 return Ok(Vec::new());
429 }
430
431 let escaped_types: Vec<String> = type_names
433 .iter()
434 .map(|t| format!("'{}'", t.replace('\'', "''")))
435 .collect();
436 let filter = format!(
437 "_deleted = false AND type IN ({})",
438 escaped_types.join(", ")
439 );
440
441 let batches = Self::execute_query(store, &filter, None).await?;
443
444 let mut edges = Vec::new();
445 for batch in &batches {
446 Self::extract_edges_with_type_from_batch(batch, &mut edges)?;
447 }
448
449 Ok(edges)
450 }
451
452 fn extract_edges_from_batch(
454 batch: &RecordBatch,
455 edges: &mut Vec<(Eid, Vid, Vid, Properties)>,
456 ) -> Result<()> {
457 let mut edges_with_type = Vec::new();
459 Self::extract_edges_with_type_from_batch(batch, &mut edges_with_type)?;
460 edges.extend(
461 edges_with_type
462 .into_iter()
463 .map(|(eid, src, dst, _type, props)| (eid, src, dst, props)),
464 );
465 Ok(())
466 }
467
468 fn extract_edges_with_type_from_batch(
470 batch: &RecordBatch,
471 edges: &mut Vec<(Eid, Vid, Vid, String, Properties)>,
472 ) -> Result<()> {
473 let Some(eid_arr) = batch
474 .column_by_name("_eid")
475 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
476 else {
477 return Ok(());
478 };
479 let Some(src_arr) = batch
480 .column_by_name("src_vid")
481 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
482 else {
483 return Ok(());
484 };
485 let Some(dst_arr) = batch
486 .column_by_name("dst_vid")
487 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
488 else {
489 return Ok(());
490 };
491 let type_arr = batch
492 .column_by_name("type")
493 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
494 let props_arr = batch
495 .column_by_name("props_json")
496 .and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>());
497
498 for i in 0..batch.num_rows() {
499 if eid_arr.is_null(i) || src_arr.is_null(i) || dst_arr.is_null(i) {
500 continue;
501 }
502
503 let eid = Eid::new(eid_arr.value(i));
504 let src_vid = Vid::new(src_arr.value(i));
505 let dst_vid = Vid::new(dst_arr.value(i));
506 let edge_type = type_arr
507 .filter(|arr| !arr.is_null(i))
508 .map(|arr| arr.value(i).to_string())
509 .unwrap_or_default();
510 let props = props_arr
511 .map(|arr| Self::parse_props_json(arr, i))
512 .transpose()?
513 .unwrap_or_default();
514
515 edges.push((eid, src_vid, dst_vid, edge_type, props));
516 }
517
518 Ok(())
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn test_main_edge_schema() {
528 let schema = MainEdgeDataset::get_arrow_schema();
529 assert_eq!(schema.fields().len(), 9);
530 assert!(schema.field_with_name("_eid").is_ok());
531 assert!(schema.field_with_name("src_vid").is_ok());
532 assert!(schema.field_with_name("dst_vid").is_ok());
533 assert!(schema.field_with_name("type").is_ok());
534 assert!(schema.field_with_name("props_json").is_ok());
535 assert!(schema.field_with_name("_deleted").is_ok());
536 assert!(schema.field_with_name("_version").is_ok());
537 assert!(schema.field_with_name("_created_at").is_ok());
538 assert!(schema.field_with_name("_updated_at").is_ok());
539 }
540
541 #[test]
542 fn test_build_record_batch() {
543 use uni_common::Value;
544 let mut props = HashMap::new();
545 props.insert("weight".to_string(), Value::Float(0.5));
546
547 let edges = vec![(
548 Eid::new(1),
549 Vid::new(1),
550 Vid::new(2),
551 "KNOWS".to_string(),
552 props,
553 false,
554 1u64,
555 )];
556
557 let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
558 assert_eq!(batch.num_rows(), 1);
559 assert_eq!(batch.num_columns(), 9);
560 }
561}