Skip to main content

ryo_analysis/query/
graph_v2.rs

1//! CodeGraphV2 - Data-Oriented Design implementation of CodeGraph.
2//!
3//! Key improvements over V1:
4//! - **petgraph-free**: Direct SymbolId-based operations
5//! - **String-free**: Uses SymbolId/FileId instead of String/PathBuf
6//! - **SoA Layout**: Cache-efficient edge storage
7//! - **SmallVec**: Stack allocation for common cases
8
9use crate::define_index;
10use crate::symbol::{FileId, SymbolId};
11use crate::SymbolKind;
12use serde::{Deserialize, Serialize};
13use slotmap::SecondaryMap;
14use smallvec::SmallVec;
15use std::collections::{HashMap, HashSet};
16
17// ============================================================================
18// Index Types
19// ============================================================================
20
21define_index! {
22    /// Index into the edges array.
23    pub struct EdgeId;
24}
25
26define_index! {
27    /// Index into the match expressions array.
28    pub struct MatchExprId;
29}
30
31// ============================================================================
32// Edge Types
33// ============================================================================
34
35/// Edge types in the code graph.
36///
37/// CodeGraphV2 tracks three kinds of relationships:
38/// - **Contains**: Structural parent-child (module → item, struct → field)
39/// - **Calls**: Function call chains (caller → callee)
40/// - **Implements**: Trait implementation (implementor → trait)
41///
42/// Type references (field types, parameter types, etc.) are tracked by
43/// TypeFlowGraphV2, not here.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum CodeEdgeV2 {
46    /// Parent contains child (module → item, struct → field).
47    Contains,
48    /// Caller calls callee.
49    Calls,
50    /// Implementor implements trait/type.
51    Implements,
52}
53
54/// Edge data (compact representation).
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub struct EdgeData {
57    /// Source symbol.
58    pub from: SymbolId,
59    /// Target symbol.
60    pub to: SymbolId,
61    /// Edge kind.
62    pub kind: CodeEdgeV2,
63}
64
65// ============================================================================
66// Match Expression (String-free)
67// ============================================================================
68
69/// Match expression data (String-free).
70///
71/// Instead of storing file_path as PathBuf and enum_name as String,
72/// we use FileId and SymbolId for efficient storage and lookup.
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
74pub struct MatchExprDataV2 {
75    /// File containing the match expression.
76    pub file_id: FileId,
77    /// The enum being matched (SymbolId).
78    pub enum_id: SymbolId,
79    /// Byte offset in the file.
80    pub offset: u32,
81    /// Line number (1-indexed).
82    pub line: u32,
83}
84
85// ============================================================================
86// CodeGraphV2
87// ============================================================================
88
89/// Symbol relationship graph with Data-Oriented Design.
90///
91/// # Design Principles
92///
93/// 1. **petgraph-free**: No NodeIndex, direct SymbolId operations
94/// 2. **SoA Layout**: Edges stored in separate arrays for cache efficiency
95/// 3. **String-free**: All references use SymbolId/FileId
96/// 4. **SmallVec**: Stack allocation for typical adjacency sizes
97///
98/// # Memory Layout
99///
100/// ```text
101/// CodeGraphV2
102/// ├── Edge Storage (SoA)
103/// │   └── edges: Vec<EdgeData>
104/// ├── Adjacency Lists (SymbolId → EdgeIds)
105/// │   ├── outgoing: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>
106/// │   └── incoming: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>
107/// ├── Indices
108/// │   ├── by_kind: HashMap<SymbolKind, SmallVec<[SymbolId; 16]>>
109/// │   └── crate_roots: SmallVec<[SymbolId; 4]>
110/// └── Match Expressions
111///     ├── match_expr_index: SecondaryMap<SymbolId, SmallVec<[MatchExprId; 2]>>
112///     └── match_exprs: Vec<MatchExprDataV2>
113/// ```
114/// NOTE: Deserialize is NOT derived because CodeGraphV2 contains SecondaryMap<SymbolId, ...>
115/// and SymbolId is process-specific. Serialize is kept for debugging/inspection.
116#[derive(Clone, Default, Serialize)]
117pub struct CodeGraphV2 {
118    // === Edge Storage (SoA) ===
119    /// All edges in the graph.
120    edges: Vec<EdgeData>,
121
122    // === Adjacency Lists ===
123    /// SymbolId → outgoing EdgeIds.
124    outgoing: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
125    /// SymbolId → incoming EdgeIds.
126    incoming: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
127
128    // === Nodes (tracked symbols) ===
129    /// All symbols in the graph.
130    nodes: SecondaryMap<SymbolId, ()>,
131
132    // === Indices ===
133    /// SymbolKind → SymbolIds index.
134    by_kind: HashMap<SymbolKind, SmallVec<[SymbolId; 16]>>,
135    /// Crate root symbols.
136    crate_roots: SmallVec<[SymbolId; 4]>,
137
138    // === Match Expressions ===
139    /// Function SymbolId → MatchExprIds.
140    match_expr_index: SecondaryMap<SymbolId, SmallVec<[MatchExprId; 2]>>,
141    /// All match expression data.
142    match_exprs: Vec<MatchExprDataV2>,
143}
144
145impl CodeGraphV2 {
146    /// Create a new empty graph.
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Create a graph with pre-allocated capacity.
152    pub fn with_capacity(_nodes: usize, edges: usize) -> Self {
153        Self {
154            edges: Vec::with_capacity(edges),
155            outgoing: SecondaryMap::new(),
156            incoming: SecondaryMap::new(),
157            nodes: SecondaryMap::new(),
158            by_kind: HashMap::new(),
159            crate_roots: SmallVec::new(),
160            match_expr_index: SecondaryMap::new(),
161            match_exprs: Vec::new(),
162        }
163    }
164
165    // ========================================================================
166    // Node Management
167    // ========================================================================
168
169    /// Add a symbol to the graph.
170    ///
171    /// Returns true if the symbol was newly added, false if it already existed.
172    pub fn add_node(&mut self, id: SymbolId) -> bool {
173        if self.nodes.contains_key(id) {
174            return false;
175        }
176        self.nodes.insert(id, ());
177        true
178    }
179
180    /// Check if a symbol exists in the graph.
181    #[inline]
182    pub fn contains(&self, id: SymbolId) -> bool {
183        self.nodes.contains_key(id)
184    }
185
186    /// Remove a symbol and all its edges from the graph.
187    ///
188    /// Returns true if the symbol was removed.
189    pub fn remove_node(&mut self, id: SymbolId) -> bool {
190        if self.nodes.remove(id).is_none() {
191            return false;
192        }
193
194        // Remove from adjacency lists
195        self.outgoing.remove(id);
196        self.incoming.remove(id);
197
198        // Remove from kind index
199        for symbols in self.by_kind.values_mut() {
200            symbols.retain(|s| *s != id);
201        }
202
203        // Remove from crate roots
204        self.crate_roots.retain(|s| *s != id);
205
206        // Remove match expressions for this function
207        self.match_expr_index.remove(id);
208
209        // Note: We don't remove edges from self.edges to avoid O(n) scan.
210        // The adjacency lists ensure orphaned edges are never traversed.
211        // Compaction can be done periodically if needed.
212
213        true
214    }
215
216    /// Clear all outgoing edges from a symbol.
217    ///
218    /// Used for incremental updates: clear old edges before rebuilding.
219    /// Does not remove the node itself or its incoming edges.
220    pub fn clear_outgoing_edges(&mut self, id: SymbolId) {
221        if let Some(edge_ids) = self.outgoing.remove(id) {
222            // Remove from incoming adjacency lists of target nodes
223            for edge_id in edge_ids.iter().copied() {
224                if let Some(edge) = self.edges.get(edge_id.as_usize()) {
225                    let target = edge.to;
226                    if let Some(incoming) = self.incoming.get_mut(target) {
227                        incoming.retain(|eid| *eid != edge_id);
228                    }
229                }
230            }
231        }
232    }
233
234    // ========================================================================
235    // Edge Management
236    // ========================================================================
237
238    /// Add an edge between two symbols.
239    ///
240    /// Both symbols are automatically added to the graph if not present.
241    pub fn add_edge(&mut self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> EdgeId {
242        // Ensure nodes exist
243        self.add_node(from);
244        self.add_node(to);
245
246        // Create edge
247        let edge_id = EdgeId::from_raw(self.edges.len() as u32);
248        self.edges.push(EdgeData { from, to, kind });
249
250        // Update adjacency lists
251        self.outgoing
252            .entry(from)
253            .expect("caller must supply a SymbolId already present in the SlotMap")
254            .or_default()
255            .push(edge_id);
256        self.incoming
257            .entry(to)
258            .expect("caller must supply a SymbolId already present in the SlotMap")
259            .or_default()
260            .push(edge_id);
261
262        edge_id
263    }
264
265    /// Get edge data by EdgeId.
266    #[inline]
267    pub fn edge(&self, id: EdgeId) -> Option<&EdgeData> {
268        self.edges.get(id.as_usize())
269    }
270
271    /// Check if an edge exists between two symbols.
272    pub fn has_edge(&self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> bool {
273        self.outgoing
274            .get(from)
275            .map(|edges| {
276                edges.iter().any(|&eid| {
277                    self.edges
278                        .get(eid.as_usize())
279                        .map(|e| e.to == to && e.kind == kind)
280                        .unwrap_or(false)
281                })
282            })
283            .unwrap_or(false)
284    }
285
286    // ========================================================================
287    // Crate Roots
288    // ========================================================================
289
290    /// Add a crate root symbol.
291    pub fn add_crate_root(&mut self, id: SymbolId) {
292        self.add_node(id);
293        if !self.crate_roots.contains(&id) {
294            self.crate_roots.push(id);
295        }
296    }
297
298    /// Get crate root symbols.
299    #[inline]
300    pub fn crate_roots(&self) -> &[SymbolId] {
301        &self.crate_roots
302    }
303
304    // ========================================================================
305    // Kind Index
306    // ========================================================================
307
308    /// Add a symbol to the kind index.
309    pub fn add_to_kind_index(&mut self, id: SymbolId, kind: SymbolKind) {
310        let symbols = self.by_kind.entry(kind).or_default();
311        if !symbols.contains(&id) {
312            symbols.push(id);
313        }
314    }
315
316    /// Iterate over symbols of a specific kind.
317    pub fn iter_by_kind(&self, kind: SymbolKind) -> impl Iterator<Item = SymbolId> + '_ {
318        self.by_kind
319            .get(&kind)
320            .into_iter()
321            .flat_map(|v| v.iter().copied())
322    }
323
324    // ========================================================================
325    // Graph Traversal
326    // ========================================================================
327
328    /// Get outgoing edges from a symbol.
329    pub fn outgoing_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
330        self.outgoing
331            .get(id)
332            .into_iter()
333            .flat_map(|edges| edges.iter())
334            .filter_map(|&eid| self.edges.get(eid.as_usize()))
335    }
336
337    /// Get incoming edges to a symbol.
338    pub fn incoming_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
339        self.incoming
340            .get(id)
341            .into_iter()
342            .flat_map(|edges| edges.iter())
343            .filter_map(|&eid| self.edges.get(eid.as_usize()))
344    }
345
346    /// Find callers of a symbol (deduplicated).
347    pub fn callers_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
348        let mut seen = HashSet::new();
349        self.incoming_edges(id)
350            .filter(|e| e.kind == CodeEdgeV2::Calls)
351            .map(|e| e.from)
352            .filter(move |&id| seen.insert(id))
353    }
354
355    /// Find callees of a symbol (deduplicated).
356    pub fn callees_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
357        let mut seen = HashSet::new();
358        self.outgoing_edges(id)
359            .filter(|e| e.kind == CodeEdgeV2::Calls)
360            .map(|e| e.to)
361            .filter(move |&id| seen.insert(id))
362    }
363
364    /// Find implementors of a trait.
365    pub fn implementors_of(&self, trait_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
366        self.incoming_edges(trait_id)
367            .filter(|e| e.kind == CodeEdgeV2::Implements)
368            .map(|e| e.from)
369    }
370
371    /// Find children (contained symbols) of a parent.
372    pub fn children_of(&self, parent_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
373        self.outgoing_edges(parent_id)
374            .filter(|e| e.kind == CodeEdgeV2::Contains)
375            .map(|e| e.to)
376    }
377
378    /// Find the parent (container) of a symbol.
379    pub fn parent_of(&self, id: SymbolId) -> Option<SymbolId> {
380        self.incoming_edges(id)
381            .find(|e| e.kind == CodeEdgeV2::Contains)
382            .map(|e| e.from)
383    }
384
385    /// Get the reference count for a symbol (call references only).
386    pub fn reference_count(&self, id: SymbolId) -> usize {
387        self.incoming_edges(id)
388            .filter(|e| e.kind == CodeEdgeV2::Calls)
389            .count()
390    }
391
392    /// Get the impl count for a symbol.
393    pub fn impl_count(&self, id: SymbolId) -> usize {
394        self.incoming_edges(id)
395            .filter(|e| e.kind == CodeEdgeV2::Implements)
396            .count()
397    }
398
399    // ========================================================================
400    // Match Expressions
401    // ========================================================================
402
403    /// Add a match expression to a function.
404    pub fn add_match_expr(&mut self, function_id: SymbolId, data: MatchExprDataV2) -> MatchExprId {
405        let expr_id = MatchExprId::from_raw(self.match_exprs.len() as u32);
406        self.match_exprs.push(data);
407
408        self.match_expr_index
409            .entry(function_id)
410            .expect("caller must supply a function SymbolId already present in the SlotMap")
411            .or_default()
412            .push(expr_id);
413
414        expr_id
415    }
416
417    /// Get match expressions in a function.
418    pub fn match_exprs_in(
419        &self,
420        function_id: SymbolId,
421    ) -> impl Iterator<Item = &MatchExprDataV2> + '_ {
422        self.match_expr_index
423            .get(function_id)
424            .into_iter()
425            .flat_map(|ids| ids.iter())
426            .filter_map(|&id| self.match_exprs.get(id.as_usize()))
427    }
428
429    /// Find all match expressions that match on a given enum.
430    pub fn match_exprs_for_enum(
431        &self,
432        enum_id: SymbolId,
433    ) -> impl Iterator<Item = (SymbolId, &MatchExprDataV2)> + '_ {
434        self.match_expr_index
435            .iter()
436            .flat_map(move |(func_id, ids)| {
437                ids.iter()
438                    .filter_map(|&id| self.match_exprs.get(id.as_usize()))
439                    .filter(move |data| data.enum_id == enum_id)
440                    .map(move |data| (func_id, data))
441            })
442    }
443
444    /// Get total number of match expressions.
445    pub fn match_expr_count(&self) -> usize {
446        self.match_exprs.len()
447    }
448
449    // ========================================================================
450    // Statistics
451    // ========================================================================
452
453    /// Get the number of nodes.
454    #[inline]
455    pub fn node_count(&self) -> usize {
456        self.nodes.len()
457    }
458
459    /// Get the number of edges.
460    #[inline]
461    pub fn edge_count(&self) -> usize {
462        self.edges.len()
463    }
464
465    /// Check if the graph is empty.
466    #[inline]
467    pub fn is_empty(&self) -> bool {
468        self.nodes.is_empty()
469    }
470
471    // ========================================================================
472    // Call Chain Analysis (Transitive Traversal)
473    // ========================================================================
474
475    /// Find all callers transitively up to max_depth.
476    ///
477    /// Returns a list of (SymbolId, depth) pairs where depth indicates
478    /// how many hops from the starting symbol.
479    ///
480    /// # Example
481    /// ```text
482    /// A calls B, B calls C, C calls D
483    /// callers_chain(D, 3) returns: [(C, 1), (B, 2), (A, 3)]
484    /// ```
485    pub fn callers_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
486        self.traverse_chain(start, max_depth, ChainDirection::Callers)
487    }
488
489    /// Find all callees transitively up to max_depth.
490    ///
491    /// Returns a list of (SymbolId, depth) pairs where depth indicates
492    /// how many hops from the starting symbol.
493    ///
494    /// # Example
495    /// ```text
496    /// A calls B, B calls C, C calls D
497    /// callees_chain(A, 3) returns: [(B, 1), (C, 2), (D, 3)]
498    /// ```
499    pub fn callees_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
500        self.traverse_chain(start, max_depth, ChainDirection::Callees)
501    }
502
503    /// Internal BFS traversal for call chain analysis.
504    fn traverse_chain(
505        &self,
506        start: SymbolId,
507        max_depth: usize,
508        direction: ChainDirection,
509    ) -> Vec<ChainNode> {
510        use std::collections::{HashSet, VecDeque};
511
512        let mut result = Vec::new();
513        let mut visited = HashSet::new();
514        let mut queue = VecDeque::new();
515
516        visited.insert(start);
517        queue.push_back((start, 0usize));
518
519        while let Some((current, depth)) = queue.pop_front() {
520            if depth > 0 {
521                result.push(ChainNode {
522                    symbol: current,
523                    depth,
524                });
525            }
526
527            if depth >= max_depth {
528                continue;
529            }
530
531            let neighbors: Vec<SymbolId> = match direction {
532                ChainDirection::Callers => self.callers_of(current).collect(),
533                ChainDirection::Callees => self.callees_of(current).collect(),
534                ChainDirection::TypeUsers | ChainDirection::TypeDeps => {
535                    unreachable!("TypeUsers/TypeDeps must use TypeFlowGraphV2")
536                }
537            };
538
539            for neighbor in neighbors {
540                if !visited.contains(&neighbor) {
541                    visited.insert(neighbor);
542                    queue.push_back((neighbor, depth + 1));
543                }
544            }
545        }
546
547        result
548    }
549
550    /// Get full chain result with statistics.
551    pub fn analyze_chain(
552        &self,
553        start: SymbolId,
554        max_depth: usize,
555        direction: ChainDirection,
556    ) -> ChainResult {
557        let nodes = self.traverse_chain(start, max_depth, direction);
558
559        let mut by_depth: HashMap<usize, usize> = HashMap::new();
560        for node in &nodes {
561            *by_depth.entry(node.depth).or_default() += 1;
562        }
563
564        let max_actual_depth = nodes.iter().map(|n| n.depth).max().unwrap_or(0);
565
566        ChainResult {
567            start,
568            direction,
569            max_depth,
570            nodes,
571            max_actual_depth,
572            by_depth,
573        }
574    }
575}
576
577// ============================================================================
578// Call Chain Types
579// ============================================================================
580
581/// Direction for chain traversal.
582///
583/// Covers both call chains (CodeGraphV2) and type reference chains (TypeFlowGraphV2).
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
585pub enum ChainDirection {
586    /// Follow incoming Calls edges (who calls this?)
587    Callers,
588    /// Follow outgoing Calls edges (what does this call?)
589    Callees,
590    /// Follow type reference edges: who uses this type? (type → containers)
591    TypeUsers,
592    /// Follow type reference edges: what types does this use? (container → types)
593    TypeDeps,
594}
595
596impl std::fmt::Display for ChainDirection {
597    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598        match self {
599            ChainDirection::Callers => write!(f, "callers"),
600            ChainDirection::Callees => write!(f, "callees"),
601            ChainDirection::TypeUsers => write!(f, "type_users"),
602            ChainDirection::TypeDeps => write!(f, "type_deps"),
603        }
604    }
605}
606
607/// A node in the chain with depth information.
608#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
609pub struct ChainNode {
610    /// The symbol at this position in the chain.
611    pub symbol: SymbolId,
612    /// Depth from the starting symbol (1 = direct, 2 = one hop away, etc.)
613    pub depth: usize,
614}
615
616/// Result of a chain analysis operation.
617///
618/// Shared by both CodeGraphV2 (call chains) and TypeFlowGraphV2 (type chains).
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ChainResult {
621    /// Starting symbol of the analysis.
622    pub start: SymbolId,
623    /// Direction of traversal.
624    pub direction: ChainDirection,
625    /// Maximum depth requested.
626    pub max_depth: usize,
627    /// All nodes found in the chain.
628    pub nodes: Vec<ChainNode>,
629    /// Maximum depth actually reached.
630    pub max_actual_depth: usize,
631    /// Count of nodes at each depth level.
632    pub by_depth: HashMap<usize, usize>,
633}
634
635impl ChainResult {
636    /// Get total number of nodes in the chain.
637    pub fn total_count(&self) -> usize {
638        self.nodes.len()
639    }
640
641    /// Get nodes at a specific depth.
642    pub fn at_depth(&self, depth: usize) -> impl Iterator<Item = &ChainNode> {
643        self.nodes.iter().filter(move |n| n.depth == depth)
644    }
645
646    /// Check if the chain is empty.
647    pub fn is_empty(&self) -> bool {
648        self.nodes.is_empty()
649    }
650
651    /// Get symbols as a flat list.
652    pub fn symbols(&self) -> impl Iterator<Item = SymbolId> + '_ {
653        self.nodes.iter().map(|n| n.symbol)
654    }
655}
656
657// ============================================================================
658// Tests
659// ============================================================================
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use crate::symbol::{SymbolPath, SymbolRegistry};
665
666    fn setup() -> (SymbolRegistry, SymbolId, SymbolId, SymbolId) {
667        let mut registry = SymbolRegistry::new();
668        let id1 = registry
669            .register(SymbolPath::parse("foo::Bar").unwrap(), SymbolKind::Struct)
670            .unwrap();
671        let id2 = registry
672            .register(SymbolPath::parse("foo::baz").unwrap(), SymbolKind::Function)
673            .unwrap();
674        let id3 = registry
675            .register(SymbolPath::parse("foo::qux").unwrap(), SymbolKind::Function)
676            .unwrap();
677        (registry, id1, id2, id3)
678    }
679
680    #[test]
681    fn test_add_node() {
682        let (_, id1, _, _) = setup();
683        let mut graph = CodeGraphV2::new();
684
685        assert!(graph.add_node(id1));
686        assert!(!graph.add_node(id1)); // Already exists
687        assert!(graph.contains(id1));
688    }
689
690    #[test]
691    fn test_add_edge() {
692        let (_, id1, id2, _) = setup();
693        let mut graph = CodeGraphV2::new();
694
695        graph.add_edge(id1, id2, CodeEdgeV2::Contains);
696
697        assert!(graph.contains(id1));
698        assert!(graph.contains(id2));
699        assert!(graph.has_edge(id1, id2, CodeEdgeV2::Contains));
700        assert!(!graph.has_edge(id2, id1, CodeEdgeV2::Contains));
701    }
702
703    #[test]
704    fn test_callers_of() {
705        let (_, id1, id2, id3) = setup();
706        let mut graph = CodeGraphV2::new();
707
708        graph.add_edge(id1, id3, CodeEdgeV2::Calls);
709        graph.add_edge(id2, id3, CodeEdgeV2::Calls);
710
711        let callers: Vec<_> = graph.callers_of(id3).collect();
712        assert_eq!(callers.len(), 2);
713        assert!(callers.contains(&id1));
714        assert!(callers.contains(&id2));
715    }
716
717    #[test]
718    fn test_children_of() {
719        let (_, id1, id2, id3) = setup();
720        let mut graph = CodeGraphV2::new();
721
722        graph.add_edge(id1, id2, CodeEdgeV2::Contains);
723        graph.add_edge(id1, id3, CodeEdgeV2::Contains);
724
725        let children: Vec<_> = graph.children_of(id1).collect();
726        assert_eq!(children.len(), 2);
727    }
728
729    #[test]
730    fn test_parent_of() {
731        let (_, id1, id2, _) = setup();
732        let mut graph = CodeGraphV2::new();
733
734        graph.add_edge(id1, id2, CodeEdgeV2::Contains);
735
736        assert_eq!(graph.parent_of(id2), Some(id1));
737        assert_eq!(graph.parent_of(id1), None);
738    }
739
740    #[test]
741    fn test_remove_node() {
742        let (_, id1, id2, _) = setup();
743        let mut graph = CodeGraphV2::new();
744
745        graph.add_edge(id1, id2, CodeEdgeV2::Calls);
746        assert_eq!(graph.node_count(), 2);
747
748        assert!(graph.remove_node(id1));
749        assert_eq!(graph.node_count(), 1);
750        assert!(!graph.contains(id1));
751        assert!(graph.contains(id2));
752    }
753
754    #[test]
755    fn test_kind_index() {
756        let (_, id1, id2, id3) = setup();
757        let mut graph = CodeGraphV2::new();
758
759        graph.add_node(id1);
760        graph.add_node(id2);
761        graph.add_node(id3);
762
763        graph.add_to_kind_index(id1, SymbolKind::Struct);
764        graph.add_to_kind_index(id2, SymbolKind::Function);
765        graph.add_to_kind_index(id3, SymbolKind::Function);
766
767        let structs: Vec<_> = graph.iter_by_kind(SymbolKind::Struct).collect();
768        assert_eq!(structs.len(), 1);
769
770        let functions: Vec<_> = graph.iter_by_kind(SymbolKind::Function).collect();
771        assert_eq!(functions.len(), 2);
772    }
773
774    // ========================================================================
775    // Call Chain Tests
776    // ========================================================================
777
778    fn setup_chain() -> (
779        SymbolRegistry,
780        SymbolId,
781        SymbolId,
782        SymbolId,
783        SymbolId,
784        SymbolId,
785    ) {
786        let mut registry = SymbolRegistry::new();
787        // Create a call chain: a -> b -> c -> d -> e
788        let a = registry
789            .register(
790                SymbolPath::parse("test::fn_a").unwrap(),
791                SymbolKind::Function,
792            )
793            .unwrap();
794        let b = registry
795            .register(
796                SymbolPath::parse("test::fn_b").unwrap(),
797                SymbolKind::Function,
798            )
799            .unwrap();
800        let c = registry
801            .register(
802                SymbolPath::parse("test::fn_c").unwrap(),
803                SymbolKind::Function,
804            )
805            .unwrap();
806        let d = registry
807            .register(
808                SymbolPath::parse("test::fn_d").unwrap(),
809                SymbolKind::Function,
810            )
811            .unwrap();
812        let e = registry
813            .register(
814                SymbolPath::parse("test::fn_e").unwrap(),
815                SymbolKind::Function,
816            )
817            .unwrap();
818        (registry, a, b, c, d, e)
819    }
820
821    #[test]
822    fn test_callers_chain_simple() {
823        let (_, a, b, c, d, _) = setup_chain();
824        let mut graph = CodeGraphV2::new();
825
826        // a calls b, b calls c, c calls d
827        graph.add_edge(a, b, CodeEdgeV2::Calls);
828        graph.add_edge(b, c, CodeEdgeV2::Calls);
829        graph.add_edge(c, d, CodeEdgeV2::Calls);
830
831        // From d, callers are: c (depth 1), b (depth 2), a (depth 3)
832        let chain = graph.callers_chain(d, 10);
833        assert_eq!(chain.len(), 3);
834
835        // Check depths
836        let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
837        assert_eq!(c_node.depth, 1);
838
839        let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
840        assert_eq!(b_node.depth, 2);
841
842        let a_node = chain.iter().find(|n| n.symbol == a).unwrap();
843        assert_eq!(a_node.depth, 3);
844    }
845
846    #[test]
847    fn test_callees_chain_simple() {
848        let (_, a, b, c, d, _) = setup_chain();
849        let mut graph = CodeGraphV2::new();
850
851        // a calls b, b calls c, c calls d
852        graph.add_edge(a, b, CodeEdgeV2::Calls);
853        graph.add_edge(b, c, CodeEdgeV2::Calls);
854        graph.add_edge(c, d, CodeEdgeV2::Calls);
855
856        // From a, callees are: b (depth 1), c (depth 2), d (depth 3)
857        let chain = graph.callees_chain(a, 10);
858        assert_eq!(chain.len(), 3);
859
860        let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
861        assert_eq!(b_node.depth, 1);
862
863        let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
864        assert_eq!(c_node.depth, 2);
865
866        let d_node = chain.iter().find(|n| n.symbol == d).unwrap();
867        assert_eq!(d_node.depth, 3);
868    }
869
870    #[test]
871    fn test_chain_with_max_depth() {
872        let (_, a, b, c, d, _) = setup_chain();
873        let mut graph = CodeGraphV2::new();
874
875        graph.add_edge(a, b, CodeEdgeV2::Calls);
876        graph.add_edge(b, c, CodeEdgeV2::Calls);
877        graph.add_edge(c, d, CodeEdgeV2::Calls);
878
879        // Limit to depth 2
880        let chain = graph.callees_chain(a, 2);
881        assert_eq!(chain.len(), 2); // b and c only, not d
882
883        let symbols: Vec<_> = chain.iter().map(|n| n.symbol).collect();
884        assert!(symbols.contains(&b));
885        assert!(symbols.contains(&c));
886        assert!(!symbols.contains(&d));
887    }
888
889    #[test]
890    fn test_chain_with_cycle() {
891        let (_, a, b, c, _, _) = setup_chain();
892        let mut graph = CodeGraphV2::new();
893
894        // Create a cycle: a -> b -> c -> a
895        graph.add_edge(a, b, CodeEdgeV2::Calls);
896        graph.add_edge(b, c, CodeEdgeV2::Calls);
897        graph.add_edge(c, a, CodeEdgeV2::Calls);
898
899        // Should not infinite loop, visited set prevents it
900        let chain = graph.callees_chain(a, 10);
901        assert_eq!(chain.len(), 2); // b and c (a is start, not included)
902    }
903
904    #[test]
905    fn test_analyze_chain() {
906        let (_, a, b, c, d, e) = setup_chain();
907        let mut graph = CodeGraphV2::new();
908
909        // Linear chain: a -> b -> c -> d -> e
910        graph.add_edge(a, b, CodeEdgeV2::Calls);
911        graph.add_edge(b, c, CodeEdgeV2::Calls);
912        graph.add_edge(c, d, CodeEdgeV2::Calls);
913        graph.add_edge(d, e, CodeEdgeV2::Calls);
914
915        let result = graph.analyze_chain(a, 10, ChainDirection::Callees);
916
917        assert_eq!(result.start, a);
918        assert_eq!(result.direction, ChainDirection::Callees);
919        assert_eq!(result.total_count(), 4);
920        assert_eq!(result.max_actual_depth, 4);
921
922        // Check depth distribution
923        assert_eq!(*result.by_depth.get(&1).unwrap_or(&0), 1); // b at depth 1
924        assert_eq!(*result.by_depth.get(&2).unwrap_or(&0), 1); // c at depth 2
925        assert_eq!(*result.by_depth.get(&3).unwrap_or(&0), 1); // d at depth 3
926        assert_eq!(*result.by_depth.get(&4).unwrap_or(&0), 1); // e at depth 4
927    }
928}