typstyle_core/
partial.rs

1//! Range-based formatting and analysis utilities for Typstyle.
2//!
3//! All byte ranges in this module are specified as UTF-8 offsets.
4//! If you have a UTF-16 range or line-column range (as used in LSP or some editors),
5//! you must convert it to a UTF-8 byte offset before calling these functions.
6//! The `typst_syntax::Source` API provides helpers for converting between line-column,
7//! UTF-16, and UTF-8 offsets.
8
9use std::{borrow::Cow, ops::Range};
10
11use itertools::Itertools;
12use typst_syntax::{ast::*, LinkedNode, Source, Span, SyntaxKind, SyntaxNode};
13
14use crate::{
15    pretty::Mode,
16    utils::{self, indent_4_to_2},
17    AttrStore, Error, PrettyPrinter, Typstyle,
18};
19
20/// Result of a range-based formatting or analysis operation.
21#[derive(Debug, Clone)]
22pub struct RangeResult {
23    /// The actual source range that was processed (may be larger than requested to include complete nodes).
24    pub source_range: Range<usize>,
25    /// The output for the range (formatted text, IR, AST, etc.).
26    pub content: String,
27}
28
29impl RangeResult {
30    fn empty(pos: usize) -> Self {
31        Self {
32            source_range: pos..pos,
33            content: String::new(),
34        }
35    }
36}
37
38impl Typstyle {
39    /// Format the smallest syntax node that fully covers the given byte range.
40    ///
41    /// The formatted range may be larger than the input to ensure valid syntax.
42    ///
43    /// # Arguments
44    /// - `source`: The source code.
45    /// - `utf8_range`: The UTF-8 byte range to format.
46    ///
47    /// # Returns
48    /// A `RangeResult` with the formatted text and actual node range.
49    pub fn format_source_range(
50        &self,
51        source: Source,
52        utf8_range: Range<usize>,
53    ) -> Result<RangeResult, Error> {
54        let trimmed_range = trim_range(source.text(), utf8_range);
55        let (node, mode) = get_node_and_mode_for_range(&source, trimmed_range.clone())?;
56
57        let Some((node, node_range)) = refine_node_range(node, trimmed_range.clone()) else {
58            return Ok(RangeResult::empty(trimmed_range.start)); // No edit
59        };
60
61        let attrs = AttrStore::new(&node); // Here we only compute the attributes of that subtree.
62        let printer = PrettyPrinter::new(self.config.clone(), attrs);
63        let doc = printer.try_convert_with_mode(&node, mode)?;
64
65        // Infer indent from context.
66        let indent = utils::count_spaces_after_last_newline(source.text(), node_range.start);
67        let text = doc
68            .nest(indent as isize)
69            .print(self.config.max_width)
70            .to_string();
71
72        Ok(RangeResult {
73            source_range: node_range,
74            content: text,
75        })
76    }
77
78    /// Get the pretty IR for the smallest syntax node covering the given byte range.
79    ///
80    /// # Arguments
81    /// - `source`: The source code.
82    /// - `utf8_range`: The UTF-8 byte range to analyze.
83    ///
84    /// # Returns
85    /// A `RangeResult` with the IR and actual node range.
86    pub fn format_source_range_ir(
87        &self,
88        source: Source,
89        utf8_range: Range<usize>,
90    ) -> Result<RangeResult, Error> {
91        let trimmed_range = trim_range(source.text(), utf8_range);
92        let (node, mode) = get_node_and_mode_for_range(&source, trimmed_range.clone())?;
93
94        let Some((node, node_range)) = refine_node_range(node, trimmed_range.clone()) else {
95            return Ok(RangeResult::empty(trimmed_range.start)); // No edit
96        };
97
98        let attrs = AttrStore::new(&node);
99        let printer = PrettyPrinter::new(self.config.clone(), attrs);
100        let doc = printer.try_convert_with_mode(&node, mode)?;
101
102        let ir = indent_4_to_2(&format!("{doc:#?}"));
103
104        Ok(RangeResult {
105            source_range: node_range,
106            content: ir,
107        })
108    }
109}
110
111/// Formats the smallest syntax node covering the given byte range as a debug AST string
112/// with 2-space indentation. Returns the node's actual source range and formatted AST.
113///
114/// # Arguments
115/// - `source`: The source code.
116/// - `utf8_range`: The UTF-8 byte range to analyze.
117///
118/// # Returns
119/// A `RangeResult` with the node's range and formatted AST.
120pub fn format_range_ast(source: &Source, utf8_range: Range<usize>) -> Result<RangeResult, Error> {
121    let node = get_node_for_range(source, utf8_range)?;
122    Ok(RangeResult {
123        source_range: node.range(),
124        content: indent_4_to_2(&format!("{node:#?}")),
125    })
126}
127
128fn refine_node_range(
129    node: LinkedNode,
130    range: Range<usize>,
131) -> Option<(Cow<SyntaxNode>, Range<usize>)> {
132    match node.kind() {
133        SyntaxKind::Markup | SyntaxKind::Code | SyntaxKind::Math => {
134            // find the smallest children covering the range
135            let inner = node
136                .children()
137                .skip_while(|it| it.range().end <= range.start)
138                .take_while(|it| it.range().start < range.end)
139                .collect_vec();
140            let sub_range = inner.first()?.range().start..inner.last()?.range().end;
141
142            // Create a synthetic (mock) syntax node for fine-grained selection.
143            // This is a key part of the functionality, as it allows us to refine
144            // the range to the smallest children covering the specified range.
145            let new_node = SyntaxNode::inner(
146                node.kind(),
147                inner.into_iter().map(|it| it.get().clone()).collect(),
148            );
149            Some((Cow::Owned(new_node), sub_range))
150        }
151        _ => Some((Cow::Borrowed(node.get()), node.range())),
152    }
153}
154
155fn get_node_for_range(source: &Source, utf8_range: Range<usize>) -> Result<LinkedNode<'_>, Error> {
156    // Trim the given range to ensure no space aside.
157    let trimmed_range = trim_range(source.text(), utf8_range);
158
159    get_node_and_mode_for_range(source, trimmed_range).map(|(node, _)| node)
160}
161
162/// Get the range of the string obtained from trimming in the original string.
163fn trim_range(s: &str, mut rng: Range<usize>) -> Range<usize> {
164    rng.end = rng.start + s[rng.clone()].trim_end().len();
165    rng.start = rng.end - s[rng.clone()].trim_start().len();
166    rng
167}
168
169fn get_node_and_mode_for_range(
170    source: &Source,
171    utf8_range: Range<usize>,
172) -> Result<(LinkedNode<'_>, Mode), Error> {
173    get_node_cover_range(source, utf8_range)
174        .filter(|(node, _)| !node.erroneous())
175        .ok_or(Error::SyntaxError)
176}
177
178/// Get a Markup/Expr/Pattern node from source with minimal span that covering the given range.
179fn get_node_cover_range(source: &Source, range: Range<usize>) -> Option<(LinkedNode<'_>, Mode)> {
180    let range = range.start..range.end.min(source.len_bytes());
181    get_node_cover_range_impl(range, LinkedNode::new(source.root()), Mode::Markup)
182        .and_then(|(span, mode)| source.find(span).map(|node| (node, mode)))
183}
184
185fn get_node_cover_range_impl(
186    range: Range<usize>,
187    node: LinkedNode<'_>,
188    mode: Mode,
189) -> Option<(Span, Mode)> {
190    let mode = match node.kind() {
191        SyntaxKind::Markup => Mode::Markup,
192        SyntaxKind::CodeBlock => Mode::Code,
193        SyntaxKind::Equation => Mode::Math,
194        _ => mode,
195    };
196
197    // First, try to find a child node that covers the range
198    for child in node.children() {
199        if let Some(res) = get_node_cover_range_impl(range.clone(), child, mode) {
200            return Some(res);
201        }
202    }
203
204    // If no child covers the range, check if this node covers it
205    let node_range = node.range();
206    (node_range.start <= range.start
207        && node_range.end >= range.end
208        && (node.is::<Markup>()
209            || node.is::<Code>()
210            || node.is::<Math>()
211            || node.is::<Expr>()
212            || node.is::<Pattern>()))
213    .then(|| (node.span(), mode))
214    // It returns span to avoid problems with borrowing.
215}
216
217#[cfg(test)]
218mod tests {
219    use insta::{assert_debug_snapshot, assert_snapshot};
220
221    use super::*;
222
223    fn test(content: &str, lc_range: Range<(usize, usize)>) -> RangeResult {
224        let source = Source::detached(content);
225        let range = source
226            .line_column_to_byte(lc_range.start.0, lc_range.start.1)
227            .unwrap()
228            ..source
229                .line_column_to_byte(lc_range.end.0, lc_range.end.1)
230                .unwrap();
231
232        let t = Typstyle::default();
233        t.format_source_range(source, range).unwrap()
234    }
235
236    #[test]
237    fn cover_markup() {
238        let res = test(
239            "
240#(1+1)
241#(2+2)
242#(3+3)",
243            (1, 1)..(2, 2),
244        );
245
246        assert_debug_snapshot!(res.source_range, @"2..14");
247        assert_snapshot!(res.content, @r"
248        (1 + 1)
249        #(2 + 2)
250        ");
251    }
252
253    #[test]
254    fn cover_markup_empty() {
255        let res = test(
256            "
257#(1+1)
258#(2+2)",
259            (1, 1)..(1, 1),
260        );
261
262        assert_debug_snapshot!(res.source_range, @"2..7");
263        assert_snapshot!(res.content, @"(1 + 1)");
264    }
265
266    #[test]
267    fn cover_markup_empty2() {
268        let res = test(" a   b ", (0, 3)..(0, 3));
269
270        assert_debug_snapshot!(res.source_range, @"2..5");
271        assert_snapshot!(res.content, @"");
272    }
273
274    #[test]
275    fn cover_code() {
276        let res = test(
277            r#"""
278#{
279("1"+"1")
280("2"+"2")
281("3"+"3")
282}"""#,
283            (2, 2)..(3, 3),
284        );
285
286        assert_debug_snapshot!(res.source_range, @"6..25");
287        assert_snapshot!(res.content, @r#"
288        ("1" + "1")
289        ("2" + "2")
290        "#);
291    }
292
293    #[test]
294    fn cover_code_empty() {
295        let res = test(
296            r#"""
297#{
298("1"+"1")
299("2"+"2")
300}"""#,
301            (2, 2)..(2, 2),
302        );
303
304        assert_debug_snapshot!(res.source_range, @"7..10");
305        assert_snapshot!(res.content, @r#""1""#);
306    }
307
308    #[test]
309    fn cover_math() {
310        let res = test(
311            r#"""$
312sin( x )
313cos( y )
314tan( z )
315$"""#,
316            (2, 2)..(3, 3),
317        );
318
319        assert_debug_snapshot!(res.source_range, @"13..30");
320        assert_snapshot!(res.content, @r"
321        cos(y)
322        tan(z)
323        ");
324    }
325}