1use crate::lancedb::LanceDbStore;
5use anyhow::{Result, anyhow};
6use arrow_array::{ListArray, RecordBatch, UInt64Array};
7use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
8use futures::TryStreamExt;
9use lance::dataset::Dataset;
10use lancedb::Table;
11use std::collections::HashMap;
12use std::sync::Arc;
13use uni_common::core::id::{Eid, Vid};
14
15type AdjacencyLists = (Vec<Vid>, Vec<Eid>);
17
18type GroupedAdjacencyLists = HashMap<Vid, (Vec<Vid>, Vec<Eid>)>;
20
21fn downcast_adjacency_lists(batch: &RecordBatch) -> Result<(&ListArray, &ListArray)> {
23 let neighbors_list = batch
24 .column_by_name("neighbors")
25 .ok_or(anyhow!("Missing neighbors"))?
26 .as_any()
27 .downcast_ref::<ListArray>()
28 .ok_or(anyhow!("Invalid neighbors type"))?;
29
30 let edge_ids_list = batch
31 .column_by_name("edge_ids")
32 .ok_or(anyhow!("Missing edge_ids"))?
33 .as_any()
34 .downcast_ref::<ListArray>()
35 .ok_or(anyhow!("Invalid edge_ids type"))?;
36
37 Ok((neighbors_list, edge_ids_list))
38}
39
40fn extract_row_adjacency(
42 neighbors_list: &ListArray,
43 edge_ids_list: &ListArray,
44 row_idx: usize,
45) -> Result<(Vec<Vid>, Vec<Eid>)> {
46 let neighbors_array = neighbors_list.value(row_idx);
47 let neighbors_uint64 = neighbors_array
48 .as_any()
49 .downcast_ref::<UInt64Array>()
50 .ok_or(anyhow!("Invalid neighbors inner type"))?;
51
52 let edge_ids_array = edge_ids_list.value(row_idx);
53 let edge_ids_uint64 = edge_ids_array
54 .as_any()
55 .downcast_ref::<UInt64Array>()
56 .ok_or(anyhow!("Invalid edge_ids inner type"))?;
57
58 let neighbors = (0..neighbors_uint64.len())
59 .map(|i| Vid::from(neighbors_uint64.value(i)))
60 .collect();
61 let eids = (0..edge_ids_uint64.len())
62 .map(|i| Eid::from(edge_ids_uint64.value(i)))
63 .collect();
64
65 Ok((neighbors, eids))
66}
67
68fn extract_adjacency_from_batch(batch: &RecordBatch) -> Result<Option<AdjacencyLists>> {
72 if batch.num_rows() == 0 {
73 return Ok(None);
74 }
75
76 let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
77
78 let mut all_neighbors = Vec::new();
79 let mut all_eids = Vec::new();
80
81 for row_idx in 0..batch.num_rows() {
82 let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
83 all_neighbors.extend(neighbors);
84 all_eids.extend(eids);
85 }
86
87 Ok(Some((all_neighbors, all_eids)))
88}
89
90fn extract_adjacency_from_batch_grouped(batch: &RecordBatch) -> Result<GroupedAdjacencyLists> {
94 if batch.num_rows() == 0 {
95 return Ok(HashMap::new());
96 }
97
98 let src_vid_col = batch
99 .column_by_name("src_vid")
100 .ok_or(anyhow!("Missing src_vid"))?
101 .as_any()
102 .downcast_ref::<UInt64Array>()
103 .ok_or(anyhow!("Invalid src_vid type"))?;
104
105 let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
106
107 let mut result: HashMap<Vid, (Vec<Vid>, Vec<Eid>)> = HashMap::new();
108
109 for row_idx in 0..batch.num_rows() {
110 let src_vid = Vid::from(src_vid_col.value(row_idx));
111 let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
112 result.insert(src_vid, (neighbors, eids));
113 }
114
115 Ok(result)
116}
117
118pub struct AdjacencyDataset {
119 uri: String,
120 edge_type: String,
121 direction: String,
122}
123
124impl AdjacencyDataset {
125 pub fn new(base_uri: &str, edge_type: &str, label: &str, direction: &str) -> Self {
126 let uri = format!(
127 "{}/adjacency/{}_{}_{}",
128 base_uri, direction, edge_type, label
129 );
130 Self {
131 uri,
132 edge_type: edge_type.to_string(),
133 direction: direction.to_string(),
134 }
135 }
136
137 pub async fn open(&self) -> Result<Arc<Dataset>> {
138 self.open_at(None).await
139 }
140
141 pub async fn open_at(&self, version: Option<u64>) -> Result<Arc<Dataset>> {
142 let mut ds = Dataset::open(&self.uri).await?;
143 if let Some(v) = version {
144 ds = ds.checkout_version(v).await?;
145 }
146 Ok(Arc::new(ds))
147 }
148
149 pub fn get_arrow_schema(&self) -> Arc<ArrowSchema> {
150 let fields = vec![
151 Field::new("src_vid", ArrowDataType::UInt64, false),
152 Field::new(
154 "neighbors",
155 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
156 false,
157 ),
158 Field::new(
160 "edge_ids",
161 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
162 false,
163 ),
164 ];
165
166 Arc::new(ArrowSchema::new(fields))
167 }
168
169 pub async fn read_adjacency(&self, vid: Vid) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
170 self.read_adjacency_at(vid, None).await
171 }
172
173 pub async fn read_adjacency_at(
174 &self,
175 vid: Vid,
176 version: Option<u64>,
177 ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
178 let ds = match self.open_at(version).await {
179 Ok(ds) => ds,
180 Err(_) => return Ok(None),
181 };
182
183 let mut stream = ds
184 .scan()
185 .filter(&format!("src_vid = {}", vid.as_u64()))?
186 .try_into_stream()
187 .await?;
188
189 if let Some(batch) = stream.try_next().await? {
190 return extract_adjacency_from_batch(&batch);
191 }
192
193 Ok(None)
194 }
195
196 pub async fn read_adjacency_lancedb(
204 &self,
205 store: &LanceDbStore,
206 vid: Vid,
207 ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
208 let table = match self.open_lancedb(store).await {
209 Ok(t) => t,
210 Err(_) => return Ok(None),
211 };
212
213 use lancedb::query::{ExecutableQuery, QueryBase};
214
215 let query = table.query().only_if(format!("src_vid = {}", vid.as_u64()));
216 let stream = query.execute().await?;
217 let batches: Vec<RecordBatch> = stream.try_collect().await?;
218
219 for batch in batches {
220 if let Some(result) = extract_adjacency_from_batch(&batch)? {
221 return Ok(Some(result));
222 }
223 }
224
225 Ok(None)
226 }
227
228 pub async fn read_adjacency_lancedb_batch(
233 &self,
234 store: &LanceDbStore,
235 vids: &[Vid],
236 ) -> Result<HashMap<Vid, (Vec<Vid>, Vec<Eid>)>> {
237 if vids.is_empty() {
238 return Ok(HashMap::new());
239 }
240
241 let table = match self.open_lancedb(store).await {
242 Ok(t) => t,
243 Err(_) => return Ok(HashMap::new()),
244 };
245
246 use lancedb::query::{ExecutableQuery, QueryBase};
247
248 let vid_list = vids
250 .iter()
251 .map(|v| v.as_u64().to_string())
252 .collect::<Vec<_>>()
253 .join(", ");
254 let query = table.query().only_if(format!("src_vid IN ({})", vid_list));
255 let stream = query.execute().await?;
256 let batches: Vec<RecordBatch> = stream.try_collect().await?;
257
258 let mut result = HashMap::new();
259 for batch in batches {
260 let batch_result = extract_adjacency_from_batch_grouped(&batch)?;
261 result.extend(batch_result);
262 }
263
264 Ok(result)
265 }
266
267 pub async fn open_lancedb(&self, store: &LanceDbStore) -> Result<Table> {
269 store
270 .open_adjacency_table(&self.edge_type, &self.direction)
271 .await
272 }
273
274 pub async fn open_or_create_lancedb(&self, store: &LanceDbStore) -> Result<Table> {
276 let arrow_schema = self.get_arrow_schema();
277 store
278 .open_or_create_adjacency_table(&self.edge_type, &self.direction, arrow_schema)
279 .await
280 }
281
282 pub async fn write_chunk_lancedb(
286 &self,
287 store: &LanceDbStore,
288 batch: RecordBatch,
289 ) -> Result<Table> {
290 let table_name = LanceDbStore::adjacency_table_name(&self.edge_type, &self.direction);
291
292 if store.table_exists(&table_name).await? {
293 let table = store.open_table(&table_name).await?;
294 store.append_to_table(&table, vec![batch]).await?;
295 Ok(table)
296 } else {
297 store.create_table(&table_name, vec![batch]).await
298 }
299 }
300
301 pub fn lancedb_table_name(&self) -> String {
303 LanceDbStore::adjacency_table_name(&self.edge_type, &self.direction)
304 }
305
306 pub async fn replace_lancedb(&self, store: &LanceDbStore, batch: RecordBatch) -> Result<Table> {
311 let table_name = self.lancedb_table_name();
312 let arrow_schema = self.get_arrow_schema();
313 store
314 .replace_table_atomic(&table_name, vec![batch], arrow_schema)
315 .await
316 }
317}