1use crate::diagnostics::CompileError;
12use crate::hir::expr::{BinOp, BuiltinFunc, Dimension, Expr, Literal, UnaryOp};
13use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
14use crate::hir::stmt::Stmt;
15use crate::hir::types::{AddressSpace, Type};
16
17pub fn parse_python(source: &str) -> Result<Kernel, CompileError> {
23 let lines: Vec<&str> = source.lines().collect();
24 let mut parser = PythonParser::new(&lines);
25 parser.parse_kernel()
26}
27
28struct PythonParser<'a> {
29 lines: &'a [&'a str],
30 pos: usize,
31}
32
33impl<'a> PythonParser<'a> {
34 fn new(lines: &'a [&'a str]) -> Self {
35 Self { lines, pos: 0 }
36 }
37
38 fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
39 while self.pos < self.lines.len() {
40 let line = self.lines[self.pos].trim();
41 if line.is_empty()
42 || line.starts_with('#')
43 || line.starts_with("from ")
44 || line.starts_with("import ")
45 {
46 self.pos += 1;
47 continue;
48 }
49 if line == "@kernel" {
50 self.pos += 1;
51 continue;
52 }
53 if line.starts_with("def ") {
54 return self.parse_def();
55 }
56 self.pos += 1;
57 }
58 Err(CompileError::ParseError {
59 message: "no kernel function found".into(),
60 })
61 }
62
63 fn parse_def(&mut self) -> Result<Kernel, CompileError> {
64 let line = self.lines[self.pos].trim();
65 let after_def = line
66 .strip_prefix("def ")
67 .ok_or_else(|| CompileError::ParseError {
68 message: "expected 'def'".into(),
69 })?;
70
71 let paren_start = after_def
72 .find('(')
73 .ok_or_else(|| CompileError::ParseError {
74 message: "expected '(' in function definition".into(),
75 })?;
76 let name = after_def[..paren_start].trim().to_string();
77
78 let paren_end = after_def
79 .find(')')
80 .ok_or_else(|| CompileError::ParseError {
81 message: "expected ')' in function definition".into(),
82 })?;
83 let params_str = &after_def[paren_start + 1..paren_end];
84 let params = Self::parse_params(params_str);
85
86 self.pos += 1;
87
88 let indent = self.get_body_indent()?;
89 let body = self.parse_body(indent)?;
90
91 Ok(Kernel {
92 name,
93 params,
94 body,
95 attributes: KernelAttributes::default(),
96 })
97 }
98
99 fn parse_params(params_str: &str) -> Vec<KernelParam> {
100 let mut params = Vec::new();
101 for param_token in params_str.split(',') {
102 let param_token = param_token.trim();
103 if param_token.is_empty() {
104 continue;
105 }
106 let parts: Vec<&str> = param_token.splitn(2, ':').collect();
107 let param_name = parts[0].trim().to_string();
108 let (ty, addr_space) = if parts.len() > 1 {
109 Self::parse_type_annotation(parts[1].trim())
110 } else {
111 (Type::U32, AddressSpace::Private)
112 };
113 params.push(KernelParam {
114 name: param_name,
115 ty,
116 address_space: addr_space,
117 });
118 }
119 params
120 }
121
122 fn parse_type_annotation(ann: &str) -> (Type, AddressSpace) {
123 match ann {
124 "u32" | "int" => (Type::U32, AddressSpace::Private),
125 "i32" => (Type::I32, AddressSpace::Private),
126 "f32" | "float" => (Type::F32, AddressSpace::Private),
127 "f16" => (Type::F16, AddressSpace::Private),
128 "f64" => (Type::F64, AddressSpace::Private),
129 "bool" => (Type::Bool, AddressSpace::Private),
130 s if s.contains("[:]") || s.contains("[]") => {
131 (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
132 }
133 _ => (Type::U32, AddressSpace::Private),
134 }
135 }
136
137 fn get_body_indent(&self) -> Result<usize, CompileError> {
138 if self.pos >= self.lines.len() {
139 return Err(CompileError::ParseError {
140 message: "expected function body".into(),
141 });
142 }
143 let line = self.lines[self.pos];
144 Ok(line.len() - line.trim_start().len())
145 }
146
147 fn parse_body(&mut self, indent: usize) -> Result<Vec<Stmt>, CompileError> {
148 let mut stmts = Vec::new();
149 while self.pos < self.lines.len() {
150 let line = self.lines[self.pos];
151 if line.trim().is_empty() || line.trim().starts_with('#') {
152 self.pos += 1;
153 continue;
154 }
155 let current_indent = line.len() - line.trim_start().len();
156 if current_indent < indent {
157 break;
158 }
159 let trimmed = line.trim();
160 if trimmed.starts_with("if ") {
161 stmts.push(self.parse_if()?);
162 } else if trimmed.starts_with("for ") {
163 stmts.push(self.parse_for()?);
164 } else if trimmed.starts_with("while ") {
165 stmts.push(self.parse_while()?);
166 } else if trimmed == "return" || trimmed.starts_with("return ") {
167 stmts.push(self.parse_return()?);
168 } else if trimmed == "barrier()" {
169 stmts.push(Stmt::Barrier);
170 self.pos += 1;
171 } else if trimmed.contains('=') && !trimmed.contains("==") {
172 stmts.push(self.parse_assignment()?);
173 } else {
174 self.pos += 1;
175 }
176 }
177 Ok(stmts)
178 }
179
180 fn parse_if(&mut self) -> Result<Stmt, CompileError> {
181 let line = self.lines[self.pos].trim();
182 let cond_str = line
183 .strip_prefix("if ")
184 .and_then(|s| s.strip_suffix(':'))
185 .ok_or_else(|| CompileError::ParseError {
186 message: format!("invalid if statement: {line}"),
187 })?;
188 let condition = self.parse_expr(cond_str.trim())?;
189 self.pos += 1;
190
191 let then_indent = self.get_body_indent()?;
192 let then_body = self.parse_body(then_indent)?;
193
194 let else_body = if self.pos < self.lines.len() {
195 let next = self.lines[self.pos].trim();
196 if next.starts_with("else:") || next.starts_with("elif ") {
197 self.pos += 1;
198 let else_indent = self.get_body_indent()?;
199 Some(self.parse_body(else_indent)?)
200 } else {
201 None
202 }
203 } else {
204 None
205 };
206
207 Ok(Stmt::If {
208 condition,
209 then_body,
210 else_body,
211 })
212 }
213
214 fn parse_for(&mut self) -> Result<Stmt, CompileError> {
215 let line = self.lines[self.pos].trim();
216 let inner = line
217 .strip_prefix("for ")
218 .and_then(|s| s.strip_suffix(':'))
219 .ok_or_else(|| CompileError::ParseError {
220 message: format!("invalid for statement: {line}"),
221 })?;
222
223 let parts: Vec<&str> = inner.splitn(2, " in ").collect();
224 if parts.len() != 2 {
225 return Err(CompileError::ParseError {
226 message: format!("invalid for statement: {line}"),
227 });
228 }
229 let var = parts[0].trim().to_string();
230 let range_str = parts[1].trim();
231
232 let (start, end, step) = self.parse_range(range_str)?;
233
234 self.pos += 1;
235 let body_indent = self.get_body_indent()?;
236 let body = self.parse_body(body_indent)?;
237
238 Ok(Stmt::For {
239 var,
240 start,
241 end,
242 step,
243 body,
244 })
245 }
246
247 fn parse_range(&self, s: &str) -> Result<(Expr, Expr, Expr), CompileError> {
248 let inner = s
249 .strip_prefix("range(")
250 .and_then(|s| s.strip_suffix(')'))
251 .ok_or_else(|| CompileError::ParseError {
252 message: format!("expected range(...), got {s}"),
253 })?;
254
255 let args: Vec<&str> = inner.split(',').collect();
256 match args.len() {
257 1 => Ok((
258 Expr::Literal(Literal::Int(0)),
259 self.parse_expr(args[0].trim())?,
260 Expr::Literal(Literal::Int(1)),
261 )),
262 2 => Ok((
263 self.parse_expr(args[0].trim())?,
264 self.parse_expr(args[1].trim())?,
265 Expr::Literal(Literal::Int(1)),
266 )),
267 3 => Ok((
268 self.parse_expr(args[0].trim())?,
269 self.parse_expr(args[1].trim())?,
270 self.parse_expr(args[2].trim())?,
271 )),
272 _ => Err(CompileError::ParseError {
273 message: "range() takes 1-3 arguments".into(),
274 }),
275 }
276 }
277
278 fn parse_while(&mut self) -> Result<Stmt, CompileError> {
279 let line = self.lines[self.pos].trim();
280 let cond_str = line
281 .strip_prefix("while ")
282 .and_then(|s| s.strip_suffix(':'))
283 .ok_or_else(|| CompileError::ParseError {
284 message: format!("invalid while statement: {line}"),
285 })?;
286 let condition = self.parse_expr(cond_str.trim())?;
287 self.pos += 1;
288
289 let body_indent = self.get_body_indent()?;
290 let body = self.parse_body(body_indent)?;
291
292 Ok(Stmt::While { condition, body })
293 }
294
295 fn parse_return(&mut self) -> Result<Stmt, CompileError> {
296 let line = self.lines[self.pos].trim();
297 self.pos += 1;
298 if line == "return" {
299 return Ok(Stmt::Return { value: None });
300 }
301 let val_str = line.strip_prefix("return ").unwrap_or("");
302 if val_str.is_empty() {
303 Ok(Stmt::Return { value: None })
304 } else {
305 Ok(Stmt::Return {
306 value: Some(self.parse_expr(val_str)?),
307 })
308 }
309 }
310
311 fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
312 let line = self.lines[self.pos].trim().to_string();
313 self.pos += 1;
314
315 if let Some(bracket_pos) = line.find('[') {
316 if let Some(eq_pos) = line.find('=') {
317 if bracket_pos < eq_pos
318 && !line[..eq_pos].ends_with('!')
319 && !line[..eq_pos].ends_with('<')
320 && !line[..eq_pos].ends_with('>')
321 {
322 let base_name = line[..bracket_pos].trim();
323 let bracket_end =
324 line[..eq_pos]
325 .rfind(']')
326 .ok_or_else(|| CompileError::ParseError {
327 message: format!("missing ']' in: {line}"),
328 })?;
329 let index_str = &line[bracket_pos + 1..bracket_end];
330 let value_str = line[eq_pos + 1..].trim();
331
332 let base = self.parse_expr(base_name)?;
333 let index = self.parse_expr(index_str)?;
334 let value = self.parse_expr(value_str)?;
335
336 let elem_size = Expr::Literal(Literal::Int(4));
337 let offset = Expr::BinOp {
338 op: BinOp::Mul,
339 lhs: Box::new(index),
340 rhs: Box::new(elem_size),
341 };
342 let addr = Expr::BinOp {
343 op: BinOp::Add,
344 lhs: Box::new(base),
345 rhs: Box::new(offset),
346 };
347
348 return Ok(Stmt::Store {
349 addr,
350 value,
351 space: AddressSpace::Device,
352 });
353 }
354 }
355 }
356
357 let eq_pos = line.find('=').ok_or_else(|| CompileError::ParseError {
358 message: format!("expected '=' in assignment: {line}"),
359 })?;
360
361 if eq_pos > 0
362 && (line.as_bytes()[eq_pos - 1] == b'!'
363 || line.as_bytes()[eq_pos - 1] == b'<'
364 || line.as_bytes()[eq_pos - 1] == b'>')
365 {
366 return Err(CompileError::ParseError {
367 message: format!("unexpected operator in: {line}"),
368 });
369 }
370 if eq_pos + 1 < line.len() && line.as_bytes()[eq_pos + 1] == b'=' {
371 return Err(CompileError::ParseError {
372 message: format!("comparison in assignment position: {line}"),
373 });
374 }
375
376 let raw_target = line[..eq_pos].trim();
377 let target = if let Some(colon_pos) = raw_target.find(':') {
378 raw_target[..colon_pos].trim().to_string()
379 } else {
380 raw_target.to_string()
381 };
382 let value_str = line[eq_pos + 1..].trim();
383 let value = self.parse_expr(value_str)?;
384
385 Ok(Stmt::Assign { target, value })
386 }
387
388 fn parse_expr(&self, s: &str) -> Result<Expr, CompileError> {
389 let s = s.trim();
390
391 for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
392 if let Some(pos) = find_top_level_op(s, op_str) {
393 let lhs = self.parse_expr(&s[..pos])?;
394 let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
395 return Ok(Expr::BinOp {
396 op,
397 lhs: Box::new(lhs),
398 rhs: Box::new(rhs),
399 });
400 }
401 }
402
403 for &(op_str, op) in &[
404 (" * ", BinOp::Mul),
405 (" // ", BinOp::FloorDiv),
406 (" / ", BinOp::Div),
407 (" % ", BinOp::Mod),
408 ] {
409 if let Some(pos) = find_top_level_op(s, op_str) {
410 let lhs = self.parse_expr(&s[..pos])?;
411 let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
412 return Ok(Expr::BinOp {
413 op,
414 lhs: Box::new(lhs),
415 rhs: Box::new(rhs),
416 });
417 }
418 }
419
420 for &(op_str, op) in &[
421 (" < ", BinOp::Lt),
422 (" <= ", BinOp::Le),
423 (" > ", BinOp::Gt),
424 (" >= ", BinOp::Ge),
425 (" == ", BinOp::Eq),
426 (" != ", BinOp::Ne),
427 ] {
428 if let Some(pos) = find_top_level_op(s, op_str) {
429 let lhs = self.parse_expr(&s[..pos])?;
430 let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
431 return Ok(Expr::BinOp {
432 op,
433 lhs: Box::new(lhs),
434 rhs: Box::new(rhs),
435 });
436 }
437 }
438
439 for &(op_str, op) in &[
440 (" & ", BinOp::BitAnd),
441 (" | ", BinOp::BitOr),
442 (" ^ ", BinOp::BitXor),
443 (" << ", BinOp::Shl),
444 (" >> ", BinOp::Shr),
445 ] {
446 if let Some(pos) = find_top_level_op(s, op_str) {
447 let lhs = self.parse_expr(&s[..pos])?;
448 let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
449 return Ok(Expr::BinOp {
450 op,
451 lhs: Box::new(lhs),
452 rhs: Box::new(rhs),
453 });
454 }
455 }
456
457 if s.starts_with('(') && s.ends_with(')') {
458 return self.parse_expr(&s[1..s.len() - 1]);
459 }
460
461 if s.starts_with('-') && s.len() > 1 {
462 let inner = self.parse_expr(&s[1..])?;
463 return Ok(Expr::UnaryOp {
464 op: UnaryOp::Neg,
465 operand: Box::new(inner),
466 });
467 }
468
469 self.parse_atom(s)
470 }
471
472 fn parse_atom(&self, s: &str) -> Result<Expr, CompileError> {
473 let s = s.trim();
474
475 match s {
476 "thread_id()" | "thread_id_x()" => return Ok(Expr::ThreadId(Dimension::X)),
477 "thread_id_y()" => return Ok(Expr::ThreadId(Dimension::Y)),
478 "thread_id_z()" => return Ok(Expr::ThreadId(Dimension::Z)),
479 "workgroup_id()" | "workgroup_id_x()" => return Ok(Expr::WorkgroupId(Dimension::X)),
480 "workgroup_size()" | "workgroup_size_x()" => {
481 return Ok(Expr::WorkgroupSize(Dimension::X))
482 }
483 "lane_id()" => return Ok(Expr::LaneId),
484 "wave_width()" => return Ok(Expr::WaveWidth),
485 "True" | "true" => return Ok(Expr::Literal(Literal::Bool(true))),
486 "False" | "false" => return Ok(Expr::Literal(Literal::Bool(false))),
487 _ => {}
488 }
489
490 if let Some(paren_pos) = s.find('(') {
491 if s.ends_with(')') {
492 let func_name = &s[..paren_pos];
493 let args_str = &s[paren_pos + 1..s.len() - 1];
494 return self.parse_call(func_name, args_str);
495 }
496 }
497
498 if let Some(bracket_pos) = s.find('[') {
499 if s.ends_with(']') {
500 let base = &s[..bracket_pos];
501 let index = &s[bracket_pos + 1..s.len() - 1];
502 return Ok(Expr::Index {
503 base: Box::new(self.parse_expr(base)?),
504 index: Box::new(self.parse_expr(index)?),
505 });
506 }
507 }
508
509 if let Ok(v) = s.parse::<i64>() {
510 return Ok(Expr::Literal(Literal::Int(v)));
511 }
512 if let Ok(v) = s.parse::<f64>() {
513 return Ok(Expr::Literal(Literal::Float(v)));
514 }
515
516 if s.starts_with("0x") || s.starts_with("0X") {
517 if let Ok(v) = i64::from_str_radix(&s[2..], 16) {
518 return Ok(Expr::Literal(Literal::Int(v)));
519 }
520 }
521
522 if is_valid_identifier(s) {
523 return Ok(Expr::Var(s.to_string()));
524 }
525
526 Err(CompileError::ParseError {
527 message: format!("cannot parse expression: '{s}'"),
528 })
529 }
530
531 fn parse_call(&self, func_name: &str, args_str: &str) -> Result<Expr, CompileError> {
532 let args: Vec<Expr> = if args_str.trim().is_empty() {
533 Vec::new()
534 } else {
535 args_str
536 .split(',')
537 .map(|a| self.parse_expr(a.trim()))
538 .collect::<Result<_, _>>()?
539 };
540
541 match func_name {
542 "sqrt" => Ok(Expr::Call {
543 func: BuiltinFunc::Sqrt,
544 args,
545 }),
546 "sin" => Ok(Expr::Call {
547 func: BuiltinFunc::Sin,
548 args,
549 }),
550 "cos" => Ok(Expr::Call {
551 func: BuiltinFunc::Cos,
552 args,
553 }),
554 "exp2" => Ok(Expr::Call {
555 func: BuiltinFunc::Exp2,
556 args,
557 }),
558 "log2" => Ok(Expr::Call {
559 func: BuiltinFunc::Log2,
560 args,
561 }),
562 "abs" => Ok(Expr::Call {
563 func: BuiltinFunc::Abs,
564 args,
565 }),
566 "min" => Ok(Expr::Call {
567 func: BuiltinFunc::Min,
568 args,
569 }),
570 "max" => Ok(Expr::Call {
571 func: BuiltinFunc::Max,
572 args,
573 }),
574 "atomic_add" => Ok(Expr::Call {
575 func: BuiltinFunc::AtomicAdd,
576 args,
577 }),
578 "thread_id" => Ok(Expr::ThreadId(Dimension::X)),
579 "workgroup_id" => Ok(Expr::WorkgroupId(Dimension::X)),
580 "workgroup_size" => Ok(Expr::WorkgroupSize(Dimension::X)),
581 "lane_id" => Ok(Expr::LaneId),
582 "wave_width" => Ok(Expr::WaveWidth),
583 "int" | "u32" => {
584 if args.len() == 1 {
585 Ok(Expr::Cast {
586 expr: Box::new(args.into_iter().next().unwrap()),
587 to: Type::U32,
588 })
589 } else {
590 Err(CompileError::ParseError {
591 message: "int() takes 1 argument".to_string(),
592 })
593 }
594 }
595 "float" | "f32" => {
596 if args.len() == 1 {
597 Ok(Expr::Cast {
598 expr: Box::new(args.into_iter().next().unwrap()),
599 to: Type::F32,
600 })
601 } else {
602 Err(CompileError::ParseError {
603 message: "float() takes 1 argument".to_string(),
604 })
605 }
606 }
607 _ => Err(CompileError::ParseError {
608 message: format!("unknown function: {func_name}"),
609 }),
610 }
611 }
612}
613
614fn find_top_level_op(s: &str, op: &str) -> Option<usize> {
615 let mut depth = 0i32;
616 let bytes = s.as_bytes();
617 let op_bytes = op.as_bytes();
618 let op_len = op.len();
619
620 if s.len() < op_len {
621 return None;
622 }
623
624 let mut i = s.len() - op_len;
625 loop {
626 let ch = bytes[i + op_len - 1];
627 match ch {
628 b')' | b']' => depth += 1,
629 b'(' | b'[' => depth -= 1,
630 _ => {}
631 }
632 if depth == 0 && &bytes[i..i + op_len] == op_bytes {
633 return Some(i);
634 }
635 if i == 0 {
636 break;
637 }
638 i -= 1;
639 }
640 None
641}
642
643fn is_valid_identifier(s: &str) -> bool {
644 if s.is_empty() {
645 return false;
646 }
647 let mut chars = s.chars();
648 let first = chars.next().unwrap();
649 if !first.is_alphabetic() && first != '_' {
650 return false;
651 }
652 chars.all(|c| c.is_alphanumeric() || c == '_')
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_parse_vector_add() {
661 let source = r#"
662from wave import kernel, f32, thread_id
663
664@kernel
665def vector_add(a: f32[:], b: f32[:], out: f32[:], n: u32):
666 gid = thread_id()
667 if gid < n:
668 out[gid] = a[gid] + b[gid]
669"#;
670 let kernel = parse_python(source).unwrap();
671 assert_eq!(kernel.name, "vector_add");
672 assert_eq!(kernel.params.len(), 4);
673 assert_eq!(kernel.params[0].name, "a");
674 assert_eq!(kernel.params[0].ty, Type::Ptr(AddressSpace::Device));
675 assert_eq!(kernel.params[3].name, "n");
676 assert_eq!(kernel.params[3].ty, Type::U32);
677 assert_eq!(kernel.body.len(), 2);
678 }
679
680 #[test]
681 fn test_parse_simple_assign() {
682 let source = r#"
683@kernel
684def test(n: u32):
685 x = 42
686 y = x + 1
687"#;
688 let kernel = parse_python(source).unwrap();
689 assert_eq!(kernel.name, "test");
690 assert_eq!(kernel.body.len(), 2);
691 }
692
693 #[test]
694 fn test_parse_expressions() {
695 let parser = PythonParser::new(&[]);
696 let expr = parser.parse_expr("a + b * c").unwrap();
697 match &expr {
698 Expr::BinOp { op: BinOp::Add, .. } => {}
699 _ => panic!("expected Add at top level"),
700 }
701 }
702
703 #[test]
704 fn test_parse_array_index() {
705 let parser = PythonParser::new(&[]);
706 let expr = parser.parse_expr("a[i]").unwrap();
707 match &expr {
708 Expr::Index { base, index } => {
709 assert_eq!(**base, Expr::Var("a".into()));
710 assert_eq!(**index, Expr::Var("i".into()));
711 }
712 _ => panic!("expected Index"),
713 }
714 }
715
716 #[test]
717 fn test_parse_thread_id() {
718 let parser = PythonParser::new(&[]);
719 let expr = parser.parse_expr("thread_id()").unwrap();
720 assert_eq!(expr, Expr::ThreadId(Dimension::X));
721 }
722
723 #[test]
724 fn test_parse_literal() {
725 let parser = PythonParser::new(&[]);
726 assert_eq!(
727 parser.parse_expr("42").unwrap(),
728 Expr::Literal(Literal::Int(42))
729 );
730 assert_eq!(
731 parser.parse_expr("3.14").unwrap(),
732 Expr::Literal(Literal::Float(3.14))
733 );
734 }
735}