Skip to main content

haystack_core/graph/
csr.rs

1// Compressed Sparse Row (CSR) adjacency — read-optimized graph traversal.
2//
3// Provides cache-friendly, contiguous-memory layout for ref edges.
4// Forward edges are sorted by source entity ID; reverse edges by target.
5// All edges for a single vertex are contiguous in memory, eliminating
6// HashMap/SmallVec pointer-chasing overhead during traversal.
7//
8// This is a **read-only snapshot** rebuilt from the mutable RefAdjacency.
9// It is not updated incrementally — call `rebuild()` after mutations.
10
11/// Read-optimized compressed sparse row adjacency for graph traversal.
12///
13/// Memory layout for N entities with E total edges:
14/// - Forward: `row_offsets[N+1]` + `targets[E]` + `edge_tags[E]`
15/// - Reverse: separate CSR indexed by target ref_val
16///
17/// `refs_from(eid)` is a single slice: `targets[row_offsets[eid]..row_offsets[eid+1]]`
18pub struct CsrAdjacency {
19    // ── Forward edges (source entity_id → targets) ──
20    /// `row_offsets[eid]..row_offsets[eid+1]` is the range in `targets`/`edge_tags`
21    /// for edges from entity `eid`. Length = max_entity_id + 2.
22    fwd_offsets: Vec<usize>,
23    /// Target ref_vals, contiguous per source entity.
24    fwd_targets: Vec<String>,
25    /// Ref tag names (parallel to `fwd_targets`).
26    fwd_tags: Vec<String>,
27
28    // ── Reverse edges (target ref_val → source entity_ids) ──
29    /// Sorted unique target ref_vals.
30    rev_keys: Vec<String>,
31    /// `rev_offsets[i]..rev_offsets[i+1]` is the range in `rev_sources`/`rev_tags`.
32    rev_offsets: Vec<usize>,
33    /// Source entity IDs, contiguous per target.
34    rev_sources: Vec<usize>,
35    /// Ref tag names (parallel to `rev_sources`).
36    rev_tags: Vec<String>,
37}
38
39impl CsrAdjacency {
40    /// Build a CSR snapshot from the mutable HashMap-based adjacency.
41    pub fn from_ref_adjacency(adj: &super::adjacency::RefAdjacency, max_entity_id: usize) -> Self {
42        // ── Forward CSR ──
43        let fwd_data = adj.forward_raw();
44        let num_rows = max_entity_id + 1;
45        let mut fwd_offsets = vec![0usize; num_rows + 1];
46        let mut fwd_targets = Vec::new();
47        let mut fwd_tags = Vec::new();
48
49        // Count edges per entity.
50        for (&eid, edges) in fwd_data {
51            if eid < num_rows {
52                fwd_offsets[eid + 1] = edges.len();
53            }
54        }
55        // Prefix sum.
56        for i in 1..=num_rows {
57            fwd_offsets[i] += fwd_offsets[i - 1];
58        }
59        // Fill edge arrays.
60        fwd_targets.resize(fwd_offsets[num_rows], String::new());
61        fwd_tags.resize(fwd_offsets[num_rows], String::new());
62        // Track current write position per entity.
63        let mut write_pos: Vec<usize> = fwd_offsets[..num_rows].to_vec();
64        for (&eid, edges) in fwd_data {
65            if eid < num_rows {
66                for (tag, target) in edges {
67                    let pos = write_pos[eid];
68                    fwd_targets[pos] = target.clone();
69                    fwd_tags[pos] = tag.clone();
70                    write_pos[eid] += 1;
71                }
72            }
73        }
74
75        // ── Reverse CSR ──
76        let rev_data = adj.reverse_raw();
77        // Collect and sort keys for binary search.
78        let mut rev_keys: Vec<String> = rev_data.keys().cloned().collect();
79        rev_keys.sort();
80
81        let mut rev_offsets = vec![0usize; rev_keys.len() + 1];
82        let mut rev_sources = Vec::new();
83        let mut rev_tags = Vec::new();
84
85        for (i, key) in rev_keys.iter().enumerate() {
86            if let Some(edges) = rev_data.get(key) {
87                rev_offsets[i + 1] = rev_offsets[i] + edges.len();
88                for (tag, src) in edges {
89                    rev_sources.push(*src);
90                    rev_tags.push(tag.clone());
91                }
92            } else {
93                rev_offsets[i + 1] = rev_offsets[i];
94            }
95        }
96
97        Self {
98            fwd_offsets,
99            fwd_targets,
100            fwd_tags,
101            rev_keys,
102            rev_offsets,
103            rev_sources,
104            rev_tags,
105        }
106    }
107
108    /// Get target ref values from an entity, optionally filtered by ref type.
109    ///
110    /// Returns a contiguous slice view (no allocation for unfiltered case).
111    pub fn targets_from(&self, entity_id: usize, ref_type: Option<&str>) -> Vec<String> {
112        if entity_id + 1 >= self.fwd_offsets.len() {
113            return Vec::new();
114        }
115        let start = self.fwd_offsets[entity_id];
116        let end = self.fwd_offsets[entity_id + 1];
117
118        match ref_type {
119            None => self.fwd_targets[start..end].to_vec(),
120            Some(rt) => (start..end)
121                .filter(|&i| self.fwd_tags[i] == rt)
122                .map(|i| self.fwd_targets[i].clone())
123                .collect(),
124        }
125    }
126
127    /// Get source entity ids that reference `target_ref_val`, optionally
128    /// filtered by ref type.
129    pub fn sources_to(&self, target_ref_val: &str, ref_type: Option<&str>) -> Vec<usize> {
130        let idx = match self
131            .rev_keys
132            .binary_search_by(|k| k.as_str().cmp(target_ref_val))
133        {
134            Ok(i) => i,
135            Err(_) => return Vec::new(),
136        };
137        let start = self.rev_offsets[idx];
138        let end = self.rev_offsets[idx + 1];
139
140        match ref_type {
141            None => self.rev_sources[start..end].to_vec(),
142            Some(rt) => (start..end)
143                .filter(|&i| self.rev_tags[i] == rt)
144                .map(|i| self.rev_sources[i])
145                .collect(),
146        }
147    }
148
149    /// Total number of forward edges stored.
150    pub fn edge_count(&self) -> usize {
151        self.fwd_targets.len()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::graph::adjacency::RefAdjacency;
159
160    fn build_test_adjacency() -> (RefAdjacency, usize) {
161        let mut adj = RefAdjacency::new();
162        // equip-1 (eid=1) -> site-1 via siteRef
163        adj.add(1, "siteRef", "site-1");
164        // point-1 (eid=2) -> equip-1 via equipRef, site-1 via siteRef
165        adj.add(2, "equipRef", "equip-1");
166        adj.add(2, "siteRef", "site-1");
167        // point-2 (eid=3) -> equip-1 via equipRef
168        adj.add(3, "equipRef", "equip-1");
169        (adj, 4) // max_entity_id = 3, so pass 4 for size
170    }
171
172    #[test]
173    fn csr_targets_from_all() {
174        let (adj, max) = build_test_adjacency();
175        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
176
177        let targets = csr.targets_from(2, None);
178        assert_eq!(targets.len(), 2);
179        assert!(targets.contains(&"equip-1".to_string()));
180        assert!(targets.contains(&"site-1".to_string()));
181    }
182
183    #[test]
184    fn csr_targets_from_filtered() {
185        let (adj, max) = build_test_adjacency();
186        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
187
188        let targets = csr.targets_from(2, Some("siteRef"));
189        assert_eq!(targets, vec!["site-1".to_string()]);
190    }
191
192    #[test]
193    fn csr_sources_to_all() {
194        let (adj, max) = build_test_adjacency();
195        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
196
197        let mut sources = csr.sources_to("site-1", None);
198        sources.sort();
199        assert_eq!(sources, vec![1, 2]);
200    }
201
202    #[test]
203    fn csr_sources_to_filtered() {
204        let (adj, max) = build_test_adjacency();
205        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
206
207        let sources = csr.sources_to("equip-1", Some("equipRef"));
208        assert_eq!(sources.len(), 2);
209        assert!(sources.contains(&2));
210        assert!(sources.contains(&3));
211    }
212
213    #[test]
214    fn csr_nonexistent_entity() {
215        let (adj, max) = build_test_adjacency();
216        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
217        assert!(csr.targets_from(99, None).is_empty());
218    }
219
220    #[test]
221    fn csr_nonexistent_target() {
222        let (adj, max) = build_test_adjacency();
223        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
224        assert!(csr.sources_to("nonexistent", None).is_empty());
225    }
226
227    #[test]
228    fn csr_edge_count() {
229        let (adj, max) = build_test_adjacency();
230        let csr = CsrAdjacency::from_ref_adjacency(&adj, max);
231        assert_eq!(csr.edge_count(), 4); // 1 + 2 + 1
232    }
233
234    #[test]
235    fn csr_empty_graph() {
236        let adj = RefAdjacency::new();
237        let csr = CsrAdjacency::from_ref_adjacency(&adj, 0);
238        assert!(csr.targets_from(0, None).is_empty());
239        assert!(csr.sources_to("anything", None).is_empty());
240        assert_eq!(csr.edge_count(), 0);
241    }
242}