1use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16pub fn parse_cpp(source: &str) -> Result<Kernel, CompileError> {
22 let lines: Vec<&str> = source.lines().collect();
23 let mut parser = CppParser::new(&lines);
24 parser.parse_kernel()
25}
26
27struct CppParser<'a> {
28 lines: &'a [&'a str],
29 pos: usize,
30}
31
32impl<'a> CppParser<'a> {
33 fn new(lines: &'a [&'a str]) -> Self {
34 Self { lines, pos: 0 }
35 }
36
37 fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
38 while self.pos < self.lines.len() {
39 let line = self.lines[self.pos].trim();
40 if line.contains("__kernel") || line.contains("void ") || line.contains("__global__") {
41 return self.parse_function();
42 }
43 self.pos += 1;
44 }
45 Err(CompileError::ParseError {
46 message: "no kernel function found".into(),
47 })
48 }
49
50 fn parse_function(&mut self) -> Result<Kernel, CompileError> {
51 let line = self.lines[self.pos].trim().to_string();
52 let paren_start = line.find('(').ok_or_else(|| CompileError::ParseError {
53 message: "expected '(' in function definition".into(),
54 })?;
55 let paren_end = line.find(')').ok_or_else(|| CompileError::ParseError {
56 message: "expected ')' in function definition".into(),
57 })?;
58
59 let before_paren = &line[..paren_start];
60 let name = before_paren
61 .split_whitespace()
62 .last()
63 .unwrap_or("kernel")
64 .to_string();
65
66 let params_str = &line[paren_start + 1..paren_end];
67 let params = Self::parse_params(params_str);
68
69 self.pos += 1;
70 while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
71 self.pos += 1;
72 }
73
74 let body = self.parse_body()?;
75
76 Ok(Kernel {
77 name,
78 params,
79 body,
80 attributes: KernelAttributes::default(),
81 })
82 }
83
84 fn parse_params(s: &str) -> Vec<KernelParam> {
85 let mut params = Vec::new();
86 for param in s.split(',') {
87 let param = param.trim();
88 if param.is_empty() {
89 continue;
90 }
91 let parts: Vec<&str> = param.split_whitespace().collect();
92 if parts.is_empty() {
93 continue;
94 }
95 let name_part = parts.last().unwrap().trim_start_matches('*');
96 let is_pointer = param.contains('*');
97 let type_str = parts[0];
98
99 let (ty, space) = if is_pointer {
100 (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
101 } else {
102 match type_str {
103 "float" | "f32" => (Type::F32, AddressSpace::Private),
104 "int" | "i32" => (Type::I32, AddressSpace::Private),
105 "double" | "f64" => (Type::F64, AddressSpace::Private),
106 "bool" => (Type::Bool, AddressSpace::Private),
107 _ => (Type::U32, AddressSpace::Private),
108 }
109 };
110
111 params.push(KernelParam {
112 name: name_part.to_string(),
113 ty,
114 address_space: space,
115 });
116 }
117 params
118 }
119
120 fn parse_body(&mut self) -> Result<Vec<Stmt>, CompileError> {
121 let mut stmts = Vec::new();
122 while self.pos < self.lines.len() {
123 let line = self.lines[self.pos].trim();
124 if line == "}" {
125 self.pos += 1;
126 break;
127 }
128 if line.is_empty() || line.starts_with("//") {
129 self.pos += 1;
130 continue;
131 }
132 if line.starts_with("if ") || line.starts_with("if(") {
133 stmts.push(self.parse_if()?);
134 } else if line.contains('=') && !line.contains("==") && !line.contains("!=") {
135 stmts.push(self.parse_assignment()?);
136 } else {
137 self.pos += 1;
138 }
139 }
140 Ok(stmts)
141 }
142
143 fn parse_if(&mut self) -> Result<Stmt, CompileError> {
144 let line = self.lines[self.pos].trim();
145 let cond_start = line.find('(').unwrap_or(3);
146 let cond_end = line.rfind(')').unwrap_or(line.len());
147 let cond_str = &line[cond_start + 1..cond_end];
148 let condition = parse_c_expr(cond_str)?;
149 self.pos += 1;
150
151 while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
152 self.pos += 1;
153 }
154
155 let then_body = self.parse_body()?;
156
157 let else_body = if self.pos < self.lines.len() {
158 let next = self.lines[self.pos].trim();
159 if next.starts_with("else") || next == "} else {" {
160 self.pos += 1;
161 while self.pos < self.lines.len() && self.lines[self.pos].trim() == "{" {
162 self.pos += 1;
163 }
164 Some(self.parse_body()?)
165 } else {
166 None
167 }
168 } else {
169 None
170 };
171
172 Ok(Stmt::If {
173 condition,
174 then_body,
175 else_body,
176 })
177 }
178
179 fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
180 let line = self.lines[self.pos].trim().trim_end_matches(';');
181 self.pos += 1;
182
183 let parts: Vec<&str> = line.splitn(2, '=').collect();
184 if parts.len() != 2 {
185 return Err(CompileError::ParseError {
186 message: format!("invalid assignment: {line}"),
187 });
188 }
189
190 let lhs = parts[0].trim();
191 let rhs = parts[1].trim();
192
193 let lhs_clean = lhs
194 .trim_start_matches("uint32_t ")
195 .trim_start_matches("int ")
196 .trim_start_matches("float ")
197 .trim_start_matches("double ")
198 .trim_start_matches("auto ")
199 .trim();
200
201 let value = parse_c_expr(rhs)?;
202
203 if lhs_clean.contains('[') {
204 let bracket_pos = lhs_clean.find('[').unwrap();
205 let bracket_end = lhs_clean.find(']').unwrap();
206 let base_name = lhs_clean[..bracket_pos].trim();
207 let index_str = &lhs_clean[bracket_pos + 1..bracket_end];
208
209 let base = Expr::Var(base_name.to_string());
210 let index = parse_c_expr(index_str)?;
211 let offset = Expr::BinOp {
212 op: BinOp::Mul,
213 lhs: Box::new(index),
214 rhs: Box::new(Expr::Literal(Literal::Int(4))),
215 };
216 let addr = Expr::BinOp {
217 op: BinOp::Add,
218 lhs: Box::new(base),
219 rhs: Box::new(offset),
220 };
221 return Ok(Stmt::Store {
222 addr,
223 value,
224 space: AddressSpace::Device,
225 });
226 }
227
228 Ok(Stmt::Assign {
229 target: lhs_clean.to_string(),
230 value,
231 })
232 }
233}
234
235fn parse_c_expr(s: &str) -> Result<Expr, CompileError> {
236 let s = s.trim();
237
238 for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
239 if let Some(pos) = s.rfind(op_str) {
240 let lhs = parse_c_expr(&s[..pos])?;
241 let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
242 return Ok(Expr::BinOp {
243 op,
244 lhs: Box::new(lhs),
245 rhs: Box::new(rhs),
246 });
247 }
248 }
249
250 for &(op_str, op) in &[
251 (" * ", BinOp::Mul),
252 (" / ", BinOp::Div),
253 (" % ", BinOp::Mod),
254 ] {
255 if let Some(pos) = s.rfind(op_str) {
256 let lhs = parse_c_expr(&s[..pos])?;
257 let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
258 return Ok(Expr::BinOp {
259 op,
260 lhs: Box::new(lhs),
261 rhs: Box::new(rhs),
262 });
263 }
264 }
265
266 for &(op_str, op) in &[
267 (" < ", BinOp::Lt),
268 (" <= ", BinOp::Le),
269 (" > ", BinOp::Gt),
270 (" >= ", BinOp::Ge),
271 (" == ", BinOp::Eq),
272 (" != ", BinOp::Ne),
273 ] {
274 if let Some(pos) = s.rfind(op_str) {
275 let lhs = parse_c_expr(&s[..pos])?;
276 let rhs = parse_c_expr(&s[pos + op_str.len()..])?;
277 return Ok(Expr::BinOp {
278 op,
279 lhs: Box::new(lhs),
280 rhs: Box::new(rhs),
281 });
282 }
283 }
284
285 if s.starts_with('(') && s.ends_with(')') {
286 return parse_c_expr(&s[1..s.len() - 1]);
287 }
288
289 match s {
290 "thread_id()" => return Ok(Expr::ThreadId(Dimension::X)),
291 "workgroup_id()" => return Ok(Expr::WorkgroupId(Dimension::X)),
292 _ => {}
293 }
294
295 if let Some(bracket_pos) = s.find('[') {
296 if s.ends_with(']') {
297 let base = &s[..bracket_pos];
298 let index = &s[bracket_pos + 1..s.len() - 1];
299 return Ok(Expr::Index {
300 base: Box::new(parse_c_expr(base)?),
301 index: Box::new(parse_c_expr(index)?),
302 });
303 }
304 }
305
306 if let Ok(v) = s.parse::<i64>() {
307 return Ok(Expr::Literal(Literal::Int(v)));
308 }
309 if let Ok(v) = s.parse::<f64>() {
310 return Ok(Expr::Literal(Literal::Float(v)));
311 }
312
313 if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
314 return Ok(Expr::Var(s.to_string()));
315 }
316
317 Err(CompileError::ParseError {
318 message: format!("cannot parse C expression: '{s}'"),
319 })
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_parse_cpp_vector_add() {
328 let source = r#"
329__kernel void vector_add(float* a, float* b, float* out, uint32_t n) {
330 uint32_t gid = thread_id();
331 if (gid < n) {
332 out[gid] = a[gid] + b[gid];
333 }
334}
335"#;
336 let kernel = parse_cpp(source).unwrap();
337 assert_eq!(kernel.name, "vector_add");
338 assert_eq!(kernel.params.len(), 4);
339 assert!(kernel.params[0].ty.is_pointer());
340 assert_eq!(kernel.params[3].ty, Type::U32);
341 }
342
343 #[test]
344 fn test_parse_cpp_simple() {
345 let source = r#"
346void test(uint32_t n) {
347 uint32_t x = 42;
348}
349"#;
350 let kernel = parse_cpp(source).unwrap();
351 assert_eq!(kernel.name, "test");
352 assert_eq!(kernel.body.len(), 1);
353 }
354}