Skip to main content

wave_compiler/frontend/
python.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Python kernel parser producing HIR.
5//!
6//! Parses a subset of Python suitable for GPU kernels: function definitions
7//! with type annotations, arithmetic, comparisons, if/else, for `range()`,
8//! array indexing, and GPU intrinsics (`thread_id`, barrier, etc.).
9//! Uses line-by-line parsing for the restricted kernel subset.
10
11use crate::diagnostics::CompileError;
12use crate::hir::expr::{BinOp, BuiltinFunc, Dimension, Expr, Literal, UnaryOp};
13use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
14use crate::hir::stmt::Stmt;
15use crate::hir::types::{AddressSpace, Type};
16
17/// Parse a Python kernel source string into an HIR Kernel.
18///
19/// # Errors
20///
21/// Returns `CompileError::ParseError` if the source cannot be parsed.
22pub fn parse_python(source: &str) -> Result<Kernel, CompileError> {
23    let lines: Vec<&str> = source.lines().collect();
24    let mut parser = PythonParser::new(&lines);
25    parser.parse_kernel()
26}
27
28struct PythonParser<'a> {
29    lines: &'a [&'a str],
30    pos: usize,
31}
32
33impl<'a> PythonParser<'a> {
34    fn new(lines: &'a [&'a str]) -> Self {
35        Self { lines, pos: 0 }
36    }
37
38    fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
39        while self.pos < self.lines.len() {
40            let line = self.lines[self.pos].trim();
41            if line.is_empty()
42                || line.starts_with('#')
43                || line.starts_with("from ")
44                || line.starts_with("import ")
45            {
46                self.pos += 1;
47                continue;
48            }
49            if line == "@kernel" {
50                self.pos += 1;
51                continue;
52            }
53            if line.starts_with("def ") {
54                return self.parse_def();
55            }
56            self.pos += 1;
57        }
58        Err(CompileError::ParseError {
59            message: "no kernel function found".into(),
60        })
61    }
62
63    fn parse_def(&mut self) -> Result<Kernel, CompileError> {
64        let line = self.lines[self.pos].trim();
65        let after_def = line
66            .strip_prefix("def ")
67            .ok_or_else(|| CompileError::ParseError {
68                message: "expected 'def'".into(),
69            })?;
70
71        let paren_start = after_def
72            .find('(')
73            .ok_or_else(|| CompileError::ParseError {
74                message: "expected '(' in function definition".into(),
75            })?;
76        let name = after_def[..paren_start].trim().to_string();
77
78        let paren_end = after_def
79            .find(')')
80            .ok_or_else(|| CompileError::ParseError {
81                message: "expected ')' in function definition".into(),
82            })?;
83        let params_str = &after_def[paren_start + 1..paren_end];
84        let params = Self::parse_params(params_str);
85
86        self.pos += 1;
87
88        let indent = self.get_body_indent()?;
89        let body = self.parse_body(indent)?;
90
91        Ok(Kernel {
92            name,
93            params,
94            body,
95            attributes: KernelAttributes::default(),
96        })
97    }
98
99    fn parse_params(params_str: &str) -> Vec<KernelParam> {
100        let mut params = Vec::new();
101        for param_token in params_str.split(',') {
102            let param_token = param_token.trim();
103            if param_token.is_empty() {
104                continue;
105            }
106            let parts: Vec<&str> = param_token.splitn(2, ':').collect();
107            let param_name = parts[0].trim().to_string();
108            let (ty, addr_space) = if parts.len() > 1 {
109                Self::parse_type_annotation(parts[1].trim())
110            } else {
111                (Type::U32, AddressSpace::Private)
112            };
113            params.push(KernelParam {
114                name: param_name,
115                ty,
116                address_space: addr_space,
117            });
118        }
119        params
120    }
121
122    fn parse_type_annotation(ann: &str) -> (Type, AddressSpace) {
123        match ann {
124            "u32" | "int" => (Type::U32, AddressSpace::Private),
125            "i32" => (Type::I32, AddressSpace::Private),
126            "f32" | "float" => (Type::F32, AddressSpace::Private),
127            "f16" => (Type::F16, AddressSpace::Private),
128            "f64" => (Type::F64, AddressSpace::Private),
129            "bool" => (Type::Bool, AddressSpace::Private),
130            s if s.contains("[:]") || s.contains("[]") => {
131                (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
132            }
133            _ => (Type::U32, AddressSpace::Private),
134        }
135    }
136
137    fn get_body_indent(&self) -> Result<usize, CompileError> {
138        if self.pos >= self.lines.len() {
139            return Err(CompileError::ParseError {
140                message: "expected function body".into(),
141            });
142        }
143        let line = self.lines[self.pos];
144        Ok(line.len() - line.trim_start().len())
145    }
146
147    fn parse_body(&mut self, indent: usize) -> Result<Vec<Stmt>, CompileError> {
148        let mut stmts = Vec::new();
149        while self.pos < self.lines.len() {
150            let line = self.lines[self.pos];
151            if line.trim().is_empty() || line.trim().starts_with('#') {
152                self.pos += 1;
153                continue;
154            }
155            let current_indent = line.len() - line.trim_start().len();
156            if current_indent < indent {
157                break;
158            }
159            let trimmed = line.trim();
160            if trimmed.starts_with("if ") {
161                stmts.push(self.parse_if()?);
162            } else if trimmed.starts_with("for ") {
163                stmts.push(self.parse_for()?);
164            } else if trimmed.starts_with("while ") {
165                stmts.push(self.parse_while()?);
166            } else if trimmed == "return" || trimmed.starts_with("return ") {
167                stmts.push(self.parse_return()?);
168            } else if trimmed == "barrier()" {
169                stmts.push(Stmt::Barrier);
170                self.pos += 1;
171            } else if trimmed.contains('=') && !trimmed.contains("==") {
172                stmts.push(self.parse_assignment()?);
173            } else {
174                self.pos += 1;
175            }
176        }
177        Ok(stmts)
178    }
179
180    fn parse_if(&mut self) -> Result<Stmt, CompileError> {
181        let line = self.lines[self.pos].trim();
182        let cond_str = line
183            .strip_prefix("if ")
184            .and_then(|s| s.strip_suffix(':'))
185            .ok_or_else(|| CompileError::ParseError {
186                message: format!("invalid if statement: {line}"),
187            })?;
188        let condition = self.parse_expr(cond_str.trim())?;
189        self.pos += 1;
190
191        let then_indent = self.get_body_indent()?;
192        let then_body = self.parse_body(then_indent)?;
193
194        let else_body = if self.pos < self.lines.len() {
195            let next = self.lines[self.pos].trim();
196            if next.starts_with("else:") || next.starts_with("elif ") {
197                self.pos += 1;
198                let else_indent = self.get_body_indent()?;
199                Some(self.parse_body(else_indent)?)
200            } else {
201                None
202            }
203        } else {
204            None
205        };
206
207        Ok(Stmt::If {
208            condition,
209            then_body,
210            else_body,
211        })
212    }
213
214    fn parse_for(&mut self) -> Result<Stmt, CompileError> {
215        let line = self.lines[self.pos].trim();
216        let inner = line
217            .strip_prefix("for ")
218            .and_then(|s| s.strip_suffix(':'))
219            .ok_or_else(|| CompileError::ParseError {
220                message: format!("invalid for statement: {line}"),
221            })?;
222
223        let parts: Vec<&str> = inner.splitn(2, " in ").collect();
224        if parts.len() != 2 {
225            return Err(CompileError::ParseError {
226                message: format!("invalid for statement: {line}"),
227            });
228        }
229        let var = parts[0].trim().to_string();
230        let range_str = parts[1].trim();
231
232        let (start, end, step) = self.parse_range(range_str)?;
233
234        self.pos += 1;
235        let body_indent = self.get_body_indent()?;
236        let body = self.parse_body(body_indent)?;
237
238        Ok(Stmt::For {
239            var,
240            start,
241            end,
242            step,
243            body,
244        })
245    }
246
247    fn parse_range(&self, s: &str) -> Result<(Expr, Expr, Expr), CompileError> {
248        let inner = s
249            .strip_prefix("range(")
250            .and_then(|s| s.strip_suffix(')'))
251            .ok_or_else(|| CompileError::ParseError {
252                message: format!("expected range(...), got {s}"),
253            })?;
254
255        let args: Vec<&str> = inner.split(',').collect();
256        match args.len() {
257            1 => Ok((
258                Expr::Literal(Literal::Int(0)),
259                self.parse_expr(args[0].trim())?,
260                Expr::Literal(Literal::Int(1)),
261            )),
262            2 => Ok((
263                self.parse_expr(args[0].trim())?,
264                self.parse_expr(args[1].trim())?,
265                Expr::Literal(Literal::Int(1)),
266            )),
267            3 => Ok((
268                self.parse_expr(args[0].trim())?,
269                self.parse_expr(args[1].trim())?,
270                self.parse_expr(args[2].trim())?,
271            )),
272            _ => Err(CompileError::ParseError {
273                message: "range() takes 1-3 arguments".into(),
274            }),
275        }
276    }
277
278    fn parse_while(&mut self) -> Result<Stmt, CompileError> {
279        let line = self.lines[self.pos].trim();
280        let cond_str = line
281            .strip_prefix("while ")
282            .and_then(|s| s.strip_suffix(':'))
283            .ok_or_else(|| CompileError::ParseError {
284                message: format!("invalid while statement: {line}"),
285            })?;
286        let condition = self.parse_expr(cond_str.trim())?;
287        self.pos += 1;
288
289        let body_indent = self.get_body_indent()?;
290        let body = self.parse_body(body_indent)?;
291
292        Ok(Stmt::While { condition, body })
293    }
294
295    fn parse_return(&mut self) -> Result<Stmt, CompileError> {
296        let line = self.lines[self.pos].trim();
297        self.pos += 1;
298        if line == "return" {
299            return Ok(Stmt::Return { value: None });
300        }
301        let val_str = line.strip_prefix("return ").unwrap_or("");
302        if val_str.is_empty() {
303            Ok(Stmt::Return { value: None })
304        } else {
305            Ok(Stmt::Return {
306                value: Some(self.parse_expr(val_str)?),
307            })
308        }
309    }
310
311    fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
312        let line = self.lines[self.pos].trim().to_string();
313        self.pos += 1;
314
315        if let Some(bracket_pos) = line.find('[') {
316            if let Some(eq_pos) = line.find('=') {
317                if bracket_pos < eq_pos
318                    && !line[..eq_pos].ends_with('!')
319                    && !line[..eq_pos].ends_with('<')
320                    && !line[..eq_pos].ends_with('>')
321                {
322                    let base_name = line[..bracket_pos].trim();
323                    let bracket_end =
324                        line[..eq_pos]
325                            .rfind(']')
326                            .ok_or_else(|| CompileError::ParseError {
327                                message: format!("missing ']' in: {line}"),
328                            })?;
329                    let index_str = &line[bracket_pos + 1..bracket_end];
330                    let value_str = line[eq_pos + 1..].trim();
331
332                    let base = self.parse_expr(base_name)?;
333                    let index = self.parse_expr(index_str)?;
334                    let value = self.parse_expr(value_str)?;
335
336                    let elem_size = Expr::Literal(Literal::Int(4));
337                    let offset = Expr::BinOp {
338                        op: BinOp::Mul,
339                        lhs: Box::new(index),
340                        rhs: Box::new(elem_size),
341                    };
342                    let addr = Expr::BinOp {
343                        op: BinOp::Add,
344                        lhs: Box::new(base),
345                        rhs: Box::new(offset),
346                    };
347
348                    return Ok(Stmt::Store {
349                        addr,
350                        value,
351                        space: AddressSpace::Device,
352                    });
353                }
354            }
355        }
356
357        let eq_pos = line.find('=').ok_or_else(|| CompileError::ParseError {
358            message: format!("expected '=' in assignment: {line}"),
359        })?;
360
361        if eq_pos > 0
362            && (line.as_bytes()[eq_pos - 1] == b'!'
363                || line.as_bytes()[eq_pos - 1] == b'<'
364                || line.as_bytes()[eq_pos - 1] == b'>')
365        {
366            return Err(CompileError::ParseError {
367                message: format!("unexpected operator in: {line}"),
368            });
369        }
370        if eq_pos + 1 < line.len() && line.as_bytes()[eq_pos + 1] == b'=' {
371            return Err(CompileError::ParseError {
372                message: format!("comparison in assignment position: {line}"),
373            });
374        }
375
376        let raw_target = line[..eq_pos].trim();
377        let target = if let Some(colon_pos) = raw_target.find(':') {
378            raw_target[..colon_pos].trim().to_string()
379        } else {
380            raw_target.to_string()
381        };
382        let value_str = line[eq_pos + 1..].trim();
383        let value = self.parse_expr(value_str)?;
384
385        Ok(Stmt::Assign { target, value })
386    }
387
388    fn parse_expr(&self, s: &str) -> Result<Expr, CompileError> {
389        let s = s.trim();
390
391        for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
392            if let Some(pos) = find_top_level_op(s, op_str) {
393                let lhs = self.parse_expr(&s[..pos])?;
394                let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
395                return Ok(Expr::BinOp {
396                    op,
397                    lhs: Box::new(lhs),
398                    rhs: Box::new(rhs),
399                });
400            }
401        }
402
403        for &(op_str, op) in &[
404            (" * ", BinOp::Mul),
405            (" // ", BinOp::FloorDiv),
406            (" / ", BinOp::Div),
407            (" % ", BinOp::Mod),
408        ] {
409            if let Some(pos) = find_top_level_op(s, op_str) {
410                let lhs = self.parse_expr(&s[..pos])?;
411                let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
412                return Ok(Expr::BinOp {
413                    op,
414                    lhs: Box::new(lhs),
415                    rhs: Box::new(rhs),
416                });
417            }
418        }
419
420        for &(op_str, op) in &[
421            (" < ", BinOp::Lt),
422            (" <= ", BinOp::Le),
423            (" > ", BinOp::Gt),
424            (" >= ", BinOp::Ge),
425            (" == ", BinOp::Eq),
426            (" != ", BinOp::Ne),
427        ] {
428            if let Some(pos) = find_top_level_op(s, op_str) {
429                let lhs = self.parse_expr(&s[..pos])?;
430                let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
431                return Ok(Expr::BinOp {
432                    op,
433                    lhs: Box::new(lhs),
434                    rhs: Box::new(rhs),
435                });
436            }
437        }
438
439        for &(op_str, op) in &[
440            (" & ", BinOp::BitAnd),
441            (" | ", BinOp::BitOr),
442            (" ^ ", BinOp::BitXor),
443            (" << ", BinOp::Shl),
444            (" >> ", BinOp::Shr),
445        ] {
446            if let Some(pos) = find_top_level_op(s, op_str) {
447                let lhs = self.parse_expr(&s[..pos])?;
448                let rhs = self.parse_expr(&s[pos + op_str.len()..])?;
449                return Ok(Expr::BinOp {
450                    op,
451                    lhs: Box::new(lhs),
452                    rhs: Box::new(rhs),
453                });
454            }
455        }
456
457        if s.starts_with('(') && s.ends_with(')') {
458            return self.parse_expr(&s[1..s.len() - 1]);
459        }
460
461        if s.starts_with('-') && s.len() > 1 {
462            let inner = self.parse_expr(&s[1..])?;
463            return Ok(Expr::UnaryOp {
464                op: UnaryOp::Neg,
465                operand: Box::new(inner),
466            });
467        }
468
469        self.parse_atom(s)
470    }
471
472    fn parse_atom(&self, s: &str) -> Result<Expr, CompileError> {
473        let s = s.trim();
474
475        match s {
476            "thread_id()" | "thread_id_x()" => return Ok(Expr::ThreadId(Dimension::X)),
477            "thread_id_y()" => return Ok(Expr::ThreadId(Dimension::Y)),
478            "thread_id_z()" => return Ok(Expr::ThreadId(Dimension::Z)),
479            "workgroup_id()" | "workgroup_id_x()" => return Ok(Expr::WorkgroupId(Dimension::X)),
480            "workgroup_size()" | "workgroup_size_x()" => {
481                return Ok(Expr::WorkgroupSize(Dimension::X))
482            }
483            "lane_id()" => return Ok(Expr::LaneId),
484            "wave_width()" => return Ok(Expr::WaveWidth),
485            "True" | "true" => return Ok(Expr::Literal(Literal::Bool(true))),
486            "False" | "false" => return Ok(Expr::Literal(Literal::Bool(false))),
487            _ => {}
488        }
489
490        if let Some(paren_pos) = s.find('(') {
491            if s.ends_with(')') {
492                let func_name = &s[..paren_pos];
493                let args_str = &s[paren_pos + 1..s.len() - 1];
494                return self.parse_call(func_name, args_str);
495            }
496        }
497
498        if let Some(bracket_pos) = s.find('[') {
499            if s.ends_with(']') {
500                let base = &s[..bracket_pos];
501                let index = &s[bracket_pos + 1..s.len() - 1];
502                return Ok(Expr::Index {
503                    base: Box::new(self.parse_expr(base)?),
504                    index: Box::new(self.parse_expr(index)?),
505                });
506            }
507        }
508
509        if let Ok(v) = s.parse::<i64>() {
510            return Ok(Expr::Literal(Literal::Int(v)));
511        }
512        if let Ok(v) = s.parse::<f64>() {
513            return Ok(Expr::Literal(Literal::Float(v)));
514        }
515
516        if s.starts_with("0x") || s.starts_with("0X") {
517            if let Ok(v) = i64::from_str_radix(&s[2..], 16) {
518                return Ok(Expr::Literal(Literal::Int(v)));
519            }
520        }
521
522        if is_valid_identifier(s) {
523            return Ok(Expr::Var(s.to_string()));
524        }
525
526        Err(CompileError::ParseError {
527            message: format!("cannot parse expression: '{s}'"),
528        })
529    }
530
531    fn parse_call(&self, func_name: &str, args_str: &str) -> Result<Expr, CompileError> {
532        let args: Vec<Expr> = if args_str.trim().is_empty() {
533            Vec::new()
534        } else {
535            args_str
536                .split(',')
537                .map(|a| self.parse_expr(a.trim()))
538                .collect::<Result<_, _>>()?
539        };
540
541        match func_name {
542            "sqrt" => Ok(Expr::Call {
543                func: BuiltinFunc::Sqrt,
544                args,
545            }),
546            "sin" => Ok(Expr::Call {
547                func: BuiltinFunc::Sin,
548                args,
549            }),
550            "cos" => Ok(Expr::Call {
551                func: BuiltinFunc::Cos,
552                args,
553            }),
554            "exp2" => Ok(Expr::Call {
555                func: BuiltinFunc::Exp2,
556                args,
557            }),
558            "log2" => Ok(Expr::Call {
559                func: BuiltinFunc::Log2,
560                args,
561            }),
562            "abs" => Ok(Expr::Call {
563                func: BuiltinFunc::Abs,
564                args,
565            }),
566            "min" => Ok(Expr::Call {
567                func: BuiltinFunc::Min,
568                args,
569            }),
570            "max" => Ok(Expr::Call {
571                func: BuiltinFunc::Max,
572                args,
573            }),
574            "atomic_add" => Ok(Expr::Call {
575                func: BuiltinFunc::AtomicAdd,
576                args,
577            }),
578            "thread_id" => Ok(Expr::ThreadId(Dimension::X)),
579            "workgroup_id" => Ok(Expr::WorkgroupId(Dimension::X)),
580            "workgroup_size" => Ok(Expr::WorkgroupSize(Dimension::X)),
581            "lane_id" => Ok(Expr::LaneId),
582            "wave_width" => Ok(Expr::WaveWidth),
583            "int" | "u32" => {
584                if args.len() == 1 {
585                    Ok(Expr::Cast {
586                        expr: Box::new(args.into_iter().next().unwrap()),
587                        to: Type::U32,
588                    })
589                } else {
590                    Err(CompileError::ParseError {
591                        message: "int() takes 1 argument".to_string(),
592                    })
593                }
594            }
595            "float" | "f32" => {
596                if args.len() == 1 {
597                    Ok(Expr::Cast {
598                        expr: Box::new(args.into_iter().next().unwrap()),
599                        to: Type::F32,
600                    })
601                } else {
602                    Err(CompileError::ParseError {
603                        message: "float() takes 1 argument".to_string(),
604                    })
605                }
606            }
607            _ => Err(CompileError::ParseError {
608                message: format!("unknown function: {func_name}"),
609            }),
610        }
611    }
612}
613
614fn find_top_level_op(s: &str, op: &str) -> Option<usize> {
615    let mut depth = 0i32;
616    let bytes = s.as_bytes();
617    let op_bytes = op.as_bytes();
618    let op_len = op.len();
619
620    if s.len() < op_len {
621        return None;
622    }
623
624    let mut i = s.len() - op_len;
625    loop {
626        let ch = bytes[i + op_len - 1];
627        match ch {
628            b')' | b']' => depth += 1,
629            b'(' | b'[' => depth -= 1,
630            _ => {}
631        }
632        if depth == 0 && &bytes[i..i + op_len] == op_bytes {
633            return Some(i);
634        }
635        if i == 0 {
636            break;
637        }
638        i -= 1;
639    }
640    None
641}
642
643fn is_valid_identifier(s: &str) -> bool {
644    if s.is_empty() {
645        return false;
646    }
647    let mut chars = s.chars();
648    let first = chars.next().unwrap();
649    if !first.is_alphabetic() && first != '_' {
650        return false;
651    }
652    chars.all(|c| c.is_alphanumeric() || c == '_')
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_parse_vector_add() {
661        let source = r#"
662from wave import kernel, f32, thread_id
663
664@kernel
665def vector_add(a: f32[:], b: f32[:], out: f32[:], n: u32):
666    gid = thread_id()
667    if gid < n:
668        out[gid] = a[gid] + b[gid]
669"#;
670        let kernel = parse_python(source).unwrap();
671        assert_eq!(kernel.name, "vector_add");
672        assert_eq!(kernel.params.len(), 4);
673        assert_eq!(kernel.params[0].name, "a");
674        assert_eq!(kernel.params[0].ty, Type::Ptr(AddressSpace::Device));
675        assert_eq!(kernel.params[3].name, "n");
676        assert_eq!(kernel.params[3].ty, Type::U32);
677        assert_eq!(kernel.body.len(), 2);
678    }
679
680    #[test]
681    fn test_parse_simple_assign() {
682        let source = r#"
683@kernel
684def test(n: u32):
685    x = 42
686    y = x + 1
687"#;
688        let kernel = parse_python(source).unwrap();
689        assert_eq!(kernel.name, "test");
690        assert_eq!(kernel.body.len(), 2);
691    }
692
693    #[test]
694    fn test_parse_expressions() {
695        let parser = PythonParser::new(&[]);
696        let expr = parser.parse_expr("a + b * c").unwrap();
697        match &expr {
698            Expr::BinOp { op: BinOp::Add, .. } => {}
699            _ => panic!("expected Add at top level"),
700        }
701    }
702
703    #[test]
704    fn test_parse_array_index() {
705        let parser = PythonParser::new(&[]);
706        let expr = parser.parse_expr("a[i]").unwrap();
707        match &expr {
708            Expr::Index { base, index } => {
709                assert_eq!(**base, Expr::Var("a".into()));
710                assert_eq!(**index, Expr::Var("i".into()));
711            }
712            _ => panic!("expected Index"),
713        }
714    }
715
716    #[test]
717    fn test_parse_thread_id() {
718        let parser = PythonParser::new(&[]);
719        let expr = parser.parse_expr("thread_id()").unwrap();
720        assert_eq!(expr, Expr::ThreadId(Dimension::X));
721    }
722
723    #[test]
724    fn test_parse_literal() {
725        let parser = PythonParser::new(&[]);
726        assert_eq!(
727            parser.parse_expr("42").unwrap(),
728            Expr::Literal(Literal::Int(42))
729        );
730        assert_eq!(
731            parser.parse_expr("3.14").unwrap(),
732            Expr::Literal(Literal::Float(3.14))
733        );
734    }
735}