Skip to main content

typst_ide/
matchers.rs

1use ecow::EcoString;
2use typst::foundations::{Module, Value};
3use typst::syntax::ast::AstNode;
4use typst::syntax::{LinkedNode, Span, SyntaxKind, ast};
5
6use crate::{IdeWorld, analyze_import};
7
8/// Find the named items starting from the given position.
9pub fn named_items<T>(
10    world: &dyn IdeWorld,
11    position: LinkedNode,
12    mut recv: impl FnMut(NamedItem) -> Option<T>,
13) -> Option<T> {
14    let mut ancestor = Some(position);
15    while let Some(node) = &ancestor {
16        let mut sibling = Some(node.clone());
17        while let Some(node) = &sibling {
18            if let Some(v) = node.cast::<ast::LetBinding>() {
19                let kind = if matches!(v.kind(), ast::LetBindingKind::Closure(..)) {
20                    NamedItem::Fn
21                } else {
22                    NamedItem::Var
23                };
24                for ident in v.kind().bindings() {
25                    if let Some(res) = recv(kind(ident)) {
26                        return Some(res);
27                    }
28                }
29            }
30
31            if let Some(v) = node.cast::<ast::ModuleImport>() {
32                let imports = v.imports();
33                let source = v.source();
34
35                let source_value = node
36                    .find(source.span())
37                    .and_then(|source| analyze_import(world, &source));
38                let source_value = source_value.as_ref();
39
40                let module = source_value.and_then(|value| match value {
41                    Value::Module(module) => Some(module),
42                    _ => None,
43                });
44
45                let name_and_span = match (imports, v.new_name()) {
46                    // ```plain
47                    // import "foo" as name
48                    // import "foo" as name: ..
49                    // ```
50                    (_, Some(name)) => Some((name.get().clone(), name.span())),
51                    // ```plain
52                    // import "foo"
53                    // ```
54                    (None, None) => v.bare_name().ok().map(|name| (name, source.span())),
55                    // ```plain
56                    // import "foo": ..
57                    // ```
58                    (Some(..), None) => None,
59                };
60
61                // Seeing the module itself.
62                if let Some((name, span)) = name_and_span
63                    && let Some(res) = recv(NamedItem::Module(&name, span, module))
64                {
65                    return Some(res);
66                }
67
68                // Seeing the imported items.
69                match imports {
70                    // ```plain
71                    // import "foo";
72                    // ```
73                    None => {}
74                    // ```plain
75                    // import "foo": *;
76                    // ```
77                    Some(ast::Imports::Wildcard) => {
78                        if let Some(scope) = source_value.and_then(Value::scope) {
79                            for (name, binding) in scope.iter() {
80                                let item = NamedItem::Import(
81                                    name,
82                                    binding.span(),
83                                    Some(binding.read()),
84                                );
85                                if let Some(res) = recv(item) {
86                                    return Some(res);
87                                }
88                            }
89                        }
90                    }
91                    // ```plain
92                    // import "foo": items;
93                    // ```
94                    Some(ast::Imports::Items(items)) => {
95                        for item in items.iter() {
96                            let mut iter = item.path().iter();
97                            let mut binding = source_value
98                                .and_then(Value::scope)
99                                .zip(iter.next())
100                                .and_then(|(scope, first)| scope.get(&first));
101
102                            for ident in iter {
103                                binding = binding.and_then(|binding| {
104                                    binding.read().scope()?.get(&ident)
105                                });
106                            }
107
108                            let bound = item.bound_name();
109                            let (span, value) = match binding {
110                                Some(binding) => (binding.span(), Some(binding.read())),
111                                None => (bound.span(), None),
112                            };
113
114                            let item = NamedItem::Import(bound.get(), span, value);
115                            if let Some(res) = recv(item) {
116                                return Some(res);
117                            }
118                        }
119                    }
120                }
121            }
122
123            sibling = node.prev_sibling();
124        }
125
126        if let Some(parent) = node.parent() {
127            if let Some(v) = parent.cast::<ast::ForLoop>()
128                && node.prev_sibling_kind() != Some(SyntaxKind::In)
129            {
130                let pattern = v.pattern();
131                for ident in pattern.bindings() {
132                    if let Some(res) = recv(NamedItem::Var(ident)) {
133                        return Some(res);
134                    }
135                }
136            }
137
138            if let Some(v) = parent.cast::<ast::Closure>().filter(|v| {
139                // Check if the node is in the body of the closure.
140                let body = parent.find(v.body().span());
141                body.is_some_and(|n| n.find(node.span()).is_some())
142            }) {
143                for param in v.params().children() {
144                    match param {
145                        ast::Param::Pos(pattern) => {
146                            for ident in pattern.bindings() {
147                                if let Some(t) = recv(NamedItem::Var(ident)) {
148                                    return Some(t);
149                                }
150                            }
151                        }
152                        ast::Param::Named(n) => {
153                            if let Some(t) = recv(NamedItem::Var(n.name())) {
154                                return Some(t);
155                            }
156                        }
157                        ast::Param::Spread(s) => {
158                            if let Some(sink_ident) = s.sink_ident()
159                                && let Some(t) = recv(NamedItem::Var(sink_ident))
160                            {
161                                return Some(t);
162                            }
163                        }
164                    }
165                }
166            }
167
168            ancestor = Some(parent.clone());
169            continue;
170        }
171
172        break;
173    }
174
175    None
176}
177
178/// An item that is named.
179pub enum NamedItem<'a> {
180    /// A variable item.
181    Var(ast::Ident<'a>),
182    /// A function item.
183    Fn(ast::Ident<'a>),
184    /// A (imported) module.
185    Module(&'a EcoString, Span, Option<&'a Module>),
186    /// An imported item.
187    Import(&'a EcoString, Span, Option<&'a Value>),
188}
189
190impl<'a> NamedItem<'a> {
191    pub(crate) fn name(&self) -> &'a EcoString {
192        match self {
193            NamedItem::Var(ident) => ident.get(),
194            NamedItem::Fn(ident) => ident.get(),
195            NamedItem::Module(name, _, _) => name,
196            NamedItem::Import(name, _, _) => name,
197        }
198    }
199
200    pub(crate) fn value(&self) -> Option<Value> {
201        match self {
202            NamedItem::Var(..) | NamedItem::Fn(..) => None,
203            NamedItem::Module(_, _, value) => value.cloned().map(Value::Module),
204            NamedItem::Import(_, _, value) => value.cloned(),
205        }
206    }
207
208    pub(crate) fn span(&self) -> Span {
209        match *self {
210            NamedItem::Var(name) | NamedItem::Fn(name) => name.span(),
211            NamedItem::Module(_, span, _) => span,
212            NamedItem::Import(_, span, _) => span,
213        }
214    }
215}
216
217/// Categorize an expression into common classes IDE functionality can operate
218/// on.
219pub fn deref_target(node: LinkedNode<'_>) -> Option<DerefTarget<'_>> {
220    // Move to the first ancestor that is an expression.
221    let mut ancestor = node;
222    while !ancestor.is::<ast::Expr>() {
223        ancestor = ancestor.parent()?.clone();
224    }
225
226    // Identify convenient expression kinds.
227    let expr_node = ancestor;
228    let expr = expr_node.cast::<ast::Expr>()?;
229    Some(match expr {
230        ast::Expr::Label(_) => DerefTarget::Label(expr_node),
231        ast::Expr::Ref(_) => DerefTarget::Ref(expr_node),
232        // TODO: Add MathCall?
233        ast::Expr::FuncCall(call) => {
234            DerefTarget::Callee(expr_node.find(call.callee().span())?)
235        }
236        ast::Expr::SetRule(set) => {
237            DerefTarget::Callee(expr_node.find(set.target().span())?)
238        }
239        ast::Expr::Ident(_)
240        | ast::Expr::FieldAccess(_)
241        | ast::Expr::MathIdent(_)
242        | ast::Expr::MathFieldAccess(_) => DerefTarget::VarAccess(expr_node),
243        ast::Expr::Str(_) => {
244            let parent = expr_node.parent()?;
245            if parent.kind() == SyntaxKind::ModuleImport {
246                DerefTarget::ImportPath(expr_node)
247            } else if parent.kind() == SyntaxKind::ModuleInclude {
248                DerefTarget::IncludePath(expr_node)
249            } else {
250                DerefTarget::Code(expr_node)
251            }
252        }
253        _ if expr.hash()
254            || matches!(expr_node.kind(), SyntaxKind::MathIdent | SyntaxKind::Error) =>
255        {
256            DerefTarget::Code(expr_node)
257        }
258        _ => return None,
259    })
260}
261
262/// Classes of expressions that can be operated on by IDE functionality.
263#[derive(Debug, Clone)]
264pub enum DerefTarget<'a> {
265    /// A variable access expression.
266    ///
267    /// It can be either an identifier or a field access.
268    VarAccess(LinkedNode<'a>),
269    /// A function call expression.
270    Callee(LinkedNode<'a>),
271    /// An import path expression.
272    ImportPath(LinkedNode<'a>),
273    /// An include path expression.
274    IncludePath(LinkedNode<'a>),
275    /// Any code expression.
276    Code(LinkedNode<'a>),
277    /// A label expression.
278    Label(LinkedNode<'a>),
279    /// A reference expression.
280    Ref(LinkedNode<'a>),
281}
282
283#[cfg(test)]
284mod tests {
285    use std::borrow::Borrow;
286
287    use ecow::EcoString;
288    use typst::foundations::Value;
289    use typst::syntax::{LinkedNode, Side};
290
291    use super::named_items;
292    use crate::tests::{FilePos, TestWorld, WorldLike};
293
294    type Response = Vec<(EcoString, Option<Value>)>;
295
296    trait ResponseExt {
297        fn must_include<'a>(&self, includes: impl IntoIterator<Item = &'a str>) -> &Self;
298        fn must_exclude<'a>(&self, excludes: impl IntoIterator<Item = &'a str>) -> &Self;
299        fn must_include_value(&self, name_value: (&str, Option<&Value>)) -> &Self;
300    }
301
302    impl ResponseExt for Response {
303        #[track_caller]
304        fn must_include<'a>(&self, includes: impl IntoIterator<Item = &'a str>) -> &Self {
305            for item in includes {
306                assert!(
307                    self.iter().any(|v| v.0 == item),
308                    "{item:?} was not contained in {self:?}",
309                );
310            }
311            self
312        }
313
314        #[track_caller]
315        fn must_exclude<'a>(&self, excludes: impl IntoIterator<Item = &'a str>) -> &Self {
316            for item in excludes {
317                assert!(
318                    !self.iter().any(|v| v.0 == item),
319                    "{item:?} was wrongly contained in {self:?}",
320                );
321            }
322            self
323        }
324
325        #[track_caller]
326        fn must_include_value(&self, name_value: (&str, Option<&Value>)) -> &Self {
327            assert!(
328                self.iter().any(|v| (v.0.as_str(), v.1.as_ref()) == name_value),
329                "{name_value:?} was not contained in {self:?}",
330            );
331            self
332        }
333    }
334
335    #[track_caller]
336    fn test(world: impl WorldLike, pos: impl FilePos) -> Response {
337        let world = world.acquire();
338        let world = world.borrow();
339        let (source, cursor) = pos.resolve(world);
340        let node = LinkedNode::new(source.root());
341        let leaf = node.leaf_at(cursor, Side::After).unwrap();
342        let mut items = vec![];
343        named_items(world, leaf, |s| {
344            items.push((s.name().clone(), s.value().clone()));
345            None::<()>
346        });
347        items
348    }
349
350    #[test]
351    fn test_named_items_simple() {
352        let s = "#let a = 1;#let b = 2;";
353        test(s, 8).must_include(["a"]).must_exclude(["b"]);
354        test(s, 15).must_include(["b"]);
355    }
356
357    #[test]
358    fn test_named_items_param() {
359        let pos = "#let f(a) = 1;#let b = 2;";
360        test(pos, 12).must_include(["a"]);
361        test(pos, 19).must_include(["b", "f"]).must_exclude(["a"]);
362
363        let named = "#let f(a: b) = 1;#let b = 2;";
364        test(named, 15).must_include(["a", "f"]).must_exclude(["b"]);
365    }
366
367    #[test]
368    fn test_named_items_import() {
369        test("#import \"foo.typ\"", 2).must_include(["foo"]);
370        test("#import \"foo.typ\" as bar", 2)
371            .must_include(["bar"])
372            .must_exclude(["foo"]);
373    }
374
375    #[test]
376    fn test_named_items_import_items() {
377        test("#import \"foo.typ\": a; #(a);", 2)
378            .must_include(["a"])
379            .must_exclude(["foo"]);
380
381        let world = TestWorld::new("#import \"foo.typ\": a.b; #(b);")
382            .with_source("foo.typ", "#import \"a.typ\"")
383            .with_source("a.typ", "#let b = 1;");
384        test(&world, 2).must_include_value(("b", Some(&Value::Int(1))));
385    }
386}