spydecy_python/
parser.rs

1//! Python AST parser using PyO3
2//!
3//! This module uses PyO3 to invoke Python's `ast` module for parsing.
4
5use anyhow::{Context, Result};
6use pyo3::prelude::*;
7use pyo3::types::PyModule;
8use serde::{Deserialize, Serialize};
9
10/// Python AST node (simplified representation)
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PythonAST {
13    /// Node type (e.g., "Module", "FunctionDef", "Call")
14    pub node_type: String,
15    /// Line number
16    pub lineno: Option<usize>,
17    /// Column offset
18    pub col_offset: Option<usize>,
19    /// Child nodes
20    pub children: Vec<PythonAST>,
21    /// Node attributes (name, value, etc.)
22    pub attributes: std::collections::HashMap<String, String>,
23}
24
25impl PythonAST {
26    /// Create a new AST node
27    #[must_use]
28    pub fn new(node_type: String) -> Self {
29        Self {
30            node_type,
31            lineno: None,
32            col_offset: None,
33            children: Vec::new(),
34            attributes: std::collections::HashMap::new(),
35        }
36    }
37}
38
39/// Parse Python source code into AST
40///
41/// # Errors
42///
43/// Returns an error if the Python code cannot be parsed
44pub fn parse(source: &str, filename: &str) -> Result<PythonAST> {
45    Python::with_gil(|py| parse_with_python(py, source, filename))
46}
47
48/// Parse Python source using Python's ast module
49fn parse_with_python(py: Python<'_>, source: &str, filename: &str) -> Result<PythonAST> {
50    // Import Python's ast module
51    let ast_module =
52        PyModule::import_bound(py, "ast").context("Failed to import Python ast module")?;
53
54    // Parse the source code
55    let ast_obj = ast_module
56        .call_method1("parse", (source, filename))
57        .context("Failed to parse Python source code")?;
58
59    // Convert Python AST to our simplified AST representation
60    extract_ast_node(&ast_obj)
61}
62
63/// Extract AST node information from Python object
64fn extract_ast_node(obj: &Bound<'_, PyAny>) -> Result<PythonAST> {
65    let node_type = obj
66        .getattr("__class__")?
67        .getattr("__name__")?
68        .extract::<String>()?;
69
70    let mut ast = PythonAST::new(node_type.clone());
71
72    // Extract line number and column offset
73    extract_location_info(obj, &mut ast);
74
75    // Extract node-specific attributes
76    extract_node_attributes(obj, &node_type, &mut ast)?;
77
78    Ok(ast)
79}
80
81/// Extract location information (line number and column offset)
82fn extract_location_info(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) {
83    if let Ok(lineno) = obj.getattr("lineno") {
84        ast.lineno = lineno.extract().ok();
85    }
86    if let Ok(col_offset) = obj.getattr("col_offset") {
87        ast.col_offset = col_offset.extract().ok();
88    }
89}
90
91/// Extract node-specific attributes based on node type
92fn extract_node_attributes(
93    obj: &Bound<'_, PyAny>,
94    node_type: &str,
95    ast: &mut PythonAST,
96) -> Result<()> {
97    match node_type {
98        "Module" => extract_module_attrs(obj, ast)?,
99        "FunctionDef" => extract_function_def_attrs(obj, ast)?,
100        "Return" => extract_return_attrs(obj, ast)?,
101        "Call" => extract_call_attrs(obj, ast)?,
102        "Name" => extract_name_attrs(obj, ast)?,
103        _ => extract_default_attrs(obj, ast)?,
104    }
105    Ok(())
106}
107
108/// Extract Module node attributes
109fn extract_module_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
110    if let Ok(body) = obj.getattr("body") {
111        ast.children = extract_list(&body)?;
112    }
113    Ok(())
114}
115
116/// Extract FunctionDef node attributes
117fn extract_function_def_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
118    if let Ok(name) = obj.getattr("name") {
119        ast.attributes.insert("name".to_string(), name.extract()?);
120    }
121    if let Ok(body) = obj.getattr("body") {
122        ast.children = extract_list(&body)?;
123    }
124    Ok(())
125}
126
127/// Extract Return node attributes
128fn extract_return_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
129    if let Ok(value) = obj.getattr("value") {
130        if !value.is_none() {
131            ast.children.push(extract_ast_node(&value)?);
132        }
133    }
134    Ok(())
135}
136
137/// Extract Call node attributes
138fn extract_call_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
139    if let Ok(func) = obj.getattr("func") {
140        ast.children.push(extract_ast_node(&func)?);
141    }
142    if let Ok(args) = obj.getattr("args") {
143        ast.children.extend(extract_list(&args)?);
144    }
145    Ok(())
146}
147
148/// Extract Name node attributes
149fn extract_name_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
150    if let Ok(id) = obj.getattr("id") {
151        ast.attributes.insert("id".to_string(), id.extract()?);
152    }
153    Ok(())
154}
155
156/// Extract default attributes for unknown node types
157#[allow(clippy::unnecessary_wraps)]
158fn extract_default_attrs(obj: &Bound<'_, PyAny>, ast: &mut PythonAST) -> Result<()> {
159    if let Ok(value) = obj.getattr("value") {
160        if !value.is_none() {
161            if let Ok(child) = extract_ast_node(&value) {
162                ast.children.push(child);
163            }
164        }
165    }
166    Ok(())
167}
168
169/// Extract a list of AST nodes
170fn extract_list(list: &Bound<'_, PyAny>) -> Result<Vec<PythonAST>> {
171    let mut nodes = Vec::new();
172    for item in list.iter()? {
173        let item = item?;
174        nodes.push(extract_ast_node(&item)?);
175    }
176    Ok(nodes)
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_parse_simple_function() {
185        let source = r"
186def my_len(x):
187    return len(x)
188";
189        let ast = parse(source, "test.py").unwrap();
190        assert_eq!(ast.node_type, "Module");
191        assert!(!ast.children.is_empty());
192    }
193
194    #[test]
195    fn test_parse_with_type_hints() {
196        let source = r"
197def my_len(x: list) -> int:
198    return len(x)
199";
200        let ast = parse(source, "test.py").unwrap();
201        assert_eq!(ast.node_type, "Module");
202    }
203
204    #[test]
205    fn test_parse_invalid_syntax() {
206        let source = "def invalid syntax here";
207        let result = parse(source, "test.py");
208        assert!(result.is_err());
209    }
210}