1use crate::args::Cli;
6use crate::commands::graph::loader::{GraphLoadConfig, load_unified_graph};
7use crate::index_discovery::find_nearest_index;
8use crate::output::OutputStreams;
9use anyhow::{Context, Result, anyhow};
10use serde::Serialize;
11use sqry_core::graph::unified::edge::EdgeKind;
12use sqry_core::graph::unified::node::NodeId;
13use std::collections::{HashSet, VecDeque};
14
15#[derive(Debug, Serialize)]
17struct SubgraphOutput {
18 seeds: Vec<String>,
20 nodes: Vec<SubgraphNode>,
22 edges: Vec<SubgraphEdge>,
24 stats: SubgraphStats,
26}
27
28#[derive(Debug, Clone, Serialize)]
29struct SubgraphNode {
30 id: String,
31 name: String,
32 qualified_name: String,
33 kind: String,
34 file: String,
35 line: u32,
36 language: String,
37 is_seed: bool,
39 depth: usize,
41}
42
43#[derive(Debug, Clone, Serialize)]
44struct SubgraphEdge {
45 source: String,
46 target: String,
47 kind: String,
48}
49
50#[derive(Debug, Serialize)]
51struct SubgraphStats {
52 node_count: usize,
53 edge_count: usize,
54 max_depth_reached: usize,
55}
56
57fn find_seed_nodes(
59 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
60 symbols: &[String],
61) -> Vec<NodeId> {
62 let strings = graph.strings();
63 let mut seed_nodes: Vec<NodeId> = Vec::new();
64
65 for symbol in symbols {
66 let found = graph.nodes().iter().find(|(_, entry)| {
67 if let Some(qn_id) = entry.qualified_name
69 && let Some(qn) = strings.resolve(qn_id)
70 && (qn.as_ref() == symbol.as_str() || qn.contains(symbol.as_str()))
71 {
72 return true;
73 }
74 if let Some(name) = strings.resolve(entry.name)
76 && name.as_ref() == symbol.as_str()
77 {
78 return true;
79 }
80 false
81 });
82
83 if let Some((node_id, _)) = found {
84 seed_nodes.push(node_id);
85 }
86 }
87
88 seed_nodes
89}
90
91struct SubgraphBfsResult {
93 visited: HashSet<NodeId>,
94 node_depths: std::collections::HashMap<NodeId, usize>,
95 collected_edges: Vec<(NodeId, NodeId, String)>,
96 max_depth_reached: usize,
97}
98
99#[allow(clippy::too_many_arguments)]
101fn process_callee_edges(
102 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
103 node_id: NodeId,
104 include_imports: bool,
105 collected_edges: &mut Vec<(NodeId, NodeId, String)>,
106 visited: &mut HashSet<NodeId>,
107 node_depths: &mut std::collections::HashMap<NodeId, usize>,
108 queue: &mut VecDeque<(NodeId, usize)>,
109 depth: usize,
110 max_nodes: usize,
111) {
112 for edge_ref in graph.edges().edges_from(node_id) {
113 let is_call = matches!(edge_ref.kind, EdgeKind::Calls { .. });
114 let is_import = matches!(edge_ref.kind, EdgeKind::Imports { .. });
115
116 if is_call || (include_imports && is_import) {
117 let kind_str = format!("{:?}", edge_ref.kind);
118 collected_edges.push((node_id, edge_ref.target, kind_str));
119
120 if !visited.contains(&edge_ref.target) && visited.len() < max_nodes {
121 visited.insert(edge_ref.target);
122 node_depths.insert(edge_ref.target, depth + 1);
123 queue.push_back((edge_ref.target, depth + 1));
124 }
125 }
126 }
127}
128
129fn process_caller_edges(
131 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
132 node_id: NodeId,
133 collected_edges: &mut Vec<(NodeId, NodeId, String)>,
134 visited: &mut HashSet<NodeId>,
135 node_depths: &mut std::collections::HashMap<NodeId, usize>,
136 queue: &mut VecDeque<(NodeId, usize)>,
137 depth: usize,
138 max_nodes: usize,
139) {
140 for edge_ref in graph.edges().edges_to(node_id) {
141 if matches!(edge_ref.kind, EdgeKind::Calls { .. }) {
142 let kind_str = format!("{:?}", edge_ref.kind);
143 collected_edges.push((edge_ref.source, node_id, kind_str));
144
145 if !visited.contains(&edge_ref.source) && visited.len() < max_nodes {
146 visited.insert(edge_ref.source);
147 node_depths.insert(edge_ref.source, depth + 1);
148 queue.push_back((edge_ref.source, depth + 1));
149 }
150 }
151 }
152}
153
154#[allow(clippy::similar_names)]
156fn collect_subgraph_bfs(
157 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
158 seed_nodes: &[NodeId],
159 max_depth: usize,
160 max_nodes: usize,
161 include_callers: bool,
162 include_callees: bool,
163 include_imports: bool,
164) -> SubgraphBfsResult {
165 let mut visited: HashSet<NodeId> = HashSet::new();
166 let mut node_depths: std::collections::HashMap<NodeId, usize> =
167 std::collections::HashMap::new();
168 let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
169 let mut collected_edges: Vec<(NodeId, NodeId, String)> = Vec::new();
170
171 for &seed in seed_nodes {
173 visited.insert(seed);
174 node_depths.insert(seed, 0);
175 queue.push_back((seed, 0));
176 }
177
178 let mut max_depth_reached = 0;
179
180 while let Some((node_id, depth)) = queue.pop_front() {
181 if visited.len() >= max_nodes {
182 break;
183 }
184 if depth >= max_depth {
185 continue;
186 }
187
188 max_depth_reached = max_depth_reached.max(depth);
189
190 if include_callees {
192 process_callee_edges(
193 graph,
194 node_id,
195 include_imports,
196 &mut collected_edges,
197 &mut visited,
198 &mut node_depths,
199 &mut queue,
200 depth,
201 max_nodes,
202 );
203 }
204
205 if include_callers {
207 process_caller_edges(
208 graph,
209 node_id,
210 &mut collected_edges,
211 &mut visited,
212 &mut node_depths,
213 &mut queue,
214 depth,
215 max_nodes,
216 );
217 }
218 }
219
220 SubgraphBfsResult {
221 visited,
222 node_depths,
223 collected_edges,
224 max_depth_reached,
225 }
226}
227
228fn extension_to_display_language(ext: &str) -> &str {
230 match ext {
231 "rs" => "Rust",
232 "py" => "Python",
233 "js" => "JavaScript",
234 "ts" => "TypeScript",
235 "go" => "Go",
236 "java" => "Java",
237 "c" | "h" => "C",
238 "cpp" | "hpp" | "cc" => "C++",
239 "rb" => "Ruby",
240 "swift" => "Swift",
241 "kt" => "Kotlin",
242 _ => ext,
243 }
244}
245
246fn build_subgraph_nodes(
248 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
249 bfs: &SubgraphBfsResult,
250 seed_nodes: &[NodeId],
251) -> Vec<SubgraphNode> {
252 let strings = graph.strings();
253 let files = graph.files();
254 let seed_set: HashSet<_> = seed_nodes.iter().collect();
255
256 let mut nodes: Vec<SubgraphNode> = bfs
257 .visited
258 .iter()
259 .filter_map(|&node_id| {
260 let entry = graph.nodes().get(node_id)?;
261 let name = strings
262 .resolve(entry.name)
263 .map(|s| s.to_string())
264 .unwrap_or_default();
265 let qualified_name = entry
266 .qualified_name
267 .and_then(|id| strings.resolve(id))
268 .map_or_else(|| name.clone(), |s| s.to_string());
269
270 let file_path = files
271 .resolve(entry.file)
272 .map(|p| p.display().to_string())
273 .unwrap_or_default();
274
275 let language = files.resolve(entry.file).map_or_else(
277 || "Unknown".to_string(),
278 |p| {
279 p.extension()
280 .and_then(|ext| ext.to_str())
281 .map_or("Unknown", extension_to_display_language)
282 .to_string()
283 },
284 );
285
286 Some(SubgraphNode {
287 id: qualified_name.clone(),
288 name,
289 qualified_name,
290 kind: format!("{:?}", entry.kind),
291 file: file_path,
292 line: entry.start_line,
293 language,
294 is_seed: seed_set.contains(&node_id),
295 depth: *bfs.node_depths.get(&node_id).unwrap_or(&0),
296 })
297 })
298 .collect();
299
300 nodes.sort_by(|a, b| a.qualified_name.cmp(&b.qualified_name));
302 nodes
303}
304
305fn build_subgraph_edges(
307 graph: &sqry_core::graph::unified::concurrent::CodeGraph,
308 bfs: &SubgraphBfsResult,
309) -> Vec<SubgraphEdge> {
310 let strings = graph.strings();
311
312 let node_names: std::collections::HashMap<NodeId, String> = bfs
314 .visited
315 .iter()
316 .filter_map(|&node_id| {
317 let entry = graph.nodes().get(node_id)?;
318 let name = strings
319 .resolve(entry.name)
320 .map(|s| s.to_string())
321 .unwrap_or_default();
322 let qn = entry
323 .qualified_name
324 .and_then(|id| strings.resolve(id))
325 .map_or_else(|| name, |s| s.to_string());
326 Some((node_id, qn))
327 })
328 .collect();
329
330 let mut edges: Vec<SubgraphEdge> = bfs
331 .collected_edges
332 .iter()
333 .filter(|(src, tgt, _)| bfs.visited.contains(src) && bfs.visited.contains(tgt))
334 .filter_map(|(src, tgt, kind)| {
335 let src_name = node_names.get(src)?.clone();
336 let tgt_name = node_names.get(tgt)?.clone();
337 Some(SubgraphEdge {
338 source: src_name,
339 target: tgt_name,
340 kind: kind.clone(),
341 })
342 })
343 .collect();
344
345 edges.sort_by(|a, b| (&a.source, &a.target, &a.kind).cmp(&(&b.source, &b.target, &b.kind)));
347 edges.dedup_by(|a, b| a.source == b.source && a.target == b.target && a.kind == b.kind);
348 edges
349}
350
351#[allow(clippy::similar_names)]
356pub fn run_subgraph(
358 cli: &Cli,
359 symbols: &[String],
360 path: Option<&str>,
361 max_depth: usize,
362 max_nodes: usize,
363 include_callers: bool,
364 include_callees: bool,
365 include_imports: bool,
366) -> Result<()> {
367 let mut streams = OutputStreams::new();
368
369 if symbols.is_empty() {
370 return Err(anyhow!("At least one seed symbol is required"));
371 }
372
373 let search_path = path.map_or_else(
375 || std::env::current_dir().unwrap_or_default(),
376 std::path::PathBuf::from,
377 );
378
379 let index_location = find_nearest_index(&search_path);
380 let Some(ref loc) = index_location else {
381 streams
382 .write_diagnostic("No .sqry-index found. Run 'sqry index' first to build the index.")?;
383 return Ok(());
384 };
385
386 let config = GraphLoadConfig::default();
388 let graph = load_unified_graph(&loc.index_root, &config)
389 .context("Failed to load graph. Run 'sqry index' to build the graph.")?;
390
391 let seed_nodes = find_seed_nodes(&graph, symbols);
393 if seed_nodes.is_empty() {
394 streams.write_diagnostic("No seed symbols found in the graph.")?;
395 return Ok(());
396 }
397
398 let bfs = collect_subgraph_bfs(
400 &graph,
401 &seed_nodes,
402 max_depth,
403 max_nodes,
404 include_callers,
405 include_callees,
406 include_imports,
407 );
408
409 let nodes = build_subgraph_nodes(&graph, &bfs, &seed_nodes);
411 let edges = build_subgraph_edges(&graph, &bfs);
412
413 let stats = SubgraphStats {
414 node_count: nodes.len(),
415 edge_count: edges.len(),
416 max_depth_reached: bfs.max_depth_reached,
417 };
418
419 let output = SubgraphOutput {
420 seeds: symbols.to_vec(),
421 nodes,
422 edges,
423 stats,
424 };
425
426 if cli.json {
428 let json = serde_json::to_string_pretty(&output).context("Failed to serialize to JSON")?;
429 streams.write_result(&json)?;
430 } else {
431 let text = format_subgraph_text(&output);
432 streams.write_result(&text)?;
433 }
434
435 Ok(())
436}
437
438fn format_subgraph_text(output: &SubgraphOutput) -> String {
439 let mut lines = Vec::new();
440
441 lines.push(format!(
442 "Subgraph around {} seed(s): {}",
443 output.seeds.len(),
444 output.seeds.join(", ")
445 ));
446 lines.push(format!(
447 "Stats: {} nodes, {} edges, max depth {}",
448 output.stats.node_count, output.stats.edge_count, output.stats.max_depth_reached
449 ));
450 lines.push(String::new());
451
452 lines.push("Nodes:".to_string());
453 for node in &output.nodes {
454 let seed_marker = if node.is_seed { " [SEED]" } else { "" };
455 lines.push(format!(
456 " {} [{}] depth={}{} ",
457 node.qualified_name, node.kind, node.depth, seed_marker
458 ));
459 lines.push(format!(" {}:{}", node.file, node.line));
460 }
461
462 if !output.edges.is_empty() {
463 lines.push(String::new());
464 lines.push("Edges:".to_string());
465 for edge in &output.edges {
466 lines.push(format!(
467 " {} --[{}]--> {}",
468 edge.source, edge.kind, edge.target
469 ));
470 }
471 }
472
473 lines.join("\n")
474}