stack_graphs/
assert.rs

1// -*- coding: utf-8 -*-
2// ------------------------------------------------------------------------------------------------
3// Copyright © 2022, stack-graphs authors.
4// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
5// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
6// ------------------------------------------------------------------------------------------------
7
8//! Defines assertions that can be run against a stack graph.
9
10use itertools::Itertools;
11use lsp_positions::Position;
12
13use crate::arena::Handle;
14use crate::graph::File;
15use crate::graph::Node;
16use crate::graph::StackGraph;
17use crate::graph::Symbol;
18use crate::partial::PartialPath;
19use crate::partial::PartialPaths;
20use crate::stitching::Database;
21use crate::stitching::DatabaseCandidates;
22use crate::stitching::ForwardPartialPathStitcher;
23use crate::stitching::StitcherConfig;
24use crate::CancellationError;
25use crate::CancellationFlag;
26
27/// A stack graph assertion
28#[derive(Debug, Clone)]
29pub enum Assertion {
30    Defined {
31        source: AssertionSource,
32        targets: Vec<AssertionTarget>,
33    },
34    Defines {
35        source: AssertionSource,
36        symbols: Vec<Handle<Symbol>>,
37    },
38    Refers {
39        source: AssertionSource,
40        symbols: Vec<Handle<Symbol>>,
41    },
42}
43
44/// Source position of an assertion
45#[derive(Debug, Clone)]
46pub struct AssertionSource {
47    pub file: Handle<File>,
48    pub position: Position,
49}
50
51impl AssertionSource {
52    /// Return an iterator over definitions at this position.
53    pub fn iter_definitions<'a>(
54        &'a self,
55        graph: &'a StackGraph,
56    ) -> impl Iterator<Item = Handle<Node>> + 'a {
57        graph.nodes_for_file(self.file).filter(move |n| {
58            graph[*n].is_definition()
59                && graph
60                    .source_info(*n)
61                    .map(|s| s.span.contains(&self.position))
62                    .unwrap_or(false)
63        })
64    }
65
66    /// Return an iterator over references at this position.
67    pub fn iter_references<'a>(
68        &'a self,
69        graph: &'a StackGraph,
70    ) -> impl Iterator<Item = Handle<Node>> + 'a {
71        graph.nodes_for_file(self.file).filter(move |n| {
72            graph[*n].is_reference()
73                && graph
74                    .source_info(*n)
75                    .map(|s| s.span.contains(&self.position))
76                    .unwrap_or(false)
77        })
78    }
79
80    pub fn display<'a>(&'a self, graph: &'a StackGraph) -> impl std::fmt::Display + 'a {
81        struct Displayer<'a>(&'a AssertionSource, &'a StackGraph);
82        impl std::fmt::Display for Displayer<'_> {
83            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84                write!(
85                    f,
86                    "{}:{}:{}",
87                    self.1[self.0.file],
88                    self.0.position.line + 1,
89                    self.0.position.column.grapheme_offset + 1
90                )
91            }
92        }
93        Displayer(self, graph)
94    }
95}
96
97/// Target line of an assertion
98#[derive(Debug, Clone, PartialEq, Eq, Hash)]
99pub struct AssertionTarget {
100    pub file: Handle<File>,
101    pub line: usize,
102}
103
104impl AssertionTarget {
105    /// Checks if the target matches the node corresponding to the handle in the given graph.
106    pub fn matches_node(&self, node: Handle<Node>, graph: &StackGraph) -> bool {
107        let file = graph[node].file().unwrap();
108        let si = graph.source_info(node).unwrap();
109        let start_line = si.span.start.line;
110        let end_line = si.span.end.line;
111        file == self.file && start_line <= self.line && self.line <= end_line
112    }
113}
114
115/// Error describing assertion failures.
116#[derive(Clone)]
117pub enum AssertionError {
118    NoReferences {
119        source: AssertionSource,
120    },
121    IncorrectlyDefined {
122        source: AssertionSource,
123        references: Vec<Handle<Node>>,
124        missing_targets: Vec<AssertionTarget>,
125        unexpected_paths: Vec<PartialPath>,
126    },
127    IncorrectDefinitions {
128        source: AssertionSource,
129        missing_symbols: Vec<Handle<Symbol>>,
130        unexpected_symbols: Vec<Handle<Symbol>>,
131    },
132    IncorrectReferences {
133        source: AssertionSource,
134        missing_symbols: Vec<Handle<Symbol>>,
135        unexpected_symbols: Vec<Handle<Symbol>>,
136    },
137    Cancelled(CancellationError),
138}
139
140impl From<CancellationError> for AssertionError {
141    fn from(value: CancellationError) -> Self {
142        Self::Cancelled(value)
143    }
144}
145
146impl Assertion {
147    /// Run this assertion against the given graph, using the given paths object for path search.
148    pub fn run(
149        &self,
150        graph: &StackGraph,
151        partials: &mut PartialPaths,
152        db: &mut Database,
153        stitcher_config: StitcherConfig,
154        cancellation_flag: &dyn CancellationFlag,
155    ) -> Result<(), AssertionError> {
156        match self {
157            Self::Defined { source, targets } => self.run_defined(
158                graph,
159                partials,
160                db,
161                source,
162                targets,
163                stitcher_config,
164                cancellation_flag,
165            ),
166            Self::Defines { source, symbols } => self.run_defines(graph, source, symbols),
167            Self::Refers { source, symbols } => self.run_refers(graph, source, symbols),
168        }
169    }
170
171    fn run_defined(
172        &self,
173        graph: &StackGraph,
174        partials: &mut PartialPaths,
175        db: &mut Database,
176        source: &AssertionSource,
177        expected_targets: &Vec<AssertionTarget>,
178        stitcher_config: StitcherConfig,
179        cancellation_flag: &dyn CancellationFlag,
180    ) -> Result<(), AssertionError> {
181        let references = source.iter_references(graph).collect::<Vec<_>>();
182        if references.is_empty() {
183            return Err(AssertionError::NoReferences {
184                source: source.clone(),
185            });
186        }
187
188        let mut actual_paths = Vec::new();
189        for reference in &references {
190            let mut reference_paths = Vec::new();
191            ForwardPartialPathStitcher::find_all_complete_partial_paths(
192                &mut DatabaseCandidates::new(graph, partials, db),
193                vec![*reference],
194                stitcher_config,
195                cancellation_flag,
196                |_, _, p| {
197                    reference_paths.push(p.clone());
198                },
199            )?;
200            for reference_path in &reference_paths {
201                if reference_paths
202                    .iter()
203                    .all(|other| !other.shadows(partials, reference_path))
204                {
205                    actual_paths.push(reference_path.clone());
206                }
207            }
208        }
209
210        let missing_targets = expected_targets
211            .iter()
212            .filter(|t| {
213                !actual_paths
214                    .iter()
215                    .any(|p| t.matches_node(p.end_node, graph))
216            })
217            .cloned()
218            .unique()
219            .collect::<Vec<_>>();
220        let unexpected_paths = actual_paths
221            .iter()
222            .filter(|p| {
223                !expected_targets
224                    .iter()
225                    .any(|t| t.matches_node(p.end_node, graph))
226            })
227            .cloned()
228            .collect::<Vec<_>>();
229        if !missing_targets.is_empty() || !unexpected_paths.is_empty() {
230            return Err(AssertionError::IncorrectlyDefined {
231                source: source.clone(),
232                references,
233                missing_targets,
234                unexpected_paths,
235            });
236        }
237
238        Ok(())
239    }
240
241    fn run_defines(
242        &self,
243        graph: &StackGraph,
244        source: &AssertionSource,
245        expected_symbols: &Vec<Handle<Symbol>>,
246    ) -> Result<(), AssertionError> {
247        let actual_symbols = source
248            .iter_definitions(graph)
249            .filter_map(|d| graph[d].symbol())
250            .collect::<Vec<_>>();
251        let missing_symbols = expected_symbols
252            .iter()
253            .filter(|x| !actual_symbols.contains(*x))
254            .cloned()
255            .unique()
256            .collect::<Vec<_>>();
257        let unexpected_symbols = actual_symbols
258            .iter()
259            .filter(|x| !expected_symbols.contains(*x))
260            .cloned()
261            .unique()
262            .collect::<Vec<_>>();
263        if !missing_symbols.is_empty() || !unexpected_symbols.is_empty() {
264            return Err(AssertionError::IncorrectDefinitions {
265                source: source.clone(),
266                missing_symbols,
267                unexpected_symbols,
268            });
269        }
270        Ok(())
271    }
272
273    fn run_refers(
274        &self,
275        graph: &StackGraph,
276        source: &AssertionSource,
277        expected_symbols: &Vec<Handle<Symbol>>,
278    ) -> Result<(), AssertionError> {
279        let actual_symbols = source
280            .iter_references(graph)
281            .filter_map(|d| graph[d].symbol())
282            .collect::<Vec<_>>();
283        let missing_symbols = expected_symbols
284            .iter()
285            .filter(|x| !actual_symbols.contains(*x))
286            .cloned()
287            .unique()
288            .collect::<Vec<_>>();
289        let unexpected_symbols = actual_symbols
290            .iter()
291            .filter(|x| !expected_symbols.contains(*x))
292            .cloned()
293            .unique()
294            .collect::<Vec<_>>();
295        if !missing_symbols.is_empty() || !unexpected_symbols.is_empty() {
296            return Err(AssertionError::IncorrectReferences {
297                source: source.clone(),
298                missing_symbols,
299                unexpected_symbols,
300            });
301        }
302        Ok(())
303    }
304}