1use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16pub fn parse_rust(source: &str) -> Result<Kernel, CompileError> {
22 let file = syn::parse_file(source).map_err(|e| CompileError::ParseError {
23 message: format!("Rust parse error: {e}"),
24 })?;
25
26 for item in &file.items {
27 if let syn::Item::Fn(func) = item {
28 let has_kernel_attr = func.attrs.iter().any(|a| a.path().is_ident("kernel"));
29 if has_kernel_attr || func.attrs.is_empty() {
30 return lower_function(func);
31 }
32 }
33 }
34
35 Err(CompileError::ParseError {
36 message: "no kernel function found".into(),
37 })
38}
39
40fn lower_function(func: &syn::ItemFn) -> Result<Kernel, CompileError> {
41 let name = func.sig.ident.to_string();
42 let mut params = Vec::new();
43
44 for arg in &func.sig.inputs {
45 if let syn::FnArg::Typed(pat_type) = arg {
46 if let syn::Pat::Ident(ident) = &*pat_type.pat {
47 let param_name = ident.ident.to_string();
48 let (ty, space) = lower_type(&pat_type.ty);
49 params.push(KernelParam {
50 name: param_name,
51 ty,
52 address_space: space,
53 });
54 }
55 }
56 }
57
58 let body = lower_block(&func.block)?;
59
60 Ok(Kernel {
61 name,
62 params,
63 body,
64 attributes: KernelAttributes::default(),
65 })
66}
67
68fn lower_type(ty: &syn::Type) -> (Type, AddressSpace) {
69 match ty {
70 syn::Type::Path(path) => {
71 let ident = path
72 .path
73 .segments
74 .last()
75 .map(|s| s.ident.to_string())
76 .unwrap_or_default();
77 match ident.as_str() {
78 "i32" => (Type::I32, AddressSpace::Private),
79 "f32" => (Type::F32, AddressSpace::Private),
80 "f64" => (Type::F64, AddressSpace::Private),
81 "bool" => (Type::Bool, AddressSpace::Private),
82 _ => (Type::U32, AddressSpace::Private),
83 }
84 }
85 syn::Type::Reference(ref_type) => {
86 if let syn::Type::Slice(_) = &*ref_type.elem {
87 (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
88 } else {
89 lower_type(&ref_type.elem)
90 }
91 }
92 _ => (Type::U32, AddressSpace::Private),
93 }
94}
95
96fn lower_block(block: &syn::Block) -> Result<Vec<Stmt>, CompileError> {
97 let mut stmts = Vec::new();
98 for stmt in &block.stmts {
99 match stmt {
100 syn::Stmt::Local(local) => {
101 if let Some(init) = &local.init {
102 if let syn::Pat::Ident(ident) = &local.pat {
103 let value = lower_expr(&init.expr)?;
104 stmts.push(Stmt::Assign {
105 target: ident.ident.to_string(),
106 value,
107 });
108 }
109 }
110 }
111 syn::Stmt::Expr(expr, _) => {
112 if let Some(s) = lower_stmt_expr(expr)? {
113 stmts.push(s);
114 }
115 }
116 _ => {}
117 }
118 }
119 Ok(stmts)
120}
121
122fn lower_stmt_expr(expr: &syn::Expr) -> Result<Option<Stmt>, CompileError> {
123 match expr {
124 syn::Expr::If(if_expr) => {
125 let condition = lower_expr(&if_expr.cond)?;
126 let then_body = lower_block(&if_expr.then_branch)?;
127 let else_body = if let Some((_, else_expr)) = &if_expr.else_branch {
128 if let syn::Expr::Block(block_expr) = &**else_expr {
129 Some(lower_block(&block_expr.block)?)
130 } else {
131 None
132 }
133 } else {
134 None
135 };
136 Ok(Some(Stmt::If {
137 condition,
138 then_body,
139 else_body,
140 }))
141 }
142 syn::Expr::Assign(assign) => {
143 let value = lower_expr(&assign.right)?;
144 if let syn::Expr::Path(path) = &*assign.left {
145 let target = path
146 .path
147 .segments
148 .last()
149 .map(|s| s.ident.to_string())
150 .unwrap_or_default();
151 Ok(Some(Stmt::Assign { target, value }))
152 } else if let syn::Expr::Index(idx) = &*assign.left {
153 let base = lower_expr(&idx.expr)?;
154 let index = lower_expr(&idx.index)?;
155 let elem_size = Expr::Literal(Literal::Int(4));
156 let offset = Expr::BinOp {
157 op: BinOp::Mul,
158 lhs: Box::new(index),
159 rhs: Box::new(elem_size),
160 };
161 let addr = Expr::BinOp {
162 op: BinOp::Add,
163 lhs: Box::new(base),
164 rhs: Box::new(offset),
165 };
166 Ok(Some(Stmt::Store {
167 addr,
168 value,
169 space: AddressSpace::Device,
170 }))
171 } else {
172 Ok(None)
173 }
174 }
175 syn::Expr::Return(ret) => {
176 let value = ret.expr.as_ref().map(|e| lower_expr(e)).transpose()?;
177 Ok(Some(Stmt::Return { value }))
178 }
179 _ => Ok(None),
180 }
181}
182
183fn lower_expr(expr: &syn::Expr) -> Result<Expr, CompileError> {
184 match expr {
185 syn::Expr::Lit(lit) => lower_lit(lit),
186 syn::Expr::Path(path) => {
187 let name = path
188 .path
189 .segments
190 .last()
191 .map(|s| s.ident.to_string())
192 .unwrap_or_default();
193 Ok(Expr::Var(name))
194 }
195 syn::Expr::Binary(bin) => lower_binary(bin),
196 syn::Expr::Call(call) => lower_call(call),
197 syn::Expr::Index(idx) => {
198 let base = lower_expr(&idx.expr)?;
199 let index = lower_expr(&idx.index)?;
200 Ok(Expr::Index {
201 base: Box::new(base),
202 index: Box::new(index),
203 })
204 }
205 syn::Expr::Paren(paren) => lower_expr(&paren.expr),
206 syn::Expr::Unary(unary) => {
207 let operand = lower_expr(&unary.expr)?;
208 match unary.op {
209 syn::UnOp::Neg(_) => Ok(Expr::UnaryOp {
210 op: crate::hir::expr::UnaryOp::Neg,
211 operand: Box::new(operand),
212 }),
213 syn::UnOp::Not(_) => Ok(Expr::UnaryOp {
214 op: crate::hir::expr::UnaryOp::Not,
215 operand: Box::new(operand),
216 }),
217 _ => Err(CompileError::ParseError {
218 message: "unsupported unary op".into(),
219 }),
220 }
221 }
222 _ => Err(CompileError::ParseError {
223 message: "unsupported expression".into(),
224 }),
225 }
226}
227
228fn lower_lit(lit: &syn::ExprLit) -> Result<Expr, CompileError> {
229 match &lit.lit {
230 syn::Lit::Int(i) => {
231 let v: i64 = i.base10_parse().unwrap_or(0);
232 Ok(Expr::Literal(Literal::Int(v)))
233 }
234 syn::Lit::Float(f) => {
235 let v: f64 = f.base10_parse().unwrap_or(0.0);
236 Ok(Expr::Literal(Literal::Float(v)))
237 }
238 syn::Lit::Bool(b) => Ok(Expr::Literal(Literal::Bool(b.value))),
239 _ => Err(CompileError::ParseError {
240 message: "unsupported literal".into(),
241 }),
242 }
243}
244
245fn lower_binary(bin: &syn::ExprBinary) -> Result<Expr, CompileError> {
246 let lhs = lower_expr(&bin.left)?;
247 let rhs = lower_expr(&bin.right)?;
248 let op = match bin.op {
249 syn::BinOp::Add(_) => BinOp::Add,
250 syn::BinOp::Sub(_) => BinOp::Sub,
251 syn::BinOp::Mul(_) => BinOp::Mul,
252 syn::BinOp::Div(_) => BinOp::Div,
253 syn::BinOp::Rem(_) => BinOp::Mod,
254 syn::BinOp::Lt(_) => BinOp::Lt,
255 syn::BinOp::Le(_) => BinOp::Le,
256 syn::BinOp::Gt(_) => BinOp::Gt,
257 syn::BinOp::Ge(_) => BinOp::Ge,
258 syn::BinOp::Eq(_) => BinOp::Eq,
259 syn::BinOp::Ne(_) => BinOp::Ne,
260 syn::BinOp::BitAnd(_) => BinOp::BitAnd,
261 syn::BinOp::BitOr(_) => BinOp::BitOr,
262 syn::BinOp::BitXor(_) => BinOp::BitXor,
263 syn::BinOp::Shl(_) => BinOp::Shl,
264 syn::BinOp::Shr(_) => BinOp::Shr,
265 _ => {
266 return Err(CompileError::ParseError {
267 message: "unsupported binary op".into(),
268 })
269 }
270 };
271 Ok(Expr::BinOp {
272 op,
273 lhs: Box::new(lhs),
274 rhs: Box::new(rhs),
275 })
276}
277
278fn lower_call(call: &syn::ExprCall) -> Result<Expr, CompileError> {
279 if let syn::Expr::Path(path) = &*call.func {
280 let func_name = path
281 .path
282 .segments
283 .last()
284 .map(|s| s.ident.to_string())
285 .unwrap_or_default();
286 match func_name.as_str() {
287 "thread_id" => Ok(Expr::ThreadId(Dimension::X)),
288 "workgroup_id" => Ok(Expr::WorkgroupId(Dimension::X)),
289 "workgroup_size" => Ok(Expr::WorkgroupSize(Dimension::X)),
290 "lane_id" => Ok(Expr::LaneId),
291 "wave_width" => Ok(Expr::WaveWidth),
292 "barrier" => Ok(Expr::Literal(Literal::Int(0))),
293 _ => {
294 let args: Vec<Expr> = call.args.iter().map(lower_expr).collect::<Result<_, _>>()?;
295 Ok(Expr::Call {
296 func: match func_name.as_str() {
297 "sqrt" => crate::hir::expr::BuiltinFunc::Sqrt,
298 "sin" => crate::hir::expr::BuiltinFunc::Sin,
299 "cos" => crate::hir::expr::BuiltinFunc::Cos,
300 "abs" => crate::hir::expr::BuiltinFunc::Abs,
301 "min" => crate::hir::expr::BuiltinFunc::Min,
302 "max" => crate::hir::expr::BuiltinFunc::Max,
303 _ => {
304 return Err(CompileError::ParseError {
305 message: format!("unknown function: {func_name}"),
306 })
307 }
308 },
309 args,
310 })
311 }
312 }
313 } else {
314 Err(CompileError::ParseError {
315 message: "unsupported call".into(),
316 })
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_parse_rust_vector_add() {
326 let source = r#"
327#[kernel]
328fn vector_add(a: &[f32], b: &[f32], out: &mut [f32], n: u32) {
329 let gid = thread_id();
330 if gid < n {
331 let a_val = a[gid];
332 }
333}
334"#;
335 let kernel = parse_rust(source).unwrap();
336 assert_eq!(kernel.name, "vector_add");
337 assert_eq!(kernel.params.len(), 4);
338 assert_eq!(kernel.params[0].ty, Type::Ptr(AddressSpace::Device));
339 assert_eq!(kernel.params[3].ty, Type::U32);
340 }
341
342 #[test]
343 fn test_parse_rust_simple() {
344 let source = r#"
345#[kernel]
346fn test(n: u32) {
347 let x = 42;
348}
349"#;
350 let kernel = parse_rust(source).unwrap();
351 assert_eq!(kernel.name, "test");
352 assert_eq!(kernel.body.len(), 1);
353 }
354}