Skip to main content

haystack_core/graph/
structural.rs

1//! WL-inspired structural fingerprinting for entity graph partitioning.
2//!
3//! Adapts the 1-dimensional Weisfeiler-Leman colour refinement algorithm
4//! to Haystack entities. Each entity gets a structural fingerprint based
5//! on its tag set and the tag sets of its k-hop neighbors.
6
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9
10use roaring::RoaringBitmap;
11use rustc_hash::FxHasher;
12
13use crate::data::HDict;
14
15/// Structural index: maps entities to WL fingerprints and partitions
16/// entities by fingerprint for fast structural queries.
17pub struct StructuralIndex {
18    /// ref_val → fingerprint
19    fingerprints: HashMap<String, u64>,
20    /// fingerprint → entity IDs (as roaring bitmap)
21    partitions: HashMap<u64, RoaringBitmap>,
22    /// fingerprint → set of tag names that entities with this fingerprint have
23    partition_tags: HashMap<u64, Vec<String>>,
24    /// Number of WL refinement rounds (default 2).
25    depth: usize,
26    /// Whether the index needs full recomputation.
27    stale: bool,
28}
29
30impl StructuralIndex {
31    pub fn new() -> Self {
32        Self::with_depth(2)
33    }
34
35    pub fn with_depth(depth: usize) -> Self {
36        Self {
37            fingerprints: HashMap::new(),
38            partitions: HashMap::new(),
39            partition_tags: HashMap::new(),
40            depth,
41            stale: true,
42        }
43    }
44
45    /// Compute a fast non-cryptographic hash.
46    fn fx_hash(data: &[u8]) -> u64 {
47        let mut hasher = FxHasher::default();
48        data.hash(&mut hasher);
49        hasher.finish()
50    }
51
52    /// Compute round-0 fingerprint: hash of sorted tag names.
53    fn round0_fingerprint(entity: &HDict) -> u64 {
54        let mut tags: Vec<&str> = entity.tag_names().collect();
55        tags.sort_unstable();
56        let combined = tags.join("\0");
57        Self::fx_hash(combined.as_bytes())
58    }
59
60    /// Maximum entities for full WL refinement (with neighbor propagation).
61    const MAX_ENTITIES_FULL_WL: usize = 50_000;
62    /// Maximum entities for any structural indexing.
63    const MAX_ENTITIES_STRUCTURAL: usize = 200_000;
64
65    /// Full recomputation of the structural index.
66    pub fn compute(
67        &mut self,
68        entities: &HashMap<String, HDict>,
69        id_map: &HashMap<String, usize>,
70        adjacency_targets: impl Fn(&str) -> Vec<String>,
71    ) {
72        self.fingerprints.clear();
73        self.partitions.clear();
74        self.partition_tags.clear();
75
76        if entities.is_empty() {
77            self.stale = false;
78            return;
79        }
80
81        // Skip entirely for very large graphs.
82        if entities.len() > Self::MAX_ENTITIES_STRUCTURAL {
83            self.stale = false;
84            return;
85        }
86
87        // Adaptive depth: tag-hash only for large graphs.
88        let effective_depth = if entities.len() > Self::MAX_ENTITIES_FULL_WL {
89            0
90        } else {
91            self.depth
92        };
93
94        // Round 0: hash of sorted tag names per entity.
95        let mut current: HashMap<String, u64> = HashMap::new();
96        for (ref_val, entity) in entities {
97            current.insert(ref_val.clone(), Self::round0_fingerprint(entity));
98        }
99
100        // Rounds 1..depth: incorporate neighbor fingerprints.
101        for _ in 0..effective_depth {
102            let mut next: HashMap<String, u64> = HashMap::new();
103            for ref_val in entities.keys() {
104                let own_fp = current[ref_val];
105                let mut neighbor_fps: Vec<u64> = adjacency_targets(ref_val)
106                    .iter()
107                    .filter_map(|n| current.get(n).copied())
108                    .collect();
109                neighbor_fps.sort_unstable();
110
111                // Hash: (own_fp, sorted neighbor fingerprints)
112                let mut hasher = FxHasher::default();
113                own_fp.hash(&mut hasher);
114                for nfp in &neighbor_fps {
115                    nfp.hash(&mut hasher);
116                }
117                next.insert(ref_val.clone(), hasher.finish());
118            }
119            current = next;
120        }
121
122        // Build partitions.
123        for (ref_val, fp) in &current {
124            self.fingerprints.insert(ref_val.clone(), *fp);
125            if let Some(&eid) = id_map.get(ref_val) {
126                self.partitions.entry(*fp).or_default().insert(eid as u32);
127            }
128        }
129
130        // Build partition tag sets (store one representative's tags per partition).
131        for (ref_val, fp) in &current {
132            if !self.partition_tags.contains_key(fp)
133                && let Some(entity) = entities.get(ref_val)
134            {
135                let mut tags: Vec<String> = entity.tag_names().map(|s| s.to_string()).collect();
136                tags.sort();
137                self.partition_tags.insert(*fp, tags);
138            }
139        }
140
141        self.stale = false;
142    }
143
144    /// Get the fingerprint for an entity.
145    pub fn fingerprint(&self, ref_val: &str) -> Option<u64> {
146        self.fingerprints.get(ref_val).copied()
147    }
148
149    /// Get the entity IDs in a structural partition.
150    pub fn partition(&self, fingerprint: u64) -> Option<&RoaringBitmap> {
151        self.partitions.get(&fingerprint)
152    }
153
154    /// Find all partitions whose entities have ALL the given required tags.
155    pub fn partitions_with_tags(&self, required_tags: &[&str]) -> RoaringBitmap {
156        let mut result = RoaringBitmap::new();
157        for (fp, tags) in &self.partition_tags {
158            if required_tags.iter().all(|rt| tags.iter().any(|t| t == rt))
159                && let Some(bm) = self.partitions.get(fp)
160            {
161                result |= bm;
162            }
163        }
164        result
165    }
166
167    /// Number of distinct structural partitions.
168    pub fn partition_count(&self) -> usize {
169        self.partitions.len()
170    }
171
172    /// Whether the index needs recomputation.
173    pub fn is_stale(&self) -> bool {
174        self.stale
175    }
176
177    /// Mark the index as needing recomputation.
178    pub fn mark_stale(&mut self) {
179        self.stale = true;
180    }
181
182    /// Get a histogram of fingerprints: fp → count.
183    pub fn histogram(&self) -> HashMap<u64, u64> {
184        self.partitions
185            .iter()
186            .map(|(fp, bm)| (*fp, bm.len()))
187            .collect()
188    }
189}
190
191impl Default for StructuralIndex {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::kinds::{HRef, Kind};
201
202    fn make_entity(id: &str, tags: &[&str], refs: &[(&str, &str)]) -> HDict {
203        let mut e = HDict::new();
204        e.set("id", Kind::Ref(HRef::from_val(id)));
205        for tag in tags {
206            e.set(*tag, Kind::Marker);
207        }
208        for (tag, target) in refs {
209            e.set(*tag, Kind::Ref(HRef::from_val(*target)));
210        }
211        e
212    }
213
214    fn adjacency_fn<'a>(entities: &'a HashMap<String, HDict>) -> impl Fn(&str) -> Vec<String> + 'a {
215        move |ref_val: &str| match entities.get(ref_val) {
216            Some(e) => e
217                .iter()
218                .filter_map(|(name, val)| {
219                    if name != "id"
220                        && let Kind::Ref(r) = val
221                    {
222                        return Some(r.val.clone());
223                    }
224                    None
225                })
226                .collect(),
227            None => Vec::new(),
228        }
229    }
230
231    #[test]
232    fn identical_structures_get_same_fingerprint() {
233        let mut entities = HashMap::new();
234        let mut id_map = HashMap::new();
235
236        entities.insert(
237            "vav-1".to_string(),
238            make_entity("vav-1", &["vav", "equip"], &[("siteRef", "site-1")]),
239        );
240        entities.insert(
241            "vav-2".to_string(),
242            make_entity("vav-2", &["vav", "equip"], &[("siteRef", "site-1")]),
243        );
244        entities.insert("site-1".to_string(), make_entity("site-1", &["site"], &[]));
245        id_map.insert("vav-1".to_string(), 0);
246        id_map.insert("vav-2".to_string(), 1);
247        id_map.insert("site-1".to_string(), 2);
248
249        let mut si = StructuralIndex::new();
250        si.compute(&entities, &id_map, adjacency_fn(&entities));
251
252        assert_eq!(si.fingerprint("vav-1"), si.fingerprint("vav-2"));
253        assert_ne!(si.fingerprint("vav-1"), si.fingerprint("site-1"));
254        assert_eq!(si.partition_count(), 2);
255    }
256
257    #[test]
258    fn different_structures_get_different_fingerprints() {
259        let mut entities = HashMap::new();
260        let mut id_map = HashMap::new();
261
262        entities.insert(
263            "sensor-1".to_string(),
264            make_entity("sensor-1", &["point", "sensor"], &[]),
265        );
266        entities.insert(
267            "cmd-1".to_string(),
268            make_entity("cmd-1", &["point", "cmd"], &[]),
269        );
270        id_map.insert("sensor-1".to_string(), 0);
271        id_map.insert("cmd-1".to_string(), 1);
272
273        let mut si = StructuralIndex::new();
274        si.compute(&entities, &id_map, |_| Vec::new());
275
276        assert_ne!(si.fingerprint("sensor-1"), si.fingerprint("cmd-1"));
277    }
278
279    #[test]
280    fn partitions_with_tags_returns_matching() {
281        let mut entities = HashMap::new();
282        let mut id_map = HashMap::new();
283
284        entities.insert(
285            "s-1".to_string(),
286            make_entity("s-1", &["point", "sensor", "temp"], &[]),
287        );
288        entities.insert(
289            "s-2".to_string(),
290            make_entity("s-2", &["point", "sensor", "temp"], &[]),
291        );
292        entities.insert(
293            "c-1".to_string(),
294            make_entity("c-1", &["point", "cmd"], &[]),
295        );
296        id_map.insert("s-1".to_string(), 0);
297        id_map.insert("s-2".to_string(), 1);
298        id_map.insert("c-1".to_string(), 2);
299
300        let mut si = StructuralIndex::new();
301        si.compute(&entities, &id_map, |_| Vec::new());
302
303        let result = si.partitions_with_tags(&["point", "sensor"]);
304        assert_eq!(result.len(), 2);
305        assert!(result.contains(0));
306        assert!(result.contains(1));
307        assert!(!result.contains(2));
308    }
309
310    #[test]
311    fn histogram_reflects_partition_sizes() {
312        let mut entities = HashMap::new();
313        let mut id_map = HashMap::new();
314
315        for i in 0..5 {
316            let id = format!("vav-{i}");
317            entities.insert(id.clone(), make_entity(&id, &["vav", "equip"], &[]));
318            id_map.insert(id, i);
319        }
320        entities.insert("site-1".to_string(), make_entity("site-1", &["site"], &[]));
321        id_map.insert("site-1".to_string(), 5);
322
323        let mut si = StructuralIndex::new();
324        si.compute(&entities, &id_map, |_| Vec::new());
325
326        let hist = si.histogram();
327        assert_eq!(hist.len(), 2);
328        assert!(hist.values().any(|&count| count == 5));
329        assert!(hist.values().any(|&count| count == 1));
330    }
331
332    #[test]
333    fn stale_tracking() {
334        let mut si = StructuralIndex::new();
335        assert!(si.is_stale());
336
337        si.compute(&HashMap::new(), &HashMap::new(), |_| Vec::new());
338        assert!(!si.is_stale());
339
340        si.mark_stale();
341        assert!(si.is_stale());
342    }
343
344    #[test]
345    fn empty_graph_produces_no_partitions() {
346        let mut si = StructuralIndex::new();
347        si.compute(&HashMap::new(), &HashMap::new(), |_| Vec::new());
348        assert_eq!(si.partition_count(), 0);
349        assert!(!si.is_stale());
350    }
351}