1use std::{borrow::Cow, ops::Range};
10
11use itertools::Itertools;
12use typst_syntax::{ast::*, LinkedNode, Source, Span, SyntaxKind, SyntaxNode};
13
14use crate::{
15 pretty::Mode,
16 utils::{self, indent_4_to_2},
17 AttrStore, Error, PrettyPrinter, Typstyle,
18};
19
20#[derive(Debug, Clone)]
22pub struct RangeResult {
23 pub source_range: Range<usize>,
25 pub content: String,
27}
28
29impl RangeResult {
30 fn empty(pos: usize) -> Self {
31 Self {
32 source_range: pos..pos,
33 content: String::new(),
34 }
35 }
36}
37
38impl Typstyle {
39 pub fn format_source_range(
50 &self,
51 source: Source,
52 utf8_range: Range<usize>,
53 ) -> Result<RangeResult, Error> {
54 let trimmed_range = trim_range(source.text(), utf8_range);
55 let (node, mode) = get_node_and_mode_for_range(&source, trimmed_range.clone())?;
56
57 let Some((node, node_range)) = refine_node_range(node, trimmed_range.clone()) else {
58 return Ok(RangeResult::empty(trimmed_range.start)); };
60
61 let attrs = AttrStore::new(&node); let printer = PrettyPrinter::new(self.config.clone(), attrs);
63 let doc = printer.try_convert_with_mode(&node, mode)?;
64
65 let indent = utils::count_spaces_after_last_newline(source.text(), node_range.start);
67 let text = doc
68 .nest(indent as isize)
69 .print(self.config.max_width)
70 .to_string();
71
72 Ok(RangeResult {
73 source_range: node_range,
74 content: text,
75 })
76 }
77
78 pub fn format_source_range_ir(
87 &self,
88 source: Source,
89 utf8_range: Range<usize>,
90 ) -> Result<RangeResult, Error> {
91 let trimmed_range = trim_range(source.text(), utf8_range);
92 let (node, mode) = get_node_and_mode_for_range(&source, trimmed_range.clone())?;
93
94 let Some((node, node_range)) = refine_node_range(node, trimmed_range.clone()) else {
95 return Ok(RangeResult::empty(trimmed_range.start)); };
97
98 let attrs = AttrStore::new(&node);
99 let printer = PrettyPrinter::new(self.config.clone(), attrs);
100 let doc = printer.try_convert_with_mode(&node, mode)?;
101
102 let ir = indent_4_to_2(&format!("{doc:#?}"));
103
104 Ok(RangeResult {
105 source_range: node_range,
106 content: ir,
107 })
108 }
109}
110
111pub fn format_range_ast(source: &Source, utf8_range: Range<usize>) -> Result<RangeResult, Error> {
121 let node = get_node_for_range(source, utf8_range)?;
122 Ok(RangeResult {
123 source_range: node.range(),
124 content: indent_4_to_2(&format!("{node:#?}")),
125 })
126}
127
128fn refine_node_range(
129 node: LinkedNode,
130 range: Range<usize>,
131) -> Option<(Cow<SyntaxNode>, Range<usize>)> {
132 match node.kind() {
133 SyntaxKind::Markup | SyntaxKind::Code | SyntaxKind::Math => {
134 let inner = node
136 .children()
137 .skip_while(|it| it.range().end <= range.start)
138 .take_while(|it| it.range().start < range.end)
139 .collect_vec();
140 let sub_range = inner.first()?.range().start..inner.last()?.range().end;
141
142 let new_node = SyntaxNode::inner(
146 node.kind(),
147 inner.into_iter().map(|it| it.get().clone()).collect(),
148 );
149 Some((Cow::Owned(new_node), sub_range))
150 }
151 _ => Some((Cow::Borrowed(node.get()), node.range())),
152 }
153}
154
155fn get_node_for_range(source: &Source, utf8_range: Range<usize>) -> Result<LinkedNode<'_>, Error> {
156 let trimmed_range = trim_range(source.text(), utf8_range);
158
159 get_node_and_mode_for_range(source, trimmed_range).map(|(node, _)| node)
160}
161
162fn trim_range(s: &str, mut rng: Range<usize>) -> Range<usize> {
164 rng.end = rng.start + s[rng.clone()].trim_end().len();
165 rng.start = rng.end - s[rng.clone()].trim_start().len();
166 rng
167}
168
169fn get_node_and_mode_for_range(
170 source: &Source,
171 utf8_range: Range<usize>,
172) -> Result<(LinkedNode<'_>, Mode), Error> {
173 get_node_cover_range(source, utf8_range)
174 .filter(|(node, _)| !node.erroneous())
175 .ok_or(Error::SyntaxError)
176}
177
178fn get_node_cover_range(source: &Source, range: Range<usize>) -> Option<(LinkedNode<'_>, Mode)> {
180 let range = range.start..range.end.min(source.len_bytes());
181 get_node_cover_range_impl(range, LinkedNode::new(source.root()), Mode::Markup)
182 .and_then(|(span, mode)| source.find(span).map(|node| (node, mode)))
183}
184
185fn get_node_cover_range_impl(
186 range: Range<usize>,
187 node: LinkedNode<'_>,
188 mode: Mode,
189) -> Option<(Span, Mode)> {
190 let mode = match node.kind() {
191 SyntaxKind::Markup => Mode::Markup,
192 SyntaxKind::CodeBlock => Mode::Code,
193 SyntaxKind::Equation => Mode::Math,
194 _ => mode,
195 };
196
197 for child in node.children() {
199 if let Some(res) = get_node_cover_range_impl(range.clone(), child, mode) {
200 return Some(res);
201 }
202 }
203
204 let node_range = node.range();
206 (node_range.start <= range.start
207 && node_range.end >= range.end
208 && (node.is::<Markup>()
209 || node.is::<Code>()
210 || node.is::<Math>()
211 || node.is::<Expr>()
212 || node.is::<Pattern>()))
213 .then(|| (node.span(), mode))
214 }
216
217#[cfg(test)]
218mod tests {
219 use insta::{assert_debug_snapshot, assert_snapshot};
220
221 use super::*;
222
223 fn test(content: &str, lc_range: Range<(usize, usize)>) -> RangeResult {
224 let source = Source::detached(content);
225 let range = source
226 .line_column_to_byte(lc_range.start.0, lc_range.start.1)
227 .unwrap()
228 ..source
229 .line_column_to_byte(lc_range.end.0, lc_range.end.1)
230 .unwrap();
231
232 let t = Typstyle::default();
233 t.format_source_range(source, range).unwrap()
234 }
235
236 #[test]
237 fn cover_markup() {
238 let res = test(
239 "
240#(1+1)
241#(2+2)
242#(3+3)",
243 (1, 1)..(2, 2),
244 );
245
246 assert_debug_snapshot!(res.source_range, @"2..14");
247 assert_snapshot!(res.content, @r"
248 (1 + 1)
249 #(2 + 2)
250 ");
251 }
252
253 #[test]
254 fn cover_markup_empty() {
255 let res = test(
256 "
257#(1+1)
258#(2+2)",
259 (1, 1)..(1, 1),
260 );
261
262 assert_debug_snapshot!(res.source_range, @"2..7");
263 assert_snapshot!(res.content, @"(1 + 1)");
264 }
265
266 #[test]
267 fn cover_markup_empty2() {
268 let res = test(" a b ", (0, 3)..(0, 3));
269
270 assert_debug_snapshot!(res.source_range, @"2..5");
271 assert_snapshot!(res.content, @"");
272 }
273
274 #[test]
275 fn cover_code() {
276 let res = test(
277 r#"""
278#{
279("1"+"1")
280("2"+"2")
281("3"+"3")
282}"""#,
283 (2, 2)..(3, 3),
284 );
285
286 assert_debug_snapshot!(res.source_range, @"6..25");
287 assert_snapshot!(res.content, @r#"
288 ("1" + "1")
289 ("2" + "2")
290 "#);
291 }
292
293 #[test]
294 fn cover_code_empty() {
295 let res = test(
296 r#"""
297#{
298("1"+"1")
299("2"+"2")
300}"""#,
301 (2, 2)..(2, 2),
302 );
303
304 assert_debug_snapshot!(res.source_range, @"7..10");
305 assert_snapshot!(res.content, @r#""1""#);
306 }
307
308 #[test]
309 fn cover_math() {
310 let res = test(
311 r#"""$
312sin( x )
313cos( y )
314tan( z )
315$"""#,
316 (2, 2)..(3, 3),
317 );
318
319 assert_debug_snapshot!(res.source_range, @"13..30");
320 assert_snapshot!(res.content, @r"
321 cos(y)
322 tan(z)
323 ");
324 }
325}