Skip to main content

wave_compiler/frontend/
cpp.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! C/C++ kernel parser producing HIR.
5//!
6//! Parses a subset of C/C++ suitable for GPU kernels using line-based
7//! parsing for the restricted kernel subset. Supports `__kernel` functions,
8//! basic types, arithmetic, if/else, and array indexing.
9
10use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16/// Parse a C/C++ kernel source string into an HIR Kernel.
17///
18/// # Errors
19///
20/// Returns `CompileError::ParseError` if the source cannot be parsed.
21pub fn parse_cpp(source: &str) -> Result<Kernel, CompileError> {
22    let lines: Vec<&str> = source.lines().collect();
23    let mut parser = CppParser::new(&lines);
24    parser.parse_kernel()
25}
26
27struct CppParser<'a> {
28    lines: &'a [&'a str],
29    pos: usize,
30}
31
32impl<'a> CppParser<'a> {
33    fn new(lines: &'a [&'a str]) -> Self {
34        Self { lines, pos: 0 }
35    }
36
37    fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
38        while self.pos < self.lines.len() {
39            let line = self.lines[self.pos].trim();
40            if line.contains("__kernel") || line.contains("void ") || line.contains("__global__") {
41                return self.parse_function();
42            }
43            self.pos += 1;
44        }
45        Err(CompileError::ParseError {
46            message: "no kernel function found".into(),
47        })
48    }
49
50    fn parse_function(&mut self) -> Result<Kernel, CompileError> {
51        let line = self.lines[self.pos].trim().to_string();
52        let paren_start = line.find('(').ok_or_else(|| CompileError::ParseError {
53            message: "expected '(' in function definition".into(),
54        })?;
55        let paren_end = line.find(')').ok_or_else(|| CompileError::ParseError {
56            message: "expected ')' in function definition".into(),
57        })?;
58
59        let before_paren = &line[..paren_start];
60        let name = before_paren
61            .split_whitespace()
62            .last()
63            .unwrap_or("kernel")
64            .to_string();
65
66        let params_str = &line[paren_start + 1..paren_end];
67        let params = Self::parse_params(params_str);
68
69        self.pos += 1;
70        while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
71            self.pos += 1;
72        }
73
74        let body = self.parse_body()?;
75
76        Ok(Kernel {
77            name,
78            params,
79            body,
80            attributes: KernelAttributes::default(),
81        })
82    }
83
84    fn parse_params(s: &str) -> Vec<KernelParam> {
85        let mut params = Vec::new();
86        for param in s.split(',') {
87            let param = param.trim();
88            if param.is_empty() {
89                continue;
90            }
91            let parts: Vec<&str> = param.split_whitespace().collect();
92            if parts.is_empty() {
93                continue;
94            }
95            let name_part = parts.last().unwrap().trim_start_matches('*');
96            let is_pointer = param.contains('*');
97            let type_str = parts[0];
98
99            let (ty, space) = if is_pointer {
100                (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
101            } else {
102                match type_str {
103                    "float" | "f32" => (Type::F32, AddressSpace::Private),
104                    "int" | "i32" => (Type::I32, AddressSpace::Private),
105                    "double" | "f64" => (Type::F64, AddressSpace::Private),
106                    "bool" => (Type::Bool, AddressSpace::Private),
107                    _ => (Type::U32, AddressSpace::Private),
108                }
109            };
110
111            params.push(KernelParam {
112                name: name_part.to_string(),
113                ty,
114                address_space: space,
115            });
116        }
117        params
118    }
119
120    fn parse_body(&mut self) -> Result<Vec<Stmt>, CompileError> {
121        let mut stmts = Vec::new();
122        while self.pos < self.lines.len() {
123            let line = self.lines[self.pos].trim();
124            if line == "}" {
125                self.pos += 1;
126                break;
127            }
128            if line.is_empty() || line.starts_with("//") {
129                self.pos += 1;
130                continue;
131            }
132            if line.starts_with("if ") || line.starts_with("if(") {
133                stmts.push(self.parse_if()?);
134            } else if line.contains('=') && !line.contains("==") && !line.contains("!=") {
135                stmts.push(self.parse_assignment()?);
136            } else {
137                self.pos += 1;
138            }
139        }
140        Ok(stmts)
141    }
142
143    fn parse_if(&mut self) -> Result<Stmt, CompileError> {
144        let line = self.lines[self.pos].trim();
145        let cond_start = line.find('(').unwrap_or(3);
146        let cond_end = line.rfind(')').unwrap_or(line.len());
147        let cond_str = &line[cond_start + 1..cond_end];
148        let condition = parse_c_expr(cond_str)?;
149        self.pos += 1;
150
151        while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
152            self.pos += 1;
153        }
154
155        let then_body = self.parse_body()?;
156
157        let else_body = if self.pos < self.lines.len() {
158            let next = self.lines[self.pos].trim();
159            if next.starts_with("else") || next == "} else {" {
160                self.pos += 1;
161                while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
162                    self.pos += 1;
163                }
164                Some(self.parse_body()?)
165            } else {
166                None
167            }
168        } else {
169            None
170        };
171
172        Ok(Stmt::If {
173            condition,
174            then_body,
175            else_body,
176        })
177    }
178
179    fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
180        let line = self.lines[self.pos].trim().trim_end_matches(';');
181        self.pos += 1;
182
183        let parts: Vec<&str> = line.splitn(2, '=').collect();
184        if parts.len() != 2 {
185            return Err(CompileError::ParseError {
186                message: format!("invalid assignment: {line}"),
187            });
188        }
189
190        let lhs = parts[0].trim();
191        let rhs = parts[1].trim();
192
193        let lhs_clean = lhs
194            .trim_start_matches("uint32_t ")
195            .trim_start_matches("int ")
196            .trim_start_matches("float ")
197            .trim_start_matches("double ")
198            .trim_start_matches("auto ")
199            .trim();
200
201        let value = parse_c_expr(rhs)?;
202
203        if lhs_clean.contains('[') {
204            let bracket_pos = lhs_clean.find('[').unwrap();
205            let bracket_end = lhs_clean.find(']').unwrap();
206            let base_name = lhs_clean[..bracket_pos].trim();
207            let index_str = &lhs_clean[bracket_pos + 1..bracket_end];
208
209            let base = Expr::Var(base_name.to_string());
210            let index = parse_c_expr(index_str)?;
211            let offset = Expr::BinOp {
212                op: BinOp::Mul,
213                lhs: Box::new(index),
214                rhs: Box::new(Expr::Literal(Literal::Int(4))),
215            };
216            let addr = Expr::BinOp {
217                op: BinOp::Add,
218                lhs: Box::new(base),
219                rhs: Box::new(offset),
220            };
221            return Ok(Stmt::Store {
222                addr,
223                value,
224                space: AddressSpace::Device,
225            });
226        }
227
228        Ok(Stmt::Assign {
229            target: lhs_clean.to_string(),
230            value,
231        })
232    }
233}
234
235fn parse_c_expr(s: &str) -> Result<Expr, CompileError> {
236    let s = s.trim();
237
238    for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
239        if let Some(pos) = s.rfind(op_str) {
240            let lhs = parse_c_expr(&s[..pos])?;
241            let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
242            return Ok(Expr::BinOp {
243                op,
244                lhs: Box::new(lhs),
245                rhs: Box::new(rhs),
246            });
247        }
248    }
249
250    for &(op_str, op) in &[
251        (" * ", BinOp::Mul),
252        (" / ", BinOp::Div),
253        (" % ", BinOp::Mod),
254    ] {
255        if let Some(pos) = s.rfind(op_str) {
256            let lhs = parse_c_expr(&s[..pos])?;
257            let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
258            return Ok(Expr::BinOp {
259                op,
260                lhs: Box::new(lhs),
261                rhs: Box::new(rhs),
262            });
263        }
264    }
265
266    for &(op_str, op) in &[
267        (" < ", BinOp::Lt),
268        (" <= ", BinOp::Le),
269        (" > ", BinOp::Gt),
270        (" >= ", BinOp::Ge),
271        (" == ", BinOp::Eq),
272        (" != ", BinOp::Ne),
273    ] {
274        if let Some(pos) = s.rfind(op_str) {
275            let lhs = parse_c_expr(&s[..pos])?;
276            let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
277            return Ok(Expr::BinOp {
278                op,
279                lhs: Box::new(lhs),
280                rhs: Box::new(rhs),
281            });
282        }
283    }
284
285    if s.starts_with('(') && s.ends_with(')') {
286        return parse_c_expr(&s[1..s.len() - 1]);
287    }
288
289    match s {
290        "thread_id()" => return Ok(Expr::ThreadId(Dimension::X)),
291        "workgroup_id()" => return Ok(Expr::WorkgroupId(Dimension::X)),
292        _ => {}
293    }
294
295    if let Some(bracket_pos) = s.find('[') {
296        if s.ends_with(']') {
297            let base = &s[..bracket_pos];
298            let index = &s[bracket_pos + 1..s.len() - 1];
299            return Ok(Expr::Index {
300                base: Box::new(parse_c_expr(base)?),
301                index: Box::new(parse_c_expr(index)?),
302            });
303        }
304    }
305
306    if let Ok(v) = s.parse::<i64>() {
307        return Ok(Expr::Literal(Literal::Int(v)));
308    }
309    if let Ok(v) = s.parse::<f64>() {
310        return Ok(Expr::Literal(Literal::Float(v)));
311    }
312
313    if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
314        return Ok(Expr::Var(s.to_string()));
315    }
316
317    Err(CompileError::ParseError {
318        message: format!("cannot parse C expression: '{s}'"),
319    })
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_parse_cpp_vector_add() {
328        let source = r#"
329__kernel void vector_add(float* a, float* b, float* out, uint32_t n) {
330    uint32_t gid = thread_id();
331    if (gid < n) {
332        out[gid] = a[gid] + b[gid];
333    }
334}
335"#;
336        let kernel = parse_cpp(source).unwrap();
337        assert_eq!(kernel.name, "vector_add");
338        assert_eq!(kernel.params.len(), 4);
339        assert!(kernel.params[0].ty.is_pointer());
340        assert_eq!(kernel.params[3].ty, Type::U32);
341    }
342
343    #[test]
344    fn test_parse_cpp_simple() {
345        let source = r#"
346void test(uint32_t n) {
347    uint32_t x = 42;
348}
349"#;
350        let kernel = parse_cpp(source).unwrap();
351        assert_eq!(kernel.name, "test");
352        assert_eq!(kernel.body.len(), 1);
353    }
354}