Skip to main content

token_goblin_runtime/
wire.rs

1#![allow(rustdoc::private_intra_doc_links)]
2//! Guest-side wire format for the dylib boundary.
3//! It's an internall module, but feel free to enjoy the docs.
4
5use std::{ops::Range, panic::UnwindSafe};
6
7use proc_macro2::{LexError, Span, TokenStream, TokenTree};
8
9use crate::wire::panic::{PanicLocation, PanicReport};
10
11/// Guest output packet returned from dylib `entry`.
12/// THIS TYPE SHOULD MATCH `token_goblin::span_recovery::Output`
13pub struct Output {
14    pub text: String,
15    pub spans: Vec<OutputEntry>,
16}
17/// Single entry in the output.
18///
19/// THIS TYPE SHOULD MATCH `token_goblin::span_recovery::OutputEntry`
20#[derive(PartialEq, Eq, Debug)]
21pub struct OutputEntry {
22    pub is_panic: bool,
23    pub range: Range<usize>,
24}
25impl OutputEntry {
26    pub fn new(is_panic: bool, range: Range<usize>) -> Self {
27        Self { is_panic, range }
28    }
29}
30
31/// The entry for `charm` that handle input/output converions, and panics handling.
32///
33/// check out:
34/// - [`panic::run_and_catch`] for panic handling.
35/// - [`parse_input`] for input parsing.
36/// - [`collect_outputs`] for output serialization.
37///
38#[allow(clippy::missing_panics_doc, reason = "panic is in catch block")]
39pub fn entry(input: &str, body: impl FnOnce(TokenStream) -> TokenStream + UnwindSafe) -> Output {
40    let mut panics = None;
41
42    let (tokens, anchor) = panic::run_and_catch(|| {
43        let (input, anchor) = parse_input(input).unwrap();
44        (body(input), anchor)
45    })
46    .unwrap_or_else(|e| {
47        let (msg, location) = panic_to_compile_error(e);
48        panics = Some((
49            msg.clone(),
50            location // TODO: Rewrite it with normal types
51                .map(|loc| Range {
52                    start: loc.line as usize,
53                    end: loc.column as usize,
54                })
55                .unwrap_or_default(),
56        ));
57        (TokenStream::new(), None)
58    });
59
60    collect_outputs(tokens, panics, anchor)
61}
62fn panic_to_compile_error(e: PanicReport) -> (TokenStream, Option<PanicLocation>) {
63    let message = format!(
64        "panic in charm (at {location}): {error}",
65        error = e.message,
66        location = e.location.as_ref().map_or_else(
67            || "<unknown location>".to_string(),
68            PanicLocation::to_string
69        )
70    );
71    let msg = syn::Error::new(Span::call_site(), message).to_compile_error();
72
73    (msg, e.location)
74}
75
76/// Parse canonical host input text into a local fallback token stream.
77///
78/// Returns:
79/// - `TokenStream` suitable for further parsing
80/// - anchor - span that is used in `output` to filter spans from external source (e.g. embedded `TokenStream::from_str`)
81///
82/// # Errors
83/// - `LexError` - if input is not a valid token stream.
84///
85pub fn parse_input(source: &str) -> Result<(TokenStream, Option<Span>), LexError> {
86    let tokens: TokenStream = source.parse()?;
87    let anchor = first_leaf_span(&tokens);
88    Ok((tokens, anchor))
89}
90
91/// Serialize macro output into text plus flattened leaf-token source ranges.
92#[allow(clippy::needless_pass_by_value, reason = "consume token stream")]
93#[must_use]
94pub fn collect_outputs(
95    tokens: TokenStream,
96    error: Option<(TokenStream, Range<usize>)>,
97    anchor: Option<Span>,
98) -> Output {
99    let mut output = {
100        // collect external tts first
101        let resulted_stream = crate::ux::flush_output(tokens);
102
103        let source_range_fn = |span: Span| source_range(span, anchor);
104        // then regular output with remapped spans
105        Output {
106            text: resulted_stream.to_string(),
107            spans: flatten_leaf_spans(&resulted_stream, &source_range_fn)
108                .into_iter()
109                .map(|range| OutputEntry::new(false, range))
110                .collect(),
111        }
112    };
113
114    // And then panic with location info
115    if let Some((panics, range)) = error {
116        let error_source_ranges = |_| range.clone();
117        output.text.push(' ');
118        output.text.push_str(&panics.to_string());
119
120        output.spans.extend(
121            flatten_leaf_spans(&panics, &error_source_ranges)
122                .into_iter()
123                .map(|range| OutputEntry::new(true, range)),
124        );
125    }
126    output
127}
128
129fn first_leaf_span(tokens: &TokenStream) -> Option<Span> {
130    for token in tokens.clone() {
131        match token {
132            TokenTree::Group(group) => {
133                if let Some(span) = first_leaf_span(&group.stream()) {
134                    return Some(span);
135                }
136            }
137            TokenTree::Ident(ident) => return Some(ident.span()),
138            TokenTree::Punct(punct) => return Some(punct.span()),
139            TokenTree::Literal(literal) => return Some(literal.span()),
140        }
141    }
142    None
143}
144
145fn flatten_leaf_spans(
146    tokens: &TokenStream,
147    source_range_fn: &dyn Fn(Span) -> Range<usize>,
148) -> Vec<Range<usize>> {
149    let mut spans = Vec::new();
150    collect_leaf_spans(tokens, &mut spans, source_range_fn);
151    spans
152}
153
154fn collect_leaf_spans(
155    tokens: &TokenStream,
156    spans: &mut Vec<Range<usize>>,
157    source_range_fn: &dyn Fn(Span) -> Range<usize>,
158) {
159    for token in tokens.clone() {
160        match token {
161            TokenTree::Group(group) => collect_leaf_spans(&group.stream(), spans, source_range_fn),
162            TokenTree::Ident(ident) => spans.push(source_range_fn(ident.span())),
163            TokenTree::Punct(punct) => spans.push(source_range_fn(punct.span())),
164            TokenTree::Literal(literal) => spans.push(source_range_fn(literal.span())),
165        }
166    }
167}
168const CALL_SITE_RANGE: Range<usize> = 0..0;
169fn source_range(span: Span, anchor: Option<Span>) -> Range<usize> {
170    let range = span.byte_range();
171    if range.is_empty() {
172        return CALL_SITE_RANGE;
173    }
174
175    match anchor {
176        Some(anchor) if anchor.join(span).is_some() => range,
177        // Either it already call_site, or it is just span from "virtual file"
178        // in both cases we map it to call_site
179        _ => CALL_SITE_RANGE,
180    }
181}
182
183/// Handles panics and converts them to compile errors with location info.
184mod panic {
185    use core::fmt;
186    use std::{
187        any::Any,
188        cell::RefCell,
189        fmt::Display,
190        panic::{self, AssertUnwindSafe, PanicHookInfo},
191    };
192
193    #[derive(Debug)]
194    pub struct PanicLocation {
195        pub file: String,
196        pub line: u32,
197        pub column: u32,
198    }
199    impl Display for PanicLocation {
200        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201            write!(f, "{}:{}:{}", self.file, self.line, self.column)
202        }
203    }
204    #[derive(Debug)]
205    pub struct PanicReport {
206        pub message: String,
207        pub location: Option<PanicLocation>,
208    }
209
210    thread_local! {
211        static LAST_PANIC: RefCell<Option<PanicReport>> = const {RefCell::new(None)};
212    }
213
214    fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String {
215        if let Some(s) = payload.downcast_ref::<&str>() {
216            s.to_string()
217        } else if let Some(s) = payload.downcast_ref::<String>() {
218            s.clone()
219        } else {
220            "<non-string panic payload>".to_string()
221        }
222    }
223
224    /// Install a panic hook for processing panics, return old one.
225    fn install_panic_hook() -> Box<dyn Fn(&PanicHookInfo<'_>) + 'static + Sync + Send> {
226        let default_hook = panic::take_hook();
227
228        panic::set_hook(Box::new(move |info: &PanicHookInfo<'_>| {
229            let message = panic_payload_to_string(info.payload());
230
231            let location = info.location().map(|loc| PanicLocation {
232                file: loc.file().to_string(),
233                line: loc.line(),
234                column: loc.column(),
235            });
236
237            LAST_PANIC.with(|slot| {
238                *slot.borrow_mut() = Some(PanicReport { message, location });
239            });
240        }));
241        default_hook
242    }
243
244    pub fn run_and_catch<F, R>(f: F) -> Result<R, PanicReport>
245    where
246        F: FnOnce() -> R,
247    {
248        LAST_PANIC.with(|slot| {
249            *slot.borrow_mut() = None;
250        });
251
252        let old_hook = install_panic_hook();
253
254        let res = match panic::catch_unwind(AssertUnwindSafe(f)) {
255            Ok(value) => Ok(value),
256
257            Err(payload) => {
258                let fallback_message = panic_payload_to_string(payload.as_ref());
259
260                let report = LAST_PANIC.with(|slot| slot.borrow_mut().take());
261
262                Err(report.unwrap_or(PanicReport {
263                    message: fallback_message,
264                    location: None,
265                }))
266            }
267        };
268
269        panic::set_hook(old_hook);
270        res
271    }
272}
273
274#[cfg(test)]
275#[allow(clippy::single_range_in_vec_init)]
276mod tests {
277    use std::str::FromStr as _;
278
279    use proc_macro2::{Literal, TokenTree};
280    use syn::Ident;
281
282    use super::*;
283
284    fn single_literal(tokens: &TokenStream) -> Literal {
285        match tokens.clone().into_iter().next().expect("one token") {
286            TokenTree::Literal(literal) => literal,
287            other => panic!("expected literal, got {other:?}"),
288        }
289    }
290
291    #[test]
292    fn input_anchor_joins_input_span() {
293        let input = parse_input("12").expect("valid token stream");
294        let literal = single_literal(&input.0);
295        assert!(input.1.unwrap().join(literal.span()).is_some());
296    }
297
298    #[test]
299    fn unrelated_parse_does_not_join_input_anchor() {
300        let input = parse_input("12").expect("valid token stream");
301        let generated = TokenStream::from_str(
302            "
303        12
304    ",
305        )
306        .unwrap();
307        let literal = single_literal(&generated);
308        assert!(input.1.unwrap().join(literal.span()).is_none());
309    }
310
311    #[test]
312    fn output_maps_unrelated_spans_to_call_site() {
313        let input = parse_input("12").expect("valid token stream");
314        let generated = TokenStream::from_str(
315            "
316        12
317    ",
318        )
319        .unwrap();
320        let out = collect_outputs(generated, None, input.1);
321        assert_eq!(out.spans, [OutputEntry::new(false, 0..0)]);
322    }
323
324    #[test]
325    fn output_preserves_input_relative_spans() {
326        let input = parse_input("12").expect("valid token stream");
327        let out = collect_outputs(input.0.clone(), None, input.1);
328        assert_eq!(out.text, "12");
329        assert_eq!(out.spans, [OutputEntry::new(false, 0..2)]);
330    }
331
332    #[test]
333    fn panic_capture_hook_captures_panic() {
334        let report = panic::run_and_catch(|| {
335            panic!("test panic");
336        })
337        .unwrap_err();
338        assert_eq!(report.message, "test panic");
339        assert!(report.location.is_some());
340    }
341
342    #[test]
343    fn test_entry_full_flow() {
344        let input = "12";
345        let body = |mut input: TokenStream| {
346            input.extend([Ident::new("foo", Span::call_site())]);
347            input
348        };
349        let output = entry(input, body);
350        assert_eq!(output.text, "12 foo");
351        assert_eq!(
352            output.spans,
353            [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
354        );
355    }
356
357    #[test]
358    fn test_checks_that_entry_cleanup() {
359        let input = "12";
360        let body = |mut input: TokenStream| {
361            input.extend([Ident::new("foo", Span::call_site())]);
362            input
363        };
364        let output = entry(input, body);
365        assert_eq!(output.text, "12 foo");
366        // second call should produce same output
367        let output = entry(input, body);
368        assert_eq!(output.text, "12 foo");
369        assert_eq!(
370            output.spans,
371            [OutputEntry::new(false, 0..2), OutputEntry::new(false, 0..0)]
372        );
373    }
374
375    #[test]
376    #[should_panic(expected = "second panic")]
377    fn test_check_panic_hook_cleanup() {
378        let input = "12";
379        let body = |_: TokenStream| {
380            panic!("test panic");
381        };
382        let output = entry(input, body);
383        assert!(output.text.contains("test panic"));
384        panic!("second panic");
385    }
386}