Skip to main content

wave_compiler/frontend/
rust.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Rust kernel parser producing HIR.
5//!
6//! Parses a subset of Rust suitable for GPU kernels using the `syn` crate.
7//! Supports function definitions with type annotations, arithmetic, comparisons,
8//! if/else, for loops, array indexing, and GPU intrinsics.
9
10use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16/// Parse a Rust kernel source string into an HIR Kernel.
17///
18/// # Errors
19///
20/// Returns `CompileError::ParseError` if the source cannot be parsed.
21pub fn parse_rust(source: &str) -> Result<Kernel, CompileError> {
22    let file = syn::parse_file(source).map_err(|e| CompileError::ParseError {
23        message: format!("Rust parse error: {e}"),
24    })?;
25
26    for item in &file.items {
27        if let syn::Item::Fn(func) = item {
28            let has_kernel_attr = func.attrs.iter().any(|a| a.path().is_ident("kernel"));
29            if has_kernel_attr || func.attrs.is_empty() {
30                return lower_function(func);
31            }
32        }
33    }
34
35    Err(CompileError::ParseError {
36        message: "no kernel function found".into(),
37    })
38}
39
40fn lower_function(func: &syn::ItemFn) -> Result<Kernel, CompileError> {
41    let name = func.sig.ident.to_string();
42    let mut params = Vec::new();
43
44    for arg in &func.sig.inputs {
45        if let syn::FnArg::Typed(pat_type) = arg {
46            if let syn::Pat::Ident(ident) = &*pat_type.pat {
47                let param_name = ident.ident.to_string();
48                let (ty, space) = lower_type(&pat_type.ty);
49                params.push(KernelParam {
50                    name: param_name,
51                    ty,
52                    address_space: space,
53                });
54            }
55        }
56    }
57
58    let body = lower_block(&func.block)?;
59
60    Ok(Kernel {
61        name,
62        params,
63        body,
64        attributes: KernelAttributes::default(),
65    })
66}
67
68fn lower_type(ty: &syn::Type) -> (Type, AddressSpace) {
69    match ty {
70        syn::Type::Path(path) => {
71            let ident = path
72                .path
73                .segments
74                .last()
75                .map(|s| s.ident.to_string())
76                .unwrap_or_default();
77            match ident.as_str() {
78                "i32" => (Type::I32, AddressSpace::Private),
79                "f32" => (Type::F32, AddressSpace::Private),
80                "f64" => (Type::F64, AddressSpace::Private),
81                "bool" => (Type::Bool, AddressSpace::Private),
82                _ => (Type::U32, AddressSpace::Private),
83            }
84        }
85        syn::Type::Reference(ref_type) => {
86            if let syn::Type::Slice(_) = &*ref_type.elem {
87                (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
88            } else {
89                lower_type(&ref_type.elem)
90            }
91        }
92        _ => (Type::U32, AddressSpace::Private),
93    }
94}
95
96fn lower_block(block: &syn::Block) -> Result<Vec<Stmt>, CompileError> {
97    let mut stmts = Vec::new();
98    for stmt in &block.stmts {
99        match stmt {
100            syn::Stmt::Local(local) => {
101                if let Some(init) = &local.init {
102                    if let syn::Pat::Ident(ident) = &local.pat {
103                        let value = lower_expr(&init.expr)?;
104                        stmts.push(Stmt::Assign {
105                            target: ident.ident.to_string(),
106                            value,
107                        });
108                    }
109                }
110            }
111            syn::Stmt::Expr(expr, _) => {
112                if let Some(s) = lower_stmt_expr(expr)? {
113                    stmts.push(s);
114                }
115            }
116            _ => {}
117        }
118    }
119    Ok(stmts)
120}
121
122fn lower_stmt_expr(expr: &syn::Expr) -> Result<Option<Stmt>, CompileError> {
123    match expr {
124        syn::Expr::If(if_expr) => {
125            let condition = lower_expr(&if_expr.cond)?;
126            let then_body = lower_block(&if_expr.then_branch)?;
127            let else_body = if let Some((_, else_expr)) = &if_expr.else_branch {
128                if let syn::Expr::Block(block_expr) = &**else_expr {
129                    Some(lower_block(&block_expr.block)?)
130                } else {
131                    None
132                }
133            } else {
134                None
135            };
136            Ok(Some(Stmt::If {
137                condition,
138                then_body,
139                else_body,
140            }))
141        }
142        syn::Expr::Assign(assign) => {
143            let value = lower_expr(&assign.right)?;
144            if let syn::Expr::Path(path) = &*assign.left {
145                let target = path
146                    .path
147                    .segments
148                    .last()
149                    .map(|s| s.ident.to_string())
150                    .unwrap_or_default();
151                Ok(Some(Stmt::Assign { target, value }))
152            } else if let syn::Expr::Index(idx) = &*assign.left {
153                let base = lower_expr(&idx.expr)?;
154                let index = lower_expr(&idx.index)?;
155                let elem_size = Expr::Literal(Literal::Int(4));
156                let offset = Expr::BinOp {
157                    op: BinOp::Mul,
158                    lhs: Box::new(index),
159                    rhs: Box::new(elem_size),
160                };
161                let addr = Expr::BinOp {
162                    op: BinOp::Add,
163                    lhs: Box::new(base),
164                    rhs: Box::new(offset),
165                };
166                Ok(Some(Stmt::Store {
167                    addr,
168                    value,
169                    space: AddressSpace::Device,
170                }))
171            } else {
172                Ok(None)
173            }
174        }
175        syn::Expr::Return(ret) => {
176            let value = ret.expr.as_ref().map(|e| lower_expr(e)).transpose()?;
177            Ok(Some(Stmt::Return { value }))
178        }
179        _ => Ok(None),
180    }
181}
182
183fn lower_expr(expr: &syn::Expr) -> Result<Expr, CompileError> {
184    match expr {
185        syn::Expr::Lit(lit) => lower_lit(lit),
186        syn::Expr::Path(path) => {
187            let name = path
188                .path
189                .segments
190                .last()
191                .map(|s| s.ident.to_string())
192                .unwrap_or_default();
193            Ok(Expr::Var(name))
194        }
195        syn::Expr::Binary(bin) => lower_binary(bin),
196        syn::Expr::Call(call) => lower_call(call),
197        syn::Expr::Index(idx) => {
198            let base = lower_expr(&idx.expr)?;
199            let index = lower_expr(&idx.index)?;
200            Ok(Expr::Index {
201                base: Box::new(base),
202                index: Box::new(index),
203            })
204        }
205        syn::Expr::Paren(paren) => lower_expr(&paren.expr),
206        syn::Expr::Unary(unary) => {
207            let operand = lower_expr(&unary.expr)?;
208            match unary.op {
209                syn::UnOp::Neg(_) => Ok(Expr::UnaryOp {
210                    op: crate::hir::expr::UnaryOp::Neg,
211                    operand: Box::new(operand),
212                }),
213                syn::UnOp::Not(_) => Ok(Expr::UnaryOp {
214                    op: crate::hir::expr::UnaryOp::Not,
215                    operand: Box::new(operand),
216                }),
217                _ => Err(CompileError::ParseError {
218                    message: "unsupported unary op".into(),
219                }),
220            }
221        }
222        _ => Err(CompileError::ParseError {
223            message: "unsupported expression".into(),
224        }),
225    }
226}
227
228fn lower_lit(lit: &syn::ExprLit) -> Result<Expr, CompileError> {
229    match &lit.lit {
230        syn::Lit::Int(i) => {
231            let v: i64 = i.base10_parse().unwrap_or(0);
232            Ok(Expr::Literal(Literal::Int(v)))
233        }
234        syn::Lit::Float(f) => {
235            let v: f64 = f.base10_parse().unwrap_or(0.0);
236            Ok(Expr::Literal(Literal::Float(v)))
237        }
238        syn::Lit::Bool(b) => Ok(Expr::Literal(Literal::Bool(b.value))),
239        _ => Err(CompileError::ParseError {
240            message: "unsupported literal".into(),
241        }),
242    }
243}
244
245fn lower_binary(bin: &syn::ExprBinary) -> Result<Expr, CompileError> {
246    let lhs = lower_expr(&bin.left)?;
247    let rhs = lower_expr(&bin.right)?;
248    let op = match bin.op {
249        syn::BinOp::Add(_) => BinOp::Add,
250        syn::BinOp::Sub(_) => BinOp::Sub,
251        syn::BinOp::Mul(_) => BinOp::Mul,
252        syn::BinOp::Div(_) => BinOp::Div,
253        syn::BinOp::Rem(_) => BinOp::Mod,
254        syn::BinOp::Lt(_) => BinOp::Lt,
255        syn::BinOp::Le(_) => BinOp::Le,
256        syn::BinOp::Gt(_) => BinOp::Gt,
257        syn::BinOp::Ge(_) => BinOp::Ge,
258        syn::BinOp::Eq(_) => BinOp::Eq,
259        syn::BinOp::Ne(_) => BinOp::Ne,
260        syn::BinOp::BitAnd(_) => BinOp::BitAnd,
261        syn::BinOp::BitOr(_) => BinOp::BitOr,
262        syn::BinOp::BitXor(_) => BinOp::BitXor,
263        syn::BinOp::Shl(_) => BinOp::Shl,
264        syn::BinOp::Shr(_) => BinOp::Shr,
265        _ => {
266            return Err(CompileError::ParseError {
267                message: "unsupported binary op".into(),
268            })
269        }
270    };
271    Ok(Expr::BinOp {
272        op,
273        lhs: Box::new(lhs),
274        rhs: Box::new(rhs),
275    })
276}
277
278fn lower_call(call: &syn::ExprCall) -> Result<Expr, CompileError> {
279    if let syn::Expr::Path(path) = &*call.func {
280        let func_name = path
281            .path
282            .segments
283            .last()
284            .map(|s| s.ident.to_string())
285            .unwrap_or_default();
286        match func_name.as_str() {
287            "thread_id" => Ok(Expr::ThreadId(Dimension::X)),
288            "workgroup_id" => Ok(Expr::WorkgroupId(Dimension::X)),
289            "workgroup_size" => Ok(Expr::WorkgroupSize(Dimension::X)),
290            "lane_id" => Ok(Expr::LaneId),
291            "wave_width" => Ok(Expr::WaveWidth),
292            "barrier" => Ok(Expr::Literal(Literal::Int(0))),
293            _ => {
294                let args: Vec<Expr> = call.args.iter().map(lower_expr).collect::<Result<_, _>>()?;
295                Ok(Expr::Call {
296                    func: match func_name.as_str() {
297                        "sqrt" => crate::hir::expr::BuiltinFunc::Sqrt,
298                        "sin" => crate::hir::expr::BuiltinFunc::Sin,
299                        "cos" => crate::hir::expr::BuiltinFunc::Cos,
300                        "abs" => crate::hir::expr::BuiltinFunc::Abs,
301                        "min" => crate::hir::expr::BuiltinFunc::Min,
302                        "max" => crate::hir::expr::BuiltinFunc::Max,
303                        _ => {
304                            return Err(CompileError::ParseError {
305                                message: format!("unknown function: {func_name}"),
306                            })
307                        }
308                    },
309                    args,
310                })
311            }
312        }
313    } else {
314        Err(CompileError::ParseError {
315            message: "unsupported call".into(),
316        })
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_parse_rust_vector_add() {
326        let source = r#"
327#[kernel]
328fn vector_add(a: &[f32], b: &[f32], out: &mut [f32], n: u32) {
329    let gid = thread_id();
330    if gid < n {
331        let a_val = a[gid];
332    }
333}
334"#;
335        let kernel = parse_rust(source).unwrap();
336        assert_eq!(kernel.name, "vector_add");
337        assert_eq!(kernel.params.len(), 4);
338        assert_eq!(kernel.params[0].ty, Type::Ptr(AddressSpace::Device));
339        assert_eq!(kernel.params[3].ty, Type::U32);
340    }
341
342    #[test]
343    fn test_parse_rust_simple() {
344        let source = r#"
345#[kernel]
346fn test(n: u32) {
347    let x = 42;
348}
349"#;
350        let kernel = parse_rust(source).unwrap();
351        assert_eq!(kernel.name, "test");
352        assert_eq!(kernel.body.len(), 1);
353    }
354}