typst_ide/
matchers.rs

1use ecow::EcoString;
2use typst::foundations::{Module, Value};
3use typst::syntax::ast::AstNode;
4use typst::syntax::{LinkedNode, Span, SyntaxKind, ast};
5
6use crate::{IdeWorld, analyze_import};
7
8/// Find the named items starting from the given position.
9pub fn named_items<T>(
10    world: &dyn IdeWorld,
11    position: LinkedNode,
12    mut recv: impl FnMut(NamedItem) -> Option<T>,
13) -> Option<T> {
14    let mut ancestor = Some(position);
15    while let Some(node) = &ancestor {
16        let mut sibling = Some(node.clone());
17        while let Some(node) = &sibling {
18            if let Some(v) = node.cast::<ast::LetBinding>() {
19                let kind = if matches!(v.kind(), ast::LetBindingKind::Closure(..)) {
20                    NamedItem::Fn
21                } else {
22                    NamedItem::Var
23                };
24                for ident in v.kind().bindings() {
25                    if let Some(res) = recv(kind(ident)) {
26                        return Some(res);
27                    }
28                }
29            }
30
31            if let Some(v) = node.cast::<ast::ModuleImport>() {
32                let imports = v.imports();
33                let source = v.source();
34
35                let source_value = node
36                    .find(source.span())
37                    .and_then(|source| analyze_import(world, &source));
38                let source_value = source_value.as_ref();
39
40                let module = source_value.and_then(|value| match value {
41                    Value::Module(module) => Some(module),
42                    _ => None,
43                });
44
45                let name_and_span = match (imports, v.new_name()) {
46                    // ```plain
47                    // import "foo" as name
48                    // import "foo" as name: ..
49                    // ```
50                    (_, Some(name)) => Some((name.get().clone(), name.span())),
51                    // ```plain
52                    // import "foo"
53                    // ```
54                    (None, None) => v.bare_name().ok().map(|name| (name, source.span())),
55                    // ```plain
56                    // import "foo": ..
57                    // ```
58                    (Some(..), None) => None,
59                };
60
61                // Seeing the module itself.
62                if let Some((name, span)) = name_and_span
63                    && let Some(res) = recv(NamedItem::Module(&name, span, module))
64                {
65                    return Some(res);
66                }
67
68                // Seeing the imported items.
69                match imports {
70                    // ```plain
71                    // import "foo";
72                    // ```
73                    None => {}
74                    // ```plain
75                    // import "foo": *;
76                    // ```
77                    Some(ast::Imports::Wildcard) => {
78                        if let Some(scope) = source_value.and_then(Value::scope) {
79                            for (name, binding) in scope.iter() {
80                                let item = NamedItem::Import(
81                                    name,
82                                    binding.span(),
83                                    Some(binding.read()),
84                                );
85                                if let Some(res) = recv(item) {
86                                    return Some(res);
87                                }
88                            }
89                        }
90                    }
91                    // ```plain
92                    // import "foo": items;
93                    // ```
94                    Some(ast::Imports::Items(items)) => {
95                        for item in items.iter() {
96                            let mut iter = item.path().iter();
97                            let mut binding = source_value
98                                .and_then(Value::scope)
99                                .zip(iter.next())
100                                .and_then(|(scope, first)| scope.get(&first));
101
102                            for ident in iter {
103                                binding = binding.and_then(|binding| {
104                                    binding.read().scope()?.get(&ident)
105                                });
106                            }
107
108                            let bound = item.bound_name();
109                            let (span, value) = match binding {
110                                Some(binding) => (binding.span(), Some(binding.read())),
111                                None => (bound.span(), None),
112                            };
113
114                            let item = NamedItem::Import(bound.get(), span, value);
115                            if let Some(res) = recv(item) {
116                                return Some(res);
117                            }
118                        }
119                    }
120                }
121            }
122
123            sibling = node.prev_sibling();
124        }
125
126        if let Some(parent) = node.parent() {
127            if let Some(v) = parent.cast::<ast::ForLoop>()
128                && node.prev_sibling_kind() != Some(SyntaxKind::In)
129            {
130                let pattern = v.pattern();
131                for ident in pattern.bindings() {
132                    if let Some(res) = recv(NamedItem::Var(ident)) {
133                        return Some(res);
134                    }
135                }
136            }
137
138            if let Some(v) = parent.cast::<ast::Closure>().filter(|v| {
139                // Check if the node is in the body of the closure.
140                let body = parent.find(v.body().span());
141                body.is_some_and(|n| n.find(node.span()).is_some())
142            }) {
143                for param in v.params().children() {
144                    match param {
145                        ast::Param::Pos(pattern) => {
146                            for ident in pattern.bindings() {
147                                if let Some(t) = recv(NamedItem::Var(ident)) {
148                                    return Some(t);
149                                }
150                            }
151                        }
152                        ast::Param::Named(n) => {
153                            if let Some(t) = recv(NamedItem::Var(n.name())) {
154                                return Some(t);
155                            }
156                        }
157                        ast::Param::Spread(s) => {
158                            if let Some(sink_ident) = s.sink_ident()
159                                && let Some(t) = recv(NamedItem::Var(sink_ident))
160                            {
161                                return Some(t);
162                            }
163                        }
164                    }
165                }
166            }
167
168            ancestor = Some(parent.clone());
169            continue;
170        }
171
172        break;
173    }
174
175    None
176}
177
178/// An item that is named.
179pub enum NamedItem<'a> {
180    /// A variable item.
181    Var(ast::Ident<'a>),
182    /// A function item.
183    Fn(ast::Ident<'a>),
184    /// A (imported) module.
185    Module(&'a EcoString, Span, Option<&'a Module>),
186    /// An imported item.
187    Import(&'a EcoString, Span, Option<&'a Value>),
188}
189
190impl<'a> NamedItem<'a> {
191    pub(crate) fn name(&self) -> &'a EcoString {
192        match self {
193            NamedItem::Var(ident) => ident.get(),
194            NamedItem::Fn(ident) => ident.get(),
195            NamedItem::Module(name, _, _) => name,
196            NamedItem::Import(name, _, _) => name,
197        }
198    }
199
200    pub(crate) fn value(&self) -> Option<Value> {
201        match self {
202            NamedItem::Var(..) | NamedItem::Fn(..) => None,
203            NamedItem::Module(_, _, value) => value.cloned().map(Value::Module),
204            NamedItem::Import(_, _, value) => value.cloned(),
205        }
206    }
207
208    pub(crate) fn span(&self) -> Span {
209        match *self {
210            NamedItem::Var(name) | NamedItem::Fn(name) => name.span(),
211            NamedItem::Module(_, span, _) => span,
212            NamedItem::Import(_, span, _) => span,
213        }
214    }
215}
216
217/// Categorize an expression into common classes IDE functionality can operate
218/// on.
219pub fn deref_target(node: LinkedNode<'_>) -> Option<DerefTarget<'_>> {
220    // Move to the first ancestor that is an expression.
221    let mut ancestor = node;
222    while !ancestor.is::<ast::Expr>() {
223        ancestor = ancestor.parent()?.clone();
224    }
225
226    // Identify convenient expression kinds.
227    let expr_node = ancestor;
228    let expr = expr_node.cast::<ast::Expr>()?;
229    Some(match expr {
230        ast::Expr::Label(_) => DerefTarget::Label(expr_node),
231        ast::Expr::Ref(_) => DerefTarget::Ref(expr_node),
232        ast::Expr::FuncCall(call) => {
233            DerefTarget::Callee(expr_node.find(call.callee().span())?)
234        }
235        ast::Expr::SetRule(set) => {
236            DerefTarget::Callee(expr_node.find(set.target().span())?)
237        }
238        ast::Expr::Ident(_) | ast::Expr::MathIdent(_) | ast::Expr::FieldAccess(_) => {
239            DerefTarget::VarAccess(expr_node)
240        }
241        ast::Expr::Str(_) => {
242            let parent = expr_node.parent()?;
243            if parent.kind() == SyntaxKind::ModuleImport {
244                DerefTarget::ImportPath(expr_node)
245            } else if parent.kind() == SyntaxKind::ModuleInclude {
246                DerefTarget::IncludePath(expr_node)
247            } else {
248                DerefTarget::Code(expr_node)
249            }
250        }
251        _ if expr.hash()
252            || matches!(expr_node.kind(), SyntaxKind::MathIdent | SyntaxKind::Error) =>
253        {
254            DerefTarget::Code(expr_node)
255        }
256        _ => return None,
257    })
258}
259
260/// Classes of expressions that can be operated on by IDE functionality.
261#[derive(Debug, Clone)]
262pub enum DerefTarget<'a> {
263    /// A variable access expression.
264    ///
265    /// It can be either an identifier or a field access.
266    VarAccess(LinkedNode<'a>),
267    /// A function call expression.
268    Callee(LinkedNode<'a>),
269    /// An import path expression.
270    ImportPath(LinkedNode<'a>),
271    /// An include path expression.
272    IncludePath(LinkedNode<'a>),
273    /// Any code expression.
274    Code(LinkedNode<'a>),
275    /// A label expression.
276    Label(LinkedNode<'a>),
277    /// A reference expression.
278    Ref(LinkedNode<'a>),
279}
280
281#[cfg(test)]
282mod tests {
283    use std::borrow::Borrow;
284
285    use ecow::EcoString;
286    use typst::foundations::Value;
287    use typst::syntax::{LinkedNode, Side};
288
289    use super::named_items;
290    use crate::tests::{FilePos, TestWorld, WorldLike};
291
292    type Response = Vec<(EcoString, Option<Value>)>;
293
294    trait ResponseExt {
295        fn must_include<'a>(&self, includes: impl IntoIterator<Item = &'a str>) -> &Self;
296        fn must_exclude<'a>(&self, excludes: impl IntoIterator<Item = &'a str>) -> &Self;
297        fn must_include_value(&self, name_value: (&str, Option<&Value>)) -> &Self;
298    }
299
300    impl ResponseExt for Response {
301        #[track_caller]
302        fn must_include<'a>(&self, includes: impl IntoIterator<Item = &'a str>) -> &Self {
303            for item in includes {
304                assert!(
305                    self.iter().any(|v| v.0 == item),
306                    "{item:?} was not contained in {self:?}",
307                );
308            }
309            self
310        }
311
312        #[track_caller]
313        fn must_exclude<'a>(&self, excludes: impl IntoIterator<Item = &'a str>) -> &Self {
314            for item in excludes {
315                assert!(
316                    !self.iter().any(|v| v.0 == item),
317                    "{item:?} was wrongly contained in {self:?}",
318                );
319            }
320            self
321        }
322
323        #[track_caller]
324        fn must_include_value(&self, name_value: (&str, Option<&Value>)) -> &Self {
325            assert!(
326                self.iter().any(|v| (v.0.as_str(), v.1.as_ref()) == name_value),
327                "{name_value:?} was not contained in {self:?}",
328            );
329            self
330        }
331    }
332
333    #[track_caller]
334    fn test(world: impl WorldLike, pos: impl FilePos) -> Response {
335        let world = world.acquire();
336        let world = world.borrow();
337        let (source, cursor) = pos.resolve(world);
338        let node = LinkedNode::new(source.root());
339        let leaf = node.leaf_at(cursor, Side::After).unwrap();
340        let mut items = vec![];
341        named_items(world, leaf, |s| {
342            items.push((s.name().clone(), s.value().clone()));
343            None::<()>
344        });
345        items
346    }
347
348    #[test]
349    fn test_named_items_simple() {
350        let s = "#let a = 1;#let b = 2;";
351        test(s, 8).must_include(["a"]).must_exclude(["b"]);
352        test(s, 15).must_include(["b"]);
353    }
354
355    #[test]
356    fn test_named_items_param() {
357        let pos = "#let f(a) = 1;#let b = 2;";
358        test(pos, 12).must_include(["a"]);
359        test(pos, 19).must_include(["b", "f"]).must_exclude(["a"]);
360
361        let named = "#let f(a: b) = 1;#let b = 2;";
362        test(named, 15).must_include(["a", "f"]).must_exclude(["b"]);
363    }
364
365    #[test]
366    fn test_named_items_import() {
367        test("#import \"foo.typ\"", 2).must_include(["foo"]);
368        test("#import \"foo.typ\" as bar", 2)
369            .must_include(["bar"])
370            .must_exclude(["foo"]);
371    }
372
373    #[test]
374    fn test_named_items_import_items() {
375        test("#import \"foo.typ\": a; #(a);", 2)
376            .must_include(["a"])
377            .must_exclude(["foo"]);
378
379        let world = TestWorld::new("#import \"foo.typ\": a.b; #(b);")
380            .with_source("foo.typ", "#import \"a.typ\"")
381            .with_source("a.typ", "#let b = 1;");
382        test(&world, 2).must_include_value(("b", Some(&Value::Int(1))));
383    }
384}