sqry_db/queries/
reachability.rs1use std::collections::HashSet;
8use std::sync::Arc;
9
10use sqry_core::graph::unified::concurrent::GraphSnapshot;
11use sqry_core::graph::unified::edge::kind::EdgeKind;
12use sqry_core::graph::unified::node::id::NodeId;
13
14use crate::QueryDb;
15use crate::dependency::record_file_dep;
16use crate::query::DerivedQuery;
17
18#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
24pub struct ReachabilityKey {
25 pub roots: Vec<NodeId>,
27 pub edge_kind: EdgeKind,
29}
30
31#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct ReachableSet {
35 pub reachable: HashSet<NodeId>,
37}
38
39pub struct ReachabilityQuery;
46
47impl DerivedQuery for ReachabilityQuery {
48 type Key = ReachabilityKey;
49 type Value = Arc<ReachableSet>;
50 const QUERY_TYPE_ID: u32 = crate::queries::type_ids::REACHABILITY;
51 const TRACKS_EDGE_REVISION: bool = true;
52
53 fn execute(
54 key: &ReachabilityKey,
55 _db: &QueryDb,
56 snapshot: &GraphSnapshot,
57 ) -> Arc<ReachableSet> {
58 for (fid, _seg) in snapshot.file_segments().iter() {
60 record_file_dep(fid);
61 }
62
63 let mut reachable = HashSet::new();
65 let mut queue: std::collections::VecDeque<NodeId> = key.roots.iter().copied().collect();
66
67 for &root in &key.roots {
68 reachable.insert(root);
69 }
70
71 while let Some(node) = queue.pop_front() {
72 for edge_ref in &snapshot.edges().edges_from(node) {
73 if std::mem::discriminant(&edge_ref.kind) == std::mem::discriminant(&key.edge_kind)
74 && reachable.insert(edge_ref.target)
75 {
76 queue.push_back(edge_ref.target);
77 }
78 }
79 }
80
81 Arc::new(ReachableSet { reachable })
82 }
83}
84
85#[cfg(test)]
90mod serde_roundtrip {
91 use super::*;
92 use postcard::{from_bytes, to_allocvec};
93
94 #[test]
95 fn reachability_key_roundtrip() {
96 let original = ReachabilityKey {
97 roots: vec![NodeId::new(1, 1), NodeId::new(5, 2)],
98 edge_kind: EdgeKind::Calls {
99 argument_count: 0,
100 is_async: false,
101 },
102 };
103 let bytes = to_allocvec(&original).expect("serialize failed");
104 let decoded: ReachabilityKey = from_bytes(&bytes).expect("deserialize failed");
105 assert_eq!(decoded, original);
106 }
107
108 #[test]
109 fn reachable_set_roundtrip() {
110 let mut reachable = HashSet::new();
111 reachable.insert(NodeId::new(10, 1));
112 reachable.insert(NodeId::new(20, 1));
113 let original = ReachableSet { reachable };
114 let bytes = to_allocvec(&original).expect("serialize failed");
115 let decoded: ReachableSet = from_bytes(&bytes).expect("deserialize failed");
116 assert_eq!(decoded.reachable.len(), original.reachable.len());
119 for node in &original.reachable {
120 assert!(
121 decoded.reachable.contains(node),
122 "node {node:?} missing from decoded ReachableSet"
123 );
124 }
125 }
126}