Skip to main content

the_code_graph_domain/use_cases/
flow.rs

1use crate::analysis::flow::{brandes_betweenness, detect_entry_points, enumerate_flows};
2use crate::error::Result;
3use crate::model::*;
4use crate::ports::GraphStore;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7pub struct FlowUseCase<S> {
8    store: S,
9}
10
11impl<S: GraphStore> FlowUseCase<S> {
12    pub fn new(store: S) -> Self {
13        Self { store }
14    }
15
16    /// Full flow analysis: detect entry points, enumerate flows, compute criticality.
17    pub fn analyze(&self, config: &FlowConfig) -> Result<FlowAnalysis> {
18        let symbols = self.store.all_symbols()?;
19        let edges = self.store.all_edges()?;
20
21        let entry_points = detect_entry_points(&symbols, &edges, config);
22        let flows = enumerate_flows(&entry_points, &edges, config);
23
24        let nodes: HashSet<String> = symbols.iter().map(|s| s.qualified_name.clone()).collect();
25        let betweenness = brandes_betweenness(&nodes, &edges);
26
27        let entry_set: HashSet<&str> = entry_points
28            .iter()
29            .map(|e| e.qualified_name.as_str())
30            .collect();
31
32        // Count flows per node
33        let mut flow_counts: HashMap<String, usize> = HashMap::new();
34        for flow in &flows {
35            for node in &flow.path {
36                *flow_counts.entry(node.clone()).or_default() += 1;
37            }
38        }
39
40        let mut criticality: Vec<CriticalityScore> = betweenness
41            .iter()
42            .map(|(name, &score)| CriticalityScore {
43                qualified_name: name.clone(),
44                betweenness: score,
45                flow_count: flow_counts.get(name).copied().unwrap_or(0),
46                is_entry_point: entry_set.contains(name.as_str()),
47            })
48            .collect();
49        criticality.sort_by(|a, b| {
50            b.betweenness
51                .partial_cmp(&a.betweenness)
52                .unwrap_or(std::cmp::Ordering::Equal)
53        });
54
55        let stats = FlowStats {
56            total_entry_points: entry_points.len(),
57            total_flows: flows.len(),
58            max_depth: flows.iter().map(|f| f.depth).max().unwrap_or(0),
59            avg_depth: if flows.is_empty() {
60                0.0
61            } else {
62                flows.iter().map(|f| f.depth as f64).sum::<f64>() / flows.len() as f64
63            },
64        };
65
66        Ok(FlowAnalysis {
67            entry_points,
68            flows,
69            criticality,
70            stats,
71        })
72    }
73
74    /// Find flows passing through a specific symbol (optimized: backward BFS first).
75    pub fn flows_through(
76        &self,
77        qualified_name: &str,
78        config: &FlowConfig,
79    ) -> Result<Vec<ExecutionFlow>> {
80        let symbols = self.store.all_symbols()?;
81        let edges = self.store.all_edges()?;
82        let entry_points = detect_entry_points(&symbols, &edges, config);
83
84        // Backward BFS from target through HIGH-CONFIDENCE edges only
85        let high_edges: Vec<&Edge> = edges
86            .iter()
87            .filter(|e| e.kind.confidence() == Confidence::High)
88            .collect();
89        let mut reachable_entries = HashSet::new();
90        let mut visited = HashSet::new();
91        let mut queue = VecDeque::new();
92        queue.push_back(qualified_name.to_string());
93        visited.insert(qualified_name.to_string());
94        while let Some(node) = queue.pop_front() {
95            if entry_points.iter().any(|ep| ep.qualified_name == node) {
96                reachable_entries.insert(node.clone());
97            }
98            for edge in &high_edges {
99                if edge.target == node && !visited.contains(&edge.source) {
100                    visited.insert(edge.source.clone());
101                    queue.push_back(edge.source.clone());
102                }
103            }
104        }
105
106        // DFS only from reachable entry points, filter to paths containing target
107        let filtered_entries: Vec<EntryPoint> = entry_points
108            .into_iter()
109            .filter(|ep| reachable_entries.contains(&ep.qualified_name))
110            .collect();
111        let all_flows = enumerate_flows(&filtered_entries, &edges, config);
112        Ok(all_flows
113            .into_iter()
114            .filter(|f| f.path.contains(&qualified_name.to_string()))
115            .collect())
116    }
117
118    /// Get criticality scores sorted descending by betweenness.
119    pub fn criticality(&self) -> Result<Vec<CriticalityScore>> {
120        let analysis = self.analyze(&FlowConfig::default())?;
121        Ok(analysis.criticality)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::test_support::InMemoryGraphStore;
129
130    fn build_store() -> InMemoryGraphStore {
131        let mut store = InMemoryGraphStore::new();
132        store.insert_symbol(SymbolNode {
133            name: "main".into(),
134            qualified_name: "src/main.rs::main".into(),
135            kind: SymbolKind::Function,
136            location: Location {
137                file: "src/main.rs".into(),
138                line_start: 1,
139                line_end: 10,
140                col_start: 0,
141                col_end: 0,
142            },
143            visibility: Visibility::Public,
144            is_exported: true,
145            is_async: false,
146            is_test: false,
147            decorators: vec![],
148            signature: None,
149        });
150        store.insert_symbol(SymbolNode {
151            name: "connect".into(),
152            qualified_name: "src/db.rs::connect".into(),
153            kind: SymbolKind::Function,
154            location: Location {
155                file: "src/db.rs".into(),
156                line_start: 1,
157                line_end: 10,
158                col_start: 0,
159                col_end: 0,
160            },
161            visibility: Visibility::Public,
162            is_exported: true,
163            is_async: false,
164            is_test: false,
165            decorators: vec![],
166            signature: None,
167        });
168        store.insert_edge(Edge {
169            kind: EdgeKind::Calls,
170            source: "src/main.rs::main".into(),
171            target: "src/db.rs::connect".into(),
172            metadata: None,
173        });
174        store
175    }
176
177    #[test]
178    fn analyze_returns_flows_and_criticality() {
179        let store = build_store();
180        let uc = FlowUseCase::new(store);
181        let analysis = uc.analyze(&FlowConfig::default()).unwrap();
182        assert!(!analysis.entry_points.is_empty());
183        assert!(!analysis.flows.is_empty());
184        assert!(!analysis.criticality.is_empty());
185    }
186
187    #[test]
188    fn flows_through_filters_correctly() {
189        let store = build_store();
190        let uc = FlowUseCase::new(store);
191        let flows = uc
192            .flows_through("src/db.rs::connect", &FlowConfig::default())
193            .unwrap();
194        for flow in &flows {
195            assert!(flow.path.contains(&"src/db.rs::connect".to_string()));
196        }
197    }
198
199    #[test]
200    fn criticality_returns_sorted_scores() {
201        let store = build_store();
202        let uc = FlowUseCase::new(store);
203        let scores = uc.criticality().unwrap();
204        for w in scores.windows(2) {
205            assert!(w[0].betweenness >= w[1].betweenness);
206        }
207    }
208
209    #[test]
210    fn flows_through_nonexistent_symbol_returns_empty() {
211        let store = build_store();
212        let uc = FlowUseCase::new(store);
213        let flows = uc
214            .flows_through("nonexistent::symbol", &FlowConfig::default())
215            .unwrap();
216        assert!(flows.is_empty());
217    }
218
219    #[test]
220    fn flows_through_ignores_medium_confidence_reachability() {
221        let mut store = InMemoryGraphStore::new();
222        store.insert_symbol(SymbolNode {
223            name: "main".into(),
224            qualified_name: "src/main.rs::main".into(),
225            kind: SymbolKind::Function,
226            location: Location {
227                file: "src/main.rs".into(),
228                line_start: 1,
229                line_end: 10,
230                col_start: 0,
231                col_end: 0,
232            },
233            visibility: Visibility::Public,
234            is_exported: true,
235            is_async: false,
236            is_test: false,
237            decorators: vec![],
238            signature: None,
239        });
240        store.insert_symbol(SymbolNode {
241            name: "util".into(),
242            qualified_name: "src/util.rs::util".into(),
243            kind: SymbolKind::Function,
244            location: Location {
245                file: "src/util.rs".into(),
246                line_start: 1,
247                line_end: 10,
248                col_start: 0,
249                col_end: 0,
250            },
251            visibility: Visibility::Private,
252            is_exported: false,
253            is_async: false,
254            is_test: false,
255            decorators: vec![],
256            signature: None,
257        });
258        // Only Medium-confidence edge connecting them
259        store.insert_edge(Edge {
260            kind: EdgeKind::ImportsFrom,
261            source: "src/main.rs::main".into(),
262            target: "src/util.rs::util".into(),
263            metadata: None,
264        });
265        let uc = FlowUseCase::new(store);
266        let flows = uc
267            .flows_through("src/util.rs::util", &FlowConfig::default())
268            .unwrap();
269        assert!(
270            flows.is_empty(),
271            "backward BFS must filter on High-confidence edges only"
272        );
273    }
274
275    #[test]
276    fn analyze_empty_graph_returns_zeros() {
277        let store = InMemoryGraphStore::new();
278        let uc = FlowUseCase::new(store);
279        let analysis = uc.analyze(&FlowConfig::default()).unwrap();
280        assert!(analysis.entry_points.is_empty());
281        assert!(analysis.flows.is_empty());
282        assert_eq!(analysis.stats.total_entry_points, 0);
283        assert_eq!(analysis.stats.total_flows, 0);
284    }
285}