Skip to main content

haystack_core/graph/
csr.rs

1// Compressed Sparse Row (CSR) adjacency — read-optimized graph traversal.
2//
3// Provides cache-friendly, contiguous-memory layout for ref edges.
4// Forward edges are sorted by source entity ID; reverse edges by target.
5// All edges for a single vertex are contiguous in memory, eliminating
6// HashMap/SmallVec pointer-chasing overhead during traversal.
7//
8// This is a **read-only snapshot** rebuilt from the mutable RefAdjacency.
9// It is not updated incrementally — call `rebuild()` after mutations.
10
11use std::collections::HashSet;
12
13/// Read-optimized compressed sparse row adjacency for graph traversal.
14///
15/// Memory layout for N entities with E total edges:
16/// - Forward: `row_offsets[N+1]` + `targets[E]` + `edge_tags[E]`
17/// - Reverse: separate CSR indexed by target ref_val
18///
19/// `refs_from(eid)` is a single slice: `targets[row_offsets[eid]..row_offsets[eid+1]]`
20pub struct CsrAdjacency {
21    // ── Forward edges (source entity_id → targets) ──
22    /// `row_offsets[eid]..row_offsets[eid+1]` is the range in `targets`/`edge_tags`
23    /// for edges from entity `eid`. Length = max_entity_id + 2.
24    fwd_offsets: Vec<usize>,
25    /// Target ref_vals, contiguous per source entity.
26    fwd_targets: Vec<String>,
27    /// Ref tag names (parallel to `fwd_targets`).
28    fwd_tags: Vec<String>,
29
30    // ── Reverse edges (target ref_val → source entity_ids) ──
31    /// Sorted unique target ref_vals.
32    rev_keys: Vec<String>,
33    /// `rev_offsets[i]..rev_offsets[i+1]` is the range in `rev_sources`/`rev_tags`.
34    rev_offsets: Vec<usize>,
35    /// Source entity IDs, contiguous per target.
36    rev_sources: Vec<usize>,
37    /// Ref tag names (parallel to `rev_sources`).
38    rev_tags: Vec<String>,
39}
40
41impl CsrAdjacency {
42    /// Build a CSR snapshot from the mutable HashMap-based adjacency.
43    pub fn from_ref_adjacency(adj: &super::adjacency::RefAdjacency, max_entity_id: usize) -> Self {
44        // ── Forward CSR ──
45        let fwd_data = adj.forward_raw();
46        let num_rows = max_entity_id + 1;
47        let mut fwd_offsets = vec![0usize; num_rows + 1];
48        let mut fwd_targets = Vec::new();
49        let mut fwd_tags = Vec::new();
50
51        // Count edges per entity.
52        for (&eid, edges) in fwd_data {
53            if eid < num_rows {
54                fwd_offsets[eid + 1] = edges.len();
55            }
56        }
57        // Prefix sum.
58        for i in 1..=num_rows {
59            fwd_offsets[i] += fwd_offsets[i - 1];
60        }
61        // Fill edge arrays.
62        fwd_targets.resize(fwd_offsets[num_rows], String::new());
63        fwd_tags.resize(fwd_offsets[num_rows], String::new());
64        // Track current write position per entity.
65        let mut write_pos: Vec<usize> = fwd_offsets[..num_rows].to_vec();
66        for (&eid, edges) in fwd_data {
67            if eid < num_rows {
68                for (tag, target) in edges {
69                    let pos = write_pos[eid];
70                    fwd_targets[pos] = target.clone();
71                    fwd_tags[pos] = tag.clone();
72                    write_pos[eid] += 1;
73                }
74            }
75        }
76
77        // ── Reverse CSR ──
78        let rev_data = adj.reverse_raw();
79        // Collect and sort keys for binary search.
80        let mut rev_keys: Vec<String> = rev_data.keys().cloned().collect();
81        rev_keys.sort();
82
83        let mut rev_offsets = vec![0usize; rev_keys.len() + 1];
84        let mut rev_sources = Vec::new();
85        let mut rev_tags = Vec::new();
86
87        for (i, key) in rev_keys.iter().enumerate() {
88            if let Some(edges) = rev_data.get(key) {
89                rev_offsets[i + 1] = rev_offsets[i] + edges.len();
90                for (tag, src) in edges {
91                    rev_sources.push(*src);
92                    rev_tags.push(tag.clone());
93                }
94            } else {
95                rev_offsets[i + 1] = rev_offsets[i];
96            }
97        }
98
99        Self {
100            fwd_offsets,
101            fwd_targets,
102            fwd_tags,
103            rev_keys,
104            rev_offsets,
105            rev_sources,
106            rev_tags,
107        }
108    }
109
110    /// Get target ref values from an entity, optionally filtered by ref type.
111    ///
112    /// Returns a contiguous slice view (no allocation for unfiltered case).
113    pub fn targets_from(&self, entity_id: usize, ref_type: Option<&str>) -> Vec<String> {
114        if entity_id + 1 >= self.fwd_offsets.len() {
115            return Vec::new();
116        }
117        let start = self.fwd_offsets[entity_id];
118        let end = self.fwd_offsets[entity_id + 1];
119
120        match ref_type {
121            None => self.fwd_targets[start..end].to_vec(),
122            Some(rt) => (start..end)
123                .filter(|&i| self.fwd_tags[i] == rt)
124                .map(|i| self.fwd_targets[i].clone())
125                .collect(),
126        }
127    }
128
129    /// Get source entity ids that reference `target_ref_val`, optionally
130    /// filtered by ref type.
131    pub fn sources_to(&self, target_ref_val: &str, ref_type: Option<&str>) -> Vec<usize> {
132        let idx = match self
133            .rev_keys
134            .binary_search_by(|k| k.as_str().cmp(target_ref_val))
135        {
136            Ok(i) => i,
137            Err(_) => return Vec::new(),
138        };
139        let start = self.rev_offsets[idx];
140        let end = self.rev_offsets[idx + 1];
141
142        match ref_type {
143            None => self.rev_sources[start..end].to_vec(),
144            Some(rt) => (start..end)
145                .filter(|&i| self.rev_tags[i] == rt)
146                .map(|i| self.rev_sources[i])
147                .collect(),
148        }
149    }
150
151    /// Get targets from an entity, overlaying a patch buffer.
152    pub fn targets_from_patched(
153        &self,
154        entity_id: usize,
155        ref_type: Option<&str>,
156        patch: &CsrPatch,
157    ) -> Vec<String> {
158        // If entity was removed/invalidated in patch, skip base CSR for this entity.
159        let mut results = if patch.removed_entities.contains(&entity_id) {
160            Vec::new()
161        } else {
162            self.targets_from(entity_id, ref_type)
163        };
164
165        // Add patch edges for this entity.
166        for (eid, tag, target) in &patch.added_edges {
167            if *eid == entity_id && ref_type.is_none_or(|rt| tag == rt) {
168                results.push(target.clone());
169            }
170        }
171        results
172    }
173
174    /// Get source entity IDs that reference target, overlaying a patch buffer.
175    pub fn sources_to_patched(
176        &self,
177        target_ref_val: &str,
178        ref_type: Option<&str>,
179        patch: &CsrPatch,
180    ) -> Vec<usize> {
181        // Base CSR sources, excluding removed entities.
182        let mut results: Vec<usize> = self
183            .sources_to(target_ref_val, ref_type)
184            .into_iter()
185            .filter(|eid| !patch.removed_entities.contains(eid))
186            .collect();
187
188        // Add patch edges that target this ref_val.
189        for (eid, tag, target) in &patch.added_edges {
190            if target == target_ref_val
191                && ref_type.is_none_or(|rt| tag == rt)
192                && !results.contains(eid)
193            {
194                results.push(*eid);
195            }
196        }
197        results
198    }
199
200    /// Total number of forward edges stored.
201    pub fn edge_count(&self) -> usize {
202        self.fwd_targets.len()
203    }
204}
205
206/// Patch buffer for incremental CSR updates. Overlaid on the base CSR during queries.
207#[derive(Debug, Default)]
208pub struct CsrPatch {
209    /// Added forward edges: (entity_id, ref_tag, target_ref_val)
210    added_edges: Vec<(usize, String, String)>,
211    /// Entity IDs whose forward edges have been removed/invalidated.
212    removed_entities: HashSet<usize>,
213}
214
215impl CsrPatch {
216    pub fn new() -> Self {
217        Self::default()
218    }
219
220    /// Record a new forward edge.
221    pub fn add_edge(&mut self, eid: usize, tag: &str, target: &str) {
222        self.added_edges
223            .push((eid, tag.to_string(), target.to_string()));
224    }
225
226    /// Record that an entity's edges are invalidated (removed or changed).
227    pub fn remove_entity(&mut self, eid: usize) {
228        self.removed_entities.insert(eid);
229    }
230
231    /// Total number of patch operations.
232    pub fn len(&self) -> usize {
233        self.added_edges.len() + self.removed_entities.len()
234    }
235
236    /// Whether the patch is empty.
237    pub fn is_empty(&self) -> bool {
238        self.added_edges.is_empty() && self.removed_entities.is_empty()
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::graph::adjacency::RefAdjacency;
246
247    fn build_test_adjacency() -> (RefAdjacency, usize) {
248        let mut adj = RefAdjacency::new();
249        // equip-1 (eid=1) -> site-1 via siteRef
250        adj.add(1, "siteRef", "site-1");
251        // point-1 (eid=2) -> equip-1 via equipRef, site-1 via siteRef
252        adj.add(2, "equipRef", "equip-1");
253        adj.add(2, "siteRef", "site-1");
254        // point-2 (eid=3) -> equip-1 via equipRef
255        adj.add(3, "equipRef", "equip-1");
256        (adj, 4) // max_entity_id = 3, so pass 4 for size
257    }
258
259    #[test]
260    fn csr_targets_from_all() {
261        let (adj, max) = build_test_adjacency();
262        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
263
264        let targets = csr.targets_from(2, None);
265        assert_eq!(targets.len(), 2);
266        assert!(targets.contains(&"equip-1".to_string()));
267        assert!(targets.contains(&"site-1".to_string()));
268    }
269
270    #[test]
271    fn csr_targets_from_filtered() {
272        let (adj, max) = build_test_adjacency();
273        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
274
275        let targets = csr.targets_from(2, Some("siteRef"));
276        assert_eq!(targets, vec!["site-1".to_string()]);
277    }
278
279    #[test]
280    fn csr_sources_to_all() {
281        let (adj, max) = build_test_adjacency();
282        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
283
284        let mut sources = csr.sources_to("site-1", None);
285        sources.sort();
286        assert_eq!(sources, vec![1, 2]);
287    }
288
289    #[test]
290    fn csr_sources_to_filtered() {
291        let (adj, max) = build_test_adjacency();
292        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
293
294        let sources = csr.sources_to("equip-1", Some("equipRef"));
295        assert_eq!(sources.len(), 2);
296        assert!(sources.contains(&2));
297        assert!(sources.contains(&3));
298    }
299
300    #[test]
301    fn csr_nonexistent_entity() {
302        let (adj, max) = build_test_adjacency();
303        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
304        assert!(csr.targets_from(99, None).is_empty());
305    }
306
307    #[test]
308    fn csr_nonexistent_target() {
309        let (adj, max) = build_test_adjacency();
310        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
311        assert!(csr.sources_to("nonexistent", None).is_empty());
312    }
313
314    #[test]
315    fn csr_edge_count() {
316        let (adj, max) = build_test_adjacency();
317        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
318        assert_eq!(csr.edge_count(), 4); // 1 + 2 + 1
319    }
320
321    #[test]
322    fn csr_empty_graph() {
323        let adj = RefAdjacency::new();
324        let csr = CsrAdjacency::from_ref_adjacency(&adj, 0);
325        assert!(csr.targets_from(0, None).is_empty());
326        assert!(csr.sources_to("anything", None).is_empty());
327        assert_eq!(csr.edge_count(), 0);
328    }
329}