Skip to main content

uni_store/storage/
adjacency.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use crate::backend::StorageBackend;
5use crate::backend::table_names;
6use crate::backend::types::{ScanRequest, WriteMode};
7use anyhow::{Result, anyhow};
8use arrow_array::{ListArray, RecordBatch, UInt64Array};
9use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
10#[cfg(feature = "lance-backend")]
11use futures::TryStreamExt;
12#[cfg(feature = "lance-backend")]
13use lance::dataset::Dataset;
14use std::collections::HashMap;
15use std::sync::Arc;
16use uni_common::core::id::{Eid, Vid};
17
18/// Type alias for adjacency list data (neighbors, edge_ids).
19type AdjacencyLists = (Vec<Vid>, Vec<Eid>);
20
21/// Type alias for grouped adjacency data by source vertex.
22type GroupedAdjacencyLists = HashMap<Vid, (Vec<Vid>, Vec<Eid>)>;
23
24/// Downcast the neighbors and edge_ids list columns from a RecordBatch.
25fn downcast_adjacency_lists(batch: &RecordBatch) -> Result<(&ListArray, &ListArray)> {
26    let neighbors_list = batch
27        .column_by_name("neighbors")
28        .ok_or(anyhow!("Missing neighbors"))?
29        .as_any()
30        .downcast_ref::<ListArray>()
31        .ok_or(anyhow!("Invalid neighbors type"))?;
32
33    let edge_ids_list = batch
34        .column_by_name("edge_ids")
35        .ok_or(anyhow!("Missing edge_ids"))?
36        .as_any()
37        .downcast_ref::<ListArray>()
38        .ok_or(anyhow!("Invalid edge_ids type"))?;
39
40    Ok((neighbors_list, edge_ids_list))
41}
42
43/// Extract (neighbors, edge_ids) from a single row of the adjacency list columns.
44fn extract_row_adjacency(
45    neighbors_list: &ListArray,
46    edge_ids_list: &ListArray,
47    row_idx: usize,
48) -> Result<(Vec<Vid>, Vec<Eid>)> {
49    let neighbors_array = neighbors_list.value(row_idx);
50    let neighbors_uint64 = neighbors_array
51        .as_any()
52        .downcast_ref::<UInt64Array>()
53        .ok_or(anyhow!("Invalid neighbors inner type"))?;
54
55    let edge_ids_array = edge_ids_list.value(row_idx);
56    let edge_ids_uint64 = edge_ids_array
57        .as_any()
58        .downcast_ref::<UInt64Array>()
59        .ok_or(anyhow!("Invalid edge_ids inner type"))?;
60
61    let neighbors = (0..neighbors_uint64.len())
62        .map(|i| Vid::from(neighbors_uint64.value(i)))
63        .collect();
64    let eids = (0..edge_ids_uint64.len())
65        .map(|i| Eid::from(edge_ids_uint64.value(i)))
66        .collect();
67
68    Ok((neighbors, eids))
69}
70
71/// Extract adjacency data (neighbors, edge IDs) from a single row of a RecordBatch.
72///
73/// Returns `None` if the batch is empty or columns are missing.
74fn extract_adjacency_from_batch(batch: &RecordBatch) -> Result<Option<AdjacencyLists>> {
75    if batch.num_rows() == 0 {
76        return Ok(None);
77    }
78
79    let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
80
81    let mut all_neighbors = Vec::new();
82    let mut all_eids = Vec::new();
83
84    for row_idx in 0..batch.num_rows() {
85        let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
86        all_neighbors.extend(neighbors);
87        all_eids.extend(eids);
88    }
89
90    Ok(Some((all_neighbors, all_eids)))
91}
92
93/// Extract adjacency data from a batch, grouped by src_vid.
94///
95/// Returns a HashMap mapping each src_vid to its (neighbors, edge_ids).
96fn extract_adjacency_from_batch_grouped(batch: &RecordBatch) -> Result<GroupedAdjacencyLists> {
97    if batch.num_rows() == 0 {
98        return Ok(HashMap::new());
99    }
100
101    let src_vid_col = batch
102        .column_by_name("src_vid")
103        .ok_or(anyhow!("Missing src_vid"))?
104        .as_any()
105        .downcast_ref::<UInt64Array>()
106        .ok_or(anyhow!("Invalid src_vid type"))?;
107
108    let (neighbors_list, edge_ids_list) = downcast_adjacency_lists(batch)?;
109
110    let mut result: HashMap<Vid, (Vec<Vid>, Vec<Eid>)> = HashMap::new();
111
112    for row_idx in 0..batch.num_rows() {
113        let src_vid = Vid::from(src_vid_col.value(row_idx));
114        let (neighbors, eids) = extract_row_adjacency(neighbors_list, edge_ids_list, row_idx)?;
115        result.insert(src_vid, (neighbors, eids));
116    }
117
118    Ok(result)
119}
120
121pub struct AdjacencyDataset {
122    #[cfg_attr(not(feature = "lance-backend"), allow(dead_code))]
123    uri: String,
124    edge_type: String,
125    direction: String,
126}
127
128impl AdjacencyDataset {
129    pub fn new(base_uri: &str, edge_type: &str, label: &str, direction: &str) -> Self {
130        let uri = format!(
131            "{}/adjacency/{}_{}_{}",
132            base_uri, direction, edge_type, label
133        );
134        Self {
135            uri,
136            edge_type: edge_type.to_string(),
137            direction: direction.to_string(),
138        }
139    }
140
141    #[cfg(feature = "lance-backend")]
142    pub async fn open(&self) -> Result<Arc<Dataset>> {
143        self.open_at(None).await
144    }
145
146    #[cfg(feature = "lance-backend")]
147    pub async fn open_at(&self, version: Option<u64>) -> Result<Arc<Dataset>> {
148        let mut ds = Dataset::open(&self.uri).await?;
149        if let Some(v) = version {
150            ds = ds.checkout_version(v).await?;
151        }
152        Ok(Arc::new(ds))
153    }
154
155    pub fn get_arrow_schema(&self) -> Arc<ArrowSchema> {
156        let fields = vec![
157            Field::new("src_vid", ArrowDataType::UInt64, false),
158            // neighbors: list<uint64>
159            Field::new(
160                "neighbors",
161                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
162                false,
163            ),
164            // edge_ids: list<uint64>
165            Field::new(
166                "edge_ids",
167                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::UInt64, true))),
168                false,
169            ),
170        ];
171
172        Arc::new(ArrowSchema::new(fields))
173    }
174
175    #[cfg(feature = "lance-backend")]
176    pub async fn read_adjacency(&self, vid: Vid) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
177        self.read_adjacency_at(vid, None).await
178    }
179
180    #[cfg(feature = "lance-backend")]
181    pub async fn read_adjacency_at(
182        &self,
183        vid: Vid,
184        version: Option<u64>,
185    ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
186        let ds = match self.open_at(version).await {
187            Ok(ds) => ds,
188            Err(_) => return Ok(None),
189        };
190
191        let mut stream = ds
192            .scan()
193            .filter(&format!("src_vid = {}", vid.as_u64()))?
194            .try_into_stream()
195            .await?;
196
197        if let Some(batch) = stream.try_next().await? {
198            return extract_adjacency_from_batch(&batch);
199        }
200
201        Ok(None)
202    }
203
204    // ========================================================================
205    // Backend-agnostic Methods
206    // ========================================================================
207
208    /// Read adjacency data for a vertex from the storage backend.
209    ///
210    /// Returns `None` if the table doesn't exist or no data for the vertex.
211    pub async fn read_adjacency_backend(
212        &self,
213        backend: &dyn StorageBackend,
214        vid: Vid,
215    ) -> Result<Option<(Vec<Vid>, Vec<Eid>)>> {
216        let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
217
218        if !backend.table_exists(&table_name).await? {
219            return Ok(None);
220        }
221
222        let filter = format!("src_vid = {}", vid.as_u64());
223        let batches = backend
224            .scan(ScanRequest::all(&table_name).with_filter(filter))
225            .await?;
226
227        for batch in batches {
228            if let Some(result) = extract_adjacency_from_batch(&batch)? {
229                return Ok(Some(result));
230            }
231        }
232
233        Ok(None)
234    }
235
236    /// Read adjacency data for multiple vertices in a single batch query.
237    ///
238    /// Returns a HashMap mapping each vid to its (neighbors, edge_ids).
239    /// VIDs with no adjacency data will not be in the map.
240    pub async fn read_adjacency_backend_batch(
241        &self,
242        backend: &dyn StorageBackend,
243        vids: &[Vid],
244    ) -> Result<HashMap<Vid, (Vec<Vid>, Vec<Eid>)>> {
245        if vids.is_empty() {
246            return Ok(HashMap::new());
247        }
248
249        let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
250
251        if !backend.table_exists(&table_name).await? {
252            return Ok(HashMap::new());
253        }
254
255        // Build IN filter for batch query
256        let vid_list = vids
257            .iter()
258            .map(|v| v.as_u64().to_string())
259            .collect::<Vec<_>>()
260            .join(", ");
261        let filter = format!("src_vid IN ({})", vid_list);
262        let batches = backend
263            .scan(ScanRequest::all(&table_name).with_filter(filter))
264            .await?;
265
266        let mut result = HashMap::new();
267        for batch in batches {
268            let batch_result = extract_adjacency_from_batch_grouped(&batch)?;
269            result.extend(batch_result);
270        }
271
272        Ok(result)
273    }
274
275    /// Open or create an adjacency table via the storage backend.
276    pub async fn open_or_create(&self, backend: &dyn StorageBackend) -> Result<()> {
277        let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
278        let arrow_schema = self.get_arrow_schema();
279        backend
280            .open_or_create_table(&table_name, arrow_schema)
281            .await
282    }
283
284    /// Write a chunk to an adjacency table.
285    ///
286    /// Creates the table if it doesn't exist, otherwise appends to it.
287    pub async fn write_chunk(
288        &self,
289        backend: &dyn StorageBackend,
290        batch: RecordBatch,
291    ) -> Result<()> {
292        let table_name = table_names::adjacency_table_name(&self.edge_type, &self.direction);
293        if backend.table_exists(&table_name).await? {
294            backend
295                .write(&table_name, vec![batch], WriteMode::Append)
296                .await
297        } else {
298            backend.create_table(&table_name, vec![batch]).await
299        }
300    }
301
302    /// Get the table name for this adjacency dataset.
303    pub fn table_name(&self) -> String {
304        table_names::adjacency_table_name(&self.edge_type, &self.direction)
305    }
306
307    /// Replace an adjacency table's contents atomically.
308    ///
309    /// Used by compaction to rewrite the table with merged data.
310    pub async fn replace(&self, backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
311        let table_name = self.table_name();
312        let arrow_schema = self.get_arrow_schema();
313        backend
314            .replace_table_atomic(&table_name, vec![batch], arrow_schema)
315            .await
316    }
317}