Skip to main content

tree_sitter_cli/
highlight.rs

1use std::{
2    collections::{BTreeMap, HashSet},
3    fmt::Write,
4    fs,
5    io::{self, Write as _},
6    path::{self, Path, PathBuf},
7    str,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11
12use ansi_colours::{ansi256_from_rgb, rgb_from_ansi256};
13use anstyle::{Ansi256Color, AnsiColor, Color, Effects, RgbColor};
14use anyhow::Result;
15use log::{info, warn};
16use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
17use serde_json::{json, Value};
18use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer};
19use tree_sitter_loader::Loader;
20
21pub const HTML_HEAD_HEADER: &str = "
22<!doctype HTML>
23<head>
24  <title>Tree-sitter Highlighting</title>
25  <style>
26    body {
27      font-family: monospace
28    }
29    .line-number {
30      user-select: none;
31      text-align: right;
32      color: rgba(27,31,35,.3);
33      padding: 0 10px;
34    }
35    .line {
36      white-space: pre;
37    }
38  </style>";
39
40pub const HTML_BODY_HEADER: &str = "
41</head>
42<body>
43";
44
45pub const HTML_FOOTER: &str = "
46</body>
47";
48
49#[derive(Debug, Default)]
50pub struct Style {
51    pub ansi: anstyle::Style,
52    pub css: Option<String>,
53}
54
55#[derive(Debug)]
56pub struct Theme {
57    pub styles: Vec<Style>,
58    pub highlight_names: Vec<String>,
59}
60
61#[derive(Default, Deserialize, Serialize)]
62pub struct ThemeConfig {
63    #[serde(default)]
64    pub theme: Theme,
65}
66
67impl Theme {
68    pub fn load(path: &path::Path) -> io::Result<Self> {
69        let json = fs::read_to_string(path)?;
70        Ok(serde_json::from_str(&json).unwrap_or_default())
71    }
72
73    #[must_use]
74    pub fn default_style(&self) -> Style {
75        Style::default()
76    }
77}
78
79impl<'de> Deserialize<'de> for Theme {
80    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
81    where
82        D: Deserializer<'de>,
83    {
84        let mut styles = Vec::new();
85        let mut highlight_names = Vec::new();
86        if let Ok(colors) = BTreeMap::<String, Value>::deserialize(deserializer) {
87            styles.reserve(colors.len());
88            highlight_names.reserve(colors.len());
89            for (name, style_value) in colors {
90                let mut style = Style::default();
91                parse_style(&mut style, style_value);
92                highlight_names.push(name);
93                styles.push(style);
94            }
95        }
96        Ok(Self {
97            styles,
98            highlight_names,
99        })
100    }
101}
102
103impl Serialize for Theme {
104    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
105    where
106        S: Serializer,
107    {
108        let mut map = serializer.serialize_map(Some(self.styles.len()))?;
109        for (name, style) in self.highlight_names.iter().zip(&self.styles) {
110            let style = &style.ansi;
111            let color = style.get_fg_color().map(|color| match color {
112                Color::Ansi(color) => match color {
113                    AnsiColor::Black => json!("black"),
114                    AnsiColor::Blue => json!("blue"),
115                    AnsiColor::Cyan => json!("cyan"),
116                    AnsiColor::Green => json!("green"),
117                    AnsiColor::Magenta => json!("purple"),
118                    AnsiColor::Red => json!("red"),
119                    AnsiColor::White => json!("white"),
120                    AnsiColor::Yellow => json!("yellow"),
121                    _ => unreachable!(),
122                },
123                Color::Ansi256(Ansi256Color(n)) => json!(n),
124                Color::Rgb(RgbColor(r, g, b)) => json!(format!("#{r:x?}{g:x?}{b:x?}")),
125            });
126            let effects = style.get_effects();
127            if effects.contains(Effects::BOLD)
128                || effects.contains(Effects::ITALIC)
129                || effects.contains(Effects::UNDERLINE)
130            {
131                let mut style_json = BTreeMap::new();
132                if let Some(color) = color {
133                    style_json.insert("color", color);
134                }
135                if effects.contains(Effects::BOLD) {
136                    style_json.insert("bold", Value::Bool(true));
137                }
138                if effects.contains(Effects::ITALIC) {
139                    style_json.insert("italic", Value::Bool(true));
140                }
141                if effects.contains(Effects::UNDERLINE) {
142                    style_json.insert("underline", Value::Bool(true));
143                }
144                map.serialize_entry(&name, &style_json)?;
145            } else if let Some(color) = color {
146                map.serialize_entry(&name, &color)?;
147            } else {
148                map.serialize_entry(&name, &Value::Null)?;
149            }
150        }
151        map.end()
152    }
153}
154
155impl Default for Theme {
156    fn default() -> Self {
157        serde_json::from_value(json!({
158            "attribute": {"color": 124, "italic": true},
159            "comment": {"color": 245, "italic": true},
160            "constant": 94,
161            "constant.builtin": {"color": 94, "bold": true},
162            "constructor": 136,
163            "embedded": null,
164            "function": 26,
165            "function.builtin": {"color": 26, "bold": true},
166            "keyword": 56,
167            "module": 136,
168            "number": {"color": 94, "bold": true},
169            "operator": {"color": 239, "bold": true},
170            "property": 124,
171            "property.builtin": {"color": 124, "bold": true},
172            "punctuation": 239,
173            "punctuation.bracket": 239,
174            "punctuation.delimiter": 239,
175            "punctuation.special": 239,
176            "string": 28,
177            "string.special": 30,
178            "tag": 18,
179            "type": 23,
180            "type.builtin": {"color": 23, "bold": true},
181            "variable": 252,
182            "variable.builtin": {"color": 252, "bold": true},
183            "variable.parameter": {"color": 252, "underline": true}
184        }))
185        .unwrap()
186    }
187}
188
189fn parse_style(style: &mut Style, json: Value) {
190    if let Value::Object(entries) = json {
191        for (property_name, value) in entries {
192            match property_name.as_str() {
193                "bold" if value == Value::Bool(true) => {
194                    style.ansi = style.ansi.bold();
195                }
196                "italic" if value == Value::Bool(true) => {
197                    style.ansi = style.ansi.italic();
198                }
199                "underline" if value == Value::Bool(true) => {
200                    style.ansi = style.ansi.underline();
201                }
202                "color" => {
203                    if let Some(color) = parse_color(value) {
204                        style.ansi = style.ansi.fg_color(Some(color));
205                    }
206                }
207                _ => {}
208            }
209        }
210        style.css = Some(style_to_css(style.ansi));
211    } else if let Some(color) = parse_color(json) {
212        style.ansi = style.ansi.fg_color(Some(color));
213        style.css = Some(style_to_css(style.ansi));
214    } else {
215        style.css = None;
216    }
217
218    if let Some(Color::Rgb(RgbColor(red, green, blue))) = style.ansi.get_fg_color() {
219        if !terminal_supports_truecolor() {
220            let ansi256 = Color::Ansi256(Ansi256Color(ansi256_from_rgb((red, green, blue))));
221            style.ansi = style.ansi.fg_color(Some(ansi256));
222        }
223    }
224}
225
226fn parse_color(json: Value) -> Option<Color> {
227    match json {
228        Value::Number(n) => n.as_u64().map(|n| Color::Ansi256(Ansi256Color(n as u8))),
229        Value::String(s) => match s.to_lowercase().as_str() {
230            "black" => Some(Color::Ansi(AnsiColor::Black)),
231            "blue" => Some(Color::Ansi(AnsiColor::Blue)),
232            "cyan" => Some(Color::Ansi(AnsiColor::Cyan)),
233            "green" => Some(Color::Ansi(AnsiColor::Green)),
234            "purple" => Some(Color::Ansi(AnsiColor::Magenta)),
235            "red" => Some(Color::Ansi(AnsiColor::Red)),
236            "white" => Some(Color::Ansi(AnsiColor::White)),
237            "yellow" => Some(Color::Ansi(AnsiColor::Yellow)),
238            s => {
239                if let Some((red, green, blue)) = hex_string_to_rgb(s) {
240                    Some(Color::Rgb(RgbColor(red, green, blue)))
241                } else {
242                    None
243                }
244            }
245        },
246        _ => None,
247    }
248}
249
250fn hex_string_to_rgb(s: &str) -> Option<(u8, u8, u8)> {
251    if s.starts_with('#') && s.len() >= 7 {
252        if let (Ok(red), Ok(green), Ok(blue)) = (
253            u8::from_str_radix(&s[1..3], 16),
254            u8::from_str_radix(&s[3..5], 16),
255            u8::from_str_radix(&s[5..7], 16),
256        ) {
257            Some((red, green, blue))
258        } else {
259            None
260        }
261    } else {
262        None
263    }
264}
265
266fn style_to_css(style: anstyle::Style) -> String {
267    let mut result = String::new();
268    let effects = style.get_effects();
269    if effects.contains(Effects::UNDERLINE) {
270        write!(&mut result, "text-decoration: underline;").unwrap();
271    }
272    if effects.contains(Effects::BOLD) {
273        write!(&mut result, "font-weight: bold;").unwrap();
274    }
275    if effects.contains(Effects::ITALIC) {
276        write!(&mut result, "font-style: italic;").unwrap();
277    }
278    if let Some(color) = style.get_fg_color() {
279        write_color(&mut result, color);
280    }
281    result
282}
283
284fn write_color(buffer: &mut String, color: Color) {
285    match color {
286        Color::Ansi(color) => match color {
287            AnsiColor::Black => write!(buffer, "color: black").unwrap(),
288            AnsiColor::Red => write!(buffer, "color: red").unwrap(),
289            AnsiColor::Green => write!(buffer, "color: green").unwrap(),
290            AnsiColor::Yellow => write!(buffer, "color: yellow").unwrap(),
291            AnsiColor::Blue => write!(buffer, "color: blue").unwrap(),
292            AnsiColor::Magenta => write!(buffer, "color: purple").unwrap(),
293            AnsiColor::Cyan => write!(buffer, "color: cyan").unwrap(),
294            AnsiColor::White => write!(buffer, "color: white").unwrap(),
295            _ => unreachable!(),
296        },
297        Color::Ansi256(Ansi256Color(n)) => {
298            let (r, g, b) = rgb_from_ansi256(n);
299            write!(buffer, "color: #{r:02x}{g:02x}{b:02x}").unwrap();
300        }
301        Color::Rgb(RgbColor(r, g, b)) => write!(buffer, "color: #{r:02x}{g:02x}{b:02x}").unwrap(),
302    }
303}
304
305fn terminal_supports_truecolor() -> bool {
306    std::env::var("COLORTERM")
307        .is_ok_and(|truecolor| truecolor == "truecolor" || truecolor == "24bit")
308}
309
310pub struct HighlightOptions {
311    pub theme: Theme,
312    pub check: bool,
313    pub captures_path: Option<PathBuf>,
314    pub inline_styles: bool,
315    pub html: bool,
316    pub quiet: bool,
317    pub print_time: bool,
318    pub cancellation_flag: Arc<AtomicUsize>,
319}
320
321pub fn highlight(
322    loader: &Loader,
323    path: &Path,
324    name: &str,
325    config: &HighlightConfiguration,
326    print_name: bool,
327    opts: &HighlightOptions,
328) -> Result<()> {
329    if opts.check {
330        let names = if let Some(path) = opts.captures_path.as_deref() {
331            let file = fs::read_to_string(path)?;
332            let capture_names = file
333                .lines()
334                .filter_map(|line| {
335                    if line.trim().is_empty() || line.trim().starts_with(';') {
336                        return None;
337                    }
338                    line.split(';').next().map(|s| s.trim().trim_matches('"'))
339                })
340                .collect::<HashSet<_>>();
341            config.nonconformant_capture_names(&capture_names)
342        } else {
343            config.nonconformant_capture_names(&HashSet::new())
344        };
345        if names.is_empty() {
346            info!("All highlight captures conform to standards.");
347        } else {
348            warn!(
349                "Non-standard highlight {} detected:\n* {}",
350                if names.len() > 1 {
351                    "captures"
352                } else {
353                    "capture"
354                },
355                names.join("\n* ")
356            );
357        }
358    }
359
360    let source = fs::read(path)?;
361    let stdout = io::stdout();
362    let mut stdout = stdout.lock();
363    let time = Instant::now();
364    let mut highlighter = Highlighter::new();
365    let events =
366        highlighter.highlight(config, &source, Some(&opts.cancellation_flag), |string| {
367            loader.highlight_config_for_injection_string(string)
368        })?;
369    let theme = &opts.theme;
370
371    if !opts.quiet && print_name {
372        writeln!(&mut stdout, "{name}")?;
373    }
374
375    if opts.html {
376        if !opts.quiet {
377            writeln!(&mut stdout, "{HTML_HEAD_HEADER}")?;
378            writeln!(&mut stdout, "  <style>")?;
379            let names = theme.highlight_names.iter();
380            let styles = theme.styles.iter();
381            for (name, style) in names.zip(styles) {
382                if let Some(css) = &style.css {
383                    writeln!(&mut stdout, "    .{name} {{ {css}; }}")?;
384                }
385            }
386            writeln!(&mut stdout, "  </style>")?;
387            writeln!(&mut stdout, "{HTML_BODY_HEADER}")?;
388        }
389
390        let mut renderer = HtmlRenderer::new();
391        renderer.render(events, &source, &move |highlight, output| {
392            if opts.inline_styles {
393                output.extend(b"style='");
394                output.extend(
395                    theme.styles[highlight.0]
396                        .css
397                        .as_ref()
398                        .map_or_else(|| "".as_bytes(), |css_style| css_style.as_bytes()),
399                );
400                output.extend(b"'");
401            } else {
402                output.extend(b"class='");
403                let mut parts = theme.highlight_names[highlight.0].split('.').peekable();
404                while let Some(part) = parts.next() {
405                    output.extend(part.as_bytes());
406                    if parts.peek().is_some() {
407                        output.extend(b" ");
408                    }
409                }
410                output.extend(b"'");
411            }
412        })?;
413
414        if !opts.quiet {
415            writeln!(&mut stdout, "<table>")?;
416            for (i, line) in renderer.lines().enumerate() {
417                writeln!(
418                    &mut stdout,
419                    "<tr><td class=line-number>{}</td><td class=line>{line}</td></tr>",
420                    i + 1,
421                )?;
422            }
423            writeln!(&mut stdout, "</table>")?;
424            writeln!(&mut stdout, "{HTML_FOOTER}")?;
425        }
426    } else {
427        let mut style_stack = vec![theme.default_style().ansi];
428        for event in events {
429            match event? {
430                HighlightEvent::HighlightStart(highlight) => {
431                    style_stack.push(theme.styles[highlight.0].ansi);
432                }
433                HighlightEvent::HighlightEnd => {
434                    style_stack.pop();
435                }
436                HighlightEvent::Source { start, end } => {
437                    let style = style_stack.last().unwrap();
438                    write!(&mut stdout, "{style}").unwrap();
439                    stdout.write_all(&source[start..end])?;
440                    write!(&mut stdout, "{style:#}").unwrap();
441                }
442            }
443        }
444    }
445
446    if opts.print_time {
447        info!("Time: {}ms", time.elapsed().as_millis());
448    }
449
450    Ok(())
451}
452
453#[cfg(test)]
454mod tests {
455    use std::env;
456
457    use super::*;
458
459    const JUNGLE_GREEN: &str = "#26A69A";
460    const DARK_CYAN: &str = "#00AF87";
461
462    #[test]
463    fn test_parse_style() {
464        let original_environment_variable = env::var("COLORTERM");
465
466        let mut style = Style::default();
467        assert_eq!(style.ansi.get_fg_color(), None);
468        assert_eq!(style.css, None);
469
470        // darkcyan is an ANSI color and is preserved
471        env::set_var("COLORTERM", "");
472        parse_style(&mut style, Value::String(DARK_CYAN.to_string()));
473        assert_eq!(
474            style.ansi.get_fg_color(),
475            Some(Color::Ansi256(Ansi256Color(36)))
476        );
477        assert_eq!(style.css, Some("color: #00af87".to_string()));
478
479        // junglegreen is not an ANSI color and is preserved when the terminal supports it
480        env::set_var("COLORTERM", "truecolor");
481        parse_style(&mut style, Value::String(JUNGLE_GREEN.to_string()));
482        assert_eq!(
483            style.ansi.get_fg_color(),
484            Some(Color::Rgb(RgbColor(38, 166, 154)))
485        );
486        assert_eq!(style.css, Some("color: #26a69a".to_string()));
487
488        // junglegreen gets approximated as cadetblue when the terminal does not support it
489        env::set_var("COLORTERM", "");
490        parse_style(&mut style, Value::String(JUNGLE_GREEN.to_string()));
491        assert_eq!(
492            style.ansi.get_fg_color(),
493            Some(Color::Ansi256(Ansi256Color(72)))
494        );
495        assert_eq!(style.css, Some("color: #26a69a".to_string()));
496
497        if let Ok(environment_variable) = original_environment_variable {
498            env::set_var("COLORTERM", environment_variable);
499        } else {
500            env::remove_var("COLORTERM");
501        }
502    }
503}