Skip to main content

shape_ext_python/
error_mapping.rs

1//! Python traceback -> Shape source location mapping.
2
3use crate::runtime::CompiledFunction;
4
5/// Parsed representation of a Python traceback frame.
6#[derive(Debug, Clone)]
7pub struct PythonFrame {
8    pub filename: String,
9    pub line: u32,
10    pub function: String,
11    pub text: Option<String>,
12}
13
14/// Parse a Python traceback string into structured frames.
15///
16/// Recognises the standard CPython traceback format:
17///
18/// ```text
19/// Traceback (most recent call last):
20///   File "script.py", line 10, in <module>
21///     some_code()
22///   File "other.py", line 5, in func
23///     do_thing()
24/// ErrorType: message
25/// ```
26///
27/// Each `File "...", line N, in <name>` line becomes a [`PythonFrame`].
28/// The optional indented source-text line that follows is captured in
29/// [`PythonFrame::text`].
30pub fn parse_traceback(traceback: &str) -> Vec<PythonFrame> {
31    let lines: Vec<&str> = traceback.lines().collect();
32    let mut frames = Vec::new();
33    let mut i = 0;
34
35    while i < lines.len() {
36        let trimmed = lines[i].trim();
37        if trimmed.starts_with("File \"") {
38            if let Some(frame) = parse_file_line(trimmed) {
39                // Check if the next line is indented source text (not another
40                // File line or the error summary).
41                let text = if i + 1 < lines.len() {
42                    let next = lines[i + 1];
43                    let next_trimmed = next.trim();
44                    // Source text lines are indented and do NOT start with "File "
45                    if !next_trimmed.is_empty()
46                        && !next_trimmed.starts_with("File \"")
47                        && !next_trimmed.starts_with("Traceback")
48                        && next.starts_with(' ')
49                    {
50                        i += 1; // consume the source text line
51                        Some(next_trimmed.to_string())
52                    } else {
53                        None
54                    }
55                } else {
56                    None
57                };
58
59                frames.push(PythonFrame {
60                    filename: frame.0,
61                    line: frame.1,
62                    function: frame.2,
63                    text,
64                });
65            }
66        }
67        i += 1;
68    }
69
70    frames
71}
72
73/// Parse a single `File "filename", line N, in funcname` line.
74/// Returns `(filename, line_number, function_name)` on success.
75fn parse_file_line(trimmed: &str) -> Option<(String, u32, String)> {
76    // Strip the leading `File "` prefix
77    let rest = trimmed.strip_prefix("File \"")?;
78    let quote_end = rest.find('"')?;
79    let filename = &rest[..quote_end];
80    let after_quote = &rest[quote_end + 1..];
81
82    // Extract line number from ", line N" portion
83    let line_start = after_quote.find("line ")?;
84    let num_str = &after_quote[line_start + 5..];
85
86    // The line number ends at the next comma (or end-of-string)
87    let line_no = if let Some(comma) = num_str.find(',') {
88        num_str[..comma].trim().parse::<u32>().ok()?
89    } else {
90        num_str.trim().parse::<u32>().ok()?
91    };
92
93    // Extract function name from ", in <name>" portion (if present)
94    let function = after_quote
95        .rfind("in ")
96        .map(|i| after_quote[i + 3..].trim().to_string())
97        .unwrap_or_else(|| "<unknown>".to_string());
98
99    Some((filename.to_string(), line_no, function))
100}
101
102/// Map a Python line number inside `__shape_fn__` back to the Shape
103/// source line number.
104pub fn map_python_line_to_shape(python_line: u32, shape_body_start_line: u32) -> u32 {
105    if python_line < 2 {
106        shape_body_start_line
107    } else {
108        shape_body_start_line + (python_line - 1)
109    }
110}
111
112/// Format a Python error with context from the compiled function.
113#[cfg(feature = "pyo3")]
114pub fn format_python_error(
115    py: pyo3::Python<'_>,
116    err: &pyo3::PyErr,
117    func: &CompiledFunction,
118) -> String {
119    use pyo3::types::PyTracebackMethods;
120    let traceback_str = err
121        .traceback(py)
122        .and_then(|tb| tb.format().ok())
123        .unwrap_or_default();
124
125    // Try to extract the relevant line number from the traceback
126    let mut shape_line = None;
127    for line in traceback_str.lines() {
128        if line.contains("<shape>") || line.contains("__shape__") {
129            // Parse "line N" from the traceback
130            if let Some(pos) = line.find("line ") {
131                let after = &line[pos + 5..];
132                if let Some(end) = after.find(|c: char| !c.is_ascii_digit()) {
133                    if let Ok(py_line) = after[..end].parse::<u32>() {
134                        shape_line = Some(map_python_line_to_shape(
135                            py_line,
136                            func.shape_body_start_line,
137                        ));
138                    }
139                } else if let Ok(py_line) = after.trim().parse::<u32>() {
140                    shape_line = Some(map_python_line_to_shape(
141                        py_line,
142                        func.shape_body_start_line,
143                    ));
144                }
145            }
146        }
147    }
148
149    if let Some(line) = shape_line {
150        format!("Python error in '{}' at line {}: {}", func.name, line, err)
151    } else {
152        format!("Python error in '{}': {}", func.name, err)
153    }
154}
155
156/// Fallback when pyo3 is not enabled.
157#[cfg(not(feature = "pyo3"))]
158pub fn format_python_error(_err: &str, func: &CompiledFunction) -> String {
159    format!("Python error in '{}': pyo3 not enabled", func.name)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn parse_traceback_full_example() {
168        let tb = "\
169Traceback (most recent call last):
170  File \"script.py\", line 10, in <module>
171    some_code()
172  File \"other.py\", line 5, in func
173    do_thing()
174ValueError: bad value";
175        let frames = parse_traceback(tb);
176        assert_eq!(frames.len(), 2);
177
178        assert_eq!(frames[0].filename, "script.py");
179        assert_eq!(frames[0].line, 10);
180        assert_eq!(frames[0].function, "<module>");
181        assert_eq!(frames[0].text.as_deref(), Some("some_code()"));
182
183        assert_eq!(frames[1].filename, "other.py");
184        assert_eq!(frames[1].line, 5);
185        assert_eq!(frames[1].function, "func");
186        assert_eq!(frames[1].text.as_deref(), Some("do_thing()"));
187    }
188
189    #[test]
190    fn parse_traceback_no_source_text() {
191        let tb = "\
192Traceback (most recent call last):
193  File \"a.py\", line 1, in main
194TypeError: oops";
195        let frames = parse_traceback(tb);
196        assert_eq!(frames.len(), 1);
197        assert_eq!(frames[0].filename, "a.py");
198        assert_eq!(frames[0].line, 1);
199        assert_eq!(frames[0].function, "main");
200        assert!(frames[0].text.is_none());
201    }
202
203    #[test]
204    fn parse_traceback_empty_input() {
205        assert!(parse_traceback("").is_empty());
206    }
207
208    #[test]
209    fn parse_traceback_no_traceback_lines() {
210        let tb = "RuntimeError: something went wrong";
211        assert!(parse_traceback(tb).is_empty());
212    }
213
214    #[test]
215    fn parse_traceback_shape_internal_frame() {
216        let tb = "  File \"<shape>\", line 3, in __shape_fn__\n    return x + 1";
217        let frames = parse_traceback(tb);
218        assert_eq!(frames.len(), 1);
219        assert_eq!(frames[0].filename, "<shape>");
220        assert_eq!(frames[0].line, 3);
221        assert_eq!(frames[0].function, "__shape_fn__");
222        assert_eq!(frames[0].text.as_deref(), Some("return x + 1"));
223    }
224
225    #[test]
226    fn map_python_line_to_shape_basics() {
227        // line < 2 maps to start
228        assert_eq!(map_python_line_to_shape(1, 10), 10);
229        assert_eq!(map_python_line_to_shape(0, 10), 10);
230        // line >= 2 maps to start + (line - 1)
231        assert_eq!(map_python_line_to_shape(2, 10), 11);
232        assert_eq!(map_python_line_to_shape(5, 10), 14);
233    }
234}