Skip to main content

sqry_db/queries/
reachability.rs

1//! Reachability derived query.
2//!
3//! Computes the set of nodes reachable from a given root set by following
4//! edges of a specified kind. Used by `find_unused` to identify entry points
5//! and compute the reachable set.
6
7use std::collections::HashSet;
8use std::sync::Arc;
9
10use sqry_core::graph::unified::concurrent::GraphSnapshot;
11use sqry_core::graph::unified::edge::kind::EdgeKind;
12use sqry_core::graph::unified::node::id::NodeId;
13
14use crate::QueryDb;
15use crate::dependency::record_file_dep;
16use crate::query::DerivedQuery;
17
18// PN3 cold-start persistence: ReachabilityKey and ReachableSet are serialized
19// via postcard at cache-insert time. EdgeKind and NodeId already derive
20// Serialize/Deserialize from sqry-core.
21
22/// Key for a reachability query: a set of root nodes + edge kind.
23#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
24pub struct ReachabilityKey {
25    /// Root nodes to start BFS from.
26    pub roots: Vec<NodeId>,
27    /// Edge kind to follow.
28    pub edge_kind: EdgeKind,
29}
30
31/// Result of a reachability query: the set of reachable nodes.
32// HashSet is serde-able because NodeId derives Serialize/Deserialize.
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct ReachableSet {
35    /// All nodes reachable from the root set (includes roots themselves).
36    pub reachable: HashSet<NodeId>,
37}
38
39/// Computes the set of all nodes reachable from a root set via BFS.
40///
41/// # Invalidation
42///
43/// `TRACKS_EDGE_REVISION = true`: invalidated when any edge changes, because
44/// a new edge could make previously-unreachable nodes reachable.
45pub struct ReachabilityQuery;
46
47impl DerivedQuery for ReachabilityQuery {
48    type Key = ReachabilityKey;
49    type Value = Arc<ReachableSet>;
50    const QUERY_TYPE_ID: u32 = crate::queries::type_ids::REACHABILITY;
51    const TRACKS_EDGE_REVISION: bool = true;
52
53    fn execute(
54        key: &ReachabilityKey,
55        _db: &QueryDb,
56        snapshot: &GraphSnapshot,
57    ) -> Arc<ReachableSet> {
58        // Record all files as deps (global topology query)
59        for (fid, _seg) in snapshot.file_segments().iter() {
60            record_file_dep(fid);
61        }
62
63        // BFS from roots
64        let mut reachable = HashSet::new();
65        let mut queue: std::collections::VecDeque<NodeId> = key.roots.iter().copied().collect();
66
67        for &root in &key.roots {
68            reachable.insert(root);
69        }
70
71        while let Some(node) = queue.pop_front() {
72            for edge_ref in &snapshot.edges().edges_from(node) {
73                if std::mem::discriminant(&edge_ref.kind) == std::mem::discriminant(&key.edge_kind)
74                    && reachable.insert(edge_ref.target)
75                {
76                    queue.push_back(edge_ref.target);
77                }
78            }
79        }
80
81        Arc::new(ReachableSet { reachable })
82    }
83}
84
85// ============================================================================
86// PN3 serde roundtrip tests
87// ============================================================================
88
89#[cfg(test)]
90mod serde_roundtrip {
91    use super::*;
92    use postcard::{from_bytes, to_allocvec};
93
94    #[test]
95    fn reachability_key_roundtrip() {
96        let original = ReachabilityKey {
97            roots: vec![NodeId::new(1, 1), NodeId::new(5, 2)],
98            edge_kind: EdgeKind::Calls {
99                argument_count: 0,
100                is_async: false,
101            },
102        };
103        let bytes = to_allocvec(&original).expect("serialize failed");
104        let decoded: ReachabilityKey = from_bytes(&bytes).expect("deserialize failed");
105        assert_eq!(decoded, original);
106    }
107
108    #[test]
109    fn reachable_set_roundtrip() {
110        let mut reachable = HashSet::new();
111        reachable.insert(NodeId::new(10, 1));
112        reachable.insert(NodeId::new(20, 1));
113        let original = ReachableSet { reachable };
114        let bytes = to_allocvec(&original).expect("serialize failed");
115        let decoded: ReachableSet = from_bytes(&bytes).expect("deserialize failed");
116        // HashSet order is not deterministic, so compare by re-serializing.
117        // Both sets must contain the same elements.
118        assert_eq!(decoded.reachable.len(), original.reachable.len());
119        for node in &original.reachable {
120            assert!(
121                decoded.reachable.contains(node),
122                "node {node:?} missing from decoded ReachableSet"
123            );
124        }
125    }
126}