1use crate::backend::StorageBackend;
19use crate::backend::table_names;
20use crate::backend::types::{ScalarIndexType, ScanRequest, WriteMode};
21use crate::storage::arrow_convert::build_timestamp_column_from_eid_map;
22use anyhow::{Result, anyhow};
23use arrow_array::builder::{LargeBinaryBuilder, StringBuilder};
24use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
26use std::collections::HashMap;
27use std::sync::Arc;
28use uni_common::Properties;
29use uni_common::core::id::{Eid, Vid};
30
31#[derive(Debug)]
37pub struct MainEdgeDataset {
38 _base_uri: String,
39}
40
41impl MainEdgeDataset {
42 pub fn new(base_uri: &str) -> Self {
44 Self {
45 _base_uri: base_uri.to_string(),
46 }
47 }
48
49 pub fn get_arrow_schema() -> Arc<ArrowSchema> {
51 Arc::new(ArrowSchema::new(vec![
52 Field::new("_eid", DataType::UInt64, false),
53 Field::new("src_vid", DataType::UInt64, false),
54 Field::new("dst_vid", DataType::UInt64, false),
55 Field::new("type", DataType::Utf8, false),
56 Field::new("props_json", DataType::LargeBinary, true),
57 Field::new("_deleted", DataType::Boolean, false),
58 Field::new("_version", DataType::UInt64, false),
59 Field::new(
60 "_created_at",
61 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
62 true,
63 ),
64 Field::new(
65 "_updated_at",
66 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
67 true,
68 ),
69 ]))
70 }
71
72 pub fn table_name() -> &'static str {
74 "edges"
75 }
76
77 pub fn build_record_batch(
84 edges: &[(Eid, Vid, Vid, String, Properties, bool, u64)],
85 created_at: Option<&HashMap<Eid, i64>>,
86 updated_at: Option<&HashMap<Eid, i64>>,
87 ) -> Result<RecordBatch> {
88 let arrow_schema = Self::get_arrow_schema();
89 let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
90
91 let eids: Vec<u64> = edges
93 .iter()
94 .map(|(e, _, _, _, _, _, _)| e.as_u64())
95 .collect();
96 columns.push(Arc::new(UInt64Array::from(eids)));
97
98 let src_vids: Vec<u64> = edges
100 .iter()
101 .map(|(_, s, _, _, _, _, _)| s.as_u64())
102 .collect();
103 columns.push(Arc::new(UInt64Array::from(src_vids)));
104
105 let dst_vids: Vec<u64> = edges
107 .iter()
108 .map(|(_, _, d, _, _, _, _)| d.as_u64())
109 .collect();
110 columns.push(Arc::new(UInt64Array::from(dst_vids)));
111
112 let mut type_builder = StringBuilder::new();
114 for (_, _, _, edge_type, _, _, _) in edges.iter() {
115 type_builder.append_value(edge_type);
116 }
117 columns.push(Arc::new(type_builder.finish()));
118
119 let mut props_json_builder = LargeBinaryBuilder::new();
121 for (_, _, _, _, props, _, _) in edges.iter() {
122 let jsonb_bytes = {
123 let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
124 let uni_val: uni_common::Value = json_val.into();
125 uni_common::cypher_value_codec::encode(&uni_val)
126 };
127 props_json_builder.append_value(&jsonb_bytes);
128 }
129 columns.push(Arc::new(props_json_builder.finish()));
130
131 let deleted: Vec<bool> = edges.iter().map(|(_, _, _, _, _, d, _)| *d).collect();
133 columns.push(Arc::new(BooleanArray::from(deleted)));
134
135 let versions: Vec<u64> = edges.iter().map(|(_, _, _, _, _, _, v)| *v).collect();
137 columns.push(Arc::new(UInt64Array::from(versions)));
138
139 let eids = edges.iter().map(|(e, _, _, _, _, _, _)| *e);
141 columns.push(build_timestamp_column_from_eid_map(
142 eids.clone(),
143 created_at,
144 ));
145 columns.push(build_timestamp_column_from_eid_map(eids, updated_at));
146
147 RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
148 }
149
150 pub async fn write_batch(backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
154 let table_name = table_names::main_edge_table_name();
155
156 if backend.table_exists(table_name).await? {
157 backend
158 .write(table_name, vec![batch], WriteMode::Append)
159 .await
160 } else {
161 backend.create_table(table_name, vec![batch]).await
162 }
163 }
164
165 pub async fn ensure_default_indexes(backend: &dyn StorageBackend) -> Result<()> {
167 let table_name = table_names::main_edge_table_name();
168 let _ = backend
169 .create_scalar_index(table_name, "_eid", ScalarIndexType::BTree)
170 .await;
171 let _ = backend
172 .create_scalar_index(table_name, "src_vid", ScalarIndexType::BTree)
173 .await;
174 let _ = backend
175 .create_scalar_index(table_name, "dst_vid", ScalarIndexType::BTree)
176 .await;
177 let _ = backend
178 .create_scalar_index(table_name, "type", ScalarIndexType::BTree)
179 .await;
180 Ok(())
181 }
182
183 pub async fn find_by_eid(
185 backend: &dyn StorageBackend,
186 eid: Eid,
187 ) -> Result<Option<(Vid, Vid, String, Properties)>> {
188 let filter = format!("_eid = {}", eid.as_u64());
189 let results = Self::execute_query(backend, &filter, None).await?;
190
191 for batch in results {
192 if batch.num_rows() > 0 {
193 let src_vid_col = batch.column_by_name("src_vid");
194 let dst_vid_col = batch.column_by_name("dst_vid");
195 let type_col = batch.column_by_name("type");
196 let props_col = batch.column_by_name("props_json");
197
198 if let (Some(src), Some(dst), Some(typ), Some(props)) =
199 (src_vid_col, dst_vid_col, type_col, props_col)
200 && let (Some(src_arr), Some(dst_arr), Some(type_arr), Some(props_arr)) = (
201 src.as_any().downcast_ref::<UInt64Array>(),
202 dst.as_any().downcast_ref::<UInt64Array>(),
203 typ.as_any().downcast_ref::<arrow_array::StringArray>(),
204 props
205 .as_any()
206 .downcast_ref::<arrow_array::LargeBinaryArray>(),
207 )
208 {
209 let src_vid = Vid::from(src_arr.value(0));
210 let dst_vid = Vid::from(dst_arr.value(0));
211 let edge_type = type_arr.value(0).to_string();
212 let properties: Properties = if props_arr.is_null(0)
213 || props_arr.value(0).is_empty()
214 {
215 Properties::new()
216 } else {
217 let uni_val = uni_common::cypher_value_codec::decode(props_arr.value(0))
218 .unwrap_or(uni_common::Value::Null);
219 let json_val: serde_json::Value = uni_val.into();
220 serde_json::from_value(json_val).unwrap_or_default()
221 };
222
223 return Ok(Some((src_vid, dst_vid, edge_type, properties)));
224 }
225 }
226 }
227
228 Ok(None)
229 }
230
231 pub async fn exists_by_eid(backend: &dyn StorageBackend, eid: Eid) -> Result<bool> {
237 let filter = format!("_eid = {}", eid.as_u64());
238 let batches = Self::execute_query(backend, &filter, Some(vec!["_eid"])).await?;
239 Ok(!batches.is_empty() && batches.iter().any(|b| b.num_rows() > 0))
240 }
241
242 async fn execute_query(
246 backend: &dyn StorageBackend,
247 filter: &str,
248 columns: Option<Vec<&str>>,
249 ) -> Result<Vec<RecordBatch>> {
250 let table_name = table_names::main_edge_table_name();
251
252 if !backend.table_exists(table_name).await? {
253 return Ok(Vec::new());
254 }
255
256 let mut request = ScanRequest::all(table_name).with_filter(filter);
257 if let Some(cols) = columns {
258 request = request.with_columns(cols.into_iter().map(String::from).collect());
259 }
260
261 backend.scan(request).await
262 }
263
264 fn extract_eids(batches: &[RecordBatch]) -> Vec<Eid> {
266 let mut eids = Vec::new();
267 for batch in batches {
268 if let Some(eid_col) = batch.column_by_name("_eid")
269 && let Some(eid_arr) = eid_col.as_any().downcast_ref::<UInt64Array>()
270 {
271 for i in 0..eid_arr.len() {
272 if !eid_arr.is_null(i) {
273 eids.push(Eid::new(eid_arr.value(i)));
274 }
275 }
276 }
277 }
278 eids
279 }
280
281 pub async fn find_all_eids(backend: &dyn StorageBackend) -> Result<Vec<Eid>> {
283 let batches = Self::execute_query(backend, "_deleted = false", Some(vec!["_eid"])).await?;
284 Ok(Self::extract_eids(&batches))
285 }
286
287 pub async fn find_eids_by_type_name(
289 backend: &dyn StorageBackend,
290 type_name: &str,
291 ) -> Result<Vec<Eid>> {
292 let filter = format!(
293 "_deleted = false AND type = '{}'",
294 type_name.replace('\'', "''")
295 );
296 let batches = Self::execute_query(backend, &filter, Some(vec!["_eid"])).await?;
297 Ok(Self::extract_eids(&batches))
298 }
299
300 pub async fn find_props_by_eid(
305 backend: &dyn StorageBackend,
306 eid: Eid,
307 ) -> Result<Option<Properties>> {
308 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
309 let batches =
310 Self::execute_query(backend, &filter, Some(vec!["props_json", "_version"])).await?;
311
312 if batches.is_empty() {
313 return Ok(None);
314 }
315
316 let mut best_props: Option<Properties> = None;
318 let mut best_version: u64 = 0;
319
320 for batch in &batches {
321 let props_col = batch.column_by_name("props_json");
322 let version_col = batch.column_by_name("_version");
323
324 if let (Some(props_arr), Some(ver_arr)) = (
325 props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
326 version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
327 ) {
328 for i in 0..batch.num_rows() {
329 let version = if ver_arr.is_null(i) {
330 0
331 } else {
332 ver_arr.value(i)
333 };
334
335 if version >= best_version {
336 best_version = version;
337 best_props = Some(Self::parse_props_json(props_arr, i)?);
338 }
339 }
340 }
341 }
342
343 Ok(best_props)
344 }
345
346 fn parse_props_json(arr: &arrow_array::LargeBinaryArray, idx: usize) -> Result<Properties> {
348 if arr.is_null(idx) || arr.value(idx).is_empty() {
349 return Ok(Properties::new());
350 }
351 let bytes = arr.value(idx);
352 let uni_val = uni_common::cypher_value_codec::decode(bytes)
353 .map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
354 let json_val: serde_json::Value = uni_val.into();
355 serde_json::from_value(json_val).map_err(|e| anyhow!("Failed to parse props_json: {}", e))
356 }
357
358 pub async fn find_type_by_eid(
360 backend: &dyn StorageBackend,
361 eid: Eid,
362 ) -> Result<Option<String>> {
363 let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
364 let batches = Self::execute_query(backend, &filter, Some(vec!["type"])).await?;
365
366 for batch in batches {
367 if batch.num_rows() > 0
368 && let Some(type_col) = batch.column_by_name("type")
369 && let Some(type_arr) = type_col.as_any().downcast_ref::<arrow_array::StringArray>()
370 && !type_arr.is_null(0)
371 {
372 return Ok(Some(type_arr.value(0).to_string()));
373 }
374 }
375
376 Ok(None)
377 }
378
379 pub async fn find_edges_by_type_name(
383 backend: &dyn StorageBackend,
384 type_name: &str,
385 ) -> Result<Vec<(Eid, Vid, Vid, Properties)>> {
386 let filter = format!(
387 "_deleted = false AND type = '{}'",
388 type_name.replace('\'', "''")
389 );
390 let batches = Self::execute_query(backend, &filter, None).await?;
392
393 let mut edges = Vec::new();
394 for batch in &batches {
395 Self::extract_edges_from_batch(batch, &mut edges)?;
396 }
397
398 Ok(edges)
399 }
400
401 pub async fn find_edges_by_type_names(
406 backend: &dyn StorageBackend,
407 type_names: &[&str],
408 ) -> Result<Vec<(Eid, Vid, Vid, String, Properties)>> {
409 if type_names.is_empty() {
410 return Ok(Vec::new());
411 }
412
413 let escaped_types: Vec<String> = type_names
415 .iter()
416 .map(|t| format!("'{}'", t.replace('\'', "''")))
417 .collect();
418 let filter = format!(
419 "_deleted = false AND type IN ({})",
420 escaped_types.join(", ")
421 );
422
423 let batches = Self::execute_query(backend, &filter, None).await?;
425
426 let mut edges = Vec::new();
427 for batch in &batches {
428 Self::extract_edges_with_type_from_batch(batch, &mut edges)?;
429 }
430
431 Ok(edges)
432 }
433
434 fn extract_edges_from_batch(
436 batch: &RecordBatch,
437 edges: &mut Vec<(Eid, Vid, Vid, Properties)>,
438 ) -> Result<()> {
439 let mut edges_with_type = Vec::new();
441 Self::extract_edges_with_type_from_batch(batch, &mut edges_with_type)?;
442 edges.extend(
443 edges_with_type
444 .into_iter()
445 .map(|(eid, src, dst, _type, props)| (eid, src, dst, props)),
446 );
447 Ok(())
448 }
449
450 fn extract_edges_with_type_from_batch(
452 batch: &RecordBatch,
453 edges: &mut Vec<(Eid, Vid, Vid, String, Properties)>,
454 ) -> Result<()> {
455 let Some(eid_arr) = batch
456 .column_by_name("_eid")
457 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
458 else {
459 return Ok(());
460 };
461 let Some(src_arr) = batch
462 .column_by_name("src_vid")
463 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
464 else {
465 return Ok(());
466 };
467 let Some(dst_arr) = batch
468 .column_by_name("dst_vid")
469 .and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
470 else {
471 return Ok(());
472 };
473 let type_arr = batch
474 .column_by_name("type")
475 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
476 let props_arr = batch
477 .column_by_name("props_json")
478 .and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>());
479
480 for i in 0..batch.num_rows() {
481 if eid_arr.is_null(i) || src_arr.is_null(i) || dst_arr.is_null(i) {
482 continue;
483 }
484
485 let eid = Eid::new(eid_arr.value(i));
486 let src_vid = Vid::new(src_arr.value(i));
487 let dst_vid = Vid::new(dst_arr.value(i));
488 let edge_type = type_arr
489 .filter(|arr| !arr.is_null(i))
490 .map(|arr| arr.value(i).to_string())
491 .unwrap_or_default();
492 let props = props_arr
493 .map(|arr| Self::parse_props_json(arr, i))
494 .transpose()?
495 .unwrap_or_default();
496
497 edges.push((eid, src_vid, dst_vid, edge_type, props));
498 }
499
500 Ok(())
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_main_edge_schema() {
510 let schema = MainEdgeDataset::get_arrow_schema();
511 assert_eq!(schema.fields().len(), 9);
512 assert!(schema.field_with_name("_eid").is_ok());
513 assert!(schema.field_with_name("src_vid").is_ok());
514 assert!(schema.field_with_name("dst_vid").is_ok());
515 assert!(schema.field_with_name("type").is_ok());
516 assert!(schema.field_with_name("props_json").is_ok());
517 assert!(schema.field_with_name("_deleted").is_ok());
518 assert!(schema.field_with_name("_version").is_ok());
519 assert!(schema.field_with_name("_created_at").is_ok());
520 assert!(schema.field_with_name("_updated_at").is_ok());
521 }
522
523 #[test]
524 fn test_build_record_batch() {
525 use uni_common::Value;
526 let mut props = HashMap::new();
527 props.insert("weight".to_string(), Value::Float(0.5));
528
529 let edges = vec![(
530 Eid::new(1),
531 Vid::new(1),
532 Vid::new(2),
533 "KNOWS".to_string(),
534 props,
535 false,
536 1u64,
537 )];
538
539 let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
540 assert_eq!(batch.num_rows(), 1);
541 assert_eq!(batch.num_columns(), 9);
542 }
543}