1use crate::backend::StorageBackend;
5use crate::backend::table_names;
6use crate::backend::types::{ScanRequest, WriteMode};
7use anyhow::{Result, anyhow};
8use arrow_array::{ListArray, RecordBatch, UInt64Array};
9use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
10#[cfg(feature = "lance-backend")]
11use futures::TryStreamExt;
12#[cfg(feature = "lance-backend")]
13use lance::dataset::Dataset;
14use std::collections::HashMap;
15use std::sync::Arc;
16use uni_common::core::id::{Eid, Vid};
17
18type AdjacencyLists = (Vec<Vid>, Vec<Eid>);
20
21type GroupedAdjacencyLists = HashMap<Vid, (Vec<Vid>, Vec<Eid>)>;
23
24fn downcast_adjacency_lists(batch: &RecordBatch) -> Result<(&ListArray, &ListArray)> {
26 let neighbors_list = batch
27 .column_by_name("neighbors")
28 .ok_or(anyhow!("Missing neighbors"))?
29 .as_any()
30 .downcast_ref::<ListArray>()
31 .ok_or(anyhow!("Invalid neighbors type"))?;
32
33 let edge_ids_list = batch
34 .column_by_name("edge_ids")
35 .ok_or(anyhow!("Missing edge_ids"))?
36 .as_any()
37 .downcast_ref::<ListArray>()
38 .ok_or(anyhow!("Invalid edge_ids type"))?;
39
40 Ok((neighbors_list, edge_ids_list))
41}
42
43fn extract_row_adjacency(
45 neighbors_list: &ListArray,
46 edge_ids_list: &ListArray,
47 row_idx: usize,
48) -> Result<(Vec<Vid>, Vec<Eid>)> {
49 let neighbors_array = neighbors_list.value(row_idx);
50 let neighbors_uint64 = neighbors_array
51 .as_any()
52 .downcast_ref::<UInt64Array>()
53 .ok_or(anyhow!("Invalid neighbors inner type"))?;
54
55 let edge_ids_array = edge_ids_list.value(row_idx);
56 let edge_ids_uint64 = edge_ids_array
57 .as_any()
58 .downcast_ref::<UInt64Array>()
59 .ok_or(anyhow!("Invalid edge_ids inner type"))?;
60
61 let neighbors = (0..neighbors_uint64.len())
62 .map(|i| Vid::from(neighbors_uint64.value(i)))
63 .collect();
64 let eids = (0..edge_ids_uint64.len())
65 .map(|i| Eid::from(edge_ids_uint64.value(i)))
66 .collect();
67
68 Ok((neighbors, eids))
69}
70
71fn extract_adjacency_from_batch(batch: &RecordBatch) -> Result<Option<AdjacencyLists>> {
75 if batch.num_rows() == 0 {
76 return Ok(None);
77 }
78
79 let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
80
81 let mut all_neighbors = Vec::new();
82 let mut all_eids = Vec::new();
83
84 for row_idx in 0..batch.num_rows() {
85 let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
86 all_neighbors.extend(neighbors);
87 all_eids.extend(eids);
88 }
89
90 Ok(Some((all_neighbors, all_eids)))
91}
92
93fn extract_adjacency_from_batch_grouped(batch: &RecordBatch) -> Result<GroupedAdjacencyLists> {
97 if batch.num_rows() == 0 {
98 return Ok(HashMap::new());
99 }
100
101 let src_vid_col = batch
102 .column_by_name("src_vid")
103 .ok_or(anyhow!("Missing src_vid"))?
104 .as_any()
105 .downcast_ref::<UInt64Array>()
106 .ok_or(anyhow!("Invalid src_vid type"))?;
107
108 let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
109
110 let mut result: HashMap<Vid, (Vec<Vid>, Vec<Eid>)> = HashMap::new();
111
112 for row_idx in 0..batch.num_rows() {
113 let src_vid = Vid::from(src_vid_col.value(row_idx));
114 let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
115 result.insert(src_vid, (neighbors, eids));
116 }
117
118 Ok(result)
119}
120
121pub struct AdjacencyDataset {
122 #[cfg_attr(not(feature = "lance-backend"), allow(dead_code))]
123 uri: String,
124 edge_type: String,
125 direction: String,
126}
127
128impl AdjacencyDataset {
129 pub fn new(base_uri: &str, edge_type: &str, label: &str, direction: &str) -> Self {
130 let uri = format!(
131 "{}/adjacency/{}_{}_{}",
132 base_uri, direction, edge_type, label
133 );
134 Self {
135 uri,
136 edge_type: edge_type.to_string(),
137 direction: direction.to_string(),
138 }
139 }
140
141 #[cfg(feature = "lance-backend")]
142 pub async fn open(&self) -> Result<Arc<Dataset>> {
143 self.open_at(None).await
144 }
145
146 #[cfg(feature = "lance-backend")]
147 pub async fn open_at(&self, version: Option<u64>) -> Result<Arc<Dataset>> {
148 let mut ds = Dataset::open(&self.uri).await?;
149 if let Some(v) = version {
150 ds = ds.checkout_version(v).await?;
151 }
152 Ok(Arc::new(ds))
153 }
154
155 pub fn get_arrow_schema(&self) -> Arc<ArrowSchema> {
156 let fields = vec![
157 Field::new("src_vid", ArrowDataType::UInt64, false),
158 Field::new(
160 "neighbors",
161 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
162 false,
163 ),
164 Field::new(
166 "edge_ids",
167 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
168 false,
169 ),
170 ];
171
172 Arc::new(ArrowSchema::new(fields))
173 }
174
175 #[cfg(feature = "lance-backend")]
176 pub async fn read_adjacency(&self, vid: Vid) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
177 self.read_adjacency_at(vid, None).await
178 }
179
180 #[cfg(feature = "lance-backend")]
181 pub async fn read_adjacency_at(
182 &self,
183 vid: Vid,
184 version: Option<u64>,
185 ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
186 let ds = match self.open_at(version).await {
187 Ok(ds) => ds,
188 Err(_) => return Ok(None),
189 };
190
191 let mut stream = ds
192 .scan()
193 .filter(&format!("src_vid = {}", vid.as_u64()))?
194 .try_into_stream()
195 .await?;
196
197 if let Some(batch) = stream.try_next().await? {
198 return extract_adjacency_from_batch(&batch);
199 }
200
201 Ok(None)
202 }
203
204 pub async fn read_adjacency_backend(
212 &self,
213 backend: &dyn StorageBackend,
214 vid: Vid,
215 ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
216 let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
217
218 if !backend.table_exists(&table_name).await? {
219 return Ok(None);
220 }
221
222 let filter = format!("src_vid = {}", vid.as_u64());
223 let batches = backend
224 .scan(ScanRequest::all(&table_name).with_filter(filter))
225 .await?;
226
227 for batch in batches {
228 if let Some(result) = extract_adjacency_from_batch(&batch)? {
229 return Ok(Some(result));
230 }
231 }
232
233 Ok(None)
234 }
235
236 pub async fn read_adjacency_backend_batch(
241 &self,
242 backend: &dyn StorageBackend,
243 vids: &[Vid],
244 ) -> Result<HashMap<Vid, (Vec<Vid>, Vec<Eid>)>> {
245 if vids.is_empty() {
246 return Ok(HashMap::new());
247 }
248
249 let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
250
251 if !backend.table_exists(&table_name).await? {
252 return Ok(HashMap::new());
253 }
254
255 let vid_list = vids
257 .iter()
258 .map(|v| v.as_u64().to_string())
259 .collect::<Vec<_>>()
260 .join(", ");
261 let filter = format!("src_vid IN ({})", vid_list);
262 let batches = backend
263 .scan(ScanRequest::all(&table_name).with_filter(filter))
264 .await?;
265
266 let mut result = HashMap::new();
267 for batch in batches {
268 let batch_result = extract_adjacency_from_batch_grouped(&batch)?;
269 result.extend(batch_result);
270 }
271
272 Ok(result)
273 }
274
275 pub async fn open_or_create(&self, backend: &dyn StorageBackend) -> Result<()> {
277 let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
278 let arrow_schema = self.get_arrow_schema();
279 backend
280 .open_or_create_table(&table_name, arrow_schema)
281 .await
282 }
283
284 pub async fn write_chunk(
288 &self,
289 backend: &dyn StorageBackend,
290 batch: RecordBatch,
291 ) -> Result<()> {
292 let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
293 if backend.table_exists(&table_name).await? {
294 backend
295 .write(&table_name, vec![batch], WriteMode::Append)
296 .await
297 } else {
298 backend.create_table(&table_name, vec![batch]).await
299 }
300 }
301
302 pub fn table_name(&self) -> String {
304 table_names::adjacency_table_name(&self.edge_type, &self.direction)
305 }
306
307 pub async fn replace(&self, backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
311 let table_name = self.table_name();
312 let arrow_schema = self.get_arrow_schema();
313 backend
314 .replace_table_atomic(&table_name, vec![batch], arrow_schema)
315 .await
316 }
317}