1use crate::backend::StorageBackend;
19use crate::backend::table_names;
20use crate::backend::types::{ScalarIndexType, ScanRequest, WriteMode};
21use crate::storage::arrow_convert::build_timestamp_column_from_eid_map;
22use anyhow::{Result, anyhow};
23use arrow_array::builder::{LargeBinaryBuilder, StringBuilder};
24use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
26use std::collections::HashMap;
27use std::sync::Arc;
28use uni_common::Properties;
29use uni_common::core::id::{Eid, Vid};
30
31#[derive(Debug)]
37pub struct MainEdgeDataset {
38 _base_uri: String,
39}
40
41impl MainEdgeDataset {
42 pub fn new(base_uri: &str) -> Self {
44 Self {
45 _base_uri: base_uri.to_string(),
46 }
47 }
48
49 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
51 Arc::new(ArrowSchema::new(vec![
52 Field::new("_eid", DataType::UInt64, false),
53 Field::new("src_vid", DataType::UInt64, false),
54 Field::new("dst_vid", DataType::UInt64, false),
55 Field::new("type", DataType::Utf8, false),
56 Field::new("props_json", DataType::LargeBinary, true),
57 Field::new("_deleted", DataType::Boolean, false),
58 Field::new("_version", DataType::UInt64, false),
59 Field::new(
60 "_created_at",
61 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
62 true,
63 ),
64 Field::new(
65 "_updated_at",
66 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
67 true,
68 ),
69 ]))
70 }
71
72 pub fn table_name() -> &'static str {
74 "edges"
75 }
76
77 pub fn build_record_batch(
84 edges: &[(Eid, Vid, Vid, String, Properties, bool, u64)],
85 created_at: Option<&HashMap<Eid, i64>>,
86 updated_at: Option<&HashMap<Eid, i64>>,
87 ) -> Result<RecordBatch> {
88 let arrow_schema = Self::get_arrow_schema();
89 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
90
91 let eids: Vec<u64> = edges
93 .iter()
94 .map(|(e, _, _, _, _, _, _)| e.as_u64())
95 .collect();
96 columns.push(Arc::new(UInt64Array::from(eids)));
97
98 let src_vids: Vec<u64> = edges
100 .iter()
101 .map(|(_, s, _, _, _, _, _)| s.as_u64())
102 .collect();
103 columns.push(Arc::new(UInt64Array::from(src_vids)));
104
105 let dst_vids: Vec<u64> = edges
107 .iter()
108 .map(|(_, _, d, _, _, _, _)| d.as_u64())
109 .collect();
110 columns.push(Arc::new(UInt64Array::from(dst_vids)));
111
112 let mut type_builder = StringBuilder::new();
114 for (_, _, _, edge_type, _, _, _) in edges.iter() {
115 type_builder.append_value(edge_type);
116 }
117 columns.push(Arc::new(type_builder.finish()));
118
119 let mut props_json_builder = LargeBinaryBuilder::new();
121 for (_, _, _, _, props, _, _) in edges.iter() {
122 let jsonb_bytes = {
123 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
124 let uni_val: uni_common::Value = json_val.into();
125 uni_common::cypher_value_codec::encode(&uni_val)
126 };
127 props_json_builder.append_value(&jsonb_bytes);
128 }
129 columns.push(Arc::new(props_json_builder.finish()));
130
131 let deleted: Vec<bool> = edges.iter().map(|(_, _, _, _, _, d, _)| *d).collect();
133 columns.push(Arc::new(BooleanArray::from(deleted)));
134
135 let versions: Vec<u64> = edges.iter().map(|(_, _, _, _, _, _, v)| *v).collect();
137 columns.push(Arc::new(UInt64Array::from(versions)));
138
139 let eids = edges.iter().map(|(e, _, _, _, _, _, _)| *e);
141 columns.push(build_timestamp_column_from_eid_map(
142 eids.clone(),
143 created_at,
144 ));
145 columns.push(build_timestamp_column_from_eid_map(eids, updated_at));
146
147 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
148 }
149
150 pub async fn write_batch(backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
154 let table_name = table_names::main_edge_table_name();
155
156 if backend.table_exists(table_name).await? {
157 backend
158 .write(table_name, vec![batch], WriteMode::Append)
159 .await
160 } else {
161 backend.create_table(table_name, vec![batch]).await
162 }
163 }
164
165 pub async fn ensure_default_indexes(backend: &dyn StorageBackend) -> Result<()> {
167 let table_name = table_names::main_edge_table_name();
168 let _ = backend
169 .create_scalar_index(table_name, "_eid", ScalarIndexType::BTree)
170 .await;
171 let _ = backend
172 .create_scalar_index(table_name, "src_vid", ScalarIndexType::BTree)
173 .await;
174 let _ = backend
175 .create_scalar_index(table_name, "dst_vid", ScalarIndexType::BTree)
176 .await;
177 let _ = backend
178 .create_scalar_index(table_name, "type", ScalarIndexType::BTree)
179 .await;
180 Ok(())
181 }
182
183 pub async fn find_by_eid(
185 backend: &dyn StorageBackend,
186 eid: Eid,
187 ) -> Result<Option<(Vid, Vid, String, Properties)>> {
188 let filter = format!("_eid = {}", eid.as_u64());
189 let results = Self::execute_query(backend, &filter, None).await?;
190
191 for batch in results {
192 if batch.num_rows() > 0 {
193 let src_vid_col = batch.column_by_name("src_vid");
194 let dst_vid_col = batch.column_by_name("dst_vid");
195 let type_col = batch.column_by_name("type");
196 let props_col = batch.column_by_name("props_json");
197
198 if let (Some(src), Some(dst), Some(typ), Some(props)) =
199 (src_vid_col, dst_vid_col, type_col, props_col)
200 && let (Some(src_arr), Some(dst_arr), Some(type_arr), Some(props_arr)) = (
201 src.as_any().downcast_ref::<UInt64Array>(),
202 dst.as_any().downcast_ref::<UInt64Array>(),
203 typ.as_any().downcast_ref::<arrow_array::StringArray>(),
204 props
205 .as_any()
206 .downcast_ref::<arrow_array::LargeBinaryArray>(),
207 )
208 {
209 let src_vid = Vid::from(src_arr.value(0));
210 let dst_vid = Vid::from(dst_arr.value(0));
211 let edge_type = type_arr.value(0).to_string();
212 let properties: Properties = if props_arr.is_null(0)
213 || props_arr.value(0).is_empty()
214 {
215 Properties::new()
216 } else {
217 let uni_val = uni_common::cypher_value_codec::decode(props_arr.value(0))
218 .unwrap_or(uni_common::Value::Null);
219 let json_val: serde_json::Value = uni_val.into();
220 serde_json::from_value(json_val).unwrap_or_default()
221 };
222
223 return Ok(Some((src_vid, dst_vid, edge_type, properties)));
224 }
225 }
226 }
227
228 Ok(None)
229 }
230
231 async fn execute_query(
235 backend: &dyn StorageBackend,
236 filter: &str,
237 columns: Option<Vec<&str>>,
238 ) -> Result<Vec<RecordBatch>> {
239 let table_name = table_names::main_edge_table_name();
240
241 if !backend.table_exists(table_name).await? {
242 return Ok(Vec::new());
243 }
244
245 let mut request = ScanRequest::all(table_name).with_filter(filter);
246 if let Some(cols) = columns {
247 request = request.with_columns(cols.into_iter().map(String::from).collect());
248 }
249
250 backend.scan(request).await
251 }
252
253 fn extract_eids(batches: &[RecordBatch]) -> Vec<Eid> {
255 let mut eids = Vec::new();
256 for batch in batches {
257 if let Some(eid_col) = batch.column_by_name("_eid")
258 && let Some(eid_arr) = eid_col.as_any().downcast_ref::<UInt64Array>()
259 {
260 for i in 0..eid_arr.len() {
261 if !eid_arr.is_null(i) {
262 eids.push(Eid::new(eid_arr.value(i)));
263 }
264 }
265 }
266 }
267 eids
268 }
269
270 pub async fn find_all_eids(backend: &dyn StorageBackend) -> Result<Vec<Eid>> {
272 let batches = Self::execute_query(backend, "_deleted = false", Some(vec!["_eid"])).await?;
273 Ok(Self::extract_eids(&batches))
274 }
275
276 pub async fn find_eids_by_type_name(
278 backend: &dyn StorageBackend,
279 type_name: &str,
280 ) -> Result<Vec<Eid>> {
281 let filter = format!(
282 "_deleted = false AND type = '{}'",
283 type_name.replace('\'', "''")
284 );
285 let batches = Self::execute_query(backend, &filter, Some(vec!["_eid"])).await?;
286 Ok(Self::extract_eids(&batches))
287 }
288
289 pub async fn find_props_by_eid(
294 backend: &dyn StorageBackend,
295 eid: Eid,
296 ) -> Result<Option<Properties>> {
297 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
298 let batches =
299 Self::execute_query(backend, &filter, Some(vec!["props_json", "_version"])).await?;
300
301 if batches.is_empty() {
302 return Ok(None);
303 }
304
305 let mut best_props: Option<Properties> = None;
307 let mut best_version: u64 = 0;
308
309 for batch in &batches {
310 let props_col = batch.column_by_name("props_json");
311 let version_col = batch.column_by_name("_version");
312
313 if let (Some(props_arr), Some(ver_arr)) = (
314 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
315 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
316 ) {
317 for i in 0..batch.num_rows() {
318 let version = if ver_arr.is_null(i) {
319 0
320 } else {
321 ver_arr.value(i)
322 };
323
324 if version >= best_version {
325 best_version = version;
326 best_props = Some(Self::parse_props_json(props_arr, i)?);
327 }
328 }
329 }
330 }
331
332 Ok(best_props)
333 }
334
335 fn parse_props_json(arr: &arrow_array::LargeBinaryArray, idx: usize) -> Result<Properties> {
337 if arr.is_null(idx) || arr.value(idx).is_empty() {
338 return Ok(Properties::new());
339 }
340 let bytes = arr.value(idx);
341 let uni_val = uni_common::cypher_value_codec::decode(bytes)
342 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
343 let json_val: serde_json::Value = uni_val.into();
344 serde_json::from_value(json_val).map_err(|e| anyhow!("Failed to parse props_json: {}", e))
345 }
346
347 pub async fn find_type_by_eid(
349 backend: &dyn StorageBackend,
350 eid: Eid,
351 ) -> Result<Option<String>> {
352 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
353 let batches = Self::execute_query(backend, &filter, Some(vec!["type"])).await?;
354
355 for batch in batches {
356 if batch.num_rows() > 0
357 && let Some(type_col) = batch.column_by_name("type")
358 && let Some(type_arr) = type_col.as_any().downcast_ref::<arrow_array::StringArray>()
359 && !type_arr.is_null(0)
360 {
361 return Ok(Some(type_arr.value(0).to_string()));
362 }
363 }
364
365 Ok(None)
366 }
367
368 pub async fn find_edges_by_type_name(
372 backend: &dyn StorageBackend,
373 type_name: &str,
374 ) -> Result<Vec<(Eid, Vid, Vid, Properties)>> {
375 let filter = format!(
376 "_deleted = false AND type = '{}'",
377 type_name.replace('\'', "''")
378 );
379 let batches = Self::execute_query(backend, &filter, None).await?;
381
382 let mut edges = Vec::new();
383 for batch in &batches {
384 Self::extract_edges_from_batch(batch, &mut edges)?;
385 }
386
387 Ok(edges)
388 }
389
390 pub async fn find_edges_by_type_names(
395 backend: &dyn StorageBackend,
396 type_names: &[&str],
397 ) -> Result<Vec<(Eid, Vid, Vid, String, Properties)>> {
398 if type_names.is_empty() {
399 return Ok(Vec::new());
400 }
401
402 let escaped_types: Vec<String> = type_names
404 .iter()
405 .map(|t| format!("'{}'", t.replace('\'', "''")))
406 .collect();
407 let filter = format!(
408 "_deleted = false AND type IN ({})",
409 escaped_types.join(", ")
410 );
411
412 let batches = Self::execute_query(backend, &filter, None).await?;
414
415 let mut edges = Vec::new();
416 for batch in &batches {
417 Self::extract_edges_with_type_from_batch(batch, &mut edges)?;
418 }
419
420 Ok(edges)
421 }
422
423 fn extract_edges_from_batch(
425 batch: &RecordBatch,
426 edges: &mut Vec<(Eid, Vid, Vid, Properties)>,
427 ) -> Result<()> {
428 let mut edges_with_type = Vec::new();
430 Self::extract_edges_with_type_from_batch(batch, &mut edges_with_type)?;
431 edges.extend(
432 edges_with_type
433 .into_iter()
434 .map(|(eid, src, dst, _type, props)| (eid, src, dst, props)),
435 );
436 Ok(())
437 }
438
439 fn extract_edges_with_type_from_batch(
441 batch: &RecordBatch,
442 edges: &mut Vec<(Eid, Vid, Vid, String, Properties)>,
443 ) -> Result<()> {
444 let Some(eid_arr) = batch
445 .column_by_name("_eid")
446 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
447 else {
448 return Ok(());
449 };
450 let Some(src_arr) = batch
451 .column_by_name("src_vid")
452 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
453 else {
454 return Ok(());
455 };
456 let Some(dst_arr) = batch
457 .column_by_name("dst_vid")
458 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
459 else {
460 return Ok(());
461 };
462 let type_arr = batch
463 .column_by_name("type")
464 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
465 let props_arr = batch
466 .column_by_name("props_json")
467 .and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>());
468
469 for i in 0..batch.num_rows() {
470 if eid_arr.is_null(i) || src_arr.is_null(i) || dst_arr.is_null(i) {
471 continue;
472 }
473
474 let eid = Eid::new(eid_arr.value(i));
475 let src_vid = Vid::new(src_arr.value(i));
476 let dst_vid = Vid::new(dst_arr.value(i));
477 let edge_type = type_arr
478 .filter(|arr| !arr.is_null(i))
479 .map(|arr| arr.value(i).to_string())
480 .unwrap_or_default();
481 let props = props_arr
482 .map(|arr| Self::parse_props_json(arr, i))
483 .transpose()?
484 .unwrap_or_default();
485
486 edges.push((eid, src_vid, dst_vid, edge_type, props));
487 }
488
489 Ok(())
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_main_edge_schema() {
499 let schema = MainEdgeDataset::get_arrow_schema();
500 assert_eq!(schema.fields().len(), 9);
501 assert!(schema.field_with_name("_eid").is_ok());
502 assert!(schema.field_with_name("src_vid").is_ok());
503 assert!(schema.field_with_name("dst_vid").is_ok());
504 assert!(schema.field_with_name("type").is_ok());
505 assert!(schema.field_with_name("props_json").is_ok());
506 assert!(schema.field_with_name("_deleted").is_ok());
507 assert!(schema.field_with_name("_version").is_ok());
508 assert!(schema.field_with_name("_created_at").is_ok());
509 assert!(schema.field_with_name("_updated_at").is_ok());
510 }
511
512 #[test]
513 fn test_build_record_batch() {
514 use uni_common::Value;
515 let mut props = HashMap::new();
516 props.insert("weight".to_string(), Value::Float(0.5));
517
518 let edges = vec![(
519 Eid::new(1),
520 Vid::new(1),
521 Vid::new(2),
522 "KNOWS".to_string(),
523 props,
524 false,
525 1u64,
526 )];
527
528 let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
529 assert_eq!(batch.num_rows(), 1);
530 assert_eq!(batch.num_columns(), 9);
531 }
532}